An easy-to-host PDS on the ATProtocol, MacOS. Grandma-approved.

feat(db): consume_authorization_code + store_oauth_refresh_token

+220
+220
crates/relay/src/db/oauth.rs
··· 162 162 Ok(()) 163 163 } 164 164 165 + /// A row read from `oauth_authorization_codes` during code exchange. 166 + pub struct AuthCodeRow { 167 + pub client_id: String, 168 + pub did: String, 169 + #[allow(dead_code)] 170 + pub code_challenge: String, 171 + #[allow(dead_code)] 172 + pub code_challenge_method: String, 173 + #[allow(dead_code)] 174 + pub redirect_uri: String, 175 + #[allow(dead_code)] 176 + pub scope: String, 177 + } 178 + 179 + /// Atomically consume an authorization code: SELECT + DELETE in one transaction. 180 + /// 181 + /// Returns `None` if the code does not exist or has already expired (`expires_at <= now`). 182 + /// Callers must treat `None` as `invalid_grant`. 183 + /// 184 + /// The code column stores the SHA-256 hex hash of the raw code bytes. Callers must 185 + /// hash the presented code before calling this function (use `routes::token::sha256_hex`). 186 + pub async fn consume_authorization_code( 187 + pool: &SqlitePool, 188 + code_hash: &str, 189 + ) -> Result<Option<AuthCodeRow>, sqlx::Error> { 190 + let mut tx = pool.begin().await?; 191 + 192 + let row: Option<(String, String, String, String, String, String)> = sqlx::query_as( 193 + "SELECT client_id, did, code_challenge, code_challenge_method, redirect_uri, scope \ 194 + FROM oauth_authorization_codes \ 195 + WHERE code = ? AND expires_at > datetime('now')", 196 + ) 197 + .bind(code_hash) 198 + .fetch_optional(&mut *tx) 199 + .await?; 200 + 201 + if row.is_some() { 202 + sqlx::query("DELETE FROM oauth_authorization_codes WHERE code = ?") 203 + .bind(code_hash) 204 + .execute(&mut *tx) 205 + .await?; 206 + } 207 + 208 + tx.commit().await?; 209 + 210 + Ok(row.map( 211 + |(client_id, did, code_challenge, code_challenge_method, redirect_uri, scope)| { 212 + AuthCodeRow { 213 + client_id, 214 + did, 215 + code_challenge, 216 + code_challenge_method, 217 + redirect_uri, 218 + scope, 219 + } 220 + }, 221 + )) 222 + } 223 + 224 + /// Store a new refresh token in `oauth_tokens`. 225 + /// 226 + /// `token_hash` is used as the row's `id` (PRIMARY KEY). This follows the same 227 + /// pattern as `oauth_authorization_codes` where `code` IS the hash. 228 + /// `scope` is always `'com.atproto.refresh'` for OAuth refresh tokens. 229 + /// `jkt` is the DPoP key thumbprint binding this token to the client's keypair. 230 + /// Expires 24 hours after insertion. 231 + pub async fn store_oauth_refresh_token( 232 + pool: &SqlitePool, 233 + token_hash: &str, 234 + client_id: &str, 235 + did: &str, 236 + jkt: &str, 237 + ) -> Result<(), sqlx::Error> { 238 + sqlx::query( 239 + "INSERT INTO oauth_tokens (id, client_id, did, scope, jkt, expires_at, created_at) \ 240 + VALUES (?, ?, ?, 'com.atproto.refresh', ?, datetime('now', '+24 hours'), datetime('now'))", 241 + ) 242 + .bind(token_hash) 243 + .bind(client_id) 244 + .bind(did) 245 + .bind(jkt) 246 + .execute(pool) 247 + .await?; 248 + Ok(()) 249 + } 250 + 165 251 #[cfg(test)] 166 252 mod tests { 167 253 use super::*; ··· 327 413 let pool = test_pool().await; 328 414 let result = get_oauth_signing_key(&pool).await.unwrap(); 329 415 assert!(result.is_none()); 416 + } 417 + 418 + /// Insert an account row needed to satisfy oauth_tokens FK. 419 + async fn insert_test_account(pool: &SqlitePool) { 420 + sqlx::query( 421 + "INSERT INTO accounts (did, email, password_hash, created_at, updated_at) \ 422 + VALUES ('did:plc:testaccount000000000000', 'test@example.com', NULL, \ 423 + datetime('now'), datetime('now'))", 424 + ) 425 + .execute(pool) 426 + .await 427 + .unwrap(); 428 + } 429 + 430 + #[tokio::test] 431 + async fn consume_authorization_code_returns_row_and_deletes_it() { 432 + let pool = test_pool().await; 433 + register_oauth_client( 434 + &pool, 435 + "https://app.example.com/client-metadata.json", 436 + r#"{"redirect_uris":["https://app.example.com/callback"]}"#, 437 + ) 438 + .await 439 + .unwrap(); 440 + insert_test_account(&pool).await; 441 + 442 + store_authorization_code( 443 + &pool, 444 + "hash-abc123", 445 + "https://app.example.com/client-metadata.json", 446 + "did:plc:testaccount000000000000", 447 + "s256challenge", 448 + "S256", 449 + "https://app.example.com/callback", 450 + "atproto", 451 + ) 452 + .await 453 + .unwrap(); 454 + 455 + let row = consume_authorization_code(&pool, "hash-abc123") 456 + .await 457 + .unwrap() 458 + .expect("code should be found"); 459 + 460 + assert_eq!(row.client_id, "https://app.example.com/client-metadata.json"); 461 + assert_eq!(row.did, "did:plc:testaccount000000000000"); 462 + 463 + // Second consume: must return None (already deleted). 464 + let second = consume_authorization_code(&pool, "hash-abc123").await.unwrap(); 465 + assert!(second.is_none(), "consumed code must not be found again (AC1.6)"); 466 + } 467 + 468 + #[tokio::test] 469 + async fn consume_authorization_code_returns_none_for_unknown_code() { 470 + let pool = test_pool().await; 471 + let result = consume_authorization_code(&pool, "nonexistent-hash").await.unwrap(); 472 + assert!(result.is_none()); 473 + } 474 + 475 + #[tokio::test] 476 + async fn consume_authorization_code_returns_none_for_expired_code() { 477 + // AC1.5: expired auth codes (>60s) are rejected. 478 + let pool = test_pool().await; 479 + register_oauth_client( 480 + &pool, 481 + "https://app.example.com/client-metadata.json", 482 + r#"{"redirect_uris":["https://app.example.com/callback"]}"#, 483 + ) 484 + .await 485 + .unwrap(); 486 + 487 + sqlx::query( 488 + "INSERT INTO accounts (did, email, password_hash, created_at, updated_at) \ 489 + VALUES ('did:plc:testaccount000000000000', 'test@example.com', NULL, \ 490 + datetime('now'), datetime('now'))", 491 + ) 492 + .execute(&pool) 493 + .await 494 + .unwrap(); 495 + 496 + // Insert an already-expired auth code directly (bypassing store_authorization_code's +60s default). 497 + sqlx::query( 498 + "INSERT INTO oauth_authorization_codes \ 499 + (code, client_id, did, code_challenge, code_challenge_method, redirect_uri, scope, expires_at, created_at) \ 500 + VALUES (?, ?, ?, ?, 'S256', ?, 'atproto', datetime('now', '-1 seconds'), datetime('now'))", 501 + ) 502 + .bind("expired-code-hash") 503 + .bind("https://app.example.com/client-metadata.json") 504 + .bind("did:plc:testaccount000000000000") 505 + .bind("s256challenge") 506 + .bind("https://app.example.com/callback") 507 + .execute(&pool) 508 + .await 509 + .unwrap(); 510 + 511 + let result = consume_authorization_code(&pool, "expired-code-hash") 512 + .await 513 + .unwrap(); 514 + assert!(result.is_none(), "expired auth code must return None (AC1.5)"); 515 + } 516 + 517 + #[tokio::test] 518 + async fn store_oauth_refresh_token_persists_row() { 519 + let pool = test_pool().await; 520 + register_oauth_client( 521 + &pool, 522 + "https://app.example.com/client-metadata.json", 523 + r#"{"redirect_uris":["https://app.example.com/callback"]}"#, 524 + ) 525 + .await 526 + .unwrap(); 527 + insert_test_account(&pool).await; 528 + 529 + store_oauth_refresh_token( 530 + &pool, 531 + "refresh-token-hash-01", 532 + "https://app.example.com/client-metadata.json", 533 + "did:plc:testaccount000000000000", 534 + "jkt-thumbprint", 535 + ) 536 + .await 537 + .unwrap(); 538 + 539 + let row: Option<(String, String, Option<String>)> = 540 + sqlx::query_as("SELECT id, scope, jkt FROM oauth_tokens WHERE id = ?") 541 + .bind("refresh-token-hash-01") 542 + .fetch_optional(&pool) 543 + .await 544 + .unwrap(); 545 + 546 + let (id, scope, jkt) = row.expect("refresh token row must exist"); 547 + assert_eq!(id, "refresh-token-hash-01"); 548 + assert_eq!(scope, "com.atproto.refresh", "scope must be com.atproto.refresh (AC1.3)"); 549 + assert_eq!(jkt.as_deref(), Some("jkt-thumbprint")); 330 550 } 331 551 }