Bluesky feed server - NSFW Likes

refactor: Move auth logic to Ktor plugin

Closes #7

+341 -184
+14 -9
darkfeed/src/main/kotlin/Main.kt
··· 1 1 package rs.averyrive.darkfeed 2 2 3 + import AuthAccount 4 + import AuthManager 3 5 import io.ktor.http.* 4 6 import kotlinx.coroutines.launch 5 7 import kotlinx.coroutines.runBlocking ··· 18 20 /** Hostname of the feed generator server. */ 19 21 val hostname: String, 20 22 /** Record key for the feed generator record. */ 21 - val recordKey: String = "darkfeed", 23 + val recordKey: String = "darkfeed-dev", 22 24 /** Display name for the feed. */ 23 - val feedDisplayName: String = "DarkFeed", 25 + val feedDisplayName: String = "DarkFeed (Dev)", 24 26 /** Description for the feed. */ 25 27 val description: String = "hi :3", 26 28 ) ··· 78 80 ?: printMessageAndExit("error: variable HOSTNAME not set"), 79 81 ) 80 82 81 - // Create API instance. 82 - val bskyApi = BskyApi(buildUrl { 83 - protocol = URLProtocol.HTTPS 84 - host = ctx.ownerPds 85 - }) 83 + // Create Bluesky API instance. 84 + val bskyApi = BskyApi( 85 + authManager = AuthManager( 86 + authAccount = AuthAccount( 87 + username = ctx.ownerDid, 88 + password = ctx.ownerPassword, 89 + pdsHost = ctx.ownerPds, 90 + ), 91 + ) 92 + ) 86 93 87 94 // Verify and update the feed generator record. 88 95 launch { 89 - bskyApi.login(ctx.ownerDid, ctx.ownerPassword) 90 - 91 96 try { 92 97 verifyAndUpdateFeedGeneratorRecord(bskyApi, ctx) 93 98 println("main: feed generator record verified")
+163
darkfeed/src/main/kotlin/api/AuthManager.kt
··· 1 + import io.ktor.client.* 2 + import io.ktor.client.call.* 3 + import io.ktor.client.engine.cio.* 4 + import io.ktor.client.plugins.contentnegotiation.* 5 + import io.ktor.client.plugins.logging.* 6 + import io.ktor.client.request.* 7 + import io.ktor.http.* 8 + import io.ktor.serialization.kotlinx.json.* 9 + import kotlinx.serialization.Serializable 10 + import kotlinx.serialization.json.Json 11 + import org.slf4j.Logger 12 + import org.slf4j.LoggerFactory 13 + 14 + /** 15 + * Stores auth tokens and handles all auth related requests. 16 + * 17 + * @param authAccount ATProto account to use for authorization. 18 + * @param httpClient Client to use for requests. 19 + */ 20 + class AuthManager( 21 + private val authAccount: AuthAccount, 22 + private val httpClient: HttpClient = HttpClient(CIO) { 23 + install(ContentNegotiation) { 24 + json(Json { 25 + explicitNulls = false 26 + ignoreUnknownKeys = true 27 + }) 28 + } 29 + 30 + install(Logging) 31 + }, 32 + ) { 33 + /** The current bearer tokens. */ 34 + var authTokens: AuthTokens? = null 35 + private set 36 + 37 + /** Logger to use for this class. */ 38 + private val log: Logger = LoggerFactory.getLogger(this.javaClass) 39 + 40 + /** Create a new auth session. */ 41 + suspend fun createSession() { 42 + @Serializable 43 + data class Request(val identifier: String, val password: String) 44 + 45 + val createSessionUrl = buildUrl { 46 + protocol = URLProtocol.HTTPS 47 + host = this@AuthManager.authAccount.pdsHost 48 + path("/xrpc/com.atproto.server.createSession") 49 + } 50 + 51 + val requestBody = Request( 52 + this.authAccount.username, 53 + this.authAccount.password, 54 + ) 55 + 56 + log.debug( 57 + "Creating new session for '{}' at '{}'.", 58 + requestBody.identifier, 59 + createSessionUrl.hostWithPortIfSpecified, 60 + ) 61 + 62 + val response = this.httpClient.post(createSessionUrl) { 63 + contentType(ContentType.Application.Json) 64 + setBody(requestBody) 65 + } 66 + 67 + log.debug( 68 + "Received response for new session with status '{}'.", 69 + response.status, 70 + ) 71 + 72 + when (response.status) { 73 + HttpStatusCode.OK -> { 74 + try { 75 + val authSession: AuthSession = response.body() 76 + this.authTokens = authSession.into() 77 + log.debug("New session created, updated auth tokens.") 78 + } catch (error: Exception) { 79 + TODO("Handle deserialization errors") 80 + } 81 + } 82 + 83 + else -> { 84 + TODO("Handle failures") 85 + } 86 + } 87 + } 88 + 89 + /** Refresh the current auth session. */ 90 + suspend fun refreshSession() { 91 + val refreshSessionUrl = buildUrl { 92 + protocol = URLProtocol.HTTPS 93 + host = this@AuthManager.authAccount.pdsHost 94 + path("/xrpc/com.atproto.server.refreshSession") 95 + } 96 + 97 + val refreshToken = this.authTokens?.refreshToken!! 98 + val authHeaderValue = "Bearer ${refreshToken}" 99 + 100 + log.debug("Refreshing current session at '{}'", refreshSessionUrl.hostWithPortIfSpecified) 101 + 102 + val response = this.httpClient.post(refreshSessionUrl) { 103 + header(HttpHeaders.Authorization, authHeaderValue) 104 + } 105 + 106 + log.debug("Received session refresh response with status {}", response.status) 107 + 108 + when (response.status) { 109 + HttpStatusCode.OK -> { 110 + try { 111 + val authSession: AuthSession = response.body() 112 + this.authTokens = authSession.into() 113 + log.debug("Session refreshed, updated auth tokens") 114 + } catch (error: Exception) { 115 + TODO("Handle deserialization errors") 116 + } 117 + } 118 + 119 + else -> { 120 + TODO("Handle failures") 121 + } 122 + } 123 + } 124 + } 125 + 126 + /** 127 + * ATProto account details. 128 + * 129 + * @param username Account handle or DID. 130 + * @param password Account password or app password. 131 + * @param pdsHost Hostname of account's PDS. 132 + */ 133 + data class AuthAccount( 134 + val username: String, 135 + val password: String, 136 + val pdsHost: String, 137 + ) 138 + 139 + /** 140 + * ATProto bearer tokens. 141 + * 142 + * @param accessToken Token used for normal authorized requests. 143 + * @param refreshToken Token used for session refresh requests. 144 + */ 145 + data class AuthTokens( 146 + val accessToken: String, 147 + val refreshToken: String, 148 + ) 149 + 150 + /** 151 + * ATProto auth session details. 152 + * 153 + * @param accessJwt: Token used for normal authorized requests. 154 + * @param refreshJwt: Token used to session refresh requests. 155 + */ 156 + @Serializable 157 + data class AuthSession( 158 + val accessJwt: String, 159 + val refreshJwt: String 160 + ) { 161 + /** Create `AuthTokens` from an `AuthSession`. */ 162 + fun into(): AuthTokens = AuthTokens(this.accessJwt, this.refreshJwt) 163 + }
+158
darkfeed/src/main/kotlin/api/AuthPlugin.kt
··· 1 + package rs.averyrive.darkfeed.api 2 + 3 + import AuthManager 4 + import io.ktor.client.call.* 5 + import io.ktor.client.plugins.api.* 6 + import io.ktor.client.statement.* 7 + import io.ktor.http.* 8 + import kotlinx.coroutines.sync.Mutex 9 + import kotlinx.coroutines.sync.withLock 10 + import kotlinx.serialization.Serializable 11 + import org.slf4j.LoggerFactory 12 + 13 + const val PLUGIN_NAME: String = "AuthPlugin" 14 + 15 + val AuthPlugin = createClientPlugin(PLUGIN_NAME, ::AuthPluginConfig) { 16 + val authManager = pluginConfig.authManager ?: throw AuthPluginConfigurationError("Auth manager is required") 17 + val authMutex = Mutex() 18 + val log = LoggerFactory.getLogger(PLUGIN_NAME) 19 + 20 + // Add authorization header to requests. 21 + onRequest { request, _ -> 22 + // Format the request's endpoint as '<protocol>://<host>/<path>' for use in logs. 23 + val endpoint = with(request.url) { "${protocol.name}://${host}${encodedPath}" } 24 + 25 + // Remove any existing `Authorization` headers. 26 + if (request.headers.contains(HttpHeaders.Authorization)) { 27 + log.info("Replacing 'Authorization' header on request to '{}'.", endpoint) 28 + request.headers.remove(HttpHeaders.Authorization) 29 + } 30 + 31 + // If another request is already refreshing tokens, this request will 32 + // pause until the other request is finished, then use the new tokens. 33 + val accessToken = authMutex.withLock { 34 + authManager.authTokens?.accessToken 35 + } 36 + 37 + // If the auth manager doesn't have auth tokens, try running the request 38 + // normally. This is a last resort in case the request doesn't require 39 + // authorization. 40 + if (accessToken == null) { 41 + log.debug( 42 + "No access token retrieved from 'AuthManager'. Sending request to '{}' without 'Authorization' header.", 43 + endpoint 44 + ) 45 + return@onRequest 46 + } 47 + 48 + // Add the authorization header. 49 + request.headers.append(HttpHeaders.Authorization, "Bearer $accessToken") 50 + 51 + log.debug("'Authorization' header added to request to '{}'.", endpoint) 52 + } 53 + 54 + // Check responses for authorization failures. 55 + on(Send) { request -> 56 + // Format the request's endpoint as '<protocol>://<host>/<path>' for use in logs. 57 + val endpoint = with(request.url) { "${protocol.name}://${host}${encodedPath}" } 58 + 59 + // Send the request. 60 + val originalCall = proceed(request) 61 + 62 + // Get the original access token from the `Authoriation` header. 63 + val originalAccessToken = originalCall.request.headers[HttpHeaders.Authorization]?.removePrefix("Bearer ") 64 + 65 + // Try to get a new access token to use when retrying the request. 66 + val newAccessToken = when (originalCall.response.status) { 67 + // An unauthorized response means a new session needs to be created. 68 + HttpStatusCode.Unauthorized -> { 69 + val newAccessToken = authMutex.withLock { authManager.authTokens?.accessToken } 70 + 71 + if (originalAccessToken != newAccessToken) { 72 + // Another request has already retrieved new tokens. 73 + log.debug("New tokens already retrieved.") 74 + newAccessToken 75 + } else { 76 + log.debug("Request to '{}' received '401 Unauthorized'. Creating new session.", endpoint) 77 + 78 + // Create a new session. By locking the auth mutex here, 79 + // other requests will block when they try to get the 80 + // tokens. 81 + authMutex.withLock { 82 + authManager.createSession() 83 + authManager.authTokens?.accessToken 84 + } 85 + } 86 + } 87 + 88 + // If the access token is expired, the response will have error code 89 + // 400 with the error `ExpiredToken`. The session needs to be 90 + // refreshed in that case. 91 + HttpStatusCode.BadRequest -> { 92 + log.debug( 93 + "Request to '{}' received '400 Bad Request'. Response: '{}'.", 94 + endpoint, 95 + originalCall.response.bodyAsText() 96 + ) 97 + 98 + // Try to deserialize the error response. If the response isn't 99 + // what's expected, just let the error go to the caller. 100 + val errorResponse = try { 101 + originalCall.response.body<ErrorResponse>() 102 + } catch (e: Exception) { 103 + null 104 + } 105 + 106 + if (errorResponse?.error == "ExpiredToken") { 107 + val newAccessToken = authMutex.withLock { authManager.authTokens?.accessToken } 108 + 109 + if (originalAccessToken != newAccessToken) { 110 + // Another request has already retrieved new tokens. 111 + log.debug("New tokens already retrieved.") 112 + newAccessToken 113 + } else { 114 + log.debug("Received error 'ExpiredToken'. Refreshing session.") 115 + 116 + // Create a new session. By locking the auth mutex here, 117 + // other requests will block when they try to get the 118 + // tokens. 119 + authMutex.withLock { 120 + authManager.refreshSession() 121 + authManager.authTokens?.accessToken 122 + } 123 + } 124 + } else { 125 + // A non auth-related error occurred. 126 + return@on originalCall 127 + } 128 + } 129 + 130 + // A non-auth related response was received. 131 + else -> return@on originalCall 132 + } 133 + 134 + log.debug("Retrying request with new access token.") 135 + 136 + // Retry the original request with the new access token. 137 + originalCall.run { 138 + request.headers.remove(HttpHeaders.Authorization) 139 + request.headers.append(HttpHeaders.Authorization, "Bearer $newAccessToken") 140 + 141 + proceed(request) 142 + } 143 + } 144 + } 145 + 146 + class AuthPluginConfig { 147 + var authManager: AuthManager? = null 148 + } 149 + 150 + @Serializable 151 + data class ErrorResponse( 152 + val error: String, 153 + val message: String, 154 + ) 155 + 156 + open class AuthPluginError(message: String, cause: Throwable? = null) : RuntimeException(message, cause) 157 + 158 + class AuthPluginConfigurationError(message: String) : AuthPluginError(message)
+6 -175
darkfeed/src/main/kotlin/api/BskyApi.kt
··· 1 1 package rs.averyrive.darkfeed.api 2 2 3 + import AuthManager 3 4 import io.ktor.client.* 4 5 import io.ktor.client.call.* 5 6 import io.ktor.client.engine.cio.* 6 7 import io.ktor.client.plugins.* 7 - import io.ktor.client.plugins.auth.providers.* 8 8 import io.ktor.client.plugins.contentnegotiation.* 9 9 import io.ktor.client.plugins.logging.* 10 10 import io.ktor.client.request.* 11 11 import io.ktor.client.statement.* 12 12 import io.ktor.http.* 13 13 import io.ktor.serialization.kotlinx.json.* 14 - import kotlinx.coroutines.sync.Mutex 15 - import kotlinx.coroutines.sync.withLock 16 14 import kotlinx.serialization.Serializable 17 15 import kotlinx.serialization.json.Json 18 16 import org.slf4j.Logger ··· 22 20 import rs.averyrive.darkfeed.api.lexicon.app.bsky.feed.defs.PostView 23 21 24 22 class BskyApi( 23 + private val authManager: AuthManager, 25 24 private val pdsUrl: Url = Url("https://bsky.social"), 26 - 27 - private var bearerTokens: BearerTokens? = null, 28 - 29 - private val bearerTokensMutex: Mutex = Mutex(), 30 - 31 - private val httpClient: HttpClient = HttpClient(CIO) { 25 + ) { 26 + private val httpClient = HttpClient(CIO) { 32 27 install(ContentNegotiation) { 33 28 json(Json { 34 29 explicitNulls = false ··· 38 33 39 34 install(Logging) 40 35 41 - defaultRequest { 42 - url { 43 - protocol = pdsUrl.protocol 44 - host = pdsUrl.host 45 - path("xrpc/") 46 - } 47 - } 48 - }, 49 - 50 - private val authHttpClient: HttpClient = HttpClient(CIO) { 51 - install(ContentNegotiation) { 52 - json(Json { 53 - explicitNulls = false 54 - ignoreUnknownKeys = true 55 - }) 36 + install(AuthPlugin) { 37 + this.authManager = this@BskyApi.authManager 56 38 } 57 - 58 - install(Logging) 59 39 60 40 defaultRequest { 61 41 url { ··· 64 44 path("xrpc/") 65 45 } 66 46 } 67 - }, 68 - ) { 69 - companion object { 70 - val unauthorizedPaths = setOf( 71 - "com.atproto.server.createSession", 72 - "com.atproto.server.refreshSession", 73 - "com.atproto.repo.getRecord", 74 - "com.atproto.repo.listRecords", 75 - ) 76 47 } 77 48 78 49 private val log: Logger = LoggerFactory.getLogger(this::class.java) 79 50 80 - init { 81 - httpClient.plugin(HttpSend).intercept { request -> 82 - log.debug( 83 - "Intercepting request to {}://{}{}", 84 - request.url.protocol.name, 85 - request.url.host, 86 - request.url.encodedPath 87 - ) 88 - 89 - // If this request does not require authorization, send it normally. 90 - if (unauthorizedPaths.any { request.url.encodedPath.contains(it) }) { 91 - log.debug("Request does not require authentication, sending normally") 92 - return@intercept execute(request) 93 - } 94 - 95 - // Get the current access token. If another coroutine is currently 96 - // refreshing the tokens, this will block until finished and get 97 - // new tokens. 98 - val accessToken = bearerTokensMutex.withLock { 99 - bearerTokens?.accessToken ?: throw RuntimeException("No auth tokens") 100 - } 101 - 102 - // Add authorization header to request. 103 - request.headers.remove(HttpHeaders.Authorization) 104 - request.headers.append(HttpHeaders.Authorization, "Bearer $accessToken") 105 - 106 - // Send request. 107 - val call = execute(request) 108 - 109 - // Check the response. 110 - val newAccessToken = when (call.response.status) { 111 - HttpStatusCode.Unauthorized -> { 112 - // Get new tokens using username and app password. 113 - log.debug("Received {}, refreshing session with username and password", call.response.status) 114 - 115 - TODO("Session refresh with username and password is not implemented yet") 116 - } 117 - 118 - HttpStatusCode.BadRequest -> { 119 - log.debug("Received {}, error: {}", call.response.status, call.response.bodyAsText()) 120 - 121 - // Check error code. 122 - val errorResponse = try { 123 - call.response.body<ErrorResponse>() 124 - } catch (e: Exception) { 125 - null 126 - } 127 - 128 - // Access token is expired, use the refresh token to get new tokens. 129 - if (errorResponse?.error == "ExpiredToken") { 130 - // Get the new access token. 131 - val newAccessToken = bearerTokensMutex.withLock { bearerTokens?.accessToken } 132 - 133 - // If the tokens have changed since the original call, 134 - // then another coroutine has updated them and the new 135 - // access token should be used. 136 - if (newAccessToken == accessToken) { 137 - log.debug("Access token is expired, using refresh token to get new tokens") 138 - 139 - // Get new tokens using the refresh token. 140 - bearerTokensMutex.withLock { 141 - val refreshToken = 142 - bearerTokens?.refreshToken ?: throw RuntimeException("No refresh token") 143 - 144 - @Serializable 145 - data class Response( 146 - val accessJwt: String, 147 - val refreshJwt: String, 148 - val handle: String, 149 - val did: String, 150 - ) 151 - 152 - val refreshRequest = authHttpClient.post("com.atproto.server.refreshSession") { 153 - header(HttpHeaders.Authorization, "Bearer $refreshToken") 154 - } 155 - 156 - // TODO: Check status codes. 157 - 158 - val response: Response = refreshRequest.body() 159 - 160 - bearerTokens = BearerTokens( 161 - accessToken = response.accessJwt, 162 - refreshToken = response.refreshJwt, 163 - ) 164 - 165 - // Return the newly refreshed access token. 166 - bearerTokens?.accessToken!! 167 - } 168 - } else { 169 - log.debug("Tokens refreshed by another coroutine") 170 - 171 - // Return the newly retrieved access token. 172 - newAccessToken!! 173 - } 174 - } else { 175 - // Another error has occurred. Return the original call. 176 - return@intercept call 177 - } 178 - } 179 - 180 - // Another status code was returned. Return the original call. 181 - else -> return@intercept call 182 - } 183 - 184 - // Resend the request with the new access token. 185 - // TODO: Check if this is necessary. If this request gets intercepted, this won't be necessary. 186 - request.headers.remove(HttpHeaders.Authorization) 187 - request.headers.append(HttpHeaders.Authorization, "Bearer $newAccessToken") 188 - 189 - log.debug("Retrying original request with new access token") 190 - 191 - execute(request) 192 - } 193 - } 194 - 195 51 @Serializable 196 52 data class ErrorResponse( 197 53 val error: String, 198 54 val message: String, 199 55 ) 200 - 201 - suspend fun login(identifier: String, password: String) { 202 - @Serializable 203 - data class Request(val identifier: String, val password: String) 204 - 205 - @Serializable 206 - data class Response(val did: String, val accessJwt: String, val refreshJwt: String) 207 - 208 - val response = httpClient.post("com.atproto.server.createSession") { 209 - contentType(ContentType.Application.Json) 210 - setBody(Request(identifier, password)) 211 - } 212 - 213 - when (response.status) { 214 - HttpStatusCode.OK -> { 215 - val tokens: Response = response.body() 216 - bearerTokens = BearerTokens(tokens.accessJwt, tokens.refreshJwt) 217 - } 218 - 219 - HttpStatusCode.BadRequest, 220 - HttpStatusCode.Unauthorized -> throw RuntimeException("Failed to create session: ${response.bodyAsText()}") 221 - 222 - else -> throw RuntimeException("Unexpected response received: ${response.bodyAsText()}") 223 - } 224 - } 225 56 226 57 suspend fun getFeedGeneratorRecord(repo: String, rkey: String): Generator? { 227 58 @Serializable