Bluesky feed server - NSFW Likes

fix: Improve token refresh logic

+162 -92
+1 -1
darkfeed/build.gradle.kts
··· 29 29 } 30 30 31 31 application { 32 - mainClass = "MainKt" 32 + mainClass = "rs.averyrive.darkfeed.MainKt" 33 33 }
+6 -7
darkfeed/src/main/kotlin/Main.kt
··· 1 - import api.BskyApi 2 - import api.lexicon.app.bsky.feed.Generator 1 + package rs.averyrive.darkfeed 2 + 3 3 import io.ktor.http.* 4 4 import kotlinx.coroutines.launch 5 5 import kotlinx.coroutines.runBlocking 6 - import server.FeedServer 6 + import rs.averyrive.darkfeed.api.BskyApi 7 + import rs.averyrive.darkfeed.api.lexicon.app.bsky.feed.Generator 8 + import rs.averyrive.darkfeed.server.FeedServer 7 9 import kotlin.system.exitProcess 8 10 9 - /** 10 - * 11 - */ 12 11 data class AppContext( 13 12 /** PDS of the feed owner's account. */ 14 13 val ownerPds: String, ··· 103 102 FeedServer( 104 103 hostname = ctx.hostname, 105 104 bskyApi = bskyApi, 106 - port = 8080, 105 + port = 1234, 107 106 ).serve() 108 107 }
+136 -67
darkfeed/src/main/kotlin/api/BskyApi.kt
··· 1 - package api 1 + package rs.averyrive.darkfeed.api 2 2 3 - import api.lexicon.app.bsky.feed.Generator 4 - import api.lexicon.app.bsky.feed.LikeRef 5 - import api.lexicon.app.bsky.feed.defs.PostView 6 3 import io.ktor.client.* 7 4 import io.ktor.client.call.* 8 5 import io.ktor.client.engine.cio.* 9 6 import io.ktor.client.plugins.* 10 - import io.ktor.client.plugins.auth.* 11 7 import io.ktor.client.plugins.auth.providers.* 12 8 import io.ktor.client.plugins.contentnegotiation.* 13 9 import io.ktor.client.plugins.logging.* ··· 15 11 import io.ktor.client.statement.* 16 12 import io.ktor.http.* 17 13 import io.ktor.serialization.kotlinx.json.* 14 + import kotlinx.coroutines.sync.Mutex 15 + import kotlinx.coroutines.sync.withLock 18 16 import kotlinx.serialization.Serializable 19 17 import kotlinx.serialization.json.Json 18 + import org.slf4j.Logger 19 + import org.slf4j.LoggerFactory 20 + import rs.averyrive.darkfeed.api.lexicon.app.bsky.feed.Generator 21 + import rs.averyrive.darkfeed.api.lexicon.app.bsky.feed.LikeRef 22 + import rs.averyrive.darkfeed.api.lexicon.app.bsky.feed.defs.PostView 20 23 21 24 class BskyApi( 22 25 private val pdsUrl: Url = Url("https://bsky.social"), 23 26 24 - private val bearerTokens: MutableList<BearerTokens> = mutableListOf(), 27 + private var bearerTokens: BearerTokens? = null, 28 + 29 + private val bearerTokensMutex: Mutex = Mutex(), 25 30 26 31 private val httpClient: HttpClient = HttpClient(CIO) { 27 32 install(ContentNegotiation) { ··· 33 38 34 39 install(Logging) 35 40 36 - install(Auth) { 37 - bearer { 38 - loadTokens { 39 - bearerTokens.lastOrNull() 40 - } 41 - 42 - refreshTokens { 43 - val currentRefreshToken = bearerTokens.lastOrNull()?.refreshToken ?: return@refreshTokens null 44 - 45 - @Serializable 46 - data class Response(val accessJwt: String, val refreshJwt: String) 47 - 48 - val refreshSessionResponse = client.post("com.atproto.server.refreshSession") { 49 - header("Authorization", "Bearer $currentRefreshToken") 50 - markAsRefreshTokenRequest() 51 - } 52 - 53 - when (refreshSessionResponse.status) { 54 - HttpStatusCode.OK -> { 55 - val refreshSessionTokens = refreshSessionResponse.body<Response>() 56 - val newBearerTokens = 57 - BearerTokens(refreshSessionTokens.accessJwt, refreshSessionTokens.refreshJwt) 58 - 59 - bearerTokens.addLast(newBearerTokens) 60 - 61 - return@refreshTokens newBearerTokens 62 - } 63 - 64 - HttpStatusCode.BadRequest, 65 - HttpStatusCode.Unauthorized -> return@refreshTokens null 66 - 67 - else -> return@refreshTokens null 68 - } 69 - } 41 + defaultRequest { 42 + url { 43 + protocol = pdsUrl.protocol 44 + host = pdsUrl.host 45 + path("xrpc/") 70 46 } 71 47 } 48 + }, 49 + 50 + private val authHttpClient: HttpClient = HttpClient(CIO) { 51 + install(ContentNegotiation) { 52 + json(Json { 53 + explicitNulls = false 54 + ignoreUnknownKeys = true 55 + }) 56 + } 57 + 58 + install(Logging) 72 59 73 60 defaultRequest { 74 61 url { ··· 79 66 } 80 67 }, 81 68 ) { 69 + companion object { 70 + val unauthorizedPaths = setOf( 71 + "com.atproto.server.createSession", 72 + "com.atproto.server.refreshSession", 73 + "com.atproto.repo.getRecord", 74 + "com.atproto.repo.listRecords", 75 + ) 76 + } 77 + 78 + private val log: Logger = LoggerFactory.getLogger(this::class.java) 79 + 82 80 init { 83 81 httpClient.plugin(HttpSend).intercept { request -> 84 - val originalCall = execute(request) 82 + log.debug( 83 + "Intercepting request to {}://{}{}", 84 + request.url.protocol.name, 85 + request.url.host, 86 + request.url.encodedPath 87 + ) 88 + 89 + // If this request does not require authorization, send it normally. 90 + if (unauthorizedPaths.any { request.url.encodedPath.contains(it) }) { 91 + log.debug("Request does not require authentication, sending normally") 92 + return@intercept execute(request) 93 + } 94 + 95 + // Get the current access token. If another coroutine is currently 96 + // refreshing the tokens, this will block until finished and get 97 + // new tokens. 98 + val accessToken = bearerTokensMutex.withLock { 99 + bearerTokens?.accessToken ?: throw RuntimeException("No auth tokens") 100 + } 101 + 102 + // Add authorization header to request. 103 + request.headers.remove(HttpHeaders.Authorization) 104 + request.headers.append(HttpHeaders.Authorization, "Bearer $accessToken") 105 + 106 + // Send request. 107 + val call = execute(request) 85 108 86 - if (originalCall.response.status == HttpStatusCode.BadRequest) { 87 - val errorResponse = try { 88 - originalCall.response.body<ErrorResponse>() 89 - } catch (e: Exception) { 90 - null 109 + // Check the response. 110 + val newAccessToken = when (call.response.status) { 111 + HttpStatusCode.Unauthorized -> { 112 + // Get new tokens using username and app password. 113 + log.debug("Received {}, refreshing session with username and password", call.response.status) 114 + 115 + TODO("Session refresh with username and password is not implemented yet") 91 116 } 92 117 93 - if (errorResponse?.error == "ExpiredToken") { 94 - val currentRefreshToken = bearerTokens.lastOrNull()?.refreshToken ?: return@intercept originalCall 118 + HttpStatusCode.BadRequest -> { 119 + log.debug("Received {}, error: {}", call.response.status, call.response.bodyAsText()) 120 + 121 + // Check error code. 122 + val errorResponse = try { 123 + call.response.body<ErrorResponse>() 124 + } catch (e: Exception) { 125 + null 126 + } 127 + 128 + // Access token is expired, use the refresh token to get new tokens. 129 + if (errorResponse?.error == "ExpiredToken") { 130 + // Get the new access token. 131 + val newAccessToken = bearerTokensMutex.withLock { bearerTokens?.accessToken } 132 + 133 + // If the tokens have changed since the original call, 134 + // then another coroutine has updated them and the new 135 + // access token should be used. 136 + if (newAccessToken == accessToken) { 137 + log.debug("Access token is expired, using refresh token to get new tokens") 138 + 139 + // Get new tokens using the refresh token. 140 + bearerTokensMutex.withLock { 141 + val refreshToken = 142 + bearerTokens?.refreshToken ?: throw RuntimeException("No refresh token") 143 + 144 + @Serializable 145 + data class Response( 146 + val accessJwt: String, 147 + val refreshJwt: String, 148 + val handle: String, 149 + val did: String, 150 + ) 151 + 152 + val refreshRequest = authHttpClient.post("com.atproto.server.refreshSession") { 153 + header(HttpHeaders.Authorization, "Bearer $refreshToken") 154 + } 95 155 96 - @Serializable 97 - data class Response(val accessJwt: String, val refreshJwt: String) 156 + // TODO: Check status codes. 98 157 99 - val refreshSessionResponse = httpClient.post("com.atproto.server.refreshSession") { 100 - header("Authorization", "Bearer $currentRefreshToken") 101 - } 158 + val response: Response = refreshRequest.body() 102 159 103 - if (refreshSessionResponse.status == HttpStatusCode.OK) { 104 - val refreshSessionTokens = refreshSessionResponse.body<Response>() 105 - val newBearerTokens = 106 - BearerTokens(refreshSessionTokens.accessJwt, refreshSessionTokens.refreshJwt) 160 + bearerTokens = BearerTokens( 161 + accessToken = response.accessJwt, 162 + refreshToken = response.refreshJwt, 163 + ) 107 164 108 - bearerTokens.addLast(newBearerTokens) 165 + // Return the newly refreshed access token. 166 + bearerTokens?.accessToken!! 167 + } 168 + } else { 169 + log.debug("Tokens refreshed by another coroutine") 109 170 110 - val newRequest = HttpRequestBuilder() 111 - newRequest.takeFrom(request) 112 - newRequest.headers { 113 - remove(HttpHeaders.Authorization) 114 - append(HttpHeaders.Authorization, "Bearer ${newBearerTokens.accessToken}") 171 + // Return the newly retrieved access token. 172 + newAccessToken!! 115 173 } 116 - 117 - return@intercept execute(newRequest) 174 + } else { 175 + // Another error has occurred. Return the original call. 176 + return@intercept call 118 177 } 119 178 } 179 + 180 + // Another status code was returned. Return the original call. 181 + else -> return@intercept call 120 182 } 121 183 122 - originalCall 184 + // Resend the request with the new access token. 185 + // TODO: Check if this is necessary. If this request gets intercepted, this won't be necessary. 186 + request.headers.remove(HttpHeaders.Authorization) 187 + request.headers.append(HttpHeaders.Authorization, "Bearer $newAccessToken") 188 + 189 + log.debug("Retrying original request with new access token") 190 + 191 + execute(request) 123 192 } 124 193 } 125 194 ··· 144 213 when (response.status) { 145 214 HttpStatusCode.OK -> { 146 215 val tokens: Response = response.body() 147 - bearerTokens.addLast(BearerTokens(tokens.accessJwt, tokens.refreshJwt)) 216 + bearerTokens = BearerTokens(tokens.accessJwt, tokens.refreshJwt) 148 217 } 149 218 150 219 HttpStatusCode.BadRequest,
+2 -2
darkfeed/src/main/kotlin/api/lexicon/app/bsky/feed/FeedSkeleton.kt
··· 1 - package api.lexicon.app.bsky.feed 1 + package rs.averyrive.darkfeed.api.lexicon.app.bsky.feed 2 2 3 - import api.lexicon.app.bsky.feed.defs.SkeletonFeedPost 4 3 import kotlinx.serialization.Serializable 4 + import rs.averyrive.darkfeed.api.lexicon.app.bsky.feed.defs.SkeletonFeedPost 5 5 6 6 @Serializable 7 7 data class FeedSkeleton(
+1 -1
darkfeed/src/main/kotlin/api/lexicon/app/bsky/feed/Generator.kt
··· 1 - package api.lexicon.app.bsky.feed 1 + package rs.averyrive.darkfeed.api.lexicon.app.bsky.feed 2 2 3 3 import kotlinx.serialization.Serializable 4 4
+2 -2
darkfeed/src/main/kotlin/api/lexicon/app/bsky/feed/Like.kt
··· 1 - package api.lexicon.app.bsky.feed 1 + package rs.averyrive.darkfeed.api.lexicon.app.bsky.feed 2 2 3 - import api.lexicon.com.atproto.repo.StrongRef 4 3 import kotlinx.serialization.Serializable 4 + import rs.averyrive.darkfeed.api.lexicon.com.atproto.repo.StrongRef 5 5 6 6 @Serializable 7 7 data class Like(
+1 -1
darkfeed/src/main/kotlin/api/lexicon/app/bsky/feed/Post.kt
··· 1 - package api.lexicon.app.bsky.feed 1 + package rs.averyrive.darkfeed.api.lexicon.app.bsky.feed 2 2 3 3 import kotlinx.serialization.Serializable 4 4
+2 -2
darkfeed/src/main/kotlin/api/lexicon/app/bsky/feed/defs/PostView.kt
··· 1 - package api.lexicon.app.bsky.feed.defs 1 + package rs.averyrive.darkfeed.api.lexicon.app.bsky.feed.defs 2 2 3 - import api.lexicon.com.atproto.label.defs.Label 4 3 import kotlinx.serialization.Serializable 4 + import rs.averyrive.darkfeed.api.lexicon.com.atproto.label.defs.Label 5 5 6 6 @Serializable 7 7 data class PostView(
+1 -1
darkfeed/src/main/kotlin/api/lexicon/app/bsky/feed/defs/SkeletonFeedPost.kt
··· 1 - package api.lexicon.app.bsky.feed.defs 1 + package rs.averyrive.darkfeed.api.lexicon.app.bsky.feed.defs 2 2 3 3 import kotlinx.serialization.Serializable 4 4
+1 -1
darkfeed/src/main/kotlin/api/lexicon/com/atproto/label/defs/Label.kt
··· 1 - package api.lexicon.com.atproto.label.defs 1 + package rs.averyrive.darkfeed.api.lexicon.com.atproto.label.defs 2 2 3 3 import kotlinx.serialization.SerialName 4 4 import kotlinx.serialization.Serializable
+1 -1
darkfeed/src/main/kotlin/api/lexicon/com/atproto/repo/StrongRef.kt
··· 1 - package api.lexicon.com.atproto.repo 1 + package rs.averyrive.darkfeed.api.lexicon.com.atproto.repo 2 2 3 3 import kotlinx.serialization.Serializable 4 4
+8 -6
darkfeed/src/main/kotlin/server/FeedServer.kt
··· 1 - package server 1 + package rs.averyrive.darkfeed.server 2 2 3 - import api.BskyApi 4 - import api.lexicon.app.bsky.feed.FeedSkeleton 5 - import api.lexicon.app.bsky.feed.defs.PostView 6 - import api.lexicon.app.bsky.feed.defs.SkeletonFeedPost 7 3 import com.auth0.jwt.JWT 8 4 import io.ktor.serialization.kotlinx.json.* 9 5 import io.ktor.server.application.* ··· 12 8 import io.ktor.server.plugins.contentnegotiation.* 13 9 import io.ktor.server.response.* 14 10 import io.ktor.server.routing.* 15 - import kotlinx.coroutines.* 11 + import kotlinx.coroutines.Deferred 12 + import kotlinx.coroutines.async 13 + import kotlinx.coroutines.runBlocking 16 14 import kotlinx.serialization.Serializable 17 15 import kotlinx.serialization.json.Json 16 + import rs.averyrive.darkfeed.api.BskyApi 17 + import rs.averyrive.darkfeed.api.lexicon.app.bsky.feed.FeedSkeleton 18 + import rs.averyrive.darkfeed.api.lexicon.app.bsky.feed.defs.PostView 19 + import rs.averyrive.darkfeed.api.lexicon.app.bsky.feed.defs.SkeletonFeedPost 18 20 19 21 val DESIRED_LABELS: List<String> = listOf("porn", "sexual", "nudity", "sexual-figurative") 20 22