A container registry that uses the AT Protocol for manifest storage and S3 for blob storage.

move download stats to the hold account so it can persist across different appviews

evan.jarrett.net f19dfa27 af99929a

verified
+1387 -351
+16 -3
cmd/appview/serve.go
··· 149 149 // Set global refresher for middleware 150 150 middleware.SetGlobalRefresher(refresher) 151 151 152 - // Set global database for pull/push metrics tracking 153 - metricsDB := db.NewMetricsDB(uiDatabase) 154 - middleware.SetGlobalDatabase(metricsDB) 152 + // Set global database for hold DID lookups (used by blob routing) 153 + holdDIDDB := db.NewHoldDIDDB(uiDatabase) 154 + middleware.SetGlobalDatabase(holdDIDDB) 155 155 156 156 // Create RemoteHoldAuthorizer for hold authorization with caching 157 157 holdAuthorizer := auth.NewRemoteHoldAuthorizer(uiDatabase, testMode) ··· 160 160 161 161 // Initialize Jetstream workers (background services before HTTP routes) 162 162 initializeJetstream(uiDatabase, &cfg.Jetstream, defaultHoldDID, testMode, refresher) 163 + 164 + // Run stats migration to holds (one-time migration, skipped if already done) 165 + go func() { 166 + // Wait for services to be ready (Docker startup race condition) 167 + time.Sleep(10 * time.Second) 168 + // Create service token getter callback that uses auth.GetOrFetchServiceToken 169 + getServiceToken := func(ctx context.Context, userDID, holdDID, pdsEndpoint string) (string, error) { 170 + return auth.GetOrFetchServiceToken(ctx, refresher, userDID, holdDID, pdsEndpoint) 171 + } 172 + if err := db.MigrateStatsToHolds(context.Background(), uiDatabase, getServiceToken); err != nil { 173 + slog.Warn("Stats migration failed", "error", err) 174 + } 175 + }() 163 176 164 177 // Create main chi router 165 178 mainRouter := chi.NewRouter()
+12 -45
pkg/appview/db/queries.go
··· 1446 1446 } 1447 1447 1448 1448 // UpsertRepositoryStats inserts or updates repository stats 1449 + // Note: star_count is calculated dynamically from the stars table, not stored here 1449 1450 func UpsertRepositoryStats(db *sql.DB, stats *RepositoryStats) error { 1450 1451 _, err := db.Exec(` 1451 - INSERT INTO repository_stats (did, repository, star_count, pull_count, last_pull, push_count, last_push) 1452 - VALUES (?, ?, ?, ?, ?, ?, ?) 1452 + INSERT INTO repository_stats (did, repository, pull_count, last_pull, push_count, last_push) 1453 + VALUES (?, ?, ?, ?, ?, ?) 1453 1454 ON CONFLICT(did, repository) DO UPDATE SET 1454 - star_count = excluded.star_count, 1455 1455 pull_count = excluded.pull_count, 1456 1456 last_pull = excluded.last_pull, 1457 1457 push_count = excluded.push_count, 1458 1458 last_push = excluded.last_push 1459 - `, stats.DID, stats.Repository, stats.StarCount, stats.PullCount, stats.LastPull, stats.PushCount, stats.LastPush) 1459 + `, stats.DID, stats.Repository, stats.PullCount, stats.LastPull, stats.PushCount, stats.LastPush) 1460 1460 return err 1461 1461 } 1462 1462 ··· 1593 1593 return nil 1594 1594 } 1595 1595 1596 - // IncrementPullCount increments the pull count for a repository 1597 - func IncrementPullCount(db *sql.DB, did, repository string) error { 1598 - _, err := db.Exec(` 1599 - INSERT INTO repository_stats (did, repository, pull_count, last_pull) 1600 - VALUES (?, ?, 1, datetime('now')) 1601 - ON CONFLICT(did, repository) DO UPDATE SET 1602 - pull_count = pull_count + 1, 1603 - last_pull = datetime('now') 1604 - `, did, repository) 1605 - return err 1606 - } 1607 - 1608 - // IncrementPushCount increments the push count for a repository 1609 - func IncrementPushCount(db *sql.DB, did, repository string) error { 1610 - _, err := db.Exec(` 1611 - INSERT INTO repository_stats (did, repository, push_count, last_push) 1612 - VALUES (?, ?, 1, datetime('now')) 1613 - ON CONFLICT(did, repository) DO UPDATE SET 1614 - push_count = push_count + 1, 1615 - last_push = datetime('now') 1616 - `, did, repository) 1617 - return err 1618 - } 1619 - 1620 1596 // parseTimestamp parses a timestamp string with multiple format attempts 1621 1597 func parseTimestamp(s string) (time.Time, error) { 1622 1598 formats := []string{ ··· 1634 1610 return time.Time{}, fmt.Errorf("unable to parse timestamp: %s", s) 1635 1611 } 1636 1612 1637 - // MetricsDB wraps a sql.DB and implements the metrics interface for middleware 1638 - type MetricsDB struct { 1613 + // HoldDIDDB wraps a sql.DB and implements the HoldDIDLookup interface for middleware 1614 + // This is a minimal wrapper that only provides hold DID lookups for blob routing 1615 + type HoldDIDDB struct { 1639 1616 db *sql.DB 1640 1617 } 1641 1618 1642 - // NewMetricsDB creates a new metrics database wrapper 1643 - func NewMetricsDB(db *sql.DB) *MetricsDB { 1644 - return &MetricsDB{db: db} 1645 - } 1646 - 1647 - // IncrementPullCount increments the pull count for a repository 1648 - func (m *MetricsDB) IncrementPullCount(did, repository string) error { 1649 - return IncrementPullCount(m.db, did, repository) 1650 - } 1651 - 1652 - // IncrementPushCount increments the push count for a repository 1653 - func (m *MetricsDB) IncrementPushCount(did, repository string) error { 1654 - return IncrementPushCount(m.db, did, repository) 1619 + // NewHoldDIDDB creates a new hold DID database wrapper 1620 + func NewHoldDIDDB(db *sql.DB) *HoldDIDDB { 1621 + return &HoldDIDDB{db: db} 1655 1622 } 1656 1623 1657 1624 // GetLatestHoldDIDForRepo returns the hold DID from the most recent manifest for a repository 1658 - func (m *MetricsDB) GetLatestHoldDIDForRepo(did, repository string) (string, error) { 1659 - return GetLatestHoldDIDForRepo(m.db, did, repository) 1625 + func (h *HoldDIDDB) GetLatestHoldDIDForRepo(did, repository string) (string, error) { 1626 + return GetLatestHoldDIDForRepo(h.db, did, repository) 1660 1627 } 1661 1628 1662 1629 // GetFeaturedRepositories fetches top repositories sorted by stars and pulls
+231
pkg/appview/db/stats_migration.go
··· 1 + package db 2 + 3 + import ( 4 + "bytes" 5 + "context" 6 + "database/sql" 7 + "encoding/json" 8 + "fmt" 9 + "io" 10 + "log/slog" 11 + "net/http" 12 + "time" 13 + 14 + "atcr.io/pkg/atproto" 15 + ) 16 + 17 + // ServiceTokenGetter is a function type for getting service tokens. 18 + // This avoids importing auth from db (which would create import cycles with tests). 19 + type ServiceTokenGetter func(ctx context.Context, userDID, holdDID, pdsEndpoint string) (string, error) 20 + 21 + // MigrateStatsToHolds migrates existing repository_stats data to hold services. 22 + // This is a one-time migration that runs on startup. 23 + // 24 + // The migration: 25 + // 1. Checks if migration has already completed 26 + // 2. Reads all repository_stats entries 27 + // 3. For each entry, looks up the hold DID from manifests table 28 + // 4. Gets a service token for the user and calls the hold's setStats endpoint 29 + // 5. Marks migration complete after all entries are processed 30 + // 31 + // If a hold is offline, the migration logs a warning and continues. 32 + // The hold will receive real-time stats updates via Jetstream once online. 33 + // 34 + // The getServiceToken parameter is a callback to avoid import cycles with pkg/auth. 35 + func MigrateStatsToHolds(ctx context.Context, db *sql.DB, getServiceToken ServiceTokenGetter) error { 36 + // Check if migration already done 37 + var migrationDone bool 38 + err := db.QueryRowContext(ctx, ` 39 + SELECT EXISTS( 40 + SELECT 1 FROM schema_migrations WHERE version = 1000 41 + ) 42 + `).Scan(&migrationDone) 43 + 44 + // Table might not exist yet on fresh install 45 + if err == sql.ErrNoRows { 46 + migrationDone = false 47 + } else if err != nil { 48 + // Check if it's a "no such table" error (fresh install) 49 + if err.Error() != "no such table: schema_migrations" { 50 + return fmt.Errorf("failed to check migration status: %w", err) 51 + } 52 + migrationDone = false 53 + } 54 + 55 + if migrationDone { 56 + slog.Debug("Stats migration already complete, skipping", "component", "migration") 57 + return nil 58 + } 59 + 60 + slog.Info("Starting stats migration to holds", "component", "migration") 61 + 62 + // Get all repository_stats entries 63 + rows, err := db.QueryContext(ctx, ` 64 + SELECT did, repository, pull_count, last_pull, push_count, last_push 65 + FROM repository_stats 66 + WHERE pull_count > 0 OR push_count > 0 67 + `) 68 + if err != nil { 69 + // Table might not exist on fresh install 70 + if err.Error() == "no such table: repository_stats" { 71 + slog.Info("No repository_stats table found, skipping migration", "component", "migration") 72 + return markMigrationComplete(db) 73 + } 74 + return fmt.Errorf("failed to query repository_stats: %w", err) 75 + } 76 + defer rows.Close() 77 + 78 + var stats []struct { 79 + DID string 80 + Repository string 81 + PullCount int64 82 + LastPull sql.NullString 83 + PushCount int64 84 + LastPush sql.NullString 85 + } 86 + 87 + for rows.Next() { 88 + var stat struct { 89 + DID string 90 + Repository string 91 + PullCount int64 92 + LastPull sql.NullString 93 + PushCount int64 94 + LastPush sql.NullString 95 + } 96 + if err := rows.Scan(&stat.DID, &stat.Repository, &stat.PullCount, &stat.LastPull, &stat.PushCount, &stat.LastPush); err != nil { 97 + return fmt.Errorf("failed to scan stat: %w", err) 98 + } 99 + stats = append(stats, stat) 100 + } 101 + 102 + if len(stats) == 0 { 103 + slog.Info("No stats to migrate", "component", "migration") 104 + return markMigrationComplete(db) 105 + } 106 + 107 + slog.Info("Found stats entries to migrate", "component", "migration", "count", len(stats)) 108 + 109 + // Process each stat 110 + successCount := 0 111 + skipCount := 0 112 + errorCount := 0 113 + 114 + for _, stat := range stats { 115 + // Look up hold DID from manifests table 116 + holdDID, err := GetLatestHoldDIDForRepo(db, stat.DID, stat.Repository) 117 + if err != nil || holdDID == "" { 118 + slog.Debug("No hold DID found for repo, skipping", "component", "migration", 119 + "did", stat.DID, "repository", stat.Repository) 120 + skipCount++ 121 + continue 122 + } 123 + 124 + // Get user's PDS endpoint 125 + user, err := GetUserByDID(db, stat.DID) 126 + if err != nil || user == nil { 127 + slog.Debug("User not found in database, skipping", "component", "migration", 128 + "did", stat.DID, "repository", stat.Repository) 129 + skipCount++ 130 + continue 131 + } 132 + 133 + // Get service token for the user 134 + serviceToken, err := getServiceToken(ctx, stat.DID, holdDID, user.PDSEndpoint) 135 + if err != nil { 136 + slog.Warn("Failed to get service token, skipping", "component", "migration", 137 + "did", stat.DID, "repository", stat.Repository, "error", err) 138 + errorCount++ 139 + continue 140 + } 141 + 142 + // Resolve hold DID to HTTP URL 143 + holdURL := atproto.ResolveHoldURL(holdDID) 144 + if holdURL == "" { 145 + slog.Warn("Failed to resolve hold DID, skipping", "component", "migration", 146 + "hold_did", holdDID) 147 + errorCount++ 148 + continue 149 + } 150 + 151 + // Call hold's setStats endpoint 152 + err = callSetStats(ctx, holdURL, serviceToken, stat.DID, stat.Repository, 153 + stat.PullCount, stat.PushCount, stat.LastPull.String, stat.LastPush.String) 154 + if err != nil { 155 + slog.Warn("Failed to migrate stats to hold, continuing", "component", "migration", 156 + "did", stat.DID, "repository", stat.Repository, "hold", holdDID, "error", err) 157 + errorCount++ 158 + continue 159 + } 160 + 161 + successCount++ 162 + slog.Debug("Migrated stats", "component", "migration", 163 + "did", stat.DID, "repository", stat.Repository, "hold", holdDID, 164 + "pull_count", stat.PullCount, "push_count", stat.PushCount) 165 + } 166 + 167 + slog.Info("Stats migration completed", "component", "migration", 168 + "success", successCount, "skipped", skipCount, "errors", errorCount, "total", len(stats)) 169 + 170 + // Mark migration complete (even if some failed - they'll get updates via Jetstream) 171 + return markMigrationComplete(db) 172 + } 173 + 174 + // markMigrationComplete records that the stats migration has been done 175 + func markMigrationComplete(db *sql.DB) error { 176 + _, err := db.Exec(` 177 + INSERT INTO schema_migrations (version, applied_at) 178 + VALUES (1000, datetime('now')) 179 + ON CONFLICT(version) DO NOTHING 180 + `) 181 + if err != nil { 182 + return fmt.Errorf("failed to mark migration complete: %w", err) 183 + } 184 + return nil 185 + } 186 + 187 + // callSetStats calls the hold's io.atcr.hold.setStats endpoint 188 + func callSetStats(ctx context.Context, holdURL, serviceToken, ownerDID, repository string, pullCount, pushCount int64, lastPull, lastPush string) error { 189 + // Build request 190 + reqBody := map[string]any{ 191 + "ownerDid": ownerDID, 192 + "repository": repository, 193 + "pullCount": pullCount, 194 + "pushCount": pushCount, 195 + } 196 + if lastPull != "" { 197 + reqBody["lastPull"] = lastPull 198 + } 199 + if lastPush != "" { 200 + reqBody["lastPush"] = lastPush 201 + } 202 + 203 + body, err := json.Marshal(reqBody) 204 + if err != nil { 205 + return fmt.Errorf("failed to marshal request: %w", err) 206 + } 207 + 208 + // Create HTTP request 209 + req, err := http.NewRequestWithContext(ctx, "POST", holdURL+atproto.HoldSetStats, bytes.NewReader(body)) 210 + if err != nil { 211 + return fmt.Errorf("failed to create request: %w", err) 212 + } 213 + 214 + req.Header.Set("Content-Type", "application/json") 215 + req.Header.Set("Authorization", "Bearer "+serviceToken) 216 + 217 + // Send request with timeout 218 + client := &http.Client{Timeout: 10 * time.Second} 219 + resp, err := client.Do(req) 220 + if err != nil { 221 + return fmt.Errorf("request failed: %w", err) 222 + } 223 + defer resp.Body.Close() 224 + 225 + if resp.StatusCode != http.StatusOK { 226 + body, _ := io.ReadAll(resp.Body) 227 + return fmt.Errorf("setStats failed: status %d, body: %s", resp.StatusCode, body) 228 + } 229 + 230 + return nil 231 + }
+2 -2
pkg/appview/jetstream/backfill.go
··· 48 48 49 49 return &BackfillWorker{ 50 50 db: database, 51 - client: client, // This points to the relay 52 - processor: NewProcessor(database, false), // No cache for batch processing 51 + client: client, // This points to the relay 52 + processor: NewProcessor(database, false, nil), // No cache for batch processing, no stats 53 53 defaultHoldDID: defaultHoldDID, 54 54 testMode: testMode, 55 55 refresher: refresher,
+63 -6
pkg/appview/jetstream/processor.go
··· 16 16 // Processor handles shared database operations for both Worker (live) and Backfill (sync) 17 17 // This eliminates code duplication between the two data ingestion paths 18 18 type Processor struct { 19 - db *sql.DB 20 - userCache *UserCache // Optional - enabled for Worker, disabled for Backfill 21 - useCache bool 19 + db *sql.DB 20 + userCache *UserCache // Optional - enabled for Worker, disabled for Backfill 21 + statsCache *StatsCache // In-memory cache for per-hold stats aggregation 22 + useCache bool 22 23 } 23 24 24 25 // NewProcessor creates a new shared processor 25 26 // useCache: true for Worker (live streaming), false for Backfill (batch processing) 26 - func NewProcessor(database *sql.DB, useCache bool) *Processor { 27 + // statsCache: shared stats cache for aggregating across holds (nil to skip stats processing) 28 + func NewProcessor(database *sql.DB, useCache bool, statsCache *StatsCache) *Processor { 27 29 p := &Processor{ 28 - db: database, 29 - useCache: useCache, 30 + db: database, 31 + useCache: useCache, 32 + statsCache: statsCache, 30 33 } 31 34 32 35 if useCache { ··· 367 370 "new_handle", newHandle) 368 371 369 372 return nil 373 + } 374 + 375 + // ProcessStats handles stats record events from hold PDSes 376 + // This is called when Jetstream receives a stats create/update/delete event from a hold 377 + // The holdDID is the DID of the hold PDS (event.DID), and the record contains ownerDID + repository 378 + func (p *Processor) ProcessStats(ctx context.Context, holdDID string, recordData []byte, isDelete bool) error { 379 + // Skip if no stats cache configured 380 + if p.statsCache == nil { 381 + return nil 382 + } 383 + 384 + // Unmarshal stats record 385 + var statsRecord atproto.StatsRecord 386 + if err := json.Unmarshal(recordData, &statsRecord); err != nil { 387 + return fmt.Errorf("failed to unmarshal stats record: %w", err) 388 + } 389 + 390 + if isDelete { 391 + // Delete from in-memory cache 392 + p.statsCache.Delete(holdDID, statsRecord.OwnerDID, statsRecord.Repository) 393 + } else { 394 + // Parse timestamps 395 + var lastPull, lastPush *time.Time 396 + if statsRecord.LastPull != "" { 397 + t, err := time.Parse(time.RFC3339, statsRecord.LastPull) 398 + if err == nil { 399 + lastPull = &t 400 + } 401 + } 402 + if statsRecord.LastPush != "" { 403 + t, err := time.Parse(time.RFC3339, statsRecord.LastPush) 404 + if err == nil { 405 + lastPush = &t 406 + } 407 + } 408 + 409 + // Update in-memory cache 410 + p.statsCache.Update(holdDID, statsRecord.OwnerDID, statsRecord.Repository, 411 + statsRecord.PullCount, statsRecord.PushCount, lastPull, lastPush) 412 + } 413 + 414 + // Get aggregated stats across all holds 415 + totalPull, totalPush, latestPull, latestPush := p.statsCache.GetAggregated( 416 + statsRecord.OwnerDID, statsRecord.Repository) 417 + 418 + // Upsert aggregated stats to repository_stats 419 + return db.UpsertRepositoryStats(p.db, &db.RepositoryStats{ 420 + DID: statsRecord.OwnerDID, 421 + Repository: statsRecord.Repository, 422 + PullCount: int(totalPull), 423 + PushCount: int(totalPush), 424 + LastPull: latestPull, 425 + LastPush: latestPush, 426 + }) 370 427 } 371 428 372 429 // ProcessAccount handles account status events (deactivation/reactivation)
+9 -9
pkg/appview/jetstream/processor_test.go
··· 115 115 116 116 for _, tt := range tests { 117 117 t.Run(tt.name, func(t *testing.T) { 118 - p := NewProcessor(database, tt.useCache) 118 + p := NewProcessor(database, tt.useCache, nil) 119 119 if p == nil { 120 120 t.Fatal("NewProcessor returned nil") 121 121 } ··· 139 139 database := setupTestDB(t) 140 140 defer database.Close() 141 141 142 - p := NewProcessor(database, false) 142 + p := NewProcessor(database, false, nil) 143 143 ctx := context.Background() 144 144 145 145 // Create test manifest record ··· 238 238 database := setupTestDB(t) 239 239 defer database.Close() 240 240 241 - p := NewProcessor(database, false) 241 + p := NewProcessor(database, false, nil) 242 242 ctx := context.Background() 243 243 244 244 // Create test manifest list record ··· 322 322 database := setupTestDB(t) 323 323 defer database.Close() 324 324 325 - p := NewProcessor(database, false) 325 + p := NewProcessor(database, false, nil) 326 326 ctx := context.Background() 327 327 328 328 // Create test tag record (using ManifestDigest field for simplicity) ··· 403 403 database := setupTestDB(t) 404 404 defer database.Close() 405 405 406 - p := NewProcessor(database, false) 406 + p := NewProcessor(database, false, nil) 407 407 ctx := context.Background() 408 408 409 409 // Create test star record ··· 463 463 database := setupTestDB(t) 464 464 defer database.Close() 465 465 466 - p := NewProcessor(database, false) 466 + p := NewProcessor(database, false, nil) 467 467 ctx := context.Background() 468 468 469 469 manifestRecord := &atproto.ManifestRecord{ ··· 514 514 database := setupTestDB(t) 515 515 defer database.Close() 516 516 517 - p := NewProcessor(database, false) 517 + p := NewProcessor(database, false, nil) 518 518 ctx := context.Background() 519 519 520 520 // Manifest with nil annotations ··· 555 555 db := setupTestDB(t) 556 556 defer db.Close() 557 557 558 - processor := NewProcessor(db, false) 558 + processor := NewProcessor(db, false, nil) 559 559 560 560 // Setup: Create test user 561 561 testDID := "did:plc:alice123" ··· 621 621 db := setupTestDB(t) 622 622 defer db.Close() 623 623 624 - processor := NewProcessor(db, false) 624 + processor := NewProcessor(db, false, nil) 625 625 626 626 // Setup: Create test user 627 627 testDID := "did:plc:bob456"
+100
pkg/appview/jetstream/stats_cache.go
··· 1 + package jetstream 2 + 3 + import ( 4 + "sync" 5 + "time" 6 + ) 7 + 8 + // HoldRepoStats represents stats for a single owner+repo from a specific hold 9 + type HoldRepoStats struct { 10 + OwnerDID string 11 + Repository string 12 + PullCount int64 13 + PushCount int64 14 + LastPull *time.Time 15 + LastPush *time.Time 16 + } 17 + 18 + // StatsCache provides in-memory caching of per-hold stats with aggregation 19 + // This allows summing stats across multiple holds for the same owner+repo 20 + type StatsCache struct { 21 + mu sync.RWMutex 22 + // holdDID -> (ownerDID/repo -> stats) 23 + holds map[string]map[string]*HoldRepoStats 24 + } 25 + 26 + // NewStatsCache creates a new in-memory stats cache 27 + func NewStatsCache() *StatsCache { 28 + return &StatsCache{ 29 + holds: make(map[string]map[string]*HoldRepoStats), 30 + } 31 + } 32 + 33 + // makeKey creates a cache key from ownerDID and repository 34 + func makeKey(ownerDID, repo string) string { 35 + return ownerDID + "/" + repo 36 + } 37 + 38 + // Update stores or updates stats for a hold+owner+repo combination 39 + func (c *StatsCache) Update(holdDID, ownerDID, repo string, pullCount, pushCount int64, lastPull, lastPush *time.Time) { 40 + c.mu.Lock() 41 + defer c.mu.Unlock() 42 + 43 + // Ensure hold map exists 44 + if c.holds[holdDID] == nil { 45 + c.holds[holdDID] = make(map[string]*HoldRepoStats) 46 + } 47 + 48 + key := makeKey(ownerDID, repo) 49 + c.holds[holdDID][key] = &HoldRepoStats{ 50 + OwnerDID: ownerDID, 51 + Repository: repo, 52 + PullCount: pullCount, 53 + PushCount: pushCount, 54 + LastPull: lastPull, 55 + LastPush: lastPush, 56 + } 57 + } 58 + 59 + // Delete removes stats for a hold+owner+repo combination 60 + func (c *StatsCache) Delete(holdDID, ownerDID, repo string) { 61 + c.mu.Lock() 62 + defer c.mu.Unlock() 63 + 64 + if c.holds[holdDID] != nil { 65 + key := makeKey(ownerDID, repo) 66 + delete(c.holds[holdDID], key) 67 + } 68 + } 69 + 70 + // GetAggregated returns aggregated stats for an owner+repo by summing across all holds 71 + // Returns (pullCount, pushCount, lastPull, lastPush) 72 + func (c *StatsCache) GetAggregated(ownerDID, repo string) (int64, int64, *time.Time, *time.Time) { 73 + c.mu.RLock() 74 + defer c.mu.RUnlock() 75 + 76 + key := makeKey(ownerDID, repo) 77 + var totalPull, totalPush int64 78 + var latestPull, latestPush *time.Time 79 + 80 + for _, holdStats := range c.holds { 81 + if stats, ok := holdStats[key]; ok { 82 + totalPull += stats.PullCount 83 + totalPush += stats.PushCount 84 + 85 + // Track latest timestamps 86 + if stats.LastPull != nil { 87 + if latestPull == nil || stats.LastPull.After(*latestPull) { 88 + latestPull = stats.LastPull 89 + } 90 + } 91 + if stats.LastPush != nil { 92 + if latestPush == nil || stats.LastPush.After(*latestPush) { 93 + latestPush = stats.LastPush 94 + } 95 + } 96 + } 97 + } 98 + 99 + return totalPull, totalPush, latestPull, latestPush 100 + }
+39 -2
pkg/appview/jetstream/worker.go
··· 34 34 startCursor int64 35 35 wantedCollections []string 36 36 debugCollectionCount int 37 - processor *Processor // Shared processor for DB operations 37 + processor *Processor // Shared processor for DB operations 38 + statsCache *StatsCache // In-memory cache for stats aggregation across holds 38 39 eventCallback EventCallback 39 40 connStartTime time.Time // Track when connection started for debugging 40 41 ··· 56 57 jetstreamURL = "wss://jetstream2.us-west.bsky.network/subscribe" 57 58 } 58 59 60 + // Create shared stats cache for aggregating across holds 61 + statsCache := NewStatsCache() 62 + 59 63 return &Worker{ 60 64 db: database, 61 65 jetstreamURL: jetstreamURL, ··· 63 67 wantedCollections: []string{ 64 68 "io.atcr.*", // Subscribe to all ATCR collections 65 69 }, 66 - processor: NewProcessor(database, true), // Use cache for live streaming 70 + statsCache: statsCache, 71 + processor: NewProcessor(database, true, statsCache), // Use cache for live streaming 67 72 } 68 73 } 69 74 ··· 313 318 case atproto.RepoPageCollection: 314 319 slog.Info("Jetstream processing repo page event", "did", commit.DID, "operation", commit.Operation, "rkey", commit.RKey) 315 320 return w.processRepoPage(commit) 321 + case atproto.StatsCollection: 322 + slog.Info("Jetstream processing stats event", "did", commit.DID, "operation", commit.Operation, "rkey", commit.RKey) 323 + return w.processStats(commit) 316 324 default: 317 325 // Ignore other collections 318 326 return nil ··· 470 478 471 479 // Use shared processor for DB operations 472 480 return w.processor.ProcessRepoPage(context.Background(), commit.DID, commit.RKey, recordBytes, false) 481 + } 482 + 483 + // processStats processes a stats commit event from a hold PDS 484 + func (w *Worker) processStats(commit *CommitEvent) error { 485 + isDelete := commit.Operation == "delete" 486 + 487 + if isDelete { 488 + // For delete events, we need to parse the rkey to get ownerDID + repository 489 + // The rkey is deterministic: base32(sha256(ownerDID + "/" + repository)[:16]) 490 + // Unfortunately, we can't reverse this - we need the record data 491 + // Delete events don't include record data, so we can't delete from cache 492 + // This is acceptable - stats will be refreshed on next update from hold 493 + slog.Debug("Jetstream ignoring stats delete event (cannot reverse rkey)", "did", commit.DID, "rkey", commit.RKey) 494 + return nil 495 + } 496 + 497 + // Parse stats record 498 + if commit.Record == nil { 499 + return nil 500 + } 501 + 502 + // Marshal map to bytes for processing 503 + recordBytes, err := json.Marshal(commit.Record) 504 + if err != nil { 505 + return fmt.Errorf("failed to marshal record: %w", err) 506 + } 507 + 508 + // Use shared processor - commit.DID is the hold's DID 509 + return w.processor.ProcessStats(context.Background(), commit.DID, recordBytes, false) 473 510 } 474 511 475 512 // processIdentity processes an identity event (handle change)
+10 -10
pkg/appview/middleware/registry.go
··· 174 174 // After initialization, request handling uses the NamespaceResolver's instance fields. 175 175 var ( 176 176 globalRefresher *oauth.Refresher 177 - globalDatabase storage.DatabaseMetrics 177 + globalDatabase storage.HoldDIDLookup 178 178 globalAuthorizer auth.HoldAuthorizer 179 179 ) 180 180 ··· 186 186 187 187 // SetGlobalDatabase sets the database instance during initialization 188 188 // Must be called before the registry starts serving requests 189 - func SetGlobalDatabase(database storage.DatabaseMetrics) { 189 + func SetGlobalDatabase(database storage.HoldDIDLookup) { 190 190 globalDatabase = database 191 191 } 192 192 ··· 204 204 // NamespaceResolver wraps a namespace and resolves names 205 205 type NamespaceResolver struct { 206 206 distribution.Namespace 207 - defaultHoldDID string // Default hold DID (e.g., "did:web:hold01.atcr.io") 208 - baseURL string // Base URL for error messages (e.g., "https://atcr.io") 209 - testMode bool // If true, fallback to default hold when user's hold is unreachable 210 - refresher *oauth.Refresher // OAuth session manager (copied from global on init) 211 - database storage.DatabaseMetrics // Metrics database (copied from global on init) 212 - authorizer auth.HoldAuthorizer // Hold authorization (copied from global on init) 213 - validationCache *validationCache // Request-level service token cache 214 - readmeFetcher *readme.Fetcher // README fetcher for repo pages 207 + defaultHoldDID string // Default hold DID (e.g., "did:web:hold01.atcr.io") 208 + baseURL string // Base URL for error messages (e.g., "https://atcr.io") 209 + testMode bool // If true, fallback to default hold when user's hold is unreachable 210 + refresher *oauth.Refresher // OAuth session manager (copied from global on init) 211 + database storage.HoldDIDLookup // Database for hold DID lookups (copied from global on init) 212 + authorizer auth.HoldAuthorizer // Hold authorization (copied from global on init) 213 + validationCache *validationCache // Request-level service token cache 214 + readmeFetcher *readme.Fetcher // README fetcher for repo pages 215 215 } 216 216 217 217 // initATProtoResolver initializes the name resolution middleware
+3 -5
pkg/appview/storage/context.go
··· 7 7 "atcr.io/pkg/auth/oauth" 8 8 ) 9 9 10 - // DatabaseMetrics interface for tracking pull/push counts and querying hold DIDs 11 - type DatabaseMetrics interface { 12 - IncrementPullCount(did, repository string) error 13 - IncrementPushCount(did, repository string) error 10 + // HoldDIDLookup interface for querying hold DIDs from manifests 11 + type HoldDIDLookup interface { 14 12 GetLatestHoldDIDForRepo(did, repository string) (string, error) 15 13 } 16 14 ··· 32 30 PullerPDSEndpoint string // Puller's PDS endpoint URL 33 31 34 32 // Shared services (same for all requests) 35 - Database DatabaseMetrics // Metrics tracking database 33 + Database HoldDIDLookup // Database for hold DID lookups 36 34 Authorizer auth.HoldAuthorizer // Hold access authorization 37 35 Refresher *oauth.Refresher // OAuth session manager 38 36 ReadmeFetcher *readme.Fetcher // README fetcher for repo pages
+11 -41
pkg/appview/storage/context_test.go
··· 1 1 package storage 2 2 3 3 import ( 4 - "sync" 5 4 "testing" 6 5 7 6 "atcr.io/pkg/atproto" 8 7 ) 9 8 10 9 // Mock implementations for testing 11 - type mockDatabaseMetrics struct { 12 - mu sync.Mutex 13 - pullCount int 14 - pushCount int 15 - } 16 - 17 - func (m *mockDatabaseMetrics) IncrementPullCount(did, repository string) error { 18 - m.mu.Lock() 19 - defer m.mu.Unlock() 20 - m.pullCount++ 21 - return nil 22 - } 23 10 24 - func (m *mockDatabaseMetrics) IncrementPushCount(did, repository string) error { 25 - m.mu.Lock() 26 - defer m.mu.Unlock() 27 - m.pushCount++ 28 - return nil 29 - } 30 - 31 - func (m *mockDatabaseMetrics) GetLatestHoldDIDForRepo(did, repository string) (string, error) { 32 - // Return empty string for mock - tests can override if needed 33 - return "", nil 34 - } 35 - 36 - func (m *mockDatabaseMetrics) getPullCount() int { 37 - m.mu.Lock() 38 - defer m.mu.Unlock() 39 - return m.pullCount 11 + // mockHoldDIDLookup implements HoldDIDLookup for testing 12 + type mockHoldDIDLookup struct { 13 + holdDID string // Return value for GetLatestHoldDIDForRepo 40 14 } 41 15 42 - func (m *mockDatabaseMetrics) getPushCount() int { 43 - m.mu.Lock() 44 - defer m.mu.Unlock() 45 - return m.pushCount 16 + func (m *mockHoldDIDLookup) GetLatestHoldDIDForRepo(did, repository string) (string, error) { 17 + return m.holdDID, nil 46 18 } 47 19 48 20 type mockHoldAuthorizer struct{} ··· 63 35 ATProtoClient: &atproto.Client{ 64 36 // Mock client - would need proper initialization in real tests 65 37 }, 66 - Database: &mockDatabaseMetrics{}, 38 + Database: &mockHoldDIDLookup{holdDID: "did:web:hold01.atcr.io"}, 67 39 } 68 40 69 41 // Verify fields are accessible ··· 88 60 } 89 61 90 62 func TestRegistryContext_DatabaseInterface(t *testing.T) { 91 - db := &mockDatabaseMetrics{} 63 + db := &mockHoldDIDLookup{holdDID: "did:web:test-hold.example.com"} 92 64 ctx := &RegistryContext{ 93 65 Database: db, 94 66 } 95 67 96 - // Test that interface methods are callable 97 - err := ctx.Database.IncrementPullCount("did:plc:test", "repo") 68 + // Test that interface method is callable 69 + holdDID, err := ctx.Database.GetLatestHoldDIDForRepo("did:plc:test", "repo") 98 70 if err != nil { 99 71 t.Errorf("Unexpected error: %v", err) 100 72 } 101 - 102 - err = ctx.Database.IncrementPushCount("did:plc:test", "repo") 103 - if err != nil { 104 - t.Errorf("Unexpected error: %v", err) 73 + if holdDID != "did:web:test-hold.example.com" { 74 + t.Errorf("Expected holdDID %q, got %q", "did:web:test-hold.example.com", holdDID) 105 75 } 106 76 } 107 77
+66 -61
pkg/appview/storage/manifest_store.go
··· 73 73 } 74 74 } 75 75 76 - // Track pull count (increment asynchronously to avoid blocking the response) 76 + // Notify hold about manifest pull (for stats tracking) 77 77 // Only count GET requests (actual downloads), not HEAD requests (existence checks) 78 - if s.ctx.Database != nil { 79 - // Check HTTP method from context (distribution library stores it as "http.request.method") 80 - if method, ok := ctx.Value("http.request.method").(string); ok && method == "GET" { 78 + // Check HTTP method from context (distribution library stores it as "http.request.method") 79 + if method, ok := ctx.Value("http.request.method").(string); ok && method == "GET" { 80 + // Do this asynchronously to avoid blocking the response 81 + if s.ctx.ServiceToken != "" && s.ctx.Handle != "" { 81 82 go func() { 82 - if err := s.ctx.Database.IncrementPullCount(s.ctx.DID, s.ctx.Repository); err != nil { 83 - slog.Warn("Failed to increment pull count", "did", s.ctx.DID, "repository", s.ctx.Repository, "error", err) 83 + defer func() { 84 + if r := recover(); r != nil { 85 + slog.Error("Panic in notifyHoldAboutManifest (pull)", "panic", r) 86 + } 87 + }() 88 + if err := s.notifyHoldAboutManifest(context.Background(), nil, "", "", "pull"); err != nil { 89 + slog.Warn("Failed to notify hold about manifest pull", "error", err) 84 90 } 85 91 }() 86 92 } ··· 190 196 return "", fmt.Errorf("failed to store manifest record in ATProto: %w", err) 191 197 } 192 198 193 - // Track push count (increment asynchronously to avoid blocking the response) 194 - if s.ctx.Database != nil { 195 - go func() { 196 - if err := s.ctx.Database.IncrementPushCount(s.ctx.DID, s.ctx.Repository); err != nil { 197 - slog.Warn("Failed to increment push count", "did", s.ctx.DID, "repository", s.ctx.Repository, "error", err) 198 - } 199 - }() 200 - } 201 - 202 199 // Also handle tag if specified 203 200 var tag string 204 201 for _, option := range options { ··· 213 210 } 214 211 } 215 212 216 - // Notify hold about manifest upload (for layer tracking and Bluesky posts) 213 + // Notify hold about manifest push (for layer tracking, Bluesky posts, and stats) 217 214 // Do this asynchronously to avoid blocking the push 218 215 if tag != "" && s.ctx.ServiceToken != "" && s.ctx.Handle != "" { 219 216 go func() { ··· 222 219 slog.Error("Panic in notifyHoldAboutManifest", "panic", r) 223 220 } 224 221 }() 225 - if err := s.notifyHoldAboutManifest(context.Background(), manifestRecord, tag, dgst.String()); err != nil { 226 - slog.Warn("Failed to notify hold about manifest", "error", err) 222 + if err := s.notifyHoldAboutManifest(context.Background(), manifestRecord, tag, dgst.String(), "push"); err != nil { 223 + slog.Warn("Failed to notify hold about manifest push", "error", err) 227 224 } 228 225 }() 229 226 } ··· 298 295 return configJSON.Config.Labels, nil 299 296 } 300 297 301 - // notifyHoldAboutManifest notifies the hold service about a manifest upload 302 - // This enables the hold to create layer records and Bluesky posts 303 - func (s *ManifestStore) notifyHoldAboutManifest(ctx context.Context, manifestRecord *atproto.ManifestRecord, tag, manifestDigest string) error { 298 + // notifyHoldAboutManifest notifies the hold service about a manifest operation 299 + // For push: Creates layer records and optionally posts to Bluesky 300 + // For pull: Just increments stats (no layer records or posts) 301 + // operation should be "push" or "pull" 302 + func (s *ManifestStore) notifyHoldAboutManifest(ctx context.Context, manifestRecord *atproto.ManifestRecord, tag, manifestDigest, operation string) error { 304 303 // Skip if no service token configured (e.g., anonymous pulls) 305 304 if s.ctx.ServiceToken == "" { 306 305 return nil ··· 314 313 serviceToken := s.ctx.ServiceToken 315 314 316 315 // Build notification request 317 - manifestData := map[string]any{ 318 - "mediaType": manifestRecord.MediaType, 316 + notifyReq := map[string]any{ 317 + "repository": s.ctx.Repository, 318 + "userDid": s.ctx.DID, 319 + "userHandle": s.ctx.Handle, 320 + "operation": operation, 319 321 } 320 322 321 - // Add config if present (not present in manifest lists/indexes) 322 - if manifestRecord.Config != nil { 323 - manifestData["config"] = map[string]any{ 324 - "digest": manifestRecord.Config.Digest, 325 - "size": manifestRecord.Config.Size, 323 + // For push operations, include full manifest data 324 + if operation == "push" && manifestRecord != nil { 325 + notifyReq["tag"] = tag 326 + 327 + manifestData := map[string]any{ 328 + "mediaType": manifestRecord.MediaType, 326 329 } 327 - } 328 330 329 - // Add layers if present 330 - if len(manifestRecord.Layers) > 0 { 331 - layers := make([]map[string]any, len(manifestRecord.Layers)) 332 - for i, layer := range manifestRecord.Layers { 333 - layers[i] = map[string]any{ 334 - "digest": layer.Digest, 335 - "size": layer.Size, 336 - "mediaType": layer.MediaType, 331 + // Add config if present (not present in manifest lists/indexes) 332 + if manifestRecord.Config != nil { 333 + manifestData["config"] = map[string]any{ 334 + "digest": manifestRecord.Config.Digest, 335 + "size": manifestRecord.Config.Size, 337 336 } 338 337 } 339 - manifestData["layers"] = layers 340 - } 341 338 342 - // Add manifests if present (for multi-arch images / manifest lists) 343 - if len(manifestRecord.Manifests) > 0 { 344 - manifests := make([]map[string]any, len(manifestRecord.Manifests)) 345 - for i, m := range manifestRecord.Manifests { 346 - mData := map[string]any{ 347 - "digest": m.Digest, 348 - "size": m.Size, 349 - "mediaType": m.MediaType, 339 + // Add layers if present 340 + if len(manifestRecord.Layers) > 0 { 341 + layers := make([]map[string]any, len(manifestRecord.Layers)) 342 + for i, layer := range manifestRecord.Layers { 343 + layers[i] = map[string]any{ 344 + "digest": layer.Digest, 345 + "size": layer.Size, 346 + "mediaType": layer.MediaType, 347 + } 350 348 } 351 - if m.Platform != nil { 352 - mData["platform"] = map[string]any{ 353 - "os": m.Platform.OS, 354 - "architecture": m.Platform.Architecture, 349 + manifestData["layers"] = layers 350 + } 351 + 352 + // Add manifests if present (for multi-arch images / manifest lists) 353 + if len(manifestRecord.Manifests) > 0 { 354 + manifests := make([]map[string]any, len(manifestRecord.Manifests)) 355 + for i, m := range manifestRecord.Manifests { 356 + mData := map[string]any{ 357 + "digest": m.Digest, 358 + "size": m.Size, 359 + "mediaType": m.MediaType, 360 + } 361 + if m.Platform != nil { 362 + mData["platform"] = map[string]any{ 363 + "os": m.Platform.OS, 364 + "architecture": m.Platform.Architecture, 365 + } 355 366 } 367 + manifests[i] = mData 356 368 } 357 - manifests[i] = mData 369 + manifestData["manifests"] = manifests 358 370 } 359 - manifestData["manifests"] = manifests 360 - } 361 371 362 - notifyReq := map[string]any{ 363 - "repository": s.ctx.Repository, 364 - "tag": tag, 365 - "userDid": s.ctx.DID, 366 - "userHandle": s.ctx.Handle, 367 - "manifest": manifestData, 372 + notifyReq["manifest"] = manifestData 368 373 } 369 374 370 375 // Marshal request ··· 401 406 // Parse response (optional logging) 402 407 var notifyResp map[string]any 403 408 if err := json.NewDecoder(resp.Body).Decode(&notifyResp); err == nil { 404 - slog.Info("Hold notification successful", "repository", s.ctx.Repository, "tag", tag, "response", notifyResp) 409 + slog.Debug("Hold notification successful", "repository", s.ctx.Repository, "operation", operation, "response", notifyResp) 405 410 } 406 411 407 412 return nil
+12 -90
pkg/appview/storage/manifest_store_test.go
··· 8 8 "net/http" 9 9 "net/http/httptest" 10 10 "testing" 11 - "time" 12 11 13 12 "atcr.io/pkg/atproto" 14 13 "github.com/distribution/distribution/v3" 15 14 "github.com/opencontainers/go-digest" 16 15 ) 17 16 18 - // mockDatabaseMetrics removed - using the one from context_test.go 17 + // mockHoldDIDLookup defined in context_test.go 19 18 20 19 // mockBlobStore is a minimal mock of distribution.BlobStore for testing 21 20 type mockBlobStore struct { ··· 73 72 } 74 73 75 74 // mockRegistryContext creates a mock RegistryContext for testing 76 - func mockRegistryContext(client *atproto.Client, repository, holdDID, did, handle string, database DatabaseMetrics) *RegistryContext { 75 + func mockRegistryContext(client *atproto.Client, repository, holdDID, did, handle string, database HoldDIDLookup) *RegistryContext { 77 76 return &RegistryContext{ 78 77 ATProtoClient: client, 79 78 Repository: repository, ··· 117 116 func TestNewManifestStore(t *testing.T) { 118 117 client := atproto.NewClient("https://pds.example.com", "did:plc:test123", "token") 119 118 blobStore := newMockBlobStore() 120 - db := &mockDatabaseMetrics{} 119 + db := &mockHoldDIDLookup{} 121 120 122 121 ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:alice123", "alice.test", db) 123 122 store := NewManifestStore(ctx, blobStore) ··· 274 273 } 275 274 } 276 275 277 - // TestManifestStore_WithMetrics tests that metrics are tracked 278 - func TestManifestStore_WithMetrics(t *testing.T) { 279 - db := &mockDatabaseMetrics{} 276 + // TestManifestStore_WithDatabase tests that database is wired up 277 + func TestManifestStore_WithDatabase(t *testing.T) { 278 + db := &mockHoldDIDLookup{holdDID: "did:web:test-hold.example.com"} 280 279 client := atproto.NewClient("https://pds.example.com", "did:plc:test123", "token") 281 280 ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:alice123", "alice.test", db) 282 281 store := NewManifestStore(ctx, nil) ··· 285 284 t.Error("ManifestStore should store database reference") 286 285 } 287 286 288 - // Note: Actual metrics tracking happens in Put() and Get() which require 289 - // full mock setup. The important thing is that the database is wired up. 287 + // Database is used for hold DID lookups during blob routing 290 288 } 291 289 292 - // TestManifestStore_WithoutMetrics tests that nil database is acceptable 293 - func TestManifestStore_WithoutMetrics(t *testing.T) { 290 + // TestManifestStore_WithoutDatabase tests that nil database is acceptable 291 + func TestManifestStore_WithoutDatabase(t *testing.T) { 294 292 client := atproto.NewClient("https://pds.example.com", "did:plc:test123", "token") 295 293 ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:alice123", "alice.test", nil) 296 294 store := NewManifestStore(ctx, nil) ··· 464 462 defer server.Close() 465 463 466 464 client := atproto.NewClient(server.URL, "did:plc:test123", "token") 467 - db := &mockDatabaseMetrics{} 465 + db := &mockHoldDIDLookup{} 468 466 ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:test123", "test.handle", db) 469 467 store := NewManifestStore(ctx, nil) 470 468 ··· 487 485 } 488 486 } 489 487 490 - // TestManifestStore_Get_OnlyCountsGETRequests verifies that HEAD requests don't increment pull count 491 - func TestManifestStore_Get_OnlyCountsGETRequests(t *testing.T) { 492 - ociManifest := []byte(`{"schemaVersion":2}`) 493 - 494 - tests := []struct { 495 - name string 496 - httpMethod string 497 - expectPullIncrement bool 498 - }{ 499 - { 500 - name: "GET request increments pull count", 501 - httpMethod: "GET", 502 - expectPullIncrement: true, 503 - }, 504 - { 505 - name: "HEAD request does not increment pull count", 506 - httpMethod: "HEAD", 507 - expectPullIncrement: false, 508 - }, 509 - { 510 - name: "POST request does not increment pull count", 511 - httpMethod: "POST", 512 - expectPullIncrement: false, 513 - }, 514 - } 515 - 516 - for _, tt := range tests { 517 - t.Run(tt.name, func(t *testing.T) { 518 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 519 - if r.URL.Path == atproto.SyncGetBlob { 520 - w.Write(ociManifest) 521 - return 522 - } 523 - w.Write([]byte(`{ 524 - "uri": "at://did:plc:test123/io.atcr.manifest/abc123", 525 - "value": { 526 - "$type":"io.atcr.manifest", 527 - "holdDid":"did:web:hold01.atcr.io", 528 - "mediaType":"application/vnd.oci.image.manifest.v1+json", 529 - "manifestBlob":{"ref":{"$link":"bafytest"},"size":100} 530 - } 531 - }`)) 532 - })) 533 - defer server.Close() 534 - 535 - client := atproto.NewClient(server.URL, "did:plc:test123", "token") 536 - mockDB := &mockDatabaseMetrics{} 537 - ctx := mockRegistryContext(client, "myapp", "did:web:hold01.atcr.io", "did:plc:test123", "test.handle", mockDB) 538 - store := NewManifestStore(ctx, nil) 539 - 540 - // Create a context with the HTTP method stored (as distribution library does) 541 - testCtx := context.WithValue(context.Background(), "http.request.method", tt.httpMethod) 542 - 543 - _, err := store.Get(testCtx, "sha256:abc123") 544 - if err != nil { 545 - t.Fatalf("Get() error = %v", err) 546 - } 547 - 548 - // Wait for async goroutine to complete (metrics are incremented asynchronously) 549 - time.Sleep(50 * time.Millisecond) 550 - 551 - if tt.expectPullIncrement { 552 - // Check that IncrementPullCount was called 553 - if mockDB.getPullCount() == 0 { 554 - t.Error("Expected pull count to be incremented for GET request, but it wasn't") 555 - } 556 - } else { 557 - // Check that IncrementPullCount was NOT called 558 - if mockDB.getPullCount() > 0 { 559 - t.Errorf("Expected pull count NOT to be incremented for %s request, but it was (count=%d)", tt.httpMethod, mockDB.getPullCount()) 560 - } 561 - } 562 - }) 563 - } 564 - } 565 - 566 488 // TestManifestStore_Put tests storing manifests 567 489 func TestManifestStore_Put(t *testing.T) { 568 490 ociManifest := []byte(`{ ··· 655 577 defer server.Close() 656 578 657 579 client := atproto.NewClient(server.URL, "did:plc:test123", "token") 658 - db := &mockDatabaseMetrics{} 580 + db := &mockHoldDIDLookup{} 659 581 ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:test123", "test.handle", db) 660 582 store := NewManifestStore(ctx, nil) 661 583 ··· 939 861 defer server.Close() 940 862 941 863 client := atproto.NewClient(server.URL, "did:plc:test123", "token") 942 - db := &mockDatabaseMetrics{} 864 + db := &mockHoldDIDLookup{} 943 865 ctx := mockRegistryContext(client, "myapp", "did:web:hold.example.com", "did:plc:test123", "test.handle", db) 944 866 store := NewManifestStore(ctx, nil) 945 867
+381
pkg/atproto/cbor_gen.go
··· 1419 1419 1420 1420 return nil 1421 1421 } 1422 + func (t *StatsRecord) MarshalCBOR(w io.Writer) error { 1423 + if t == nil { 1424 + _, err := w.Write(cbg.CborNull) 1425 + return err 1426 + } 1427 + 1428 + cw := cbg.NewCborWriter(w) 1429 + fieldCount := 8 1430 + 1431 + if t.LastPull == "" { 1432 + fieldCount-- 1433 + } 1434 + 1435 + if t.LastPush == "" { 1436 + fieldCount-- 1437 + } 1438 + 1439 + if _, err := cw.Write(cbg.CborEncodeMajorType(cbg.MajMap, uint64(fieldCount))); err != nil { 1440 + return err 1441 + } 1442 + 1443 + // t.Type (string) (string) 1444 + if len("$type") > 8192 { 1445 + return xerrors.Errorf("Value in field \"$type\" was too long") 1446 + } 1447 + 1448 + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len("$type"))); err != nil { 1449 + return err 1450 + } 1451 + if _, err := cw.WriteString(string("$type")); err != nil { 1452 + return err 1453 + } 1454 + 1455 + if len(t.Type) > 8192 { 1456 + return xerrors.Errorf("Value in field t.Type was too long") 1457 + } 1458 + 1459 + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len(t.Type))); err != nil { 1460 + return err 1461 + } 1462 + if _, err := cw.WriteString(string(t.Type)); err != nil { 1463 + return err 1464 + } 1465 + 1466 + // t.LastPull (string) (string) 1467 + if t.LastPull != "" { 1468 + 1469 + if len("lastPull") > 8192 { 1470 + return xerrors.Errorf("Value in field \"lastPull\" was too long") 1471 + } 1472 + 1473 + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len("lastPull"))); err != nil { 1474 + return err 1475 + } 1476 + if _, err := cw.WriteString(string("lastPull")); err != nil { 1477 + return err 1478 + } 1479 + 1480 + if len(t.LastPull) > 8192 { 1481 + return xerrors.Errorf("Value in field t.LastPull was too long") 1482 + } 1483 + 1484 + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len(t.LastPull))); err != nil { 1485 + return err 1486 + } 1487 + if _, err := cw.WriteString(string(t.LastPull)); err != nil { 1488 + return err 1489 + } 1490 + } 1491 + 1492 + // t.LastPush (string) (string) 1493 + if t.LastPush != "" { 1494 + 1495 + if len("lastPush") > 8192 { 1496 + return xerrors.Errorf("Value in field \"lastPush\" was too long") 1497 + } 1498 + 1499 + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len("lastPush"))); err != nil { 1500 + return err 1501 + } 1502 + if _, err := cw.WriteString(string("lastPush")); err != nil { 1503 + return err 1504 + } 1505 + 1506 + if len(t.LastPush) > 8192 { 1507 + return xerrors.Errorf("Value in field t.LastPush was too long") 1508 + } 1509 + 1510 + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len(t.LastPush))); err != nil { 1511 + return err 1512 + } 1513 + if _, err := cw.WriteString(string(t.LastPush)); err != nil { 1514 + return err 1515 + } 1516 + } 1517 + 1518 + // t.OwnerDID (string) (string) 1519 + if len("ownerDid") > 8192 { 1520 + return xerrors.Errorf("Value in field \"ownerDid\" was too long") 1521 + } 1522 + 1523 + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len("ownerDid"))); err != nil { 1524 + return err 1525 + } 1526 + if _, err := cw.WriteString(string("ownerDid")); err != nil { 1527 + return err 1528 + } 1529 + 1530 + if len(t.OwnerDID) > 8192 { 1531 + return xerrors.Errorf("Value in field t.OwnerDID was too long") 1532 + } 1533 + 1534 + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len(t.OwnerDID))); err != nil { 1535 + return err 1536 + } 1537 + if _, err := cw.WriteString(string(t.OwnerDID)); err != nil { 1538 + return err 1539 + } 1540 + 1541 + // t.PullCount (int64) (int64) 1542 + if len("pullCount") > 8192 { 1543 + return xerrors.Errorf("Value in field \"pullCount\" was too long") 1544 + } 1545 + 1546 + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len("pullCount"))); err != nil { 1547 + return err 1548 + } 1549 + if _, err := cw.WriteString(string("pullCount")); err != nil { 1550 + return err 1551 + } 1552 + 1553 + if t.PullCount >= 0 { 1554 + if err := cw.WriteMajorTypeHeader(cbg.MajUnsignedInt, uint64(t.PullCount)); err != nil { 1555 + return err 1556 + } 1557 + } else { 1558 + if err := cw.WriteMajorTypeHeader(cbg.MajNegativeInt, uint64(-t.PullCount-1)); err != nil { 1559 + return err 1560 + } 1561 + } 1562 + 1563 + // t.PushCount (int64) (int64) 1564 + if len("pushCount") > 8192 { 1565 + return xerrors.Errorf("Value in field \"pushCount\" was too long") 1566 + } 1567 + 1568 + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len("pushCount"))); err != nil { 1569 + return err 1570 + } 1571 + if _, err := cw.WriteString(string("pushCount")); err != nil { 1572 + return err 1573 + } 1574 + 1575 + if t.PushCount >= 0 { 1576 + if err := cw.WriteMajorTypeHeader(cbg.MajUnsignedInt, uint64(t.PushCount)); err != nil { 1577 + return err 1578 + } 1579 + } else { 1580 + if err := cw.WriteMajorTypeHeader(cbg.MajNegativeInt, uint64(-t.PushCount-1)); err != nil { 1581 + return err 1582 + } 1583 + } 1584 + 1585 + // t.UpdatedAt (string) (string) 1586 + if len("updatedAt") > 8192 { 1587 + return xerrors.Errorf("Value in field \"updatedAt\" was too long") 1588 + } 1589 + 1590 + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len("updatedAt"))); err != nil { 1591 + return err 1592 + } 1593 + if _, err := cw.WriteString(string("updatedAt")); err != nil { 1594 + return err 1595 + } 1596 + 1597 + if len(t.UpdatedAt) > 8192 { 1598 + return xerrors.Errorf("Value in field t.UpdatedAt was too long") 1599 + } 1600 + 1601 + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len(t.UpdatedAt))); err != nil { 1602 + return err 1603 + } 1604 + if _, err := cw.WriteString(string(t.UpdatedAt)); err != nil { 1605 + return err 1606 + } 1607 + 1608 + // t.Repository (string) (string) 1609 + if len("repository") > 8192 { 1610 + return xerrors.Errorf("Value in field \"repository\" was too long") 1611 + } 1612 + 1613 + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len("repository"))); err != nil { 1614 + return err 1615 + } 1616 + if _, err := cw.WriteString(string("repository")); err != nil { 1617 + return err 1618 + } 1619 + 1620 + if len(t.Repository) > 8192 { 1621 + return xerrors.Errorf("Value in field t.Repository was too long") 1622 + } 1623 + 1624 + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len(t.Repository))); err != nil { 1625 + return err 1626 + } 1627 + if _, err := cw.WriteString(string(t.Repository)); err != nil { 1628 + return err 1629 + } 1630 + return nil 1631 + } 1632 + 1633 + func (t *StatsRecord) UnmarshalCBOR(r io.Reader) (err error) { 1634 + *t = StatsRecord{} 1635 + 1636 + cr := cbg.NewCborReader(r) 1637 + 1638 + maj, extra, err := cr.ReadHeader() 1639 + if err != nil { 1640 + return err 1641 + } 1642 + defer func() { 1643 + if err == io.EOF { 1644 + err = io.ErrUnexpectedEOF 1645 + } 1646 + }() 1647 + 1648 + if maj != cbg.MajMap { 1649 + return fmt.Errorf("cbor input should be of type map") 1650 + } 1651 + 1652 + if extra > cbg.MaxLength { 1653 + return fmt.Errorf("StatsRecord: map struct too large (%d)", extra) 1654 + } 1655 + 1656 + n := extra 1657 + 1658 + nameBuf := make([]byte, 10) 1659 + for i := uint64(0); i < n; i++ { 1660 + nameLen, ok, err := cbg.ReadFullStringIntoBuf(cr, nameBuf, 8192) 1661 + if err != nil { 1662 + return err 1663 + } 1664 + 1665 + if !ok { 1666 + // Field doesn't exist on this type, so ignore it 1667 + if err := cbg.ScanForLinks(cr, func(cid.Cid) {}); err != nil { 1668 + return err 1669 + } 1670 + continue 1671 + } 1672 + 1673 + switch string(nameBuf[:nameLen]) { 1674 + // t.Type (string) (string) 1675 + case "$type": 1676 + 1677 + { 1678 + sval, err := cbg.ReadStringWithMax(cr, 8192) 1679 + if err != nil { 1680 + return err 1681 + } 1682 + 1683 + t.Type = string(sval) 1684 + } 1685 + // t.LastPull (string) (string) 1686 + case "lastPull": 1687 + 1688 + { 1689 + sval, err := cbg.ReadStringWithMax(cr, 8192) 1690 + if err != nil { 1691 + return err 1692 + } 1693 + 1694 + t.LastPull = string(sval) 1695 + } 1696 + // t.LastPush (string) (string) 1697 + case "lastPush": 1698 + 1699 + { 1700 + sval, err := cbg.ReadStringWithMax(cr, 8192) 1701 + if err != nil { 1702 + return err 1703 + } 1704 + 1705 + t.LastPush = string(sval) 1706 + } 1707 + // t.OwnerDID (string) (string) 1708 + case "ownerDid": 1709 + 1710 + { 1711 + sval, err := cbg.ReadStringWithMax(cr, 8192) 1712 + if err != nil { 1713 + return err 1714 + } 1715 + 1716 + t.OwnerDID = string(sval) 1717 + } 1718 + // t.PullCount (int64) (int64) 1719 + case "pullCount": 1720 + { 1721 + maj, extra, err := cr.ReadHeader() 1722 + if err != nil { 1723 + return err 1724 + } 1725 + var extraI int64 1726 + switch maj { 1727 + case cbg.MajUnsignedInt: 1728 + extraI = int64(extra) 1729 + if extraI < 0 { 1730 + return fmt.Errorf("int64 positive overflow") 1731 + } 1732 + case cbg.MajNegativeInt: 1733 + extraI = int64(extra) 1734 + if extraI < 0 { 1735 + return fmt.Errorf("int64 negative overflow") 1736 + } 1737 + extraI = -1 - extraI 1738 + default: 1739 + return fmt.Errorf("wrong type for int64 field: %d", maj) 1740 + } 1741 + 1742 + t.PullCount = int64(extraI) 1743 + } 1744 + // t.PushCount (int64) (int64) 1745 + case "pushCount": 1746 + { 1747 + maj, extra, err := cr.ReadHeader() 1748 + if err != nil { 1749 + return err 1750 + } 1751 + var extraI int64 1752 + switch maj { 1753 + case cbg.MajUnsignedInt: 1754 + extraI = int64(extra) 1755 + if extraI < 0 { 1756 + return fmt.Errorf("int64 positive overflow") 1757 + } 1758 + case cbg.MajNegativeInt: 1759 + extraI = int64(extra) 1760 + if extraI < 0 { 1761 + return fmt.Errorf("int64 negative overflow") 1762 + } 1763 + extraI = -1 - extraI 1764 + default: 1765 + return fmt.Errorf("wrong type for int64 field: %d", maj) 1766 + } 1767 + 1768 + t.PushCount = int64(extraI) 1769 + } 1770 + // t.UpdatedAt (string) (string) 1771 + case "updatedAt": 1772 + 1773 + { 1774 + sval, err := cbg.ReadStringWithMax(cr, 8192) 1775 + if err != nil { 1776 + return err 1777 + } 1778 + 1779 + t.UpdatedAt = string(sval) 1780 + } 1781 + // t.Repository (string) (string) 1782 + case "repository": 1783 + 1784 + { 1785 + sval, err := cbg.ReadStringWithMax(cr, 8192) 1786 + if err != nil { 1787 + return err 1788 + } 1789 + 1790 + t.Repository = string(sval) 1791 + } 1792 + 1793 + default: 1794 + // Field doesn't exist on this type, so ignore it 1795 + if err := cbg.ScanForLinks(r, func(cid.Cid) {}); err != nil { 1796 + return err 1797 + } 1798 + } 1799 + } 1800 + 1801 + return nil 1802 + }
+6
pkg/atproto/endpoints.go
··· 45 45 // Request: {"repository": "...", "tag": "...", "userDid": "...", "userHandle": "...", "manifest": {...}} 46 46 // Response: {"success": true, "layersCreated": 5, "postCreated": true, "postUri": "at://..."} 47 47 HoldNotifyManifest = "/xrpc/io.atcr.hold.notifyManifest" 48 + 49 + // HoldSetStats sets absolute stats values for a repository (used by migration). 50 + // Method: POST 51 + // Request: {"ownerDid": "...", "repository": "...", "pullCount": 10, "pushCount": 5, "lastPull": "...", "lastPush": "..."} 52 + // Response: {"success": true} 53 + HoldSetStats = "/xrpc/io.atcr.hold.setStats" 48 54 ) 49 55 50 56 // Hold service crew management endpoints (io.atcr.hold.*)
+2 -1
pkg/atproto/generate.go
··· 25 25 ) 26 26 27 27 func main() { 28 - // Generate map-style encoders for CrewRecord, CaptainRecord, LayerRecord, and TangledProfileRecord 28 + // Generate map-style encoders 29 29 if err := cbg.WriteMapEncodersToFile("cbor_gen.go", "atproto", 30 30 atproto.CrewRecord{}, 31 31 atproto.CaptainRecord{}, 32 32 atproto.LayerRecord{}, 33 33 atproto.TangledProfileRecord{}, 34 + atproto.StatsRecord{}, 34 35 ); err != nil { 35 36 fmt.Printf("Failed to generate CBOR encoders: %v\n", err) 36 37 os.Exit(1)
+45
pkg/atproto/lexicon.go
··· 3 3 //go:generate go run generate.go 4 4 5 5 import ( 6 + "crypto/sha256" 7 + "encoding/base32" 6 8 "encoding/base64" 7 9 "encoding/json" 8 10 "fmt" ··· 34 36 // LayerCollection is the collection name for container layer metadata 35 37 // Stored in hold's embedded PDS to track which layers are stored 36 38 LayerCollection = "io.atcr.hold.layer" 39 + 40 + // StatsCollection is the collection name for repository statistics 41 + // Stored in hold's embedded PDS to track pull/push counts per owner+repo 42 + StatsCollection = "io.atcr.hold.stats" 37 43 38 44 // TangledProfileCollection is the collection name for tangled profiles 39 45 // Stored in hold's embedded PDS (singleton record at rkey "self") ··· 618 624 UserHandle: userHandle, 619 625 CreatedAt: time.Now().Format(time.RFC3339), 620 626 } 627 + } 628 + 629 + // StatsRecord represents repository statistics stored in the hold's PDS 630 + // Collection: io.atcr.hold.stats 631 + // Stored in the hold's embedded PDS for tracking manifest pull/push counts 632 + // Uses CBOR encoding for efficient storage in hold's carstore 633 + // RKey is deterministic: base32(sha256(ownerDID + "/" + repository)[:16]) 634 + type StatsRecord struct { 635 + Type string `json:"$type" cborgen:"$type"` 636 + OwnerDID string `json:"ownerDid" cborgen:"ownerDid"` // DID of the image owner (e.g., "did:plc:xyz123") 637 + Repository string `json:"repository" cborgen:"repository"` // Repository name (e.g., "myapp") 638 + PullCount int64 `json:"pullCount" cborgen:"pullCount"` // Number of manifest downloads 639 + LastPull string `json:"lastPull,omitempty" cborgen:"lastPull,omitempty"` 640 + PushCount int64 `json:"pushCount" cborgen:"pushCount"` // Number of manifest uploads 641 + LastPush string `json:"lastPush,omitempty" cborgen:"lastPush,omitempty"` 642 + UpdatedAt string `json:"updatedAt" cborgen:"updatedAt"` // RFC3339 timestamp 643 + } 644 + 645 + // NewStatsRecord creates a new stats record 646 + func NewStatsRecord(ownerDID, repository string) *StatsRecord { 647 + return &StatsRecord{ 648 + Type: StatsCollection, 649 + OwnerDID: ownerDID, 650 + Repository: repository, 651 + PullCount: 0, 652 + PushCount: 0, 653 + UpdatedAt: time.Now().Format(time.RFC3339), 654 + } 655 + } 656 + 657 + // StatsRecordKey generates a deterministic record key for stats 658 + // Uses base32 encoding of first 16 bytes of SHA-256 hash of "ownerDID/repository" 659 + // This ensures same owner+repo always maps to same rkey 660 + func StatsRecordKey(ownerDID, repository string) string { 661 + combined := ownerDID + "/" + repository 662 + hash := sha256.Sum256([]byte(combined)) 663 + // Use first 16 bytes (128 bits) for collision resistance 664 + // Encode with base32 (alphanumeric, lowercase, no padding) for ATProto rkey compatibility 665 + return strings.ToLower(base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(hash[:16])) 621 666 } 622 667 623 668 // TangledProfileRecord represents a Tangled profile for the hold
+159 -75
pkg/hold/oci/xrpc.go
··· 50 50 r.Post(atproto.HoldCompleteUpload, h.HandleCompleteUpload) 51 51 r.Post(atproto.HoldAbortUpload, h.HandleAbortUpload) 52 52 r.Post(atproto.HoldNotifyManifest, h.HandleNotifyManifest) 53 + r.Post(atproto.HoldSetStats, h.HandleSetStats) 53 54 }) 54 55 } 55 56 ··· 201 202 }) 202 203 } 203 204 204 - // HandleNotifyManifest handles manifest upload notifications from AppView 205 - // Creates layer records and optionally posts to Bluesky 205 + // HandleNotifyManifest handles manifest notifications from AppView 206 + // For pushes: Creates layer records and optionally posts to Bluesky 207 + // For pulls: Just increments stats (no layer records or posts) 208 + // Always increments stats (pull or push counts) 206 209 func (h *XRPCHandler) HandleNotifyManifest(w http.ResponseWriter, r *http.Request) { 207 210 ctx := r.Context() 208 211 ··· 219 222 Tag string `json:"tag"` 220 223 UserDID string `json:"userDid"` 221 224 UserHandle string `json:"userHandle"` 225 + Operation string `json:"operation"` // "push" or "pull", defaults to "push" for backward compatibility 222 226 Manifest struct { 223 227 MediaType string `json:"mediaType"` 224 228 Config struct { ··· 253 257 return 254 258 } 255 259 256 - // Check if manifest posts are enabled 257 - // Read from captain record (which is synced with HOLD_BLUESKY_POSTS_ENABLED env var) 258 - postsEnabled := false 259 - _, captain, err := h.pds.GetCaptainRecord(ctx) 260 - if err == nil { 261 - postsEnabled = captain.EnableBlueskyPosts 262 - } else { 263 - // Fallback to env var if captain record doesn't exist (shouldn't happen in normal operation) 264 - postsEnabled = h.enableBlueskyPosts 260 + // Default operation to "push" for backward compatibility 261 + operation := req.Operation 262 + if operation == "" { 263 + operation = "push" 265 264 } 266 265 267 - // Create layer records for each blob 268 - layersCreated := 0 269 - for _, layer := range req.Manifest.Layers { 270 - record := atproto.NewLayerRecord( 271 - layer.Digest, 272 - layer.Size, 273 - layer.MediaType, 274 - req.Repository, 275 - req.UserDID, 276 - req.UserHandle, 277 - ) 266 + // Validate operation 267 + if operation != "push" && operation != "pull" { 268 + RespondError(w, http.StatusBadRequest, fmt.Sprintf("invalid operation: %s (must be 'push' or 'pull')", operation)) 269 + return 270 + } 271 + 272 + var layersCreated int 273 + var postCreated bool 274 + var postURI string 278 275 279 - _, _, err := h.pds.CreateLayerRecord(ctx, record) 280 - if err != nil { 281 - slog.Error("Failed to create layer record", "error", err) 282 - // Continue creating other records 276 + // Only create layer records and Bluesky posts for pushes 277 + if operation == "push" { 278 + // Check if manifest posts are enabled 279 + // Read from captain record (which is synced with HOLD_BLUESKY_POSTS_ENABLED env var) 280 + postsEnabled := false 281 + _, captain, err := h.pds.GetCaptainRecord(ctx) 282 + if err == nil { 283 + postsEnabled = captain.EnableBlueskyPosts 283 284 } else { 284 - layersCreated++ 285 + // Fallback to env var if captain record doesn't exist (shouldn't happen in normal operation) 286 + postsEnabled = h.enableBlueskyPosts 285 287 } 286 - } 287 288 288 - // Check if this is a multi-arch image (has manifests instead of layers) 289 - isMultiArch := len(req.Manifest.Manifests) > 0 289 + // Create layer records for each blob 290 + for _, layer := range req.Manifest.Layers { 291 + record := atproto.NewLayerRecord( 292 + layer.Digest, 293 + layer.Size, 294 + layer.MediaType, 295 + req.Repository, 296 + req.UserDID, 297 + req.UserHandle, 298 + ) 290 299 291 - // Calculate total size from all layers (for single-arch images) 292 - var totalSize int64 293 - for _, layer := range req.Manifest.Layers { 294 - totalSize += layer.Size 295 - } 296 - totalSize += req.Manifest.Config.Size // Add config blob size 297 - 298 - // Extract platforms for multi-arch images 299 - var platforms []string 300 - if isMultiArch { 301 - for _, m := range req.Manifest.Manifests { 302 - if m.Platform != nil { 303 - platforms = append(platforms, m.Platform.OS+"/"+m.Platform.Architecture) 300 + _, _, err := h.pds.CreateLayerRecord(ctx, record) 301 + if err != nil { 302 + slog.Error("Failed to create layer record", "error", err) 303 + // Continue creating other records 304 + } else { 305 + layersCreated++ 304 306 } 305 307 } 306 - } 308 + 309 + // Check if this is a multi-arch image (has manifests instead of layers) 310 + isMultiArch := len(req.Manifest.Manifests) > 0 307 311 308 - // Create Bluesky post if enabled 309 - var postURI string 310 - postCreated := false 311 - if postsEnabled { 312 - // Extract manifest digest from first layer (or use config digest as fallback) 313 - manifestDigest := req.Manifest.Config.Digest 314 - if len(req.Manifest.Layers) > 0 { 315 - manifestDigest = req.Manifest.Layers[0].Digest 312 + // Calculate total size from all layers (for single-arch images) 313 + var totalSize int64 314 + for _, layer := range req.Manifest.Layers { 315 + totalSize += layer.Size 316 316 } 317 + totalSize += req.Manifest.Config.Size // Add config blob size 317 318 318 - postURI, err = h.pds.CreateManifestPost( 319 - ctx, 320 - h.driver, 321 - req.Repository, 322 - req.Tag, 323 - req.UserHandle, 324 - req.UserDID, 325 - manifestDigest, 326 - totalSize, 327 - platforms, 328 - ) 329 - if err != nil { 330 - slog.Error("Failed to create manifest post", "error", err) 331 - } else { 332 - postCreated = true 319 + // Extract platforms for multi-arch images 320 + var platforms []string 321 + if isMultiArch { 322 + for _, m := range req.Manifest.Manifests { 323 + if m.Platform != nil { 324 + platforms = append(platforms, m.Platform.OS+"/"+m.Platform.Architecture) 325 + } 326 + } 327 + } 328 + 329 + // Create Bluesky post if enabled 330 + if postsEnabled { 331 + // Extract manifest digest from first layer (or use config digest as fallback) 332 + manifestDigest := req.Manifest.Config.Digest 333 + if len(req.Manifest.Layers) > 0 { 334 + manifestDigest = req.Manifest.Layers[0].Digest 335 + } 336 + 337 + postURI, err = h.pds.CreateManifestPost( 338 + ctx, 339 + h.driver, 340 + req.Repository, 341 + req.Tag, 342 + req.UserHandle, 343 + req.UserDID, 344 + manifestDigest, 345 + totalSize, 346 + platforms, 347 + ) 348 + if err != nil { 349 + slog.Error("Failed to create manifest post", "error", err) 350 + } else { 351 + postCreated = true 352 + } 333 353 } 354 + } 355 + 356 + // ALWAYS increment stats (even if Bluesky posts disabled, even for pulls) 357 + statsUpdated := false 358 + if err := h.pds.IncrementStats(ctx, req.UserDID, req.Repository, operation); err != nil { 359 + slog.Error("Failed to increment stats", "operation", operation, "error", err) 360 + } else { 361 + statsUpdated = true 334 362 } 335 363 336 364 // Return response 337 365 resp := map[string]any{ 338 - "success": layersCreated > 0 || postCreated, 339 - "layersCreated": layersCreated, 340 - "postCreated": postCreated, 366 + "success": statsUpdated || layersCreated > 0 || postCreated, 367 + "operation": operation, 368 + "statsUpdated": statsUpdated, 369 + } 370 + 371 + // Only include push-specific fields for push operations 372 + if operation == "push" { 373 + resp["layersCreated"] = layersCreated 374 + resp["postCreated"] = postCreated 375 + if postURI != "" { 376 + resp["postUri"] = postURI 377 + } 378 + } 379 + 380 + RespondJSON(w, http.StatusOK, resp) 381 + } 382 + 383 + // HandleSetStats sets absolute stats values for a repository (used by migration) 384 + // This is a migration-only endpoint that allows AppView to sync existing stats to holds 385 + func (h *XRPCHandler) HandleSetStats(w http.ResponseWriter, r *http.Request) { 386 + ctx := r.Context() 387 + 388 + // Validate service token (same auth as blob:write endpoints) 389 + validatedUser, err := pds.ValidateBlobWriteAccess(r, h.pds, h.httpClient) 390 + if err != nil { 391 + RespondError(w, http.StatusForbidden, fmt.Sprintf("authorization failed: %v", err)) 392 + return 341 393 } 342 - if postURI != "" { 343 - resp["postUri"] = postURI 394 + 395 + // Parse request 396 + var req struct { 397 + OwnerDID string `json:"ownerDid"` 398 + Repository string `json:"repository"` 399 + PullCount int64 `json:"pullCount"` 400 + PushCount int64 `json:"pushCount"` 401 + LastPull string `json:"lastPull,omitempty"` 402 + LastPush string `json:"lastPush,omitempty"` 344 403 } 345 - if err != nil && layersCreated == 0 && !postCreated { 346 - resp["error"] = err.Error() 404 + 405 + if err := DecodeJSON(r, &req); err != nil { 406 + RespondError(w, http.StatusBadRequest, err.Error()) 407 + return 347 408 } 348 409 349 - RespondJSON(w, http.StatusOK, resp) 410 + // Verify user DID matches token (user can only set stats for their own repos) 411 + if req.OwnerDID != validatedUser.DID { 412 + RespondError(w, http.StatusForbidden, "owner DID mismatch") 413 + return 414 + } 415 + 416 + // Validate required fields 417 + if req.OwnerDID == "" || req.Repository == "" { 418 + RespondError(w, http.StatusBadRequest, "ownerDid and repository are required") 419 + return 420 + } 421 + 422 + // Set stats using the SetStats method 423 + if err := h.pds.SetStats(ctx, req.OwnerDID, req.Repository, req.PullCount, req.PushCount, req.LastPull, req.LastPush); err != nil { 424 + slog.Error("Failed to set stats", "error", err) 425 + RespondError(w, http.StatusInternalServerError, fmt.Sprintf("failed to set stats: %v", err)) 426 + return 427 + } 428 + 429 + slog.Info("Stats set via migration", "owner_did", req.OwnerDID, "repository", req.Repository, "pull_count", req.PullCount, "push_count", req.PushCount) 430 + 431 + RespondJSON(w, http.StatusOK, map[string]any{ 432 + "success": true, 433 + }) 350 434 } 351 435 352 436 // requireBlobWriteAccess middleware - validates DPoP + OAuth and checks for blob:write permission
+2 -1
pkg/hold/pds/server.go
··· 22 22 // init registers our custom ATProto types with indigo's lexutil type registry 23 23 // This allows repomgr.GetRecord to automatically unmarshal our types 24 24 func init() { 25 - // Register captain, crew, tangled profile, and layer record types 25 + // Register captain, crew, tangled profile, layer, and stats record types 26 26 // These must match the $type field in the records 27 27 lexutil.RegisterType(atproto.CaptainCollection, &atproto.CaptainRecord{}) 28 28 lexutil.RegisterType(atproto.CrewCollection, &atproto.CrewRecord{}) 29 29 lexutil.RegisterType(atproto.LayerCollection, &atproto.LayerRecord{}) 30 30 lexutil.RegisterType(atproto.TangledProfileCollection, &atproto.TangledProfileRecord{}) 31 + lexutil.RegisterType(atproto.StatsCollection, &atproto.StatsRecord{}) 31 32 } 32 33 33 34 // HoldPDS is a minimal ATProto PDS implementation for a hold service
+218
pkg/hold/pds/stats.go
··· 1 + package pds 2 + 3 + import ( 4 + "bytes" 5 + "context" 6 + "errors" 7 + "fmt" 8 + "log/slog" 9 + "strings" 10 + "time" 11 + 12 + "atcr.io/pkg/atproto" 13 + "github.com/bluesky-social/indigo/repo" 14 + "github.com/ipfs/go-cid" 15 + ) 16 + 17 + // IncrementStats increments the pull or push count for a repository 18 + // operation should be "pull" or "push" 19 + // Creates a new record if none exists, updates existing record otherwise 20 + func (p *HoldPDS) IncrementStats(ctx context.Context, ownerDID, repository, operation string) error { 21 + if operation != "pull" && operation != "push" { 22 + return fmt.Errorf("invalid operation: %s (must be 'pull' or 'push')", operation) 23 + } 24 + 25 + rkey := atproto.StatsRecordKey(ownerDID, repository) 26 + now := time.Now().Format(time.RFC3339) 27 + 28 + // Try to get existing record 29 + _, existing, err := p.GetStats(ctx, ownerDID, repository) 30 + if err != nil { 31 + // Record doesn't exist - create new one 32 + record := atproto.NewStatsRecord(ownerDID, repository) 33 + if operation == "pull" { 34 + record.PullCount = 1 35 + record.LastPull = now 36 + } else { 37 + record.PushCount = 1 38 + record.LastPush = now 39 + } 40 + record.UpdatedAt = now 41 + 42 + _, _, err := p.repomgr.PutRecord(ctx, p.uid, atproto.StatsCollection, rkey, record) 43 + if err != nil { 44 + return fmt.Errorf("failed to create stats record: %w", err) 45 + } 46 + 47 + slog.Debug("Created stats record", 48 + "ownerDID", ownerDID, 49 + "repository", repository, 50 + "operation", operation) 51 + return nil 52 + } 53 + 54 + // Record exists - update it 55 + if operation == "pull" { 56 + existing.PullCount++ 57 + existing.LastPull = now 58 + } else { 59 + existing.PushCount++ 60 + existing.LastPush = now 61 + } 62 + existing.UpdatedAt = now 63 + 64 + _, err = p.repomgr.UpdateRecord(ctx, p.uid, atproto.StatsCollection, rkey, existing) 65 + if err != nil { 66 + return fmt.Errorf("failed to update stats record: %w", err) 67 + } 68 + 69 + slog.Debug("Updated stats record", 70 + "ownerDID", ownerDID, 71 + "repository", repository, 72 + "operation", operation, 73 + "pullCount", existing.PullCount, 74 + "pushCount", existing.PushCount) 75 + return nil 76 + } 77 + 78 + // GetStats retrieves the stats record for a repository 79 + // Returns nil, nil if no stats record exists 80 + func (p *HoldPDS) GetStats(ctx context.Context, ownerDID, repository string) (cid.Cid, *atproto.StatsRecord, error) { 81 + rkey := atproto.StatsRecordKey(ownerDID, repository) 82 + 83 + recordCID, val, err := p.repomgr.GetRecord(ctx, p.uid, atproto.StatsCollection, rkey, cid.Undef) 84 + if err != nil { 85 + return cid.Undef, nil, err 86 + } 87 + 88 + statsRecord, ok := val.(*atproto.StatsRecord) 89 + if !ok { 90 + return cid.Undef, nil, fmt.Errorf("unexpected type for stats record: %T", val) 91 + } 92 + 93 + return recordCID, statsRecord, nil 94 + } 95 + 96 + // SetStats directly sets the stats for a repository (used for migration) 97 + // Creates or updates the stats record with the specified counts 98 + func (p *HoldPDS) SetStats(ctx context.Context, ownerDID, repository string, pullCount, pushCount int64, lastPull, lastPush string) error { 99 + rkey := atproto.StatsRecordKey(ownerDID, repository) 100 + now := time.Now().Format(time.RFC3339) 101 + 102 + // Try to get existing record 103 + _, existing, err := p.GetStats(ctx, ownerDID, repository) 104 + if err != nil { 105 + // Record doesn't exist - create new one 106 + record := &atproto.StatsRecord{ 107 + Type: atproto.StatsCollection, 108 + OwnerDID: ownerDID, 109 + Repository: repository, 110 + PullCount: pullCount, 111 + PushCount: pushCount, 112 + LastPull: lastPull, 113 + LastPush: lastPush, 114 + UpdatedAt: now, 115 + } 116 + 117 + _, _, err := p.repomgr.PutRecord(ctx, p.uid, atproto.StatsCollection, rkey, record) 118 + if err != nil { 119 + return fmt.Errorf("failed to create stats record: %w", err) 120 + } 121 + return nil 122 + } 123 + 124 + // Record exists - update it 125 + existing.PullCount = pullCount 126 + existing.PushCount = pushCount 127 + existing.LastPull = lastPull 128 + existing.LastPush = lastPush 129 + existing.UpdatedAt = now 130 + 131 + _, err = p.repomgr.UpdateRecord(ctx, p.uid, atproto.StatsCollection, rkey, existing) 132 + if err != nil { 133 + return fmt.Errorf("failed to update stats record: %w", err) 134 + } 135 + 136 + return nil 137 + } 138 + 139 + // ListStats returns all stats records in the hold's PDS 140 + // This is used by AppView to aggregate stats from all holds 141 + func (p *HoldPDS) ListStats(ctx context.Context) ([]*atproto.StatsRecord, error) { 142 + // Get read-only session from carstore 143 + session, err := p.carstore.ReadOnlySession(p.uid) 144 + if err != nil { 145 + return nil, fmt.Errorf("failed to get read-only session: %w", err) 146 + } 147 + 148 + // Get repo head 149 + head, err := p.carstore.GetUserRepoHead(ctx, p.uid) 150 + if err != nil { 151 + return nil, fmt.Errorf("failed to get repo head: %w", err) 152 + } 153 + 154 + if !head.Defined() { 155 + // No repo yet, return empty list 156 + return []*atproto.StatsRecord{}, nil 157 + } 158 + 159 + // Open repo 160 + r, err := repo.OpenRepo(ctx, session, head) 161 + if err != nil { 162 + return nil, fmt.Errorf("failed to open repo: %w", err) 163 + } 164 + 165 + var stats []*atproto.StatsRecord 166 + 167 + // Iterate over all stats records 168 + err = r.ForEach(ctx, atproto.StatsCollection, func(k string, v cid.Cid) error { 169 + // Extract collection and rkey from full path (k is like "io.atcr.hold.stats/abcd1234...") 170 + parts := strings.Split(k, "/") 171 + if len(parts) < 2 { 172 + return nil // Skip invalid keys 173 + } 174 + 175 + // Extract actual collection 176 + actualCollection := strings.Join(parts[:len(parts)-1], "/") 177 + 178 + // MST keys are sorted, so once we hit a different collection, stop walking 179 + if actualCollection != atproto.StatsCollection { 180 + return repo.ErrDoneIterating 181 + } 182 + 183 + // Get record bytes 184 + _, recBytes, err := r.GetRecordBytes(ctx, k) 185 + if err != nil { 186 + slog.Warn("Failed to get stats record bytes", "key", k, "error", err) 187 + return nil // Continue with other records 188 + } 189 + 190 + if recBytes == nil { 191 + return nil 192 + } 193 + 194 + // Unmarshal the CBOR bytes 195 + var statsRecord atproto.StatsRecord 196 + if err := statsRecord.UnmarshalCBOR(bytes.NewReader(*recBytes)); err != nil { 197 + slog.Warn("Failed to unmarshal stats record", "key", k, "error", err) 198 + return nil // Continue with other records 199 + } 200 + 201 + stats = append(stats, &statsRecord) 202 + return nil 203 + }) 204 + 205 + if err != nil { 206 + // ErrDoneIterating is expected when we stop walking early 207 + if errors.Is(err, repo.ErrDoneIterating) { 208 + // Successfully stopped at collection boundary 209 + } else if strings.Contains(err.Error(), "not found") { 210 + // Collection doesn't exist yet - return empty list 211 + return []*atproto.StatsRecord{}, nil 212 + } else { 213 + return nil, fmt.Errorf("failed to iterate stats records: %w", err) 214 + } 215 + } 216 + 217 + return stats, nil 218 + }