tangled
alpha
login
or
join now
zenfyr.dev
/
xpost
2
fork
atom
social media crossposting tool. 3rd time's the charm
mastodon
misskey
crossposting
bluesky
2
fork
atom
overview
issues
1
pulls
pipelines
run a formatter
zenfyr.dev
7 months ago
dd2bea73
c0211a8f
verified
This commit was signed with the committer's
known signature
.
zenfyr.dev
SSH Key Fingerprint:
SHA256:TtcIcnTnoAB5mqHofsaOxIgiMzfVBxej1AXT7DQdrTE=
+1541
-1238
17 changed files
expand all
collapse all
unified
split
bluesky
atproto2.py
common.py
input.py
output.py
cross.py
main.py
mastodon
common.py
input.py
output.py
misskey
common.py
input.py
mfm_util.py
util
database.py
html_util.py
md_util.py
media.py
util.py
+82
-72
bluesky/atproto2.py
reviewed
···
1
1
from typing import Any
2
2
-
from atproto import client_utils, Client, AtUri, IdResolver
2
2
+
3
3
+
from atproto import AtUri, Client, IdResolver, client_utils
3
4
from atproto_client import models
5
5
+
4
6
from util.util import LOGGER
7
7
+
5
8
6
9
def resolve_identity(
7
7
-
handle: str | None = None,
8
8
-
did: str | None = None,
9
9
-
pds: str | None = None):
10
10
+
handle: str | None = None, did: str | None = None, pds: str | None = None
11
11
+
):
10
12
"""helper to try and resolve identity from provided parameters, a valid handle is enough"""
11
11
-
13
13
+
12
14
if did and pds:
13
13
-
return did, pds[:-1] if pds.endswith('/') else pds
14
14
-
15
15
+
return did, pds[:-1] if pds.endswith("/") else pds
16
16
+
15
17
resolver = IdResolver()
16
18
if not did:
17
19
if not handle:
···
20
22
did = resolver.handle.resolve(handle)
21
23
if not did:
22
24
raise Exception("Failed to resolve DID!")
23
23
-
25
25
+
24
26
if not pds:
25
27
LOGGER.info("Resolving PDS from DID document...")
26
28
did_doc = resolver.did.resolve(did)
···
29
31
pds = did_doc.get_pds_endpoint()
30
32
if not pds:
31
33
raise Exception("Failed to resolve PDS!")
32
32
-
33
33
-
return did, pds[:-1] if pds.endswith('/') else pds
34
34
+
35
35
+
return did, pds[:-1] if pds.endswith("/") else pds
36
36
+
34
37
35
38
class Client2(Client):
36
39
def __init__(self, base_url: str | None = None, *args: Any, **kwargs: Any) -> None:
37
40
super().__init__(base_url, *args, **kwargs)
38
38
-
41
41
+
39
42
def send_video(
40
40
-
self,
41
41
-
text: str | client_utils.TextBuilder,
43
43
+
self,
44
44
+
text: str | client_utils.TextBuilder,
42
45
video: bytes,
43
46
video_alt: str | None = None,
44
47
video_aspect_ratio: models.AppBskyEmbedDefs.AspectRatio | None = None,
···
46
49
langs: list[str] | None = None,
47
50
facets: list[models.AppBskyRichtextFacet.Main] | None = None,
48
51
labels: models.ComAtprotoLabelDefs.SelfLabels | None = None,
49
49
-
time_iso: str | None = None
50
50
-
) -> models.AppBskyFeedPost.CreateRecordResponse:
52
52
+
time_iso: str | None = None,
53
53
+
) -> models.AppBskyFeedPost.CreateRecordResponse:
51
54
"""same as send_video, but with labels"""
52
52
-
55
55
+
53
56
if video_alt is None:
54
54
-
video_alt = ''
57
57
+
video_alt = ""
55
58
56
59
upload = self.upload_blob(video)
57
57
-
60
60
+
58
61
return self.send_post(
59
62
text,
60
63
reply_to=reply_to,
61
61
-
embed=models.AppBskyEmbedVideo.Main(video=upload.blob, alt=video_alt, aspect_ratio=video_aspect_ratio),
64
64
+
embed=models.AppBskyEmbedVideo.Main(
65
65
+
video=upload.blob, alt=video_alt, aspect_ratio=video_aspect_ratio
66
66
+
),
62
67
langs=langs,
63
68
facets=facets,
64
69
labels=labels,
65
65
-
time_iso=time_iso
70
70
+
time_iso=time_iso,
66
71
)
67
67
-
72
72
+
68
73
def send_images(
69
69
-
self,
70
70
-
text: str | client_utils.TextBuilder,
74
74
+
self,
75
75
+
text: str | client_utils.TextBuilder,
71
76
images: list[bytes],
72
77
image_alts: list[str] | None = None,
73
78
image_aspect_ratios: list[models.AppBskyEmbedDefs.AspectRatio] | None = None,
···
75
80
langs: list[str] | None = None,
76
81
facets: list[models.AppBskyRichtextFacet.Main] | None = None,
77
82
labels: models.ComAtprotoLabelDefs.SelfLabels | None = None,
78
78
-
time_iso: str | None = None
79
79
-
) -> models.AppBskyFeedPost.CreateRecordResponse:
83
83
+
time_iso: str | None = None,
84
84
+
) -> models.AppBskyFeedPost.CreateRecordResponse:
80
85
"""same as send_images, but with labels"""
81
81
-
86
86
+
82
87
if image_alts is None:
83
83
-
image_alts = [''] * len(images)
88
88
+
image_alts = [""] * len(images)
84
89
else:
85
90
diff = len(images) - len(image_alts)
86
86
-
image_alts = image_alts + [''] * diff
87
87
-
91
91
+
image_alts = image_alts + [""] * diff
92
92
+
88
93
if image_aspect_ratios is None:
89
94
aligned_image_aspect_ratios = [None] * len(images)
90
95
else:
91
96
diff = len(images) - len(image_aspect_ratios)
92
97
aligned_image_aspect_ratios = image_aspect_ratios + [None] * diff
93
93
-
98
98
+
94
99
uploads = [self.upload_blob(image) for image in images]
95
95
-
100
100
+
96
101
embed_images = [
97
97
-
models.AppBskyEmbedImages.Image(alt=alt, image=upload.blob, aspect_ratio=aspect_ratio)
98
98
-
for alt, upload, aspect_ratio in zip(image_alts, uploads, aligned_image_aspect_ratios)
102
102
+
models.AppBskyEmbedImages.Image(
103
103
+
alt=alt, image=upload.blob, aspect_ratio=aspect_ratio
104
104
+
)
105
105
+
for alt, upload, aspect_ratio in zip(
106
106
+
image_alts, uploads, aligned_image_aspect_ratios
107
107
+
)
99
108
]
100
100
-
109
109
+
101
110
return self.send_post(
102
111
text,
103
112
reply_to=reply_to,
···
105
114
langs=langs,
106
115
facets=facets,
107
116
labels=labels,
108
108
-
time_iso=time_iso
117
117
+
time_iso=time_iso,
109
118
)
110
110
-
119
119
+
111
120
def send_post(
112
112
-
self,
113
113
-
text: str | client_utils.TextBuilder,
121
121
+
self,
122
122
+
text: str | client_utils.TextBuilder,
114
123
reply_to: models.AppBskyFeedPost.ReplyRef | None = None,
115
115
-
embed:
116
116
-
None |
117
117
-
models.AppBskyEmbedImages.Main |
118
118
-
models.AppBskyEmbedExternal.Main |
119
119
-
models.AppBskyEmbedRecord.Main |
120
120
-
models.AppBskyEmbedRecordWithMedia.Main |
121
121
-
models.AppBskyEmbedVideo.Main = None,
124
124
+
embed: None
125
125
+
| models.AppBskyEmbedImages.Main
126
126
+
| models.AppBskyEmbedExternal.Main
127
127
+
| models.AppBskyEmbedRecord.Main
128
128
+
| models.AppBskyEmbedRecordWithMedia.Main
129
129
+
| models.AppBskyEmbedVideo.Main = None,
122
130
langs: list[str] | None = None,
123
131
facets: list[models.AppBskyRichtextFacet.Main] | None = None,
124
132
labels: models.ComAtprotoLabelDefs.SelfLabels | None = None,
125
125
-
time_iso: str | None = None
126
126
-
) -> models.AppBskyFeedPost.CreateRecordResponse:
133
133
+
time_iso: str | None = None,
134
134
+
) -> models.AppBskyFeedPost.CreateRecordResponse:
127
135
"""same as send_post, but with labels"""
128
128
-
136
136
+
129
137
if isinstance(text, client_utils.TextBuilder):
130
138
facets = text.build_facets()
131
139
text = text.build_text()
132
132
-
140
140
+
133
141
repo = self.me and self.me.did
134
142
if not repo:
135
143
raise Exception("Client not logged in!")
136
136
-
144
144
+
137
145
if not langs:
138
138
-
langs = ['en']
139
139
-
146
146
+
langs = ["en"]
147
147
+
140
148
record = models.AppBskyFeedPost.Record(
141
149
created_at=time_iso or self.get_current_time_iso(),
142
150
text=text,
···
144
152
embed=embed or None,
145
153
langs=langs,
146
154
facets=facets or None,
147
147
-
labels=labels or None
155
155
+
labels=labels or None,
148
156
)
149
157
return self.app.bsky.feed.post.create(repo, record)
150
150
-
151
151
-
def create_gates(self, thread_gate_opts: list[str], quote_gate: bool, post_uri: str, time_iso: str | None = None):
158
158
+
159
159
+
def create_gates(
160
160
+
self,
161
161
+
thread_gate_opts: list[str],
162
162
+
quote_gate: bool,
163
163
+
post_uri: str,
164
164
+
time_iso: str | None = None,
165
165
+
):
152
166
account = self.me
153
167
if not account:
154
168
raise Exception("Client not logged in!")
155
155
-
169
169
+
156
170
rkey = AtUri.from_str(post_uri).rkey
157
171
time_iso = time_iso or self.get_current_time_iso()
158
158
-
159
159
-
if 'everybody' not in thread_gate_opts:
172
172
+
173
173
+
if "everybody" not in thread_gate_opts:
160
174
allow = []
161
175
if thread_gate_opts:
162
162
-
if 'following' in thread_gate_opts:
176
176
+
if "following" in thread_gate_opts:
163
177
allow.append(models.AppBskyFeedThreadgate.FollowingRule())
164
164
-
if 'followers' in thread_gate_opts:
178
178
+
if "followers" in thread_gate_opts:
165
179
allow.append(models.AppBskyFeedThreadgate.FollowerRule())
166
166
-
if 'mentioned' in thread_gate_opts:
180
180
+
if "mentioned" in thread_gate_opts:
167
181
allow.append(models.AppBskyFeedThreadgate.MentionRule())
168
168
-
182
182
+
169
183
thread_gate = models.AppBskyFeedThreadgate.Record(
170
170
-
post=post_uri,
171
171
-
created_at=time_iso,
172
172
-
allow=allow
184
184
+
post=post_uri, created_at=time_iso, allow=allow
173
185
)
174
174
-
186
186
+
175
187
self.app.bsky.feed.threadgate.create(account.did, thread_gate, rkey)
176
176
-
188
188
+
177
189
if quote_gate:
178
190
post_gate = models.AppBskyFeedPostgate.Record(
179
191
post=post_uri,
180
192
created_at=time_iso,
181
181
-
embedding_rules=[
182
182
-
models.AppBskyFeedPostgate.DisableRule()
183
183
-
]
193
193
+
embedding_rules=[models.AppBskyFeedPostgate.DisableRule()],
184
194
)
185
185
-
186
186
-
self.app.bsky.feed.postgate.create(account.did, post_gate, rkey)
195
195
+
196
196
+
self.app.bsky.feed.postgate.create(account.did, post_gate, rkey)
+90
-71
bluesky/common.py
reviewed
···
1
1
-
import re, json
1
1
+
import re
2
2
3
3
from atproto import client_utils
4
4
···
7
7
from util.util import canonical_label
8
8
9
9
# only for lexicon reference
10
10
-
SERVICE = 'https://bsky.app'
10
10
+
SERVICE = "https://bsky.app"
11
11
12
12
# TODO this is terrible and stupid
13
13
-
ADULT_PATTERN = re.compile(r"\b(sexual content|nsfw|erotic|adult only|18\+)\b", re.IGNORECASE)
14
14
-
PORN_PATTERN = re.compile(r"\b(porn|yiff|hentai|pornographic|fetish)\b", re.IGNORECASE)
13
13
+
ADULT_PATTERN = re.compile(
14
14
+
r"\b(sexual content|nsfw|erotic|adult only|18\+)\b", re.IGNORECASE
15
15
+
)
16
16
+
PORN_PATTERN = re.compile(r"\b(porn|yiff|hentai|pornographic|fetish)\b", re.IGNORECASE)
17
17
+
15
18
16
19
class BlueskyPost(cross.Post):
17
17
-
def __init__(self, record: dict, tokens: list[cross.Token], attachments: list[MediaInfo]) -> None:
20
20
+
def __init__(
21
21
+
self, record: dict, tokens: list[cross.Token], attachments: list[MediaInfo]
22
22
+
) -> None:
18
23
super().__init__()
19
19
-
self.uri = record['$xpost.strongRef']['uri']
24
24
+
self.uri = record["$xpost.strongRef"]["uri"]
20
25
self.parent_uri = None
21
21
-
if record.get('reply'):
22
22
-
self.parent_uri = record['reply']['parent']['uri']
23
23
-
26
26
+
if record.get("reply"):
27
27
+
self.parent_uri = record["reply"]["parent"]["uri"]
28
28
+
24
29
self.tokens = tokens
25
25
-
self.timestamp = record['createdAt']
26
26
-
labels = record.get('labels', {}).get('values')
30
30
+
self.timestamp = record["createdAt"]
31
31
+
labels = record.get("labels", {}).get("values")
27
32
self.spoiler = None
28
33
if labels:
29
29
-
self.spoiler = ', '.join([str(label['val']).replace('-', ' ') for label in labels])
30
30
-
34
34
+
self.spoiler = ", ".join(
35
35
+
[str(label["val"]).replace("-", " ") for label in labels]
36
36
+
)
37
37
+
31
38
self.attachments = attachments
32
32
-
self.languages = record.get('langs', [])
33
33
-
39
39
+
self.languages = record.get("langs", [])
40
40
+
34
41
# at:// of the post record
35
42
def get_id(self) -> str:
36
43
return self.uri
37
37
-
44
44
+
38
45
def get_parent_id(self) -> str | None:
39
46
return self.parent_uri
40
40
-
47
47
+
41
48
def get_tokens(self) -> list[cross.Token]:
42
49
return self.tokens
43
43
-
50
50
+
44
51
def get_text_type(self) -> str:
45
52
return "text/plain"
46
46
-
53
53
+
47
54
def get_timestamp(self) -> str:
48
55
return self.timestamp
49
56
50
57
def get_attachments(self) -> list[MediaInfo]:
51
58
return self.attachments
52
52
-
59
59
+
53
60
def get_spoiler(self) -> str | None:
54
61
return self.spoiler
55
62
56
63
def get_languages(self) -> list[str]:
57
64
return self.languages
58
58
-
65
65
+
59
66
def is_sensitive(self) -> bool:
60
67
return self.spoiler is not None
61
68
62
69
def get_post_url(self) -> str | None:
63
63
-
did, _, post_id = str(self.uri[len("at://"):]).split("/")
64
64
-
70
70
+
did, _, post_id = str(self.uri[len("at://") :]).split("/")
71
71
+
65
72
return f"https://bsky.app/profile/{did}/post/{post_id}"
73
73
+
66
74
67
75
def tokenize_post(post: dict) -> list[cross.Token]:
68
68
-
text: str = post.get('text', '')
76
76
+
text: str = post.get("text", "")
69
77
if not text:
70
78
return []
71
71
-
ut8_text = text.encode(encoding='utf-8')
72
72
-
79
79
+
ut8_text = text.encode(encoding="utf-8")
80
80
+
73
81
def decode(ut8: bytes) -> str:
74
74
-
return ut8.decode(encoding='utf-8')
75
75
-
76
76
-
facets: list[dict] = post.get('facets', [])
82
82
+
return ut8.decode(encoding="utf-8")
83
83
+
84
84
+
facets: list[dict] = post.get("facets", [])
77
85
if not facets:
78
86
return [cross.TextToken(decode(ut8_text))]
79
79
-
87
87
+
80
88
slices: list[tuple[int, int, str, str]] = []
81
81
-
89
89
+
82
90
for facet in facets:
83
83
-
features: list[dict] = facet.get('features', [])
91
91
+
features: list[dict] = facet.get("features", [])
84
92
if not features:
85
93
continue
86
86
-
94
94
+
87
95
# we don't support overlapping facets/features
88
96
feature = features[0]
89
89
-
feature_type = feature['$type']
90
90
-
index = facet['index']
97
97
+
feature_type = feature["$type"]
98
98
+
index = facet["index"]
91
99
match feature_type:
92
92
-
case 'app.bsky.richtext.facet#tag':
93
93
-
slices.append((index['byteStart'], index['byteEnd'], 'tag', feature['tag']))
94
94
-
case 'app.bsky.richtext.facet#link':
95
95
-
slices.append((index['byteStart'], index['byteEnd'], 'link', feature['uri']))
96
96
-
case 'app.bsky.richtext.facet#mention':
97
97
-
slices.append((index['byteStart'], index['byteEnd'], 'mention', feature['did']))
98
98
-
100
100
+
case "app.bsky.richtext.facet#tag":
101
101
+
slices.append(
102
102
+
(index["byteStart"], index["byteEnd"], "tag", feature["tag"])
103
103
+
)
104
104
+
case "app.bsky.richtext.facet#link":
105
105
+
slices.append(
106
106
+
(index["byteStart"], index["byteEnd"], "link", feature["uri"])
107
107
+
)
108
108
+
case "app.bsky.richtext.facet#mention":
109
109
+
slices.append(
110
110
+
(index["byteStart"], index["byteEnd"], "mention", feature["did"])
111
111
+
)
112
112
+
99
113
if not slices:
100
114
return [cross.TextToken(decode(ut8_text))]
101
101
-
115
115
+
102
116
slices.sort(key=lambda s: s[0])
103
117
unique: list[tuple[int, int, str, str]] = []
104
118
current_end = 0
···
106
120
if start >= current_end:
107
121
unique.append((start, end, ttype, val))
108
122
current_end = end
109
109
-
123
123
+
110
124
if not unique:
111
125
return [cross.TextToken(decode(ut8_text))]
112
112
-
126
126
+
113
127
tokens: list[cross.Token] = []
114
128
prev = 0
115
115
-
129
129
+
116
130
for start, end, ttype, val in unique:
117
131
if start > prev:
118
132
# text between facets
119
133
tokens.append(cross.TextToken(decode(ut8_text[prev:start])))
120
134
# facet token
121
135
match ttype:
122
122
-
case 'link':
136
136
+
case "link":
123
137
label = decode(ut8_text[start:end])
124
124
-
138
138
+
125
139
# try to unflatten links
126
126
-
split = val.split('://', 1)
140
140
+
split = val.split("://", 1)
127
141
if len(split) > 1:
128
142
if split[1].startswith(label):
129
129
-
tokens.append(cross.LinkToken(val, ''))
143
143
+
tokens.append(cross.LinkToken(val, ""))
130
144
prev = end
131
145
continue
132
132
-
133
133
-
if label.endswith('...') and split[1].startswith(label[:-3]):
134
134
-
tokens.append(cross.LinkToken(val, ''))
146
146
+
147
147
+
if label.endswith("...") and split[1].startswith(label[:-3]):
148
148
+
tokens.append(cross.LinkToken(val, ""))
135
149
prev = end
136
136
-
continue
137
137
-
150
150
+
continue
151
151
+
138
152
tokens.append(cross.LinkToken(val, label))
139
139
-
case 'tag':
153
153
+
case "tag":
140
154
tag = decode(ut8_text[start:end])
141
141
-
tokens.append(cross.TagToken(tag[1:] if tag.startswith('#') else tag))
142
142
-
case 'mention':
155
155
+
tokens.append(cross.TagToken(tag[1:] if tag.startswith("#") else tag))
156
156
+
case "mention":
143
157
mention = decode(ut8_text[start:end])
144
144
-
tokens.append(cross.MentionToken(mention[1:] if mention.startswith('@') else mention, val))
158
158
+
tokens.append(
159
159
+
cross.MentionToken(
160
160
+
mention[1:] if mention.startswith("@") else mention, val
161
161
+
)
162
162
+
)
145
163
prev = end
146
164
147
165
if prev < len(ut8_text):
148
166
tokens.append(cross.TextToken(decode(ut8_text[prev:])))
149
149
-
150
150
-
return tokens
167
167
+
168
168
+
return tokens
169
169
+
151
170
152
171
def tokens_to_richtext(tokens: list[cross.Token]) -> client_utils.TextBuilder | None:
153
172
builder = client_utils.TextBuilder()
154
154
-
173
173
+
155
174
def flatten_link(href: str):
156
156
-
split = href.split('://', 1)
175
175
+
split = href.split("://", 1)
157
176
if len(split) > 1:
158
177
href = split[1]
159
159
-
178
178
+
160
179
if len(href) > 32:
161
161
-
href = href[:32] + '...'
162
162
-
180
180
+
href = href[:32] + "..."
181
181
+
163
182
return href
164
164
-
183
183
+
165
184
for token in tokens:
166
185
if isinstance(token, cross.TextToken):
167
186
builder.text(token.text)
···
169
188
if canonical_label(token.label, token.href):
170
189
builder.link(flatten_link(token.href), token.href)
171
190
continue
172
172
-
191
191
+
173
192
builder.link(token.label, token.href)
174
193
elif isinstance(token, cross.TagToken):
175
175
-
builder.tag('#' + token.tag, token.tag.lower())
194
194
+
builder.tag("#" + token.tag, token.tag.lower())
176
195
else:
177
196
# fail on unsupported tokens
178
197
return None
179
179
-
180
180
-
return builder
198
198
+
199
199
+
return builder
+105
-79
bluesky/input.py
reviewed
···
1
1
-
import re, json, websockets, asyncio
1
1
+
import asyncio
2
2
+
import json
3
3
+
import re
4
4
+
from typing import Any, Callable
2
5
6
6
+
import websockets
3
7
from atproto_client import models
4
8
from atproto_client.models.utils import get_or_create as get_model_or_create
9
9
+
10
10
+
import cross
11
11
+
import util.database as database
5
12
from bluesky.atproto2 import resolve_identity
6
6
-
7
7
-
from bluesky.common import BlueskyPost, SERVICE, tokenize_post
8
8
-
9
9
-
import cross, util.database as database
13
13
+
from bluesky.common import SERVICE, BlueskyPost, tokenize_post
14
14
+
from util.database import DataBaseWorker
15
15
+
from util.media import MediaInfo, download_media
10
16
from util.util import LOGGER, as_envvar
11
11
-
from util.media import MediaInfo, download_media
12
12
-
from util.database import DataBaseWorker
13
17
14
14
-
from typing import Callable, Any
15
18
16
16
-
class BlueskyInputOptions():
19
19
+
class BlueskyInputOptions:
17
20
def __init__(self, o: dict) -> None:
18
18
-
self.filters = [re.compile(f) for f in o.get('regex_filters', [])]
21
21
+
self.filters = [re.compile(f) for f in o.get("regex_filters", [])]
22
22
+
19
23
20
24
class BlueskyInput(cross.Input):
21
25
def __init__(self, settings: dict, db: DataBaseWorker) -> None:
22
22
-
self.options = BlueskyInputOptions(settings.get('options', {}))
26
26
+
self.options = BlueskyInputOptions(settings.get("options", {}))
23
27
did, pds = resolve_identity(
24
24
-
handle=as_envvar(settings.get('handle')),
25
25
-
did=as_envvar(settings.get('did')),
26
26
-
pds=as_envvar(settings.get('pds'))
28
28
+
handle=as_envvar(settings.get("handle")),
29
29
+
did=as_envvar(settings.get("did")),
30
30
+
pds=as_envvar(settings.get("pds")),
27
31
)
28
32
self.pds = pds
29
29
-
33
33
+
30
34
# PDS is Not a service, the lexicon and rids are the same across pds
31
35
super().__init__(SERVICE, did, settings, db)
32
32
-
36
36
+
33
37
def _on_post(self, outputs: list[cross.Output], post: dict[str, Any]):
34
34
-
post_uri = post['$xpost.strongRef']['uri']
35
35
-
post_cid = post['$xpost.strongRef']['cid']
36
36
-
38
38
+
post_uri = post["$xpost.strongRef"]["uri"]
39
39
+
post_cid = post["$xpost.strongRef"]["cid"]
40
40
+
37
41
parent_uri = None
38
38
-
if post.get('reply'):
39
39
-
parent_uri = post['reply']['parent']['uri']
40
40
-
41
41
-
embed = post.get('embed', {})
42
42
-
if embed.get('$type') in ('app.bsky.embed.record', 'app.bsky.embed.recordWithMedia'):
43
43
-
did, collection, rid = str(embed['record']['uri'][len('at://'):]).split('/')
44
44
-
if collection == 'app.bsky.feed.post':
42
42
+
if post.get("reply"):
43
43
+
parent_uri = post["reply"]["parent"]["uri"]
44
44
+
45
45
+
embed = post.get("embed", {})
46
46
+
if embed.get("$type") in (
47
47
+
"app.bsky.embed.record",
48
48
+
"app.bsky.embed.recordWithMedia",
49
49
+
):
50
50
+
did, collection, rid = str(embed["record"]["uri"][len("at://") :]).split(
51
51
+
"/"
52
52
+
)
53
53
+
if collection == "app.bsky.feed.post":
45
54
LOGGER.info("Skipping '%s'! Quote..", post_uri)
46
55
return
47
47
-
48
48
-
success = database.try_insert_post(self.db, post_uri, parent_uri, self.user_id, self.service)
56
56
+
57
57
+
success = database.try_insert_post(
58
58
+
self.db, post_uri, parent_uri, self.user_id, self.service
59
59
+
)
49
60
if not success:
50
61
LOGGER.info("Skipping '%s' as parent post was not found in db!", post_uri)
51
62
return
52
52
-
database.store_data(self.db, post_uri, self.user_id, self.service, {'cid': post_cid})
53
53
-
63
63
+
database.store_data(
64
64
+
self.db, post_uri, self.user_id, self.service, {"cid": post_cid}
65
65
+
)
66
66
+
54
67
tokens = tokenize_post(post)
55
68
if not cross.test_filters(tokens, self.options.filters):
56
69
LOGGER.info("Skipping '%s'. Matched a filter!", post_uri)
57
70
return
58
58
-
71
71
+
59
72
LOGGER.info("Crossposting '%s'...", post_uri)
60
60
-
73
73
+
61
74
def get_blob_url(blob: str):
62
62
-
return f'{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.user_id}&cid={blob}'
63
63
-
75
75
+
return f"{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.user_id}&cid={blob}"
76
76
+
64
77
attachments: list[MediaInfo] = []
65
65
-
if embed.get('$type') == 'app.bsky.embed.images':
78
78
+
if embed.get("$type") == "app.bsky.embed.images":
66
79
model = get_model_or_create(embed, model=models.AppBskyEmbedImages.Main)
67
80
assert isinstance(model, models.AppBskyEmbedImages.Main)
68
68
-
81
81
+
69
82
for image in model.images:
70
83
url = get_blob_url(image.image.cid.encode())
71
84
LOGGER.info("Downloading %s...", url)
···
74
87
LOGGER.error("Skipping '%s'. Failed to download media!", post_uri)
75
88
return
76
89
attachments.append(io)
77
77
-
elif embed.get('$type') == 'app.bsky.embed.video':
90
90
+
elif embed.get("$type") == "app.bsky.embed.video":
78
91
model = get_model_or_create(embed, model=models.AppBskyEmbedVideo.Main)
79
92
assert isinstance(model, models.AppBskyEmbedVideo.Main)
80
93
url = get_blob_url(model.video.cid.encode())
81
94
LOGGER.info("Downloading %s...", url)
82
82
-
io = download_media(url, model.alt if model.alt else '')
95
95
+
io = download_media(url, model.alt if model.alt else "")
83
96
if not io:
84
97
LOGGER.error("Skipping '%s'. Failed to download media!", post_uri)
85
98
return
86
99
attachments.append(io)
87
87
-
100
100
+
88
101
cross_post = BlueskyPost(post, tokens, attachments)
89
102
for output in outputs:
90
103
output.accept_post(cross_post)
···
93
106
post = database.find_post(self.db, post_id, self.user_id, self.service)
94
107
if not post:
95
108
return
96
96
-
109
109
+
97
110
LOGGER.info("Deleting '%s'...", post_id)
98
111
if repost:
99
112
for output in outputs:
···
102
115
for output in outputs:
103
116
output.delete_post(post_id)
104
117
database.delete_post(self.db, post_id, self.user_id, self.service)
105
105
-
118
118
+
106
119
def _on_repost(self, outputs: list[cross.Output], post: dict[str, Any]):
107
107
-
post_uri = post['$xpost.strongRef']['uri']
108
108
-
post_cid = post['$xpost.strongRef']['cid']
109
109
-
110
110
-
reposted_uri = post['subject']['uri']
111
111
-
112
112
-
success = database.try_insert_repost(self.db, post_uri, reposted_uri, self.user_id, self.service)
120
120
+
post_uri = post["$xpost.strongRef"]["uri"]
121
121
+
post_cid = post["$xpost.strongRef"]["cid"]
122
122
+
123
123
+
reposted_uri = post["subject"]["uri"]
124
124
+
125
125
+
success = database.try_insert_repost(
126
126
+
self.db, post_uri, reposted_uri, self.user_id, self.service
127
127
+
)
113
128
if not success:
114
129
LOGGER.info("Skipping '%s' as reposted post was not found in db!", post_uri)
115
130
return
116
116
-
database.store_data(self.db, post_uri, self.user_id, self.service, {'cid': post_cid})
117
117
-
131
131
+
database.store_data(
132
132
+
self.db, post_uri, self.user_id, self.service, {"cid": post_cid}
133
133
+
)
134
134
+
118
135
LOGGER.info("Crossposting '%s'...", post_uri)
119
136
for output in outputs:
120
137
output.accept_repost(post_uri, reposted_uri)
121
138
139
139
+
122
140
class BlueskyJetstreamInput(BlueskyInput):
123
141
def __init__(self, settings: dict, db: DataBaseWorker) -> None:
124
142
super().__init__(settings, db)
125
125
-
self.jetstream = settings.get("jetstream", "wss://jetstream2.us-east.bsky.network/subscribe")
126
126
-
143
143
+
self.jetstream = settings.get(
144
144
+
"jetstream", "wss://jetstream2.us-east.bsky.network/subscribe"
145
145
+
)
146
146
+
127
147
def __on_commit(self, outputs: list[cross.Output], msg: dict):
128
128
-
if msg.get('did') != self.user_id:
148
148
+
if msg.get("did") != self.user_id:
129
149
return
130
130
-
131
131
-
commit: dict = msg.get('commit', {})
150
150
+
151
151
+
commit: dict = msg.get("commit", {})
132
152
if not commit:
133
153
return
134
134
-
135
135
-
commit_type = commit['operation']
154
154
+
155
155
+
commit_type = commit["operation"]
136
156
match commit_type:
137
137
-
case 'create':
138
138
-
record = dict(commit.get('record', {}))
139
139
-
record['$xpost.strongRef'] = {
140
140
-
'cid': commit['cid'],
141
141
-
'uri': f'at://{self.user_id}/{commit['collection']}/{commit['rkey']}'
157
157
+
case "create":
158
158
+
record = dict(commit.get("record", {}))
159
159
+
record["$xpost.strongRef"] = {
160
160
+
"cid": commit["cid"],
161
161
+
"uri": f"at://{self.user_id}/{commit['collection']}/{commit['rkey']}",
142
162
}
143
143
-
144
144
-
match commit['collection']:
145
145
-
case 'app.bsky.feed.post':
163
163
+
164
164
+
match commit["collection"]:
165
165
+
case "app.bsky.feed.post":
146
166
self._on_post(outputs, record)
147
147
-
case 'app.bsky.feed.repost':
167
167
+
case "app.bsky.feed.repost":
148
168
self._on_repost(outputs, record)
149
149
-
case 'delete':
150
150
-
post_id: str = f'at://{self.user_id}/{commit['collection']}/{commit['rkey']}'
151
151
-
match commit['collection']:
152
152
-
case 'app.bsky.feed.post':
169
169
+
case "delete":
170
170
+
post_id: str = (
171
171
+
f"at://{self.user_id}/{commit['collection']}/{commit['rkey']}"
172
172
+
)
173
173
+
match commit["collection"]:
174
174
+
case "app.bsky.feed.post":
153
175
self._on_delete_post(outputs, post_id, False)
154
154
-
case 'app.bsky.feed.repost':
176
176
+
case "app.bsky.feed.repost":
155
177
self._on_delete_post(outputs, post_id, True)
156
156
-
157
157
-
async def listen(self, outputs: list[cross.Output], submit: Callable[[Callable[[], Any]], Any]):
158
158
-
uri = self.jetstream + '?'
178
178
+
179
179
+
async def listen(
180
180
+
self, outputs: list[cross.Output], submit: Callable[[Callable[[], Any]], Any]
181
181
+
):
182
182
+
uri = self.jetstream + "?"
159
183
uri += "wantedCollections=app.bsky.feed.post"
160
184
uri += "&wantedCollections=app.bsky.feed.repost"
161
185
uri += f"&wantedDids={self.user_id}"
162
162
-
163
163
-
async for ws in websockets.connect(uri, extra_headers={"User-Agent": "XPost/0.0.3"}):
186
186
+
187
187
+
async for ws in websockets.connect(
188
188
+
uri, extra_headers={"User-Agent": "XPost/0.0.3"}
189
189
+
):
164
190
try:
165
191
LOGGER.info("Listening to %s...", self.jetstream)
166
166
-
192
192
+
167
193
async def listen_for_messages():
168
194
async for msg in ws:
169
195
submit(lambda: self.__on_commit(outputs, json.loads(msg)))
170
170
-
196
196
+
171
197
listen = asyncio.create_task(listen_for_messages())
172
172
-
198
198
+
173
199
await asyncio.gather(listen)
174
200
except websockets.ConnectionClosedError as e:
175
201
LOGGER.error(e, stack_info=True, exc_info=True)
176
202
LOGGER.info("Reconnecting to %s...", self.jetstream)
177
177
-
continue
203
203
+
continue
+238
-209
bluesky/output.py
reviewed
···
1
1
-
import json
2
2
-
from httpx import Timeout
3
3
-
4
4
-
from atproto import client_utils, Request
1
1
+
from atproto import Request, client_utils
5
2
from atproto_client import models
6
6
-
from bluesky.atproto2 import Client2, resolve_identity
7
7
-
8
8
-
from bluesky.common import SERVICE, ADULT_PATTERN, PORN_PATTERN, tokens_to_richtext
3
3
+
from httpx import Timeout
9
4
10
10
-
import cross, util.database as database
5
5
+
import cross
11
6
import misskey.mfm_util as mfm_util
12
12
-
from util.util import LOGGER, as_envvar
13
13
-
from util.media import MediaInfo, get_filename_from_url, get_media_meta, compress_image, convert_to_mp4
7
7
+
import util.database as database
8
8
+
from bluesky.atproto2 import Client2, resolve_identity
9
9
+
from bluesky.common import ADULT_PATTERN, PORN_PATTERN, SERVICE, tokens_to_richtext
14
10
from util.database import DataBaseWorker
11
11
+
from util.media import (
12
12
+
MediaInfo,
13
13
+
compress_image,
14
14
+
convert_to_mp4,
15
15
+
get_filename_from_url,
16
16
+
get_media_meta,
17
17
+
)
18
18
+
from util.util import LOGGER, as_envvar
15
19
16
16
-
ALLOWED_GATES = ['mentioned', 'following', 'followers', 'everybody']
20
20
+
ALLOWED_GATES = ["mentioned", "following", "followers", "everybody"]
21
21
+
17
22
18
23
class BlueskyOutputOptions:
19
24
def __init__(self, o: dict) -> None:
20
25
self.quote_gate: bool = False
21
21
-
self.thread_gate: list[str] = ['everybody']
26
26
+
self.thread_gate: list[str] = ["everybody"]
22
27
self.encode_videos: bool = True
23
23
-
24
24
-
quote_gate = o.get('quote_gate')
28
28
+
29
29
+
quote_gate = o.get("quote_gate")
25
30
if quote_gate is not None:
26
31
self.quote_gate = bool(quote_gate)
27
27
-
28
28
-
thread_gate = o.get('thread_gate')
32
32
+
33
33
+
thread_gate = o.get("thread_gate")
29
34
if thread_gate is not None:
30
35
if any([v not in ALLOWED_GATES for v in thread_gate]):
31
31
-
raise ValueError(f"'thread_gate' only accepts {', '.join(ALLOWED_GATES)} or [], got: {thread_gate}")
36
36
+
raise ValueError(
37
37
+
f"'thread_gate' only accepts {', '.join(ALLOWED_GATES)} or [], got: {thread_gate}"
38
38
+
)
32
39
self.thread_gate = thread_gate
33
33
-
34
34
-
encode_videos = o.get('encode_videos')
40
40
+
41
41
+
encode_videos = o.get("encode_videos")
35
42
if encode_videos is not None:
36
43
self.encode_videos = bool(encode_videos)
44
44
+
37
45
38
46
class BlueskyOutput(cross.Output):
39
47
def __init__(self, input: cross.Input, settings: dict, db: DataBaseWorker) -> None:
40
48
super().__init__(input, settings, db)
41
41
-
self.options = BlueskyOutputOptions(settings.get('options') or {})
42
42
-
43
43
-
if not as_envvar(settings.get('app-password')):
49
49
+
self.options = BlueskyOutputOptions(settings.get("options") or {})
50
50
+
51
51
+
if not as_envvar(settings.get("app-password")):
44
52
raise Exception("Account app password not provided!")
45
45
-
53
53
+
46
54
did, pds = resolve_identity(
47
47
-
handle=as_envvar(settings.get('handle')),
48
48
-
did=as_envvar(settings.get('did')),
49
49
-
pds=as_envvar(settings.get('pds'))
55
55
+
handle=as_envvar(settings.get("handle")),
56
56
+
did=as_envvar(settings.get("did")),
57
57
+
pds=as_envvar(settings.get("pds")),
50
58
)
51
51
-
59
59
+
52
60
reqs = Request(timeout=Timeout(None, connect=30.0))
53
53
-
61
61
+
54
62
self.bsky = Client2(pds, request=reqs)
55
63
self.bsky.configure_proxy_header(
56
56
-
service_type='bsky_appview',
57
57
-
did=as_envvar(settings.get('bsky_appview')) or 'did:web:api.bsky.app'
64
64
+
service_type="bsky_appview",
65
65
+
did=as_envvar(settings.get("bsky_appview")) or "did:web:api.bsky.app",
58
66
)
59
59
-
self.bsky.login(did, as_envvar(settings.get('app-password')))
60
60
-
67
67
+
self.bsky.login(did, as_envvar(settings.get("app-password")))
68
68
+
61
69
def __check_login(self):
62
70
login = self.bsky.me
63
71
if not login:
64
72
raise Exception("Client not logged in!")
65
73
return login
66
66
-
74
74
+
67
75
def _find_parent(self, parent_id: str):
68
76
login = self.__check_login()
69
69
-
77
77
+
70
78
thread_tuple = database.find_mapped_thread(
71
79
self.db,
72
80
parent_id,
73
81
self.input.user_id,
74
82
self.input.service,
75
83
login.did,
76
76
-
SERVICE
84
84
+
SERVICE,
77
85
)
78
78
-
86
86
+
79
87
if not thread_tuple:
80
88
LOGGER.error("Failed to find thread tuple in the database!")
81
89
return None
82
82
-
90
90
+
83
91
root_uri: str = thread_tuple[0]
84
92
reply_uri: str = thread_tuple[1]
85
85
-
86
86
-
root_cid = database.fetch_data(self.db, root_uri, login.did, SERVICE)['cid']
87
87
-
reply_cid = database.fetch_data(self.db, root_uri, login.did, SERVICE)['cid']
88
88
-
89
89
-
root_record = models.AppBskyFeedPost.CreateRecordResponse(uri=root_uri, cid=root_cid)
90
90
-
reply_record = models.AppBskyFeedPost.CreateRecordResponse(uri=reply_uri, cid=reply_cid)
91
91
-
93
93
+
94
94
+
root_cid = database.fetch_data(self.db, root_uri, login.did, SERVICE)["cid"]
95
95
+
reply_cid = database.fetch_data(self.db, root_uri, login.did, SERVICE)["cid"]
96
96
+
97
97
+
root_record = models.AppBskyFeedPost.CreateRecordResponse(
98
98
+
uri=root_uri, cid=root_cid
99
99
+
)
100
100
+
reply_record = models.AppBskyFeedPost.CreateRecordResponse(
101
101
+
uri=reply_uri, cid=reply_cid
102
102
+
)
103
103
+
92
104
return (
93
105
models.create_strong_ref(root_record),
94
106
models.create_strong_ref(reply_record),
95
107
thread_tuple[2],
96
96
-
thread_tuple[3]
108
108
+
thread_tuple[3],
97
109
)
98
98
-
110
110
+
99
111
def _split_attachments(self, attachments: list[MediaInfo]):
100
112
sup_media: list[MediaInfo] = []
101
113
unsup_media: list[MediaInfo] = []
102
102
-
114
114
+
103
115
for a in attachments:
104
104
-
if a.mime.startswith('image/') or a.mime.startswith('video/'): # TODO convert gifs to videos
116
116
+
if a.mime.startswith("image/") or a.mime.startswith(
117
117
+
"video/"
118
118
+
): # TODO convert gifs to videos
105
119
sup_media.append(a)
106
120
else:
107
121
unsup_media.append(a)
108
108
-
122
122
+
109
123
return (sup_media, unsup_media)
110
124
111
125
def _split_media_per_post(
112
112
-
self,
113
113
-
tokens: list[client_utils.TextBuilder],
114
114
-
media: list[MediaInfo]):
115
115
-
126
126
+
self, tokens: list[client_utils.TextBuilder], media: list[MediaInfo]
127
127
+
):
116
128
posts: list[dict] = [{"tokens": tokens, "attachments": []} for tokens in tokens]
117
129
available_indices: list[int] = list(range(len(posts)))
118
118
-
130
130
+
119
131
current_image_post_idx: int | None = None
120
132
121
133
def make_blank_post() -> dict:
122
122
-
return {
123
123
-
"tokens": [client_utils.TextBuilder().text('')],
124
124
-
"attachments": []
125
125
-
}
126
126
-
134
134
+
return {"tokens": [client_utils.TextBuilder().text("")], "attachments": []}
135
135
+
127
136
def pop_next_empty_index() -> int:
128
137
if available_indices:
129
138
return available_indices.pop(0)
···
131
140
new_idx = len(posts)
132
141
posts.append(make_blank_post())
133
142
return new_idx
134
134
-
143
143
+
135
144
for att in media:
136
136
-
if att.mime.startswith('video/'):
145
145
+
if att.mime.startswith("video/"):
137
146
current_image_post_idx = None
138
147
idx = pop_next_empty_index()
139
148
posts[idx]["attachments"].append(att)
140
140
-
elif att.mime.startswith('image/'):
149
149
+
elif att.mime.startswith("image/"):
141
150
if (
142
151
current_image_post_idx is not None
143
152
and len(posts[current_image_post_idx]["attachments"]) < 4
···
147
156
idx = pop_next_empty_index()
148
157
posts[idx]["attachments"].append(att)
149
158
current_image_post_idx = idx
150
150
-
159
159
+
151
160
result: list[tuple[client_utils.TextBuilder, list[MediaInfo]]] = []
152
161
for p in posts:
153
162
result.append((p["tokens"], p["attachments"]))
154
163
return result
155
155
-
164
164
+
156
165
def accept_post(self, post: cross.Post):
157
166
login = self.__check_login()
158
158
-
167
167
+
159
168
parent_id = post.get_parent_id()
160
160
-
169
169
+
161
170
# used for db insertion
162
171
new_root_id = None
163
172
new_parent_id = None
164
164
-
173
173
+
165
174
root_ref = None
166
175
reply_ref = None
167
176
if parent_id:
···
169
178
if not parents:
170
179
return
171
180
root_ref, reply_ref, new_root_id, new_parent_id = parents
172
172
-
181
181
+
173
182
tokens = post.get_tokens().copy()
174
174
-
183
183
+
175
184
unique_labels: set[str] = set()
176
185
cw = post.get_spoiler()
177
186
if cw:
178
187
tokens.insert(0, cross.TextToken("CW: " + cw + "\n\n"))
179
179
-
unique_labels.add('graphic-media')
180
180
-
188
188
+
unique_labels.add("graphic-media")
189
189
+
181
190
# from bsky.app, a post can only have one of those labels
182
191
if PORN_PATTERN.search(cw):
183
183
-
unique_labels.add('porn')
192
192
+
unique_labels.add("porn")
184
193
elif ADULT_PATTERN.search(cw):
185
185
-
unique_labels.add('sexual')
186
186
-
194
194
+
unique_labels.add("sexual")
195
195
+
187
196
if post.is_sensitive():
188
188
-
unique_labels.add('graphic-media')
189
189
-
190
190
-
labels = models.ComAtprotoLabelDefs.SelfLabels(values=[models.ComAtprotoLabelDefs.SelfLabel(val=label) for label in unique_labels])
197
197
+
unique_labels.add("graphic-media")
198
198
+
199
199
+
labels = models.ComAtprotoLabelDefs.SelfLabels(
200
200
+
values=[
201
201
+
models.ComAtprotoLabelDefs.SelfLabel(val=label)
202
202
+
for label in unique_labels
203
203
+
]
204
204
+
)
191
205
192
206
sup_media, unsup_media = self._split_attachments(post.get_attachments())
193
207
194
208
if unsup_media:
195
209
if tokens:
196
196
-
tokens.append(cross.TextToken('\n'))
210
210
+
tokens.append(cross.TextToken("\n"))
197
211
for i, attachment in enumerate(unsup_media):
198
198
-
tokens.append(cross.LinkToken(
199
199
-
attachment.url,
200
200
-
f"[{get_filename_from_url(attachment.url)}]"
201
201
-
))
202
202
-
tokens.append(cross.TextToken(' '))
203
203
-
212
212
+
tokens.append(
213
213
+
cross.LinkToken(
214
214
+
attachment.url, f"[{get_filename_from_url(attachment.url)}]"
215
215
+
)
216
216
+
)
217
217
+
tokens.append(cross.TextToken(" "))
218
218
+
204
219
if post.get_text_type() == "text/x.misskeymarkdown":
205
220
tokens, status = mfm_util.strip_mfm(tokens)
206
221
post_url = post.get_post_url()
207
222
if status and post_url:
208
208
-
tokens.append(cross.TextToken('\n'))
209
209
-
tokens.append(cross.LinkToken(post_url, "[Post contains MFM, see original]"))
210
210
-
223
223
+
tokens.append(cross.TextToken("\n"))
224
224
+
tokens.append(
225
225
+
cross.LinkToken(post_url, "[Post contains MFM, see original]")
226
226
+
)
227
227
+
211
228
split_tokens: list[list[cross.Token]] = cross.split_tokens(tokens, 300)
212
229
post_text: list[client_utils.TextBuilder] = []
213
213
-
230
230
+
214
231
# convert tokens into rich text. skip post if contains unsupported tokens
215
232
for block in split_tokens:
216
233
rich_text = tokens_to_richtext(block)
217
217
-
234
234
+
218
235
if not rich_text:
219
219
-
LOGGER.error("Skipping '%s' as it contains invalid rich text types!", post.get_id())
236
236
+
LOGGER.error(
237
237
+
"Skipping '%s' as it contains invalid rich text types!",
238
238
+
post.get_id(),
239
239
+
)
220
240
return
221
241
post_text.append(rich_text)
222
222
-
242
242
+
223
243
if not post_text:
224
224
-
post_text = [client_utils.TextBuilder().text('')]
225
225
-
244
244
+
post_text = [client_utils.TextBuilder().text("")]
245
245
+
226
246
for m in sup_media:
227
227
-
if m.mime.startswith('image/'):
247
247
+
if m.mime.startswith("image/"):
228
248
if len(m.io) > 2_000_000:
229
229
-
LOGGER.error("Skipping post_id '%s', failed to download attachment! File too large.", post.get_id())
249
249
+
LOGGER.error(
250
250
+
"Skipping post_id '%s', failed to download attachment! File too large.",
251
251
+
post.get_id(),
252
252
+
)
230
253
return
231
231
-
232
232
-
if m.mime.startswith('video/'):
233
233
-
if m.mime != 'video/mp4' and not self.options.encode_videos:
234
234
-
LOGGER.info("Video is not mp4, but encoding is disabled. Skipping '%s'...", post.get_id())
254
254
+
255
255
+
if m.mime.startswith("video/"):
256
256
+
if m.mime != "video/mp4" and not self.options.encode_videos:
257
257
+
LOGGER.info(
258
258
+
"Video is not mp4, but encoding is disabled. Skipping '%s'...",
259
259
+
post.get_id(),
260
260
+
)
235
261
return
236
236
-
262
262
+
237
263
if len(m.io) > 100_000_000:
238
238
-
LOGGER.error("Skipping post_id '%s', failed to download attachment! File too large?", post.get_id())
264
264
+
LOGGER.error(
265
265
+
"Skipping post_id '%s', failed to download attachment! File too large?",
266
266
+
post.get_id(),
267
267
+
)
239
268
return
240
240
-
269
269
+
241
270
created_records: list[models.AppBskyFeedPost.CreateRecordResponse] = []
242
271
baked_media = self._split_media_per_post(post_text, sup_media)
243
243
-
272
272
+
244
273
for text, attachments in baked_media:
245
274
if not attachments:
246
275
if reply_ref and root_ref:
247
247
-
new_post = self.bsky.send_post(text, reply_to=models.AppBskyFeedPost.ReplyRef(
248
248
-
parent=reply_ref,
249
249
-
root=root_ref
250
250
-
), labels=labels, time_iso=post.get_timestamp())
276
276
+
new_post = self.bsky.send_post(
277
277
+
text,
278
278
+
reply_to=models.AppBskyFeedPost.ReplyRef(
279
279
+
parent=reply_ref, root=root_ref
280
280
+
),
281
281
+
labels=labels,
282
282
+
time_iso=post.get_timestamp(),
283
283
+
)
251
284
else:
252
252
-
new_post = self.bsky.send_post(text, labels=labels, time_iso=post.get_timestamp())
285
285
+
new_post = self.bsky.send_post(
286
286
+
text, labels=labels, time_iso=post.get_timestamp()
287
287
+
)
253
288
root_ref = models.create_strong_ref(new_post)
254
254
-
289
289
+
255
290
self.bsky.create_gates(
256
256
-
self.options.thread_gate,
257
257
-
self.options.quote_gate,
258
258
-
new_post.uri,
259
259
-
time_iso=post.get_timestamp()
291
291
+
self.options.thread_gate,
292
292
+
self.options.quote_gate,
293
293
+
new_post.uri,
294
294
+
time_iso=post.get_timestamp(),
260
295
)
261
296
reply_ref = models.create_strong_ref(new_post)
262
297
created_records.append(new_post)
263
298
else:
264
299
# if a single post is an image - everything else is an image
265
265
-
if attachments[0].mime.startswith('image/'):
300
300
+
if attachments[0].mime.startswith("image/"):
266
301
images: list[bytes] = []
267
302
image_alts: list[str] = []
268
303
image_aspect_ratios: list[models.AppBskyEmbedDefs.AspectRatio] = []
269
269
-
304
304
+
270
305
for attachment in attachments:
271
306
image_io = compress_image(attachment.io, quality=100)
272
307
metadata = get_media_meta(image_io)
273
273
-
308
308
+
274
309
if len(image_io) > 1_000_000:
275
310
LOGGER.info("Compressing %s...", attachment.name)
276
311
image_io = compress_image(image_io)
277
277
-
312
312
+
278
313
images.append(image_io)
279
314
image_alts.append(attachment.alt)
280
280
-
image_aspect_ratios.append(models.AppBskyEmbedDefs.AspectRatio(
281
281
-
width=metadata['width'],
282
282
-
height=metadata['height']
283
283
-
))
284
284
-
315
315
+
image_aspect_ratios.append(
316
316
+
models.AppBskyEmbedDefs.AspectRatio(
317
317
+
width=metadata["width"], height=metadata["height"]
318
318
+
)
319
319
+
)
320
320
+
285
321
new_post = self.bsky.send_images(
286
322
text=post_text[0],
287
323
images=images,
288
324
image_alts=image_alts,
289
325
image_aspect_ratios=image_aspect_ratios,
290
290
-
reply_to= models.AppBskyFeedPost.ReplyRef(
291
291
-
parent=reply_ref,
292
292
-
root=root_ref
293
293
-
) if root_ref and reply_ref else None,
294
294
-
labels=labels,
295
295
-
time_iso=post.get_timestamp()
326
326
+
reply_to=models.AppBskyFeedPost.ReplyRef(
327
327
+
parent=reply_ref, root=root_ref
328
328
+
)
329
329
+
if root_ref and reply_ref
330
330
+
else None,
331
331
+
labels=labels,
332
332
+
time_iso=post.get_timestamp(),
296
333
)
297
334
if not root_ref:
298
335
root_ref = models.create_strong_ref(new_post)
299
299
-
336
336
+
300
337
self.bsky.create_gates(
301
301
-
self.options.thread_gate,
338
338
+
self.options.thread_gate,
302
339
self.options.quote_gate,
303
303
-
new_post.uri,
304
304
-
time_iso=post.get_timestamp()
340
340
+
new_post.uri,
341
341
+
time_iso=post.get_timestamp(),
305
342
)
306
343
reply_ref = models.create_strong_ref(new_post)
307
344
created_records.append(new_post)
308
308
-
else: # video is guarantedd to be one
345
345
+
else: # video is guarantedd to be one
309
346
metadata = get_media_meta(attachments[0].io)
310
310
-
if metadata['duration'] > 180:
311
311
-
LOGGER.info("Skipping post_id '%s', video attachment too long!", post.get_id())
347
347
+
if metadata["duration"] > 180:
348
348
+
LOGGER.info(
349
349
+
"Skipping post_id '%s', video attachment too long!",
350
350
+
post.get_id(),
351
351
+
)
312
352
return
313
313
-
353
353
+
314
354
video_io = attachments[0].io
315
315
-
if attachments[0].mime != 'video/mp4':
355
355
+
if attachments[0].mime != "video/mp4":
316
356
LOGGER.info("Converting %s to mp4...", attachments[0].name)
317
357
video_io = convert_to_mp4(video_io)
318
318
-
358
358
+
319
359
aspect_ratio = models.AppBskyEmbedDefs.AspectRatio(
320
320
-
width=metadata['width'],
321
321
-
height=metadata['height']
360
360
+
width=metadata["width"], height=metadata["height"]
322
361
)
323
323
-
362
362
+
324
363
new_post = self.bsky.send_video(
325
364
text=post_text[0],
326
365
video=video_io,
327
366
video_aspect_ratio=aspect_ratio,
328
367
video_alt=attachments[0].alt,
329
329
-
reply_to= models.AppBskyFeedPost.ReplyRef(
330
330
-
parent=reply_ref,
331
331
-
root=root_ref
332
332
-
) if root_ref and reply_ref else None,
368
368
+
reply_to=models.AppBskyFeedPost.ReplyRef(
369
369
+
parent=reply_ref, root=root_ref
370
370
+
)
371
371
+
if root_ref and reply_ref
372
372
+
else None,
333
373
labels=labels,
334
334
-
time_iso=post.get_timestamp()
374
374
+
time_iso=post.get_timestamp(),
335
375
)
336
376
if not root_ref:
337
377
root_ref = models.create_strong_ref(new_post)
338
338
-
378
378
+
339
379
self.bsky.create_gates(
340
380
self.options.thread_gate,
341
341
-
self.options.quote_gate,
342
342
-
new_post.uri,
343
343
-
time_iso=post.get_timestamp()
381
381
+
self.options.quote_gate,
382
382
+
new_post.uri,
383
383
+
time_iso=post.get_timestamp(),
344
384
)
345
385
reply_ref = models.create_strong_ref(new_post)
346
386
created_records.append(new_post)
347
347
-
348
348
-
db_post = database.find_post(self.db, post.get_id(), self.input.user_id, self.input.service)
387
387
+
388
388
+
db_post = database.find_post(
389
389
+
self.db, post.get_id(), self.input.user_id, self.input.service
390
390
+
)
349
391
assert db_post, "ghghghhhhh"
350
350
-
351
351
-
if new_root_id is None or new_parent_id is None:
392
392
+
393
393
+
if new_root_id is None or new_parent_id is None:
352
394
new_root_id = database.insert_post(
395
395
+
self.db, created_records[0].uri, login.did, SERVICE
396
396
+
)
397
397
+
database.store_data(
353
398
self.db,
354
399
created_records[0].uri,
355
400
login.did,
356
356
-
SERVICE
357
357
-
)
358
358
-
database.store_data(
359
359
-
self.db,
360
360
-
created_records[0].uri,
361
361
-
login.did,
362
401
SERVICE,
363
363
-
{'cid': created_records[0].cid}
402
402
+
{"cid": created_records[0].cid},
364
403
)
365
365
-
404
404
+
366
405
new_parent_id = new_root_id
367
367
-
database.insert_mapping(self.db, db_post['id'], new_parent_id)
406
406
+
database.insert_mapping(self.db, db_post["id"], new_parent_id)
368
407
created_records = created_records[1:]
369
369
-
408
408
+
370
409
for record in created_records:
371
410
new_parent_id = database.insert_reply(
372
372
-
self.db,
373
373
-
record.uri,
374
374
-
login.did,
375
375
-
SERVICE,
376
376
-
new_parent_id,
377
377
-
new_root_id
411
411
+
self.db, record.uri, login.did, SERVICE, new_parent_id, new_root_id
378
412
)
379
413
database.store_data(
380
380
-
self.db,
381
381
-
record.uri,
382
382
-
login.did,
383
383
-
SERVICE,
384
384
-
{'cid': record.cid}
414
414
+
self.db, record.uri, login.did, SERVICE, {"cid": record.cid}
385
415
)
386
386
-
database.insert_mapping(self.db, db_post['id'], new_parent_id)
387
387
-
416
416
+
database.insert_mapping(self.db, db_post["id"], new_parent_id)
417
417
+
388
418
def delete_post(self, identifier: str):
389
419
login = self.__check_login()
390
390
-
391
391
-
post = database.find_post(self.db, identifier, self.input.user_id, self.input.service)
420
420
+
421
421
+
post = database.find_post(
422
422
+
self.db, identifier, self.input.user_id, self.input.service
423
423
+
)
392
424
if not post:
393
425
return
394
394
-
395
395
-
mappings = database.find_mappings(self.db, post['id'], SERVICE, login.did)
426
426
+
427
427
+
mappings = database.find_mappings(self.db, post["id"], SERVICE, login.did)
396
428
for mapping in mappings[::-1]:
397
429
LOGGER.info("Deleting '%s'...", mapping[0])
398
430
self.bsky.delete_post(mapping[0])
399
431
database.delete_post(self.db, mapping[0], SERVICE, login.did)
400
400
-
432
432
+
401
433
def accept_repost(self, repost_id: str, reposted_id: str):
402
434
login, repost = self.__delete_repost(repost_id)
403
435
if not (login and repost):
404
436
return
405
405
-
406
406
-
reposted = database.find_post(self.db, reposted_id, self.input.user_id, self.input.service)
437
437
+
438
438
+
reposted = database.find_post(
439
439
+
self.db, reposted_id, self.input.user_id, self.input.service
440
440
+
)
407
441
if not reposted:
408
442
return
409
409
-
443
443
+
410
444
# mappings of the reposted post
411
411
-
mappings = database.find_mappings(self.db, reposted['id'], SERVICE, login.did)
445
445
+
mappings = database.find_mappings(self.db, reposted["id"], SERVICE, login.did)
412
446
if mappings:
413
413
-
cid = database.fetch_data(self.db, mappings[0][0], login.did, SERVICE)['cid']
447
447
+
cid = database.fetch_data(self.db, mappings[0][0], login.did, SERVICE)[
448
448
+
"cid"
449
449
+
]
414
450
rsp = self.bsky.repost(mappings[0][0], cid)
415
415
-
451
451
+
416
452
internal_id = database.insert_repost(
417
417
-
self.db,
418
418
-
rsp.uri,
419
419
-
reposted['id'],
420
420
-
login.did,
421
421
-
SERVICE)
422
422
-
database.store_data(
423
423
-
self.db,
424
424
-
rsp.uri,
425
425
-
login.did,
426
426
-
SERVICE,
427
427
-
{'cid': rsp.cid}
453
453
+
self.db, rsp.uri, reposted["id"], login.did, SERVICE
428
454
)
429
429
-
database.insert_mapping(self.db, repost['id'], internal_id)
430
430
-
431
431
-
def __delete_repost(self, repost_id: str) -> tuple[models.AppBskyActorDefs.ProfileViewDetailed | None, dict | None]:
455
455
+
database.store_data(self.db, rsp.uri, login.did, SERVICE, {"cid": rsp.cid})
456
456
+
database.insert_mapping(self.db, repost["id"], internal_id)
457
457
+
458
458
+
def __delete_repost(
459
459
+
self, repost_id: str
460
460
+
) -> tuple[models.AppBskyActorDefs.ProfileViewDetailed | None, dict | None]:
432
461
login = self.__check_login()
433
433
-
434
434
-
repost = database.find_post(self.db, repost_id, self.input.user_id, self.input.service)
462
462
+
463
463
+
repost = database.find_post(
464
464
+
self.db, repost_id, self.input.user_id, self.input.service
465
465
+
)
435
466
if not repost:
436
467
return None, None
437
437
-
438
438
-
mappings = database.find_mappings(self.db, repost['id'], SERVICE, login.did)
468
468
+
469
469
+
mappings = database.find_mappings(self.db, repost["id"], SERVICE, login.did)
439
470
if mappings:
440
471
LOGGER.info("Deleting '%s'...", mappings[0][0])
441
472
self.bsky.unrepost(mappings[0][0])
442
473
database.delete_post(self.db, mappings[0][0], login.did, SERVICE)
443
474
return login, repost
444
444
-
475
475
+
445
476
def delete_repost(self, repost_id: str):
446
477
self.__delete_repost(repost_id)
447
447
-
448
448
-
+79
-61
cross.py
reviewed
···
1
1
+
import re
1
2
from abc import ABC, abstractmethod
2
2
-
from typing import Callable, Any
3
3
+
from datetime import datetime, timezone
4
4
+
from typing import Any, Callable
5
5
+
3
6
from util.database import DataBaseWorker
4
4
-
from datetime import datetime, timezone
5
7
from util.media import MediaInfo
6
8
from util.util import LOGGER, canonical_label
7
7
-
import re
8
9
9
9
-
ALTERNATE = re.compile(r'\S+|\s+')
10
10
+
ALTERNATE = re.compile(r"\S+|\s+")
11
11
+
10
12
11
13
# generic token
12
12
-
class Token():
14
14
+
class Token:
13
15
def __init__(self, type: str) -> None:
14
16
self.type = type
15
17
18
18
+
16
19
class TextToken(Token):
17
20
def __init__(self, text: str) -> None:
18
18
-
super().__init__('text')
21
21
+
super().__init__("text")
19
22
self.text = text
23
23
+
20
24
21
25
# token that represents a link to a website. e.g. [link](https://google.com/)
22
26
class LinkToken(Token):
23
27
def __init__(self, href: str, label: str) -> None:
24
24
-
super().__init__('link')
28
28
+
super().__init__("link")
25
29
self.href = href
26
30
self.label = label
27
27
-
28
28
-
# token that represents a hashtag. e.g. #SocialMedia
31
31
+
32
32
+
33
33
+
# token that represents a hashtag. e.g. #SocialMedia
29
34
class TagToken(Token):
30
35
def __init__(self, tag: str) -> None:
31
31
-
super().__init__('tag')
36
36
+
super().__init__("tag")
32
37
self.tag = tag
38
38
+
33
39
34
40
# token that represents a mention of a user.
35
41
class MentionToken(Token):
36
42
def __init__(self, username: str, uri: str) -> None:
37
37
-
super().__init__('mention')
43
43
+
super().__init__("mention")
38
44
self.username = username
39
45
self.uri = uri
40
40
-
41
41
-
class MediaMeta():
46
46
+
47
47
+
48
48
+
class MediaMeta:
42
49
def __init__(self, width: int, height: int, duration: float) -> None:
43
50
self.width = width
44
51
self.height = height
45
52
self.duration = duration
46
46
-
53
53
+
47
54
def get_width(self) -> int:
48
55
return self.width
49
49
-
56
56
+
50
57
def get_height(self) -> int:
51
58
return self.height
52
52
-
59
59
+
53
60
def get_duration(self) -> float:
54
61
return self.duration
55
55
-
62
62
+
63
63
+
56
64
class Post(ABC):
57
65
@abstractmethod
58
66
def get_id(self) -> str:
59
59
-
return ''
60
60
-
67
67
+
return ""
68
68
+
61
69
@abstractmethod
62
70
def get_parent_id(self) -> str | None:
63
71
pass
64
64
-
72
72
+
65
73
@abstractmethod
66
74
def get_tokens(self) -> list[Token]:
67
75
pass
···
71
79
@abstractmethod
72
80
def get_text_type(self) -> str:
73
81
pass
74
74
-
82
82
+
75
83
# post iso timestamp
76
84
@abstractmethod
77
85
def get_timestamp(self) -> str:
78
86
pass
79
79
-
87
87
+
80
88
def get_attachments(self) -> list[MediaInfo]:
81
89
return []
82
82
-
90
90
+
83
91
def get_spoiler(self) -> str | None:
84
92
return None
85
85
-
93
93
+
86
94
def get_languages(self) -> list[str]:
87
95
return []
88
88
-
96
96
+
89
97
def is_sensitive(self) -> bool:
90
98
return False
91
91
-
99
99
+
92
100
def get_post_url(self) -> str | None:
93
101
return None
94
102
103
103
+
95
104
# generic input service.
96
105
# user and service for db queries
97
97
-
class Input():
98
98
-
def __init__(self, service: str, user_id: str, settings: dict, db: DataBaseWorker) -> None:
106
106
+
class Input:
107
107
+
def __init__(
108
108
+
self, service: str, user_id: str, settings: dict, db: DataBaseWorker
109
109
+
) -> None:
99
110
self.service = service
100
111
self.user_id = user_id
101
112
self.settings = settings
102
113
self.db = db
103
103
-
114
114
+
104
115
async def listen(self, outputs: list, handler: Callable[[Post], Any]):
105
116
pass
106
117
107
107
-
class Output():
118
118
+
119
119
+
class Output:
108
120
def __init__(self, input: Input, settings: dict, db: DataBaseWorker) -> None:
109
121
self.input = input
110
122
self.settings = settings
111
123
self.db = db
112
112
-
124
124
+
113
125
def accept_post(self, post: Post):
114
126
LOGGER.warning('Not Implemented.. "posted" %s', post.get_id())
115
115
-
127
127
+
116
128
def delete_post(self, identifier: str):
117
129
LOGGER.warning('Not Implemented.. "deleted" %s', identifier)
118
118
-
130
130
+
119
131
def accept_repost(self, repost_id: str, reposted_id: str):
120
132
LOGGER.warning('Not Implemented.. "reblogged" %s, %s', repost_id, reposted_id)
121
121
-
133
133
+
122
134
def delete_repost(self, repost_id: str):
123
135
LOGGER.warning('Not Implemented.. "removed reblog" %s', repost_id)
136
136
+
124
137
125
138
def test_filters(tokens: list[Token], filters: list[re.Pattern[str]]):
126
139
if not tokens or not filters:
127
140
return True
128
128
-
129
129
-
markdown = ''
130
130
-
141
141
+
142
142
+
markdown = ""
143
143
+
131
144
for token in tokens:
132
145
if isinstance(token, TextToken):
133
146
markdown += token.text
134
147
elif isinstance(token, LinkToken):
135
135
-
markdown += f'[{token.label}]({token.href})'
148
148
+
markdown += f"[{token.label}]({token.href})"
136
149
elif isinstance(token, TagToken):
137
137
-
markdown += '#' + token.tag
150
150
+
markdown += "#" + token.tag
138
151
elif isinstance(token, MentionToken):
139
152
markdown += token.username
140
140
-
153
153
+
141
154
for filter in filters:
142
155
if filter.search(markdown):
143
156
return False
144
144
-
157
157
+
145
158
return True
146
159
147
147
-
def split_tokens(tokens: list[Token], max_chars: int, max_link_len: int = 35) -> list[list[Token]]:
160
160
+
161
161
+
def split_tokens(
162
162
+
tokens: list[Token], max_chars: int, max_link_len: int = 35
163
163
+
) -> list[list[Token]]:
148
164
def new_block():
149
165
nonlocal blocks, block, length
150
166
if block:
151
167
blocks.append(block)
152
168
block = []
153
169
length = 0
154
154
-
170
170
+
155
171
def append_text(text_segment):
156
172
nonlocal block
157
173
# if the last element in the current block is also text, just append to it
···
159
175
block[-1].text += text_segment
160
176
else:
161
177
block.append(TextToken(text_segment))
162
162
-
178
178
+
163
179
blocks: list[list[Token]] = []
164
180
block: list[Token] = []
165
181
length = 0
166
166
-
182
182
+
167
183
for tk in tokens:
168
184
if isinstance(tk, TagToken):
169
169
-
tag_len = 1 + len(tk.tag) # (#) + tag
185
185
+
tag_len = 1 + len(tk.tag) # (#) + tag
170
186
if length + tag_len > max_chars:
171
171
-
new_block() # create new block if the current one is too large
172
172
-
187
187
+
new_block() # create new block if the current one is too large
188
188
+
173
189
block.append(tk)
174
190
length += tag_len
175
175
-
elif isinstance(tk, LinkToken): # TODO labels should proably be split too
191
191
+
elif isinstance(tk, LinkToken): # TODO labels should proably be split too
176
192
link_len = len(tk.label)
177
177
-
if canonical_label(tk.label, tk.href): # cut down the link if the label is canonical
193
193
+
if canonical_label(
194
194
+
tk.label, tk.href
195
195
+
): # cut down the link if the label is canonical
178
196
link_len = min(link_len, max_link_len)
179
179
-
197
197
+
180
198
if length + link_len > max_chars:
181
199
new_block()
182
200
block.append(tk)
183
201
length += link_len
184
202
elif isinstance(tk, TextToken):
185
203
segments: list[str] = ALTERNATE.findall(tk.text)
186
186
-
204
204
+
187
205
for seg in segments:
188
206
seg_len: int = len(seg)
189
207
if length + seg_len <= max_chars - (0 if seg.isspace() else 1):
190
208
append_text(seg)
191
209
length += seg_len
192
210
continue
193
193
-
211
211
+
194
212
if length > 0:
195
213
new_block()
196
196
-
214
214
+
197
215
if not seg.isspace():
198
216
while len(seg) > max_chars - 1:
199
217
chunk = seg[: max_chars - 1] + "-"
···
202
220
seg = seg[max_chars - 1 :]
203
221
else:
204
222
while len(seg) > max_chars:
205
205
-
chunk = seg[: max_chars]
223
223
+
chunk = seg[:max_chars]
206
224
append_text(chunk)
207
225
new_block()
208
208
-
seg = seg[max_chars :]
209
209
-
226
226
+
seg = seg[max_chars:]
227
227
+
210
228
if seg:
211
229
append_text(seg)
212
230
length = len(seg)
213
213
-
else: #TODO fix mentions
231
231
+
else: # TODO fix mentions
214
232
block.append(tk)
215
215
-
233
233
+
216
234
if block:
217
235
blocks.append(block)
218
218
-
219
219
-
return blocks
236
236
+
237
237
+
return blocks
+59
-54
main.py
reviewed
···
1
1
-
import os
1
1
+
import asyncio
2
2
import json
3
3
-
import asyncio, threading, queue, traceback
3
3
+
import os
4
4
+
import queue
5
5
+
import threading
6
6
+
import traceback
4
7
5
5
-
from util.util import LOGGER, as_json
6
6
-
import cross, util.database as database
7
7
-
8
8
+
import cross
9
9
+
import util.database as database
8
10
from bluesky.input import BlueskyJetstreamInput
9
9
-
from bluesky.output import BlueskyOutputOptions, BlueskyOutput
10
10
-
11
11
-
from mastodon.input import MastodonInputOptions, MastodonInput
11
11
+
from bluesky.output import BlueskyOutput, BlueskyOutputOptions
12
12
+
from mastodon.input import MastodonInput, MastodonInputOptions
12
13
from mastodon.output import MastodonOutput
13
13
-
14
14
from misskey.input import MisskeyInput
15
15
+
from util.util import LOGGER, as_json
15
16
16
17
DEFAULT_SETTINGS: dict = {
17
17
-
'input': {
18
18
-
'type': 'mastodon-wss',
19
19
-
'instance': 'env:MASTODON_INSTANCE',
20
20
-
'token': 'env:MASTODON_TOKEN',
21
21
-
"options": MastodonInputOptions({})
18
18
+
"input": {
19
19
+
"type": "mastodon-wss",
20
20
+
"instance": "env:MASTODON_INSTANCE",
21
21
+
"token": "env:MASTODON_TOKEN",
22
22
+
"options": MastodonInputOptions({}),
22
23
},
23
23
-
'outputs': [
24
24
+
"outputs": [
24
25
{
25
25
-
'type': 'bluesky',
26
26
-
'handle': 'env:BLUESKY_HANDLE',
27
27
-
'app-password': 'env:BLUESKY_APP_PASSWORD',
28
28
-
'options': BlueskyOutputOptions({})
26
26
+
"type": "bluesky",
27
27
+
"handle": "env:BLUESKY_HANDLE",
28
28
+
"app-password": "env:BLUESKY_APP_PASSWORD",
29
29
+
"options": BlueskyOutputOptions({}),
29
30
}
30
30
-
]
31
31
+
],
31
32
}
32
33
33
34
INPUTS = {
34
35
"mastodon-wss": lambda settings, db: MastodonInput(settings, db),
35
36
"misskey-wss": lambda settigs, db: MisskeyInput(settigs, db),
36
36
-
"bluesky-jetstream-wss": lambda settings, db: BlueskyJetstreamInput(settings, db)
37
37
+
"bluesky-jetstream-wss": lambda settings, db: BlueskyJetstreamInput(settings, db),
37
38
}
38
39
39
40
OUTPUTS = {
40
41
"bluesky": lambda input, settings, db: BlueskyOutput(input, settings, db),
41
41
-
"mastodon": lambda input, settings, db: MastodonOutput(input, settings, db)
42
42
+
"mastodon": lambda input, settings, db: MastodonOutput(input, settings, db),
42
43
}
44
44
+
43
45
44
46
def execute(data_dir):
45
47
if not os.path.exists(data_dir):
46
48
os.makedirs(data_dir)
47
47
-
48
48
-
settings_path = os.path.join(data_dir, 'settings.json')
49
49
-
database_path = os.path.join(data_dir, 'data.db')
50
50
-
49
49
+
50
50
+
settings_path = os.path.join(data_dir, "settings.json")
51
51
+
database_path = os.path.join(data_dir, "data.db")
52
52
+
51
53
if not os.path.exists(settings_path):
52
54
LOGGER.info("First launch detected! Creating %s and exiting!", settings_path)
53
53
-
54
54
-
with open(settings_path, 'w') as f:
55
55
+
56
56
+
with open(settings_path, "w") as f:
55
57
f.write(as_json(DEFAULT_SETTINGS, indent=2))
56
58
return 0
57
59
58
58
-
LOGGER.info('Loading settings...')
59
59
-
with open(settings_path, 'rb') as f:
60
60
+
LOGGER.info("Loading settings...")
61
61
+
with open(settings_path, "rb") as f:
60
62
settings = json.load(f)
61
61
-
62
62
-
LOGGER.info('Starting database worker...')
63
63
+
64
64
+
LOGGER.info("Starting database worker...")
63
65
db_worker = database.DataBaseWorker(os.path.abspath(database_path))
64
64
-
65
65
-
db_worker.execute('PRAGMA foreign_keys = ON;')
66
66
-
66
66
+
67
67
+
db_worker.execute("PRAGMA foreign_keys = ON;")
68
68
+
67
69
# create the posts table
68
70
# id - internal id of the post
69
71
# user_id - user id on the service (e.g. a724sknj5y9ydk0w)
···
82
84
);
83
85
"""
84
86
)
85
85
-
87
87
+
86
88
columns = db_worker.execute("PRAGMA table_info(posts)")
87
89
column_names = [col[1] for col in columns]
88
90
if "reposted_id" not in column_names:
···
95
97
ALTER TABLE posts
96
98
ADD COLUMN extra_data TEXT NULL
97
99
""")
98
98
-
100
100
+
99
101
# create the mappings table
100
102
# original_post_id - the post this was mapped from
101
103
# mapped_post_id - the post this was mapped to
···
107
109
);
108
110
"""
109
111
)
110
110
-
111
111
-
input_settings = settings.get('input')
112
112
+
113
113
+
input_settings = settings.get("input")
112
114
if not input_settings:
113
115
raise Exception("No input specified!")
114
114
-
outputs_settings = settings.get('outputs', [])
115
115
-
116
116
-
input = INPUTS[input_settings['type']](input_settings, db_worker)
117
117
-
116
116
+
outputs_settings = settings.get("outputs", [])
117
117
+
118
118
+
input = INPUTS[input_settings["type"]](input_settings, db_worker)
119
119
+
118
120
if not outputs_settings:
119
121
LOGGER.warning("No outputs specified! Check the config!")
120
120
-
122
122
+
121
123
outputs: list[cross.Output] = []
122
124
for output_settings in outputs_settings:
123
123
-
outputs.append(OUTPUTS[output_settings['type']](input, output_settings, db_worker))
124
124
-
125
125
-
LOGGER.info('Starting task worker...')
125
125
+
outputs.append(
126
126
+
OUTPUTS[output_settings["type"]](input, output_settings, db_worker)
127
127
+
)
128
128
+
129
129
+
LOGGER.info("Starting task worker...")
130
130
+
126
131
def worker(queue: queue.Queue):
127
132
while True:
128
133
task = queue.get()
129
134
if task is None:
130
135
break
131
131
-
136
136
+
132
137
try:
133
138
task()
134
139
except Exception as e:
···
136
141
traceback.print_exc()
137
142
finally:
138
143
queue.task_done()
139
139
-
144
144
+
140
145
task_queue = queue.Queue()
141
146
thread = threading.Thread(target=worker, args=(task_queue,), daemon=True)
142
147
thread.start()
143
143
-
144
144
-
LOGGER.info('Connecting to %s...', input.service)
148
148
+
149
149
+
LOGGER.info("Connecting to %s...", input.service)
145
150
try:
146
151
asyncio.run(input.listen(outputs, lambda x: task_queue.put(x)))
147
152
except KeyboardInterrupt:
148
153
LOGGER.info("Stopping...")
149
149
-
154
154
+
150
155
task_queue.join()
151
156
task_queue.put(None)
152
157
thread.join()
153
153
-
158
158
+
154
159
155
160
if __name__ == "__main__":
156
156
-
execute('./data')
161
161
+
execute("./data")
+26
-20
mastodon/common.py
reviewed
···
1
1
import cross
2
2
from util.media import MediaInfo
3
3
4
4
+
4
5
class MastodonPost(cross.Post):
5
5
-
def __init__(self, status: dict, tokens: list[cross.Token], media_attachments: list[MediaInfo]) -> None:
6
6
+
def __init__(
7
7
+
self,
8
8
+
status: dict,
9
9
+
tokens: list[cross.Token],
10
10
+
media_attachments: list[MediaInfo],
11
11
+
) -> None:
6
12
super().__init__()
7
7
-
self.id = status['id']
8
8
-
self.parent_id = status.get('in_reply_to_id')
13
13
+
self.id = status["id"]
14
14
+
self.parent_id = status.get("in_reply_to_id")
9
15
self.tokens = tokens
10
10
-
self.content_type = status.get('content_type', 'text/plain')
11
11
-
self.timestamp = status['created_at']
16
16
+
self.content_type = status.get("content_type", "text/plain")
17
17
+
self.timestamp = status["created_at"]
12
18
self.media_attachments = media_attachments
13
13
-
self.spoiler = status.get('spoiler_text')
14
14
-
self.language = [status['language']] if status.get('language') else []
15
15
-
self.sensitive = status.get('sensitive', False)
16
16
-
self.url = status.get('url')
17
17
-
19
19
+
self.spoiler = status.get("spoiler_text")
20
20
+
self.language = [status["language"]] if status.get("language") else []
21
21
+
self.sensitive = status.get("sensitive", False)
22
22
+
self.url = status.get("url")
23
23
+
18
24
def get_id(self) -> str:
19
25
return self.id
20
20
-
26
26
+
21
27
def get_parent_id(self) -> str | None:
22
28
return self.parent_id
23
23
-
29
29
+
24
30
def get_tokens(self) -> list[cross.Token]:
25
31
return self.tokens
26
26
-
32
32
+
27
33
def get_text_type(self) -> str:
28
34
return self.content_type
29
29
-
35
35
+
30
36
def get_timestamp(self) -> str:
31
37
return self.timestamp
32
32
-
38
38
+
33
39
def get_attachments(self) -> list[MediaInfo]:
34
40
return self.media_attachments
35
35
-
41
41
+
36
42
def get_spoiler(self) -> str | None:
37
43
return self.spoiler
38
38
-
44
44
+
39
45
def get_languages(self) -> list[str]:
40
46
return self.language
41
41
-
47
47
+
42
48
def is_sensitive(self) -> bool:
43
49
return self.sensitive or (self.spoiler is not None)
44
44
-
50
50
+
45
51
def get_post_url(self) -> str | None:
46
46
-
return self.url
52
52
+
return self.url
+124
-96
mastodon/input.py
reviewed
···
1
1
-
import requests, websockets
1
1
+
import asyncio
2
2
import json
3
3
import re
4
4
-
import asyncio
4
4
+
from typing import Any, Callable
5
5
6
6
-
from mastodon.common import MastodonPost
6
6
+
import requests
7
7
+
import websockets
8
8
+
9
9
+
import cross
10
10
+
import util.database as database
7
11
import util.html_util as html_util
8
12
import util.md_util as md_util
9
9
-
10
10
-
import cross, util.database as database
11
11
-
from util.util import LOGGER, as_envvar
12
12
-
from util.media import MediaInfo, download_media
13
13
+
from mastodon.common import MastodonPost
13
14
from util.database import DataBaseWorker
15
15
+
from util.media import MediaInfo, download_media
16
16
+
from util.util import LOGGER, as_envvar
14
17
15
15
-
from typing import Callable, Any
18
18
+
ALLOWED_VISIBILITY = ["public", "unlisted"]
19
19
+
MARKDOWNY = ["text/x.misskeymarkdown", "text/markdown", "text/plain"]
16
20
17
17
-
ALLOWED_VISIBILITY = ['public', 'unlisted']
18
18
-
MARKDOWNY = ['text/x.misskeymarkdown', 'text/markdown', 'text/plain']
19
21
20
20
-
class MastodonInputOptions():
22
22
+
class MastodonInputOptions:
21
23
def __init__(self, o: dict) -> None:
22
24
self.allowed_visibility = ALLOWED_VISIBILITY
23
23
-
self.filters = [re.compile(f) for f in o.get('regex_filters', [])]
24
24
-
25
25
-
allowed_visibility = o.get('allowed_visibility')
25
25
+
self.filters = [re.compile(f) for f in o.get("regex_filters", [])]
26
26
+
27
27
+
allowed_visibility = o.get("allowed_visibility")
26
28
if allowed_visibility is not None:
27
29
if any([v not in ALLOWED_VISIBILITY for v in allowed_visibility]):
28
28
-
raise ValueError(f"'allowed_visibility' only accepts {', '.join(ALLOWED_VISIBILITY)}, got: {allowed_visibility}")
30
30
+
raise ValueError(
31
31
+
f"'allowed_visibility' only accepts {', '.join(ALLOWED_VISIBILITY)}, got: {allowed_visibility}"
32
32
+
)
29
33
self.allowed_visibility = allowed_visibility
34
34
+
30
35
31
36
class MastodonInput(cross.Input):
32
37
def __init__(self, settings: dict, db: DataBaseWorker) -> None:
33
33
-
self.options = MastodonInputOptions(settings.get('options', {}))
34
34
-
self.token = as_envvar(settings.get('token')) or (_ for _ in ()).throw(ValueError("'token' is required"))
35
35
-
instance: str = as_envvar(settings.get('instance')) or (_ for _ in ()).throw(ValueError("'instance' is required"))
36
36
-
37
37
-
service = instance[:-1] if instance.endswith('/') else instance
38
38
-
38
38
+
self.options = MastodonInputOptions(settings.get("options", {}))
39
39
+
self.token = as_envvar(settings.get("token")) or (_ for _ in ()).throw(
40
40
+
ValueError("'token' is required")
41
41
+
)
42
42
+
instance: str = as_envvar(settings.get("instance")) or (_ for _ in ()).throw(
43
43
+
ValueError("'instance' is required")
44
44
+
)
45
45
+
46
46
+
service = instance[:-1] if instance.endswith("/") else instance
47
47
+
39
48
LOGGER.info("Verifying %s credentails...", service)
40
40
-
responce = requests.get(f"{service}/api/v1/accounts/verify_credentials", headers={
41
41
-
'Authorization': f'Bearer {self.token}'
42
42
-
})
49
49
+
responce = requests.get(
50
50
+
f"{service}/api/v1/accounts/verify_credentials",
51
51
+
headers={"Authorization": f"Bearer {self.token}"},
52
52
+
)
43
53
if responce.status_code != 200:
44
54
LOGGER.error("Failed to validate user credentials!")
45
55
responce.raise_for_status()
46
56
return
47
47
-
57
57
+
48
58
super().__init__(service, responce.json()["id"], settings, db)
49
59
self.streaming = self._get_streaming_url()
50
50
-
60
60
+
51
61
if not self.streaming:
52
62
raise Exception("Instance %s does not support streaming!", service)
53
63
···
55
65
response = requests.get(f"{self.service}/api/v1/instance")
56
66
response.raise_for_status()
57
67
data: dict = response.json()
58
58
-
return (data.get('urls') or {}).get('streaming_api')
68
68
+
return (data.get("urls") or {}).get("streaming_api")
59
69
60
70
def __to_tokens(self, status: dict):
61
61
-
content_type = status.get('content_type', 'text/plain')
62
62
-
raw_text = status.get('text')
63
63
-
71
71
+
content_type = status.get("content_type", "text/plain")
72
72
+
raw_text = status.get("text")
73
73
+
64
74
tags: list[str] = []
65
65
-
for tag in status.get('tags', []):
66
66
-
tags.append(tag['name'])
67
67
-
75
75
+
for tag in status.get("tags", []):
76
76
+
tags.append(tag["name"])
77
77
+
68
78
mentions: list[tuple[str, str]] = []
69
69
-
for mention in status.get('mentions', []):
70
70
-
mentions.append(('@' + mention['username'], '@' + mention['acct']))
71
71
-
79
79
+
for mention in status.get("mentions", []):
80
80
+
mentions.append(("@" + mention["username"], "@" + mention["acct"]))
81
81
+
72
82
if raw_text and content_type in MARKDOWNY:
73
83
return md_util.tokenize_markdown(raw_text, tags, mentions)
74
74
-
75
75
-
akkoma_ext: dict | None = status.get('akkoma', {}).get('source')
84
84
+
85
85
+
akkoma_ext: dict | None = status.get("akkoma", {}).get("source")
76
86
if akkoma_ext:
77
77
-
if akkoma_ext.get('mediaType') in MARKDOWNY:
87
87
+
if akkoma_ext.get("mediaType") in MARKDOWNY:
78
88
return md_util.tokenize_markdown(akkoma_ext["content"], tags, mentions)
79
79
-
89
89
+
80
90
tokenizer = html_util.HTMLPostTokenizer()
81
91
tokenizer.mentions = mentions
82
92
tokenizer.tags = tags
83
83
-
tokenizer.feed(status.get('content', ""))
93
93
+
tokenizer.feed(status.get("content", ""))
84
94
return tokenizer.get_tokens()
85
85
-
95
95
+
86
96
def _on_create_post(self, outputs: list[cross.Output], status: dict):
87
97
# skip events from other users
88
88
-
if (status.get('account') or {})['id'] != self.user_id:
98
98
+
if (status.get("account") or {})["id"] != self.user_id:
89
99
return
90
90
-
91
91
-
if status.get('visibility') not in self.options.allowed_visibility:
100
100
+
101
101
+
if status.get("visibility") not in self.options.allowed_visibility:
92
102
# Skip f/o and direct posts
93
93
-
LOGGER.info("Skipping '%s'! '%s' visibility..", status['id'], status.get('visibility'))
103
103
+
LOGGER.info(
104
104
+
"Skipping '%s'! '%s' visibility..",
105
105
+
status["id"],
106
106
+
status.get("visibility"),
107
107
+
)
94
108
return
95
95
-
109
109
+
96
110
# TODO polls not supported on bsky. maybe 3rd party? skip for now
97
111
# we don't handle reblogs. possible with bridgy(?) and self
98
112
# we don't handle quotes.
99
99
-
if status.get('poll'):
100
100
-
LOGGER.info("Skipping '%s'! Contains a poll..", status['id'])
113
113
+
if status.get("poll"):
114
114
+
LOGGER.info("Skipping '%s'! Contains a poll..", status["id"])
101
115
return
102
102
-
103
103
-
if status.get('quote_id') or status.get('quote'):
104
104
-
LOGGER.info("Skipping '%s'! Quote..", status['id'])
116
116
+
117
117
+
if status.get("quote_id") or status.get("quote"):
118
118
+
LOGGER.info("Skipping '%s'! Quote..", status["id"])
105
119
return
106
106
-
107
107
-
reblog: dict | None = status.get('reblog')
120
120
+
121
121
+
reblog: dict | None = status.get("reblog")
108
122
if reblog:
109
109
-
if (reblog.get('account') or {})['id'] != self.user_id:
110
110
-
LOGGER.info("Skipping '%s'! Reblog of other user..", status['id'])
123
123
+
if (reblog.get("account") or {})["id"] != self.user_id:
124
124
+
LOGGER.info("Skipping '%s'! Reblog of other user..", status["id"])
111
125
return
112
112
-
113
113
-
success = database.try_insert_repost(self.db, status['id'], reblog['id'], self.user_id, self.service)
126
126
+
127
127
+
success = database.try_insert_repost(
128
128
+
self.db, status["id"], reblog["id"], self.user_id, self.service
129
129
+
)
114
130
if not success:
115
115
-
LOGGER.info("Skipping '%s' as reblogged post was not found in db!", status['id'])
131
131
+
LOGGER.info(
132
132
+
"Skipping '%s' as reblogged post was not found in db!", status["id"]
133
133
+
)
116
134
return
117
117
-
135
135
+
118
136
for output in outputs:
119
119
-
output.accept_repost(status['id'], reblog['id'])
137
137
+
output.accept_repost(status["id"], reblog["id"])
120
138
return
121
121
-
122
122
-
in_reply: str | None = status.get('in_reply_to_id')
123
123
-
in_reply_to: str | None = status.get('in_reply_to_account_id')
139
139
+
140
140
+
in_reply: str | None = status.get("in_reply_to_id")
141
141
+
in_reply_to: str | None = status.get("in_reply_to_account_id")
124
142
if in_reply_to and in_reply_to != self.user_id:
125
143
# We don't support replies.
126
126
-
LOGGER.info("Skipping '%s'! Reply to other user..", status['id'])
144
144
+
LOGGER.info("Skipping '%s'! Reply to other user..", status["id"])
127
145
return
128
128
-
129
129
-
success = database.try_insert_post(self.db, status['id'], in_reply, self.user_id, self.service)
146
146
+
147
147
+
success = database.try_insert_post(
148
148
+
self.db, status["id"], in_reply, self.user_id, self.service
149
149
+
)
130
150
if not success:
131
131
-
LOGGER.info("Skipping '%s' as parent post was not found in db!", status['id'])
151
151
+
LOGGER.info(
152
152
+
"Skipping '%s' as parent post was not found in db!", status["id"]
153
153
+
)
132
154
return
133
133
-
155
155
+
134
156
tokens = self.__to_tokens(status)
135
157
if not cross.test_filters(tokens, self.options.filters):
136
136
-
LOGGER.info("Skipping '%s'. Matched a filter!", status['id'])
158
158
+
LOGGER.info("Skipping '%s'. Matched a filter!", status["id"])
137
159
return
138
138
-
139
139
-
LOGGER.info("Crossposting '%s'...", status['id'])
140
140
-
160
160
+
161
161
+
LOGGER.info("Crossposting '%s'...", status["id"])
162
162
+
141
163
media_attachments: list[MediaInfo] = []
142
142
-
for attachment in status.get('media_attachments', []):
143
143
-
LOGGER.info("Downloading %s...", attachment['url'])
144
144
-
info = download_media(attachment['url'], attachment.get('description') or '')
164
164
+
for attachment in status.get("media_attachments", []):
165
165
+
LOGGER.info("Downloading %s...", attachment["url"])
166
166
+
info = download_media(
167
167
+
attachment["url"], attachment.get("description") or ""
168
168
+
)
145
169
if not info:
146
146
-
LOGGER.error("Skipping '%s'. Failed to download media!", status['id'])
170
170
+
LOGGER.error("Skipping '%s'. Failed to download media!", status["id"])
147
171
return
148
172
media_attachments.append(info)
149
149
-
173
173
+
150
174
cross_post = MastodonPost(status, tokens, media_attachments)
151
175
for output in outputs:
152
176
output.accept_post(cross_post)
153
153
-
177
177
+
154
178
def _on_delete_post(self, outputs: list[cross.Output], identifier: str):
155
179
post = database.find_post(self.db, identifier, self.user_id, self.service)
156
180
if not post:
157
181
return
158
158
-
182
182
+
159
183
LOGGER.info("Deleting '%s'...", identifier)
160
160
-
if post['reposted_id']:
184
184
+
if post["reposted_id"]:
161
185
for output in outputs:
162
186
output.delete_repost(identifier)
163
187
else:
164
188
for output in outputs:
165
189
output.delete_post(identifier)
166
166
-
190
190
+
167
191
database.delete_post(self.db, identifier, self.user_id, self.service)
168
168
-
192
192
+
169
193
def _on_post(self, outputs: list[cross.Output], event: str, payload: str):
170
194
match event:
171
171
-
case 'update':
195
195
+
case "update":
172
196
self._on_create_post(outputs, json.loads(payload))
173
173
-
case 'delete':
197
197
+
case "delete":
174
198
self._on_delete_post(outputs, payload)
175
175
-
176
176
-
async def listen(self, outputs: list[cross.Output], submit: Callable[[Callable[[], Any]], Any]):
199
199
+
200
200
+
async def listen(
201
201
+
self, outputs: list[cross.Output], submit: Callable[[Callable[[], Any]], Any]
202
202
+
):
177
203
uri = f"{self.streaming}/api/v1/streaming?stream=user&access_token={self.token}"
178
178
-
179
179
-
async for ws in websockets.connect(uri, extra_headers={"User-Agent": "XPost/0.0.3"}):
204
204
+
205
205
+
async for ws in websockets.connect(
206
206
+
uri, extra_headers={"User-Agent": "XPost/0.0.3"}
207
207
+
):
180
208
try:
181
209
LOGGER.info("Listening to %s...", self.streaming)
182
182
-
210
210
+
183
211
async def listen_for_messages():
184
212
async for msg in ws:
185
213
data = json.loads(msg)
186
186
-
event: str = data.get('event')
187
187
-
payload: str = data.get('payload')
188
188
-
214
214
+
event: str = data.get("event")
215
215
+
payload: str = data.get("payload")
216
216
+
189
217
submit(lambda: self._on_post(outputs, str(event), str(payload)))
190
190
-
218
218
+
191
219
listen = asyncio.create_task(listen_for_messages())
192
192
-
220
220
+
193
221
await asyncio.gather(listen)
194
222
except websockets.ConnectionClosedError as e:
195
223
LOGGER.error(e, stack_info=True, exc_info=True)
196
224
LOGGER.info("Reconnecting to %s...", self.streaming)
197
197
-
continue
225
225
+
continue
+252
-214
mastodon/output.py
reviewed
···
1
1
-
import requests, time
1
1
+
import time
2
2
3
3
-
import cross, util.database as database
3
3
+
import requests
4
4
+
5
5
+
import cross
4
6
import misskey.mfm_util as mfm_util
5
5
-
from util.util import LOGGER, as_envvar, canonical_label
7
7
+
import util.database as database
8
8
+
from util.database import DataBaseWorker
6
9
from util.media import MediaInfo
7
7
-
from util.database import DataBaseWorker
10
10
+
from util.util import LOGGER, as_envvar, canonical_label
8
11
9
12
POSSIBLE_MIMES = [
10
10
-
'audio/ogg',
11
11
-
'audio/mp3',
12
12
-
'image/webp',
13
13
-
'image/jpeg',
14
14
-
'image/png',
15
15
-
'video/mp4',
16
16
-
'video/quicktime',
17
17
-
'video/webm'
13
13
+
"audio/ogg",
14
14
+
"audio/mp3",
15
15
+
"image/webp",
16
16
+
"image/jpeg",
17
17
+
"image/png",
18
18
+
"video/mp4",
19
19
+
"video/quicktime",
20
20
+
"video/webm",
18
21
]
19
22
20
20
-
TEXT_MIMES = [
21
21
-
'text/x.misskeymarkdown',
22
22
-
'text/markdown',
23
23
-
'text/plain'
24
24
-
]
23
23
+
TEXT_MIMES = ["text/x.misskeymarkdown", "text/markdown", "text/plain"]
24
24
+
25
25
+
ALLOWED_POSTING_VISIBILITY = ["public", "unlisted", "private"]
25
26
26
26
-
ALLOWED_POSTING_VISIBILITY = ['public', 'unlisted', 'private']
27
27
28
28
-
class MastodonOutputOptions():
28
28
+
class MastodonOutputOptions:
29
29
def __init__(self, o: dict) -> None:
30
30
-
self.visibility = 'public'
31
31
-
32
32
-
visibility = o.get('visibility')
30
30
+
self.visibility = "public"
31
31
+
32
32
+
visibility = o.get("visibility")
33
33
if visibility is not None:
34
34
if visibility not in ALLOWED_POSTING_VISIBILITY:
35
35
-
raise ValueError(f"'visibility' only accepts {', '.join(ALLOWED_POSTING_VISIBILITY)}, got: {visibility}")
35
35
+
raise ValueError(
36
36
+
f"'visibility' only accepts {', '.join(ALLOWED_POSTING_VISIBILITY)}, got: {visibility}"
37
37
+
)
36
38
self.visibility = visibility
39
39
+
37
40
38
41
class MastodonOutput(cross.Output):
39
42
def __init__(self, input: cross.Input, settings: dict, db: DataBaseWorker) -> None:
40
43
super().__init__(input, settings, db)
41
41
-
self.options = settings.get('options') or {}
42
42
-
self.token = as_envvar(settings.get('token')) or (_ for _ in ()).throw(ValueError("'token' is required"))
43
43
-
instance: str = as_envvar(settings.get('instance')) or (_ for _ in ()).throw(ValueError("'instance' is required"))
44
44
-
45
45
-
self.service = instance[:-1] if instance.endswith('/') else instance
46
46
-
44
44
+
self.options = settings.get("options") or {}
45
45
+
self.token = as_envvar(settings.get("token")) or (_ for _ in ()).throw(
46
46
+
ValueError("'token' is required")
47
47
+
)
48
48
+
instance: str = as_envvar(settings.get("instance")) or (_ for _ in ()).throw(
49
49
+
ValueError("'instance' is required")
50
50
+
)
51
51
+
52
52
+
self.service = instance[:-1] if instance.endswith("/") else instance
53
53
+
47
54
LOGGER.info("Verifying %s credentails...", self.service)
48
48
-
responce = requests.get(f"{self.service}/api/v1/accounts/verify_credentials", headers={
49
49
-
'Authorization': f'Bearer {self.token}'
50
50
-
})
55
55
+
responce = requests.get(
56
56
+
f"{self.service}/api/v1/accounts/verify_credentials",
57
57
+
headers={"Authorization": f"Bearer {self.token}"},
58
58
+
)
51
59
if responce.status_code != 200:
52
60
LOGGER.error("Failed to validate user credentials!")
53
61
responce.raise_for_status()
···
55
63
self.user_id: str = responce.json()["id"]
56
64
57
65
LOGGER.info("Getting %s configuration...", self.service)
58
58
-
responce = requests.get(f"{self.service}/api/v1/instance", headers={
59
59
-
'Authorization': f'Bearer {self.token}'
60
60
-
})
66
66
+
responce = requests.get(
67
67
+
f"{self.service}/api/v1/instance",
68
68
+
headers={"Authorization": f"Bearer {self.token}"},
69
69
+
)
61
70
if responce.status_code != 200:
62
71
LOGGER.error("Failed to get instance info!")
63
72
responce.raise_for_status()
64
73
return
65
65
-
74
74
+
66
75
instance_info: dict = responce.json()
67
67
-
configuration: dict = instance_info['configuration']
68
68
-
69
69
-
statuses_config: dict = configuration.get('statuses', {})
70
70
-
self.max_characters: int = statuses_config.get('max_characters', 500)
71
71
-
self.max_media_attachments: int = statuses_config.get('max_media_attachments', 4)
72
72
-
self.characters_reserved_per_url: int = statuses_config.get('characters_reserved_per_url', 23)
73
73
-
74
74
-
media_config: dict = configuration.get('media_attachments', {})
75
75
-
self.image_size_limit: int = media_config.get('image_size_limit', 16777216)
76
76
-
self.video_size_limit: int = media_config.get('video_size_limit', 103809024)
77
77
-
self.supported_mime_types: list[str] = media_config.get('supported_mime_types', POSSIBLE_MIMES)
78
78
-
76
76
+
configuration: dict = instance_info["configuration"]
77
77
+
78
78
+
statuses_config: dict = configuration.get("statuses", {})
79
79
+
self.max_characters: int = statuses_config.get("max_characters", 500)
80
80
+
self.max_media_attachments: int = statuses_config.get(
81
81
+
"max_media_attachments", 4
82
82
+
)
83
83
+
self.characters_reserved_per_url: int = statuses_config.get(
84
84
+
"characters_reserved_per_url", 23
85
85
+
)
86
86
+
87
87
+
media_config: dict = configuration.get("media_attachments", {})
88
88
+
self.image_size_limit: int = media_config.get("image_size_limit", 16777216)
89
89
+
self.video_size_limit: int = media_config.get("video_size_limit", 103809024)
90
90
+
self.supported_mime_types: list[str] = media_config.get(
91
91
+
"supported_mime_types", POSSIBLE_MIMES
92
92
+
)
93
93
+
79
94
# *oma: max post chars
80
80
-
max_toot_chars = instance_info.get('max_toot_chars')
95
95
+
max_toot_chars = instance_info.get("max_toot_chars")
81
96
if max_toot_chars:
82
97
self.max_characters: int = max_toot_chars
83
83
-
98
98
+
84
99
# *oma: max upload limit
85
85
-
upload_limit = instance_info.get('upload_limit')
100
100
+
upload_limit = instance_info.get("upload_limit")
86
101
if upload_limit:
87
102
self.image_size_limit: int = upload_limit
88
103
self.video_size_limit: int = upload_limit
89
89
-
104
104
+
90
105
# chuckya: supported text types
91
91
-
chuckya_text_mimes: list[str] = statuses_config.get('supported_mime_types', [])
106
106
+
chuckya_text_mimes: list[str] = statuses_config.get("supported_mime_types", [])
92
107
self.text_format = next(
93
93
-
(mime for mime in TEXT_MIMES if mime in (chuckya_text_mimes)),
94
94
-
'text/plain'
108
108
+
(mime for mime in TEXT_MIMES if mime in (chuckya_text_mimes)), "text/plain"
95
109
)
96
96
-
110
110
+
97
111
# *oma ext: supported text types
98
98
-
pleroma = instance_info.get('pleroma')
112
112
+
pleroma = instance_info.get("pleroma")
99
113
if pleroma:
100
100
-
post_formats: list[str] = pleroma.get('metadata', {}).get('post_formats', [])
114
114
+
post_formats: list[str] = pleroma.get("metadata", {}).get(
115
115
+
"post_formats", []
116
116
+
)
101
117
self.text_format = next(
102
102
-
(mime for mime in TEXT_MIMES if mime in post_formats),
103
103
-
self.text_format
118
118
+
(mime for mime in TEXT_MIMES if mime in post_formats), self.text_format
104
119
)
105
105
-
120
120
+
106
121
def upload_media(self, attachments: list[MediaInfo]) -> list[str] | None:
107
122
for a in attachments:
108
108
-
if a.mime.startswith('image/') and len(a.io) > self.image_size_limit:
123
123
+
if a.mime.startswith("image/") and len(a.io) > self.image_size_limit:
109
124
return None
110
110
-
111
111
-
if a.mime.startswith('video/') and len(a.io) > self.video_size_limit:
125
125
+
126
126
+
if a.mime.startswith("video/") and len(a.io) > self.video_size_limit:
112
127
return None
113
113
-
114
114
-
if not a.mime.startswith('image/') and not a.mime.startswith('video/'):
128
128
+
129
129
+
if not a.mime.startswith("image/") and not a.mime.startswith("video/"):
115
130
if len(a.io) > 7_000_000:
116
131
return None
117
117
-
132
132
+
118
133
uploads: list[dict] = []
119
134
for a in attachments:
120
135
data = {}
121
136
if a.alt:
122
122
-
data['description'] = a.alt
123
123
-
124
124
-
req = requests.post(f"{self.service}/api/v2/media", headers= {
125
125
-
'Authorization': f'Bearer {self.token}'
126
126
-
}, files={'file': (a.name, a.io, a.mime)}, data=data)
127
127
-
137
137
+
data["description"] = a.alt
138
138
+
139
139
+
req = requests.post(
140
140
+
f"{self.service}/api/v2/media",
141
141
+
headers={"Authorization": f"Bearer {self.token}"},
142
142
+
files={"file": (a.name, a.io, a.mime)},
143
143
+
data=data,
144
144
+
)
145
145
+
128
146
if req.status_code == 200:
129
129
-
LOGGER.info("Uploaded %s! (%s)", a.name, req.json()['id'])
130
130
-
uploads.append({
131
131
-
'done': True,
132
132
-
'id': req.json()['id']
133
133
-
})
147
147
+
LOGGER.info("Uploaded %s! (%s)", a.name, req.json()["id"])
148
148
+
uploads.append({"done": True, "id": req.json()["id"]})
134
149
elif req.status_code == 202:
135
150
LOGGER.info("Waiting for %s to process!", a.name)
136
136
-
uploads.append({
137
137
-
'done': False,
138
138
-
'id': req.json()['id']
139
139
-
})
151
151
+
uploads.append({"done": False, "id": req.json()["id"]})
140
152
else:
141
153
LOGGER.error("Failed to upload %s! %s", a.name, req.text)
142
154
req.raise_for_status()
143
143
-
144
144
-
while any([not val['done'] for val in uploads]):
155
155
+
156
156
+
while any([not val["done"] for val in uploads]):
145
157
LOGGER.info("Waiting for media to process...")
146
158
time.sleep(3)
147
159
for media in uploads:
148
148
-
if media['done']:
160
160
+
if media["done"]:
149
161
continue
150
150
-
151
151
-
reqs = requests.get(f'{self.service}/api/v1/media/{media['id']}', headers={
152
152
-
'Authorization': f'Bearer {self.token}'
153
153
-
})
154
154
-
162
162
+
163
163
+
reqs = requests.get(
164
164
+
f"{self.service}/api/v1/media/{media['id']}",
165
165
+
headers={"Authorization": f"Bearer {self.token}"},
166
166
+
)
167
167
+
155
168
if reqs.status_code == 206:
156
169
continue
157
157
-
170
170
+
158
171
if reqs.status_code == 200:
159
159
-
media['done'] = True
172
172
+
media["done"] = True
160
173
continue
161
174
reqs.raise_for_status()
162
162
-
163
163
-
return [val['id'] for val in uploads]
175
175
+
176
176
+
return [val["id"] for val in uploads]
164
177
165
178
def token_to_string(self, tokens: list[cross.Token]) -> str | None:
166
166
-
p_text: str = ''
167
167
-
179
179
+
p_text: str = ""
180
180
+
168
181
for token in tokens:
169
182
if isinstance(token, cross.TextToken):
170
183
p_text += token.text
171
184
elif isinstance(token, cross.TagToken):
172
172
-
p_text += '#' + token.tag
185
185
+
p_text += "#" + token.tag
173
186
elif isinstance(token, cross.LinkToken):
174
187
if canonical_label(token.label, token.href):
175
188
p_text += token.href
176
189
else:
177
177
-
if self.text_format == 'text/plain':
178
178
-
p_text += f'{token.label} ({token.href})'
179
179
-
elif self.text_format in {'text/x.misskeymarkdown', 'text/markdown'}:
180
180
-
p_text += f'[{token.label}]({token.href})'
190
190
+
if self.text_format == "text/plain":
191
191
+
p_text += f"{token.label} ({token.href})"
192
192
+
elif self.text_format in {
193
193
+
"text/x.misskeymarkdown",
194
194
+
"text/markdown",
195
195
+
}:
196
196
+
p_text += f"[{token.label}]({token.href})"
181
197
else:
182
198
return None
183
183
-
199
199
+
184
200
return p_text
185
201
186
202
def split_tokens_media(self, tokens: list[cross.Token], media: list[MediaInfo]):
187
187
-
split_tokens = cross.split_tokens(tokens, self.max_characters, self.characters_reserved_per_url)
203
203
+
split_tokens = cross.split_tokens(
204
204
+
tokens, self.max_characters, self.characters_reserved_per_url
205
205
+
)
188
206
post_text: list[str] = []
189
189
-
207
207
+
190
208
for block in split_tokens:
191
209
baked_text = self.token_to_string(block)
192
192
-
210
210
+
193
211
if baked_text is None:
194
212
return None
195
213
post_text.append(baked_text)
196
196
-
214
214
+
197
215
if not post_text:
198
198
-
post_text = ['']
199
199
-
200
200
-
posts: list[dict] = [{"text": post_text, "attachments": []} for post_text in post_text]
216
216
+
post_text = [""]
217
217
+
218
218
+
posts: list[dict] = [
219
219
+
{"text": post_text, "attachments": []} for post_text in post_text
220
220
+
]
201
221
available_indices: list[int] = list(range(len(posts)))
202
202
-
222
222
+
203
223
current_image_post_idx: int | None = None
204
204
-
224
224
+
205
225
def make_blank_post() -> dict:
206
206
-
return {
207
207
-
"text": '',
208
208
-
"attachments": []
209
209
-
}
210
210
-
226
226
+
return {"text": "", "attachments": []}
227
227
+
211
228
def pop_next_empty_index() -> int:
212
229
if available_indices:
213
230
return available_indices.pop(0)
···
215
232
new_idx = len(posts)
216
233
posts.append(make_blank_post())
217
234
return new_idx
218
218
-
235
235
+
219
236
for att in media:
220
237
if (
221
238
current_image_post_idx is not None
222
222
-
and len(posts[current_image_post_idx]["attachments"]) < self.max_media_attachments
239
239
+
and len(posts[current_image_post_idx]["attachments"])
240
240
+
< self.max_media_attachments
223
241
):
224
242
posts[current_image_post_idx]["attachments"].append(att)
225
243
else:
226
244
idx = pop_next_empty_index()
227
245
posts[idx]["attachments"].append(att)
228
246
current_image_post_idx = idx
229
229
-
247
247
+
230
248
result: list[tuple[str, list[MediaInfo]]] = []
231
231
-
249
249
+
232
250
for p in posts:
233
233
-
result.append((p['text'], p["attachments"]))
234
234
-
251
251
+
result.append((p["text"], p["attachments"]))
252
252
+
235
253
return result
236
236
-
254
254
+
237
255
def accept_post(self, post: cross.Post):
238
256
parent_id = post.get_parent_id()
239
239
-
257
257
+
240
258
new_root_id: int | None = None
241
259
new_parent_id: int | None = None
242
242
-
260
260
+
243
261
reply_ref: str | None = None
244
262
if parent_id:
245
263
thread_tuple = database.find_mapped_thread(
···
248
266
self.input.user_id,
249
267
self.input.service,
250
268
self.user_id,
251
251
-
self.service
269
269
+
self.service,
252
270
)
253
253
-
271
271
+
254
272
if not thread_tuple:
255
273
LOGGER.error("Failed to find thread tuple in the database!")
256
274
return None
257
257
-
275
275
+
258
276
_, reply_ref, new_root_id, new_parent_id = thread_tuple
259
259
-
277
277
+
260
278
lang: str
261
279
if post.get_languages():
262
280
lang = post.get_languages()[0]
263
281
else:
264
264
-
lang = 'en'
265
265
-
282
282
+
lang = "en"
283
283
+
266
284
post_tokens = post.get_tokens()
267
285
if post.get_text_type() == "text/x.misskeymarkdown":
268
286
post_tokens, status = mfm_util.strip_mfm(post_tokens)
269
287
post_url = post.get_post_url()
270
288
if status and post_url:
271
271
-
post_tokens.append(cross.TextToken('\n'))
272
272
-
post_tokens.append(cross.LinkToken(post_url, "[Post contains MFM, see original]"))
273
273
-
289
289
+
post_tokens.append(cross.TextToken("\n"))
290
290
+
post_tokens.append(
291
291
+
cross.LinkToken(post_url, "[Post contains MFM, see original]")
292
292
+
)
293
293
+
274
294
raw_statuses = self.split_tokens_media(post_tokens, post.get_attachments())
275
295
if not raw_statuses:
276
296
LOGGER.error("Failed to split post into statuses?")
277
297
return None
278
298
baked_statuses = []
279
279
-
299
299
+
280
300
for status, raw_media in raw_statuses:
281
301
media: list[str] | None = None
282
302
if raw_media:
···
286
306
return None
287
307
baked_statuses.append((status, media))
288
308
continue
289
289
-
baked_statuses.append((status,[]))
290
290
-
309
309
+
baked_statuses.append((status, []))
310
310
+
291
311
created_statuses: list[str] = []
292
292
-
312
312
+
293
313
for status, media in baked_statuses:
294
314
payload = {
295
295
-
'status': status,
296
296
-
'media_ids': media or [],
297
297
-
'spoiler_text': post.get_spoiler() or '',
298
298
-
'visibility': self.options.get('visibility', 'public'),
299
299
-
'content_type': self.text_format,
300
300
-
'language': lang
315
315
+
"status": status,
316
316
+
"media_ids": media or [],
317
317
+
"spoiler_text": post.get_spoiler() or "",
318
318
+
"visibility": self.options.get("visibility", "public"),
319
319
+
"content_type": self.text_format,
320
320
+
"language": lang,
301
321
}
302
302
-
322
322
+
303
323
if media:
304
304
-
payload['sensitive'] = post.is_sensitive()
305
305
-
324
324
+
payload["sensitive"] = post.is_sensitive()
325
325
+
306
326
if post.get_spoiler():
307
307
-
payload['sensitive'] = True
308
308
-
327
327
+
payload["sensitive"] = True
328
328
+
309
329
if not status:
310
310
-
payload['status'] = '🖼️'
311
311
-
330
330
+
payload["status"] = "🖼️"
331
331
+
312
332
if reply_ref:
313
313
-
payload['in_reply_to_id'] = reply_ref
314
314
-
315
315
-
reqs = requests.post(f'{self.service}/api/v1/statuses', headers={
316
316
-
'Authorization': f'Bearer {self.token}',
317
317
-
'Content-Type': 'application/json'
318
318
-
}, json=payload)
319
319
-
333
333
+
payload["in_reply_to_id"] = reply_ref
334
334
+
335
335
+
reqs = requests.post(
336
336
+
f"{self.service}/api/v1/statuses",
337
337
+
headers={
338
338
+
"Authorization": f"Bearer {self.token}",
339
339
+
"Content-Type": "application/json",
340
340
+
},
341
341
+
json=payload,
342
342
+
)
343
343
+
320
344
if reqs.status_code != 200:
321
321
-
LOGGER.info("Failed to post status! %s - %s", reqs.status_code, reqs.text)
345
345
+
LOGGER.info(
346
346
+
"Failed to post status! %s - %s", reqs.status_code, reqs.text
347
347
+
)
322
348
reqs.raise_for_status()
323
323
-
324
324
-
reply_ref = reqs.json()['id']
349
349
+
350
350
+
reply_ref = reqs.json()["id"]
325
351
LOGGER.info("Created new status %s!", reply_ref)
326
326
-
327
327
-
created_statuses.append(reqs.json()['id'])
328
328
-
329
329
-
db_post = database.find_post(self.db, post.get_id(), self.input.user_id, self.input.service)
352
352
+
353
353
+
created_statuses.append(reqs.json()["id"])
354
354
+
355
355
+
db_post = database.find_post(
356
356
+
self.db, post.get_id(), self.input.user_id, self.input.service
357
357
+
)
330
358
assert db_post, "ghghghhhhh"
331
331
-
332
332
-
if new_root_id is None or new_parent_id is None:
359
359
+
360
360
+
if new_root_id is None or new_parent_id is None:
333
361
new_root_id = database.insert_post(
334
334
-
self.db,
335
335
-
created_statuses[0],
336
336
-
self.user_id,
337
337
-
self.service
362
362
+
self.db, created_statuses[0], self.user_id, self.service
338
363
)
339
364
new_parent_id = new_root_id
340
340
-
database.insert_mapping(self.db, db_post['id'], new_parent_id)
365
365
+
database.insert_mapping(self.db, db_post["id"], new_parent_id)
341
366
created_statuses = created_statuses[1:]
342
342
-
367
367
+
343
368
for db_id in created_statuses:
344
369
new_parent_id = database.insert_reply(
345
345
-
self.db,
346
346
-
db_id,
347
347
-
self.user_id,
348
348
-
self.service,
349
349
-
new_parent_id,
350
350
-
new_root_id
370
370
+
self.db, db_id, self.user_id, self.service, new_parent_id, new_root_id
351
371
)
352
352
-
database.insert_mapping(self.db, db_post['id'], new_parent_id)
353
353
-
372
372
+
database.insert_mapping(self.db, db_post["id"], new_parent_id)
373
373
+
354
374
def delete_post(self, identifier: str):
355
355
-
post = database.find_post(self.db, identifier, self.input.user_id, self.input.service)
375
375
+
post = database.find_post(
376
376
+
self.db, identifier, self.input.user_id, self.input.service
377
377
+
)
356
378
if not post:
357
379
return
358
358
-
359
359
-
mappings = database.find_mappings(self.db, post['id'], self.service, self.user_id)
380
380
+
381
381
+
mappings = database.find_mappings(
382
382
+
self.db, post["id"], self.service, self.user_id
383
383
+
)
360
384
for mapping in mappings[::-1]:
361
385
LOGGER.info("Deleting '%s'...", mapping[0])
362
362
-
requests.delete(f'{self.service}/api/v1/statuses/{mapping[0]}', headers={
363
363
-
'Authorization': f'Bearer {self.token}'
364
364
-
})
386
386
+
requests.delete(
387
387
+
f"{self.service}/api/v1/statuses/{mapping[0]}",
388
388
+
headers={"Authorization": f"Bearer {self.token}"},
389
389
+
)
365
390
database.delete_post(self.db, mapping[0], self.service, self.user_id)
366
366
-
391
391
+
367
392
def accept_repost(self, repost_id: str, reposted_id: str):
368
393
repost = self.__delete_repost(repost_id)
369
394
if not repost:
370
395
return None
371
371
-
372
372
-
reposted = database.find_post(self.db, reposted_id, self.input.user_id, self.input.service)
396
396
+
397
397
+
reposted = database.find_post(
398
398
+
self.db, reposted_id, self.input.user_id, self.input.service
399
399
+
)
373
400
if not reposted:
374
401
return
375
375
-
376
376
-
mappings = database.find_mappings(self.db, reposted['id'], self.service, self.user_id)
402
402
+
403
403
+
mappings = database.find_mappings(
404
404
+
self.db, reposted["id"], self.service, self.user_id
405
405
+
)
377
406
if mappings:
378
378
-
rsp = requests.post(f'{self.service}/api/v1/statuses/{mappings[0][0]}/reblog', headers={
379
379
-
'Authorization': f'Bearer {self.token}'
380
380
-
})
381
381
-
407
407
+
rsp = requests.post(
408
408
+
f"{self.service}/api/v1/statuses/{mappings[0][0]}/reblog",
409
409
+
headers={"Authorization": f"Bearer {self.token}"},
410
410
+
)
411
411
+
382
412
if rsp.status_code != 200:
383
383
-
LOGGER.error("Failed to boost status! status_code: %s, msg: %s", rsp.status_code, rsp.content)
413
413
+
LOGGER.error(
414
414
+
"Failed to boost status! status_code: %s, msg: %s",
415
415
+
rsp.status_code,
416
416
+
rsp.content,
417
417
+
)
384
418
return
385
385
-
419
419
+
386
420
internal_id = database.insert_repost(
387
387
-
self.db,
388
388
-
rsp.json()['id'],
389
389
-
reposted['id'],
390
390
-
self.user_id,
391
391
-
self.service)
392
392
-
database.insert_mapping(self.db, repost['id'], internal_id)
393
393
-
421
421
+
self.db, rsp.json()["id"], reposted["id"], self.user_id, self.service
422
422
+
)
423
423
+
database.insert_mapping(self.db, repost["id"], internal_id)
424
424
+
394
425
def __delete_repost(self, repost_id: str) -> dict | None:
395
395
-
repost = database.find_post(self.db, repost_id, self.input.user_id, self.input.service)
426
426
+
repost = database.find_post(
427
427
+
self.db, repost_id, self.input.user_id, self.input.service
428
428
+
)
396
429
if not repost:
397
430
return None
398
398
-
399
399
-
mappings = database.find_mappings(self.db, repost['id'], self.service, self.user_id)
400
400
-
reposted_mappings = database.find_mappings(self.db, repost['reposted_id'], self.service, self.user_id)
431
431
+
432
432
+
mappings = database.find_mappings(
433
433
+
self.db, repost["id"], self.service, self.user_id
434
434
+
)
435
435
+
reposted_mappings = database.find_mappings(
436
436
+
self.db, repost["reposted_id"], self.service, self.user_id
437
437
+
)
401
438
if mappings and reposted_mappings:
402
439
LOGGER.info("Deleting '%s'...", mappings[0][0])
403
403
-
requests.post(f'{self.service}/api/v1/statuses/{reposted_mappings[0][0]}/unreblog', headers={
404
404
-
'Authorization': f'Bearer {self.token}'
405
405
-
})
440
440
+
requests.post(
441
441
+
f"{self.service}/api/v1/statuses/{reposted_mappings[0][0]}/unreblog",
442
442
+
headers={"Authorization": f"Bearer {self.token}"},
443
443
+
)
406
444
database.delete_post(self.db, mappings[0][0], self.user_id, self.service)
407
445
return repost
408
408
-
446
446
+
409
447
def delete_repost(self, repost_id: str):
410
410
-
self.__delete_repost(repost_id)
448
448
+
self.__delete_repost(repost_id)
+26
-17
misskey/common.py
reviewed
···
1
1
import cross
2
2
from util.media import MediaInfo
3
3
4
4
+
4
5
class MisskeyPost(cross.Post):
5
5
-
def __init__(self, instance_url: str, note: dict, tokens: list[cross.Token], files: list[MediaInfo]) -> None:
6
6
+
def __init__(
7
7
+
self,
8
8
+
instance_url: str,
9
9
+
note: dict,
10
10
+
tokens: list[cross.Token],
11
11
+
files: list[MediaInfo],
12
12
+
) -> None:
6
13
super().__init__()
7
14
self.note = note
8
8
-
self.id = note['id']
9
9
-
self.parent_id = note.get('replyId')
15
15
+
self.id = note["id"]
16
16
+
self.parent_id = note.get("replyId")
10
17
self.tokens = tokens
11
11
-
self.timestamp = note['createdAt']
18
18
+
self.timestamp = note["createdAt"]
12
19
self.media_attachments = files
13
13
-
self.spoiler = note.get('cw')
14
14
-
self.sensitive = any([a.get('isSensitive', False) for a in note.get('files', [])])
15
15
-
self.url = instance_url + '/notes/' + note['id']
16
16
-
20
20
+
self.spoiler = note.get("cw")
21
21
+
self.sensitive = any(
22
22
+
[a.get("isSensitive", False) for a in note.get("files", [])]
23
23
+
)
24
24
+
self.url = instance_url + "/notes/" + note["id"]
25
25
+
17
26
def get_id(self) -> str:
18
27
return self.id
19
19
-
28
28
+
20
29
def get_parent_id(self) -> str | None:
21
30
return self.parent_id
22
22
-
31
31
+
23
32
def get_tokens(self) -> list[cross.Token]:
24
33
return self.tokens
25
34
26
35
def get_text_type(self) -> str:
27
36
return "text/x.misskeymarkdown"
28
28
-
37
37
+
29
38
def get_timestamp(self) -> str:
30
39
return self.timestamp
31
31
-
40
40
+
32
41
def get_attachments(self) -> list[MediaInfo]:
33
42
return self.media_attachments
34
34
-
43
43
+
35
44
def get_spoiler(self) -> str | None:
36
45
return self.spoiler
37
37
-
46
46
+
38
47
def get_languages(self) -> list[str]:
39
48
return []
40
40
-
49
49
+
41
50
def is_sensitive(self) -> bool:
42
51
return self.sensitive or (self.spoiler is not None)
43
43
-
52
52
+
44
53
def get_post_url(self) -> str | None:
45
45
-
return self.url
54
54
+
return self.url
+115
-92
misskey/input.py
reviewed
···
1
1
-
import requests, websockets
2
1
import asyncio
3
3
-
import json, uuid
2
2
+
import json
4
3
import re
4
4
+
import uuid
5
5
+
from typing import Any, Callable
5
6
6
6
-
from misskey.common import MisskeyPost
7
7
+
import requests
8
8
+
import websockets
7
9
8
8
-
import cross, util.database as database
10
10
+
import cross
11
11
+
import util.database as database
9
12
import util.md_util as md_util
13
13
+
from misskey.common import MisskeyPost
10
14
from util.media import MediaInfo, download_media
11
15
from util.util import LOGGER, as_envvar
12
16
13
13
-
from typing import Callable, Any
14
14
-
15
15
-
ALLOWED_VISIBILITY = ['public', 'home']
16
16
-
17
17
-
class MisskeyInputOptions():
17
17
+
ALLOWED_VISIBILITY = ["public", "home"]
18
18
+
19
19
+
20
20
+
class MisskeyInputOptions:
18
21
def __init__(self, o: dict) -> None:
19
22
self.allowed_visibility = ALLOWED_VISIBILITY
20
20
-
self.filters = [re.compile(f) for f in o.get('regex_filters', [])]
21
21
-
22
22
-
allowed_visibility = o.get('allowed_visibility')
23
23
+
self.filters = [re.compile(f) for f in o.get("regex_filters", [])]
24
24
+
25
25
+
allowed_visibility = o.get("allowed_visibility")
23
26
if allowed_visibility is not None:
24
27
if any([v not in ALLOWED_VISIBILITY for v in allowed_visibility]):
25
25
-
raise ValueError(f"'allowed_visibility' only accepts {', '.join(ALLOWED_VISIBILITY)}, got: {allowed_visibility}")
28
28
+
raise ValueError(
29
29
+
f"'allowed_visibility' only accepts {', '.join(ALLOWED_VISIBILITY)}, got: {allowed_visibility}"
30
30
+
)
26
31
self.allowed_visibility = allowed_visibility
32
32
+
27
33
28
34
class MisskeyInput(cross.Input):
29
35
def __init__(self, settings: dict, db: cross.DataBaseWorker) -> None:
30
30
-
self.options = MisskeyInputOptions(settings.get('options', {}))
31
31
-
self.token = as_envvar(settings.get('token')) or (_ for _ in ()).throw(ValueError("'token' is required"))
32
32
-
instance: str = as_envvar(settings.get('instance')) or (_ for _ in ()).throw(ValueError("'instance' is required"))
33
33
-
34
34
-
service = instance[:-1] if instance.endswith('/') else instance
35
35
-
36
36
+
self.options = MisskeyInputOptions(settings.get("options", {}))
37
37
+
self.token = as_envvar(settings.get("token")) or (_ for _ in ()).throw(
38
38
+
ValueError("'token' is required")
39
39
+
)
40
40
+
instance: str = as_envvar(settings.get("instance")) or (_ for _ in ()).throw(
41
41
+
ValueError("'instance' is required")
42
42
+
)
43
43
+
44
44
+
service = instance[:-1] if instance.endswith("/") else instance
45
45
+
36
46
LOGGER.info("Verifying %s credentails...", service)
37
37
-
responce = requests.post(f"{instance}/api/i", json={ 'i': self.token }, headers={
38
38
-
"Content-Type": "application/json"
39
39
-
})
47
47
+
responce = requests.post(
48
48
+
f"{instance}/api/i",
49
49
+
json={"i": self.token},
50
50
+
headers={"Content-Type": "application/json"},
51
51
+
)
40
52
if responce.status_code != 200:
41
53
LOGGER.error("Failed to validate user credentials!")
42
54
responce.raise_for_status()
43
55
return
44
44
-
56
56
+
45
57
super().__init__(service, responce.json()["id"], settings, db)
46
46
-
58
58
+
47
59
def _on_note(self, outputs: list[cross.Output], note: dict):
48
48
-
if note['userId'] != self.user_id:
60
60
+
if note["userId"] != self.user_id:
49
61
return
50
50
-
51
51
-
if note.get('visibility') not in self.options.allowed_visibility:
52
52
-
LOGGER.info("Skipping '%s'! '%s' visibility..", note['id'], note.get('visibility'))
62
62
+
63
63
+
if note.get("visibility") not in self.options.allowed_visibility:
64
64
+
LOGGER.info(
65
65
+
"Skipping '%s'! '%s' visibility..", note["id"], note.get("visibility")
66
66
+
)
53
67
return
54
54
-
68
68
+
55
69
# TODO polls not supported on bsky. maybe 3rd party? skip for now
56
70
# we don't handle reblogs. possible with bridgy(?) and self
57
57
-
if note.get('poll'):
58
58
-
LOGGER.info("Skipping '%s'! Contains a poll..", note['id'])
71
71
+
if note.get("poll"):
72
72
+
LOGGER.info("Skipping '%s'! Contains a poll..", note["id"])
59
73
return
60
60
-
61
61
-
renote: dict | None = note.get('renote')
74
74
+
75
75
+
renote: dict | None = note.get("renote")
62
76
if renote:
63
63
-
if note.get('text') is not None:
64
64
-
LOGGER.info("Skipping '%s'! Quote..", note['id'])
77
77
+
if note.get("text") is not None:
78
78
+
LOGGER.info("Skipping '%s'! Quote..", note["id"])
65
79
return
66
66
-
67
67
-
if renote.get('userId') != self.user_id:
68
68
-
LOGGER.info("Skipping '%s'! Reblog of other user..", note['id'])
80
80
+
81
81
+
if renote.get("userId") != self.user_id:
82
82
+
LOGGER.info("Skipping '%s'! Reblog of other user..", note["id"])
69
83
return
70
70
-
71
71
-
success = database.try_insert_repost(self.db, note['id'], renote['id'], self.user_id, self.service)
84
84
+
85
85
+
success = database.try_insert_repost(
86
86
+
self.db, note["id"], renote["id"], self.user_id, self.service
87
87
+
)
72
88
if not success:
73
73
-
LOGGER.info("Skipping '%s' as renoted note was not found in db!", note['id'])
89
89
+
LOGGER.info(
90
90
+
"Skipping '%s' as renoted note was not found in db!", note["id"]
91
91
+
)
74
92
return
75
75
-
93
93
+
76
94
for output in outputs:
77
77
-
output.accept_repost(note['id'], renote['id'])
95
95
+
output.accept_repost(note["id"], renote["id"])
78
96
return
79
79
-
80
80
-
reply_id: str | None = note.get('replyId')
97
97
+
98
98
+
reply_id: str | None = note.get("replyId")
81
99
if reply_id:
82
82
-
if note.get('reply', {}).get('userId') != self.user_id:
83
83
-
LOGGER.info("Skipping '%s'! Reply to other user..", note['id'])
100
100
+
if note.get("reply", {}).get("userId") != self.user_id:
101
101
+
LOGGER.info("Skipping '%s'! Reply to other user..", note["id"])
84
102
return
85
85
-
86
86
-
success = database.try_insert_post(self.db, note['id'], reply_id, self.user_id, self.service)
103
103
+
104
104
+
success = database.try_insert_post(
105
105
+
self.db, note["id"], reply_id, self.user_id, self.service
106
106
+
)
87
107
if not success:
88
88
-
LOGGER.info("Skipping '%s' as parent note was not found in db!", note['id'])
108
108
+
LOGGER.info("Skipping '%s' as parent note was not found in db!", note["id"])
89
109
return
90
90
-
91
91
-
mention_handles: dict = note.get('mentionHandles') or {}
92
92
-
tags: list[str] = note.get('tags') or []
93
93
-
110
110
+
111
111
+
mention_handles: dict = note.get("mentionHandles") or {}
112
112
+
tags: list[str] = note.get("tags") or []
113
113
+
94
114
handles: list[tuple[str, str]] = []
95
115
for key, value in mention_handles.items():
96
116
handles.append((value, value))
97
97
-
98
98
-
tokens = md_util.tokenize_markdown(note.get('text', ''), tags, handles)
117
117
+
118
118
+
tokens = md_util.tokenize_markdown(note.get("text", ""), tags, handles)
99
119
if not cross.test_filters(tokens, self.options.filters):
100
100
-
LOGGER.info("Skipping '%s'. Matched a filter!", note['id'])
120
120
+
LOGGER.info("Skipping '%s'. Matched a filter!", note["id"])
101
121
return
102
102
-
103
103
-
LOGGER.info("Crossposting '%s'...", note['id'])
104
104
-
122
122
+
123
123
+
LOGGER.info("Crossposting '%s'...", note["id"])
124
124
+
105
125
media_attachments: list[MediaInfo] = []
106
106
-
for attachment in note.get('files', []):
107
107
-
LOGGER.info("Downloading %s...", attachment['url'])
108
108
-
info = download_media(attachment['url'], attachment.get('comment') or '')
126
126
+
for attachment in note.get("files", []):
127
127
+
LOGGER.info("Downloading %s...", attachment["url"])
128
128
+
info = download_media(attachment["url"], attachment.get("comment") or "")
109
129
if not info:
110
110
-
LOGGER.error("Skipping '%s'. Failed to download media!", note['id'])
130
130
+
LOGGER.error("Skipping '%s'. Failed to download media!", note["id"])
111
131
return
112
132
media_attachments.append(info)
113
113
-
133
133
+
114
134
cross_post = MisskeyPost(self.service, note, tokens, media_attachments)
115
135
for output in outputs:
116
136
output.accept_post(cross_post)
117
117
-
137
137
+
118
138
def _on_delete(self, outputs: list[cross.Output], note: dict):
119
139
# TODO handle deletes
120
140
pass
121
121
-
141
141
+
122
142
def _on_message(self, outputs: list[cross.Output], data: dict):
123
123
-
124
124
-
if data['type'] == 'channel':
125
125
-
type: str = data['body']['type']
126
126
-
if type == 'note' or type == 'reply':
127
127
-
note_body = data['body']['body']
143
143
+
if data["type"] == "channel":
144
144
+
type: str = data["body"]["type"]
145
145
+
if type == "note" or type == "reply":
146
146
+
note_body = data["body"]["body"]
128
147
self._on_note(outputs, note_body)
129
148
return
130
130
-
149
149
+
131
150
pass
132
132
-
151
151
+
133
152
async def _send_keepalive(self, ws: websockets.WebSocketClientProtocol):
134
153
while ws.open:
135
154
try:
···
143
162
except Exception as e:
144
163
LOGGER.error(f"Error sending keepalive: {e}")
145
164
break
146
146
-
165
165
+
147
166
async def _subscribe_to_home(self, ws: websockets.WebSocketClientProtocol):
148
148
-
await ws.send(json.dumps({
149
149
-
"type": "connect",
150
150
-
"body": {
151
151
-
"channel": "homeTimeline",
152
152
-
"id": str(uuid.uuid4())
153
153
-
}
154
154
-
}))
167
167
+
await ws.send(
168
168
+
json.dumps(
169
169
+
{
170
170
+
"type": "connect",
171
171
+
"body": {"channel": "homeTimeline", "id": str(uuid.uuid4())},
172
172
+
}
173
173
+
)
174
174
+
)
155
175
LOGGER.info("Subscribed to 'homeTimeline' channel...")
156
156
-
157
157
-
158
158
-
async def listen(self, outputs: list[cross.Output], submit: Callable[[Callable[[], Any]], Any]):
159
159
-
streaming: str = f"wss://{self.service.split("://", 1)[1]}"
176
176
+
177
177
+
async def listen(
178
178
+
self, outputs: list[cross.Output], submit: Callable[[Callable[[], Any]], Any]
179
179
+
):
180
180
+
streaming: str = f"wss://{self.service.split('://', 1)[1]}"
160
181
url: str = f"{streaming}/streaming?i={self.token}"
161
161
-
162
162
-
async for ws in websockets.connect(url, extra_headers={"User-Agent": "XPost/0.0.3"}):
182
182
+
183
183
+
async for ws in websockets.connect(
184
184
+
url, extra_headers={"User-Agent": "XPost/0.0.3"}
185
185
+
):
163
186
try:
164
187
LOGGER.info("Listening to %s...", streaming)
165
188
await self._subscribe_to_home(ws)
166
166
-
189
189
+
167
190
async def listen_for_messages():
168
191
async for msg in ws:
169
192
# TODO listen to deletes somehow
170
193
submit(lambda: self._on_message(outputs, json.loads(msg)))
171
171
-
194
194
+
172
195
keepalive = asyncio.create_task(self._send_keepalive(ws))
173
196
listen = asyncio.create_task(listen_for_messages())
174
174
-
197
197
+
175
198
await asyncio.gather(keepalive, listen)
176
199
except websockets.ConnectionClosedError as e:
177
200
LOGGER.error(e, stack_info=True, exc_info=True)
178
201
LOGGER.info("Reconnecting to %s...", streaming)
179
179
-
continue
202
202
+
continue
+8
-5
misskey/mfm_util.py
reviewed
···
1
1
-
import re, cross
1
1
+
import re
2
2
+
3
3
+
import cross
4
4
+
5
5
+
MFM_PATTERN = re.compile(r"\$\[([^\[\]]+)\]")
2
6
3
3
-
MFM_PATTERN = re.compile(r'\$\[([^\[\]]+)\]')
4
7
5
8
def strip_mfm(tokens: list[cross.Token]) -> tuple[list[cross.Token], bool]:
6
9
modified = False
···
22
25
23
26
return tokens, modified
24
27
28
28
+
25
29
def __strip_mfm(text: str) -> str:
26
30
def match_contents(match: re.Match[str]):
27
31
content = match.group(1).strip()
28
28
-
parts = content.split(' ', 1)
29
29
-
return parts[1] if len(parts) > 1 else ''
32
32
+
parts = content.split(" ", 1)
33
33
+
return parts[1] if len(parts) > 1 else ""
30
34
31
35
while MFM_PATTERN.search(text):
32
36
text = MFM_PATTERN.sub(match_contents, text)
33
37
34
38
return text
35
35
-
+118
-67
util/database.py
reviewed
···
1
1
+
import json
2
2
+
import queue
1
3
import sqlite3
2
2
-
from concurrent.futures import Future
3
4
import threading
4
4
-
import queue
5
5
-
import json
5
5
+
from concurrent.futures import Future
6
6
7
7
-
class DataBaseWorker():
7
7
+
8
8
+
class DataBaseWorker:
8
9
def __init__(self, database: str) -> None:
9
10
super(DataBaseWorker, self).__init__()
10
11
self.database = database
···
14
15
self.conn = sqlite3.connect(self.database, check_same_thread=False)
15
16
self.lock = threading.Lock()
16
17
self.thread.start()
17
17
-
18
18
+
18
19
def _run(self):
19
20
while not self.shutdown_event.is_set():
20
21
try:
···
29
30
self.queue.task_done()
30
31
except queue.Empty:
31
32
continue
32
32
-
33
33
-
def execute(self, sql: str, params = ()):
33
33
+
34
34
+
def execute(self, sql: str, params=()):
34
35
def task(conn: sqlite3.Connection):
35
36
cursor = conn.execute(sql, params)
36
37
conn.commit()
37
38
return cursor.fetchall()
38
38
-
39
39
+
39
40
future = Future()
40
41
self.queue.put((task, future))
41
42
return future.result()
42
42
-
43
43
+
43
44
def close(self):
44
45
self.shutdown_event.set()
45
46
self.thread.join()
46
47
with self.lock:
47
48
self.conn.close()
49
49
+
48
50
49
51
def try_insert_repost(
50
52
db: DataBaseWorker,
51
53
post_id: str,
52
54
reposted_id: str,
53
55
input_user: str,
54
54
-
input_service: str) -> bool:
55
55
-
56
56
+
input_service: str,
57
57
+
) -> bool:
56
58
reposted = find_post(db, reposted_id, input_user, input_service)
57
59
if not reposted:
58
60
return False
59
59
-
60
60
-
insert_repost(db, post_id, reposted['id'], input_user, input_service)
61
61
+
62
62
+
insert_repost(db, post_id, reposted["id"], input_user, input_service)
61
63
return True
62
62
-
64
64
+
63
65
64
66
def try_insert_post(
65
65
-
db: DataBaseWorker,
67
67
+
db: DataBaseWorker,
66
68
post_id: str,
67
69
in_reply: str | None,
68
70
input_user: str,
69
69
-
input_service: str) -> bool:
71
71
+
input_service: str,
72
72
+
) -> bool:
70
73
root_id = None
71
74
parent_id = None
72
72
-
75
75
+
73
76
if in_reply:
74
77
parent_post = find_post(db, in_reply, input_user, input_service)
75
78
if not parent_post:
76
79
return False
77
77
-
78
78
-
root_id = parent_post['id']
80
80
+
81
81
+
root_id = parent_post["id"]
79
82
parent_id = root_id
80
80
-
if parent_post['root_id']:
81
81
-
root_id = parent_post['root_id']
82
82
-
83
83
+
if parent_post["root_id"]:
84
84
+
root_id = parent_post["root_id"]
85
85
+
83
86
if root_id and parent_id:
84
84
-
insert_reply(db,post_id, input_user, input_service, parent_id, root_id)
87
87
+
insert_reply(db, post_id, input_user, input_service, parent_id, root_id)
85
88
else:
86
89
insert_post(db, post_id, input_user, input_service)
87
87
-
90
90
+
88
91
return True
89
92
90
90
-
def insert_repost(db: DataBaseWorker, identifier: str, reposted_id: int, user_id: str, serivce: str) -> int:
93
93
+
94
94
+
def insert_repost(
95
95
+
db: DataBaseWorker, identifier: str, reposted_id: int, user_id: str, serivce: str
96
96
+
) -> int:
91
97
db.execute(
92
98
"""
93
99
INSERT INTO posts (user_id, service, identifier, reposted_id)
94
100
VALUES (?, ?, ?, ?);
95
95
-
""", (user_id, serivce, identifier, reposted_id))
101
101
+
""",
102
102
+
(user_id, serivce, identifier, reposted_id),
103
103
+
)
96
104
return db.execute("SELECT last_insert_rowid();", ())[0][0]
105
105
+
97
106
98
107
def insert_post(db: DataBaseWorker, identifier: str, user_id: str, serivce: str) -> int:
99
108
db.execute(
100
109
"""
101
110
INSERT INTO posts (user_id, service, identifier)
102
111
VALUES (?, ?, ?);
103
103
-
""", (user_id, serivce, identifier))
112
112
+
""",
113
113
+
(user_id, serivce, identifier),
114
114
+
)
104
115
return db.execute("SELECT last_insert_rowid();", ())[0][0]
105
116
106
106
-
def insert_reply(db: DataBaseWorker, identifier: str, user_id: str, serivce: str, parent: int, root: int) -> int:
117
117
+
118
118
+
def insert_reply(
119
119
+
db: DataBaseWorker,
120
120
+
identifier: str,
121
121
+
user_id: str,
122
122
+
serivce: str,
123
123
+
parent: int,
124
124
+
root: int,
125
125
+
) -> int:
107
126
db.execute(
108
127
"""
109
128
INSERT INTO posts (user_id, service, identifier, parent_id, root_id)
110
129
VALUES (?, ?, ?, ?, ?);
111
111
-
""", (user_id, serivce, identifier, parent, root))
130
130
+
""",
131
131
+
(user_id, serivce, identifier, parent, root),
132
132
+
)
112
133
return db.execute("SELECT last_insert_rowid();", ())[0][0]
113
134
135
135
+
114
136
def insert_mapping(db: DataBaseWorker, original: int, mapped: int):
115
115
-
db.execute("""
137
137
+
db.execute(
138
138
+
"""
116
139
INSERT INTO mappings (original_post_id, mapped_post_id)
117
140
VALUES (?, ?);
118
118
-
""", (original, mapped))
141
141
+
""",
142
142
+
(original, mapped),
143
143
+
)
144
144
+
119
145
120
146
def delete_post(db: DataBaseWorker, identifier: str, user_id: str, serivce: str):
121
147
db.execute(
···
124
150
WHERE identifier = ?
125
151
AND service = ?
126
152
AND user_id = ?
127
127
-
""", (identifier, serivce, user_id))
128
128
-
153
153
+
""",
154
154
+
(identifier, serivce, user_id),
155
155
+
)
156
156
+
157
157
+
129
158
def fetch_data(db: DataBaseWorker, identifier: str, user_id: str, service: str) -> dict:
130
159
result = db.execute(
131
160
"""
···
134
163
WHERE identifier = ?
135
164
AND user_id = ?
136
165
AND service = ?
137
137
-
""", (identifier, user_id, service))
166
166
+
""",
167
167
+
(identifier, user_id, service),
168
168
+
)
138
169
if not result or not result[0]:
139
170
return {}
140
171
return json.loads(result[0][0])
141
172
142
142
-
def store_data(db: DataBaseWorker, identifier: str, user_id: str, service: str, extra_data: dict) -> None:
173
173
+
174
174
+
def store_data(
175
175
+
db: DataBaseWorker, identifier: str, user_id: str, service: str, extra_data: dict
176
176
+
) -> None:
143
177
db.execute(
144
178
"""
145
179
UPDATE posts
···
148
182
AND user_id = ?
149
183
AND service = ?
150
184
""",
151
151
-
(json.dumps(extra_data), identifier, user_id, service)
185
185
+
(json.dumps(extra_data), identifier, user_id, service),
152
186
)
153
187
154
154
-
def find_mappings(db: DataBaseWorker, original_post: int, service: str, user_id: str) -> list[str]:
188
188
+
189
189
+
def find_mappings(
190
190
+
db: DataBaseWorker, original_post: int, service: str, user_id: str
191
191
+
) -> list[str]:
155
192
return db.execute(
156
193
"""
157
194
SELECT p.identifier
···
163
200
AND p.user_id = ?
164
201
ORDER BY p.id;
165
202
""",
166
166
-
(original_post, service, user_id))
167
167
-
203
203
+
(original_post, service, user_id),
204
204
+
)
205
205
+
206
206
+
168
207
def find_post_by_id(db: DataBaseWorker, id: int) -> dict | None:
169
208
result = db.execute(
170
209
"""
171
210
SELECT user_id, service, identifier, parent_id, root_id, reposted_id
172
211
FROM posts
173
212
WHERE id = ?
174
174
-
""", (id,))
213
213
+
""",
214
214
+
(id,),
215
215
+
)
175
216
if not result:
176
217
return None
177
218
user_id, service, identifier, parent_id, root_id, reposted_id = result[0]
178
219
return {
179
179
-
'user_id': user_id,
180
180
-
'service': service,
181
181
-
'identifier': identifier,
182
182
-
'parent_id': parent_id,
183
183
-
'root_id': root_id,
184
184
-
'reposted_id': reposted_id
220
220
+
"user_id": user_id,
221
221
+
"service": service,
222
222
+
"identifier": identifier,
223
223
+
"parent_id": parent_id,
224
224
+
"root_id": root_id,
225
225
+
"reposted_id": reposted_id,
185
226
}
186
227
187
187
-
def find_post(db: DataBaseWorker, identifier: str, user_id: str, service: str) -> dict | None:
228
228
+
229
229
+
def find_post(
230
230
+
db: DataBaseWorker, identifier: str, user_id: str, service: str
231
231
+
) -> dict | None:
188
232
result = db.execute(
189
233
"""
190
234
SELECT id, parent_id, root_id, reposted_id
···
192
236
WHERE identifier = ?
193
237
AND user_id = ?
194
238
AND service = ?
195
195
-
""", (identifier, user_id, service))
239
239
+
""",
240
240
+
(identifier, user_id, service),
241
241
+
)
196
242
if not result:
197
243
return None
198
244
id, parent_id, root_id, reposted_id = result[0]
199
245
return {
200
200
-
'id': id,
201
201
-
'parent_id': parent_id,
202
202
-
'root_id': root_id,
203
203
-
'reposted_id': reposted_id
246
246
+
"id": id,
247
247
+
"parent_id": parent_id,
248
248
+
"root_id": root_id,
249
249
+
"reposted_id": reposted_id,
204
250
}
251
251
+
205
252
206
253
def find_mapped_thread(
207
207
-
db: DataBaseWorker,
254
254
+
db: DataBaseWorker,
208
255
parent_id: str,
209
256
input_user: str,
210
257
input_service: str,
211
258
output_user: str,
212
212
-
output_service: str):
213
213
-
259
259
+
output_service: str,
260
260
+
):
214
261
reply_data: dict | None = find_post(db, parent_id, input_user, input_service)
215
262
if not reply_data:
216
263
return None
217
217
-
218
218
-
reply_mappings: list[str] | None = find_mappings(db, reply_data['id'], output_service, output_user)
264
264
+
265
265
+
reply_mappings: list[str] | None = find_mappings(
266
266
+
db, reply_data["id"], output_service, output_user
267
267
+
)
219
268
if not reply_mappings:
220
269
return None
221
221
-
270
270
+
222
271
reply_identifier: str = reply_mappings[-1]
223
272
root_identifier: str = reply_mappings[0]
224
224
-
if reply_data['root_id']:
225
225
-
root_data = find_post_by_id(db, reply_data['root_id'])
273
273
+
if reply_data["root_id"]:
274
274
+
root_data = find_post_by_id(db, reply_data["root_id"])
226
275
if not root_data:
227
276
return None
228
228
-
229
229
-
root_mappings = find_mappings(db, reply_data['root_id'], output_service, output_user)
277
277
+
278
278
+
root_mappings = find_mappings(
279
279
+
db, reply_data["root_id"], output_service, output_user
280
280
+
)
230
281
if not root_mappings:
231
282
return None
232
283
root_identifier = root_mappings[0]
233
233
-
284
284
+
234
285
return (
235
235
-
root_identifier[0], # real ids
286
286
+
root_identifier[0], # real ids
236
287
reply_identifier[0],
237
237
-
reply_data['root_id'], # db ids
238
238
-
reply_data['id']
239
239
-
)
288
288
+
reply_data["root_id"], # db ids
289
289
+
reply_data["id"],
290
290
+
)
+82
-79
util/html_util.py
reviewed
···
1
1
from html.parser import HTMLParser
2
2
+
2
3
import cross
4
4
+
3
5
4
6
class HTMLPostTokenizer(HTMLParser):
5
7
def __init__(self) -> None:
6
8
super().__init__()
7
9
self.tokens: list[cross.Token] = []
8
8
-
10
10
+
9
11
self.mentions: list[tuple[str, str]]
10
12
self.tags: list[str]
11
11
-
13
13
+
12
14
self.in_pre = False
13
15
self.in_code = False
14
14
-
16
16
+
15
17
self.current_tag_stack = []
16
18
self.list_stack = []
17
17
-
19
19
+
18
20
self.anchor_stack = []
19
21
self.anchor_data = []
20
20
-
22
22
+
21
23
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
22
24
attrs_dict = dict(attrs)
23
23
-
25
25
+
24
26
def append_newline():
25
27
if self.tokens:
26
28
last_token = self.tokens[-1]
27
27
-
if isinstance(last_token, cross.TextToken) and not last_token.text.endswith('\n'):
28
28
-
self.tokens.append(cross.TextToken('\n'))
29
29
-
29
29
+
if isinstance(
30
30
+
last_token, cross.TextToken
31
31
+
) and not last_token.text.endswith("\n"):
32
32
+
self.tokens.append(cross.TextToken("\n"))
33
33
+
30
34
match tag:
31
31
-
case 'br':
32
32
-
self.tokens.append(cross.TextToken(' \n'))
33
33
-
case 'a':
34
34
-
href = attrs_dict.get('href', '')
35
35
+
case "br":
36
36
+
self.tokens.append(cross.TextToken(" \n"))
37
37
+
case "a":
38
38
+
href = attrs_dict.get("href", "")
35
39
self.anchor_stack.append(href)
36
36
-
case 'strong', 'b':
37
37
-
self.tokens.append(cross.TextToken('**'))
38
38
-
case 'em', 'i':
39
39
-
self.tokens.append(cross.TextToken('*'))
40
40
-
case 'del', 's':
41
41
-
self.tokens.append(cross.TextToken('~~'))
42
42
-
case 'code':
40
40
+
case "strong", "b":
41
41
+
self.tokens.append(cross.TextToken("**"))
42
42
+
case "em", "i":
43
43
+
self.tokens.append(cross.TextToken("*"))
44
44
+
case "del", "s":
45
45
+
self.tokens.append(cross.TextToken("~~"))
46
46
+
case "code":
43
47
if not self.in_pre:
44
44
-
self.tokens.append(cross.TextToken('`'))
48
48
+
self.tokens.append(cross.TextToken("`"))
45
49
self.in_code = True
46
46
-
case 'pre':
50
50
+
case "pre":
47
51
append_newline()
48
48
-
self.tokens.append(cross.TextToken('```\n'))
52
52
+
self.tokens.append(cross.TextToken("```\n"))
49
53
self.in_pre = True
50
50
-
case 'blockquote':
54
54
+
case "blockquote":
51
55
append_newline()
52
52
-
self.tokens.append(cross.TextToken('> '))
53
53
-
case 'ul', 'ol':
56
56
+
self.tokens.append(cross.TextToken("> "))
57
57
+
case "ul", "ol":
54
58
self.list_stack.append(tag)
55
59
append_newline()
56
56
-
case 'li':
57
57
-
indent = ' ' * (len(self.list_stack) - 1)
58
58
-
if self.list_stack and self.list_stack[-1] == 'ul':
59
59
-
self.tokens.append(cross.TextToken(f'{indent}- '))
60
60
-
elif self.list_stack and self.list_stack[-1] == 'ol':
61
61
-
self.tokens.append(cross.TextToken(f'{indent}1. '))
60
60
+
case "li":
61
61
+
indent = " " * (len(self.list_stack) - 1)
62
62
+
if self.list_stack and self.list_stack[-1] == "ul":
63
63
+
self.tokens.append(cross.TextToken(f"{indent}- "))
64
64
+
elif self.list_stack and self.list_stack[-1] == "ol":
65
65
+
self.tokens.append(cross.TextToken(f"{indent}1. "))
62
66
case _:
63
63
-
if tag in {'h1', 'h2', 'h3', 'h4', 'h5', 'h6'}:
67
67
+
if tag in {"h1", "h2", "h3", "h4", "h5", "h6"}:
64
68
level = int(tag[1])
65
69
self.tokens.append(cross.TextToken("\n" + "#" * level + " "))
66
66
-
70
70
+
67
71
self.current_tag_stack.append(tag)
68
68
-
72
72
+
69
73
def handle_data(self, data: str) -> None:
70
74
if self.anchor_stack:
71
75
self.anchor_data.append(data)
72
76
else:
73
77
self.tokens.append(cross.TextToken(data))
74
74
-
78
78
+
75
79
def handle_endtag(self, tag: str) -> None:
76
80
if not self.current_tag_stack:
77
81
return
78
78
-
82
82
+
79
83
if tag in self.current_tag_stack:
80
84
self.current_tag_stack.remove(tag)
81
81
-
85
85
+
82
86
match tag:
83
83
-
case 'p':
84
84
-
self.tokens.append(cross.TextToken('\n\n'))
85
85
-
case 'a':
87
87
+
case "p":
88
88
+
self.tokens.append(cross.TextToken("\n\n"))
89
89
+
case "a":
86
90
href = self.anchor_stack.pop()
87
87
-
anchor_data = ''.join(self.anchor_data)
91
91
+
anchor_data = "".join(self.anchor_data)
88
92
self.anchor_data = []
89
89
-
90
90
-
if anchor_data.startswith('#'):
93
93
+
94
94
+
if anchor_data.startswith("#"):
91
95
as_tag = anchor_data[1:].lower()
92
96
if any(as_tag == block for block in self.tags):
93
97
self.tokens.append(cross.TagToken(anchor_data[1:]))
94
94
-
elif anchor_data.startswith('@'):
98
98
+
elif anchor_data.startswith("@"):
95
99
match = next(
96
96
-
(pair for pair in self.mentions if anchor_data in pair),
97
97
-
None
100
100
+
(pair for pair in self.mentions if anchor_data in pair), None
98
101
)
99
99
-
102
102
+
100
103
if match:
101
101
-
self.tokens.append(cross.MentionToken(match[1], ''))
104
104
+
self.tokens.append(cross.MentionToken(match[1], ""))
102
105
else:
103
106
self.tokens.append(cross.LinkToken(href, anchor_data))
104
104
-
case 'strong', 'b':
105
105
-
self.tokens.append(cross.TextToken('**'))
106
106
-
case 'em', 'i':
107
107
-
self.tokens.append(cross.TextToken('*'))
108
108
-
case 'del', 's':
109
109
-
self.tokens.append(cross.TextToken('~~'))
110
110
-
case 'code':
107
107
+
case "strong", "b":
108
108
+
self.tokens.append(cross.TextToken("**"))
109
109
+
case "em", "i":
110
110
+
self.tokens.append(cross.TextToken("*"))
111
111
+
case "del", "s":
112
112
+
self.tokens.append(cross.TextToken("~~"))
113
113
+
case "code":
111
114
if not self.in_pre and self.in_code:
112
112
-
self.tokens.append(cross.TextToken('`'))
115
115
+
self.tokens.append(cross.TextToken("`"))
113
116
self.in_code = False
114
114
-
case 'pre':
115
115
-
self.tokens.append(cross.TextToken('\n```\n'))
117
117
+
case "pre":
118
118
+
self.tokens.append(cross.TextToken("\n```\n"))
116
119
self.in_pre = False
117
117
-
case 'blockquote':
118
118
-
self.tokens.append(cross.TextToken('\n'))
119
119
-
case 'ul', 'ol':
120
120
+
case "blockquote":
121
121
+
self.tokens.append(cross.TextToken("\n"))
122
122
+
case "ul", "ol":
120
123
if self.list_stack:
121
124
self.list_stack.pop()
122
122
-
self.tokens.append(cross.TextToken('\n'))
123
123
-
case 'li':
124
124
-
self.tokens.append(cross.TextToken('\n'))
125
125
+
self.tokens.append(cross.TextToken("\n"))
126
126
+
case "li":
127
127
+
self.tokens.append(cross.TextToken("\n"))
125
128
case _:
126
126
-
if tag in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']:
127
127
-
self.tokens.append(cross.TextToken('\n'))
128
128
-
129
129
+
if tag in ["h1", "h2", "h3", "h4", "h5", "h6"]:
130
130
+
self.tokens.append(cross.TextToken("\n"))
131
131
+
129
132
def get_tokens(self) -> list[cross.Token]:
130
133
if not self.tokens:
131
134
return []
132
132
-
135
135
+
133
136
combined: list[cross.Token] = []
134
137
buffer: list[str] = []
135
135
-
138
138
+
136
139
def flush_buffer():
137
140
if buffer:
138
138
-
merged = ''.join(buffer)
141
141
+
merged = "".join(buffer)
139
142
combined.append(cross.TextToken(text=merged))
140
143
buffer.clear()
141
144
···
145
148
else:
146
149
flush_buffer()
147
150
combined.append(token)
148
148
-
151
151
+
149
152
flush_buffer()
150
150
-
153
153
+
151
154
if combined and isinstance(combined[-1], cross.TextToken):
152
152
-
if combined[-1].text.endswith('\n\n'):
155
155
+
if combined[-1].text.endswith("\n\n"):
153
156
combined[-1] = cross.TextToken(combined[-1].text[:-2])
154
157
return combined
155
155
-
158
158
+
156
159
def reset(self):
157
160
"""Reset the parser state for reuse."""
158
161
super().reset()
159
162
self.tokens = []
160
160
-
163
163
+
161
164
self.mentions = []
162
165
self.tags = []
163
163
-
166
166
+
164
167
self.in_pre = False
165
168
self.in_code = False
166
166
-
169
169
+
167
170
self.current_tag_stack = []
168
171
self.anchor_stack = []
169
169
-
self.list_stack = []
172
172
+
self.list_stack = []
+44
-33
util/md_util.py
reviewed
···
4
4
import util.html_util as html_util
5
5
import util.util as util
6
6
7
7
-
URL = re.compile(r'(?:(?:[A-Za-z][A-Za-z0-9+.-]*://)|mailto:)[^\s]+', re.IGNORECASE)
8
8
-
MD_INLINE_LINK = re.compile(r"\[([^\]]+)\]\(\s*((?:(?:[A-Za-z][A-Za-z0-9+.\-]*://)|mailto:)[^\s\)]+)\s*\)", re.IGNORECASE)
9
9
-
MD_AUTOLINK = re.compile(r"<((?:(?:[A-Za-z][A-Za-z0-9+.\-]*://)|mailto:)[^\s>]+)>", re.IGNORECASE)
10
10
-
HASHTAG = re.compile(r'(?<!\w)\#([\w]+)')
11
11
-
FEDIVERSE_HANDLE = re.compile(r'(?<![\w@])@([\w\.-]+)(?:@([\w\.-]+\.[\w\.-]+))?')
7
7
+
URL = re.compile(r"(?:(?:[A-Za-z][A-Za-z0-9+.-]*://)|mailto:)[^\s]+", re.IGNORECASE)
8
8
+
MD_INLINE_LINK = re.compile(
9
9
+
r"\[([^\]]+)\]\(\s*((?:(?:[A-Za-z][A-Za-z0-9+.\-]*://)|mailto:)[^\s\)]+)\s*\)",
10
10
+
re.IGNORECASE,
11
11
+
)
12
12
+
MD_AUTOLINK = re.compile(
13
13
+
r"<((?:(?:[A-Za-z][A-Za-z0-9+.\-]*://)|mailto:)[^\s>]+)>", re.IGNORECASE
14
14
+
)
15
15
+
HASHTAG = re.compile(r"(?<!\w)\#([\w]+)")
16
16
+
FEDIVERSE_HANDLE = re.compile(r"(?<![\w@])@([\w\.-]+)(?:@([\w\.-]+\.[\w\.-]+))?")
12
17
13
13
-
def tokenize_markdown(text: str, tags: list[str], handles: list[tuple[str, str]]) -> list[cross.Token]:
18
18
+
19
19
+
def tokenize_markdown(
20
20
+
text: str, tags: list[str], handles: list[tuple[str, str]]
21
21
+
) -> list[cross.Token]:
14
22
if not text:
15
23
return []
16
16
-
24
24
+
17
25
tokenizer = html_util.HTMLPostTokenizer()
18
26
tokenizer.mentions = handles
19
27
tokenizer.tags = tags
20
28
tokenizer.feed(text)
21
29
html_tokens = tokenizer.get_tokens()
22
22
-
30
30
+
23
31
tokens: list[cross.Token] = []
24
24
-
32
32
+
25
33
for tk in html_tokens:
26
34
if isinstance(tk, cross.TextToken):
27
35
tokens.extend(__tokenize_md(tk.text, tags, handles))
···
29
37
if not tk.label or util.canonical_label(tk.label, tk.href):
30
38
tokens.append(tk)
31
39
continue
32
32
-
40
40
+
33
41
tokens.extend(__tokenize_md(f"[{tk.label}]({tk.href})", tags, handles))
34
42
else:
35
43
tokens.append(tk)
36
36
-
44
44
+
37
45
return tokens
38
38
-
46
46
+
39
47
40
40
-
def __tokenize_md(text: str, tags: list[str], handles: list[tuple[str, str]]) -> list[cross.Token]:
48
48
+
def __tokenize_md(
49
49
+
text: str, tags: list[str], handles: list[tuple[str, str]]
50
50
+
) -> list[cross.Token]:
41
51
index: int = 0
42
52
total: int = len(text)
43
53
buffer: list[str] = []
44
44
-
54
54
+
45
55
tokens: list[cross.Token] = []
46
46
-
56
56
+
47
57
def flush():
48
58
nonlocal buffer
49
59
if buffer:
50
50
-
tokens.append(cross.TextToken(''.join(buffer)))
60
60
+
tokens.append(cross.TextToken("".join(buffer)))
51
61
buffer = []
52
52
-
62
62
+
53
63
while index < total:
54
54
-
if text[index] == '[':
64
64
+
if text[index] == "[":
55
65
md_inline = MD_INLINE_LINK.match(text, index)
56
66
if md_inline:
57
67
flush()
···
60
70
tokens.append(cross.LinkToken(href, label))
61
71
index = md_inline.end()
62
72
continue
63
63
-
64
64
-
if text[index] == '<':
73
73
+
74
74
+
if text[index] == "<":
65
75
md_auto = MD_AUTOLINK.match(text, index)
66
76
if md_auto:
67
77
flush()
···
69
79
tokens.append(cross.LinkToken(href, href))
70
80
index = md_auto.end()
71
81
continue
72
72
-
73
73
-
if text[index] == '#':
82
82
+
83
83
+
if text[index] == "#":
74
84
tag = HASHTAG.match(text, index)
75
85
if tag:
76
86
tag_text = tag.group(1)
···
79
89
tokens.append(cross.TagToken(tag_text))
80
90
index = tag.end()
81
91
continue
82
82
-
83
83
-
if text[index] == '@':
92
92
+
93
93
+
if text[index] == "@":
84
94
handle = FEDIVERSE_HANDLE.match(text, index)
85
95
if handle:
86
96
handle_text = handle.group(0)
87
97
stripped_handle = handle_text.strip()
88
88
-
98
98
+
89
99
match = next(
90
90
-
(pair for pair in handles if stripped_handle in pair),
91
91
-
None
100
100
+
(pair for pair in handles if stripped_handle in pair), None
92
101
)
93
93
-
102
102
+
94
103
if match:
95
104
flush()
96
96
-
tokens.append(cross.MentionToken(match[1], '')) # TODO: misskey doesn’t provide a uri
105
105
+
tokens.append(
106
106
+
cross.MentionToken(match[1], "")
107
107
+
) # TODO: misskey doesn’t provide a uri
97
108
index = handle.end()
98
109
continue
99
99
-
110
110
+
100
111
url = URL.match(text, index)
101
112
if url:
102
113
flush()
···
104
115
tokens.append(cross.LinkToken(href, href))
105
116
index = url.end()
106
117
continue
107
107
-
118
118
+
108
119
buffer.append(text[index])
109
120
index += 1
110
110
-
121
121
+
111
122
flush()
112
112
-
return tokens
123
123
+
return tokens
+73
-56
util/media.py
reviewed
···
1
1
+
import json
2
2
+
import os
3
3
+
import re
4
4
+
import subprocess
5
5
+
import urllib.parse
6
6
+
7
7
+
import magic
1
8
import requests
2
2
-
import subprocess
3
3
-
import json
4
4
-
import re, urllib.parse, os
9
9
+
5
10
from util.util import LOGGER
6
6
-
import magic
7
11
8
12
FILENAME = re.compile(r'filename="?([^\";]*)"?')
9
13
MAGIC = magic.Magic(mime=True)
10
14
11
11
-
class MediaInfo():
15
15
+
16
16
+
class MediaInfo:
12
17
def __init__(self, url: str, name: str, mime: str, alt: str, io: bytes) -> None:
13
18
self.url = url
14
19
self.name = name
···
16
21
self.alt = alt
17
22
self.io = io
18
23
24
24
+
19
25
def download_media(url: str, alt: str) -> MediaInfo | None:
20
26
name = get_filename_from_url(url)
21
27
io = download_blob(url, max_bytes=100_000_000)
···
24
30
return None
25
31
mime = MAGIC.from_buffer(io)
26
32
if not mime:
27
27
-
mime = 'application/octet-stream'
33
33
+
mime = "application/octet-stream"
28
34
return MediaInfo(url, name, mime, alt, io)
35
35
+
29
36
30
37
def get_filename_from_url(url):
31
38
try:
32
39
response = requests.head(url, allow_redirects=True)
33
33
-
disposition = response.headers.get('Content-Disposition')
40
40
+
disposition = response.headers.get("Content-Disposition")
34
41
if disposition:
35
42
filename = FILENAME.findall(disposition)
36
43
if filename:
···
40
47
41
48
parsed_url = urllib.parse.urlparse(url)
42
49
base_name = os.path.basename(parsed_url.path)
43
43
-
50
50
+
44
51
# hardcoded fix to return the cid for pds
45
45
-
if base_name == 'com.atproto.sync.getBlob':
52
52
+
if base_name == "com.atproto.sync.getBlob":
46
53
qs = urllib.parse.parse_qs(parsed_url.query)
47
47
-
if qs and qs.get('cid'):
48
48
-
return qs['cid'][0]
54
54
+
if qs and qs.get("cid"):
55
55
+
return qs["cid"][0]
49
56
50
57
return base_name
51
58
59
59
+
52
60
def probe_bytes(bytes: bytes) -> dict:
53
61
cmd = [
54
54
-
'ffprobe',
55
55
-
'-v', 'error',
56
56
-
'-show_format',
57
57
-
'-show_streams',
58
58
-
'-print_format', 'json',
59
59
-
'pipe:0'
62
62
+
"ffprobe",
63
63
+
"-v", "error",
64
64
+
"-show_format",
65
65
+
"-show_streams",
66
66
+
"-print_format", "json",
67
67
+
"pipe:0",
60
68
]
61
61
-
proc = subprocess.run(cmd, input=bytes, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
69
69
+
proc = subprocess.run(
70
70
+
cmd, input=bytes, stdout=subprocess.PIPE, stderr=subprocess.PIPE
71
71
+
)
62
72
63
73
if proc.returncode != 0:
64
74
raise RuntimeError(f"ffprobe failed: {proc.stderr.decode()}")
65
75
66
76
return json.loads(proc.stdout)
67
77
78
78
+
68
79
def convert_to_mp4(video_bytes: bytes) -> bytes:
69
80
cmd = [
70
70
-
'ffmpeg',
71
71
-
'-i', 'pipe:0',
72
72
-
'-c:v', 'libx264',
73
73
-
'-crf', '30',
74
74
-
'-preset', 'slow',
75
75
-
'-c:a', 'aac',
76
76
-
'-b:a', '128k',
77
77
-
'-movflags', 'frag_keyframe+empty_moov+default_base_moof',
78
78
-
'-f', 'mp4',
79
79
-
'pipe:1'
81
81
+
"ffmpeg",
82
82
+
"-i", "pipe:0",
83
83
+
"-c:v", "libx264",
84
84
+
"-crf", "30",
85
85
+
"-preset", "slow",
86
86
+
"-c:a", "aac",
87
87
+
"-b:a", "128k",
88
88
+
"-movflags", "frag_keyframe+empty_moov+default_base_moof",
89
89
+
"-f", "mp4",
90
90
+
"pipe:1",
80
91
]
81
81
-
82
82
-
proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
92
92
+
93
93
+
proc = subprocess.Popen(
94
94
+
cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE
95
95
+
)
83
96
out_bytes, err = proc.communicate(input=video_bytes)
84
84
-
97
97
+
85
98
if proc.returncode != 0:
86
99
raise RuntimeError(f"ffmpeg compress failed: {err.decode()}")
87
87
-
100
100
+
88
101
return out_bytes
89
102
103
103
+
90
104
def compress_image(image_bytes: bytes, quality: int = 90):
91
105
cmd = [
92
92
-
'ffmpeg',
93
93
-
'-f', 'image2pipe',
94
94
-
'-i', 'pipe:0',
95
95
-
'-c:v', 'webp',
96
96
-
'-q:v', str(quality),
97
97
-
'-f', 'image2pipe',
98
98
-
'pipe:1'
99
99
-
]
106
106
+
"ffmpeg",
107
107
+
"-f", "image2pipe",
108
108
+
"-i", "pipe:0",
109
109
+
"-c:v", "webp",
110
110
+
"-q:v", str(quality),
111
111
+
"-f", "image2pipe",
112
112
+
"pipe:1",
113
113
+
]
100
114
101
101
-
proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
115
115
+
proc = subprocess.Popen(
116
116
+
cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE
117
117
+
)
102
118
out_bytes, err = proc.communicate(input=image_bytes)
103
103
-
119
119
+
104
120
if proc.returncode != 0:
105
121
raise RuntimeError(f"ffmpeg compress failed: {err.decode()}")
106
106
-
122
122
+
107
123
return out_bytes
124
124
+
108
125
109
126
def download_blob(url: str, max_bytes: int = 5_000_000) -> bytes | None:
110
127
response = requests.get(url, stream=True, timeout=20)
111
128
if response.status_code != 200:
112
129
LOGGER.info("Failed to download %s! %s", url, response.text)
113
130
return None
114
114
-
131
131
+
115
132
downloaded_bytes = b""
116
133
current_size = 0
117
117
-
134
134
+
118
135
for chunk in response.iter_content(chunk_size=8192):
119
119
-
if not chunk:
136
136
+
if not chunk:
120
137
continue
121
121
-
138
138
+
122
139
current_size += len(chunk)
123
140
if current_size > max_bytes:
124
141
response.close()
125
142
return None
126
126
-
143
143
+
127
144
downloaded_bytes += chunk
128
128
-
145
145
+
129
146
return downloaded_bytes
130
130
-
147
147
+
131
148
132
149
def get_media_meta(bytes: bytes):
133
150
probe = probe_bytes(bytes)
134
134
-
streams = [s for s in probe['streams'] if s['codec_type'] == 'video']
151
151
+
streams = [s for s in probe["streams"] if s["codec_type"] == "video"]
135
152
if not streams:
136
153
raise ValueError("No video stream found")
137
137
-
154
154
+
138
155
media = streams[0]
139
156
return {
140
140
-
'width': int(media['width']),
141
141
-
'height': int(media['height']),
142
142
-
'duration': float(media.get('duration', probe['format'].get('duration', -1)))
143
143
-
}
157
157
+
"width": int(media["width"]),
158
158
+
"height": int(media["height"]),
159
159
+
"duration": float(media.get("duration", probe["format"].get("duration", -1))),
160
160
+
}
+20
-13
util/util.py
reviewed
···
1
1
-
import logging, sys, os
2
1
import json
2
2
+
import logging
3
3
+
import os
4
4
+
import sys
3
5
4
6
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
5
7
LOGGER = logging.getLogger("XPost")
6
8
7
7
-
def as_json(obj, indent=None,sort_keys=False) -> str:
9
9
+
10
10
+
def as_json(obj, indent=None, sort_keys=False) -> str:
8
11
return json.dumps(
9
9
-
obj.__dict__ if not isinstance(obj, dict) else obj,
10
10
-
default=lambda o: o.__json__() if hasattr(o, '__json__') else o.__dict__,
12
12
+
obj.__dict__ if not isinstance(obj, dict) else obj,
13
13
+
default=lambda o: o.__json__() if hasattr(o, "__json__") else o.__dict__,
11
14
indent=indent,
12
12
-
sort_keys=sort_keys)
15
15
+
sort_keys=sort_keys,
16
16
+
)
17
17
+
13
18
14
19
def canonical_label(label: str | None, href: str):
15
20
if not label or label == href:
16
21
return True
17
17
-
18
18
-
split = href.split('://', 1)
22
22
+
23
23
+
split = href.split("://", 1)
19
24
if len(split) > 1:
20
25
if split[1] == label:
21
26
return True
22
22
-
27
27
+
23
28
return False
24
29
30
30
+
25
31
def safe_get(obj: dict, key: str, default):
26
32
val = obj.get(key, default)
27
33
return val if val else default
34
34
+
28
35
29
36
def as_envvar(text: str | None) -> str | None:
30
37
if not text:
31
38
return None
32
32
-
33
33
-
if text.startswith('env:'):
34
34
-
return os.environ.get(text[4:], '')
35
35
-
36
36
-
return text
39
39
+
40
40
+
if text.startswith("env:"):
41
41
+
return os.environ.get(text[4:], "")
42
42
+
43
43
+
return text