social media crossposting tool. 3rd time's the charm
mastodon misskey crossposting bluesky

run a formatter

zenfyr.dev dd2bea73 c0211a8f

verified
+1541 -1238
+82 -72
bluesky/atproto2.py
··· 1 1 from typing import Any 2 - from atproto import client_utils, Client, AtUri, IdResolver 2 + 3 + from atproto import AtUri, Client, IdResolver, client_utils 3 4 from atproto_client import models 5 + 4 6 from util.util import LOGGER 7 + 5 8 6 9 def resolve_identity( 7 - handle: str | None = None, 8 - did: str | None = None, 9 - pds: str | None = None): 10 + handle: str | None = None, did: str | None = None, pds: str | None = None 11 + ): 10 12 """helper to try and resolve identity from provided parameters, a valid handle is enough""" 11 - 13 + 12 14 if did and pds: 13 - return did, pds[:-1] if pds.endswith('/') else pds 14 - 15 + return did, pds[:-1] if pds.endswith("/") else pds 16 + 15 17 resolver = IdResolver() 16 18 if not did: 17 19 if not handle: ··· 20 22 did = resolver.handle.resolve(handle) 21 23 if not did: 22 24 raise Exception("Failed to resolve DID!") 23 - 25 + 24 26 if not pds: 25 27 LOGGER.info("Resolving PDS from DID document...") 26 28 did_doc = resolver.did.resolve(did) ··· 29 31 pds = did_doc.get_pds_endpoint() 30 32 if not pds: 31 33 raise Exception("Failed to resolve PDS!") 32 - 33 - return did, pds[:-1] if pds.endswith('/') else pds 34 + 35 + return did, pds[:-1] if pds.endswith("/") else pds 36 + 34 37 35 38 class Client2(Client): 36 39 def __init__(self, base_url: str | None = None, *args: Any, **kwargs: Any) -> None: 37 40 super().__init__(base_url, *args, **kwargs) 38 - 41 + 39 42 def send_video( 40 - self, 41 - text: str | client_utils.TextBuilder, 43 + self, 44 + text: str | client_utils.TextBuilder, 42 45 video: bytes, 43 46 video_alt: str | None = None, 44 47 video_aspect_ratio: models.AppBskyEmbedDefs.AspectRatio | None = None, ··· 46 49 langs: list[str] | None = None, 47 50 facets: list[models.AppBskyRichtextFacet.Main] | None = None, 48 51 labels: models.ComAtprotoLabelDefs.SelfLabels | None = None, 49 - time_iso: str | None = None 50 - ) -> models.AppBskyFeedPost.CreateRecordResponse: 52 + time_iso: str | None = None, 53 + ) -> models.AppBskyFeedPost.CreateRecordResponse: 51 54 """same as send_video, but with labels""" 52 - 55 + 53 56 if video_alt is None: 54 - video_alt = '' 57 + video_alt = "" 55 58 56 59 upload = self.upload_blob(video) 57 - 60 + 58 61 return self.send_post( 59 62 text, 60 63 reply_to=reply_to, 61 - embed=models.AppBskyEmbedVideo.Main(video=upload.blob, alt=video_alt, aspect_ratio=video_aspect_ratio), 64 + embed=models.AppBskyEmbedVideo.Main( 65 + video=upload.blob, alt=video_alt, aspect_ratio=video_aspect_ratio 66 + ), 62 67 langs=langs, 63 68 facets=facets, 64 69 labels=labels, 65 - time_iso=time_iso 70 + time_iso=time_iso, 66 71 ) 67 - 72 + 68 73 def send_images( 69 - self, 70 - text: str | client_utils.TextBuilder, 74 + self, 75 + text: str | client_utils.TextBuilder, 71 76 images: list[bytes], 72 77 image_alts: list[str] | None = None, 73 78 image_aspect_ratios: list[models.AppBskyEmbedDefs.AspectRatio] | None = None, ··· 75 80 langs: list[str] | None = None, 76 81 facets: list[models.AppBskyRichtextFacet.Main] | None = None, 77 82 labels: models.ComAtprotoLabelDefs.SelfLabels | None = None, 78 - time_iso: str | None = None 79 - ) -> models.AppBskyFeedPost.CreateRecordResponse: 83 + time_iso: str | None = None, 84 + ) -> models.AppBskyFeedPost.CreateRecordResponse: 80 85 """same as send_images, but with labels""" 81 - 86 + 82 87 if image_alts is None: 83 - image_alts = [''] * len(images) 88 + image_alts = [""] * len(images) 84 89 else: 85 90 diff = len(images) - len(image_alts) 86 - image_alts = image_alts + [''] * diff 87 - 91 + image_alts = image_alts + [""] * diff 92 + 88 93 if image_aspect_ratios is None: 89 94 aligned_image_aspect_ratios = [None] * len(images) 90 95 else: 91 96 diff = len(images) - len(image_aspect_ratios) 92 97 aligned_image_aspect_ratios = image_aspect_ratios + [None] * diff 93 - 98 + 94 99 uploads = [self.upload_blob(image) for image in images] 95 - 100 + 96 101 embed_images = [ 97 - models.AppBskyEmbedImages.Image(alt=alt, image=upload.blob, aspect_ratio=aspect_ratio) 98 - for alt, upload, aspect_ratio in zip(image_alts, uploads, aligned_image_aspect_ratios) 102 + models.AppBskyEmbedImages.Image( 103 + alt=alt, image=upload.blob, aspect_ratio=aspect_ratio 104 + ) 105 + for alt, upload, aspect_ratio in zip( 106 + image_alts, uploads, aligned_image_aspect_ratios 107 + ) 99 108 ] 100 - 109 + 101 110 return self.send_post( 102 111 text, 103 112 reply_to=reply_to, ··· 105 114 langs=langs, 106 115 facets=facets, 107 116 labels=labels, 108 - time_iso=time_iso 117 + time_iso=time_iso, 109 118 ) 110 - 119 + 111 120 def send_post( 112 - self, 113 - text: str | client_utils.TextBuilder, 121 + self, 122 + text: str | client_utils.TextBuilder, 114 123 reply_to: models.AppBskyFeedPost.ReplyRef | None = None, 115 - embed: 116 - None | 117 - models.AppBskyEmbedImages.Main | 118 - models.AppBskyEmbedExternal.Main | 119 - models.AppBskyEmbedRecord.Main | 120 - models.AppBskyEmbedRecordWithMedia.Main | 121 - models.AppBskyEmbedVideo.Main = None, 124 + embed: None 125 + | models.AppBskyEmbedImages.Main 126 + | models.AppBskyEmbedExternal.Main 127 + | models.AppBskyEmbedRecord.Main 128 + | models.AppBskyEmbedRecordWithMedia.Main 129 + | models.AppBskyEmbedVideo.Main = None, 122 130 langs: list[str] | None = None, 123 131 facets: list[models.AppBskyRichtextFacet.Main] | None = None, 124 132 labels: models.ComAtprotoLabelDefs.SelfLabels | None = None, 125 - time_iso: str | None = None 126 - ) -> models.AppBskyFeedPost.CreateRecordResponse: 133 + time_iso: str | None = None, 134 + ) -> models.AppBskyFeedPost.CreateRecordResponse: 127 135 """same as send_post, but with labels""" 128 - 136 + 129 137 if isinstance(text, client_utils.TextBuilder): 130 138 facets = text.build_facets() 131 139 text = text.build_text() 132 - 140 + 133 141 repo = self.me and self.me.did 134 142 if not repo: 135 143 raise Exception("Client not logged in!") 136 - 144 + 137 145 if not langs: 138 - langs = ['en'] 139 - 146 + langs = ["en"] 147 + 140 148 record = models.AppBskyFeedPost.Record( 141 149 created_at=time_iso or self.get_current_time_iso(), 142 150 text=text, ··· 144 152 embed=embed or None, 145 153 langs=langs, 146 154 facets=facets or None, 147 - labels=labels or None 155 + labels=labels or None, 148 156 ) 149 157 return self.app.bsky.feed.post.create(repo, record) 150 - 151 - def create_gates(self, thread_gate_opts: list[str], quote_gate: bool, post_uri: str, time_iso: str | None = None): 158 + 159 + def create_gates( 160 + self, 161 + thread_gate_opts: list[str], 162 + quote_gate: bool, 163 + post_uri: str, 164 + time_iso: str | None = None, 165 + ): 152 166 account = self.me 153 167 if not account: 154 168 raise Exception("Client not logged in!") 155 - 169 + 156 170 rkey = AtUri.from_str(post_uri).rkey 157 171 time_iso = time_iso or self.get_current_time_iso() 158 - 159 - if 'everybody' not in thread_gate_opts: 172 + 173 + if "everybody" not in thread_gate_opts: 160 174 allow = [] 161 175 if thread_gate_opts: 162 - if 'following' in thread_gate_opts: 176 + if "following" in thread_gate_opts: 163 177 allow.append(models.AppBskyFeedThreadgate.FollowingRule()) 164 - if 'followers' in thread_gate_opts: 178 + if "followers" in thread_gate_opts: 165 179 allow.append(models.AppBskyFeedThreadgate.FollowerRule()) 166 - if 'mentioned' in thread_gate_opts: 180 + if "mentioned" in thread_gate_opts: 167 181 allow.append(models.AppBskyFeedThreadgate.MentionRule()) 168 - 182 + 169 183 thread_gate = models.AppBskyFeedThreadgate.Record( 170 - post=post_uri, 171 - created_at=time_iso, 172 - allow=allow 184 + post=post_uri, created_at=time_iso, allow=allow 173 185 ) 174 - 186 + 175 187 self.app.bsky.feed.threadgate.create(account.did, thread_gate, rkey) 176 - 188 + 177 189 if quote_gate: 178 190 post_gate = models.AppBskyFeedPostgate.Record( 179 191 post=post_uri, 180 192 created_at=time_iso, 181 - embedding_rules=[ 182 - models.AppBskyFeedPostgate.DisableRule() 183 - ] 193 + embedding_rules=[models.AppBskyFeedPostgate.DisableRule()], 184 194 ) 185 - 186 - self.app.bsky.feed.postgate.create(account.did, post_gate, rkey) 195 + 196 + self.app.bsky.feed.postgate.create(account.did, post_gate, rkey)
+90 -71
bluesky/common.py
··· 1 - import re, json 1 + import re 2 2 3 3 from atproto import client_utils 4 4 ··· 7 7 from util.util import canonical_label 8 8 9 9 # only for lexicon reference 10 - SERVICE = 'https://bsky.app' 10 + SERVICE = "https://bsky.app" 11 11 12 12 # TODO this is terrible and stupid 13 - ADULT_PATTERN = re.compile(r"\b(sexual content|nsfw|erotic|adult only|18\+)\b", re.IGNORECASE) 14 - PORN_PATTERN = re.compile(r"\b(porn|yiff|hentai|pornographic|fetish)\b", re.IGNORECASE) 13 + ADULT_PATTERN = re.compile( 14 + r"\b(sexual content|nsfw|erotic|adult only|18\+)\b", re.IGNORECASE 15 + ) 16 + PORN_PATTERN = re.compile(r"\b(porn|yiff|hentai|pornographic|fetish)\b", re.IGNORECASE) 17 + 15 18 16 19 class BlueskyPost(cross.Post): 17 - def __init__(self, record: dict, tokens: list[cross.Token], attachments: list[MediaInfo]) -> None: 20 + def __init__( 21 + self, record: dict, tokens: list[cross.Token], attachments: list[MediaInfo] 22 + ) -> None: 18 23 super().__init__() 19 - self.uri = record['$xpost.strongRef']['uri'] 24 + self.uri = record["$xpost.strongRef"]["uri"] 20 25 self.parent_uri = None 21 - if record.get('reply'): 22 - self.parent_uri = record['reply']['parent']['uri'] 23 - 26 + if record.get("reply"): 27 + self.parent_uri = record["reply"]["parent"]["uri"] 28 + 24 29 self.tokens = tokens 25 - self.timestamp = record['createdAt'] 26 - labels = record.get('labels', {}).get('values') 30 + self.timestamp = record["createdAt"] 31 + labels = record.get("labels", {}).get("values") 27 32 self.spoiler = None 28 33 if labels: 29 - self.spoiler = ', '.join([str(label['val']).replace('-', ' ') for label in labels]) 30 - 34 + self.spoiler = ", ".join( 35 + [str(label["val"]).replace("-", " ") for label in labels] 36 + ) 37 + 31 38 self.attachments = attachments 32 - self.languages = record.get('langs', []) 33 - 39 + self.languages = record.get("langs", []) 40 + 34 41 # at:// of the post record 35 42 def get_id(self) -> str: 36 43 return self.uri 37 - 44 + 38 45 def get_parent_id(self) -> str | None: 39 46 return self.parent_uri 40 - 47 + 41 48 def get_tokens(self) -> list[cross.Token]: 42 49 return self.tokens 43 - 50 + 44 51 def get_text_type(self) -> str: 45 52 return "text/plain" 46 - 53 + 47 54 def get_timestamp(self) -> str: 48 55 return self.timestamp 49 56 50 57 def get_attachments(self) -> list[MediaInfo]: 51 58 return self.attachments 52 - 59 + 53 60 def get_spoiler(self) -> str | None: 54 61 return self.spoiler 55 62 56 63 def get_languages(self) -> list[str]: 57 64 return self.languages 58 - 65 + 59 66 def is_sensitive(self) -> bool: 60 67 return self.spoiler is not None 61 68 62 69 def get_post_url(self) -> str | None: 63 - did, _, post_id = str(self.uri[len("at://"):]).split("/") 64 - 70 + did, _, post_id = str(self.uri[len("at://") :]).split("/") 71 + 65 72 return f"https://bsky.app/profile/{did}/post/{post_id}" 73 + 66 74 67 75 def tokenize_post(post: dict) -> list[cross.Token]: 68 - text: str = post.get('text', '') 76 + text: str = post.get("text", "") 69 77 if not text: 70 78 return [] 71 - ut8_text = text.encode(encoding='utf-8') 72 - 79 + ut8_text = text.encode(encoding="utf-8") 80 + 73 81 def decode(ut8: bytes) -> str: 74 - return ut8.decode(encoding='utf-8') 75 - 76 - facets: list[dict] = post.get('facets', []) 82 + return ut8.decode(encoding="utf-8") 83 + 84 + facets: list[dict] = post.get("facets", []) 77 85 if not facets: 78 86 return [cross.TextToken(decode(ut8_text))] 79 - 87 + 80 88 slices: list[tuple[int, int, str, str]] = [] 81 - 89 + 82 90 for facet in facets: 83 - features: list[dict] = facet.get('features', []) 91 + features: list[dict] = facet.get("features", []) 84 92 if not features: 85 93 continue 86 - 94 + 87 95 # we don't support overlapping facets/features 88 96 feature = features[0] 89 - feature_type = feature['$type'] 90 - index = facet['index'] 97 + feature_type = feature["$type"] 98 + index = facet["index"] 91 99 match feature_type: 92 - case 'app.bsky.richtext.facet#tag': 93 - slices.append((index['byteStart'], index['byteEnd'], 'tag', feature['tag'])) 94 - case 'app.bsky.richtext.facet#link': 95 - slices.append((index['byteStart'], index['byteEnd'], 'link', feature['uri'])) 96 - case 'app.bsky.richtext.facet#mention': 97 - slices.append((index['byteStart'], index['byteEnd'], 'mention', feature['did'])) 98 - 100 + case "app.bsky.richtext.facet#tag": 101 + slices.append( 102 + (index["byteStart"], index["byteEnd"], "tag", feature["tag"]) 103 + ) 104 + case "app.bsky.richtext.facet#link": 105 + slices.append( 106 + (index["byteStart"], index["byteEnd"], "link", feature["uri"]) 107 + ) 108 + case "app.bsky.richtext.facet#mention": 109 + slices.append( 110 + (index["byteStart"], index["byteEnd"], "mention", feature["did"]) 111 + ) 112 + 99 113 if not slices: 100 114 return [cross.TextToken(decode(ut8_text))] 101 - 115 + 102 116 slices.sort(key=lambda s: s[0]) 103 117 unique: list[tuple[int, int, str, str]] = [] 104 118 current_end = 0 ··· 106 120 if start >= current_end: 107 121 unique.append((start, end, ttype, val)) 108 122 current_end = end 109 - 123 + 110 124 if not unique: 111 125 return [cross.TextToken(decode(ut8_text))] 112 - 126 + 113 127 tokens: list[cross.Token] = [] 114 128 prev = 0 115 - 129 + 116 130 for start, end, ttype, val in unique: 117 131 if start > prev: 118 132 # text between facets 119 133 tokens.append(cross.TextToken(decode(ut8_text[prev:start]))) 120 134 # facet token 121 135 match ttype: 122 - case 'link': 136 + case "link": 123 137 label = decode(ut8_text[start:end]) 124 - 138 + 125 139 # try to unflatten links 126 - split = val.split('://', 1) 140 + split = val.split("://", 1) 127 141 if len(split) > 1: 128 142 if split[1].startswith(label): 129 - tokens.append(cross.LinkToken(val, '')) 143 + tokens.append(cross.LinkToken(val, "")) 130 144 prev = end 131 145 continue 132 - 133 - if label.endswith('...') and split[1].startswith(label[:-3]): 134 - tokens.append(cross.LinkToken(val, '')) 146 + 147 + if label.endswith("...") and split[1].startswith(label[:-3]): 148 + tokens.append(cross.LinkToken(val, "")) 135 149 prev = end 136 - continue 137 - 150 + continue 151 + 138 152 tokens.append(cross.LinkToken(val, label)) 139 - case 'tag': 153 + case "tag": 140 154 tag = decode(ut8_text[start:end]) 141 - tokens.append(cross.TagToken(tag[1:] if tag.startswith('#') else tag)) 142 - case 'mention': 155 + tokens.append(cross.TagToken(tag[1:] if tag.startswith("#") else tag)) 156 + case "mention": 143 157 mention = decode(ut8_text[start:end]) 144 - tokens.append(cross.MentionToken(mention[1:] if mention.startswith('@') else mention, val)) 158 + tokens.append( 159 + cross.MentionToken( 160 + mention[1:] if mention.startswith("@") else mention, val 161 + ) 162 + ) 145 163 prev = end 146 164 147 165 if prev < len(ut8_text): 148 166 tokens.append(cross.TextToken(decode(ut8_text[prev:]))) 149 - 150 - return tokens 167 + 168 + return tokens 169 + 151 170 152 171 def tokens_to_richtext(tokens: list[cross.Token]) -> client_utils.TextBuilder | None: 153 172 builder = client_utils.TextBuilder() 154 - 173 + 155 174 def flatten_link(href: str): 156 - split = href.split('://', 1) 175 + split = href.split("://", 1) 157 176 if len(split) > 1: 158 177 href = split[1] 159 - 178 + 160 179 if len(href) > 32: 161 - href = href[:32] + '...' 162 - 180 + href = href[:32] + "..." 181 + 163 182 return href 164 - 183 + 165 184 for token in tokens: 166 185 if isinstance(token, cross.TextToken): 167 186 builder.text(token.text) ··· 169 188 if canonical_label(token.label, token.href): 170 189 builder.link(flatten_link(token.href), token.href) 171 190 continue 172 - 191 + 173 192 builder.link(token.label, token.href) 174 193 elif isinstance(token, cross.TagToken): 175 - builder.tag('#' + token.tag, token.tag.lower()) 194 + builder.tag("#" + token.tag, token.tag.lower()) 176 195 else: 177 196 # fail on unsupported tokens 178 197 return None 179 - 180 - return builder 198 + 199 + return builder
+105 -79
bluesky/input.py
··· 1 - import re, json, websockets, asyncio 1 + import asyncio 2 + import json 3 + import re 4 + from typing import Any, Callable 2 5 6 + import websockets 3 7 from atproto_client import models 4 8 from atproto_client.models.utils import get_or_create as get_model_or_create 9 + 10 + import cross 11 + import util.database as database 5 12 from bluesky.atproto2 import resolve_identity 6 - 7 - from bluesky.common import BlueskyPost, SERVICE, tokenize_post 8 - 9 - import cross, util.database as database 13 + from bluesky.common import SERVICE, BlueskyPost, tokenize_post 14 + from util.database import DataBaseWorker 15 + from util.media import MediaInfo, download_media 10 16 from util.util import LOGGER, as_envvar 11 - from util.media import MediaInfo, download_media 12 - from util.database import DataBaseWorker 13 17 14 - from typing import Callable, Any 15 18 16 - class BlueskyInputOptions(): 19 + class BlueskyInputOptions: 17 20 def __init__(self, o: dict) -> None: 18 - self.filters = [re.compile(f) for f in o.get('regex_filters', [])] 21 + self.filters = [re.compile(f) for f in o.get("regex_filters", [])] 22 + 19 23 20 24 class BlueskyInput(cross.Input): 21 25 def __init__(self, settings: dict, db: DataBaseWorker) -> None: 22 - self.options = BlueskyInputOptions(settings.get('options', {})) 26 + self.options = BlueskyInputOptions(settings.get("options", {})) 23 27 did, pds = resolve_identity( 24 - handle=as_envvar(settings.get('handle')), 25 - did=as_envvar(settings.get('did')), 26 - pds=as_envvar(settings.get('pds')) 28 + handle=as_envvar(settings.get("handle")), 29 + did=as_envvar(settings.get("did")), 30 + pds=as_envvar(settings.get("pds")), 27 31 ) 28 32 self.pds = pds 29 - 33 + 30 34 # PDS is Not a service, the lexicon and rids are the same across pds 31 35 super().__init__(SERVICE, did, settings, db) 32 - 36 + 33 37 def _on_post(self, outputs: list[cross.Output], post: dict[str, Any]): 34 - post_uri = post['$xpost.strongRef']['uri'] 35 - post_cid = post['$xpost.strongRef']['cid'] 36 - 38 + post_uri = post["$xpost.strongRef"]["uri"] 39 + post_cid = post["$xpost.strongRef"]["cid"] 40 + 37 41 parent_uri = None 38 - if post.get('reply'): 39 - parent_uri = post['reply']['parent']['uri'] 40 - 41 - embed = post.get('embed', {}) 42 - if embed.get('$type') in ('app.bsky.embed.record', 'app.bsky.embed.recordWithMedia'): 43 - did, collection, rid = str(embed['record']['uri'][len('at://'):]).split('/') 44 - if collection == 'app.bsky.feed.post': 42 + if post.get("reply"): 43 + parent_uri = post["reply"]["parent"]["uri"] 44 + 45 + embed = post.get("embed", {}) 46 + if embed.get("$type") in ( 47 + "app.bsky.embed.record", 48 + "app.bsky.embed.recordWithMedia", 49 + ): 50 + did, collection, rid = str(embed["record"]["uri"][len("at://") :]).split( 51 + "/" 52 + ) 53 + if collection == "app.bsky.feed.post": 45 54 LOGGER.info("Skipping '%s'! Quote..", post_uri) 46 55 return 47 - 48 - success = database.try_insert_post(self.db, post_uri, parent_uri, self.user_id, self.service) 56 + 57 + success = database.try_insert_post( 58 + self.db, post_uri, parent_uri, self.user_id, self.service 59 + ) 49 60 if not success: 50 61 LOGGER.info("Skipping '%s' as parent post was not found in db!", post_uri) 51 62 return 52 - database.store_data(self.db, post_uri, self.user_id, self.service, {'cid': post_cid}) 53 - 63 + database.store_data( 64 + self.db, post_uri, self.user_id, self.service, {"cid": post_cid} 65 + ) 66 + 54 67 tokens = tokenize_post(post) 55 68 if not cross.test_filters(tokens, self.options.filters): 56 69 LOGGER.info("Skipping '%s'. Matched a filter!", post_uri) 57 70 return 58 - 71 + 59 72 LOGGER.info("Crossposting '%s'...", post_uri) 60 - 73 + 61 74 def get_blob_url(blob: str): 62 - return f'{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.user_id}&cid={blob}' 63 - 75 + return f"{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.user_id}&cid={blob}" 76 + 64 77 attachments: list[MediaInfo] = [] 65 - if embed.get('$type') == 'app.bsky.embed.images': 78 + if embed.get("$type") == "app.bsky.embed.images": 66 79 model = get_model_or_create(embed, model=models.AppBskyEmbedImages.Main) 67 80 assert isinstance(model, models.AppBskyEmbedImages.Main) 68 - 81 + 69 82 for image in model.images: 70 83 url = get_blob_url(image.image.cid.encode()) 71 84 LOGGER.info("Downloading %s...", url) ··· 74 87 LOGGER.error("Skipping '%s'. Failed to download media!", post_uri) 75 88 return 76 89 attachments.append(io) 77 - elif embed.get('$type') == 'app.bsky.embed.video': 90 + elif embed.get("$type") == "app.bsky.embed.video": 78 91 model = get_model_or_create(embed, model=models.AppBskyEmbedVideo.Main) 79 92 assert isinstance(model, models.AppBskyEmbedVideo.Main) 80 93 url = get_blob_url(model.video.cid.encode()) 81 94 LOGGER.info("Downloading %s...", url) 82 - io = download_media(url, model.alt if model.alt else '') 95 + io = download_media(url, model.alt if model.alt else "") 83 96 if not io: 84 97 LOGGER.error("Skipping '%s'. Failed to download media!", post_uri) 85 98 return 86 99 attachments.append(io) 87 - 100 + 88 101 cross_post = BlueskyPost(post, tokens, attachments) 89 102 for output in outputs: 90 103 output.accept_post(cross_post) ··· 93 106 post = database.find_post(self.db, post_id, self.user_id, self.service) 94 107 if not post: 95 108 return 96 - 109 + 97 110 LOGGER.info("Deleting '%s'...", post_id) 98 111 if repost: 99 112 for output in outputs: ··· 102 115 for output in outputs: 103 116 output.delete_post(post_id) 104 117 database.delete_post(self.db, post_id, self.user_id, self.service) 105 - 118 + 106 119 def _on_repost(self, outputs: list[cross.Output], post: dict[str, Any]): 107 - post_uri = post['$xpost.strongRef']['uri'] 108 - post_cid = post['$xpost.strongRef']['cid'] 109 - 110 - reposted_uri = post['subject']['uri'] 111 - 112 - success = database.try_insert_repost(self.db, post_uri, reposted_uri, self.user_id, self.service) 120 + post_uri = post["$xpost.strongRef"]["uri"] 121 + post_cid = post["$xpost.strongRef"]["cid"] 122 + 123 + reposted_uri = post["subject"]["uri"] 124 + 125 + success = database.try_insert_repost( 126 + self.db, post_uri, reposted_uri, self.user_id, self.service 127 + ) 113 128 if not success: 114 129 LOGGER.info("Skipping '%s' as reposted post was not found in db!", post_uri) 115 130 return 116 - database.store_data(self.db, post_uri, self.user_id, self.service, {'cid': post_cid}) 117 - 131 + database.store_data( 132 + self.db, post_uri, self.user_id, self.service, {"cid": post_cid} 133 + ) 134 + 118 135 LOGGER.info("Crossposting '%s'...", post_uri) 119 136 for output in outputs: 120 137 output.accept_repost(post_uri, reposted_uri) 121 138 139 + 122 140 class BlueskyJetstreamInput(BlueskyInput): 123 141 def __init__(self, settings: dict, db: DataBaseWorker) -> None: 124 142 super().__init__(settings, db) 125 - self.jetstream = settings.get("jetstream", "wss://jetstream2.us-east.bsky.network/subscribe") 126 - 143 + self.jetstream = settings.get( 144 + "jetstream", "wss://jetstream2.us-east.bsky.network/subscribe" 145 + ) 146 + 127 147 def __on_commit(self, outputs: list[cross.Output], msg: dict): 128 - if msg.get('did') != self.user_id: 148 + if msg.get("did") != self.user_id: 129 149 return 130 - 131 - commit: dict = msg.get('commit', {}) 150 + 151 + commit: dict = msg.get("commit", {}) 132 152 if not commit: 133 153 return 134 - 135 - commit_type = commit['operation'] 154 + 155 + commit_type = commit["operation"] 136 156 match commit_type: 137 - case 'create': 138 - record = dict(commit.get('record', {})) 139 - record['$xpost.strongRef'] = { 140 - 'cid': commit['cid'], 141 - 'uri': f'at://{self.user_id}/{commit['collection']}/{commit['rkey']}' 157 + case "create": 158 + record = dict(commit.get("record", {})) 159 + record["$xpost.strongRef"] = { 160 + "cid": commit["cid"], 161 + "uri": f"at://{self.user_id}/{commit['collection']}/{commit['rkey']}", 142 162 } 143 - 144 - match commit['collection']: 145 - case 'app.bsky.feed.post': 163 + 164 + match commit["collection"]: 165 + case "app.bsky.feed.post": 146 166 self._on_post(outputs, record) 147 - case 'app.bsky.feed.repost': 167 + case "app.bsky.feed.repost": 148 168 self._on_repost(outputs, record) 149 - case 'delete': 150 - post_id: str = f'at://{self.user_id}/{commit['collection']}/{commit['rkey']}' 151 - match commit['collection']: 152 - case 'app.bsky.feed.post': 169 + case "delete": 170 + post_id: str = ( 171 + f"at://{self.user_id}/{commit['collection']}/{commit['rkey']}" 172 + ) 173 + match commit["collection"]: 174 + case "app.bsky.feed.post": 153 175 self._on_delete_post(outputs, post_id, False) 154 - case 'app.bsky.feed.repost': 176 + case "app.bsky.feed.repost": 155 177 self._on_delete_post(outputs, post_id, True) 156 - 157 - async def listen(self, outputs: list[cross.Output], submit: Callable[[Callable[[], Any]], Any]): 158 - uri = self.jetstream + '?' 178 + 179 + async def listen( 180 + self, outputs: list[cross.Output], submit: Callable[[Callable[[], Any]], Any] 181 + ): 182 + uri = self.jetstream + "?" 159 183 uri += "wantedCollections=app.bsky.feed.post" 160 184 uri += "&wantedCollections=app.bsky.feed.repost" 161 185 uri += f"&wantedDids={self.user_id}" 162 - 163 - async for ws in websockets.connect(uri, extra_headers={"User-Agent": "XPost/0.0.3"}): 186 + 187 + async for ws in websockets.connect( 188 + uri, extra_headers={"User-Agent": "XPost/0.0.3"} 189 + ): 164 190 try: 165 191 LOGGER.info("Listening to %s...", self.jetstream) 166 - 192 + 167 193 async def listen_for_messages(): 168 194 async for msg in ws: 169 195 submit(lambda: self.__on_commit(outputs, json.loads(msg))) 170 - 196 + 171 197 listen = asyncio.create_task(listen_for_messages()) 172 - 198 + 173 199 await asyncio.gather(listen) 174 200 except websockets.ConnectionClosedError as e: 175 201 LOGGER.error(e, stack_info=True, exc_info=True) 176 202 LOGGER.info("Reconnecting to %s...", self.jetstream) 177 - continue 203 + continue
+238 -209
bluesky/output.py
··· 1 - import json 2 - from httpx import Timeout 3 - 4 - from atproto import client_utils, Request 1 + from atproto import Request, client_utils 5 2 from atproto_client import models 6 - from bluesky.atproto2 import Client2, resolve_identity 7 - 8 - from bluesky.common import SERVICE, ADULT_PATTERN, PORN_PATTERN, tokens_to_richtext 3 + from httpx import Timeout 9 4 10 - import cross, util.database as database 5 + import cross 11 6 import misskey.mfm_util as mfm_util 12 - from util.util import LOGGER, as_envvar 13 - from util.media import MediaInfo, get_filename_from_url, get_media_meta, compress_image, convert_to_mp4 7 + import util.database as database 8 + from bluesky.atproto2 import Client2, resolve_identity 9 + from bluesky.common import ADULT_PATTERN, PORN_PATTERN, SERVICE, tokens_to_richtext 14 10 from util.database import DataBaseWorker 11 + from util.media import ( 12 + MediaInfo, 13 + compress_image, 14 + convert_to_mp4, 15 + get_filename_from_url, 16 + get_media_meta, 17 + ) 18 + from util.util import LOGGER, as_envvar 15 19 16 - ALLOWED_GATES = ['mentioned', 'following', 'followers', 'everybody'] 20 + ALLOWED_GATES = ["mentioned", "following", "followers", "everybody"] 21 + 17 22 18 23 class BlueskyOutputOptions: 19 24 def __init__(self, o: dict) -> None: 20 25 self.quote_gate: bool = False 21 - self.thread_gate: list[str] = ['everybody'] 26 + self.thread_gate: list[str] = ["everybody"] 22 27 self.encode_videos: bool = True 23 - 24 - quote_gate = o.get('quote_gate') 28 + 29 + quote_gate = o.get("quote_gate") 25 30 if quote_gate is not None: 26 31 self.quote_gate = bool(quote_gate) 27 - 28 - thread_gate = o.get('thread_gate') 32 + 33 + thread_gate = o.get("thread_gate") 29 34 if thread_gate is not None: 30 35 if any([v not in ALLOWED_GATES for v in thread_gate]): 31 - raise ValueError(f"'thread_gate' only accepts {', '.join(ALLOWED_GATES)} or [], got: {thread_gate}") 36 + raise ValueError( 37 + f"'thread_gate' only accepts {', '.join(ALLOWED_GATES)} or [], got: {thread_gate}" 38 + ) 32 39 self.thread_gate = thread_gate 33 - 34 - encode_videos = o.get('encode_videos') 40 + 41 + encode_videos = o.get("encode_videos") 35 42 if encode_videos is not None: 36 43 self.encode_videos = bool(encode_videos) 44 + 37 45 38 46 class BlueskyOutput(cross.Output): 39 47 def __init__(self, input: cross.Input, settings: dict, db: DataBaseWorker) -> None: 40 48 super().__init__(input, settings, db) 41 - self.options = BlueskyOutputOptions(settings.get('options') or {}) 42 - 43 - if not as_envvar(settings.get('app-password')): 49 + self.options = BlueskyOutputOptions(settings.get("options") or {}) 50 + 51 + if not as_envvar(settings.get("app-password")): 44 52 raise Exception("Account app password not provided!") 45 - 53 + 46 54 did, pds = resolve_identity( 47 - handle=as_envvar(settings.get('handle')), 48 - did=as_envvar(settings.get('did')), 49 - pds=as_envvar(settings.get('pds')) 55 + handle=as_envvar(settings.get("handle")), 56 + did=as_envvar(settings.get("did")), 57 + pds=as_envvar(settings.get("pds")), 50 58 ) 51 - 59 + 52 60 reqs = Request(timeout=Timeout(None, connect=30.0)) 53 - 61 + 54 62 self.bsky = Client2(pds, request=reqs) 55 63 self.bsky.configure_proxy_header( 56 - service_type='bsky_appview', 57 - did=as_envvar(settings.get('bsky_appview')) or 'did:web:api.bsky.app' 64 + service_type="bsky_appview", 65 + did=as_envvar(settings.get("bsky_appview")) or "did:web:api.bsky.app", 58 66 ) 59 - self.bsky.login(did, as_envvar(settings.get('app-password'))) 60 - 67 + self.bsky.login(did, as_envvar(settings.get("app-password"))) 68 + 61 69 def __check_login(self): 62 70 login = self.bsky.me 63 71 if not login: 64 72 raise Exception("Client not logged in!") 65 73 return login 66 - 74 + 67 75 def _find_parent(self, parent_id: str): 68 76 login = self.__check_login() 69 - 77 + 70 78 thread_tuple = database.find_mapped_thread( 71 79 self.db, 72 80 parent_id, 73 81 self.input.user_id, 74 82 self.input.service, 75 83 login.did, 76 - SERVICE 84 + SERVICE, 77 85 ) 78 - 86 + 79 87 if not thread_tuple: 80 88 LOGGER.error("Failed to find thread tuple in the database!") 81 89 return None 82 - 90 + 83 91 root_uri: str = thread_tuple[0] 84 92 reply_uri: str = thread_tuple[1] 85 - 86 - root_cid = database.fetch_data(self.db, root_uri, login.did, SERVICE)['cid'] 87 - reply_cid = database.fetch_data(self.db, root_uri, login.did, SERVICE)['cid'] 88 - 89 - root_record = models.AppBskyFeedPost.CreateRecordResponse(uri=root_uri, cid=root_cid) 90 - reply_record = models.AppBskyFeedPost.CreateRecordResponse(uri=reply_uri, cid=reply_cid) 91 - 93 + 94 + root_cid = database.fetch_data(self.db, root_uri, login.did, SERVICE)["cid"] 95 + reply_cid = database.fetch_data(self.db, root_uri, login.did, SERVICE)["cid"] 96 + 97 + root_record = models.AppBskyFeedPost.CreateRecordResponse( 98 + uri=root_uri, cid=root_cid 99 + ) 100 + reply_record = models.AppBskyFeedPost.CreateRecordResponse( 101 + uri=reply_uri, cid=reply_cid 102 + ) 103 + 92 104 return ( 93 105 models.create_strong_ref(root_record), 94 106 models.create_strong_ref(reply_record), 95 107 thread_tuple[2], 96 - thread_tuple[3] 108 + thread_tuple[3], 97 109 ) 98 - 110 + 99 111 def _split_attachments(self, attachments: list[MediaInfo]): 100 112 sup_media: list[MediaInfo] = [] 101 113 unsup_media: list[MediaInfo] = [] 102 - 114 + 103 115 for a in attachments: 104 - if a.mime.startswith('image/') or a.mime.startswith('video/'): # TODO convert gifs to videos 116 + if a.mime.startswith("image/") or a.mime.startswith( 117 + "video/" 118 + ): # TODO convert gifs to videos 105 119 sup_media.append(a) 106 120 else: 107 121 unsup_media.append(a) 108 - 122 + 109 123 return (sup_media, unsup_media) 110 124 111 125 def _split_media_per_post( 112 - self, 113 - tokens: list[client_utils.TextBuilder], 114 - media: list[MediaInfo]): 115 - 126 + self, tokens: list[client_utils.TextBuilder], media: list[MediaInfo] 127 + ): 116 128 posts: list[dict] = [{"tokens": tokens, "attachments": []} for tokens in tokens] 117 129 available_indices: list[int] = list(range(len(posts))) 118 - 130 + 119 131 current_image_post_idx: int | None = None 120 132 121 133 def make_blank_post() -> dict: 122 - return { 123 - "tokens": [client_utils.TextBuilder().text('')], 124 - "attachments": [] 125 - } 126 - 134 + return {"tokens": [client_utils.TextBuilder().text("")], "attachments": []} 135 + 127 136 def pop_next_empty_index() -> int: 128 137 if available_indices: 129 138 return available_indices.pop(0) ··· 131 140 new_idx = len(posts) 132 141 posts.append(make_blank_post()) 133 142 return new_idx 134 - 143 + 135 144 for att in media: 136 - if att.mime.startswith('video/'): 145 + if att.mime.startswith("video/"): 137 146 current_image_post_idx = None 138 147 idx = pop_next_empty_index() 139 148 posts[idx]["attachments"].append(att) 140 - elif att.mime.startswith('image/'): 149 + elif att.mime.startswith("image/"): 141 150 if ( 142 151 current_image_post_idx is not None 143 152 and len(posts[current_image_post_idx]["attachments"]) < 4 ··· 147 156 idx = pop_next_empty_index() 148 157 posts[idx]["attachments"].append(att) 149 158 current_image_post_idx = idx 150 - 159 + 151 160 result: list[tuple[client_utils.TextBuilder, list[MediaInfo]]] = [] 152 161 for p in posts: 153 162 result.append((p["tokens"], p["attachments"])) 154 163 return result 155 - 164 + 156 165 def accept_post(self, post: cross.Post): 157 166 login = self.__check_login() 158 - 167 + 159 168 parent_id = post.get_parent_id() 160 - 169 + 161 170 # used for db insertion 162 171 new_root_id = None 163 172 new_parent_id = None 164 - 173 + 165 174 root_ref = None 166 175 reply_ref = None 167 176 if parent_id: ··· 169 178 if not parents: 170 179 return 171 180 root_ref, reply_ref, new_root_id, new_parent_id = parents 172 - 181 + 173 182 tokens = post.get_tokens().copy() 174 - 183 + 175 184 unique_labels: set[str] = set() 176 185 cw = post.get_spoiler() 177 186 if cw: 178 187 tokens.insert(0, cross.TextToken("CW: " + cw + "\n\n")) 179 - unique_labels.add('graphic-media') 180 - 188 + unique_labels.add("graphic-media") 189 + 181 190 # from bsky.app, a post can only have one of those labels 182 191 if PORN_PATTERN.search(cw): 183 - unique_labels.add('porn') 192 + unique_labels.add("porn") 184 193 elif ADULT_PATTERN.search(cw): 185 - unique_labels.add('sexual') 186 - 194 + unique_labels.add("sexual") 195 + 187 196 if post.is_sensitive(): 188 - unique_labels.add('graphic-media') 189 - 190 - labels = models.ComAtprotoLabelDefs.SelfLabels(values=[models.ComAtprotoLabelDefs.SelfLabel(val=label) for label in unique_labels]) 197 + unique_labels.add("graphic-media") 198 + 199 + labels = models.ComAtprotoLabelDefs.SelfLabels( 200 + values=[ 201 + models.ComAtprotoLabelDefs.SelfLabel(val=label) 202 + for label in unique_labels 203 + ] 204 + ) 191 205 192 206 sup_media, unsup_media = self._split_attachments(post.get_attachments()) 193 207 194 208 if unsup_media: 195 209 if tokens: 196 - tokens.append(cross.TextToken('\n')) 210 + tokens.append(cross.TextToken("\n")) 197 211 for i, attachment in enumerate(unsup_media): 198 - tokens.append(cross.LinkToken( 199 - attachment.url, 200 - f"[{get_filename_from_url(attachment.url)}]" 201 - )) 202 - tokens.append(cross.TextToken(' ')) 203 - 212 + tokens.append( 213 + cross.LinkToken( 214 + attachment.url, f"[{get_filename_from_url(attachment.url)}]" 215 + ) 216 + ) 217 + tokens.append(cross.TextToken(" ")) 218 + 204 219 if post.get_text_type() == "text/x.misskeymarkdown": 205 220 tokens, status = mfm_util.strip_mfm(tokens) 206 221 post_url = post.get_post_url() 207 222 if status and post_url: 208 - tokens.append(cross.TextToken('\n')) 209 - tokens.append(cross.LinkToken(post_url, "[Post contains MFM, see original]")) 210 - 223 + tokens.append(cross.TextToken("\n")) 224 + tokens.append( 225 + cross.LinkToken(post_url, "[Post contains MFM, see original]") 226 + ) 227 + 211 228 split_tokens: list[list[cross.Token]] = cross.split_tokens(tokens, 300) 212 229 post_text: list[client_utils.TextBuilder] = [] 213 - 230 + 214 231 # convert tokens into rich text. skip post if contains unsupported tokens 215 232 for block in split_tokens: 216 233 rich_text = tokens_to_richtext(block) 217 - 234 + 218 235 if not rich_text: 219 - LOGGER.error("Skipping '%s' as it contains invalid rich text types!", post.get_id()) 236 + LOGGER.error( 237 + "Skipping '%s' as it contains invalid rich text types!", 238 + post.get_id(), 239 + ) 220 240 return 221 241 post_text.append(rich_text) 222 - 242 + 223 243 if not post_text: 224 - post_text = [client_utils.TextBuilder().text('')] 225 - 244 + post_text = [client_utils.TextBuilder().text("")] 245 + 226 246 for m in sup_media: 227 - if m.mime.startswith('image/'): 247 + if m.mime.startswith("image/"): 228 248 if len(m.io) > 2_000_000: 229 - LOGGER.error("Skipping post_id '%s', failed to download attachment! File too large.", post.get_id()) 249 + LOGGER.error( 250 + "Skipping post_id '%s', failed to download attachment! File too large.", 251 + post.get_id(), 252 + ) 230 253 return 231 - 232 - if m.mime.startswith('video/'): 233 - if m.mime != 'video/mp4' and not self.options.encode_videos: 234 - LOGGER.info("Video is not mp4, but encoding is disabled. Skipping '%s'...", post.get_id()) 254 + 255 + if m.mime.startswith("video/"): 256 + if m.mime != "video/mp4" and not self.options.encode_videos: 257 + LOGGER.info( 258 + "Video is not mp4, but encoding is disabled. Skipping '%s'...", 259 + post.get_id(), 260 + ) 235 261 return 236 - 262 + 237 263 if len(m.io) > 100_000_000: 238 - LOGGER.error("Skipping post_id '%s', failed to download attachment! File too large?", post.get_id()) 264 + LOGGER.error( 265 + "Skipping post_id '%s', failed to download attachment! File too large?", 266 + post.get_id(), 267 + ) 239 268 return 240 - 269 + 241 270 created_records: list[models.AppBskyFeedPost.CreateRecordResponse] = [] 242 271 baked_media = self._split_media_per_post(post_text, sup_media) 243 - 272 + 244 273 for text, attachments in baked_media: 245 274 if not attachments: 246 275 if reply_ref and root_ref: 247 - new_post = self.bsky.send_post(text, reply_to=models.AppBskyFeedPost.ReplyRef( 248 - parent=reply_ref, 249 - root=root_ref 250 - ), labels=labels, time_iso=post.get_timestamp()) 276 + new_post = self.bsky.send_post( 277 + text, 278 + reply_to=models.AppBskyFeedPost.ReplyRef( 279 + parent=reply_ref, root=root_ref 280 + ), 281 + labels=labels, 282 + time_iso=post.get_timestamp(), 283 + ) 251 284 else: 252 - new_post = self.bsky.send_post(text, labels=labels, time_iso=post.get_timestamp()) 285 + new_post = self.bsky.send_post( 286 + text, labels=labels, time_iso=post.get_timestamp() 287 + ) 253 288 root_ref = models.create_strong_ref(new_post) 254 - 289 + 255 290 self.bsky.create_gates( 256 - self.options.thread_gate, 257 - self.options.quote_gate, 258 - new_post.uri, 259 - time_iso=post.get_timestamp() 291 + self.options.thread_gate, 292 + self.options.quote_gate, 293 + new_post.uri, 294 + time_iso=post.get_timestamp(), 260 295 ) 261 296 reply_ref = models.create_strong_ref(new_post) 262 297 created_records.append(new_post) 263 298 else: 264 299 # if a single post is an image - everything else is an image 265 - if attachments[0].mime.startswith('image/'): 300 + if attachments[0].mime.startswith("image/"): 266 301 images: list[bytes] = [] 267 302 image_alts: list[str] = [] 268 303 image_aspect_ratios: list[models.AppBskyEmbedDefs.AspectRatio] = [] 269 - 304 + 270 305 for attachment in attachments: 271 306 image_io = compress_image(attachment.io, quality=100) 272 307 metadata = get_media_meta(image_io) 273 - 308 + 274 309 if len(image_io) > 1_000_000: 275 310 LOGGER.info("Compressing %s...", attachment.name) 276 311 image_io = compress_image(image_io) 277 - 312 + 278 313 images.append(image_io) 279 314 image_alts.append(attachment.alt) 280 - image_aspect_ratios.append(models.AppBskyEmbedDefs.AspectRatio( 281 - width=metadata['width'], 282 - height=metadata['height'] 283 - )) 284 - 315 + image_aspect_ratios.append( 316 + models.AppBskyEmbedDefs.AspectRatio( 317 + width=metadata["width"], height=metadata["height"] 318 + ) 319 + ) 320 + 285 321 new_post = self.bsky.send_images( 286 322 text=post_text[0], 287 323 images=images, 288 324 image_alts=image_alts, 289 325 image_aspect_ratios=image_aspect_ratios, 290 - reply_to= models.AppBskyFeedPost.ReplyRef( 291 - parent=reply_ref, 292 - root=root_ref 293 - ) if root_ref and reply_ref else None, 294 - labels=labels, 295 - time_iso=post.get_timestamp() 326 + reply_to=models.AppBskyFeedPost.ReplyRef( 327 + parent=reply_ref, root=root_ref 328 + ) 329 + if root_ref and reply_ref 330 + else None, 331 + labels=labels, 332 + time_iso=post.get_timestamp(), 296 333 ) 297 334 if not root_ref: 298 335 root_ref = models.create_strong_ref(new_post) 299 - 336 + 300 337 self.bsky.create_gates( 301 - self.options.thread_gate, 338 + self.options.thread_gate, 302 339 self.options.quote_gate, 303 - new_post.uri, 304 - time_iso=post.get_timestamp() 340 + new_post.uri, 341 + time_iso=post.get_timestamp(), 305 342 ) 306 343 reply_ref = models.create_strong_ref(new_post) 307 344 created_records.append(new_post) 308 - else: # video is guarantedd to be one 345 + else: # video is guarantedd to be one 309 346 metadata = get_media_meta(attachments[0].io) 310 - if metadata['duration'] > 180: 311 - LOGGER.info("Skipping post_id '%s', video attachment too long!", post.get_id()) 347 + if metadata["duration"] > 180: 348 + LOGGER.info( 349 + "Skipping post_id '%s', video attachment too long!", 350 + post.get_id(), 351 + ) 312 352 return 313 - 353 + 314 354 video_io = attachments[0].io 315 - if attachments[0].mime != 'video/mp4': 355 + if attachments[0].mime != "video/mp4": 316 356 LOGGER.info("Converting %s to mp4...", attachments[0].name) 317 357 video_io = convert_to_mp4(video_io) 318 - 358 + 319 359 aspect_ratio = models.AppBskyEmbedDefs.AspectRatio( 320 - width=metadata['width'], 321 - height=metadata['height'] 360 + width=metadata["width"], height=metadata["height"] 322 361 ) 323 - 362 + 324 363 new_post = self.bsky.send_video( 325 364 text=post_text[0], 326 365 video=video_io, 327 366 video_aspect_ratio=aspect_ratio, 328 367 video_alt=attachments[0].alt, 329 - reply_to= models.AppBskyFeedPost.ReplyRef( 330 - parent=reply_ref, 331 - root=root_ref 332 - ) if root_ref and reply_ref else None, 368 + reply_to=models.AppBskyFeedPost.ReplyRef( 369 + parent=reply_ref, root=root_ref 370 + ) 371 + if root_ref and reply_ref 372 + else None, 333 373 labels=labels, 334 - time_iso=post.get_timestamp() 374 + time_iso=post.get_timestamp(), 335 375 ) 336 376 if not root_ref: 337 377 root_ref = models.create_strong_ref(new_post) 338 - 378 + 339 379 self.bsky.create_gates( 340 380 self.options.thread_gate, 341 - self.options.quote_gate, 342 - new_post.uri, 343 - time_iso=post.get_timestamp() 381 + self.options.quote_gate, 382 + new_post.uri, 383 + time_iso=post.get_timestamp(), 344 384 ) 345 385 reply_ref = models.create_strong_ref(new_post) 346 386 created_records.append(new_post) 347 - 348 - db_post = database.find_post(self.db, post.get_id(), self.input.user_id, self.input.service) 387 + 388 + db_post = database.find_post( 389 + self.db, post.get_id(), self.input.user_id, self.input.service 390 + ) 349 391 assert db_post, "ghghghhhhh" 350 - 351 - if new_root_id is None or new_parent_id is None: 392 + 393 + if new_root_id is None or new_parent_id is None: 352 394 new_root_id = database.insert_post( 395 + self.db, created_records[0].uri, login.did, SERVICE 396 + ) 397 + database.store_data( 353 398 self.db, 354 399 created_records[0].uri, 355 400 login.did, 356 - SERVICE 357 - ) 358 - database.store_data( 359 - self.db, 360 - created_records[0].uri, 361 - login.did, 362 401 SERVICE, 363 - {'cid': created_records[0].cid} 402 + {"cid": created_records[0].cid}, 364 403 ) 365 - 404 + 366 405 new_parent_id = new_root_id 367 - database.insert_mapping(self.db, db_post['id'], new_parent_id) 406 + database.insert_mapping(self.db, db_post["id"], new_parent_id) 368 407 created_records = created_records[1:] 369 - 408 + 370 409 for record in created_records: 371 410 new_parent_id = database.insert_reply( 372 - self.db, 373 - record.uri, 374 - login.did, 375 - SERVICE, 376 - new_parent_id, 377 - new_root_id 411 + self.db, record.uri, login.did, SERVICE, new_parent_id, new_root_id 378 412 ) 379 413 database.store_data( 380 - self.db, 381 - record.uri, 382 - login.did, 383 - SERVICE, 384 - {'cid': record.cid} 414 + self.db, record.uri, login.did, SERVICE, {"cid": record.cid} 385 415 ) 386 - database.insert_mapping(self.db, db_post['id'], new_parent_id) 387 - 416 + database.insert_mapping(self.db, db_post["id"], new_parent_id) 417 + 388 418 def delete_post(self, identifier: str): 389 419 login = self.__check_login() 390 - 391 - post = database.find_post(self.db, identifier, self.input.user_id, self.input.service) 420 + 421 + post = database.find_post( 422 + self.db, identifier, self.input.user_id, self.input.service 423 + ) 392 424 if not post: 393 425 return 394 - 395 - mappings = database.find_mappings(self.db, post['id'], SERVICE, login.did) 426 + 427 + mappings = database.find_mappings(self.db, post["id"], SERVICE, login.did) 396 428 for mapping in mappings[::-1]: 397 429 LOGGER.info("Deleting '%s'...", mapping[0]) 398 430 self.bsky.delete_post(mapping[0]) 399 431 database.delete_post(self.db, mapping[0], SERVICE, login.did) 400 - 432 + 401 433 def accept_repost(self, repost_id: str, reposted_id: str): 402 434 login, repost = self.__delete_repost(repost_id) 403 435 if not (login and repost): 404 436 return 405 - 406 - reposted = database.find_post(self.db, reposted_id, self.input.user_id, self.input.service) 437 + 438 + reposted = database.find_post( 439 + self.db, reposted_id, self.input.user_id, self.input.service 440 + ) 407 441 if not reposted: 408 442 return 409 - 443 + 410 444 # mappings of the reposted post 411 - mappings = database.find_mappings(self.db, reposted['id'], SERVICE, login.did) 445 + mappings = database.find_mappings(self.db, reposted["id"], SERVICE, login.did) 412 446 if mappings: 413 - cid = database.fetch_data(self.db, mappings[0][0], login.did, SERVICE)['cid'] 447 + cid = database.fetch_data(self.db, mappings[0][0], login.did, SERVICE)[ 448 + "cid" 449 + ] 414 450 rsp = self.bsky.repost(mappings[0][0], cid) 415 - 451 + 416 452 internal_id = database.insert_repost( 417 - self.db, 418 - rsp.uri, 419 - reposted['id'], 420 - login.did, 421 - SERVICE) 422 - database.store_data( 423 - self.db, 424 - rsp.uri, 425 - login.did, 426 - SERVICE, 427 - {'cid': rsp.cid} 453 + self.db, rsp.uri, reposted["id"], login.did, SERVICE 428 454 ) 429 - database.insert_mapping(self.db, repost['id'], internal_id) 430 - 431 - def __delete_repost(self, repost_id: str) -> tuple[models.AppBskyActorDefs.ProfileViewDetailed | None, dict | None]: 455 + database.store_data(self.db, rsp.uri, login.did, SERVICE, {"cid": rsp.cid}) 456 + database.insert_mapping(self.db, repost["id"], internal_id) 457 + 458 + def __delete_repost( 459 + self, repost_id: str 460 + ) -> tuple[models.AppBskyActorDefs.ProfileViewDetailed | None, dict | None]: 432 461 login = self.__check_login() 433 - 434 - repost = database.find_post(self.db, repost_id, self.input.user_id, self.input.service) 462 + 463 + repost = database.find_post( 464 + self.db, repost_id, self.input.user_id, self.input.service 465 + ) 435 466 if not repost: 436 467 return None, None 437 - 438 - mappings = database.find_mappings(self.db, repost['id'], SERVICE, login.did) 468 + 469 + mappings = database.find_mappings(self.db, repost["id"], SERVICE, login.did) 439 470 if mappings: 440 471 LOGGER.info("Deleting '%s'...", mappings[0][0]) 441 472 self.bsky.unrepost(mappings[0][0]) 442 473 database.delete_post(self.db, mappings[0][0], login.did, SERVICE) 443 474 return login, repost 444 - 475 + 445 476 def delete_repost(self, repost_id: str): 446 477 self.__delete_repost(repost_id) 447 - 448 -
+79 -61
cross.py
··· 1 + import re 1 2 from abc import ABC, abstractmethod 2 - from typing import Callable, Any 3 + from datetime import datetime, timezone 4 + from typing import Any, Callable 5 + 3 6 from util.database import DataBaseWorker 4 - from datetime import datetime, timezone 5 7 from util.media import MediaInfo 6 8 from util.util import LOGGER, canonical_label 7 - import re 8 9 9 - ALTERNATE = re.compile(r'\S+|\s+') 10 + ALTERNATE = re.compile(r"\S+|\s+") 11 + 10 12 11 13 # generic token 12 - class Token(): 14 + class Token: 13 15 def __init__(self, type: str) -> None: 14 16 self.type = type 15 17 18 + 16 19 class TextToken(Token): 17 20 def __init__(self, text: str) -> None: 18 - super().__init__('text') 21 + super().__init__("text") 19 22 self.text = text 23 + 20 24 21 25 # token that represents a link to a website. e.g. [link](https://google.com/) 22 26 class LinkToken(Token): 23 27 def __init__(self, href: str, label: str) -> None: 24 - super().__init__('link') 28 + super().__init__("link") 25 29 self.href = href 26 30 self.label = label 27 - 28 - # token that represents a hashtag. e.g. #SocialMedia 31 + 32 + 33 + # token that represents a hashtag. e.g. #SocialMedia 29 34 class TagToken(Token): 30 35 def __init__(self, tag: str) -> None: 31 - super().__init__('tag') 36 + super().__init__("tag") 32 37 self.tag = tag 38 + 33 39 34 40 # token that represents a mention of a user. 35 41 class MentionToken(Token): 36 42 def __init__(self, username: str, uri: str) -> None: 37 - super().__init__('mention') 43 + super().__init__("mention") 38 44 self.username = username 39 45 self.uri = uri 40 - 41 - class MediaMeta(): 46 + 47 + 48 + class MediaMeta: 42 49 def __init__(self, width: int, height: int, duration: float) -> None: 43 50 self.width = width 44 51 self.height = height 45 52 self.duration = duration 46 - 53 + 47 54 def get_width(self) -> int: 48 55 return self.width 49 - 56 + 50 57 def get_height(self) -> int: 51 58 return self.height 52 - 59 + 53 60 def get_duration(self) -> float: 54 61 return self.duration 55 - 62 + 63 + 56 64 class Post(ABC): 57 65 @abstractmethod 58 66 def get_id(self) -> str: 59 - return '' 60 - 67 + return "" 68 + 61 69 @abstractmethod 62 70 def get_parent_id(self) -> str | None: 63 71 pass 64 - 72 + 65 73 @abstractmethod 66 74 def get_tokens(self) -> list[Token]: 67 75 pass ··· 71 79 @abstractmethod 72 80 def get_text_type(self) -> str: 73 81 pass 74 - 82 + 75 83 # post iso timestamp 76 84 @abstractmethod 77 85 def get_timestamp(self) -> str: 78 86 pass 79 - 87 + 80 88 def get_attachments(self) -> list[MediaInfo]: 81 89 return [] 82 - 90 + 83 91 def get_spoiler(self) -> str | None: 84 92 return None 85 - 93 + 86 94 def get_languages(self) -> list[str]: 87 95 return [] 88 - 96 + 89 97 def is_sensitive(self) -> bool: 90 98 return False 91 - 99 + 92 100 def get_post_url(self) -> str | None: 93 101 return None 94 102 103 + 95 104 # generic input service. 96 105 # user and service for db queries 97 - class Input(): 98 - def __init__(self, service: str, user_id: str, settings: dict, db: DataBaseWorker) -> None: 106 + class Input: 107 + def __init__( 108 + self, service: str, user_id: str, settings: dict, db: DataBaseWorker 109 + ) -> None: 99 110 self.service = service 100 111 self.user_id = user_id 101 112 self.settings = settings 102 113 self.db = db 103 - 114 + 104 115 async def listen(self, outputs: list, handler: Callable[[Post], Any]): 105 116 pass 106 117 107 - class Output(): 118 + 119 + class Output: 108 120 def __init__(self, input: Input, settings: dict, db: DataBaseWorker) -> None: 109 121 self.input = input 110 122 self.settings = settings 111 123 self.db = db 112 - 124 + 113 125 def accept_post(self, post: Post): 114 126 LOGGER.warning('Not Implemented.. "posted" %s', post.get_id()) 115 - 127 + 116 128 def delete_post(self, identifier: str): 117 129 LOGGER.warning('Not Implemented.. "deleted" %s', identifier) 118 - 130 + 119 131 def accept_repost(self, repost_id: str, reposted_id: str): 120 132 LOGGER.warning('Not Implemented.. "reblogged" %s, %s', repost_id, reposted_id) 121 - 133 + 122 134 def delete_repost(self, repost_id: str): 123 135 LOGGER.warning('Not Implemented.. "removed reblog" %s', repost_id) 136 + 124 137 125 138 def test_filters(tokens: list[Token], filters: list[re.Pattern[str]]): 126 139 if not tokens or not filters: 127 140 return True 128 - 129 - markdown = '' 130 - 141 + 142 + markdown = "" 143 + 131 144 for token in tokens: 132 145 if isinstance(token, TextToken): 133 146 markdown += token.text 134 147 elif isinstance(token, LinkToken): 135 - markdown += f'[{token.label}]({token.href})' 148 + markdown += f"[{token.label}]({token.href})" 136 149 elif isinstance(token, TagToken): 137 - markdown += '#' + token.tag 150 + markdown += "#" + token.tag 138 151 elif isinstance(token, MentionToken): 139 152 markdown += token.username 140 - 153 + 141 154 for filter in filters: 142 155 if filter.search(markdown): 143 156 return False 144 - 157 + 145 158 return True 146 159 147 - def split_tokens(tokens: list[Token], max_chars: int, max_link_len: int = 35) -> list[list[Token]]: 160 + 161 + def split_tokens( 162 + tokens: list[Token], max_chars: int, max_link_len: int = 35 163 + ) -> list[list[Token]]: 148 164 def new_block(): 149 165 nonlocal blocks, block, length 150 166 if block: 151 167 blocks.append(block) 152 168 block = [] 153 169 length = 0 154 - 170 + 155 171 def append_text(text_segment): 156 172 nonlocal block 157 173 # if the last element in the current block is also text, just append to it ··· 159 175 block[-1].text += text_segment 160 176 else: 161 177 block.append(TextToken(text_segment)) 162 - 178 + 163 179 blocks: list[list[Token]] = [] 164 180 block: list[Token] = [] 165 181 length = 0 166 - 182 + 167 183 for tk in tokens: 168 184 if isinstance(tk, TagToken): 169 - tag_len = 1 + len(tk.tag) # (#) + tag 185 + tag_len = 1 + len(tk.tag) # (#) + tag 170 186 if length + tag_len > max_chars: 171 - new_block() # create new block if the current one is too large 172 - 187 + new_block() # create new block if the current one is too large 188 + 173 189 block.append(tk) 174 190 length += tag_len 175 - elif isinstance(tk, LinkToken): # TODO labels should proably be split too 191 + elif isinstance(tk, LinkToken): # TODO labels should proably be split too 176 192 link_len = len(tk.label) 177 - if canonical_label(tk.label, tk.href): # cut down the link if the label is canonical 193 + if canonical_label( 194 + tk.label, tk.href 195 + ): # cut down the link if the label is canonical 178 196 link_len = min(link_len, max_link_len) 179 - 197 + 180 198 if length + link_len > max_chars: 181 199 new_block() 182 200 block.append(tk) 183 201 length += link_len 184 202 elif isinstance(tk, TextToken): 185 203 segments: list[str] = ALTERNATE.findall(tk.text) 186 - 204 + 187 205 for seg in segments: 188 206 seg_len: int = len(seg) 189 207 if length + seg_len <= max_chars - (0 if seg.isspace() else 1): 190 208 append_text(seg) 191 209 length += seg_len 192 210 continue 193 - 211 + 194 212 if length > 0: 195 213 new_block() 196 - 214 + 197 215 if not seg.isspace(): 198 216 while len(seg) > max_chars - 1: 199 217 chunk = seg[: max_chars - 1] + "-" ··· 202 220 seg = seg[max_chars - 1 :] 203 221 else: 204 222 while len(seg) > max_chars: 205 - chunk = seg[: max_chars] 223 + chunk = seg[:max_chars] 206 224 append_text(chunk) 207 225 new_block() 208 - seg = seg[max_chars :] 209 - 226 + seg = seg[max_chars:] 227 + 210 228 if seg: 211 229 append_text(seg) 212 230 length = len(seg) 213 - else: #TODO fix mentions 231 + else: # TODO fix mentions 214 232 block.append(tk) 215 - 233 + 216 234 if block: 217 235 blocks.append(block) 218 - 219 - return blocks 236 + 237 + return blocks
+59 -54
main.py
··· 1 - import os 1 + import asyncio 2 2 import json 3 - import asyncio, threading, queue, traceback 3 + import os 4 + import queue 5 + import threading 6 + import traceback 4 7 5 - from util.util import LOGGER, as_json 6 - import cross, util.database as database 7 - 8 + import cross 9 + import util.database as database 8 10 from bluesky.input import BlueskyJetstreamInput 9 - from bluesky.output import BlueskyOutputOptions, BlueskyOutput 10 - 11 - from mastodon.input import MastodonInputOptions, MastodonInput 11 + from bluesky.output import BlueskyOutput, BlueskyOutputOptions 12 + from mastodon.input import MastodonInput, MastodonInputOptions 12 13 from mastodon.output import MastodonOutput 13 - 14 14 from misskey.input import MisskeyInput 15 + from util.util import LOGGER, as_json 15 16 16 17 DEFAULT_SETTINGS: dict = { 17 - 'input': { 18 - 'type': 'mastodon-wss', 19 - 'instance': 'env:MASTODON_INSTANCE', 20 - 'token': 'env:MASTODON_TOKEN', 21 - "options": MastodonInputOptions({}) 18 + "input": { 19 + "type": "mastodon-wss", 20 + "instance": "env:MASTODON_INSTANCE", 21 + "token": "env:MASTODON_TOKEN", 22 + "options": MastodonInputOptions({}), 22 23 }, 23 - 'outputs': [ 24 + "outputs": [ 24 25 { 25 - 'type': 'bluesky', 26 - 'handle': 'env:BLUESKY_HANDLE', 27 - 'app-password': 'env:BLUESKY_APP_PASSWORD', 28 - 'options': BlueskyOutputOptions({}) 26 + "type": "bluesky", 27 + "handle": "env:BLUESKY_HANDLE", 28 + "app-password": "env:BLUESKY_APP_PASSWORD", 29 + "options": BlueskyOutputOptions({}), 29 30 } 30 - ] 31 + ], 31 32 } 32 33 33 34 INPUTS = { 34 35 "mastodon-wss": lambda settings, db: MastodonInput(settings, db), 35 36 "misskey-wss": lambda settigs, db: MisskeyInput(settigs, db), 36 - "bluesky-jetstream-wss": lambda settings, db: BlueskyJetstreamInput(settings, db) 37 + "bluesky-jetstream-wss": lambda settings, db: BlueskyJetstreamInput(settings, db), 37 38 } 38 39 39 40 OUTPUTS = { 40 41 "bluesky": lambda input, settings, db: BlueskyOutput(input, settings, db), 41 - "mastodon": lambda input, settings, db: MastodonOutput(input, settings, db) 42 + "mastodon": lambda input, settings, db: MastodonOutput(input, settings, db), 42 43 } 44 + 43 45 44 46 def execute(data_dir): 45 47 if not os.path.exists(data_dir): 46 48 os.makedirs(data_dir) 47 - 48 - settings_path = os.path.join(data_dir, 'settings.json') 49 - database_path = os.path.join(data_dir, 'data.db') 50 - 49 + 50 + settings_path = os.path.join(data_dir, "settings.json") 51 + database_path = os.path.join(data_dir, "data.db") 52 + 51 53 if not os.path.exists(settings_path): 52 54 LOGGER.info("First launch detected! Creating %s and exiting!", settings_path) 53 - 54 - with open(settings_path, 'w') as f: 55 + 56 + with open(settings_path, "w") as f: 55 57 f.write(as_json(DEFAULT_SETTINGS, indent=2)) 56 58 return 0 57 59 58 - LOGGER.info('Loading settings...') 59 - with open(settings_path, 'rb') as f: 60 + LOGGER.info("Loading settings...") 61 + with open(settings_path, "rb") as f: 60 62 settings = json.load(f) 61 - 62 - LOGGER.info('Starting database worker...') 63 + 64 + LOGGER.info("Starting database worker...") 63 65 db_worker = database.DataBaseWorker(os.path.abspath(database_path)) 64 - 65 - db_worker.execute('PRAGMA foreign_keys = ON;') 66 - 66 + 67 + db_worker.execute("PRAGMA foreign_keys = ON;") 68 + 67 69 # create the posts table 68 70 # id - internal id of the post 69 71 # user_id - user id on the service (e.g. a724sknj5y9ydk0w) ··· 82 84 ); 83 85 """ 84 86 ) 85 - 87 + 86 88 columns = db_worker.execute("PRAGMA table_info(posts)") 87 89 column_names = [col[1] for col in columns] 88 90 if "reposted_id" not in column_names: ··· 95 97 ALTER TABLE posts 96 98 ADD COLUMN extra_data TEXT NULL 97 99 """) 98 - 100 + 99 101 # create the mappings table 100 102 # original_post_id - the post this was mapped from 101 103 # mapped_post_id - the post this was mapped to ··· 107 109 ); 108 110 """ 109 111 ) 110 - 111 - input_settings = settings.get('input') 112 + 113 + input_settings = settings.get("input") 112 114 if not input_settings: 113 115 raise Exception("No input specified!") 114 - outputs_settings = settings.get('outputs', []) 115 - 116 - input = INPUTS[input_settings['type']](input_settings, db_worker) 117 - 116 + outputs_settings = settings.get("outputs", []) 117 + 118 + input = INPUTS[input_settings["type"]](input_settings, db_worker) 119 + 118 120 if not outputs_settings: 119 121 LOGGER.warning("No outputs specified! Check the config!") 120 - 122 + 121 123 outputs: list[cross.Output] = [] 122 124 for output_settings in outputs_settings: 123 - outputs.append(OUTPUTS[output_settings['type']](input, output_settings, db_worker)) 124 - 125 - LOGGER.info('Starting task worker...') 125 + outputs.append( 126 + OUTPUTS[output_settings["type"]](input, output_settings, db_worker) 127 + ) 128 + 129 + LOGGER.info("Starting task worker...") 130 + 126 131 def worker(queue: queue.Queue): 127 132 while True: 128 133 task = queue.get() 129 134 if task is None: 130 135 break 131 - 136 + 132 137 try: 133 138 task() 134 139 except Exception as e: ··· 136 141 traceback.print_exc() 137 142 finally: 138 143 queue.task_done() 139 - 144 + 140 145 task_queue = queue.Queue() 141 146 thread = threading.Thread(target=worker, args=(task_queue,), daemon=True) 142 147 thread.start() 143 - 144 - LOGGER.info('Connecting to %s...', input.service) 148 + 149 + LOGGER.info("Connecting to %s...", input.service) 145 150 try: 146 151 asyncio.run(input.listen(outputs, lambda x: task_queue.put(x))) 147 152 except KeyboardInterrupt: 148 153 LOGGER.info("Stopping...") 149 - 154 + 150 155 task_queue.join() 151 156 task_queue.put(None) 152 157 thread.join() 153 - 158 + 154 159 155 160 if __name__ == "__main__": 156 - execute('./data') 161 + execute("./data")
+26 -20
mastodon/common.py
··· 1 1 import cross 2 2 from util.media import MediaInfo 3 3 4 + 4 5 class MastodonPost(cross.Post): 5 - def __init__(self, status: dict, tokens: list[cross.Token], media_attachments: list[MediaInfo]) -> None: 6 + def __init__( 7 + self, 8 + status: dict, 9 + tokens: list[cross.Token], 10 + media_attachments: list[MediaInfo], 11 + ) -> None: 6 12 super().__init__() 7 - self.id = status['id'] 8 - self.parent_id = status.get('in_reply_to_id') 13 + self.id = status["id"] 14 + self.parent_id = status.get("in_reply_to_id") 9 15 self.tokens = tokens 10 - self.content_type = status.get('content_type', 'text/plain') 11 - self.timestamp = status['created_at'] 16 + self.content_type = status.get("content_type", "text/plain") 17 + self.timestamp = status["created_at"] 12 18 self.media_attachments = media_attachments 13 - self.spoiler = status.get('spoiler_text') 14 - self.language = [status['language']] if status.get('language') else [] 15 - self.sensitive = status.get('sensitive', False) 16 - self.url = status.get('url') 17 - 19 + self.spoiler = status.get("spoiler_text") 20 + self.language = [status["language"]] if status.get("language") else [] 21 + self.sensitive = status.get("sensitive", False) 22 + self.url = status.get("url") 23 + 18 24 def get_id(self) -> str: 19 25 return self.id 20 - 26 + 21 27 def get_parent_id(self) -> str | None: 22 28 return self.parent_id 23 - 29 + 24 30 def get_tokens(self) -> list[cross.Token]: 25 31 return self.tokens 26 - 32 + 27 33 def get_text_type(self) -> str: 28 34 return self.content_type 29 - 35 + 30 36 def get_timestamp(self) -> str: 31 37 return self.timestamp 32 - 38 + 33 39 def get_attachments(self) -> list[MediaInfo]: 34 40 return self.media_attachments 35 - 41 + 36 42 def get_spoiler(self) -> str | None: 37 43 return self.spoiler 38 - 44 + 39 45 def get_languages(self) -> list[str]: 40 46 return self.language 41 - 47 + 42 48 def is_sensitive(self) -> bool: 43 49 return self.sensitive or (self.spoiler is not None) 44 - 50 + 45 51 def get_post_url(self) -> str | None: 46 - return self.url 52 + return self.url
+124 -96
mastodon/input.py
··· 1 - import requests, websockets 1 + import asyncio 2 2 import json 3 3 import re 4 - import asyncio 4 + from typing import Any, Callable 5 5 6 - from mastodon.common import MastodonPost 6 + import requests 7 + import websockets 8 + 9 + import cross 10 + import util.database as database 7 11 import util.html_util as html_util 8 12 import util.md_util as md_util 9 - 10 - import cross, util.database as database 11 - from util.util import LOGGER, as_envvar 12 - from util.media import MediaInfo, download_media 13 + from mastodon.common import MastodonPost 13 14 from util.database import DataBaseWorker 15 + from util.media import MediaInfo, download_media 16 + from util.util import LOGGER, as_envvar 14 17 15 - from typing import Callable, Any 18 + ALLOWED_VISIBILITY = ["public", "unlisted"] 19 + MARKDOWNY = ["text/x.misskeymarkdown", "text/markdown", "text/plain"] 16 20 17 - ALLOWED_VISIBILITY = ['public', 'unlisted'] 18 - MARKDOWNY = ['text/x.misskeymarkdown', 'text/markdown', 'text/plain'] 19 21 20 - class MastodonInputOptions(): 22 + class MastodonInputOptions: 21 23 def __init__(self, o: dict) -> None: 22 24 self.allowed_visibility = ALLOWED_VISIBILITY 23 - self.filters = [re.compile(f) for f in o.get('regex_filters', [])] 24 - 25 - allowed_visibility = o.get('allowed_visibility') 25 + self.filters = [re.compile(f) for f in o.get("regex_filters", [])] 26 + 27 + allowed_visibility = o.get("allowed_visibility") 26 28 if allowed_visibility is not None: 27 29 if any([v not in ALLOWED_VISIBILITY for v in allowed_visibility]): 28 - raise ValueError(f"'allowed_visibility' only accepts {', '.join(ALLOWED_VISIBILITY)}, got: {allowed_visibility}") 30 + raise ValueError( 31 + f"'allowed_visibility' only accepts {', '.join(ALLOWED_VISIBILITY)}, got: {allowed_visibility}" 32 + ) 29 33 self.allowed_visibility = allowed_visibility 34 + 30 35 31 36 class MastodonInput(cross.Input): 32 37 def __init__(self, settings: dict, db: DataBaseWorker) -> None: 33 - self.options = MastodonInputOptions(settings.get('options', {})) 34 - self.token = as_envvar(settings.get('token')) or (_ for _ in ()).throw(ValueError("'token' is required")) 35 - instance: str = as_envvar(settings.get('instance')) or (_ for _ in ()).throw(ValueError("'instance' is required")) 36 - 37 - service = instance[:-1] if instance.endswith('/') else instance 38 - 38 + self.options = MastodonInputOptions(settings.get("options", {})) 39 + self.token = as_envvar(settings.get("token")) or (_ for _ in ()).throw( 40 + ValueError("'token' is required") 41 + ) 42 + instance: str = as_envvar(settings.get("instance")) or (_ for _ in ()).throw( 43 + ValueError("'instance' is required") 44 + ) 45 + 46 + service = instance[:-1] if instance.endswith("/") else instance 47 + 39 48 LOGGER.info("Verifying %s credentails...", service) 40 - responce = requests.get(f"{service}/api/v1/accounts/verify_credentials", headers={ 41 - 'Authorization': f'Bearer {self.token}' 42 - }) 49 + responce = requests.get( 50 + f"{service}/api/v1/accounts/verify_credentials", 51 + headers={"Authorization": f"Bearer {self.token}"}, 52 + ) 43 53 if responce.status_code != 200: 44 54 LOGGER.error("Failed to validate user credentials!") 45 55 responce.raise_for_status() 46 56 return 47 - 57 + 48 58 super().__init__(service, responce.json()["id"], settings, db) 49 59 self.streaming = self._get_streaming_url() 50 - 60 + 51 61 if not self.streaming: 52 62 raise Exception("Instance %s does not support streaming!", service) 53 63 ··· 55 65 response = requests.get(f"{self.service}/api/v1/instance") 56 66 response.raise_for_status() 57 67 data: dict = response.json() 58 - return (data.get('urls') or {}).get('streaming_api') 68 + return (data.get("urls") or {}).get("streaming_api") 59 69 60 70 def __to_tokens(self, status: dict): 61 - content_type = status.get('content_type', 'text/plain') 62 - raw_text = status.get('text') 63 - 71 + content_type = status.get("content_type", "text/plain") 72 + raw_text = status.get("text") 73 + 64 74 tags: list[str] = [] 65 - for tag in status.get('tags', []): 66 - tags.append(tag['name']) 67 - 75 + for tag in status.get("tags", []): 76 + tags.append(tag["name"]) 77 + 68 78 mentions: list[tuple[str, str]] = [] 69 - for mention in status.get('mentions', []): 70 - mentions.append(('@' + mention['username'], '@' + mention['acct'])) 71 - 79 + for mention in status.get("mentions", []): 80 + mentions.append(("@" + mention["username"], "@" + mention["acct"])) 81 + 72 82 if raw_text and content_type in MARKDOWNY: 73 83 return md_util.tokenize_markdown(raw_text, tags, mentions) 74 - 75 - akkoma_ext: dict | None = status.get('akkoma', {}).get('source') 84 + 85 + akkoma_ext: dict | None = status.get("akkoma", {}).get("source") 76 86 if akkoma_ext: 77 - if akkoma_ext.get('mediaType') in MARKDOWNY: 87 + if akkoma_ext.get("mediaType") in MARKDOWNY: 78 88 return md_util.tokenize_markdown(akkoma_ext["content"], tags, mentions) 79 - 89 + 80 90 tokenizer = html_util.HTMLPostTokenizer() 81 91 tokenizer.mentions = mentions 82 92 tokenizer.tags = tags 83 - tokenizer.feed(status.get('content', "")) 93 + tokenizer.feed(status.get("content", "")) 84 94 return tokenizer.get_tokens() 85 - 95 + 86 96 def _on_create_post(self, outputs: list[cross.Output], status: dict): 87 97 # skip events from other users 88 - if (status.get('account') or {})['id'] != self.user_id: 98 + if (status.get("account") or {})["id"] != self.user_id: 89 99 return 90 - 91 - if status.get('visibility') not in self.options.allowed_visibility: 100 + 101 + if status.get("visibility") not in self.options.allowed_visibility: 92 102 # Skip f/o and direct posts 93 - LOGGER.info("Skipping '%s'! '%s' visibility..", status['id'], status.get('visibility')) 103 + LOGGER.info( 104 + "Skipping '%s'! '%s' visibility..", 105 + status["id"], 106 + status.get("visibility"), 107 + ) 94 108 return 95 - 109 + 96 110 # TODO polls not supported on bsky. maybe 3rd party? skip for now 97 111 # we don't handle reblogs. possible with bridgy(?) and self 98 112 # we don't handle quotes. 99 - if status.get('poll'): 100 - LOGGER.info("Skipping '%s'! Contains a poll..", status['id']) 113 + if status.get("poll"): 114 + LOGGER.info("Skipping '%s'! Contains a poll..", status["id"]) 101 115 return 102 - 103 - if status.get('quote_id') or status.get('quote'): 104 - LOGGER.info("Skipping '%s'! Quote..", status['id']) 116 + 117 + if status.get("quote_id") or status.get("quote"): 118 + LOGGER.info("Skipping '%s'! Quote..", status["id"]) 105 119 return 106 - 107 - reblog: dict | None = status.get('reblog') 120 + 121 + reblog: dict | None = status.get("reblog") 108 122 if reblog: 109 - if (reblog.get('account') or {})['id'] != self.user_id: 110 - LOGGER.info("Skipping '%s'! Reblog of other user..", status['id']) 123 + if (reblog.get("account") or {})["id"] != self.user_id: 124 + LOGGER.info("Skipping '%s'! Reblog of other user..", status["id"]) 111 125 return 112 - 113 - success = database.try_insert_repost(self.db, status['id'], reblog['id'], self.user_id, self.service) 126 + 127 + success = database.try_insert_repost( 128 + self.db, status["id"], reblog["id"], self.user_id, self.service 129 + ) 114 130 if not success: 115 - LOGGER.info("Skipping '%s' as reblogged post was not found in db!", status['id']) 131 + LOGGER.info( 132 + "Skipping '%s' as reblogged post was not found in db!", status["id"] 133 + ) 116 134 return 117 - 135 + 118 136 for output in outputs: 119 - output.accept_repost(status['id'], reblog['id']) 137 + output.accept_repost(status["id"], reblog["id"]) 120 138 return 121 - 122 - in_reply: str | None = status.get('in_reply_to_id') 123 - in_reply_to: str | None = status.get('in_reply_to_account_id') 139 + 140 + in_reply: str | None = status.get("in_reply_to_id") 141 + in_reply_to: str | None = status.get("in_reply_to_account_id") 124 142 if in_reply_to and in_reply_to != self.user_id: 125 143 # We don't support replies. 126 - LOGGER.info("Skipping '%s'! Reply to other user..", status['id']) 144 + LOGGER.info("Skipping '%s'! Reply to other user..", status["id"]) 127 145 return 128 - 129 - success = database.try_insert_post(self.db, status['id'], in_reply, self.user_id, self.service) 146 + 147 + success = database.try_insert_post( 148 + self.db, status["id"], in_reply, self.user_id, self.service 149 + ) 130 150 if not success: 131 - LOGGER.info("Skipping '%s' as parent post was not found in db!", status['id']) 151 + LOGGER.info( 152 + "Skipping '%s' as parent post was not found in db!", status["id"] 153 + ) 132 154 return 133 - 155 + 134 156 tokens = self.__to_tokens(status) 135 157 if not cross.test_filters(tokens, self.options.filters): 136 - LOGGER.info("Skipping '%s'. Matched a filter!", status['id']) 158 + LOGGER.info("Skipping '%s'. Matched a filter!", status["id"]) 137 159 return 138 - 139 - LOGGER.info("Crossposting '%s'...", status['id']) 140 - 160 + 161 + LOGGER.info("Crossposting '%s'...", status["id"]) 162 + 141 163 media_attachments: list[MediaInfo] = [] 142 - for attachment in status.get('media_attachments', []): 143 - LOGGER.info("Downloading %s...", attachment['url']) 144 - info = download_media(attachment['url'], attachment.get('description') or '') 164 + for attachment in status.get("media_attachments", []): 165 + LOGGER.info("Downloading %s...", attachment["url"]) 166 + info = download_media( 167 + attachment["url"], attachment.get("description") or "" 168 + ) 145 169 if not info: 146 - LOGGER.error("Skipping '%s'. Failed to download media!", status['id']) 170 + LOGGER.error("Skipping '%s'. Failed to download media!", status["id"]) 147 171 return 148 172 media_attachments.append(info) 149 - 173 + 150 174 cross_post = MastodonPost(status, tokens, media_attachments) 151 175 for output in outputs: 152 176 output.accept_post(cross_post) 153 - 177 + 154 178 def _on_delete_post(self, outputs: list[cross.Output], identifier: str): 155 179 post = database.find_post(self.db, identifier, self.user_id, self.service) 156 180 if not post: 157 181 return 158 - 182 + 159 183 LOGGER.info("Deleting '%s'...", identifier) 160 - if post['reposted_id']: 184 + if post["reposted_id"]: 161 185 for output in outputs: 162 186 output.delete_repost(identifier) 163 187 else: 164 188 for output in outputs: 165 189 output.delete_post(identifier) 166 - 190 + 167 191 database.delete_post(self.db, identifier, self.user_id, self.service) 168 - 192 + 169 193 def _on_post(self, outputs: list[cross.Output], event: str, payload: str): 170 194 match event: 171 - case 'update': 195 + case "update": 172 196 self._on_create_post(outputs, json.loads(payload)) 173 - case 'delete': 197 + case "delete": 174 198 self._on_delete_post(outputs, payload) 175 - 176 - async def listen(self, outputs: list[cross.Output], submit: Callable[[Callable[[], Any]], Any]): 199 + 200 + async def listen( 201 + self, outputs: list[cross.Output], submit: Callable[[Callable[[], Any]], Any] 202 + ): 177 203 uri = f"{self.streaming}/api/v1/streaming?stream=user&access_token={self.token}" 178 - 179 - async for ws in websockets.connect(uri, extra_headers={"User-Agent": "XPost/0.0.3"}): 204 + 205 + async for ws in websockets.connect( 206 + uri, extra_headers={"User-Agent": "XPost/0.0.3"} 207 + ): 180 208 try: 181 209 LOGGER.info("Listening to %s...", self.streaming) 182 - 210 + 183 211 async def listen_for_messages(): 184 212 async for msg in ws: 185 213 data = json.loads(msg) 186 - event: str = data.get('event') 187 - payload: str = data.get('payload') 188 - 214 + event: str = data.get("event") 215 + payload: str = data.get("payload") 216 + 189 217 submit(lambda: self._on_post(outputs, str(event), str(payload))) 190 - 218 + 191 219 listen = asyncio.create_task(listen_for_messages()) 192 - 220 + 193 221 await asyncio.gather(listen) 194 222 except websockets.ConnectionClosedError as e: 195 223 LOGGER.error(e, stack_info=True, exc_info=True) 196 224 LOGGER.info("Reconnecting to %s...", self.streaming) 197 - continue 225 + continue
+252 -214
mastodon/output.py
··· 1 - import requests, time 1 + import time 2 2 3 - import cross, util.database as database 3 + import requests 4 + 5 + import cross 4 6 import misskey.mfm_util as mfm_util 5 - from util.util import LOGGER, as_envvar, canonical_label 7 + import util.database as database 8 + from util.database import DataBaseWorker 6 9 from util.media import MediaInfo 7 - from util.database import DataBaseWorker 10 + from util.util import LOGGER, as_envvar, canonical_label 8 11 9 12 POSSIBLE_MIMES = [ 10 - 'audio/ogg', 11 - 'audio/mp3', 12 - 'image/webp', 13 - 'image/jpeg', 14 - 'image/png', 15 - 'video/mp4', 16 - 'video/quicktime', 17 - 'video/webm' 13 + "audio/ogg", 14 + "audio/mp3", 15 + "image/webp", 16 + "image/jpeg", 17 + "image/png", 18 + "video/mp4", 19 + "video/quicktime", 20 + "video/webm", 18 21 ] 19 22 20 - TEXT_MIMES = [ 21 - 'text/x.misskeymarkdown', 22 - 'text/markdown', 23 - 'text/plain' 24 - ] 23 + TEXT_MIMES = ["text/x.misskeymarkdown", "text/markdown", "text/plain"] 24 + 25 + ALLOWED_POSTING_VISIBILITY = ["public", "unlisted", "private"] 25 26 26 - ALLOWED_POSTING_VISIBILITY = ['public', 'unlisted', 'private'] 27 27 28 - class MastodonOutputOptions(): 28 + class MastodonOutputOptions: 29 29 def __init__(self, o: dict) -> None: 30 - self.visibility = 'public' 31 - 32 - visibility = o.get('visibility') 30 + self.visibility = "public" 31 + 32 + visibility = o.get("visibility") 33 33 if visibility is not None: 34 34 if visibility not in ALLOWED_POSTING_VISIBILITY: 35 - raise ValueError(f"'visibility' only accepts {', '.join(ALLOWED_POSTING_VISIBILITY)}, got: {visibility}") 35 + raise ValueError( 36 + f"'visibility' only accepts {', '.join(ALLOWED_POSTING_VISIBILITY)}, got: {visibility}" 37 + ) 36 38 self.visibility = visibility 39 + 37 40 38 41 class MastodonOutput(cross.Output): 39 42 def __init__(self, input: cross.Input, settings: dict, db: DataBaseWorker) -> None: 40 43 super().__init__(input, settings, db) 41 - self.options = settings.get('options') or {} 42 - self.token = as_envvar(settings.get('token')) or (_ for _ in ()).throw(ValueError("'token' is required")) 43 - instance: str = as_envvar(settings.get('instance')) or (_ for _ in ()).throw(ValueError("'instance' is required")) 44 - 45 - self.service = instance[:-1] if instance.endswith('/') else instance 46 - 44 + self.options = settings.get("options") or {} 45 + self.token = as_envvar(settings.get("token")) or (_ for _ in ()).throw( 46 + ValueError("'token' is required") 47 + ) 48 + instance: str = as_envvar(settings.get("instance")) or (_ for _ in ()).throw( 49 + ValueError("'instance' is required") 50 + ) 51 + 52 + self.service = instance[:-1] if instance.endswith("/") else instance 53 + 47 54 LOGGER.info("Verifying %s credentails...", self.service) 48 - responce = requests.get(f"{self.service}/api/v1/accounts/verify_credentials", headers={ 49 - 'Authorization': f'Bearer {self.token}' 50 - }) 55 + responce = requests.get( 56 + f"{self.service}/api/v1/accounts/verify_credentials", 57 + headers={"Authorization": f"Bearer {self.token}"}, 58 + ) 51 59 if responce.status_code != 200: 52 60 LOGGER.error("Failed to validate user credentials!") 53 61 responce.raise_for_status() ··· 55 63 self.user_id: str = responce.json()["id"] 56 64 57 65 LOGGER.info("Getting %s configuration...", self.service) 58 - responce = requests.get(f"{self.service}/api/v1/instance", headers={ 59 - 'Authorization': f'Bearer {self.token}' 60 - }) 66 + responce = requests.get( 67 + f"{self.service}/api/v1/instance", 68 + headers={"Authorization": f"Bearer {self.token}"}, 69 + ) 61 70 if responce.status_code != 200: 62 71 LOGGER.error("Failed to get instance info!") 63 72 responce.raise_for_status() 64 73 return 65 - 74 + 66 75 instance_info: dict = responce.json() 67 - configuration: dict = instance_info['configuration'] 68 - 69 - statuses_config: dict = configuration.get('statuses', {}) 70 - self.max_characters: int = statuses_config.get('max_characters', 500) 71 - self.max_media_attachments: int = statuses_config.get('max_media_attachments', 4) 72 - self.characters_reserved_per_url: int = statuses_config.get('characters_reserved_per_url', 23) 73 - 74 - media_config: dict = configuration.get('media_attachments', {}) 75 - self.image_size_limit: int = media_config.get('image_size_limit', 16777216) 76 - self.video_size_limit: int = media_config.get('video_size_limit', 103809024) 77 - self.supported_mime_types: list[str] = media_config.get('supported_mime_types', POSSIBLE_MIMES) 78 - 76 + configuration: dict = instance_info["configuration"] 77 + 78 + statuses_config: dict = configuration.get("statuses", {}) 79 + self.max_characters: int = statuses_config.get("max_characters", 500) 80 + self.max_media_attachments: int = statuses_config.get( 81 + "max_media_attachments", 4 82 + ) 83 + self.characters_reserved_per_url: int = statuses_config.get( 84 + "characters_reserved_per_url", 23 85 + ) 86 + 87 + media_config: dict = configuration.get("media_attachments", {}) 88 + self.image_size_limit: int = media_config.get("image_size_limit", 16777216) 89 + self.video_size_limit: int = media_config.get("video_size_limit", 103809024) 90 + self.supported_mime_types: list[str] = media_config.get( 91 + "supported_mime_types", POSSIBLE_MIMES 92 + ) 93 + 79 94 # *oma: max post chars 80 - max_toot_chars = instance_info.get('max_toot_chars') 95 + max_toot_chars = instance_info.get("max_toot_chars") 81 96 if max_toot_chars: 82 97 self.max_characters: int = max_toot_chars 83 - 98 + 84 99 # *oma: max upload limit 85 - upload_limit = instance_info.get('upload_limit') 100 + upload_limit = instance_info.get("upload_limit") 86 101 if upload_limit: 87 102 self.image_size_limit: int = upload_limit 88 103 self.video_size_limit: int = upload_limit 89 - 104 + 90 105 # chuckya: supported text types 91 - chuckya_text_mimes: list[str] = statuses_config.get('supported_mime_types', []) 106 + chuckya_text_mimes: list[str] = statuses_config.get("supported_mime_types", []) 92 107 self.text_format = next( 93 - (mime for mime in TEXT_MIMES if mime in (chuckya_text_mimes)), 94 - 'text/plain' 108 + (mime for mime in TEXT_MIMES if mime in (chuckya_text_mimes)), "text/plain" 95 109 ) 96 - 110 + 97 111 # *oma ext: supported text types 98 - pleroma = instance_info.get('pleroma') 112 + pleroma = instance_info.get("pleroma") 99 113 if pleroma: 100 - post_formats: list[str] = pleroma.get('metadata', {}).get('post_formats', []) 114 + post_formats: list[str] = pleroma.get("metadata", {}).get( 115 + "post_formats", [] 116 + ) 101 117 self.text_format = next( 102 - (mime for mime in TEXT_MIMES if mime in post_formats), 103 - self.text_format 118 + (mime for mime in TEXT_MIMES if mime in post_formats), self.text_format 104 119 ) 105 - 120 + 106 121 def upload_media(self, attachments: list[MediaInfo]) -> list[str] | None: 107 122 for a in attachments: 108 - if a.mime.startswith('image/') and len(a.io) > self.image_size_limit: 123 + if a.mime.startswith("image/") and len(a.io) > self.image_size_limit: 109 124 return None 110 - 111 - if a.mime.startswith('video/') and len(a.io) > self.video_size_limit: 125 + 126 + if a.mime.startswith("video/") and len(a.io) > self.video_size_limit: 112 127 return None 113 - 114 - if not a.mime.startswith('image/') and not a.mime.startswith('video/'): 128 + 129 + if not a.mime.startswith("image/") and not a.mime.startswith("video/"): 115 130 if len(a.io) > 7_000_000: 116 131 return None 117 - 132 + 118 133 uploads: list[dict] = [] 119 134 for a in attachments: 120 135 data = {} 121 136 if a.alt: 122 - data['description'] = a.alt 123 - 124 - req = requests.post(f"{self.service}/api/v2/media", headers= { 125 - 'Authorization': f'Bearer {self.token}' 126 - }, files={'file': (a.name, a.io, a.mime)}, data=data) 127 - 137 + data["description"] = a.alt 138 + 139 + req = requests.post( 140 + f"{self.service}/api/v2/media", 141 + headers={"Authorization": f"Bearer {self.token}"}, 142 + files={"file": (a.name, a.io, a.mime)}, 143 + data=data, 144 + ) 145 + 128 146 if req.status_code == 200: 129 - LOGGER.info("Uploaded %s! (%s)", a.name, req.json()['id']) 130 - uploads.append({ 131 - 'done': True, 132 - 'id': req.json()['id'] 133 - }) 147 + LOGGER.info("Uploaded %s! (%s)", a.name, req.json()["id"]) 148 + uploads.append({"done": True, "id": req.json()["id"]}) 134 149 elif req.status_code == 202: 135 150 LOGGER.info("Waiting for %s to process!", a.name) 136 - uploads.append({ 137 - 'done': False, 138 - 'id': req.json()['id'] 139 - }) 151 + uploads.append({"done": False, "id": req.json()["id"]}) 140 152 else: 141 153 LOGGER.error("Failed to upload %s! %s", a.name, req.text) 142 154 req.raise_for_status() 143 - 144 - while any([not val['done'] for val in uploads]): 155 + 156 + while any([not val["done"] for val in uploads]): 145 157 LOGGER.info("Waiting for media to process...") 146 158 time.sleep(3) 147 159 for media in uploads: 148 - if media['done']: 160 + if media["done"]: 149 161 continue 150 - 151 - reqs = requests.get(f'{self.service}/api/v1/media/{media['id']}', headers={ 152 - 'Authorization': f'Bearer {self.token}' 153 - }) 154 - 162 + 163 + reqs = requests.get( 164 + f"{self.service}/api/v1/media/{media['id']}", 165 + headers={"Authorization": f"Bearer {self.token}"}, 166 + ) 167 + 155 168 if reqs.status_code == 206: 156 169 continue 157 - 170 + 158 171 if reqs.status_code == 200: 159 - media['done'] = True 172 + media["done"] = True 160 173 continue 161 174 reqs.raise_for_status() 162 - 163 - return [val['id'] for val in uploads] 175 + 176 + return [val["id"] for val in uploads] 164 177 165 178 def token_to_string(self, tokens: list[cross.Token]) -> str | None: 166 - p_text: str = '' 167 - 179 + p_text: str = "" 180 + 168 181 for token in tokens: 169 182 if isinstance(token, cross.TextToken): 170 183 p_text += token.text 171 184 elif isinstance(token, cross.TagToken): 172 - p_text += '#' + token.tag 185 + p_text += "#" + token.tag 173 186 elif isinstance(token, cross.LinkToken): 174 187 if canonical_label(token.label, token.href): 175 188 p_text += token.href 176 189 else: 177 - if self.text_format == 'text/plain': 178 - p_text += f'{token.label} ({token.href})' 179 - elif self.text_format in {'text/x.misskeymarkdown', 'text/markdown'}: 180 - p_text += f'[{token.label}]({token.href})' 190 + if self.text_format == "text/plain": 191 + p_text += f"{token.label} ({token.href})" 192 + elif self.text_format in { 193 + "text/x.misskeymarkdown", 194 + "text/markdown", 195 + }: 196 + p_text += f"[{token.label}]({token.href})" 181 197 else: 182 198 return None 183 - 199 + 184 200 return p_text 185 201 186 202 def split_tokens_media(self, tokens: list[cross.Token], media: list[MediaInfo]): 187 - split_tokens = cross.split_tokens(tokens, self.max_characters, self.characters_reserved_per_url) 203 + split_tokens = cross.split_tokens( 204 + tokens, self.max_characters, self.characters_reserved_per_url 205 + ) 188 206 post_text: list[str] = [] 189 - 207 + 190 208 for block in split_tokens: 191 209 baked_text = self.token_to_string(block) 192 - 210 + 193 211 if baked_text is None: 194 212 return None 195 213 post_text.append(baked_text) 196 - 214 + 197 215 if not post_text: 198 - post_text = [''] 199 - 200 - posts: list[dict] = [{"text": post_text, "attachments": []} for post_text in post_text] 216 + post_text = [""] 217 + 218 + posts: list[dict] = [ 219 + {"text": post_text, "attachments": []} for post_text in post_text 220 + ] 201 221 available_indices: list[int] = list(range(len(posts))) 202 - 222 + 203 223 current_image_post_idx: int | None = None 204 - 224 + 205 225 def make_blank_post() -> dict: 206 - return { 207 - "text": '', 208 - "attachments": [] 209 - } 210 - 226 + return {"text": "", "attachments": []} 227 + 211 228 def pop_next_empty_index() -> int: 212 229 if available_indices: 213 230 return available_indices.pop(0) ··· 215 232 new_idx = len(posts) 216 233 posts.append(make_blank_post()) 217 234 return new_idx 218 - 235 + 219 236 for att in media: 220 237 if ( 221 238 current_image_post_idx is not None 222 - and len(posts[current_image_post_idx]["attachments"]) < self.max_media_attachments 239 + and len(posts[current_image_post_idx]["attachments"]) 240 + < self.max_media_attachments 223 241 ): 224 242 posts[current_image_post_idx]["attachments"].append(att) 225 243 else: 226 244 idx = pop_next_empty_index() 227 245 posts[idx]["attachments"].append(att) 228 246 current_image_post_idx = idx 229 - 247 + 230 248 result: list[tuple[str, list[MediaInfo]]] = [] 231 - 249 + 232 250 for p in posts: 233 - result.append((p['text'], p["attachments"])) 234 - 251 + result.append((p["text"], p["attachments"])) 252 + 235 253 return result 236 - 254 + 237 255 def accept_post(self, post: cross.Post): 238 256 parent_id = post.get_parent_id() 239 - 257 + 240 258 new_root_id: int | None = None 241 259 new_parent_id: int | None = None 242 - 260 + 243 261 reply_ref: str | None = None 244 262 if parent_id: 245 263 thread_tuple = database.find_mapped_thread( ··· 248 266 self.input.user_id, 249 267 self.input.service, 250 268 self.user_id, 251 - self.service 269 + self.service, 252 270 ) 253 - 271 + 254 272 if not thread_tuple: 255 273 LOGGER.error("Failed to find thread tuple in the database!") 256 274 return None 257 - 275 + 258 276 _, reply_ref, new_root_id, new_parent_id = thread_tuple 259 - 277 + 260 278 lang: str 261 279 if post.get_languages(): 262 280 lang = post.get_languages()[0] 263 281 else: 264 - lang = 'en' 265 - 282 + lang = "en" 283 + 266 284 post_tokens = post.get_tokens() 267 285 if post.get_text_type() == "text/x.misskeymarkdown": 268 286 post_tokens, status = mfm_util.strip_mfm(post_tokens) 269 287 post_url = post.get_post_url() 270 288 if status and post_url: 271 - post_tokens.append(cross.TextToken('\n')) 272 - post_tokens.append(cross.LinkToken(post_url, "[Post contains MFM, see original]")) 273 - 289 + post_tokens.append(cross.TextToken("\n")) 290 + post_tokens.append( 291 + cross.LinkToken(post_url, "[Post contains MFM, see original]") 292 + ) 293 + 274 294 raw_statuses = self.split_tokens_media(post_tokens, post.get_attachments()) 275 295 if not raw_statuses: 276 296 LOGGER.error("Failed to split post into statuses?") 277 297 return None 278 298 baked_statuses = [] 279 - 299 + 280 300 for status, raw_media in raw_statuses: 281 301 media: list[str] | None = None 282 302 if raw_media: ··· 286 306 return None 287 307 baked_statuses.append((status, media)) 288 308 continue 289 - baked_statuses.append((status,[])) 290 - 309 + baked_statuses.append((status, [])) 310 + 291 311 created_statuses: list[str] = [] 292 - 312 + 293 313 for status, media in baked_statuses: 294 314 payload = { 295 - 'status': status, 296 - 'media_ids': media or [], 297 - 'spoiler_text': post.get_spoiler() or '', 298 - 'visibility': self.options.get('visibility', 'public'), 299 - 'content_type': self.text_format, 300 - 'language': lang 315 + "status": status, 316 + "media_ids": media or [], 317 + "spoiler_text": post.get_spoiler() or "", 318 + "visibility": self.options.get("visibility", "public"), 319 + "content_type": self.text_format, 320 + "language": lang, 301 321 } 302 - 322 + 303 323 if media: 304 - payload['sensitive'] = post.is_sensitive() 305 - 324 + payload["sensitive"] = post.is_sensitive() 325 + 306 326 if post.get_spoiler(): 307 - payload['sensitive'] = True 308 - 327 + payload["sensitive"] = True 328 + 309 329 if not status: 310 - payload['status'] = '🖼️' 311 - 330 + payload["status"] = "🖼️" 331 + 312 332 if reply_ref: 313 - payload['in_reply_to_id'] = reply_ref 314 - 315 - reqs = requests.post(f'{self.service}/api/v1/statuses', headers={ 316 - 'Authorization': f'Bearer {self.token}', 317 - 'Content-Type': 'application/json' 318 - }, json=payload) 319 - 333 + payload["in_reply_to_id"] = reply_ref 334 + 335 + reqs = requests.post( 336 + f"{self.service}/api/v1/statuses", 337 + headers={ 338 + "Authorization": f"Bearer {self.token}", 339 + "Content-Type": "application/json", 340 + }, 341 + json=payload, 342 + ) 343 + 320 344 if reqs.status_code != 200: 321 - LOGGER.info("Failed to post status! %s - %s", reqs.status_code, reqs.text) 345 + LOGGER.info( 346 + "Failed to post status! %s - %s", reqs.status_code, reqs.text 347 + ) 322 348 reqs.raise_for_status() 323 - 324 - reply_ref = reqs.json()['id'] 349 + 350 + reply_ref = reqs.json()["id"] 325 351 LOGGER.info("Created new status %s!", reply_ref) 326 - 327 - created_statuses.append(reqs.json()['id']) 328 - 329 - db_post = database.find_post(self.db, post.get_id(), self.input.user_id, self.input.service) 352 + 353 + created_statuses.append(reqs.json()["id"]) 354 + 355 + db_post = database.find_post( 356 + self.db, post.get_id(), self.input.user_id, self.input.service 357 + ) 330 358 assert db_post, "ghghghhhhh" 331 - 332 - if new_root_id is None or new_parent_id is None: 359 + 360 + if new_root_id is None or new_parent_id is None: 333 361 new_root_id = database.insert_post( 334 - self.db, 335 - created_statuses[0], 336 - self.user_id, 337 - self.service 362 + self.db, created_statuses[0], self.user_id, self.service 338 363 ) 339 364 new_parent_id = new_root_id 340 - database.insert_mapping(self.db, db_post['id'], new_parent_id) 365 + database.insert_mapping(self.db, db_post["id"], new_parent_id) 341 366 created_statuses = created_statuses[1:] 342 - 367 + 343 368 for db_id in created_statuses: 344 369 new_parent_id = database.insert_reply( 345 - self.db, 346 - db_id, 347 - self.user_id, 348 - self.service, 349 - new_parent_id, 350 - new_root_id 370 + self.db, db_id, self.user_id, self.service, new_parent_id, new_root_id 351 371 ) 352 - database.insert_mapping(self.db, db_post['id'], new_parent_id) 353 - 372 + database.insert_mapping(self.db, db_post["id"], new_parent_id) 373 + 354 374 def delete_post(self, identifier: str): 355 - post = database.find_post(self.db, identifier, self.input.user_id, self.input.service) 375 + post = database.find_post( 376 + self.db, identifier, self.input.user_id, self.input.service 377 + ) 356 378 if not post: 357 379 return 358 - 359 - mappings = database.find_mappings(self.db, post['id'], self.service, self.user_id) 380 + 381 + mappings = database.find_mappings( 382 + self.db, post["id"], self.service, self.user_id 383 + ) 360 384 for mapping in mappings[::-1]: 361 385 LOGGER.info("Deleting '%s'...", mapping[0]) 362 - requests.delete(f'{self.service}/api/v1/statuses/{mapping[0]}', headers={ 363 - 'Authorization': f'Bearer {self.token}' 364 - }) 386 + requests.delete( 387 + f"{self.service}/api/v1/statuses/{mapping[0]}", 388 + headers={"Authorization": f"Bearer {self.token}"}, 389 + ) 365 390 database.delete_post(self.db, mapping[0], self.service, self.user_id) 366 - 391 + 367 392 def accept_repost(self, repost_id: str, reposted_id: str): 368 393 repost = self.__delete_repost(repost_id) 369 394 if not repost: 370 395 return None 371 - 372 - reposted = database.find_post(self.db, reposted_id, self.input.user_id, self.input.service) 396 + 397 + reposted = database.find_post( 398 + self.db, reposted_id, self.input.user_id, self.input.service 399 + ) 373 400 if not reposted: 374 401 return 375 - 376 - mappings = database.find_mappings(self.db, reposted['id'], self.service, self.user_id) 402 + 403 + mappings = database.find_mappings( 404 + self.db, reposted["id"], self.service, self.user_id 405 + ) 377 406 if mappings: 378 - rsp = requests.post(f'{self.service}/api/v1/statuses/{mappings[0][0]}/reblog', headers={ 379 - 'Authorization': f'Bearer {self.token}' 380 - }) 381 - 407 + rsp = requests.post( 408 + f"{self.service}/api/v1/statuses/{mappings[0][0]}/reblog", 409 + headers={"Authorization": f"Bearer {self.token}"}, 410 + ) 411 + 382 412 if rsp.status_code != 200: 383 - LOGGER.error("Failed to boost status! status_code: %s, msg: %s", rsp.status_code, rsp.content) 413 + LOGGER.error( 414 + "Failed to boost status! status_code: %s, msg: %s", 415 + rsp.status_code, 416 + rsp.content, 417 + ) 384 418 return 385 - 419 + 386 420 internal_id = database.insert_repost( 387 - self.db, 388 - rsp.json()['id'], 389 - reposted['id'], 390 - self.user_id, 391 - self.service) 392 - database.insert_mapping(self.db, repost['id'], internal_id) 393 - 421 + self.db, rsp.json()["id"], reposted["id"], self.user_id, self.service 422 + ) 423 + database.insert_mapping(self.db, repost["id"], internal_id) 424 + 394 425 def __delete_repost(self, repost_id: str) -> dict | None: 395 - repost = database.find_post(self.db, repost_id, self.input.user_id, self.input.service) 426 + repost = database.find_post( 427 + self.db, repost_id, self.input.user_id, self.input.service 428 + ) 396 429 if not repost: 397 430 return None 398 - 399 - mappings = database.find_mappings(self.db, repost['id'], self.service, self.user_id) 400 - reposted_mappings = database.find_mappings(self.db, repost['reposted_id'], self.service, self.user_id) 431 + 432 + mappings = database.find_mappings( 433 + self.db, repost["id"], self.service, self.user_id 434 + ) 435 + reposted_mappings = database.find_mappings( 436 + self.db, repost["reposted_id"], self.service, self.user_id 437 + ) 401 438 if mappings and reposted_mappings: 402 439 LOGGER.info("Deleting '%s'...", mappings[0][0]) 403 - requests.post(f'{self.service}/api/v1/statuses/{reposted_mappings[0][0]}/unreblog', headers={ 404 - 'Authorization': f'Bearer {self.token}' 405 - }) 440 + requests.post( 441 + f"{self.service}/api/v1/statuses/{reposted_mappings[0][0]}/unreblog", 442 + headers={"Authorization": f"Bearer {self.token}"}, 443 + ) 406 444 database.delete_post(self.db, mappings[0][0], self.user_id, self.service) 407 445 return repost 408 - 446 + 409 447 def delete_repost(self, repost_id: str): 410 - self.__delete_repost(repost_id) 448 + self.__delete_repost(repost_id)
+26 -17
misskey/common.py
··· 1 1 import cross 2 2 from util.media import MediaInfo 3 3 4 + 4 5 class MisskeyPost(cross.Post): 5 - def __init__(self, instance_url: str, note: dict, tokens: list[cross.Token], files: list[MediaInfo]) -> None: 6 + def __init__( 7 + self, 8 + instance_url: str, 9 + note: dict, 10 + tokens: list[cross.Token], 11 + files: list[MediaInfo], 12 + ) -> None: 6 13 super().__init__() 7 14 self.note = note 8 - self.id = note['id'] 9 - self.parent_id = note.get('replyId') 15 + self.id = note["id"] 16 + self.parent_id = note.get("replyId") 10 17 self.tokens = tokens 11 - self.timestamp = note['createdAt'] 18 + self.timestamp = note["createdAt"] 12 19 self.media_attachments = files 13 - self.spoiler = note.get('cw') 14 - self.sensitive = any([a.get('isSensitive', False) for a in note.get('files', [])]) 15 - self.url = instance_url + '/notes/' + note['id'] 16 - 20 + self.spoiler = note.get("cw") 21 + self.sensitive = any( 22 + [a.get("isSensitive", False) for a in note.get("files", [])] 23 + ) 24 + self.url = instance_url + "/notes/" + note["id"] 25 + 17 26 def get_id(self) -> str: 18 27 return self.id 19 - 28 + 20 29 def get_parent_id(self) -> str | None: 21 30 return self.parent_id 22 - 31 + 23 32 def get_tokens(self) -> list[cross.Token]: 24 33 return self.tokens 25 34 26 35 def get_text_type(self) -> str: 27 36 return "text/x.misskeymarkdown" 28 - 37 + 29 38 def get_timestamp(self) -> str: 30 39 return self.timestamp 31 - 40 + 32 41 def get_attachments(self) -> list[MediaInfo]: 33 42 return self.media_attachments 34 - 43 + 35 44 def get_spoiler(self) -> str | None: 36 45 return self.spoiler 37 - 46 + 38 47 def get_languages(self) -> list[str]: 39 48 return [] 40 - 49 + 41 50 def is_sensitive(self) -> bool: 42 51 return self.sensitive or (self.spoiler is not None) 43 - 52 + 44 53 def get_post_url(self) -> str | None: 45 - return self.url 54 + return self.url
+115 -92
misskey/input.py
··· 1 - import requests, websockets 2 1 import asyncio 3 - import json, uuid 2 + import json 4 3 import re 4 + import uuid 5 + from typing import Any, Callable 5 6 6 - from misskey.common import MisskeyPost 7 + import requests 8 + import websockets 7 9 8 - import cross, util.database as database 10 + import cross 11 + import util.database as database 9 12 import util.md_util as md_util 13 + from misskey.common import MisskeyPost 10 14 from util.media import MediaInfo, download_media 11 15 from util.util import LOGGER, as_envvar 12 16 13 - from typing import Callable, Any 14 - 15 - ALLOWED_VISIBILITY = ['public', 'home'] 16 - 17 - class MisskeyInputOptions(): 17 + ALLOWED_VISIBILITY = ["public", "home"] 18 + 19 + 20 + class MisskeyInputOptions: 18 21 def __init__(self, o: dict) -> None: 19 22 self.allowed_visibility = ALLOWED_VISIBILITY 20 - self.filters = [re.compile(f) for f in o.get('regex_filters', [])] 21 - 22 - allowed_visibility = o.get('allowed_visibility') 23 + self.filters = [re.compile(f) for f in o.get("regex_filters", [])] 24 + 25 + allowed_visibility = o.get("allowed_visibility") 23 26 if allowed_visibility is not None: 24 27 if any([v not in ALLOWED_VISIBILITY for v in allowed_visibility]): 25 - raise ValueError(f"'allowed_visibility' only accepts {', '.join(ALLOWED_VISIBILITY)}, got: {allowed_visibility}") 28 + raise ValueError( 29 + f"'allowed_visibility' only accepts {', '.join(ALLOWED_VISIBILITY)}, got: {allowed_visibility}" 30 + ) 26 31 self.allowed_visibility = allowed_visibility 32 + 27 33 28 34 class MisskeyInput(cross.Input): 29 35 def __init__(self, settings: dict, db: cross.DataBaseWorker) -> None: 30 - self.options = MisskeyInputOptions(settings.get('options', {})) 31 - self.token = as_envvar(settings.get('token')) or (_ for _ in ()).throw(ValueError("'token' is required")) 32 - instance: str = as_envvar(settings.get('instance')) or (_ for _ in ()).throw(ValueError("'instance' is required")) 33 - 34 - service = instance[:-1] if instance.endswith('/') else instance 35 - 36 + self.options = MisskeyInputOptions(settings.get("options", {})) 37 + self.token = as_envvar(settings.get("token")) or (_ for _ in ()).throw( 38 + ValueError("'token' is required") 39 + ) 40 + instance: str = as_envvar(settings.get("instance")) or (_ for _ in ()).throw( 41 + ValueError("'instance' is required") 42 + ) 43 + 44 + service = instance[:-1] if instance.endswith("/") else instance 45 + 36 46 LOGGER.info("Verifying %s credentails...", service) 37 - responce = requests.post(f"{instance}/api/i", json={ 'i': self.token }, headers={ 38 - "Content-Type": "application/json" 39 - }) 47 + responce = requests.post( 48 + f"{instance}/api/i", 49 + json={"i": self.token}, 50 + headers={"Content-Type": "application/json"}, 51 + ) 40 52 if responce.status_code != 200: 41 53 LOGGER.error("Failed to validate user credentials!") 42 54 responce.raise_for_status() 43 55 return 44 - 56 + 45 57 super().__init__(service, responce.json()["id"], settings, db) 46 - 58 + 47 59 def _on_note(self, outputs: list[cross.Output], note: dict): 48 - if note['userId'] != self.user_id: 60 + if note["userId"] != self.user_id: 49 61 return 50 - 51 - if note.get('visibility') not in self.options.allowed_visibility: 52 - LOGGER.info("Skipping '%s'! '%s' visibility..", note['id'], note.get('visibility')) 62 + 63 + if note.get("visibility") not in self.options.allowed_visibility: 64 + LOGGER.info( 65 + "Skipping '%s'! '%s' visibility..", note["id"], note.get("visibility") 66 + ) 53 67 return 54 - 68 + 55 69 # TODO polls not supported on bsky. maybe 3rd party? skip for now 56 70 # we don't handle reblogs. possible with bridgy(?) and self 57 - if note.get('poll'): 58 - LOGGER.info("Skipping '%s'! Contains a poll..", note['id']) 71 + if note.get("poll"): 72 + LOGGER.info("Skipping '%s'! Contains a poll..", note["id"]) 59 73 return 60 - 61 - renote: dict | None = note.get('renote') 74 + 75 + renote: dict | None = note.get("renote") 62 76 if renote: 63 - if note.get('text') is not None: 64 - LOGGER.info("Skipping '%s'! Quote..", note['id']) 77 + if note.get("text") is not None: 78 + LOGGER.info("Skipping '%s'! Quote..", note["id"]) 65 79 return 66 - 67 - if renote.get('userId') != self.user_id: 68 - LOGGER.info("Skipping '%s'! Reblog of other user..", note['id']) 80 + 81 + if renote.get("userId") != self.user_id: 82 + LOGGER.info("Skipping '%s'! Reblog of other user..", note["id"]) 69 83 return 70 - 71 - success = database.try_insert_repost(self.db, note['id'], renote['id'], self.user_id, self.service) 84 + 85 + success = database.try_insert_repost( 86 + self.db, note["id"], renote["id"], self.user_id, self.service 87 + ) 72 88 if not success: 73 - LOGGER.info("Skipping '%s' as renoted note was not found in db!", note['id']) 89 + LOGGER.info( 90 + "Skipping '%s' as renoted note was not found in db!", note["id"] 91 + ) 74 92 return 75 - 93 + 76 94 for output in outputs: 77 - output.accept_repost(note['id'], renote['id']) 95 + output.accept_repost(note["id"], renote["id"]) 78 96 return 79 - 80 - reply_id: str | None = note.get('replyId') 97 + 98 + reply_id: str | None = note.get("replyId") 81 99 if reply_id: 82 - if note.get('reply', {}).get('userId') != self.user_id: 83 - LOGGER.info("Skipping '%s'! Reply to other user..", note['id']) 100 + if note.get("reply", {}).get("userId") != self.user_id: 101 + LOGGER.info("Skipping '%s'! Reply to other user..", note["id"]) 84 102 return 85 - 86 - success = database.try_insert_post(self.db, note['id'], reply_id, self.user_id, self.service) 103 + 104 + success = database.try_insert_post( 105 + self.db, note["id"], reply_id, self.user_id, self.service 106 + ) 87 107 if not success: 88 - LOGGER.info("Skipping '%s' as parent note was not found in db!", note['id']) 108 + LOGGER.info("Skipping '%s' as parent note was not found in db!", note["id"]) 89 109 return 90 - 91 - mention_handles: dict = note.get('mentionHandles') or {} 92 - tags: list[str] = note.get('tags') or [] 93 - 110 + 111 + mention_handles: dict = note.get("mentionHandles") or {} 112 + tags: list[str] = note.get("tags") or [] 113 + 94 114 handles: list[tuple[str, str]] = [] 95 115 for key, value in mention_handles.items(): 96 116 handles.append((value, value)) 97 - 98 - tokens = md_util.tokenize_markdown(note.get('text', ''), tags, handles) 117 + 118 + tokens = md_util.tokenize_markdown(note.get("text", ""), tags, handles) 99 119 if not cross.test_filters(tokens, self.options.filters): 100 - LOGGER.info("Skipping '%s'. Matched a filter!", note['id']) 120 + LOGGER.info("Skipping '%s'. Matched a filter!", note["id"]) 101 121 return 102 - 103 - LOGGER.info("Crossposting '%s'...", note['id']) 104 - 122 + 123 + LOGGER.info("Crossposting '%s'...", note["id"]) 124 + 105 125 media_attachments: list[MediaInfo] = [] 106 - for attachment in note.get('files', []): 107 - LOGGER.info("Downloading %s...", attachment['url']) 108 - info = download_media(attachment['url'], attachment.get('comment') or '') 126 + for attachment in note.get("files", []): 127 + LOGGER.info("Downloading %s...", attachment["url"]) 128 + info = download_media(attachment["url"], attachment.get("comment") or "") 109 129 if not info: 110 - LOGGER.error("Skipping '%s'. Failed to download media!", note['id']) 130 + LOGGER.error("Skipping '%s'. Failed to download media!", note["id"]) 111 131 return 112 132 media_attachments.append(info) 113 - 133 + 114 134 cross_post = MisskeyPost(self.service, note, tokens, media_attachments) 115 135 for output in outputs: 116 136 output.accept_post(cross_post) 117 - 137 + 118 138 def _on_delete(self, outputs: list[cross.Output], note: dict): 119 139 # TODO handle deletes 120 140 pass 121 - 141 + 122 142 def _on_message(self, outputs: list[cross.Output], data: dict): 123 - 124 - if data['type'] == 'channel': 125 - type: str = data['body']['type'] 126 - if type == 'note' or type == 'reply': 127 - note_body = data['body']['body'] 143 + if data["type"] == "channel": 144 + type: str = data["body"]["type"] 145 + if type == "note" or type == "reply": 146 + note_body = data["body"]["body"] 128 147 self._on_note(outputs, note_body) 129 148 return 130 - 149 + 131 150 pass 132 - 151 + 133 152 async def _send_keepalive(self, ws: websockets.WebSocketClientProtocol): 134 153 while ws.open: 135 154 try: ··· 143 162 except Exception as e: 144 163 LOGGER.error(f"Error sending keepalive: {e}") 145 164 break 146 - 165 + 147 166 async def _subscribe_to_home(self, ws: websockets.WebSocketClientProtocol): 148 - await ws.send(json.dumps({ 149 - "type": "connect", 150 - "body": { 151 - "channel": "homeTimeline", 152 - "id": str(uuid.uuid4()) 153 - } 154 - })) 167 + await ws.send( 168 + json.dumps( 169 + { 170 + "type": "connect", 171 + "body": {"channel": "homeTimeline", "id": str(uuid.uuid4())}, 172 + } 173 + ) 174 + ) 155 175 LOGGER.info("Subscribed to 'homeTimeline' channel...") 156 - 157 - 158 - async def listen(self, outputs: list[cross.Output], submit: Callable[[Callable[[], Any]], Any]): 159 - streaming: str = f"wss://{self.service.split("://", 1)[1]}" 176 + 177 + async def listen( 178 + self, outputs: list[cross.Output], submit: Callable[[Callable[[], Any]], Any] 179 + ): 180 + streaming: str = f"wss://{self.service.split('://', 1)[1]}" 160 181 url: str = f"{streaming}/streaming?i={self.token}" 161 - 162 - async for ws in websockets.connect(url, extra_headers={"User-Agent": "XPost/0.0.3"}): 182 + 183 + async for ws in websockets.connect( 184 + url, extra_headers={"User-Agent": "XPost/0.0.3"} 185 + ): 163 186 try: 164 187 LOGGER.info("Listening to %s...", streaming) 165 188 await self._subscribe_to_home(ws) 166 - 189 + 167 190 async def listen_for_messages(): 168 191 async for msg in ws: 169 192 # TODO listen to deletes somehow 170 193 submit(lambda: self._on_message(outputs, json.loads(msg))) 171 - 194 + 172 195 keepalive = asyncio.create_task(self._send_keepalive(ws)) 173 196 listen = asyncio.create_task(listen_for_messages()) 174 - 197 + 175 198 await asyncio.gather(keepalive, listen) 176 199 except websockets.ConnectionClosedError as e: 177 200 LOGGER.error(e, stack_info=True, exc_info=True) 178 201 LOGGER.info("Reconnecting to %s...", streaming) 179 - continue 202 + continue
+8 -5
misskey/mfm_util.py
··· 1 - import re, cross 1 + import re 2 + 3 + import cross 4 + 5 + MFM_PATTERN = re.compile(r"\$\[([^\[\]]+)\]") 2 6 3 - MFM_PATTERN = re.compile(r'\$\[([^\[\]]+)\]') 4 7 5 8 def strip_mfm(tokens: list[cross.Token]) -> tuple[list[cross.Token], bool]: 6 9 modified = False ··· 22 25 23 26 return tokens, modified 24 27 28 + 25 29 def __strip_mfm(text: str) -> str: 26 30 def match_contents(match: re.Match[str]): 27 31 content = match.group(1).strip() 28 - parts = content.split(' ', 1) 29 - return parts[1] if len(parts) > 1 else '' 32 + parts = content.split(" ", 1) 33 + return parts[1] if len(parts) > 1 else "" 30 34 31 35 while MFM_PATTERN.search(text): 32 36 text = MFM_PATTERN.sub(match_contents, text) 33 37 34 38 return text 35 -
+118 -67
util/database.py
··· 1 + import json 2 + import queue 1 3 import sqlite3 2 - from concurrent.futures import Future 3 4 import threading 4 - import queue 5 - import json 5 + from concurrent.futures import Future 6 6 7 - class DataBaseWorker(): 7 + 8 + class DataBaseWorker: 8 9 def __init__(self, database: str) -> None: 9 10 super(DataBaseWorker, self).__init__() 10 11 self.database = database ··· 14 15 self.conn = sqlite3.connect(self.database, check_same_thread=False) 15 16 self.lock = threading.Lock() 16 17 self.thread.start() 17 - 18 + 18 19 def _run(self): 19 20 while not self.shutdown_event.is_set(): 20 21 try: ··· 29 30 self.queue.task_done() 30 31 except queue.Empty: 31 32 continue 32 - 33 - def execute(self, sql: str, params = ()): 33 + 34 + def execute(self, sql: str, params=()): 34 35 def task(conn: sqlite3.Connection): 35 36 cursor = conn.execute(sql, params) 36 37 conn.commit() 37 38 return cursor.fetchall() 38 - 39 + 39 40 future = Future() 40 41 self.queue.put((task, future)) 41 42 return future.result() 42 - 43 + 43 44 def close(self): 44 45 self.shutdown_event.set() 45 46 self.thread.join() 46 47 with self.lock: 47 48 self.conn.close() 49 + 48 50 49 51 def try_insert_repost( 50 52 db: DataBaseWorker, 51 53 post_id: str, 52 54 reposted_id: str, 53 55 input_user: str, 54 - input_service: str) -> bool: 55 - 56 + input_service: str, 57 + ) -> bool: 56 58 reposted = find_post(db, reposted_id, input_user, input_service) 57 59 if not reposted: 58 60 return False 59 - 60 - insert_repost(db, post_id, reposted['id'], input_user, input_service) 61 + 62 + insert_repost(db, post_id, reposted["id"], input_user, input_service) 61 63 return True 62 - 64 + 63 65 64 66 def try_insert_post( 65 - db: DataBaseWorker, 67 + db: DataBaseWorker, 66 68 post_id: str, 67 69 in_reply: str | None, 68 70 input_user: str, 69 - input_service: str) -> bool: 71 + input_service: str, 72 + ) -> bool: 70 73 root_id = None 71 74 parent_id = None 72 - 75 + 73 76 if in_reply: 74 77 parent_post = find_post(db, in_reply, input_user, input_service) 75 78 if not parent_post: 76 79 return False 77 - 78 - root_id = parent_post['id'] 80 + 81 + root_id = parent_post["id"] 79 82 parent_id = root_id 80 - if parent_post['root_id']: 81 - root_id = parent_post['root_id'] 82 - 83 + if parent_post["root_id"]: 84 + root_id = parent_post["root_id"] 85 + 83 86 if root_id and parent_id: 84 - insert_reply(db,post_id, input_user, input_service, parent_id, root_id) 87 + insert_reply(db, post_id, input_user, input_service, parent_id, root_id) 85 88 else: 86 89 insert_post(db, post_id, input_user, input_service) 87 - 90 + 88 91 return True 89 92 90 - def insert_repost(db: DataBaseWorker, identifier: str, reposted_id: int, user_id: str, serivce: str) -> int: 93 + 94 + def insert_repost( 95 + db: DataBaseWorker, identifier: str, reposted_id: int, user_id: str, serivce: str 96 + ) -> int: 91 97 db.execute( 92 98 """ 93 99 INSERT INTO posts (user_id, service, identifier, reposted_id) 94 100 VALUES (?, ?, ?, ?); 95 - """, (user_id, serivce, identifier, reposted_id)) 101 + """, 102 + (user_id, serivce, identifier, reposted_id), 103 + ) 96 104 return db.execute("SELECT last_insert_rowid();", ())[0][0] 105 + 97 106 98 107 def insert_post(db: DataBaseWorker, identifier: str, user_id: str, serivce: str) -> int: 99 108 db.execute( 100 109 """ 101 110 INSERT INTO posts (user_id, service, identifier) 102 111 VALUES (?, ?, ?); 103 - """, (user_id, serivce, identifier)) 112 + """, 113 + (user_id, serivce, identifier), 114 + ) 104 115 return db.execute("SELECT last_insert_rowid();", ())[0][0] 105 116 106 - def insert_reply(db: DataBaseWorker, identifier: str, user_id: str, serivce: str, parent: int, root: int) -> int: 117 + 118 + def insert_reply( 119 + db: DataBaseWorker, 120 + identifier: str, 121 + user_id: str, 122 + serivce: str, 123 + parent: int, 124 + root: int, 125 + ) -> int: 107 126 db.execute( 108 127 """ 109 128 INSERT INTO posts (user_id, service, identifier, parent_id, root_id) 110 129 VALUES (?, ?, ?, ?, ?); 111 - """, (user_id, serivce, identifier, parent, root)) 130 + """, 131 + (user_id, serivce, identifier, parent, root), 132 + ) 112 133 return db.execute("SELECT last_insert_rowid();", ())[0][0] 113 134 135 + 114 136 def insert_mapping(db: DataBaseWorker, original: int, mapped: int): 115 - db.execute(""" 137 + db.execute( 138 + """ 116 139 INSERT INTO mappings (original_post_id, mapped_post_id) 117 140 VALUES (?, ?); 118 - """, (original, mapped)) 141 + """, 142 + (original, mapped), 143 + ) 144 + 119 145 120 146 def delete_post(db: DataBaseWorker, identifier: str, user_id: str, serivce: str): 121 147 db.execute( ··· 124 150 WHERE identifier = ? 125 151 AND service = ? 126 152 AND user_id = ? 127 - """, (identifier, serivce, user_id)) 128 - 153 + """, 154 + (identifier, serivce, user_id), 155 + ) 156 + 157 + 129 158 def fetch_data(db: DataBaseWorker, identifier: str, user_id: str, service: str) -> dict: 130 159 result = db.execute( 131 160 """ ··· 134 163 WHERE identifier = ? 135 164 AND user_id = ? 136 165 AND service = ? 137 - """, (identifier, user_id, service)) 166 + """, 167 + (identifier, user_id, service), 168 + ) 138 169 if not result or not result[0]: 139 170 return {} 140 171 return json.loads(result[0][0]) 141 172 142 - def store_data(db: DataBaseWorker, identifier: str, user_id: str, service: str, extra_data: dict) -> None: 173 + 174 + def store_data( 175 + db: DataBaseWorker, identifier: str, user_id: str, service: str, extra_data: dict 176 + ) -> None: 143 177 db.execute( 144 178 """ 145 179 UPDATE posts ··· 148 182 AND user_id = ? 149 183 AND service = ? 150 184 """, 151 - (json.dumps(extra_data), identifier, user_id, service) 185 + (json.dumps(extra_data), identifier, user_id, service), 152 186 ) 153 187 154 - def find_mappings(db: DataBaseWorker, original_post: int, service: str, user_id: str) -> list[str]: 188 + 189 + def find_mappings( 190 + db: DataBaseWorker, original_post: int, service: str, user_id: str 191 + ) -> list[str]: 155 192 return db.execute( 156 193 """ 157 194 SELECT p.identifier ··· 163 200 AND p.user_id = ? 164 201 ORDER BY p.id; 165 202 """, 166 - (original_post, service, user_id)) 167 - 203 + (original_post, service, user_id), 204 + ) 205 + 206 + 168 207 def find_post_by_id(db: DataBaseWorker, id: int) -> dict | None: 169 208 result = db.execute( 170 209 """ 171 210 SELECT user_id, service, identifier, parent_id, root_id, reposted_id 172 211 FROM posts 173 212 WHERE id = ? 174 - """, (id,)) 213 + """, 214 + (id,), 215 + ) 175 216 if not result: 176 217 return None 177 218 user_id, service, identifier, parent_id, root_id, reposted_id = result[0] 178 219 return { 179 - 'user_id': user_id, 180 - 'service': service, 181 - 'identifier': identifier, 182 - 'parent_id': parent_id, 183 - 'root_id': root_id, 184 - 'reposted_id': reposted_id 220 + "user_id": user_id, 221 + "service": service, 222 + "identifier": identifier, 223 + "parent_id": parent_id, 224 + "root_id": root_id, 225 + "reposted_id": reposted_id, 185 226 } 186 227 187 - def find_post(db: DataBaseWorker, identifier: str, user_id: str, service: str) -> dict | None: 228 + 229 + def find_post( 230 + db: DataBaseWorker, identifier: str, user_id: str, service: str 231 + ) -> dict | None: 188 232 result = db.execute( 189 233 """ 190 234 SELECT id, parent_id, root_id, reposted_id ··· 192 236 WHERE identifier = ? 193 237 AND user_id = ? 194 238 AND service = ? 195 - """, (identifier, user_id, service)) 239 + """, 240 + (identifier, user_id, service), 241 + ) 196 242 if not result: 197 243 return None 198 244 id, parent_id, root_id, reposted_id = result[0] 199 245 return { 200 - 'id': id, 201 - 'parent_id': parent_id, 202 - 'root_id': root_id, 203 - 'reposted_id': reposted_id 246 + "id": id, 247 + "parent_id": parent_id, 248 + "root_id": root_id, 249 + "reposted_id": reposted_id, 204 250 } 251 + 205 252 206 253 def find_mapped_thread( 207 - db: DataBaseWorker, 254 + db: DataBaseWorker, 208 255 parent_id: str, 209 256 input_user: str, 210 257 input_service: str, 211 258 output_user: str, 212 - output_service: str): 213 - 259 + output_service: str, 260 + ): 214 261 reply_data: dict | None = find_post(db, parent_id, input_user, input_service) 215 262 if not reply_data: 216 263 return None 217 - 218 - reply_mappings: list[str] | None = find_mappings(db, reply_data['id'], output_service, output_user) 264 + 265 + reply_mappings: list[str] | None = find_mappings( 266 + db, reply_data["id"], output_service, output_user 267 + ) 219 268 if not reply_mappings: 220 269 return None 221 - 270 + 222 271 reply_identifier: str = reply_mappings[-1] 223 272 root_identifier: str = reply_mappings[0] 224 - if reply_data['root_id']: 225 - root_data = find_post_by_id(db, reply_data['root_id']) 273 + if reply_data["root_id"]: 274 + root_data = find_post_by_id(db, reply_data["root_id"]) 226 275 if not root_data: 227 276 return None 228 - 229 - root_mappings = find_mappings(db, reply_data['root_id'], output_service, output_user) 277 + 278 + root_mappings = find_mappings( 279 + db, reply_data["root_id"], output_service, output_user 280 + ) 230 281 if not root_mappings: 231 282 return None 232 283 root_identifier = root_mappings[0] 233 - 284 + 234 285 return ( 235 - root_identifier[0], # real ids 286 + root_identifier[0], # real ids 236 287 reply_identifier[0], 237 - reply_data['root_id'], # db ids 238 - reply_data['id'] 239 - ) 288 + reply_data["root_id"], # db ids 289 + reply_data["id"], 290 + )
+82 -79
util/html_util.py
··· 1 1 from html.parser import HTMLParser 2 + 2 3 import cross 4 + 3 5 4 6 class HTMLPostTokenizer(HTMLParser): 5 7 def __init__(self) -> None: 6 8 super().__init__() 7 9 self.tokens: list[cross.Token] = [] 8 - 10 + 9 11 self.mentions: list[tuple[str, str]] 10 12 self.tags: list[str] 11 - 13 + 12 14 self.in_pre = False 13 15 self.in_code = False 14 - 16 + 15 17 self.current_tag_stack = [] 16 18 self.list_stack = [] 17 - 19 + 18 20 self.anchor_stack = [] 19 21 self.anchor_data = [] 20 - 22 + 21 23 def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None: 22 24 attrs_dict = dict(attrs) 23 - 25 + 24 26 def append_newline(): 25 27 if self.tokens: 26 28 last_token = self.tokens[-1] 27 - if isinstance(last_token, cross.TextToken) and not last_token.text.endswith('\n'): 28 - self.tokens.append(cross.TextToken('\n')) 29 - 29 + if isinstance( 30 + last_token, cross.TextToken 31 + ) and not last_token.text.endswith("\n"): 32 + self.tokens.append(cross.TextToken("\n")) 33 + 30 34 match tag: 31 - case 'br': 32 - self.tokens.append(cross.TextToken(' \n')) 33 - case 'a': 34 - href = attrs_dict.get('href', '') 35 + case "br": 36 + self.tokens.append(cross.TextToken(" \n")) 37 + case "a": 38 + href = attrs_dict.get("href", "") 35 39 self.anchor_stack.append(href) 36 - case 'strong', 'b': 37 - self.tokens.append(cross.TextToken('**')) 38 - case 'em', 'i': 39 - self.tokens.append(cross.TextToken('*')) 40 - case 'del', 's': 41 - self.tokens.append(cross.TextToken('~~')) 42 - case 'code': 40 + case "strong", "b": 41 + self.tokens.append(cross.TextToken("**")) 42 + case "em", "i": 43 + self.tokens.append(cross.TextToken("*")) 44 + case "del", "s": 45 + self.tokens.append(cross.TextToken("~~")) 46 + case "code": 43 47 if not self.in_pre: 44 - self.tokens.append(cross.TextToken('`')) 48 + self.tokens.append(cross.TextToken("`")) 45 49 self.in_code = True 46 - case 'pre': 50 + case "pre": 47 51 append_newline() 48 - self.tokens.append(cross.TextToken('```\n')) 52 + self.tokens.append(cross.TextToken("```\n")) 49 53 self.in_pre = True 50 - case 'blockquote': 54 + case "blockquote": 51 55 append_newline() 52 - self.tokens.append(cross.TextToken('> ')) 53 - case 'ul', 'ol': 56 + self.tokens.append(cross.TextToken("> ")) 57 + case "ul", "ol": 54 58 self.list_stack.append(tag) 55 59 append_newline() 56 - case 'li': 57 - indent = ' ' * (len(self.list_stack) - 1) 58 - if self.list_stack and self.list_stack[-1] == 'ul': 59 - self.tokens.append(cross.TextToken(f'{indent}- ')) 60 - elif self.list_stack and self.list_stack[-1] == 'ol': 61 - self.tokens.append(cross.TextToken(f'{indent}1. ')) 60 + case "li": 61 + indent = " " * (len(self.list_stack) - 1) 62 + if self.list_stack and self.list_stack[-1] == "ul": 63 + self.tokens.append(cross.TextToken(f"{indent}- ")) 64 + elif self.list_stack and self.list_stack[-1] == "ol": 65 + self.tokens.append(cross.TextToken(f"{indent}1. ")) 62 66 case _: 63 - if tag in {'h1', 'h2', 'h3', 'h4', 'h5', 'h6'}: 67 + if tag in {"h1", "h2", "h3", "h4", "h5", "h6"}: 64 68 level = int(tag[1]) 65 69 self.tokens.append(cross.TextToken("\n" + "#" * level + " ")) 66 - 70 + 67 71 self.current_tag_stack.append(tag) 68 - 72 + 69 73 def handle_data(self, data: str) -> None: 70 74 if self.anchor_stack: 71 75 self.anchor_data.append(data) 72 76 else: 73 77 self.tokens.append(cross.TextToken(data)) 74 - 78 + 75 79 def handle_endtag(self, tag: str) -> None: 76 80 if not self.current_tag_stack: 77 81 return 78 - 82 + 79 83 if tag in self.current_tag_stack: 80 84 self.current_tag_stack.remove(tag) 81 - 85 + 82 86 match tag: 83 - case 'p': 84 - self.tokens.append(cross.TextToken('\n\n')) 85 - case 'a': 87 + case "p": 88 + self.tokens.append(cross.TextToken("\n\n")) 89 + case "a": 86 90 href = self.anchor_stack.pop() 87 - anchor_data = ''.join(self.anchor_data) 91 + anchor_data = "".join(self.anchor_data) 88 92 self.anchor_data = [] 89 - 90 - if anchor_data.startswith('#'): 93 + 94 + if anchor_data.startswith("#"): 91 95 as_tag = anchor_data[1:].lower() 92 96 if any(as_tag == block for block in self.tags): 93 97 self.tokens.append(cross.TagToken(anchor_data[1:])) 94 - elif anchor_data.startswith('@'): 98 + elif anchor_data.startswith("@"): 95 99 match = next( 96 - (pair for pair in self.mentions if anchor_data in pair), 97 - None 100 + (pair for pair in self.mentions if anchor_data in pair), None 98 101 ) 99 - 102 + 100 103 if match: 101 - self.tokens.append(cross.MentionToken(match[1], '')) 104 + self.tokens.append(cross.MentionToken(match[1], "")) 102 105 else: 103 106 self.tokens.append(cross.LinkToken(href, anchor_data)) 104 - case 'strong', 'b': 105 - self.tokens.append(cross.TextToken('**')) 106 - case 'em', 'i': 107 - self.tokens.append(cross.TextToken('*')) 108 - case 'del', 's': 109 - self.tokens.append(cross.TextToken('~~')) 110 - case 'code': 107 + case "strong", "b": 108 + self.tokens.append(cross.TextToken("**")) 109 + case "em", "i": 110 + self.tokens.append(cross.TextToken("*")) 111 + case "del", "s": 112 + self.tokens.append(cross.TextToken("~~")) 113 + case "code": 111 114 if not self.in_pre and self.in_code: 112 - self.tokens.append(cross.TextToken('`')) 115 + self.tokens.append(cross.TextToken("`")) 113 116 self.in_code = False 114 - case 'pre': 115 - self.tokens.append(cross.TextToken('\n```\n')) 117 + case "pre": 118 + self.tokens.append(cross.TextToken("\n```\n")) 116 119 self.in_pre = False 117 - case 'blockquote': 118 - self.tokens.append(cross.TextToken('\n')) 119 - case 'ul', 'ol': 120 + case "blockquote": 121 + self.tokens.append(cross.TextToken("\n")) 122 + case "ul", "ol": 120 123 if self.list_stack: 121 124 self.list_stack.pop() 122 - self.tokens.append(cross.TextToken('\n')) 123 - case 'li': 124 - self.tokens.append(cross.TextToken('\n')) 125 + self.tokens.append(cross.TextToken("\n")) 126 + case "li": 127 + self.tokens.append(cross.TextToken("\n")) 125 128 case _: 126 - if tag in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']: 127 - self.tokens.append(cross.TextToken('\n')) 128 - 129 + if tag in ["h1", "h2", "h3", "h4", "h5", "h6"]: 130 + self.tokens.append(cross.TextToken("\n")) 131 + 129 132 def get_tokens(self) -> list[cross.Token]: 130 133 if not self.tokens: 131 134 return [] 132 - 135 + 133 136 combined: list[cross.Token] = [] 134 137 buffer: list[str] = [] 135 - 138 + 136 139 def flush_buffer(): 137 140 if buffer: 138 - merged = ''.join(buffer) 141 + merged = "".join(buffer) 139 142 combined.append(cross.TextToken(text=merged)) 140 143 buffer.clear() 141 144 ··· 145 148 else: 146 149 flush_buffer() 147 150 combined.append(token) 148 - 151 + 149 152 flush_buffer() 150 - 153 + 151 154 if combined and isinstance(combined[-1], cross.TextToken): 152 - if combined[-1].text.endswith('\n\n'): 155 + if combined[-1].text.endswith("\n\n"): 153 156 combined[-1] = cross.TextToken(combined[-1].text[:-2]) 154 157 return combined 155 - 158 + 156 159 def reset(self): 157 160 """Reset the parser state for reuse.""" 158 161 super().reset() 159 162 self.tokens = [] 160 - 163 + 161 164 self.mentions = [] 162 165 self.tags = [] 163 - 166 + 164 167 self.in_pre = False 165 168 self.in_code = False 166 - 169 + 167 170 self.current_tag_stack = [] 168 171 self.anchor_stack = [] 169 - self.list_stack = [] 172 + self.list_stack = []
+44 -33
util/md_util.py
··· 4 4 import util.html_util as html_util 5 5 import util.util as util 6 6 7 - URL = re.compile(r'(?:(?:[A-Za-z][A-Za-z0-9+.-]*://)|mailto:)[^\s]+', re.IGNORECASE) 8 - MD_INLINE_LINK = re.compile(r"\[([^\]]+)\]\(\s*((?:(?:[A-Za-z][A-Za-z0-9+.\-]*://)|mailto:)[^\s\)]+)\s*\)", re.IGNORECASE) 9 - MD_AUTOLINK = re.compile(r"<((?:(?:[A-Za-z][A-Za-z0-9+.\-]*://)|mailto:)[^\s>]+)>", re.IGNORECASE) 10 - HASHTAG = re.compile(r'(?<!\w)\#([\w]+)') 11 - FEDIVERSE_HANDLE = re.compile(r'(?<![\w@])@([\w\.-]+)(?:@([\w\.-]+\.[\w\.-]+))?') 7 + URL = re.compile(r"(?:(?:[A-Za-z][A-Za-z0-9+.-]*://)|mailto:)[^\s]+", re.IGNORECASE) 8 + MD_INLINE_LINK = re.compile( 9 + r"\[([^\]]+)\]\(\s*((?:(?:[A-Za-z][A-Za-z0-9+.\-]*://)|mailto:)[^\s\)]+)\s*\)", 10 + re.IGNORECASE, 11 + ) 12 + MD_AUTOLINK = re.compile( 13 + r"<((?:(?:[A-Za-z][A-Za-z0-9+.\-]*://)|mailto:)[^\s>]+)>", re.IGNORECASE 14 + ) 15 + HASHTAG = re.compile(r"(?<!\w)\#([\w]+)") 16 + FEDIVERSE_HANDLE = re.compile(r"(?<![\w@])@([\w\.-]+)(?:@([\w\.-]+\.[\w\.-]+))?") 12 17 13 - def tokenize_markdown(text: str, tags: list[str], handles: list[tuple[str, str]]) -> list[cross.Token]: 18 + 19 + def tokenize_markdown( 20 + text: str, tags: list[str], handles: list[tuple[str, str]] 21 + ) -> list[cross.Token]: 14 22 if not text: 15 23 return [] 16 - 24 + 17 25 tokenizer = html_util.HTMLPostTokenizer() 18 26 tokenizer.mentions = handles 19 27 tokenizer.tags = tags 20 28 tokenizer.feed(text) 21 29 html_tokens = tokenizer.get_tokens() 22 - 30 + 23 31 tokens: list[cross.Token] = [] 24 - 32 + 25 33 for tk in html_tokens: 26 34 if isinstance(tk, cross.TextToken): 27 35 tokens.extend(__tokenize_md(tk.text, tags, handles)) ··· 29 37 if not tk.label or util.canonical_label(tk.label, tk.href): 30 38 tokens.append(tk) 31 39 continue 32 - 40 + 33 41 tokens.extend(__tokenize_md(f"[{tk.label}]({tk.href})", tags, handles)) 34 42 else: 35 43 tokens.append(tk) 36 - 44 + 37 45 return tokens 38 - 46 + 39 47 40 - def __tokenize_md(text: str, tags: list[str], handles: list[tuple[str, str]]) -> list[cross.Token]: 48 + def __tokenize_md( 49 + text: str, tags: list[str], handles: list[tuple[str, str]] 50 + ) -> list[cross.Token]: 41 51 index: int = 0 42 52 total: int = len(text) 43 53 buffer: list[str] = [] 44 - 54 + 45 55 tokens: list[cross.Token] = [] 46 - 56 + 47 57 def flush(): 48 58 nonlocal buffer 49 59 if buffer: 50 - tokens.append(cross.TextToken(''.join(buffer))) 60 + tokens.append(cross.TextToken("".join(buffer))) 51 61 buffer = [] 52 - 62 + 53 63 while index < total: 54 - if text[index] == '[': 64 + if text[index] == "[": 55 65 md_inline = MD_INLINE_LINK.match(text, index) 56 66 if md_inline: 57 67 flush() ··· 60 70 tokens.append(cross.LinkToken(href, label)) 61 71 index = md_inline.end() 62 72 continue 63 - 64 - if text[index] == '<': 73 + 74 + if text[index] == "<": 65 75 md_auto = MD_AUTOLINK.match(text, index) 66 76 if md_auto: 67 77 flush() ··· 69 79 tokens.append(cross.LinkToken(href, href)) 70 80 index = md_auto.end() 71 81 continue 72 - 73 - if text[index] == '#': 82 + 83 + if text[index] == "#": 74 84 tag = HASHTAG.match(text, index) 75 85 if tag: 76 86 tag_text = tag.group(1) ··· 79 89 tokens.append(cross.TagToken(tag_text)) 80 90 index = tag.end() 81 91 continue 82 - 83 - if text[index] == '@': 92 + 93 + if text[index] == "@": 84 94 handle = FEDIVERSE_HANDLE.match(text, index) 85 95 if handle: 86 96 handle_text = handle.group(0) 87 97 stripped_handle = handle_text.strip() 88 - 98 + 89 99 match = next( 90 - (pair for pair in handles if stripped_handle in pair), 91 - None 100 + (pair for pair in handles if stripped_handle in pair), None 92 101 ) 93 - 102 + 94 103 if match: 95 104 flush() 96 - tokens.append(cross.MentionToken(match[1], '')) # TODO: misskey doesn’t provide a uri 105 + tokens.append( 106 + cross.MentionToken(match[1], "") 107 + ) # TODO: misskey doesn’t provide a uri 97 108 index = handle.end() 98 109 continue 99 - 110 + 100 111 url = URL.match(text, index) 101 112 if url: 102 113 flush() ··· 104 115 tokens.append(cross.LinkToken(href, href)) 105 116 index = url.end() 106 117 continue 107 - 118 + 108 119 buffer.append(text[index]) 109 120 index += 1 110 - 121 + 111 122 flush() 112 - return tokens 123 + return tokens
+73 -56
util/media.py
··· 1 + import json 2 + import os 3 + import re 4 + import subprocess 5 + import urllib.parse 6 + 7 + import magic 1 8 import requests 2 - import subprocess 3 - import json 4 - import re, urllib.parse, os 9 + 5 10 from util.util import LOGGER 6 - import magic 7 11 8 12 FILENAME = re.compile(r'filename="?([^\";]*)"?') 9 13 MAGIC = magic.Magic(mime=True) 10 14 11 - class MediaInfo(): 15 + 16 + class MediaInfo: 12 17 def __init__(self, url: str, name: str, mime: str, alt: str, io: bytes) -> None: 13 18 self.url = url 14 19 self.name = name ··· 16 21 self.alt = alt 17 22 self.io = io 18 23 24 + 19 25 def download_media(url: str, alt: str) -> MediaInfo | None: 20 26 name = get_filename_from_url(url) 21 27 io = download_blob(url, max_bytes=100_000_000) ··· 24 30 return None 25 31 mime = MAGIC.from_buffer(io) 26 32 if not mime: 27 - mime = 'application/octet-stream' 33 + mime = "application/octet-stream" 28 34 return MediaInfo(url, name, mime, alt, io) 35 + 29 36 30 37 def get_filename_from_url(url): 31 38 try: 32 39 response = requests.head(url, allow_redirects=True) 33 - disposition = response.headers.get('Content-Disposition') 40 + disposition = response.headers.get("Content-Disposition") 34 41 if disposition: 35 42 filename = FILENAME.findall(disposition) 36 43 if filename: ··· 40 47 41 48 parsed_url = urllib.parse.urlparse(url) 42 49 base_name = os.path.basename(parsed_url.path) 43 - 50 + 44 51 # hardcoded fix to return the cid for pds 45 - if base_name == 'com.atproto.sync.getBlob': 52 + if base_name == "com.atproto.sync.getBlob": 46 53 qs = urllib.parse.parse_qs(parsed_url.query) 47 - if qs and qs.get('cid'): 48 - return qs['cid'][0] 54 + if qs and qs.get("cid"): 55 + return qs["cid"][0] 49 56 50 57 return base_name 51 58 59 + 52 60 def probe_bytes(bytes: bytes) -> dict: 53 61 cmd = [ 54 - 'ffprobe', 55 - '-v', 'error', 56 - '-show_format', 57 - '-show_streams', 58 - '-print_format', 'json', 59 - 'pipe:0' 62 + "ffprobe", 63 + "-v", "error", 64 + "-show_format", 65 + "-show_streams", 66 + "-print_format", "json", 67 + "pipe:0", 60 68 ] 61 - proc = subprocess.run(cmd, input=bytes, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 69 + proc = subprocess.run( 70 + cmd, input=bytes, stdout=subprocess.PIPE, stderr=subprocess.PIPE 71 + ) 62 72 63 73 if proc.returncode != 0: 64 74 raise RuntimeError(f"ffprobe failed: {proc.stderr.decode()}") 65 75 66 76 return json.loads(proc.stdout) 67 77 78 + 68 79 def convert_to_mp4(video_bytes: bytes) -> bytes: 69 80 cmd = [ 70 - 'ffmpeg', 71 - '-i', 'pipe:0', 72 - '-c:v', 'libx264', 73 - '-crf', '30', 74 - '-preset', 'slow', 75 - '-c:a', 'aac', 76 - '-b:a', '128k', 77 - '-movflags', 'frag_keyframe+empty_moov+default_base_moof', 78 - '-f', 'mp4', 79 - 'pipe:1' 81 + "ffmpeg", 82 + "-i", "pipe:0", 83 + "-c:v", "libx264", 84 + "-crf", "30", 85 + "-preset", "slow", 86 + "-c:a", "aac", 87 + "-b:a", "128k", 88 + "-movflags", "frag_keyframe+empty_moov+default_base_moof", 89 + "-f", "mp4", 90 + "pipe:1", 80 91 ] 81 - 82 - proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 92 + 93 + proc = subprocess.Popen( 94 + cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE 95 + ) 83 96 out_bytes, err = proc.communicate(input=video_bytes) 84 - 97 + 85 98 if proc.returncode != 0: 86 99 raise RuntimeError(f"ffmpeg compress failed: {err.decode()}") 87 - 100 + 88 101 return out_bytes 89 102 103 + 90 104 def compress_image(image_bytes: bytes, quality: int = 90): 91 105 cmd = [ 92 - 'ffmpeg', 93 - '-f', 'image2pipe', 94 - '-i', 'pipe:0', 95 - '-c:v', 'webp', 96 - '-q:v', str(quality), 97 - '-f', 'image2pipe', 98 - 'pipe:1' 99 - ] 106 + "ffmpeg", 107 + "-f", "image2pipe", 108 + "-i", "pipe:0", 109 + "-c:v", "webp", 110 + "-q:v", str(quality), 111 + "-f", "image2pipe", 112 + "pipe:1", 113 + ] 100 114 101 - proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 115 + proc = subprocess.Popen( 116 + cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE 117 + ) 102 118 out_bytes, err = proc.communicate(input=image_bytes) 103 - 119 + 104 120 if proc.returncode != 0: 105 121 raise RuntimeError(f"ffmpeg compress failed: {err.decode()}") 106 - 122 + 107 123 return out_bytes 124 + 108 125 109 126 def download_blob(url: str, max_bytes: int = 5_000_000) -> bytes | None: 110 127 response = requests.get(url, stream=True, timeout=20) 111 128 if response.status_code != 200: 112 129 LOGGER.info("Failed to download %s! %s", url, response.text) 113 130 return None 114 - 131 + 115 132 downloaded_bytes = b"" 116 133 current_size = 0 117 - 134 + 118 135 for chunk in response.iter_content(chunk_size=8192): 119 - if not chunk: 136 + if not chunk: 120 137 continue 121 - 138 + 122 139 current_size += len(chunk) 123 140 if current_size > max_bytes: 124 141 response.close() 125 142 return None 126 - 143 + 127 144 downloaded_bytes += chunk 128 - 145 + 129 146 return downloaded_bytes 130 - 147 + 131 148 132 149 def get_media_meta(bytes: bytes): 133 150 probe = probe_bytes(bytes) 134 - streams = [s for s in probe['streams'] if s['codec_type'] == 'video'] 151 + streams = [s for s in probe["streams"] if s["codec_type"] == "video"] 135 152 if not streams: 136 153 raise ValueError("No video stream found") 137 - 154 + 138 155 media = streams[0] 139 156 return { 140 - 'width': int(media['width']), 141 - 'height': int(media['height']), 142 - 'duration': float(media.get('duration', probe['format'].get('duration', -1))) 143 - } 157 + "width": int(media["width"]), 158 + "height": int(media["height"]), 159 + "duration": float(media.get("duration", probe["format"].get("duration", -1))), 160 + }
+20 -13
util/util.py
··· 1 - import logging, sys, os 2 1 import json 2 + import logging 3 + import os 4 + import sys 3 5 4 6 logging.basicConfig(stream=sys.stdout, level=logging.INFO) 5 7 LOGGER = logging.getLogger("XPost") 6 8 7 - def as_json(obj, indent=None,sort_keys=False) -> str: 9 + 10 + def as_json(obj, indent=None, sort_keys=False) -> str: 8 11 return json.dumps( 9 - obj.__dict__ if not isinstance(obj, dict) else obj, 10 - default=lambda o: o.__json__() if hasattr(o, '__json__') else o.__dict__, 12 + obj.__dict__ if not isinstance(obj, dict) else obj, 13 + default=lambda o: o.__json__() if hasattr(o, "__json__") else o.__dict__, 11 14 indent=indent, 12 - sort_keys=sort_keys) 15 + sort_keys=sort_keys, 16 + ) 17 + 13 18 14 19 def canonical_label(label: str | None, href: str): 15 20 if not label or label == href: 16 21 return True 17 - 18 - split = href.split('://', 1) 22 + 23 + split = href.split("://", 1) 19 24 if len(split) > 1: 20 25 if split[1] == label: 21 26 return True 22 - 27 + 23 28 return False 24 29 30 + 25 31 def safe_get(obj: dict, key: str, default): 26 32 val = obj.get(key, default) 27 33 return val if val else default 34 + 28 35 29 36 def as_envvar(text: str | None) -> str | None: 30 37 if not text: 31 38 return None 32 - 33 - if text.startswith('env:'): 34 - return os.environ.get(text[4:], '') 35 - 36 - return text 39 + 40 + if text.startswith("env:"): 41 + return os.environ.get(text[4:], "") 42 + 43 + return text