···66[dependencies]
77axum = { version = "0.8.4", features = ["macros", "json"] }
88tokio = { version = "1.47.1", features = ["rt-multi-thread", "macros", "signal"] }
99-sqlx = { version = "0.8.6", features = ["runtime-tokio-rustls", "sqlite", "migrate"] }
99+sqlx = { version = "0.8.6", features = ["runtime-tokio-rustls", "sqlite", "migrate", "chrono"] }
1010dotenvy = "0.15.7"
1111serde = { version = "1.0", features = ["derive"] }
1212serde_json = "1.0"
···2222handlebars = { version = "6.3.2", features = ["rust-embed"] }
2323rust-embed = "8.7.2"
2424axum-template = { version = "3.0.0", features = ["handlebars"] }
2525+rand = "0.9.2"
2626+anyhow = "1.0.99"
2727+chrono = "0.4.41"
2828+sha2 = "0.10"
+5-6
README.md
···12121313## 2FA
14141515-- [x] Ability to turn on/off 2FA
1616-- [x] getSession overwrite to set the `emailAuthFactor` flag if the user has 2FA turned on
1717-- [x] send an email using the `PDS_EMAIL_SMTP_URL` with a handlebar email template like Bluesky's 2FA sign in email.
1818-- [ ] generate a 2FA code
1919-- [ ] createSession gatekeeping (It does stop logins, just eh, doesn't actually send a real code or check it yet)
2020-- [ ] oauth endpoint gatekeeping
1515+- Overrides The login endpoint to add 2FA for both Bluesky client logged in and OAuth logins
1616+- Overrides the settings endpoints as well. As long as you have a confirmed email you can turn on 2FA
21172218## Captcha on Create Account
23192420Future feature?
25212622# Setup
2323+2424+We are getting close! Testing now
27252826Nothing here yet! If you are brave enough to try before full release, let me know and I'll help you set it up.
2927But I want to run it locally on my own PDS first to test run it a bit.
···3735 path /xrpc/com.atproto.server.getSession
3836 path /xrpc/com.atproto.server.updateEmail
3937 path /xrpc/com.atproto.server.createSession
3838+ path /@atproto/oauth-provider/~api/sign-in
4039 }
41404241 handle @gatekeeper {
-3
migrations_bells_and_whistles/.keep
···11-# This directory holds SQLx migrations for the bells_and_whistles.sqlite database.
22-# It is intentionally empty for now; running `sqlx::migrate!` will still ensure the
33-# migrations table exists and succeed with zero migrations.
+524
src/helpers.rs
···11+use crate::AppState;
22+use crate::helpers::TokenCheckError::InvalidToken;
33+use anyhow::anyhow;
44+use axum::body::{Body, to_bytes};
55+use axum::extract::Request;
66+use axum::http::header::CONTENT_TYPE;
77+use axum::http::{HeaderMap, StatusCode, Uri};
88+use axum::response::{IntoResponse, Response};
99+use axum_template::TemplateEngine;
1010+use chrono::Utc;
1111+use lettre::message::{MultiPart, SinglePart, header};
1212+use lettre::{AsyncTransport, Message};
1313+use rand::Rng;
1414+use serde::de::DeserializeOwned;
1515+use serde_json::{Map, Value};
1616+use sha2::{Digest, Sha256};
1717+use sqlx::SqlitePool;
1818+use tracing::{error, log};
1919+2020+///Used to generate the email 2fa code
2121+const UPPERCASE_BASE32_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
2222+2323+/// The result of a proxied call that attempts to parse JSON.
2424+pub enum ProxiedResult<T> {
2525+ /// Successfully parsed JSON body along with original response headers.
2626+ Parsed { value: T, _headers: HeaderMap },
2727+ /// Could not or should not parse: return the original (or rebuilt) response as-is.
2828+ Passthrough(Response<Body>),
2929+}
3030+3131+/// Proxy the incoming request to the PDS base URL plus the provided path and attempt to parse
3232+/// the successful response body as JSON into `T`.
3333+///
3434+pub async fn proxy_get_json<T>(
3535+ state: &AppState,
3636+ mut req: Request,
3737+ path: &str,
3838+) -> Result<ProxiedResult<T>, StatusCode>
3939+where
4040+ T: DeserializeOwned,
4141+{
4242+ let uri = format!("{}{}", state.pds_base_url, path);
4343+ *req.uri_mut() = Uri::try_from(uri).map_err(|_| StatusCode::BAD_REQUEST)?;
4444+4545+ let result = state
4646+ .reverse_proxy_client
4747+ .request(req)
4848+ .await
4949+ .map_err(|_| StatusCode::BAD_REQUEST)?
5050+ .into_response();
5151+5252+ if result.status() != StatusCode::OK {
5353+ return Ok(ProxiedResult::Passthrough(result));
5454+ }
5555+5656+ let response_headers = result.headers().clone();
5757+ let body = result.into_body();
5858+ let body_bytes = to_bytes(body, usize::MAX)
5959+ .await
6060+ .map_err(|_| StatusCode::BAD_REQUEST)?;
6161+6262+ match serde_json::from_slice::<T>(&body_bytes) {
6363+ Ok(value) => Ok(ProxiedResult::Parsed {
6464+ value,
6565+ _headers: response_headers,
6666+ }),
6767+ Err(err) => {
6868+ error!(%err, "failed to parse proxied JSON response; returning original body");
6969+ let mut builder = Response::builder().status(StatusCode::OK);
7070+ if let Some(headers) = builder.headers_mut() {
7171+ *headers = response_headers;
7272+ }
7373+ let resp = builder
7474+ .body(Body::from(body_bytes))
7575+ .map_err(|_| StatusCode::BAD_REQUEST)?;
7676+ Ok(ProxiedResult::Passthrough(resp))
7777+ }
7878+ }
7979+}
8080+8181+/// Build a JSON error response with the required Content-Type header
8282+/// Content-Type: application/json;charset=utf-8
8383+/// Body shape: { "error": string, "message": string }
8484+pub fn json_error_response(
8585+ status: StatusCode,
8686+ error: impl Into<String>,
8787+ message: impl Into<String>,
8888+) -> Result<Response<Body>, StatusCode> {
8989+ let body_str = match serde_json::to_string(&serde_json::json!({
9090+ "error": error.into(),
9191+ "message": message.into(),
9292+ })) {
9393+ Ok(s) => s,
9494+ Err(_) => return Err(StatusCode::BAD_REQUEST),
9595+ };
9696+9797+ Response::builder()
9898+ .status(status)
9999+ .header(CONTENT_TYPE, "application/json;charset=utf-8")
100100+ .body(Body::from(body_str))
101101+ .map_err(|_| StatusCode::BAD_REQUEST)
102102+}
103103+104104+/// Build a JSON error response with the required Content-Type header
105105+/// Content-Type: application/json (oauth endpoint does not like utf ending)
106106+/// Body shape: { "error": string, "error_description": string }
107107+pub fn oauth_json_error_response(
108108+ status: StatusCode,
109109+ error: impl Into<String>,
110110+ message: impl Into<String>,
111111+) -> Result<Response<Body>, StatusCode> {
112112+ let body_str = match serde_json::to_string(&serde_json::json!({
113113+ "error": error.into(),
114114+ "error_description": message.into(),
115115+ })) {
116116+ Ok(s) => s,
117117+ Err(_) => return Err(StatusCode::BAD_REQUEST),
118118+ };
119119+120120+ Response::builder()
121121+ .status(status)
122122+ .header(CONTENT_TYPE, "application/json")
123123+ .body(Body::from(body_str))
124124+ .map_err(|_| StatusCode::BAD_REQUEST)
125125+}
126126+127127+/// Creates a random token of 10 characters for email 2FA
128128+pub fn get_random_token() -> String {
129129+ let mut rng = rand::rng();
130130+131131+ let mut full_code = String::with_capacity(10);
132132+ for _ in 0..10 {
133133+ let idx = rng.random_range(0..UPPERCASE_BASE32_CHARS.len());
134134+ full_code.push(UPPERCASE_BASE32_CHARS[idx] as char);
135135+ }
136136+137137+ //The PDS implementation creates in lowercase, then converts to uppercase.
138138+ //Just going a head and doing uppercase here.
139139+ let slice_one = &full_code[0..5].to_ascii_uppercase();
140140+ let slice_two = &full_code[5..10].to_ascii_uppercase();
141141+ format!("{slice_one}-{slice_two}")
142142+}
143143+144144+pub enum TokenCheckError {
145145+ InvalidToken,
146146+ ExpiredToken,
147147+}
148148+149149+pub enum AuthResult {
150150+ WrongIdentityOrPassword,
151151+ /// The string here is the email address to create a hint for oauth
152152+ TwoFactorRequired(String),
153153+ /// User does not have 2FA enabled, or using an app password, or passes it
154154+ ProxyThrough,
155155+ TokenCheckFailed(TokenCheckError),
156156+}
157157+158158+pub enum IdentifierType {
159159+ Email,
160160+ Did,
161161+ Handle,
162162+}
163163+164164+impl IdentifierType {
165165+ fn what_is_it(identifier: String) -> Self {
166166+ if identifier.contains("@") {
167167+ IdentifierType::Email
168168+ } else if identifier.contains("did:") {
169169+ IdentifierType::Did
170170+ } else {
171171+ IdentifierType::Handle
172172+ }
173173+ }
174174+}
175175+176176+/// Creates a hex string from the password and salt to find app passwords
177177+fn scrypt_hex(password: &str, salt: &str) -> anyhow::Result<String> {
178178+ let params = scrypt::Params::new(14, 8, 1, 64)?;
179179+ let mut derived = [0u8; 64];
180180+ scrypt::scrypt(password.as_bytes(), salt.as_bytes(), ¶ms, &mut derived)?;
181181+ Ok(hex::encode(derived))
182182+}
183183+184184+/// Hashes the app password. did is used as the salt.
185185+pub fn hash_app_password(did: &str, password: &str) -> anyhow::Result<String> {
186186+ let mut hasher = Sha256::new();
187187+ hasher.update(did.as_bytes());
188188+ let sha = hasher.finalize();
189189+ let salt = hex::encode(&sha[..16]);
190190+ let hash_hex = scrypt_hex(password, &salt)?;
191191+ Ok(format!("{salt}:{hash_hex}"))
192192+}
193193+194194+async fn verify_password(password: &str, password_scrypt: &str) -> anyhow::Result<bool> {
195195+ // Expected format: "salt:hash" where hash is hex of scrypt(password, salt, 64 bytes)
196196+ let mut parts = password_scrypt.splitn(2, ':');
197197+ let salt = match parts.next() {
198198+ Some(s) if !s.is_empty() => s,
199199+ _ => return Ok(false),
200200+ };
201201+ let stored_hash_hex = match parts.next() {
202202+ Some(h) if !h.is_empty() => h,
203203+ _ => return Ok(false),
204204+ };
205205+206206+ // Derive using the shared helper and compare
207207+ let derived_hex = match scrypt_hex(password, salt) {
208208+ Ok(h) => h,
209209+ Err(_) => return Ok(false),
210210+ };
211211+212212+ Ok(derived_hex.as_str() == stored_hash_hex)
213213+}
214214+215215+/// Handles the auth checks along with sending a 2fa email
216216+pub async fn preauth_check(
217217+ state: &AppState,
218218+ identifier: &str,
219219+ password: &str,
220220+ two_factor_code: Option<String>,
221221+ oauth: bool,
222222+) -> anyhow::Result<AuthResult> {
223223+ // Determine identifier type
224224+ let id_type = IdentifierType::what_is_it(identifier.to_string());
225225+226226+ // Query account DB for did and passwordScrypt based on identifier type
227227+ let account_row: Option<(String, String, String, String)> = match id_type {
228228+ IdentifierType::Email => {
229229+ sqlx::query_as::<_, (String, String, String, String)>(
230230+ "SELECT account.did, account.passwordScrypt, account.email, actor.handle
231231+ FROM actor
232232+ LEFT JOIN account ON actor.did = account.did
233233+ where account.email = ? LIMIT 1",
234234+ )
235235+ .bind(identifier)
236236+ .fetch_optional(&state.account_pool)
237237+ .await?
238238+ }
239239+ IdentifierType::Handle => {
240240+ sqlx::query_as::<_, (String, String, String, String)>(
241241+ "SELECT account.did, account.passwordScrypt, account.email, actor.handle
242242+ FROM actor
243243+ LEFT JOIN account ON actor.did = account.did
244244+ where actor.handle = ? LIMIT 1",
245245+ )
246246+ .bind(identifier)
247247+ .fetch_optional(&state.account_pool)
248248+ .await?
249249+ }
250250+ IdentifierType::Did => {
251251+ sqlx::query_as::<_, (String, String, String, String)>(
252252+ "SELECT account.did, account.passwordScrypt, account.email, actor.handle
253253+ FROM actor
254254+ LEFT JOIN account ON actor.did = account.did
255255+ where account.did = ? LIMIT 1",
256256+ )
257257+ .bind(identifier)
258258+ .fetch_optional(&state.account_pool)
259259+ .await?
260260+ }
261261+ };
262262+263263+ if let Some((did, password_scrypt, email, handle)) = account_row {
264264+ // Verify password before proceeding to 2FA email step
265265+ let verified = verify_password(password, &password_scrypt).await?;
266266+ if !verified {
267267+ if oauth {
268268+ //OAuth does not allow app password logins so just go ahead and send it along it's way
269269+ return Ok(AuthResult::WrongIdentityOrPassword);
270270+ }
271271+ //Theres a chance it could be an app password so check that as well
272272+ return match verify_app_password(&state.account_pool, &did, password).await {
273273+ Ok(valid) => {
274274+ if valid {
275275+ //Was a valid app password up to the PDS now
276276+ Ok(AuthResult::ProxyThrough)
277277+ } else {
278278+ Ok(AuthResult::WrongIdentityOrPassword)
279279+ }
280280+ }
281281+ Err(err) => {
282282+ log::error!("Error checking the app password: {err}");
283283+ Err(err)
284284+ }
285285+ };
286286+ }
287287+288288+ // Check two-factor requirement for this DID in the gatekeeper DB
289289+ let required_opt = sqlx::query_as::<_, (u8,)>(
290290+ "SELECT required FROM two_factor_accounts WHERE did = ? LIMIT 1",
291291+ )
292292+ .bind(did.clone())
293293+ .fetch_optional(&state.pds_gatekeeper_pool)
294294+ .await?;
295295+296296+ let two_factor_required = match required_opt {
297297+ Some(row) => row.0 != 0,
298298+ None => false,
299299+ };
300300+301301+ if two_factor_required {
302302+ //Two factor is required and a taken was provided
303303+ if let Some(two_factor_code) = two_factor_code {
304304+ //if the two_factor_code is set need to see if we have a valid token
305305+ if !two_factor_code.is_empty() {
306306+ return match assert_valid_token(
307307+ &state.account_pool,
308308+ did.clone(),
309309+ two_factor_code,
310310+ )
311311+ .await
312312+ {
313313+ Ok(_) => {
314314+ let result_of_cleanup =
315315+ delete_all_email_tokens(&state.account_pool, did.clone()).await;
316316+ if result_of_cleanup.is_err() {
317317+ log::error!(
318318+ "There was an error deleting the email tokens after login: {:?}",
319319+ result_of_cleanup.err()
320320+ )
321321+ }
322322+ Ok(AuthResult::ProxyThrough)
323323+ }
324324+ Err(err) => Ok(AuthResult::TokenCheckFailed(err)),
325325+ };
326326+ }
327327+ }
328328+329329+ return match create_two_factor_token(&state.account_pool, did).await {
330330+ Ok(code) => {
331331+ let mut email_data = Map::new();
332332+ email_data.insert("token".to_string(), Value::from(code.clone()));
333333+ email_data.insert("handle".to_string(), Value::from(handle.clone()));
334334+ let email_body = state
335335+ .template_engine
336336+ .render("two_factor_code.hbs", email_data)?;
337337+338338+ let email_message = Message::builder()
339339+ //TODO prob get the proper type in the state
340340+ .from(state.mailer_from.parse()?)
341341+ .to(email.parse()?)
342342+ .subject("Sign in to Bluesky")
343343+ .multipart(
344344+ MultiPart::alternative() // This is composed of two parts.
345345+ .singlepart(
346346+ SinglePart::builder()
347347+ .header(header::ContentType::TEXT_PLAIN)
348348+ .body(format!("We received a sign-in request for the account @{handle}. Use the code: {code} to sign in. If this wasn't you, we recommend taking steps to protect your account by changing your password at https://bsky.app/settings.")), // Every message should have a plain text fallback.
349349+ )
350350+ .singlepart(
351351+ SinglePart::builder()
352352+ .header(header::ContentType::TEXT_HTML)
353353+ .body(email_body),
354354+ ),
355355+ )?;
356356+ match state.mailer.send(email_message).await {
357357+ Ok(_) => Ok(AuthResult::TwoFactorRequired(mask_email(email))),
358358+ Err(err) => {
359359+ log::error!("Error sending the 2FA email: {err}");
360360+ Err(anyhow!(err))
361361+ }
362362+ }
363363+ }
364364+ Err(err) => {
365365+ log::error!("error on creating a 2fa token: {err}");
366366+ Err(anyhow!(err))
367367+ }
368368+ };
369369+ }
370370+ }
371371+372372+ // No local 2FA requirement (or account not found)
373373+ Ok(AuthResult::ProxyThrough)
374374+}
375375+376376+pub async fn create_two_factor_token(
377377+ account_db: &SqlitePool,
378378+ did: String,
379379+) -> anyhow::Result<String> {
380380+ let purpose = "2fa_code";
381381+382382+ let token = get_random_token();
383383+ let right_now = Utc::now();
384384+385385+ let res = sqlx::query(
386386+ "INSERT INTO email_token (purpose, did, token, requestedAt)
387387+ VALUES (?, ?, ?, ?)
388388+ ON CONFLICT(purpose, did) DO UPDATE SET
389389+ token=excluded.token,
390390+ requestedAt=excluded.requestedAt
391391+ WHERE did=excluded.did",
392392+ )
393393+ .bind(purpose)
394394+ .bind(&did)
395395+ .bind(&token)
396396+ .bind(right_now)
397397+ .execute(account_db)
398398+ .await;
399399+400400+ match res {
401401+ Ok(_) => Ok(token),
402402+ Err(err) => {
403403+ log::error!("Error creating a two factor token: {err}");
404404+ Err(anyhow::anyhow!(err))
405405+ }
406406+ }
407407+}
408408+409409+pub async fn delete_all_email_tokens(account_db: &SqlitePool, did: String) -> anyhow::Result<()> {
410410+ sqlx::query("DELETE FROM email_token WHERE did = ?")
411411+ .bind(did)
412412+ .execute(account_db)
413413+ .await?;
414414+ Ok(())
415415+}
416416+417417+pub async fn assert_valid_token(
418418+ account_db: &SqlitePool,
419419+ did: String,
420420+ token: String,
421421+) -> Result<(), TokenCheckError> {
422422+ let token_upper = token.to_ascii_uppercase();
423423+ let purpose = "2fa_code";
424424+425425+ let row: Option<(String,)> = sqlx::query_as(
426426+ "SELECT requestedAt FROM email_token WHERE purpose = ? AND did = ? AND token = ? LIMIT 1",
427427+ )
428428+ .bind(purpose)
429429+ .bind(did)
430430+ .bind(token_upper)
431431+ .fetch_optional(account_db)
432432+ .await
433433+ .map_err(|err| {
434434+ log::error!("Error getting the 2fa token: {err}");
435435+ InvalidToken
436436+ })?;
437437+438438+ match row {
439439+ None => Err(InvalidToken),
440440+ Some(row) => {
441441+ // Token lives for 15 minutes
442442+ let expiration_ms = 15 * 60_000;
443443+444444+ let requested_at_utc = match chrono::DateTime::parse_from_rfc3339(&row.0) {
445445+ Ok(dt) => dt.with_timezone(&Utc),
446446+ Err(_) => {
447447+ return Err(TokenCheckError::InvalidToken);
448448+ }
449449+ };
450450+451451+ let now = Utc::now();
452452+ let age_ms = (now - requested_at_utc).num_milliseconds();
453453+ let expired = age_ms > expiration_ms;
454454+ if expired {
455455+ return Err(TokenCheckError::ExpiredToken);
456456+ }
457457+458458+ Ok(())
459459+ }
460460+ }
461461+}
462462+463463+/// We just need to confirm if it's there or not. Will let the PDS do the actual figuring of permissions
464464+pub async fn verify_app_password(
465465+ account_db: &SqlitePool,
466466+ did: &str,
467467+ password: &str,
468468+) -> anyhow::Result<bool> {
469469+ let password_scrypt = hash_app_password(did, password)?;
470470+471471+ let row: Option<(i64,)> = sqlx::query_as(
472472+ "SELECT Count(*) FROM app_password WHERE did = ? AND passwordScrypt = ? LIMIT 1",
473473+ )
474474+ .bind(did)
475475+ .bind(password_scrypt)
476476+ .fetch_optional(account_db)
477477+ .await?;
478478+479479+ Ok(match row {
480480+ None => false,
481481+ Some((count,)) => count > 0,
482482+ })
483483+}
484484+485485+/// Mask an email address into a hint like "2***0@p***m".
486486+pub fn mask_email(email: String) -> String {
487487+ // Basic split on first '@'
488488+ let mut parts = email.splitn(2, '@');
489489+ let local = match parts.next() {
490490+ Some(l) => l,
491491+ None => return email.to_string(),
492492+ };
493493+ let domain_rest = match parts.next() {
494494+ Some(d) if !d.is_empty() => d,
495495+ _ => return email.to_string(),
496496+ };
497497+498498+ // Helper to mask a single label (keep first and last, middle becomes ***).
499499+ fn mask_label(s: &str) -> String {
500500+ let chars: Vec<char> = s.chars().collect();
501501+ match chars.len() {
502502+ 0 => String::new(),
503503+ 1 => format!("{}***", chars[0]),
504504+ 2 => format!("{}***{}", chars[0], chars[1]),
505505+ _ => format!("{}***{}", chars[0], chars[chars.len() - 1]),
506506+ }
507507+ }
508508+509509+ // Mask local
510510+ let masked_local = mask_label(local);
511511+512512+ // Mask first domain label only, keep the rest of the domain intact
513513+ let mut dom_parts = domain_rest.splitn(2, '.');
514514+ let first_label = dom_parts.next().unwrap_or("");
515515+ let rest = dom_parts.next();
516516+ let masked_first = mask_label(first_label);
517517+ let masked_domain = if let Some(rest) = rest {
518518+ format!("{}.{rest}", masked_first)
519519+ } else {
520520+ masked_first
521521+ };
522522+523523+ format!("{masked_local}@{masked_domain}")
524524+}
+53-26
src/main.rs
···11+#![warn(clippy::unwrap_used)]
22+use crate::oauth_provider::sign_in;
13use crate::xrpc::com_atproto_server::{create_session, get_session, update_email};
22-use axum::middleware as ax_middleware;
33-mod middleware;
44use axum::body::Body;
55use axum::handler::Handler;
66use axum::http::{Method, header};
77+use axum::middleware as ax_middleware;
78use axum::routing::post;
89use axum::{Router, routing::get};
910use axum_template::engine::Engine;
···2122use tower_governor::governor::GovernorConfigBuilder;
2223use tower_http::compression::CompressionLayer;
2324use tower_http::cors::{Any, CorsLayer};
2424-use tracing::{error, log};
2525+use tracing::log;
2526use tracing_subscriber::{EnvFilter, fmt, prelude::*};
26272828+pub mod helpers;
2929+mod middleware;
3030+mod oauth_provider;
2731mod xrpc;
28322933type HyperUtilClient = hyper_util::client::legacy::Client<HttpConnector, Body>;
···3438struct EmailTemplates;
35393640#[derive(Clone)]
3737-struct AppState {
4141+pub struct AppState {
3842 account_pool: SqlitePool,
3943 pds_gatekeeper_pool: SqlitePool,
4044 reverse_proxy_client: HyperUtilClient,
···73777478 let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n";
75797676- let banner = format!(" {}\n{}", body, intro);
8080+ let banner = format!(" {body}\n{intro}");
77817882 (
7983 [(header::CONTENT_TYPE, "text/plain; charset=utf-8")],
···8488#[tokio::main]
8589async fn main() -> Result<(), Box<dyn std::error::Error>> {
8690 setup_tracing();
8787- //TODO prod
9191+ //TODO may need to change where this reads from? Like an env variable for it's location? Or arg?
8892 dotenvy::from_path(Path::new("./pds.env"))?;
8993 let pds_root = env::var("PDS_DATA_DIRECTORY")?;
9090- // let pds_root = "/home/baileytownsend/Documents/code/docker_compose/pds/pds_data";
9191- let account_db_url = format!("{}/account.sqlite", pds_root);
9292- log::info!("accounts_db_url: {}", account_db_url);
9494+ let account_db_url = format!("{pds_root}/account.sqlite");
93959496 let account_options = SqliteConnectOptions::new()
9595- .journal_mode(SqliteJournalMode::Wal)
9696- .filename(account_db_url);
9797+ .filename(account_db_url)
9898+ .busy_timeout(Duration::from_secs(5));
979998100 let account_pool = SqlitePoolOptions::new()
99101 .max_connections(5)
100102 .connect_with(account_options)
101103 .await?;
102104103103- let bells_db_url = format!("{}/pds_gatekeeper.sqlite", pds_root);
105105+ let bells_db_url = format!("{pds_root}/pds_gatekeeper.sqlite");
104106 let options = SqliteConnectOptions::new()
105107 .journal_mode(SqliteJournalMode::Wal)
106108 .filename(bells_db_url)
107107- .create_if_missing(true);
109109+ .create_if_missing(true)
110110+ .busy_timeout(Duration::from_secs(5));
108111 let pds_gatekeeper_pool = SqlitePoolOptions::new()
109112 .max_connections(5)
110113 .connect_with(options)
111114 .await?;
112115113113- // Run migrations for the bells_and_whistles database
116116+ // Run migrations for the extra database
114117 // Note: the migrations are embedded at compile time from the given directory
115118 // sqlx
116119 sqlx::migrate!("./migrations")
···130133 AsyncSmtpTransport::<Tokio1Executor>::from_url(smtp_url.as_str())?.build();
131134 //Email templates setup
132135 let mut hbs = Handlebars::new();
133133- let _ = hbs.register_embed_templates::<EmailTemplates>();
136136+137137+ let users_email_directory = env::var("GATEKEEPER_EMAIL_TEMPLATES_DIRECTORY");
138138+ if let Ok(users_email_directory) = users_email_directory {
139139+ hbs.register_template_file(
140140+ "two_factor_code.hbs",
141141+ format!("{users_email_directory}/two_factor_code.hbs"),
142142+ )?;
143143+ } else {
144144+ let _ = hbs.register_embed_templates::<EmailTemplates>();
145145+ }
146146+147147+ let pds_base_url =
148148+ env::var("PDS_BASE_URL").unwrap_or_else(|_| "http://localhost:3000".to_string());
134149135150 let state = AppState {
136151 account_pool,
137152 pds_gatekeeper_pool,
138153 reverse_proxy_client: client,
139139- //TODO should be env prob
140140- pds_base_url: "http://localhost:3000".to_string(),
154154+ pds_base_url,
141155 mailer,
142156 mailer_from: sent_from,
143157 template_engine: Engine::from(hbs),
···145159146160 // Rate limiting
147161 //Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds.
148148- let governor_conf = GovernorConfigBuilder::default()
162162+ let create_session_governor_conf = GovernorConfigBuilder::default()
149163 .per_second(60)
150164 .burst_size(5)
151165 .finish()
152152- .unwrap();
153153- let governor_limiter = governor_conf.limiter().clone();
166166+ .expect("failed to create governor config. this should not happen and is a bug");
167167+168168+ // Create a second config with the same settings for the other endpoint
169169+ let sign_in_governor_conf = GovernorConfigBuilder::default()
170170+ .per_second(60)
171171+ .burst_size(5)
172172+ .finish()
173173+ .expect("failed to create governor config. this should not happen and is a bug");
174174+175175+ let create_session_governor_limiter = create_session_governor_conf.limiter().clone();
176176+ let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone();
154177 let interval = Duration::from_secs(60);
155178 // a separate background task to clean up
156179 std::thread::spawn(move || {
157180 loop {
158181 std::thread::sleep(interval);
159159- tracing::info!("rate limiting storage size: {}", governor_limiter.len());
160160- governor_limiter.retain_recent();
182182+ create_session_governor_limiter.retain_recent();
183183+ sign_in_governor_limiter.retain_recent();
161184 }
162185 });
163186···177200 post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)),
178201 )
179202 .route(
203203+ "/@atproto/oauth-provider/~api/sign-in",
204204+ post(sign_in).layer(GovernorLayer::new(sign_in_governor_conf)),
205205+ )
206206+ .route(
180207 "/xrpc/com.atproto.server.createSession",
181181- post(create_session.layer(GovernorLayer::new(governor_conf))),
208208+ post(create_session.layer(GovernorLayer::new(create_session_governor_conf))),
182209 )
183210 .layer(CompressionLayer::new())
184211 .layer(cors)
185212 .with_state(state);
186213187187- let host = env::var("HOST").unwrap_or_else(|_| "127.0.0.1".to_string());
188188- let port: u16 = env::var("PORT")
214214+ let host = env::var("GATEKEEPER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string());
215215+ let port: u16 = env::var("GATEKEEPER_PORT")
189216 .ok()
190217 .and_then(|s| s.parse().ok())
191218 .unwrap_or(8080);
···202229 .with_graceful_shutdown(shutdown_signal());
203230204231 if let Err(err) = server.await {
205205- error!(error = %err, "server error");
232232+ log::error!("server error:{err}");
206233 }
207234208235 Ok(())
+19-34
src/middleware.rs
···11-use crate::xrpc::helpers::json_error_response;
11+use crate::helpers::json_error_response;
22use axum::extract::Request;
33use axum::http::{HeaderMap, StatusCode};
44use axum::middleware::Next;
···77use jwt_compact::{AlgorithmExt, Claims, Token, UntrustedToken, ValidationError};
88use serde::{Deserialize, Serialize};
99use std::env;
1010+use tracing::log;
10111112#[derive(Clone, Debug)]
1213pub struct Did(pub Option<String>);
···2223 match token {
2324 Ok(token) => {
2425 match token {
2525- None => {
2626- return json_error_response(
2727- StatusCode::BAD_REQUEST,
2828- "TokenRequired",
2929- "",
3030- ).unwrap();
3131- }
2626+ None => json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "")
2727+ .expect("Error creating an error response"),
3228 Some(token) => {
3329 let token = UntrustedToken::new(&token);
3434- //Doing weird unwraps cause I can't do Result for middleware?
3530 if token.is_err() {
3636- return json_error_response(
3737- StatusCode::BAD_REQUEST,
3838- "TokenRequired",
3939- "",
4040- ).unwrap();
3131+ return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "")
3232+ .expect("Error creating an error response");
4133 }
4242- let parsed_token = token.unwrap();
3434+ let parsed_token = token.expect("Already checked for error");
4335 let claims: Result<Claims<TokenClaims>, ValidationError> =
4436 parsed_token.deserialize_claims_unchecked();
4537 if claims.is_err() {
4646- return json_error_response(
4747- StatusCode::BAD_REQUEST,
4848- "TokenRequired",
4949- "",
5050- ).unwrap();
3838+ return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "")
3939+ .expect("Error creating an error response");
5140 }
52415353- let key = Hs256Key::new(env::var("PDS_JWT_SECRET").unwrap());
4242+ let key = Hs256Key::new(
4343+ env::var("PDS_JWT_SECRET").expect("PDS_JWT_SECRET not set in the pds.env"),
4444+ );
5445 let token: Result<Token<TokenClaims>, ValidationError> =
5546 Hs256.validator(&key).validate(&parsed_token);
5647 if token.is_err() {
5757- return json_error_response(
5858- StatusCode::BAD_REQUEST,
5959- "InvalidToken",
6060- "",
6161- ).unwrap();
4848+ return json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "")
4949+ .expect("Error creating an error response");
6250 }
6363- let token = token.unwrap();
5151+ let token = token.expect("Already checked for error,");
6452 //Not going to worry about expiration since it still goes to the PDS
6565-6653 req.extensions_mut()
6754 .insert(Did(Some(token.claims().custom.sub.clone())));
6855 next.run(req).await
6956 }
7057 }
7158 }
7272- Err(_) => {
7373- return json_error_response(
7474- StatusCode::BAD_REQUEST,
7575- "InvalidToken",
7676- "",
7777- ).unwrap();
5959+ Err(err) => {
6060+ log::error!("Error extracting token: {err}");
6161+ json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "")
6262+ .expect("Error creating an error response")
7863 }
7964 }
8065}
+141
src/oauth_provider.rs
···11+use crate::AppState;
22+use crate::helpers::{AuthResult, oauth_json_error_response, preauth_check};
33+use axum::body::Body;
44+use axum::extract::State;
55+use axum::http::header::CONTENT_TYPE;
66+use axum::http::{HeaderMap, HeaderName, HeaderValue, StatusCode};
77+use axum::response::{IntoResponse, Response};
88+use axum::{Json, extract};
99+use serde::{Deserialize, Serialize};
1010+use tracing::log;
1111+1212+#[derive(Serialize, Deserialize, Clone)]
1313+pub struct SignInRequest {
1414+ pub username: String,
1515+ pub password: String,
1616+ pub remember: bool,
1717+ pub locale: String,
1818+ #[serde(skip_serializing_if = "Option::is_none", rename = "emailOtp")]
1919+ pub email_otp: Option<String>,
2020+}
2121+2222+pub async fn sign_in(
2323+ State(state): State<AppState>,
2424+ headers: HeaderMap,
2525+ Json(mut payload): extract::Json<SignInRequest>,
2626+) -> Result<Response<Body>, StatusCode> {
2727+ let identifier = payload.username.clone();
2828+ let password = payload.password.clone();
2929+ let auth_factor_token = payload.email_otp.clone();
3030+3131+ match preauth_check(&state, &identifier, &password, auth_factor_token, true).await {
3232+ Ok(result) => match result {
3333+ AuthResult::WrongIdentityOrPassword => oauth_json_error_response(
3434+ StatusCode::BAD_REQUEST,
3535+ "invalid_request",
3636+ "Invalid identifier or password",
3737+ ),
3838+ AuthResult::TwoFactorRequired(masked_email) => {
3939+ // Email sending step can be handled here if needed in the future.
4040+4141+ // {"error":"second_authentication_factor_required","error_description":"emailOtp authentication factor required (hint: 2***0@p***m)","type":"emailOtp","hint":"2***0@p***m"}
4242+ let body_str = match serde_json::to_string(&serde_json::json!({
4343+ "error": "second_authentication_factor_required",
4444+ "error_description": format!("emailOtp authentication factor required (hint: {})", masked_email),
4545+ "type": "emailOtp",
4646+ "hint": masked_email,
4747+ })) {
4848+ Ok(s) => s,
4949+ Err(_) => return Err(StatusCode::BAD_REQUEST),
5050+ };
5151+5252+ Response::builder()
5353+ .status(StatusCode::BAD_REQUEST)
5454+ .header(CONTENT_TYPE, "application/json")
5555+ .body(Body::from(body_str))
5656+ .map_err(|_| StatusCode::BAD_REQUEST)
5757+ }
5858+ AuthResult::ProxyThrough => {
5959+ //No 2FA or already passed
6060+ let uri = format!(
6161+ "{}{}",
6262+ state.pds_base_url, "/@atproto/oauth-provider/~api/sign-in"
6363+ );
6464+6565+ let mut req = axum::http::Request::post(uri);
6666+ if let Some(req_headers) = req.headers_mut() {
6767+ // Copy headers but remove problematic ones. There was an issue with the PDS not parsing the body fully if i forwarded all headers
6868+ copy_filtered_headers(&headers, req_headers);
6969+ //Setting the content type to application/json manually
7070+ req_headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
7171+ }
7272+7373+ //Clears the email_otp because the pds will reject a request with it.
7474+ payload.email_otp = None;
7575+ let payload_bytes =
7676+ serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?;
7777+7878+ let req = req
7979+ .body(Body::from(payload_bytes))
8080+ .map_err(|_| StatusCode::BAD_REQUEST)?;
8181+8282+ let proxied = state
8383+ .reverse_proxy_client
8484+ .request(req)
8585+ .await
8686+ .map_err(|_| StatusCode::BAD_REQUEST)?
8787+ .into_response();
8888+8989+ Ok(proxied)
9090+ }
9191+ //Ignoring the type of token check failure. Looks like oauth on the entry treads them the same.
9292+ AuthResult::TokenCheckFailed(_) => oauth_json_error_response(
9393+ StatusCode::BAD_REQUEST,
9494+ "invalid_request",
9595+ "Unable to sign-in due to an unexpected server error",
9696+ ),
9797+ },
9898+ Err(err) => {
9999+ log::error!(
100100+ "Error during pre-auth check. This happens on the create_session endpoint when trying to decide if the user has access:\n {err}"
101101+ );
102102+ oauth_json_error_response(
103103+ StatusCode::BAD_REQUEST,
104104+ "pds_gatekeeper_error",
105105+ "This error was not generated by the PDS, but PDS Gatekeeper. Please contact your PDS administrator for help and for them to review the server logs.",
106106+ )
107107+ }
108108+ }
109109+}
110110+111111+fn is_disallowed_header(name: &HeaderName) -> bool {
112112+ // possible problematic headers with proxying
113113+ matches!(
114114+ name.as_str(),
115115+ "connection"
116116+ | "keep-alive"
117117+ | "proxy-authenticate"
118118+ | "proxy-authorization"
119119+ | "te"
120120+ | "trailer"
121121+ | "transfer-encoding"
122122+ | "upgrade"
123123+ | "host"
124124+ | "content-length"
125125+ | "content-encoding"
126126+ | "expect"
127127+ | "accept-encoding"
128128+ )
129129+}
130130+131131+fn copy_filtered_headers(src: &HeaderMap, dst: &mut HeaderMap) {
132132+ for (name, value) in src.iter() {
133133+ if is_disallowed_header(name) {
134134+ continue;
135135+ }
136136+ // Only copy valid headers
137137+ if let Ok(hv) = HeaderValue::from_bytes(value.as_bytes()) {
138138+ dst.insert(name.clone(), hv);
139139+ }
140140+ }
141141+}
+66-211
src/xrpc/com_atproto_server.rs
···11use crate::AppState;
22+use crate::helpers::{
33+ AuthResult, ProxiedResult, TokenCheckError, json_error_response, preauth_check, proxy_get_json,
44+};
25use crate::middleware::Did;
33-use crate::xrpc::helpers::{ProxiedResult, json_error_response, proxy_get_json};
46use axum::body::Body;
57use axum::extract::State;
68use axum::http::{HeaderMap, StatusCode};
79use axum::response::{IntoResponse, Response};
810use axum::{Extension, Json, debug_handler, extract, extract::Request};
99-use axum_template::TemplateEngine;
1010-use lettre::message::{MultiPart, SinglePart, header};
1111-use lettre::{AsyncTransport, Message};
1211use serde::{Deserialize, Serialize};
1312use serde_json;
1414-use serde_json::Value;
1515-use serde_json::value::Map;
1613use tracing::log;
17141815#[derive(Serialize, Deserialize, Debug, Clone)]
···5855pub struct CreateSessionRequest {
5956 identifier: String,
6057 password: String,
6161- auth_factor_token: String,
6262- allow_takendown: bool,
6363-}
6464-6565-pub enum AuthResult {
6666- WrongIdentityOrPassword,
6767- TwoFactorRequired,
6868- TwoFactorFailed,
6969- /// User does not have 2FA enabled, or passes it
7070- ProxyThrough,
7171-}
7272-7373-pub enum IdentifierType {
7474- Email,
7575- DID,
7676- Handle,
7777-}
7878-7979-impl IdentifierType {
8080- fn what_is_it(identifier: String) -> Self {
8181- if identifier.contains("@") {
8282- IdentifierType::Email
8383- } else if identifier.contains("did:") {
8484- IdentifierType::DID
8585- } else {
8686- IdentifierType::Handle
8787- }
8888- }
8989-}
9090-9191-async fn verify_password(password: &str, password_scrypt: &str) -> Result<bool, StatusCode> {
9292- // Expected format: "salt:hash" where hash is hex of scrypt(password, salt, 64 bytes)
9393- let mut parts = password_scrypt.splitn(2, ':');
9494- let salt = match parts.next() {
9595- Some(s) if !s.is_empty() => s,
9696- _ => return Ok(false),
9797- };
9898- let stored_hash_hex = match parts.next() {
9999- Some(h) if !h.is_empty() => h,
100100- _ => return Ok(false),
101101- };
102102-103103- //Sets up scrypt to mimic node's scrypt
104104- let params = match scrypt::Params::new(14, 8, 1, 64) {
105105- Ok(p) => p,
106106- Err(_) => return Ok(false),
107107- };
108108- let mut derived = [0u8; 64];
109109- if scrypt::scrypt(password.as_bytes(), salt.as_bytes(), ¶ms, &mut derived).is_err() {
110110- return Ok(false);
111111- }
112112-113113- let stored_bytes = match hex::decode(stored_hash_hex) {
114114- Ok(b) => b,
115115- Err(e) => {
116116- log::error!("Error decoding stored hash: {}", e);
117117- return Ok(false);
118118- }
119119- };
120120-121121- Ok(derived.as_slice() == stored_bytes.as_slice())
122122-}
123123-124124-async fn preauth_check(
125125- state: &AppState,
126126- identifier: &str,
127127- password: &str,
128128-) -> Result<AuthResult, StatusCode> {
129129- // Determine identifier type
130130- let id_type = IdentifierType::what_is_it(identifier.to_string());
131131-132132- // Query account DB for did and passwordScrypt based on identifier type
133133- let account_row: Option<(String, String, String)> = match id_type {
134134- IdentifierType::Email => sqlx::query_as::<_, (String, String, String)>(
135135- "SELECT did, passwordScrypt, account.email FROM account WHERE email = ? LIMIT 1",
136136- )
137137- .bind(identifier)
138138- .fetch_optional(&state.account_pool)
139139- .await
140140- .map_err(|_| StatusCode::BAD_REQUEST)?,
141141- IdentifierType::Handle => sqlx::query_as::<_, (String, String, String)>(
142142- "SELECT account.did, account.passwordScrypt, account.email
143143- FROM actor
144144- LEFT JOIN account ON actor.did = account.did
145145- where actor.handle =? LIMIT 1",
146146- )
147147- .bind(identifier)
148148- .fetch_optional(&state.account_pool)
149149- .await
150150- .map_err(|_| StatusCode::BAD_REQUEST)?,
151151- IdentifierType::DID => sqlx::query_as::<_, (String, String, String)>(
152152- "SELECT did, passwordScrypt, account.email FROM account WHERE did = ? LIMIT 1",
153153- )
154154- .bind(identifier)
155155- .fetch_optional(&state.account_pool)
156156- .await
157157- .map_err(|_| StatusCode::BAD_REQUEST)?,
158158- };
159159-160160- if let Some((did, password_scrypt, email)) = account_row {
161161- // Check two-factor requirement for this DID in the gatekeeper DB
162162- let required_opt = sqlx::query_as::<_, (u8,)>(
163163- "SELECT required FROM two_factor_accounts WHERE did = ? LIMIT 1",
164164- )
165165- .bind(&did)
166166- .fetch_optional(&state.pds_gatekeeper_pool)
167167- .await
168168- .map_err(|_| StatusCode::BAD_REQUEST)?;
169169-170170- let two_factor_required = match required_opt {
171171- Some(row) => row.0 != 0,
172172- None => false,
173173- };
174174-175175- if two_factor_required {
176176- // Verify password before proceeding to 2FA email step
177177- let verified = verify_password(password, &password_scrypt).await?;
178178- if !verified {
179179- return Ok(AuthResult::WrongIdentityOrPassword);
180180- }
181181- let mut email_data = Map::new();
182182- //TODO these need real values
183183- let token = "test".to_string();
184184- let handle = "baileytownsend.dev".to_string();
185185- email_data.insert("token".to_string(), Value::from(token.clone()));
186186- email_data.insert("handle".to_string(), Value::from(handle.clone()));
187187- //TODO bad unwrap
188188- let email_body = state
189189- .template_engine
190190- .render("two_factor_code.hbs", email_data)
191191- .unwrap();
192192-193193- let email = Message::builder()
194194- //TODO prob get the proper type in the state
195195- .from(state.mailer_from.parse().unwrap())
196196- .to(email.parse().unwrap())
197197- .subject("Sign in to Bluesky")
198198- .multipart(
199199- MultiPart::alternative() // This is composed of two parts.
200200- .singlepart(
201201- SinglePart::builder()
202202- .header(header::ContentType::TEXT_PLAIN)
203203- .body(format!("We received a sign-in request for the account @{}. Use the code: {} to sign in. If this wasn't you, we recommend taking steps to protect your account by changing your password at https://bsky.app/settings.", handle, token)), // Every message should have a plain text fallback.
204204- )
205205- .singlepart(
206206- SinglePart::builder()
207207- .header(header::ContentType::TEXT_HTML)
208208- .body(email_body),
209209- ),
210210- )
211211- //TODO bad
212212- .unwrap();
213213- return match state.mailer.send(email).await {
214214- Ok(_) => Ok(AuthResult::TwoFactorRequired),
215215- Err(err) => {
216216- log::error!("Error sending the 2FA email: {}", err);
217217- Err(StatusCode::BAD_REQUEST)
218218- }
219219- };
220220- }
221221- }
222222-223223- // No local 2FA requirement (or account not found)
224224- Ok(AuthResult::ProxyThrough)
5858+ #[serde(skip_serializing_if = "Option::is_none")]
5959+ auth_factor_token: Option<String>,
6060+ #[serde(skip_serializing_if = "Option::is_none")]
6161+ allow_takendown: Option<bool>,
22562}
2266322764pub async fn create_session(
···23168) -> Result<Response<Body>, StatusCode> {
23269 let identifier = payload.identifier.clone();
23370 let password = payload.password.clone();
7171+ let auth_factor_token = payload.auth_factor_token.clone();
2347223573 // Run the shared pre-auth logic to validate and check 2FA requirement
236236- match preauth_check(&state, &identifier, &password).await? {
237237- AuthResult::WrongIdentityOrPassword => json_error_response(
238238- StatusCode::UNAUTHORIZED,
239239- "AuthenticationRequired",
240240- "Invalid identifier or password",
241241- ),
242242- AuthResult::TwoFactorRequired => {
243243- // Email sending step can be handled here if needed in the future.
244244- json_error_response(
7474+ match preauth_check(&state, &identifier, &password, auth_factor_token, false).await {
7575+ Ok(result) => match result {
7676+ AuthResult::WrongIdentityOrPassword => json_error_response(
24577 StatusCode::UNAUTHORIZED,
246246- "AuthFactorTokenRequired",
247247- "A sign in code has been sent to your email address",
248248- )
249249- }
250250- AuthResult::TwoFactorFailed => {
251251- //Not sure what the errors are for this response is yet
252252- json_error_response(StatusCode::UNAUTHORIZED, "PLACEHOLDER", "PLACEHOLDER")
253253- }
254254- AuthResult::ProxyThrough => {
255255- //No 2FA or already passed
256256- let uri = format!(
257257- "{}{}",
258258- state.pds_base_url, "/xrpc/com.atproto.server.createSession"
259259- );
260260-261261- let mut req = axum::http::Request::post(uri);
262262- if let Some(req_headers) = req.headers_mut() {
263263- req_headers.extend(headers.clone());
7878+ "AuthenticationRequired",
7979+ "Invalid identifier or password",
8080+ ),
8181+ AuthResult::TwoFactorRequired(_) => {
8282+ // Email sending step can be handled here if needed in the future.
8383+ json_error_response(
8484+ StatusCode::UNAUTHORIZED,
8585+ "AuthFactorTokenRequired",
8686+ "A sign in code has been sent to your email address",
8787+ )
26488 }
8989+ AuthResult::ProxyThrough => {
9090+ log::info!("Proxying through");
9191+ //No 2FA or already passed
9292+ let uri = format!(
9393+ "{}{}",
9494+ state.pds_base_url, "/xrpc/com.atproto.server.createSession"
9595+ );
26596266266- let payload_bytes =
267267- serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?;
268268- let req = req
269269- .body(Body::from(payload_bytes))
270270- .map_err(|_| StatusCode::BAD_REQUEST)?;
9797+ let mut req = axum::http::Request::post(uri);
9898+ if let Some(req_headers) = req.headers_mut() {
9999+ req_headers.extend(headers.clone());
100100+ }
101101+102102+ let payload_bytes =
103103+ serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?;
104104+ let req = req
105105+ .body(Body::from(payload_bytes))
106106+ .map_err(|_| StatusCode::BAD_REQUEST)?;
271107272272- let proxied = state
273273- .reverse_proxy_client
274274- .request(req)
275275- .await
276276- .map_err(|_| StatusCode::BAD_REQUEST)?
277277- .into_response();
108108+ let proxied = state
109109+ .reverse_proxy_client
110110+ .request(req)
111111+ .await
112112+ .map_err(|_| StatusCode::BAD_REQUEST)?
113113+ .into_response();
278114279279- Ok(proxied)
115115+ Ok(proxied)
116116+ }
117117+ AuthResult::TokenCheckFailed(err) => match err {
118118+ TokenCheckError::InvalidToken => {
119119+ json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "Token is invalid")
120120+ }
121121+ TokenCheckError::ExpiredToken => {
122122+ json_error_response(StatusCode::BAD_REQUEST, "ExpiredToken", "Token is expired")
123123+ }
124124+ },
125125+ },
126126+ Err(err) => {
127127+ log::error!(
128128+ "Error during pre-auth check. This happens on the create_session endpoint when trying to decide if the user has access:\n {err}"
129129+ );
130130+ json_error_response(
131131+ StatusCode::INTERNAL_SERVER_ERROR,
132132+ "InternalServerError",
133133+ "This error was not generated by the PDS, but PDS Gatekeeper. Please contact your PDS administrator for help and for them to review the server logs.",
134134+ )
280135 }
281136 }
282137}
···290145) -> Result<Response<Body>, StatusCode> {
291146 //If email auth is not set at all it is a update email address
292147 let email_auth_not_set = payload.email_auth_factor.is_none();
293293- //If email aurth is set it is to either turn on or off 2fa
148148+ //If email auth is set it is to either turn on or off 2fa
294149 let email_auth_update = payload.email_auth_factor.unwrap_or(false);
295150296151 // Email update asked for
···350205 }
351206 }
352207353353- // Updating the acutal email address
208208+ // Updating the actual email address by sending it on to the PDS
354209 let uri = format!(
355210 "{}{}",
356211 state.pds_base_url, "/xrpc/com.atproto.server.updateEmail"
-150
src/xrpc/helpers.rs
···11-use axum::body::{Body, to_bytes};
22-use axum::extract::Request;
33-use axum::http::{HeaderMap, Method, StatusCode, Uri};
44-use axum::http::header::CONTENT_TYPE;
55-use axum::response::{IntoResponse, Response};
66-use serde::de::DeserializeOwned;
77-use tracing::error;
88-99-use crate::AppState;
1010-1111-/// The result of a proxied call that attempts to parse JSON.
1212-pub enum ProxiedResult<T> {
1313- /// Successfully parsed JSON body along with original response headers.
1414- Parsed { value: T, _headers: HeaderMap },
1515- /// Could not or should not parse: return the original (or rebuilt) response as-is.
1616- Passthrough(Response<Body>),
1717-}
1818-1919-/// Proxy the incoming request to the PDS base URL plus the provided path and attempt to parse
2020-/// the successful response body as JSON into `T`.
2121-///
2222-/// Behavior:
2323-/// - If the proxied response is non-200, returns Passthrough with the original response.
2424-/// - If the response is 200 but JSON parsing fails, returns Passthrough with the original body and headers.
2525-/// - If parsing succeeds, returns Parsed { value, headers }.
2626-pub async fn proxy_get_json<T>(
2727- state: &AppState,
2828- mut req: Request,
2929- path: &str,
3030-) -> Result<ProxiedResult<T>, StatusCode>
3131-where
3232- T: DeserializeOwned,
3333-{
3434- let uri = format!("{}{}", state.pds_base_url, path);
3535- *req.uri_mut() = Uri::try_from(uri).map_err(|_| StatusCode::BAD_REQUEST)?;
3636-3737- let result = state
3838- .reverse_proxy_client
3939- .request(req)
4040- .await
4141- .map_err(|_| StatusCode::BAD_REQUEST)?
4242- .into_response();
4343-4444- if result.status() != StatusCode::OK {
4545- return Ok(ProxiedResult::Passthrough(result));
4646- }
4747-4848- let response_headers = result.headers().clone();
4949- let body = result.into_body();
5050- let body_bytes = to_bytes(body, usize::MAX)
5151- .await
5252- .map_err(|_| StatusCode::BAD_REQUEST)?;
5353-5454- match serde_json::from_slice::<T>(&body_bytes) {
5555- Ok(value) => Ok(ProxiedResult::Parsed {
5656- value,
5757- _headers: response_headers,
5858- }),
5959- Err(err) => {
6060- error!(%err, "failed to parse proxied JSON response; returning original body");
6161- let mut builder = Response::builder().status(StatusCode::OK);
6262- if let Some(headers) = builder.headers_mut() {
6363- *headers = response_headers;
6464- }
6565- let resp = builder
6666- .body(Body::from(body_bytes))
6767- .map_err(|_| StatusCode::BAD_REQUEST)?;
6868- Ok(ProxiedResult::Passthrough(resp))
6969- }
7070- }
7171-}
7272-7373-/// Proxy the incoming request as a POST to the PDS base URL plus the provided path and attempt to parse
7474-/// the successful response body as JSON into `T`.
7575-///
7676-/// Behavior mirrors `proxy_get_json`:
7777-/// - If the proxied response is non-200, returns Passthrough with the original response.
7878-/// - If the response is 200 but JSON parsing fails, returns Passthrough with the original body and headers.
7979-/// - If parsing succeeds, returns Parsed { value, headers }.
8080-pub async fn _proxy_post_json<T>(
8181- state: &AppState,
8282- mut req: Request,
8383- path: &str,
8484-) -> Result<ProxiedResult<T>, StatusCode>
8585-where
8686- T: DeserializeOwned,
8787-{
8888- let uri = format!("{}{}", state.pds_base_url, path);
8989- *req.uri_mut() = Uri::try_from(uri).map_err(|_| StatusCode::BAD_REQUEST)?;
9090- *req.method_mut() = Method::POST;
9191-9292- let result = state
9393- .reverse_proxy_client
9494- .request(req)
9595- .await
9696- .map_err(|_| StatusCode::BAD_REQUEST)?
9797- .into_response();
9898-9999- if result.status() != StatusCode::OK {
100100- return Ok(ProxiedResult::Passthrough(result));
101101- }
102102-103103- let response_headers = result.headers().clone();
104104- let body = result.into_body();
105105- let body_bytes = to_bytes(body, usize::MAX)
106106- .await
107107- .map_err(|_| StatusCode::BAD_REQUEST)?;
108108-109109- match serde_json::from_slice::<T>(&body_bytes) {
110110- Ok(value) => Ok(ProxiedResult::Parsed {
111111- value,
112112- _headers: response_headers,
113113- }),
114114- Err(err) => {
115115- error!(%err, "failed to parse proxied JSON response (POST); returning original body");
116116- let mut builder = Response::builder().status(StatusCode::OK);
117117- if let Some(headers) = builder.headers_mut() {
118118- *headers = response_headers;
119119- }
120120- let resp = builder
121121- .body(Body::from(body_bytes))
122122- .map_err(|_| StatusCode::BAD_REQUEST)?;
123123- Ok(ProxiedResult::Passthrough(resp))
124124- }
125125- }
126126-}
127127-128128-129129-/// Build a JSON error response with the required Content-Type header
130130-/// Content-Type: application/json;charset=utf-8
131131-/// Body shape: { "error": string, "message": string }
132132-pub fn json_error_response(
133133- status: StatusCode,
134134- error: impl Into<String>,
135135- message: impl Into<String>,
136136-) -> Result<Response<Body>, StatusCode> {
137137- let body_str = match serde_json::to_string(&serde_json::json!({
138138- "error": error.into(),
139139- "message": message.into(),
140140- })) {
141141- Ok(s) => s,
142142- Err(_) => return Err(StatusCode::BAD_REQUEST),
143143- };
144144-145145- Response::builder()
146146- .status(status)
147147- .header(CONTENT_TYPE, "application/json;charset=utf-8")
148148- .body(Body::from(body_str))
149149- .map_err(|_| StatusCode::BAD_REQUEST)
150150-}
-1
src/xrpc/mod.rs
···11pub mod com_atproto_server;
22-pub mod helpers;