A container registry that uses the AT Protocol for manifest storage and S3 for blob storage. atcr.io
docker container atproto go

update scanner, fix tests, fix dockerfile, move keys to db instead of flat files for appview

evan.jarrett.net 49b914ba 2df53775

verified
+685 -88
+1 -1
Dockerfile.hold
··· 19 19 20 20 # Build frontend assets (Tailwind CSS, JS bundle, SVG icons) 21 21 RUN npm ci 22 - RUN npm run css:build && npm run css:copy-hold && npm run js:build:hold && npm run icons:build 22 + go generate ./... 23 23 24 24 # Conditionally add billing tag based on build arg 25 25 RUN if [ "$BILLING_ENABLED" = "true" ]; then \
-1
pkg/appview/config.go
··· 312 312 "service": cfg.Auth.ServiceName, 313 313 "issuer": cfg.Auth.ServiceName, 314 314 "rootcertbundle": cfg.Auth.CertPath, 315 - "privatekey": cfg.Auth.KeyPath, 316 315 "expiration": int(cfg.Auth.TokenExpiration.Seconds()), 317 316 }, 318 317 }
+182
pkg/appview/crypto_keys.go
··· 1 + package appview 2 + 3 + import ( 4 + "crypto/rand" 5 + "crypto/rsa" 6 + "crypto/x509" 7 + "crypto/x509/pkix" 8 + "database/sql" 9 + "encoding/pem" 10 + "fmt" 11 + "log/slog" 12 + "math/big" 13 + "os" 14 + "path/filepath" 15 + "time" 16 + 17 + "atcr.io/pkg/appview/db" 18 + "github.com/bluesky-social/indigo/atproto/atcrypto" 19 + ) 20 + 21 + // loadOAuthKey loads the OAuth P-256 key with priority: DB → file → generate. 22 + // Keys loaded from file or newly generated are stored in the DB. 23 + func loadOAuthKey(database *sql.DB, keyPath string) (*atcrypto.PrivateKeyP256, error) { 24 + // Try database first 25 + data, err := db.GetCryptoKey(database, "oauth_p256") 26 + if err != nil { 27 + return nil, fmt.Errorf("failed to query crypto_keys: %w", err) 28 + } 29 + if data != nil { 30 + key, err := atcrypto.ParsePrivateBytesP256(data) 31 + if err != nil { 32 + return nil, fmt.Errorf("failed to parse OAuth key from database: %w", err) 33 + } 34 + slog.Info("Loaded OAuth P-256 key from database") 35 + return key, nil 36 + } 37 + 38 + // Try file fallback 39 + if keyPath != "" { 40 + if fileData, err := os.ReadFile(keyPath); err == nil { 41 + key, err := atcrypto.ParsePrivateBytesP256(fileData) 42 + if err != nil { 43 + return nil, fmt.Errorf("failed to parse OAuth key from file %s: %w", keyPath, err) 44 + } 45 + // Migrate to database 46 + if err := db.PutCryptoKey(database, "oauth_p256", fileData); err != nil { 47 + return nil, fmt.Errorf("failed to store OAuth key in database: %w", err) 48 + } 49 + slog.Info("Migrated OAuth P-256 key from file to database", "path", keyPath) 50 + return key, nil 51 + } 52 + } 53 + 54 + // Generate new key 55 + p256Key, err := atcrypto.GeneratePrivateKeyP256() 56 + if err != nil { 57 + return nil, fmt.Errorf("failed to generate OAuth P-256 key: %w", err) 58 + } 59 + 60 + keyBytes := p256Key.Bytes() 61 + if err := db.PutCryptoKey(database, "oauth_p256", keyBytes); err != nil { 62 + return nil, fmt.Errorf("failed to store generated OAuth key in database: %w", err) 63 + } 64 + slog.Info("Generated new OAuth P-256 key and stored in database") 65 + 66 + return p256Key, nil 67 + } 68 + 69 + // loadJWTKeyAndCert loads the JWT RSA key from DB (with file fallback) and generates 70 + // a self-signed certificate. The cert is always regenerated and written to certPath 71 + // on disk because the distribution library reads it via os.Open(). 72 + func loadJWTKeyAndCert(database *sql.DB, keyPath, certPath string) (*rsa.PrivateKey, []byte, error) { 73 + rsaKey, err := loadRSAKey(database, keyPath) 74 + if err != nil { 75 + return nil, nil, err 76 + } 77 + 78 + // Generate cert and write to disk for distribution library 79 + certDER, err := generateAndWriteCert(rsaKey, certPath) 80 + if err != nil { 81 + return nil, nil, err 82 + } 83 + 84 + return rsaKey, certDER, nil 85 + } 86 + 87 + // loadRSAKey loads the RSA private key with priority: DB → file → generate. 88 + func loadRSAKey(database *sql.DB, keyPath string) (*rsa.PrivateKey, error) { 89 + // Try database first 90 + data, err := db.GetCryptoKey(database, "jwt_rsa") 91 + if err != nil { 92 + return nil, fmt.Errorf("failed to query crypto_keys: %w", err) 93 + } 94 + if data != nil { 95 + key, err := parseRSAKeyPEM(data) 96 + if err != nil { 97 + return nil, fmt.Errorf("failed to parse RSA key from database: %w", err) 98 + } 99 + slog.Info("Loaded JWT RSA key from database") 100 + return key, nil 101 + } 102 + 103 + // Try file fallback 104 + if keyPath != "" { 105 + if fileData, err := os.ReadFile(keyPath); err == nil { 106 + key, err := parseRSAKeyPEM(fileData) 107 + if err != nil { 108 + return nil, fmt.Errorf("failed to parse RSA key from file %s: %w", keyPath, err) 109 + } 110 + // Migrate to database 111 + if err := db.PutCryptoKey(database, "jwt_rsa", fileData); err != nil { 112 + return nil, fmt.Errorf("failed to store RSA key in database: %w", err) 113 + } 114 + slog.Info("Migrated JWT RSA key from file to database", "path", keyPath) 115 + return key, nil 116 + } 117 + } 118 + 119 + // Generate new key 120 + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) 121 + if err != nil { 122 + return nil, fmt.Errorf("failed to generate RSA key: %w", err) 123 + } 124 + 125 + keyPEM := pem.EncodeToMemory(&pem.Block{ 126 + Type: "RSA PRIVATE KEY", 127 + Bytes: x509.MarshalPKCS1PrivateKey(rsaKey), 128 + }) 129 + if err := db.PutCryptoKey(database, "jwt_rsa", keyPEM); err != nil { 130 + return nil, fmt.Errorf("failed to store generated RSA key in database: %w", err) 131 + } 132 + slog.Info("Generated new JWT RSA key and stored in database") 133 + 134 + return rsaKey, nil 135 + } 136 + 137 + func parseRSAKeyPEM(data []byte) (*rsa.PrivateKey, error) { 138 + block, _ := pem.Decode(data) 139 + if block == nil || block.Type != "RSA PRIVATE KEY" { 140 + return nil, fmt.Errorf("failed to decode PEM block containing RSA private key") 141 + } 142 + return x509.ParsePKCS1PrivateKey(block.Bytes) 143 + } 144 + 145 + // generateAndWriteCert creates a self-signed certificate from the RSA key and writes 146 + // it to certPath. Returns the DER-encoded certificate bytes for the JWT x5c header. 147 + func generateAndWriteCert(rsaKey *rsa.PrivateKey, certPath string) ([]byte, error) { 148 + template := x509.Certificate{ 149 + SerialNumber: big.NewInt(1), 150 + Subject: pkix.Name{ 151 + Organization: []string{"ATCR"}, 152 + CommonName: "ATCR Token Signing Certificate", 153 + }, 154 + NotBefore: time.Now(), 155 + NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour), 156 + KeyUsage: x509.KeyUsageDigitalSignature, 157 + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, 158 + BasicConstraintsValid: true, 159 + } 160 + 161 + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &rsaKey.PublicKey, rsaKey) 162 + if err != nil { 163 + return nil, fmt.Errorf("failed to create certificate: %w", err) 164 + } 165 + 166 + // Write cert to disk for distribution library 167 + certPEM := pem.EncodeToMemory(&pem.Block{ 168 + Type: "CERTIFICATE", 169 + Bytes: certDER, 170 + }) 171 + 172 + dir := filepath.Dir(certPath) 173 + if err := os.MkdirAll(dir, 0700); err != nil { 174 + return nil, fmt.Errorf("failed to create cert directory: %w", err) 175 + } 176 + if err := os.WriteFile(certPath, certPEM, 0644); err != nil { 177 + return nil, fmt.Errorf("failed to write certificate: %w", err) 178 + } 179 + 180 + slog.Info("Generated JWT signing certificate", "path", certPath) 181 + return certDER, nil 182 + }
+26
pkg/appview/db/crypto_keys.go
··· 1 + package db 2 + 3 + import "database/sql" 4 + 5 + // GetCryptoKey retrieves a key by name from the database. 6 + // Returns nil, nil if no key with that name exists. 7 + func GetCryptoKey(db DBTX, name string) ([]byte, error) { 8 + var data []byte 9 + err := db.QueryRow("SELECT key_data FROM crypto_keys WHERE name = ?", name).Scan(&data) 10 + if err == sql.ErrNoRows { 11 + return nil, nil 12 + } 13 + if err != nil { 14 + return nil, err 15 + } 16 + return data, nil 17 + } 18 + 19 + // PutCryptoKey stores a key in the database, replacing any existing key with the same name. 20 + func PutCryptoKey(db DBTX, name string, data []byte) error { 21 + _, err := db.Exec( 22 + "INSERT INTO crypto_keys (name, key_data) VALUES (?, ?) ON CONFLICT(name) DO UPDATE SET key_data = excluded.key_data", 23 + name, data, 24 + ) 25 + return err 26 + }
+7
pkg/appview/db/migrations/0013_create_crypto_keys.yaml
··· 1 + description: Create crypto_keys table for storing signing keys in the database 2 + query: | 3 + CREATE TABLE IF NOT EXISTS crypto_keys ( 4 + name TEXT PRIMARY KEY, 5 + key_data BLOB NOT NULL, 6 + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP 7 + );
+6
pkg/appview/db/schema.sql
··· 237 237 FOREIGN KEY(did) REFERENCES users(did) ON DELETE CASCADE 238 238 ); 239 239 CREATE INDEX IF NOT EXISTS idx_repo_pages_did ON repo_pages(did); 240 + 241 + CREATE TABLE IF NOT EXISTS crypto_keys ( 242 + name TEXT PRIMARY KEY, 243 + key_data BLOB NOT NULL, 244 + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP 245 + );
+19 -19
pkg/appview/handlers/scan_result_test.go
··· 84 84 85 85 body := rr.Body.String() 86 86 87 - // Should contain severity badges 88 - if !strings.Contains(body, "badge-error") { 89 - t.Error("Expected body to contain badge-error for critical vulnerabilities") 87 + // Should contain vuln-strip severity boxes 88 + if !strings.Contains(body, "vuln-box-critical") { 89 + t.Error("Expected body to contain vuln-box-critical for critical vulnerabilities") 90 90 } 91 - if !strings.Contains(body, "C:2") { 92 - t.Error("Expected body to contain 'C:2' for critical count") 91 + if !strings.Contains(body, `data-tip="Critical">2<`) { 92 + t.Error("Expected critical count of 2") 93 93 } 94 - if !strings.Contains(body, "badge-warning") { 95 - t.Error("Expected body to contain badge-warning for high vulnerabilities") 94 + if !strings.Contains(body, "vuln-box-high") { 95 + t.Error("Expected body to contain vuln-box-high for high vulnerabilities") 96 96 } 97 - if !strings.Contains(body, "H:5") { 98 - t.Error("Expected body to contain 'H:5' for high count") 97 + if !strings.Contains(body, `data-tip="High">5<`) { 98 + t.Error("Expected high count of 5") 99 99 } 100 - if !strings.Contains(body, "M:10") { 101 - t.Error("Expected body to contain 'M:10' for medium count") 100 + if !strings.Contains(body, `data-tip="Medium">10<`) { 101 + t.Error("Expected medium count of 10") 102 102 } 103 - if !strings.Contains(body, "L:3") { 104 - t.Error("Expected body to contain 'L:3' for low count") 103 + if !strings.Contains(body, `data-tip="Low">3<`) { 104 + t.Error("Expected low count of 3") 105 105 } 106 106 // Should be clickable (has openVulnDetails) 107 107 if !strings.Contains(body, "openVulnDetails") { ··· 267 267 268 268 body := rr.Body.String() 269 269 270 - if !strings.Contains(body, "C:3") { 271 - t.Error("Expected body to contain 'C:3'") 270 + if !strings.Contains(body, `data-tip="Critical">3<`) { 271 + t.Error("Expected critical count of 3") 272 272 } 273 273 // Zero-count badges should NOT appear 274 274 if strings.Contains(body, "H:0") { ··· 346 346 } 347 347 348 348 // abc123 should have vulnerability badges 349 - if !strings.Contains(body, "C:2") { 350 - t.Error("Expected body to contain 'C:2' for abc123") 349 + if !strings.Contains(body, `data-tip="Critical">2<`) { 350 + t.Error("Expected critical count of 2 for abc123") 351 351 } 352 352 // def456 should have clean badge 353 353 if !strings.Contains(body, "Clean") { ··· 430 430 if !strings.Contains(body, `id="scan-badge-abc123"`) { 431 431 t.Error("Expected OOB span for abc123") 432 432 } 433 - if !strings.Contains(body, "C:1") { 434 - t.Error("Expected body to contain 'C:1'") 433 + if !strings.Contains(body, `data-tip="Critical">1<`) { 434 + t.Error("Expected critical count of 1") 435 435 } 436 436 }
+31 -27
pkg/appview/handlers/vuln_details_test.go
··· 17 17 18 18 // mockGrypeReport returns a minimal Grype JSON report 19 19 func mockGrypeReport() string { 20 + // Grype v0.107+ uses PascalCase JSON keys, and severity is nested under Metadata 20 21 report := map[string]any{ 21 22 "matches": []map[string]any{ 22 23 { 23 - "vulnerability": map[string]any{ 24 - "id": "CVE-2024-1234", 25 - "severity": "Critical", 26 - "fix": map[string]any{"versions": []string{"1.2.4"}, "state": "fixed"}, 24 + "Vulnerability": map[string]any{ 25 + "ID": "CVE-2024-1234", 26 + "Metadata": map[string]any{"Severity": "Critical"}, 27 + "Fix": map[string]any{"Versions": []string{"1.2.4"}, "State": "fixed"}, 27 28 }, 28 - "artifact": map[string]any{ 29 - "name": "libssl", 30 - "version": "1.1.1", 31 - "type": "deb", 29 + "Package": map[string]any{ 30 + "Name": "libssl", 31 + "Version": "1.1.1", 32 + "Type": "deb", 32 33 }, 33 34 }, 34 35 { 35 - "vulnerability": map[string]any{ 36 - "id": "CVE-2024-5678", 37 - "severity": "Low", 38 - "fix": map[string]any{"versions": []string{}, "state": "not-fixed"}, 36 + "Vulnerability": map[string]any{ 37 + "ID": "CVE-2024-5678", 38 + "Metadata": map[string]any{"Severity": "Low"}, 39 + "Fix": map[string]any{"Versions": []string{}, "State": "not-fixed"}, 39 40 }, 40 - "artifact": map[string]any{ 41 - "name": "zlib", 42 - "version": "1.2.11", 43 - "type": "deb", 41 + "Package": map[string]any{ 42 + "Name": "zlib", 43 + "Version": "1.2.11", 44 + "Type": "deb", 44 45 }, 45 46 }, 46 47 { 47 - "vulnerability": map[string]any{ 48 - "id": "GHSA-abcd-efgh-ijkl", 49 - "severity": "High", 50 - "fix": map[string]any{"versions": []string{"2.0.0"}, "state": "fixed"}, 48 + "Vulnerability": map[string]any{ 49 + "ID": "GHSA-abcd-efgh-ijkl", 50 + "Metadata": map[string]any{"Severity": "High"}, 51 + "Fix": map[string]any{"Versions": []string{"2.0.0"}, "State": "fixed"}, 51 52 }, 52 - "artifact": map[string]any{ 53 - "name": "express", 54 - "version": "4.17.1", 55 - "type": "npm", 53 + "Package": map[string]any{ 54 + "Name": "express", 55 + "Version": "4.17.1", 56 + "Type": "npm", 56 57 }, 57 58 }, 58 59 }, ··· 251 252 252 253 body := rr.Body.String() 253 254 254 - // Should show summary counts 255 - if !strings.Contains(body, "2 Critical") { 256 - t.Error("Expected body to contain '2 Critical' summary") 255 + // Should show summary counts in vuln-strip boxes 256 + if !strings.Contains(body, "vuln-box-critical") { 257 + t.Error("Expected body to contain vuln-box-critical in summary") 258 + } 259 + if !strings.Contains(body, `data-tip="Critical">2<`) { 260 + t.Error("Expected critical count of 2 in summary") 257 261 } 258 262 259 263 // Should indicate no detailed report
+11 -14
pkg/appview/server.go
··· 185 185 slog.Info("TEST_MODE enabled - will use HTTP for local DID resolution") 186 186 } 187 187 188 + // Load crypto keys from database (with file fallback and migration) 189 + oauthKey, err := loadOAuthKey(s.Database, cfg.Server.OAuthKeyPath) 190 + if err != nil { 191 + return nil, fmt.Errorf("failed to load OAuth key: %w", err) 192 + } 193 + 188 194 // Create OAuth client app 189 195 desiredScopes := oauth.GetDefaultScopes(defaultHoldDID) 190 - var err error 191 - s.OAuthClientApp, err = oauth.NewClientApp(baseURL, s.OAuthStore, desiredScopes, cfg.Server.OAuthKeyPath, cfg.Server.ClientName) 196 + s.OAuthClientApp, err = oauth.NewClientAppWithKey(baseURL, s.OAuthStore, desiredScopes, oauthKey, cfg.Server.ClientName) 192 197 if err != nil { 193 198 return nil, fmt.Errorf("failed to create OAuth client app: %w", err) 194 199 } ··· 404 409 405 410 // Create token issuer 406 411 if cfg.Distribution.Auth["token"] != nil { 407 - s.TokenIssuer, err = s.createTokenIssuer() 412 + rsaKey, certDER, err := loadJWTKeyAndCert(s.Database, cfg.Auth.KeyPath, cfg.Auth.CertPath) 408 413 if err != nil { 409 - return nil, fmt.Errorf("failed to create token issuer: %w", err) 414 + return nil, fmt.Errorf("failed to load JWT key material: %w", err) 410 415 } 411 - slog.Info("Auth keys initialized", "path", cfg.Auth.KeyPath) 416 + s.TokenIssuer = token.NewIssuerFromKey(rsaKey, certDER, cfg.Auth.ServiceName, cfg.Auth.ServiceName, cfg.Auth.TokenExpiration) 417 + slog.Info("Auth keys initialized") 412 418 } 413 419 414 420 // Create registry app (distribution library handler) ··· 593 599 return nil 594 600 } 595 601 596 - // createTokenIssuer creates a token issuer for auth handlers. 597 - func (s *AppViewServer) createTokenIssuer() (*token.Issuer, error) { 598 - return token.NewIssuer( 599 - s.Config.Auth.KeyPath, 600 - s.Config.Auth.ServiceName, 601 - s.Config.Auth.ServiceName, 602 - s.Config.Auth.TokenExpiration, 603 - ) 604 - } 605 602 606 603 // DomainRoutingMiddleware enforces three-tier domain routing: 607 604 //
+33
pkg/auth/oauth/client.go
··· 13 13 "time" 14 14 15 15 "atcr.io/pkg/atproto" 16 + "github.com/bluesky-social/indigo/atproto/atcrypto" 16 17 "github.com/bluesky-social/indigo/atproto/auth/oauth" 17 18 "github.com/bluesky-social/indigo/atproto/syntax" 18 19 ) ··· 86 87 } else { 87 88 config = oauth.NewLocalhostConfig(redirectURI, scopes) 88 89 90 + slog.Info("Using public OAuth client (localhost development)") 91 + } 92 + 93 + clientApp := oauth.NewClientApp(&config, store) 94 + clientApp.Dir = atproto.GetDirectory() 95 + 96 + return clientApp, nil 97 + } 98 + 99 + // NewClientAppWithKey creates an indigo OAuth ClientApp with a pre-loaded P-256 key. 100 + // Used by AppView when loading keys from the database instead of disk. 101 + // For localhost development, privateKey is ignored (public client). 102 + func NewClientAppWithKey(baseURL string, store oauth.ClientAuthStore, scopes []string, privateKey *atcrypto.PrivateKeyP256, clientName string) (*oauth.ClientApp, error) { 103 + var config oauth.ClientConfig 104 + redirectURI := RedirectURI(baseURL) 105 + 106 + if !isLocalhost(baseURL) { 107 + clientID := baseURL + "/oauth-client-metadata.json" 108 + config = oauth.NewPublicConfig(clientID, redirectURI, scopes) 109 + 110 + keyID, err := GenerateKeyID(privateKey) 111 + if err != nil { 112 + return nil, fmt.Errorf("failed to generate key ID: %w", err) 113 + } 114 + 115 + if err := config.SetClientSecret(privateKey, keyID); err != nil { 116 + return nil, fmt.Errorf("failed to configure confidential client: %w", err) 117 + } 118 + 119 + slog.Info("Configured confidential OAuth client", "key_id", keyID) 120 + } else { 121 + config = oauth.NewLocalhostConfig(redirectURI, scopes) 89 122 slog.Info("Using public OAuth client (localhost development)") 90 123 } 91 124
+13
pkg/auth/token/issuer.go
··· 59 59 }, nil 60 60 } 61 61 62 + // NewIssuerFromKey creates a JWT issuer from pre-loaded key material. 63 + // certDER is the DER-encoded X.509 certificate for the x5c JWT header. 64 + func NewIssuerFromKey(privateKey *rsa.PrivateKey, certDER []byte, issuer, service string, expiration time.Duration) *Issuer { 65 + return &Issuer{ 66 + privateKey: privateKey, 67 + publicKey: &privateKey.PublicKey, 68 + certificate: certDER, 69 + issuer: issuer, 70 + service: service, 71 + expiration: expiration, 72 + } 73 + } 74 + 62 75 // Issue creates and signs a new JWT token 63 76 func (i *Issuer) Issue(subject string, access []auth.AccessEntry, authMethod string) (string, error) { 64 77 claims := NewClaims(subject, i.issuer, i.service, i.expiration, access, authMethod)
+4
pkg/hold/config.go
··· 139 139 type ScannerConfig struct { 140 140 // Shared secret for scanner WebSocket authentication. Empty disables scanning. 141 141 Secret string `yaml:"secret" comment:"Shared secret for scanner WebSocket auth. Empty disables scanning."` 142 + 143 + // Minimum interval between re-scans of the same manifest. 0 disables proactive scanning. 144 + RescanInterval time.Duration `yaml:"rescan_interval" comment:"Minimum interval between re-scans of the same manifest. When set, the hold proactively scans manifests when the scanner is idle. Default: 24h. Set to 0 to disable."` 142 145 } 143 146 144 147 // DatabaseConfig defines embedded PDS database settings ··· 220 223 v.SetDefault("gc.enabled", false) 221 224 // Scanner defaults 222 225 v.SetDefault("scanner.secret", "") 226 + v.SetDefault("scanner.rescan_interval", "24h") 223 227 224 228 // Log shipper defaults 225 229 v.SetDefault("log_shipper.batch_size", 100)
+347 -23
pkg/hold/pds/scan_broadcaster.go
··· 7 7 "encoding/hex" 8 8 "encoding/json" 9 9 "fmt" 10 + "io" 10 11 "log/slog" 12 + "net/http" 13 + "net/url" 11 14 "strings" 12 15 "sync" 13 16 "time" ··· 33 36 ackTimeout time.Duration 34 37 secret string // Shared secret for scanner authentication 35 38 ownsDB bool // true when this broadcaster opened the connection itself 39 + 40 + // Proactive scan scheduling 41 + rescanInterval time.Duration // Minimum interval between re-scans (0 = disabled) 42 + stopCh chan struct{} // Signal to stop background goroutines 43 + wg sync.WaitGroup // Wait for background goroutines to finish 44 + userIdx int // Round-robin index through users for proactive scanning 45 + predecessorCache map[string]bool // holdDID → "is this hold's successor us?" 36 46 } 37 47 38 48 // ScanSubscriber represents a connected scanner WebSocket client ··· 80 90 81 91 // NewScanBroadcaster creates a new scan job broadcaster 82 92 // dbPath should point to a SQLite database file (e.g., "/path/to/pds/db.sqlite3") 83 - func NewScanBroadcaster(holdDID, holdEndpoint, secret, dbPath string, s3svc *s3.S3Service, holdPDS *HoldPDS) (*ScanBroadcaster, error) { 93 + func NewScanBroadcaster(holdDID, holdEndpoint, secret, dbPath string, s3svc *s3.S3Service, holdPDS *HoldPDS, rescanInterval time.Duration) (*ScanBroadcaster, error) { 84 94 dsn := dbPath 85 95 if dbPath != ":memory:" && !strings.HasPrefix(dbPath, "file:") { 86 96 dsn = "file:" + dbPath ··· 107 117 } 108 118 109 119 sb := &ScanBroadcaster{ 110 - subscribers: make([]*ScanSubscriber, 0), 111 - db: db, 112 - holdDID: holdDID, 113 - holdEndpoint: holdEndpoint, 114 - s3: s3svc, 115 - pds: holdPDS, 116 - ackTimeout: 5 * time.Minute, 117 - secret: secret, 118 - ownsDB: true, 120 + subscribers: make([]*ScanSubscriber, 0), 121 + db: db, 122 + holdDID: holdDID, 123 + holdEndpoint: holdEndpoint, 124 + s3: s3svc, 125 + pds: holdPDS, 126 + ackTimeout: 5 * time.Minute, 127 + secret: secret, 128 + ownsDB: true, 129 + rescanInterval: rescanInterval, 130 + stopCh: make(chan struct{}), 131 + predecessorCache: make(map[string]bool), 119 132 } 120 133 121 134 if err := sb.initSchema(); err != nil { ··· 124 137 } 125 138 126 139 // Start re-dispatch loop for timed-out jobs 140 + sb.wg.Add(1) 127 141 go sb.reDispatchLoop() 128 142 143 + // Start proactive scan loop if rescan interval is configured 144 + if rescanInterval > 0 { 145 + sb.wg.Add(1) 146 + go sb.proactiveScanLoop() 147 + slog.Info("Proactive scan scheduler started", "rescanInterval", rescanInterval) 148 + } 149 + 129 150 return sb, nil 130 151 } 131 152 132 153 // NewScanBroadcasterWithDB creates a scan job broadcaster using an existing *sql.DB connection. 133 154 // The caller is responsible for the DB lifecycle. 134 - func NewScanBroadcasterWithDB(holdDID, holdEndpoint, secret string, db *sql.DB, s3svc *s3.S3Service, holdPDS *HoldPDS) (*ScanBroadcaster, error) { 155 + func NewScanBroadcasterWithDB(holdDID, holdEndpoint, secret string, db *sql.DB, s3svc *s3.S3Service, holdPDS *HoldPDS, rescanInterval time.Duration) (*ScanBroadcaster, error) { 135 156 sb := &ScanBroadcaster{ 136 - subscribers: make([]*ScanSubscriber, 0), 137 - db: db, 138 - holdDID: holdDID, 139 - holdEndpoint: holdEndpoint, 140 - s3: s3svc, 141 - pds: holdPDS, 142 - ackTimeout: 5 * time.Minute, 143 - secret: secret, 144 - ownsDB: false, 157 + subscribers: make([]*ScanSubscriber, 0), 158 + db: db, 159 + holdDID: holdDID, 160 + holdEndpoint: holdEndpoint, 161 + s3: s3svc, 162 + pds: holdPDS, 163 + ackTimeout: 5 * time.Minute, 164 + secret: secret, 165 + ownsDB: false, 166 + rescanInterval: rescanInterval, 167 + stopCh: make(chan struct{}), 168 + predecessorCache: make(map[string]bool), 145 169 } 146 170 147 171 if err := sb.initSchema(); err != nil { 148 172 return nil, fmt.Errorf("failed to initialize scan_jobs schema: %w", err) 149 173 } 150 174 175 + sb.wg.Add(1) 151 176 go sb.reDispatchLoop() 177 + 178 + if rescanInterval > 0 { 179 + sb.wg.Add(1) 180 + go sb.proactiveScanLoop() 181 + slog.Info("Proactive scan scheduler started", "rescanInterval", rescanInterval) 182 + } 152 183 153 184 return sb, nil 154 185 } ··· 587 618 588 619 // reDispatchLoop periodically checks for timed-out jobs and re-dispatches them 589 620 func (sb *ScanBroadcaster) reDispatchLoop() { 621 + defer sb.wg.Done() 622 + 590 623 ticker := time.NewTicker(30 * time.Second) 591 624 defer ticker.Stop() 592 625 593 - for range ticker.C { 594 - sb.reDispatchTimedOut() 626 + for { 627 + select { 628 + case <-sb.stopCh: 629 + return 630 + case <-ticker.C: 631 + sb.reDispatchTimedOut() 632 + } 595 633 } 596 634 } 597 635 ··· 649 687 } 650 688 } 651 689 652 - // Close closes the scan broadcaster's database connection 690 + // Close stops background goroutines and closes the scan broadcaster's database connection 653 691 func (sb *ScanBroadcaster) Close() error { 692 + if sb.stopCh != nil { 693 + close(sb.stopCh) 694 + sb.wg.Wait() 695 + } 654 696 if sb.db != nil && sb.ownsDB { 655 697 return sb.db.Close() 656 698 } ··· 665 707 // ValidateScannerSecret checks if the provided secret matches 666 708 func (sb *ScanBroadcaster) ValidateScannerSecret(secret string) bool { 667 709 return sb.secret != "" && secret == sb.secret 710 + } 711 + 712 + // proactiveScanLoop periodically finds manifests needing scanning and enqueues jobs. 713 + // It fetches manifest records from users' PDS (the source of truth) and creates scan 714 + // jobs for manifests that haven't been scanned recently. 715 + func (sb *ScanBroadcaster) proactiveScanLoop() { 716 + defer sb.wg.Done() 717 + 718 + // Wait a bit before starting to let the system settle 719 + select { 720 + case <-sb.stopCh: 721 + return 722 + case <-time.After(30 * time.Second): 723 + } 724 + 725 + ticker := time.NewTicker(60 * time.Second) 726 + defer ticker.Stop() 727 + 728 + for { 729 + select { 730 + case <-sb.stopCh: 731 + slog.Info("Proactive scan loop stopped") 732 + return 733 + case <-ticker.C: 734 + sb.tryEnqueueProactiveScan() 735 + } 736 + } 737 + } 738 + 739 + // tryEnqueueProactiveScan finds the next manifest needing a scan and enqueues it. 740 + // Only enqueues one job per call to avoid flooding the scanner. 741 + func (sb *ScanBroadcaster) tryEnqueueProactiveScan() { 742 + if !sb.hasConnectedScanners() { 743 + return 744 + } 745 + if sb.hasActiveJobs() { 746 + return 747 + } 748 + 749 + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) 750 + defer cancel() 751 + 752 + // Get all users who have pushed to this hold 753 + stats, err := sb.pds.ListStats(ctx) 754 + if err != nil { 755 + slog.Error("Proactive scan: failed to list stats", "error", err) 756 + return 757 + } 758 + 759 + // Extract unique user DIDs 760 + seen := make(map[string]bool) 761 + var userDIDs []string 762 + for _, s := range stats { 763 + if !seen[s.OwnerDID] { 764 + seen[s.OwnerDID] = true 765 + userDIDs = append(userDIDs, s.OwnerDID) 766 + } 767 + } 768 + 769 + if len(userDIDs) == 0 { 770 + return 771 + } 772 + 773 + // Round-robin through users, trying each until we find work or exhaust the list 774 + for attempts := 0; attempts < len(userDIDs); attempts++ { 775 + idx := sb.userIdx % len(userDIDs) 776 + sb.userIdx++ 777 + userDID := userDIDs[idx] 778 + 779 + if sb.tryEnqueueForUser(ctx, userDID) { 780 + return // Enqueued one job, done for this tick 781 + } 782 + } 783 + } 784 + 785 + // tryEnqueueForUser fetches manifests from a user's PDS and enqueues a scan for the 786 + // first one that needs scanning. Returns true if a job was enqueued. 787 + func (sb *ScanBroadcaster) tryEnqueueForUser(ctx context.Context, userDID string) bool { 788 + // Resolve user DID to PDS endpoint and handle 789 + did, userHandle, pdsEndpoint, err := atproto.ResolveIdentity(ctx, userDID) 790 + if err != nil { 791 + slog.Debug("Proactive scan: failed to resolve user identity", 792 + "userDID", userDID, "error", err) 793 + return false 794 + } 795 + 796 + // Fetch manifest records from user's PDS 797 + client := atproto.NewClient(pdsEndpoint, did, "") 798 + var cursor string 799 + for { 800 + records, nextCursor, err := client.ListRecordsForRepo(ctx, did, atproto.ManifestCollection, 100, cursor) 801 + if err != nil { 802 + slog.Debug("Proactive scan: failed to list manifest records", 803 + "userDID", did, "pds", pdsEndpoint, "error", err) 804 + return false 805 + } 806 + 807 + for _, record := range records { 808 + var manifest atproto.ManifestRecord 809 + if err := json.Unmarshal(record.Value, &manifest); err != nil { 810 + slog.Debug("Proactive scan: failed to unmarshal manifest record", 811 + "uri", record.URI, "error", err) 812 + continue 813 + } 814 + 815 + // Check if this manifest belongs to us (directly or via successor) 816 + holdDID := manifest.HoldDID 817 + if holdDID == "" { 818 + holdDID = manifest.HoldEndpoint // Legacy field 819 + } 820 + if !sb.isOurManifest(ctx, holdDID) { 821 + continue 822 + } 823 + 824 + // Skip manifest lists (no layers to scan) 825 + if len(manifest.Layers) == 0 { 826 + continue 827 + } 828 + 829 + // Skip if config is nil (shouldn't happen for image manifests, but be safe) 830 + if manifest.Config == nil { 831 + continue 832 + } 833 + 834 + // Check if already scanned recently 835 + if sb.isRecentlyScanned(ctx, manifest.Digest) { 836 + continue 837 + } 838 + 839 + // Construct and enqueue scan job 840 + configJSON, _ := json.Marshal(manifest.Config) 841 + layersJSON, _ := json.Marshal(manifest.Layers) 842 + 843 + slog.Info("Enqueuing proactive scan", 844 + "manifestDigest", manifest.Digest, 845 + "repository", manifest.Repository, 846 + "userDID", did) 847 + 848 + if err := sb.Enqueue(&ScanJobEvent{ 849 + ManifestDigest: manifest.Digest, 850 + Repository: manifest.Repository, 851 + UserDID: did, 852 + UserHandle: userHandle, 853 + Tier: "deckhand", 854 + Config: configJSON, 855 + Layers: layersJSON, 856 + }); err != nil { 857 + slog.Error("Proactive scan: failed to enqueue", 858 + "manifest", manifest.Digest, "error", err) 859 + return false 860 + } 861 + return true 862 + } 863 + 864 + if nextCursor == "" || len(records) == 0 { 865 + break 866 + } 867 + cursor = nextCursor 868 + } 869 + 870 + return false 871 + } 872 + 873 + // isOurManifest checks if a manifest's holdDID matches this hold, either directly 874 + // or via successor (the manifest's hold has set us as its successor). 875 + func (sb *ScanBroadcaster) isOurManifest(ctx context.Context, holdDID string) bool { 876 + if holdDID == "" { 877 + return false 878 + } 879 + 880 + // Direct match 881 + if holdDID == sb.holdDID { 882 + return true 883 + } 884 + 885 + // Check predecessor cache 886 + if isPredecessor, cached := sb.predecessorCache[holdDID]; cached { 887 + return isPredecessor 888 + } 889 + 890 + // Fetch captain record from the other hold's PDS to check successor 891 + isPredecessor := sb.checkPredecessor(ctx, holdDID) 892 + sb.predecessorCache[holdDID] = isPredecessor 893 + return isPredecessor 894 + } 895 + 896 + // checkPredecessor fetches a hold's captain record to check if its successor is us. 897 + func (sb *ScanBroadcaster) checkPredecessor(ctx context.Context, holdDID string) bool { 898 + fetchCtx, cancel := context.WithTimeout(ctx, 5*time.Second) 899 + defer cancel() 900 + 901 + holdURL, err := atproto.ResolveHoldURL(fetchCtx, holdDID) 902 + if err != nil { 903 + slog.Debug("Proactive scan: failed to resolve predecessor hold URL", 904 + "holdDID", holdDID, "error", err) 905 + return false 906 + } 907 + 908 + // Fetch captain record: com.atproto.repo.getRecord 909 + recordURL := fmt.Sprintf("%s/xrpc/com.atproto.repo.getRecord?repo=%s&collection=%s&rkey=self", 910 + holdURL, 911 + url.QueryEscape(holdDID), 912 + url.QueryEscape(atproto.CaptainCollection), 913 + ) 914 + 915 + req, err := http.NewRequestWithContext(fetchCtx, "GET", recordURL, nil) 916 + if err != nil { 917 + return false 918 + } 919 + 920 + resp, err := http.DefaultClient.Do(req) 921 + if err != nil { 922 + slog.Debug("Proactive scan: failed to fetch predecessor captain record", 923 + "holdDID", holdDID, "error", err) 924 + return false 925 + } 926 + defer resp.Body.Close() 927 + 928 + if resp.StatusCode != http.StatusOK { 929 + return false 930 + } 931 + 932 + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) // 1MB limit 933 + if err != nil { 934 + return false 935 + } 936 + 937 + var envelope struct { 938 + Value json.RawMessage `json:"value"` 939 + } 940 + if err := json.Unmarshal(body, &envelope); err != nil { 941 + return false 942 + } 943 + 944 + var captain atproto.CaptainRecord 945 + if err := json.Unmarshal(envelope.Value, &captain); err != nil { 946 + return false 947 + } 948 + 949 + if captain.Successor == sb.holdDID { 950 + slog.Info("Proactive scan: discovered predecessor hold", 951 + "predecessorDID", holdDID, "successor", sb.holdDID) 952 + return true 953 + } 954 + 955 + return false 956 + } 957 + 958 + // isRecentlyScanned checks if a manifest has been scanned within the rescan interval. 959 + func (sb *ScanBroadcaster) isRecentlyScanned(ctx context.Context, manifestDigest string) bool { 960 + _, scanRecord, err := sb.pds.GetScanRecord(ctx, manifestDigest) 961 + if err != nil { 962 + return false // Not scanned or error reading → needs scanning 963 + } 964 + 965 + scannedAt, err := time.Parse(time.RFC3339, scanRecord.ScannedAt) 966 + if err != nil { 967 + return false // Can't parse timestamp → treat as needing scan 968 + } 969 + 970 + return time.Since(scannedAt) < sb.rescanInterval 971 + } 972 + 973 + // hasConnectedScanners returns true if at least one scanner is connected. 974 + func (sb *ScanBroadcaster) hasConnectedScanners() bool { 975 + sb.mu.RLock() 976 + defer sb.mu.RUnlock() 977 + return len(sb.subscribers) > 0 978 + } 979 + 980 + // hasActiveJobs returns true if there are any pending, assigned, or processing scan jobs. 981 + func (sb *ScanBroadcaster) hasActiveJobs() bool { 982 + var count int 983 + err := sb.db.QueryRow(` 984 + SELECT COUNT(*) FROM scan_jobs 985 + WHERE status IN ('pending', 'assigned', 'processing') 986 + `).Scan(&count) 987 + if err != nil { 988 + slog.Error("Failed to check active scan jobs", "error", err) 989 + return true // Assume busy on error 990 + } 991 + return count > 0 668 992 } 669 993 670 994 func generateSubscriberID() string {
+5 -3
pkg/hold/server.go
··· 193 193 // Initialize scan broadcaster if scanner secret is configured 194 194 if cfg.Scanner.Secret != "" { 195 195 holdDID := s.PDS.DID() 196 + rescanInterval := cfg.Scanner.RescanInterval 196 197 var sb *pds.ScanBroadcaster 197 198 if s.holdDB != nil { 198 - sb, err = pds.NewScanBroadcasterWithDB(holdDID, cfg.Server.PublicURL, cfg.Scanner.Secret, s.holdDB.DB, s3Service, s.PDS) 199 + sb, err = pds.NewScanBroadcasterWithDB(holdDID, cfg.Server.PublicURL, cfg.Scanner.Secret, s.holdDB.DB, s3Service, s.PDS, rescanInterval) 199 200 } else { 200 201 scanDBPath := cfg.Database.Path + "/db.sqlite3" 201 - sb, err = pds.NewScanBroadcaster(holdDID, cfg.Server.PublicURL, cfg.Scanner.Secret, scanDBPath, s3Service, s.PDS) 202 + sb, err = pds.NewScanBroadcaster(holdDID, cfg.Server.PublicURL, cfg.Scanner.Secret, scanDBPath, s3Service, s.PDS, rescanInterval) 202 203 } 203 204 if err != nil { 204 205 return nil, fmt.Errorf("failed to initialize scan broadcaster: %w", err) ··· 206 207 s.scanBroadcaster = sb 207 208 xrpcHandler.SetScanBroadcaster(sb) 208 209 ociHandler.SetScanBroadcaster(sb) 209 - slog.Info("Scan broadcaster initialized (scanner WebSocket enabled)") 210 + slog.Info("Scan broadcaster initialized (scanner WebSocket enabled)", 211 + "rescanInterval", rescanInterval) 210 212 } 211 213 212 214 // Initialize garbage collector