social media crossposting tool. 3rd time's the charm
mastodon misskey crossposting bluesky

use jwt key expiry time instead of a random ttl. minor cleanup

zenfyr.dev acff38ee 49f701f6

verified
+69 -111
+27 -54
atproto/auth.py
··· 12 12 13 13 def login(self, identifier: str, password: str) -> Session: 14 14 cached = self.store.get_session_by_pds(self.pds_url, identifier) 15 - if cached: 15 + if cached and not cached.is_refresh_token_expired(): 16 16 LOGGER.info("Using cached session for %s", identifier) 17 17 return cached 18 18 return self.create_session(identifier, password) 19 19 20 + def get_session(self, did: str) -> Session | None: 21 + session = self.store.get_session(did) 22 + if not session: 23 + return None 24 + if session.is_access_token_expired(): 25 + if not session.is_refresh_token_expired(): 26 + LOGGER.info("refreshing session for %s", session.did) 27 + return self.refresh_session(session) 28 + LOGGER.info("both tokens expired for %s, removing session", session.did) 29 + self.store.remove_session(did) 30 + raise ValueError("Both access and refresh tokens expired. Please login again.") 31 + return session 32 + 20 33 def create_session( 21 34 self, 22 35 identifier: str, ··· 24 37 auth_factor_token: str | None = None, 25 38 ) -> Session: 26 39 url = f"{self.pds_url}/xrpc/com.atproto.server.createSession" 27 - payload: dict[str, str] = { 28 - "identifier": identifier, 29 - "password": password, 30 - } 40 + payload: dict[str, str] = {"identifier": identifier, "password": password} 31 41 if auth_factor_token: 32 42 payload["authFactorToken"] = auth_factor_token 33 43 ··· 43 53 case _: 44 54 raise ValueError(f"Authentication failed with status {response.status_code}") 45 55 46 - data = response.json() 47 - session = Session.from_dict(data, self.pds_url) 56 + session = Session.from_dict(response.json(), self.pds_url) 48 57 self.store.set_session(session) 49 58 LOGGER.info("Created session for %s (%s)", session.handle, session.did) 50 59 return session ··· 62 71 error_data = response.json() if response.content else {} 63 72 raise ValueError(f"Refresh failed: {error_data}") 64 73 case 400: 65 - error_data = response.json() 66 - raise ValueError(f"Refresh failed: {error_data}") 74 + raise ValueError(f"Refresh failed: {response.json()}") 67 75 case _: 68 76 raise ValueError(f"Refresh failed with status {response.status_code}") 69 77 70 - data = response.json() 71 - new_session = Session.from_dict(data, self.pds_url) 78 + new_session = Session.from_dict(response.json(), self.pds_url) 72 79 self.store.set_session(new_session) 73 80 LOGGER.info("Refreshed session for %s (%s)", new_session.handle, new_session.did) 74 81 return new_session 75 82 76 - def get_session(self, did: str) -> Session | None: 77 - return self.store.get_session(did) 78 - 79 83 def get_access_token(self, did: str) -> str | None: 80 84 session = self.get_session(did) 81 - if not session: 82 - return None 83 - return session.access_jwt 85 + return session.access_jwt if session else None 84 86 85 - def list_sessions(self) -> list[Session]: 86 - return self.store.list_sessions_by_pds(self.pds_url) 87 87 88 - def remove_session(self, did: str) -> None: 89 - self.store.remove_session(did) 90 - 91 - 92 - _auth_instances: dict[str, PDSAuth] = None # type: ignore 88 + _auth_instances: dict[str, PDSAuth] = {} 93 89 _store: AtprotoStore | None = None 94 90 95 91 96 92 def init_atproto_store(db) -> AtprotoStore: 97 - global _store, _auth_instances 93 + global _store 98 94 if _store is None: 99 95 _store = AtprotoStore(db.get_conn()) 100 - _auth_instances = {} 101 - return _store 102 - 103 - 104 - def get_atproto_store() -> AtprotoStore | None: 105 96 return _store 106 97 107 98 ··· 114 105 return _auth_instances[normalized] 115 106 116 107 117 - def get_auth_by_did(did: str) -> PDSAuth | None: 118 - if _store is None: 119 - return None 120 - for auth in _auth_instances.values(): 121 - if auth.get_session(did): 122 - return auth 123 - return None 124 - 125 - 126 - def cleanup_expired() -> None: 127 - if _store is not None: 128 - _store.cleanup_expired() 129 - 130 - 131 - def flush_caches() -> tuple[int, int]: 132 - if _store is not None: 133 - return _store.flush_all() 134 - return 0, 0 135 - 136 - 137 108 def resolve_identity(identifier: str) -> IdentityInfo: 138 109 if _store is None: 139 110 raise RuntimeError("AtprotoStore not initialized") ··· 159 130 raise ValueError(f"Failed to resolve identity {identifier}: {e}") from e 160 131 161 132 162 - def _cleanup_hook(): 163 - if _store: 164 - cleanup_expired() 133 + shutdown_hook.append(lambda: _store.cleanup_expired() if _store else None) 165 134 166 135 167 - shutdown_hook.append(_cleanup_hook) 136 + def flush_caches() -> tuple[int, int]: 137 + if _store is not None: 138 + return _store.flush_all() 139 + return 0, 0 140 +
+41 -56
atproto/store.py
··· 1 + import base64 2 + import json 1 3 import sqlite3 2 4 import time 3 5 from dataclasses import dataclass 6 + from functools import cached_property 4 7 from typing import Any 5 8 6 9 10 + def _decode_jwt_payload(token: str) -> dict[str, Any]: 11 + try: 12 + _, claims, _ = token.split(".") 13 + claims = claims + '=' * (4 - len(claims) % 4) if len(claims) % 4 else claims 14 + return json.loads(base64.urlsafe_b64decode(claims)) # type: ignore[no-any-return] 15 + except Exception: 16 + return {} 17 + 18 + 7 19 @dataclass 8 20 class Session: 9 21 access_jwt: str ··· 17 29 active: bool = True 18 30 status: str | None = None 19 31 32 + @cached_property 33 + def access_payload(self) -> dict[str, Any]: 34 + return _decode_jwt_payload(self.access_jwt) 35 + 36 + @cached_property 37 + def refresh_payload(self) -> dict[str, Any]: 38 + return _decode_jwt_payload(self.refresh_jwt) 39 + 40 + def is_access_token_expired(self, buffer_seconds: int = 60) -> bool: 41 + exp = self.access_payload.get("exp", 0) 42 + return bool(time.time() >= (exp - buffer_seconds)) 43 + 44 + def is_refresh_token_expired(self, buffer_seconds: int = 60) -> bool: 45 + exp = self.refresh_payload.get("exp", 0) 46 + return bool(time.time() >= (exp - buffer_seconds)) 47 + 20 48 @classmethod 21 49 def from_row(cls, row: sqlite3.Row) -> "Session": 22 50 return cls( ··· 78 106 def __init__( 79 107 self, 80 108 db: sqlite3.Connection, 81 - session_ttl: int = 2 * 60 * 60, 82 109 identity_ttl: int = 12 * 60 * 60, 83 110 ) -> None: 84 111 self.db = db 85 112 self.db.row_factory = sqlite3.Row 86 - self.session_ttl = session_ttl 87 113 self.identity_ttl = identity_ttl 88 114 89 115 def get_session(self, did: str) -> Session | None: 90 - row = self.db.execute( 91 - "SELECT * FROM atproto_sessions WHERE did = ? AND created_at + ? > ?", 92 - (did, self.session_ttl, time.time()) 93 - ).fetchone() 94 - if not row: 95 - return None 96 - return Session.from_row(row) 116 + row = self.db.execute("SELECT * FROM atproto_sessions WHERE did = ?", (did,)).fetchone() 117 + return Session.from_row(row) if row else None 97 118 98 119 def set_session(self, session: Session) -> None: 99 120 now = time.time() ··· 103 124 email_auth_factor, active, status, created_at) 104 125 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) 105 126 """, ( 106 - session.did, 107 - session.pds, 108 - session.handle, 109 - session.access_jwt, 110 - session.refresh_jwt, 111 - session.email, 112 - session.email_confirmed, 113 - session.email_auth_factor, 114 - session.active, 115 - session.status, 116 - now, 127 + session.did, session.pds, session.handle, session.access_jwt, 128 + session.refresh_jwt, session.email, session.email_confirmed, 129 + session.email_auth_factor, session.active, session.status, now, 117 130 )) 118 131 self.db.commit() 119 132 120 133 def get_session_by_pds(self, pds: str, identifier: str) -> Session | None: 121 134 row = self.db.execute(""" 122 135 SELECT * FROM atproto_sessions 123 - WHERE pds = ? AND (did = ? OR handle = ?) AND created_at + ? > ? 124 - """, (pds, identifier, identifier, self.session_ttl, time.time())).fetchone() 125 - if not row: 126 - return None 127 - return Session.from_row(row) 136 + WHERE pds = ? AND (did = ? OR handle = ?) 137 + """, (pds, identifier, identifier)).fetchone() 138 + return Session.from_row(row) if row else None 128 139 129 140 def list_sessions_by_pds(self, pds: str) -> list[Session]: 130 - rows = self.db.execute(""" 131 - SELECT * FROM atproto_sessions 132 - WHERE pds = ? AND created_at + ? > ? 133 - """, (pds, self.session_ttl, time.time())).fetchall() 141 + rows = self.db.execute("SELECT * FROM atproto_sessions WHERE pds = ?", (pds,)).fetchall() 134 142 return [Session.from_row(row) for row in rows] 135 143 136 144 def remove_session(self, did: str) -> None: 137 145 self.db.execute("DELETE FROM atproto_sessions WHERE did = ?", (did,)) 138 146 self.db.commit() 139 147 140 - def cleanup_expired_sessions(self) -> int: 141 - cutoff = time.time() - self.session_ttl 142 - cursor = self.db.execute( 143 - "DELETE FROM atproto_sessions WHERE created_at + ? < ?", 144 - (self.session_ttl, cutoff) 145 - ) 146 - self.db.commit() 147 - return cursor.rowcount 148 - 149 148 def get_identity(self, identifier: str) -> IdentityInfo | None: 150 149 row = self.db.execute( 151 150 "SELECT * FROM atproto_identities WHERE identifier = ? AND created_at + ? > ?", 152 151 (identifier, self.identity_ttl, time.time()) 153 152 ).fetchone() 154 - if not row: 155 - return None 156 - return IdentityInfo.from_row(row) 153 + return IdentityInfo.from_row(row) if row else None 157 154 158 155 def set_identity(self, identifier: str, identity: IdentityInfo) -> None: 159 156 now = time.time() ··· 162 159 INSERT OR REPLACE INTO atproto_identities 163 160 (identifier, did, handle, pds, signing_key, created_at) 164 161 VALUES (?, ?, ?, ?, ?, ?) 165 - """, ( 166 - key, 167 - identity.did, 168 - identity.handle, 169 - identity.pds, 170 - identity.signing_key, 171 - now, 172 - )) 162 + """, (key, identity.did, identity.handle, identity.pds, identity.signing_key, now)) 173 163 self.db.commit() 174 164 175 165 def remove_identity(self, identifier: str) -> None: 176 166 self.db.execute("DELETE FROM atproto_identities WHERE identifier = ?", (identifier,)) 177 167 self.db.commit() 178 168 179 - def cleanup_expired_identities(self) -> int: 169 + def cleanup_expired(self) -> None: 180 170 cutoff = time.time() - self.identity_ttl 181 - cursor = self.db.execute( 171 + self.db.execute( 182 172 "DELETE FROM atproto_identities WHERE created_at + ? < ?", 183 173 (self.identity_ttl, cutoff) 184 174 ) 185 175 self.db.commit() 186 - return cursor.rowcount 187 - 188 - def cleanup_expired(self) -> None: 189 - self.cleanup_expired_sessions() 190 - self.cleanup_expired_identities() 191 176 192 177 def flush_all(self) -> tuple[int, int]: 193 178 """Delete all cached sessions and identities."""
+1 -1
bluesky/output.py
··· 29 29 def __init__(self, db: DatabasePool, options: BlueskyOutputOptions) -> None: 30 30 super().__init__(SERVICE, db) 31 31 self.options: BlueskyOutputOptions = options 32 - self._init_identity() 33 32 init_atproto_store(db) 33 + self._init_identity() 34 34 self._auth = get_auth(self.pds) 35 35 self._auth.login(self.did, options.password) 36 36