its for when you want to get like notifications for your reposts

feat: shard jetstream connections

ptr.pet 1e47dacb 76405e9f

verified
+152 -33
+119 -13
server/jetstream.go
··· 2 2 3 3 import ( 4 4 "context" 5 + "fmt" 5 6 "log/slog" 7 + "sync" 8 + "time" 6 9 7 10 "github.com/bluesky-social/jetstream/pkg/client" 8 11 "github.com/bluesky-social/jetstream/pkg/client/schedulers/sequential" 9 12 "github.com/bluesky-social/jetstream/pkg/models" 10 13 ) 11 14 15 + func Chunk[T any](slice []T, chunkSize int) [][]T { 16 + var chunks [][]T 17 + for i := 0; i < len(slice); i += chunkSize { 18 + end := min(i+chunkSize, len(slice)) 19 + chunks = append(chunks, slice[i:end]) 20 + } 21 + return chunks 22 + } 23 + 24 + type Stream struct { 25 + inner *client.Client 26 + cancel context.CancelFunc 27 + } 28 + type StreamManager struct { 29 + ctx context.Context 30 + logger *slog.Logger 31 + streamsLock sync.Mutex 32 + streams map[int]Stream 33 + name string 34 + handleEvent HandleEvent 35 + optsFn OptsFn 36 + } 37 + 38 + func NewStreamManager(logger *slog.Logger, name string, handleEvent HandleEvent, optsFn OptsFn) StreamManager { 39 + return StreamManager{ 40 + ctx: context.TODO(), 41 + logger: logger.With("stream", name), 42 + streamsLock: sync.Mutex{}, 43 + streams: make(map[int]Stream), 44 + name: name, 45 + handleEvent: handleEvent, 46 + optsFn: optsFn, 47 + } 48 + } 49 + 50 + // doesnt lock streams!!! 51 + func (manager *StreamManager) startSingle(id int, opts models.SubscriberOptionsUpdatePayload) { 52 + ctx, cancel := context.WithCancel(manager.ctx) 53 + stream := Stream{inner: nil, cancel: cancel} 54 + // add to streams and put on wait group 55 + manager.streams[id] = stream 56 + go startJetstreamLoop(ctx, manager.logger.With("streamId", id), &stream.inner, fmt.Sprintf("%s_%d", manager.name, id), manager.handleEvent, opts) 57 + } 58 + 59 + func (manager *StreamManager) chunkedOpts() ([]models.SubscriberOptionsUpdatePayload, int) { 60 + results := make([]models.SubscriberOptionsUpdatePayload, 0) 61 + opts := manager.optsFn() 62 + for _, wantedDidsChunk := range Chunk(opts.WantedDIDs, 9999) { 63 + results = append(results, models.SubscriberOptionsUpdatePayload{ 64 + WantedCollections: opts.WantedCollections, 65 + WantedDIDs: wantedDidsChunk, 66 + MaxMessageSizeBytes: opts.MaxMessageSizeBytes, 67 + }) 68 + } 69 + return results, len(opts.WantedDIDs) 70 + } 71 + 72 + func (manager *StreamManager) updateOpts() { 73 + chunks, userCount := manager.chunkedOpts() 74 + manager.streamsLock.Lock() 75 + idsSeen := make(map[int]struct{}, 0) 76 + // update existing streams or create new ones 77 + for id, opts := range chunks { 78 + idsSeen[id] = struct{}{} 79 + if len(manager.streams) > id { 80 + stream := manager.streams[id] 81 + if stream.inner == nil { 82 + continue 83 + } 84 + if err := stream.inner.SendOptionsUpdate(opts); err != nil { 85 + manager.logger.Error("couldnt update follow stream opts", "error", err, "streamId", id) 86 + } 87 + } else { 88 + manager.startSingle(id, opts) 89 + } 90 + } 91 + // cancel and delete unused streams 92 + for k := range manager.streams { 93 + if _, exists := idsSeen[k]; !exists { 94 + manager.streams[k].cancel() 95 + delete(manager.streams, k) 96 + } 97 + } 98 + manager.streamsLock.Unlock() 99 + manager.logger.Info("updated opts", "userCount", userCount) 100 + } 101 + 12 102 type HandleEvent func(context.Context, *models.Event) error 103 + type OptsFn func() models.SubscriberOptionsUpdatePayload 13 104 14 - func startJetstreamLoop(logger *slog.Logger, outStream **client.Client, name string, handleEvent HandleEvent, optsFn func() models.SubscriberOptionsUpdatePayload) { 105 + func startJetstreamLoop(ctx context.Context, logger *slog.Logger, outStream **client.Client, name string, handleEvent HandleEvent, opts models.SubscriberOptionsUpdatePayload) { 106 + backoff := time.Second 15 107 for { 16 - stream, startFn, err := startJetstreamClient(name, optsFn(), handleEvent) 108 + done := make(chan struct{}) 109 + if ctx.Err() != nil { 110 + break 111 + } 112 + stream, startFn, err := startJetstreamClient(ctx, logger, name, handleEvent) 17 113 *outStream = stream 18 114 if startFn != nil { 19 - err = startFn() 115 + logger.Info("starting jetstream client", "collections", opts.WantedCollections, "userCount", len(opts.WantedDIDs)) 116 + go func() { 117 + err = startFn() 118 + done <- struct{}{} 119 + }() 120 + // HACK: we need to wait for the websocket connection to start here. so we do 121 + // need to upstream something to jetstream client 122 + time.Sleep(time.Second * 2) 123 + err = stream.SendOptionsUpdate(opts) 124 + if err == nil { 125 + <-done 126 + } 20 127 } 21 128 if err != nil { 22 - logger.Error("stream failed", "name", name, "error", err) 129 + logger.Error("stream failed", "error", err, "backoff", backoff) 130 + time.Sleep(backoff) 131 + backoff = backoff * 2 132 + } else { 133 + backoff = time.Second 23 134 } 24 135 } 25 136 } 26 137 27 - func startJetstreamClient(name string, opts models.SubscriberOptionsUpdatePayload, handleEvent HandleEvent) (*client.Client, func() error, error) { 28 - ctx := context.Background() 29 - 138 + func startJetstreamClient(ctx context.Context, logger *slog.Logger, name string, handleEvent HandleEvent) (*client.Client, func() error, error) { 30 139 config := client.DefaultClientConfig() 31 140 config.WebsocketURL = "wss://jetstream1.us-west.bsky.network/subscribe" 32 141 config.Compress = true 33 - config.WantedCollections = opts.WantedCollections 34 - config.WantedDids = opts.WantedDIDs 35 - config.RequireHello = false 142 + config.RequireHello = true 36 143 37 144 scheduler := sequential.NewScheduler(name, logger, handleEvent) 38 145 39 146 c, err := client.NewClient(config, logger, scheduler) 40 147 if err != nil { 41 - logger.Error("failed to create jetstream client", "name", name, "error", err) 148 + logger.Error("failed to create jetstream client", "error", err) 42 149 return nil, nil, err 43 150 } 44 151 45 152 startFn := func() error { 46 - logger.Info("starting jetstream client", "name", name, "collections", opts.WantedCollections, "wanted_dids", len(opts.WantedDIDs)) 47 153 if err := c.ConnectAndRead(ctx, nil); err != nil { 48 - logger.Error("jetstream client failed", "name", name, "error", err) 154 + logger.Error("jetstream client failed", "error", err) 49 155 return err 50 156 } 51 157
+33 -20
server/main.go
··· 12 12 "github.com/bluesky-social/indigo/api/bsky" 13 13 "github.com/bluesky-social/indigo/atproto/syntax" 14 14 "github.com/bluesky-social/indigo/xrpc" 15 - "github.com/bluesky-social/jetstream/pkg/client" 16 15 "github.com/bluesky-social/jetstream/pkg/models" 17 16 "github.com/cornelk/hashmap" 18 17 "github.com/google/uuid" ··· 34 33 35 34 type ActorData struct { 36 35 targets *hashmap.Map[string, *SubscriberData] 37 - likes map[syntax.RecordKey]bsky.FeedLike 36 + likes *hashmap.Map[syntax.RecordKey, bsky.FeedLike] 38 37 follows *hashmap.Map[syntax.RecordKey, bsky.GraphFollow] 39 38 followsCursor atomic.Pointer[string] 40 39 profile *bsky.ActorDefs_ProfileViewDetailed ··· 68 67 subscribers = hashmap.New[string, *SubscriberData]() 69 68 actorData = hashmap.New[syntax.DID, *ActorData]() 70 69 71 - likeStream *client.Client 72 - followStream *client.Client 70 + likeStreams StreamManager 71 + followStreams StreamManager 73 72 74 73 upgrader = websocket.Upgrader{ 75 74 CheckOrigin: func(r *http.Request) bool { ··· 93 92 return dids 94 93 } 95 94 95 + func getLikeDids() []string { 96 + _dids := make(Set[string], subscribers.Len()*5000) 97 + subscribers.Range(func(s string, sd *SubscriberData) bool { 98 + for did := range sd.listenTo { 99 + _dids[string(did)] = struct{}{} 100 + } 101 + return true 102 + }) 103 + dids := make([]string, 0, len(_dids)) 104 + for k := range _dids { 105 + dids = append(dids, k) 106 + } 107 + return dids 108 + } 109 + 96 110 func getActorData(did syntax.DID) *ActorData { 97 111 ud, _ := actorData.GetOrInsert(did, &ActorData{ 98 112 targets: hashmap.New[string, *SubscriberData](), 99 - likes: make(map[syntax.RecordKey]bsky.FeedLike), 113 + likes: hashmap.New[syntax.RecordKey, bsky.FeedLike](), 100 114 follows: hashmap.New[syntax.RecordKey, bsky.GraphFollow](), 101 115 }) 102 116 return ud ··· 120 134 func main() { 121 135 logger = slog.Default() 122 136 123 - go startJetstreamLoop(logger, &likeStream, "like_tracker", HandleLikeEvent, getLikeStreamOpts) 124 - go startJetstreamLoop(logger, &followStream, "subscriber", HandleFollowEvent, getFollowStreamOpts) 137 + likeStreams = NewStreamManager(logger, "like-tracker", HandleLikeEvent, getLikeStreamOpts) 138 + followStreams = NewStreamManager(logger, "subscriber", HandleFollowEvent, getFollowStreamOpts) 125 139 126 140 r := mux.NewRouter() 127 141 r.HandleFunc("/subscribe/{did}", handleSubscribe).Methods("GET") ··· 209 223 for listenDid := range sd.listenTo { 210 224 markActorForLikes(sid, sd, listenDid) 211 225 } 212 - updateFollowStreamOpts() 226 + updateStreamOpts() 213 227 // delete subscriber after we are done 214 228 defer func() { 215 229 for listenDid := range sd.listenTo { 216 230 unmarkActorForLikes(sid, listenDid) 217 231 } 218 232 subscribers.Del(sid) 219 - updateFollowStreamOpts() 233 + updateStreamOpts() 220 234 }() 221 235 222 236 logger.Info("serving subscriber") ··· 240 254 logger.Info("invalid message", "error", err) 241 255 break 242 256 } 257 + 243 258 // remove all current listens and add the ones the user requested 244 259 for listenDid := range sd.listenTo { 245 260 unmarkActorForLikes(sid, listenDid) ··· 249 264 sd.listenTo[listenDid] = struct{}{} 250 265 markActorForLikes(sid, sd, listenDid) 251 266 } 267 + 268 + updateStreamOpts() 252 269 } 253 270 } 254 271 } ··· 256 273 func getLikeStreamOpts() models.SubscriberOptionsUpdatePayload { 257 274 return models.SubscriberOptionsUpdatePayload{ 258 275 WantedCollections: []string{"app.bsky.feed.like"}, 276 + WantedDIDs: getLikeDids(), 259 277 } 260 278 } 261 279 ··· 266 284 } 267 285 } 268 286 269 - func updateFollowStreamOpts() { 270 - opts := getFollowStreamOpts() 271 - err := followStream.SendOptionsUpdate(opts) 272 - if err != nil { 273 - logger.Error("couldnt update follow stream opts", "error", err) 274 - return 275 - } 276 - logger.Info("updated follow stream opts", "userCount", len(opts.WantedDIDs)) 287 + func updateStreamOpts() { 288 + likeStreams.updateOpts() 289 + followStreams.updateOpts() 277 290 } 278 291 279 292 func HandleLikeEvent(ctx context.Context, event *models.Event) error { ··· 295 308 296 309 var like bsky.FeedLike 297 310 if deleted { 298 - if l, exists := ud.likes[rkey]; exists { 311 + if l, exists := ud.likes.Get(rkey); exists { 299 312 like = l 300 - defer delete(ud.likes, rkey) 313 + defer ud.likes.Del(rkey) 301 314 } else { 302 315 logger.Error("like record not found", "rkey", rkey) 303 316 return nil ··· 313 326 314 327 // store for later when it gets deleted so we can fetch the record 315 328 if !deleted { 316 - ud.likes[rkey] = like 329 + ud.likes.Insert(rkey, like) 317 330 } 318 331 319 332 repostURI := syntax.ATURI(like.Via.Uri)