Rename some filers and folder for clarity
This commit is contained in:
12
src/backend/__about__.py
Normal file
12
src/backend/__about__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""<SHORTDESC>"""
|
||||
|
||||
__pkg_name__ = "harmonyqml"
|
||||
__pretty_name__ = "Harmony QML"
|
||||
__version__ = "0.2.3"
|
||||
__status__ = "Development"
|
||||
# __status__ = "Production"
|
||||
|
||||
__author__ = "miruka"
|
||||
__email__ = "miruka@disroot.org"
|
||||
|
||||
__license__ = "LGPLv3"
|
1
src/backend/__init__.py
Normal file
1
src/backend/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .app import APP # noqa
|
110
src/backend/app.py
Normal file
110
src/backend/app.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import asyncio
|
||||
import logging as log
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
from concurrent.futures import Future
|
||||
from operator import attrgetter
|
||||
from threading import Thread
|
||||
from typing import Coroutine, Sequence
|
||||
|
||||
import nio
|
||||
from appdirs import AppDirs
|
||||
|
||||
from . import __about__
|
||||
from .pyotherside_events import CoroutineDone
|
||||
|
||||
log.getLogger().setLevel(log.INFO)
|
||||
nio.logger_group.level = nio.log.logbook.ERROR
|
||||
nio.log.logbook.StreamHandler(sys.stderr).push_application()
|
||||
|
||||
try:
|
||||
import uvloop
|
||||
except ModuleNotFoundError:
|
||||
UVLOOP = False
|
||||
log.info("uvloop not available, using default asyncio loop.")
|
||||
else:
|
||||
UVLOOP = True
|
||||
log.info("uvloop is available.")
|
||||
|
||||
|
||||
class App:
|
||||
def __init__(self) -> None:
|
||||
self.appdirs = AppDirs(appname=__about__.__pkg_name__, roaming=True)
|
||||
|
||||
from .backend import Backend
|
||||
self.backend = Backend(app=self)
|
||||
self.debug = False
|
||||
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.loop_thread = Thread(target=self._loop_starter)
|
||||
self.loop_thread.start()
|
||||
|
||||
|
||||
def _loop_starter(self) -> None:
|
||||
asyncio.set_event_loop(self.loop)
|
||||
|
||||
if UVLOOP:
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
self.loop.run_forever()
|
||||
|
||||
|
||||
def run_in_loop(self, coro: Coroutine) -> Future:
|
||||
return asyncio.run_coroutine_threadsafe(coro, self.loop)
|
||||
|
||||
|
||||
def _call_coro(self, coro: Coroutine, uuid: str) -> Future:
|
||||
def on_done(future: Future) -> None:
|
||||
result = exception = trace = None
|
||||
|
||||
try:
|
||||
result = future.result()
|
||||
except Exception as err:
|
||||
exception = err
|
||||
trace = traceback.format_exc().rstrip()
|
||||
|
||||
CoroutineDone(uuid, result, exception, trace)
|
||||
|
||||
future = self.run_in_loop(coro)
|
||||
future.add_done_callback(on_done)
|
||||
return future
|
||||
|
||||
|
||||
def call_backend_coro(self, name: str, uuid: str, args: Sequence[str] = (),
|
||||
) -> Future:
|
||||
return self._call_coro(attrgetter(name)(self.backend)(*args), uuid)
|
||||
|
||||
|
||||
def call_client_coro(self,
|
||||
account_id: str,
|
||||
name: str,
|
||||
uuid: str,
|
||||
args: Sequence[str] = ()) -> Future:
|
||||
client = self.backend.clients[account_id]
|
||||
return self._call_coro(attrgetter(name)(client)(*args), uuid)
|
||||
|
||||
|
||||
def pdb(self, additional_data: Sequence = ()) -> None:
|
||||
ad = additional_data # noqa
|
||||
rl = self.run_in_loop # noqa
|
||||
ba = self.backend # noqa
|
||||
mo = self.backend.models # noqa
|
||||
cl = self.backend.clients
|
||||
tcl = lambda user: cl[f"@{user}:matrix.org"] # noqa
|
||||
|
||||
from .models.items import Account, Room, Member, Event, Device # noqa
|
||||
|
||||
p = print # pdb's `p` doesn't print a class's __str__ # noqa
|
||||
from pprintpp import pprint as pp # noqa
|
||||
|
||||
log.info("\n=> Run `socat readline tcp:127.0.0.1:4444` in a terminal "
|
||||
"to connect to pdb.")
|
||||
import remote_pdb
|
||||
remote_pdb.RemotePdb("127.0.0.1", 4444).set_trace()
|
||||
|
||||
|
||||
# Make CTRL-C work again
|
||||
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
||||
|
||||
APP = App()
|
233
src/backend/backend.py
Normal file
233
src/backend/backend.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import asyncio
|
||||
import logging as log
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Tuple
|
||||
|
||||
import hsluv
|
||||
|
||||
import nio
|
||||
|
||||
from .app import App
|
||||
from .errors import MatrixError
|
||||
from .matrix_client import MatrixClient
|
||||
from .models.items import Account, Device, Event, Member, Room, Upload
|
||||
from .models.model_store import ModelStore
|
||||
|
||||
|
||||
class Backend:
|
||||
def __init__(self, app: App) -> None:
|
||||
self.app = app
|
||||
|
||||
from . import config_files
|
||||
self.saved_accounts = config_files.Accounts(self)
|
||||
self.ui_settings = config_files.UISettings(self)
|
||||
self.ui_state = config_files.UIState(self)
|
||||
self.history = config_files.History(self)
|
||||
|
||||
self.models = ModelStore(allowed_key_types={
|
||||
Account, # Logged-in accounts
|
||||
(Device, str), # Devices of user_id
|
||||
(Room, str), # Rooms for user_id
|
||||
(Upload, str), # Uploads running in room_id
|
||||
(Member, str, str), # Members for account user_id for room_id
|
||||
(Event, str, str), # Events for account user_id for room_id
|
||||
})
|
||||
|
||||
self.clients: Dict[str, MatrixClient] = {}
|
||||
|
||||
self.profile_cache: Dict[str, nio.ProfileGetResponse] = {}
|
||||
self.get_profile_locks: DefaultDict[str, asyncio.Lock] = \
|
||||
DefaultDict(asyncio.Lock) # {user_id: lock}
|
||||
|
||||
self.send_locks: DefaultDict[str, asyncio.Lock] = \
|
||||
DefaultDict(asyncio.Lock) # {room_id: lock}
|
||||
|
||||
from .media_cache import MediaCache
|
||||
cache_dir = Path(self.app.appdirs.user_cache_dir)
|
||||
self.media_cache = MediaCache(self, cache_dir)
|
||||
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{type(self).__name__}(clients={self.clients!r})"
|
||||
|
||||
|
||||
# Clients management
|
||||
|
||||
async def login_client(self,
|
||||
user: str,
|
||||
password: str,
|
||||
device_id: Optional[str] = None,
|
||||
homeserver: str = "https://matrix.org",
|
||||
) -> str:
|
||||
|
||||
client = MatrixClient(
|
||||
self, user=user, homeserver=homeserver, device_id=device_id,
|
||||
)
|
||||
|
||||
try:
|
||||
await client.login(password)
|
||||
except MatrixError:
|
||||
await client.close()
|
||||
raise
|
||||
|
||||
self.clients[client.user_id] = client
|
||||
self.models[Account][client.user_id] = Account(client.user_id)
|
||||
return client.user_id
|
||||
|
||||
|
||||
async def resume_client(self,
|
||||
user_id: str,
|
||||
token: str,
|
||||
device_id: str,
|
||||
homeserver: str = "https://matrix.org") -> None:
|
||||
|
||||
client = MatrixClient(
|
||||
backend=self,
|
||||
user=user_id, homeserver=homeserver, device_id=device_id,
|
||||
)
|
||||
|
||||
self.clients[user_id] = client
|
||||
self.models[Account][user_id] = Account(user_id)
|
||||
|
||||
await client.resume(user_id=user_id, token=token, device_id=device_id)
|
||||
|
||||
|
||||
async def load_saved_accounts(self) -> Tuple[str, ...]:
|
||||
async def resume(user_id: str, info: Dict[str, str]) -> str:
|
||||
await self.resume_client(
|
||||
user_id = user_id,
|
||||
token = info["token"],
|
||||
device_id = info["device_id"],
|
||||
homeserver = info["homeserver"],
|
||||
)
|
||||
return user_id
|
||||
|
||||
return await asyncio.gather(*(
|
||||
resume(uid, info)
|
||||
for uid, info in (await self.saved_accounts.read()).items()
|
||||
))
|
||||
|
||||
|
||||
async def logout_client(self, user_id: str) -> None:
|
||||
client = self.clients.pop(user_id, None)
|
||||
if client:
|
||||
self.models[Account].pop(user_id, None)
|
||||
await client.logout()
|
||||
|
||||
await self.saved_accounts.delete(user_id)
|
||||
|
||||
|
||||
async def wait_until_client_exists(self, user_id: str) -> None:
|
||||
loops = 0
|
||||
while True:
|
||||
if user_id in self.clients:
|
||||
return
|
||||
|
||||
if loops and loops % 100 == 0: # every 10s except first time
|
||||
log.warning("Waiting for account %s to exist, %ds passed",
|
||||
user_id, loops // 10)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
loops += 1
|
||||
|
||||
|
||||
# General functions
|
||||
|
||||
@staticmethod
|
||||
def hsluv(hue: int, saturation: int, lightness: int) -> List[float]:
|
||||
# (0-360, 0-100, 0-100) -> [0-1, 0-1, 0-1]
|
||||
return hsluv.hsluv_to_rgb([hue, saturation, lightness])
|
||||
|
||||
|
||||
async def load_settings(self) -> tuple:
|
||||
from .config_files import Theme
|
||||
settings = await self.ui_settings.read()
|
||||
ui_state = await self.ui_state.read()
|
||||
history = await self.history.read()
|
||||
theme = await Theme(self, settings["theme"]).read()
|
||||
|
||||
return (settings, ui_state, history, theme)
|
||||
|
||||
|
||||
async def get_flat_mainpane_data(self) -> List[Dict[str, Any]]:
|
||||
data = []
|
||||
|
||||
for account in sorted(self.models[Account].values()):
|
||||
data.append({
|
||||
"type": "Account",
|
||||
"id": account.user_id,
|
||||
"user_id": account.user_id,
|
||||
"data": account.serialized,
|
||||
})
|
||||
|
||||
for room in sorted(self.models[Room, account.user_id].values()):
|
||||
data.append({
|
||||
"type": "Room",
|
||||
"id": "/".join((account.user_id, room.room_id)),
|
||||
"user_id": account.user_id,
|
||||
"data": room.serialized,
|
||||
})
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# Client functions that don't need authentification
|
||||
|
||||
async def _any_client(self, caller: Callable, *args, **kw) -> MatrixClient:
|
||||
failures = 0
|
||||
|
||||
while True:
|
||||
for client in self.clients.values():
|
||||
if client.syncing:
|
||||
return client
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
failures += 1
|
||||
|
||||
if failures and failures % 300 == 0:
|
||||
log.warn(
|
||||
"No syncing client found after %ds of wait for %s %r %r",
|
||||
failures / 10, caller.__name__, args, kw,
|
||||
)
|
||||
|
||||
|
||||
async def get_profile(self, user_id: str) -> nio.ProfileGetResponse:
|
||||
if user_id in self.profile_cache:
|
||||
return self.profile_cache[user_id]
|
||||
|
||||
async with self.get_profile_locks[user_id]:
|
||||
client = await self._any_client(self.get_profile, user_id)
|
||||
response = await client.get_profile(user_id)
|
||||
|
||||
if isinstance(response, nio.ProfileGetError):
|
||||
raise MatrixError.from_nio(response)
|
||||
|
||||
self.profile_cache[user_id] = response
|
||||
return response
|
||||
|
||||
|
||||
async def thumbnail(
|
||||
self, server_name: str, media_id: str, width: int, height: int,
|
||||
) -> nio.ThumbnailResponse:
|
||||
|
||||
args = (server_name, media_id, width, height)
|
||||
client = await self._any_client(self.thumbnail, *args)
|
||||
response = await client.thumbnail(*args)
|
||||
|
||||
if isinstance(response, nio.ThumbnailError):
|
||||
raise MatrixError.from_nio(response)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def download(
|
||||
self, server_name: str, media_id: str,
|
||||
) -> nio.DownloadResponse:
|
||||
|
||||
client = await self._any_client(self.download, server_name, media_id)
|
||||
response = await client.download(server_name, media_id)
|
||||
|
||||
if isinstance(response, nio.DownloadError):
|
||||
raise MatrixError.from_nio(response)
|
||||
|
||||
return response
|
225
src/backend/config_files.py
Normal file
225
src/backend/config_files.py
Normal file
@@ -0,0 +1,225 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging as log
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import aiofiles
|
||||
|
||||
from .backend import Backend
|
||||
from .theme_parser import convert_to_qml
|
||||
from .utils import dict_update_recursive
|
||||
|
||||
JsonData = Dict[str, Any]
|
||||
|
||||
WRITE_LOCK = asyncio.Lock()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigFile:
|
||||
backend: Backend = field(repr=False)
|
||||
filename: str = field()
|
||||
|
||||
_to_write: Optional[str] = field(init=False, default=None)
|
||||
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
asyncio.ensure_future(self._write_loop())
|
||||
|
||||
|
||||
@property
|
||||
def path(self) -> Path:
|
||||
return Path(self.backend.app.appdirs.user_config_dir) / self.filename
|
||||
|
||||
|
||||
async def default_data(self):
|
||||
return ""
|
||||
|
||||
|
||||
async def read(self):
|
||||
log.debug("Reading config %s at %s", type(self).__name__, self.path)
|
||||
return self.path.read_text()
|
||||
|
||||
|
||||
async def write(self, data) -> None:
|
||||
self._to_write = data
|
||||
|
||||
|
||||
async def _write_loop(self) -> None:
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
while True:
|
||||
if self._to_write is not None:
|
||||
async with aiofiles.open(self.path, "w") as new:
|
||||
await new.write(self._to_write)
|
||||
|
||||
self._to_write = None
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
@dataclass
|
||||
class JSONConfigFile(ConfigFile):
|
||||
async def default_data(self) -> JsonData:
|
||||
return {}
|
||||
|
||||
|
||||
async def read(self) -> JsonData:
|
||||
try:
|
||||
data = json.loads(await super().read())
|
||||
except (FileNotFoundError, json.JSONDecodeError):
|
||||
data = {}
|
||||
|
||||
all_data = await self.default_data()
|
||||
dict_update_recursive(all_data, data)
|
||||
|
||||
if data != all_data:
|
||||
await self.write(all_data)
|
||||
|
||||
return all_data
|
||||
|
||||
|
||||
async def write(self, data: JsonData) -> None:
|
||||
js = json.dumps(data, indent=4, ensure_ascii=False, sort_keys=True)
|
||||
await super().write(js)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Accounts(JSONConfigFile):
|
||||
filename: str = "accounts.json"
|
||||
|
||||
|
||||
async def any_saved(self) -> bool:
|
||||
return bool(await self.read())
|
||||
|
||||
|
||||
async def add(self, user_id: str) -> None:
|
||||
client = self.backend.clients[user_id]
|
||||
|
||||
await self.write({
|
||||
**await self.read(),
|
||||
client.user_id: {
|
||||
"homeserver": client.homeserver,
|
||||
"token": client.access_token,
|
||||
"device_id": client.device_id,
|
||||
},
|
||||
})
|
||||
|
||||
|
||||
async def delete(self, user_id: str) -> None:
|
||||
await self.write({
|
||||
uid: info
|
||||
for uid, info in (await self.read()).items() if uid != user_id
|
||||
})
|
||||
|
||||
|
||||
@dataclass
|
||||
class UISettings(JSONConfigFile):
|
||||
filename: str = "settings.json"
|
||||
|
||||
|
||||
async def default_data(self) -> JsonData:
|
||||
return {
|
||||
"alertOnMessageForMsec": 4000,
|
||||
"clearRoomFilterOnEnter": True,
|
||||
"clearRoomFilterOnEscape": True,
|
||||
"theme": "Default.qpl",
|
||||
"writeAliases": {},
|
||||
"media": {
|
||||
"autoLoad": True,
|
||||
"autoPlay": False,
|
||||
"autoPlayGIF": True,
|
||||
"autoHideOSDAfterMsec": 3000,
|
||||
"defaultVolume": 100,
|
||||
"startMuted": False,
|
||||
},
|
||||
"keys": {
|
||||
"startPythonDebugger": ["Alt+Shift+D"],
|
||||
"toggleDebugConsole": ["Alt+Shift+C", "F1"],
|
||||
"reloadConfig": ["Alt+Shift+R"],
|
||||
|
||||
"zoomIn": ["Ctrl++"],
|
||||
"zoomOut": ["Ctrl+-"],
|
||||
"zoomReset": ["Ctrl+="],
|
||||
|
||||
"scrollUp": ["Alt+Up", "Alt+K"],
|
||||
"scrollDown": ["Alt+Down", "Alt+J"],
|
||||
"scrollPageUp": ["Alt+Ctrl+Up", "Alt+Ctrl+K", "PgUp"],
|
||||
"scrollPageDown": ["Alt+Ctrl+Down", "Alt+Ctrl+J", "PgDown"],
|
||||
"scrollToTop":
|
||||
["Alt+Ctrl+Shift+Up", "Alt+Ctrl+Shift+K", "Home"],
|
||||
"scrollToBottom":
|
||||
["Alt+Ctrl+Shift+Down", "Alt+Ctrl+Shift+J", "End"],
|
||||
|
||||
"previousTab": ["Alt+Shift+Left", "Alt+Shift+H"],
|
||||
"nextTab": ["Alt+Shift+Right", "Alt+Shift+L"],
|
||||
|
||||
"focusMainPane": ["Alt+S"],
|
||||
"clearRoomFilter": ["Alt+Shift+S"],
|
||||
"accountSettings": ["Alt+A"],
|
||||
"addNewChat": ["Alt+N"],
|
||||
"addNewAccount": ["Alt+Shift+N"],
|
||||
|
||||
"goToLastPage": ["Ctrl+Tab"],
|
||||
"goToPreviousRoom": ["Alt+Shift+Up", "Alt+Shift+K"],
|
||||
"goToNextRoom": ["Alt+Shift+Down", "Alt+Shift+J"],
|
||||
"toggleCollapseAccount": [ "Alt+O"],
|
||||
|
||||
"clearRoomMessages": ["Ctrl+L"],
|
||||
"sendFile": ["Alt+F"],
|
||||
"sendFileFromPathInClipboard": ["Alt+Shift+F"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class UIState(JSONConfigFile):
|
||||
filename: str = "state.json"
|
||||
|
||||
|
||||
@property
|
||||
def path(self) -> Path:
|
||||
return Path(self.backend.app.appdirs.user_data_dir) / self.filename
|
||||
|
||||
|
||||
async def default_data(self) -> JsonData:
|
||||
return {
|
||||
"collapseAccounts": {},
|
||||
"page": "Pages/Default.qml",
|
||||
"pageProperties": {},
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class History(JSONConfigFile):
|
||||
filename: str = "history.json"
|
||||
|
||||
|
||||
@property
|
||||
def path(self) -> Path:
|
||||
return Path(self.backend.app.appdirs.user_data_dir) / self.filename
|
||||
|
||||
|
||||
async def default_data(self) -> JsonData:
|
||||
return {"console": []}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Theme(ConfigFile):
|
||||
@property
|
||||
def path(self) -> Path:
|
||||
data_dir = Path(self.backend.app.appdirs.user_data_dir)
|
||||
return data_dir / "themes" / self.filename
|
||||
|
||||
|
||||
async def default_data(self) -> str:
|
||||
async with aiofiles.open("src/themes/Default.qpl") as file:
|
||||
return await file.read()
|
||||
|
||||
|
||||
async def read(self) -> str:
|
||||
if not self.path.exists():
|
||||
await self.write(await self.default_data())
|
||||
|
||||
return convert_to_qml(await super().read())
|
89
src/backend/errors.py
Normal file
89
src/backend/errors.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import nio
|
||||
|
||||
|
||||
# Matrix Errors
|
||||
|
||||
@dataclass
|
||||
class MatrixError(Exception):
|
||||
http_code: int = 400
|
||||
m_code: str = "M_UNKNOWN"
|
||||
|
||||
@classmethod
|
||||
def from_nio(cls, response: nio.ErrorResponse) -> "MatrixError":
|
||||
# Check for the M_CODE first: some errors for an API share the same
|
||||
# http code, but have different M_CODEs (e.g. POST /login 403).
|
||||
for subcls in cls.__subclasses__():
|
||||
if subcls.m_code == response.status_code:
|
||||
return subcls()
|
||||
|
||||
for subcls in cls.__subclasses__():
|
||||
if subcls.http_code == response.transport_response.status:
|
||||
return subcls()
|
||||
|
||||
return cls(response.transport_response.status, response.status_code)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatrixForbidden(MatrixError):
|
||||
http_code: int = 403
|
||||
m_code: str = "M_FORBIDDEN"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatrixBadJson(MatrixError):
|
||||
http_code: int = 403
|
||||
m_code: str = "M_BAD_JSON"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatrixNotJson(MatrixError):
|
||||
http_code: int = 403
|
||||
m_code: str = "M_NOT_JSON"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatrixUserDeactivated(MatrixError):
|
||||
http_code: int = 403
|
||||
m_code: str = "M_USER_DEACTIVATED"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatrixNotFound(MatrixError):
|
||||
http_code: int = 404
|
||||
m_code: str = "M_NOT_FOUND"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatrixTooLarge(MatrixError):
|
||||
http_code: int = 413
|
||||
m_code: str = "M_TOO_LARGE"
|
||||
|
||||
|
||||
# Client errors
|
||||
|
||||
@dataclass
|
||||
class UserNotFound(Exception):
|
||||
user_id: str = field()
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvalidUserId(Exception):
|
||||
user_id: str = field()
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvalidUserInContext(Exception):
|
||||
user_id: str = field()
|
||||
|
||||
|
||||
@dataclass
|
||||
class UneededThumbnail(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BadMimeType(Exception):
|
||||
wanted: str = field()
|
||||
got: str = field()
|
217
src/backend/html_filter.py
Normal file
217
src/backend/html_filter.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import re
|
||||
|
||||
import html_sanitizer.sanitizer as sanitizer
|
||||
import mistune
|
||||
from html_sanitizer.sanitizer import Sanitizer
|
||||
from lxml.html import HtmlElement # nosec
|
||||
|
||||
|
||||
class MarkdownInlineGrammar(mistune.InlineGrammar):
|
||||
# Enable *word* but not _word_ syntaxes (TODO: config option for that)
|
||||
emphasis = re.compile(r"^\*((?:\*\*|[^\*])+?)\*(?!\*)")
|
||||
double_emphasis = re.compile(r"^\*{2}([\s\S]+?)\*{2}(?!\*)")
|
||||
|
||||
|
||||
class MarkdownInlineLexer(mistune.InlineLexer):
|
||||
grammar_class = MarkdownInlineGrammar
|
||||
|
||||
|
||||
def output_double_emphasis(self, m):
|
||||
return self.renderer.double_emphasis(self.output(m.group(1)))
|
||||
|
||||
|
||||
def output_emphasis(self, m):
|
||||
return self.renderer.emphasis(self.output(m.group(1)))
|
||||
|
||||
|
||||
class HtmlFilter:
|
||||
inline_tags = {"font", "a", "sup", "sub", "b", "i", "s", "u", "code"}
|
||||
|
||||
block_tags = {
|
||||
"h1", "h2", "h3", "h4", "h5", "h6","blockquote",
|
||||
"p", "ul", "ol", "li", "hr", "br",
|
||||
"table", "thead", "tbody", "tr", "th", "td", "pre",
|
||||
}
|
||||
|
||||
link_regexes = [re.compile(r, re.IGNORECASE) for r in [
|
||||
(r"(?P<body>[a-zA-Z\d]+://(?P<host>[a-z\d._-]+(?:\:\d+)?)"
|
||||
r"(?:/[/\-_.,a-z\d#%&?;=~]*)?(?:\([/\-_.,a-z\d#%&?;=~]*\))?)"),
|
||||
r"mailto:(?P<body>[a-z0-9._-]+@(?P<host>[a-z0-9_.-]+[a-z](?:\:\d+)?))",
|
||||
r"tel:(?P<body>[0-9+-]+)(?P<host>)",
|
||||
r"(?P<body>magnet:\?xt=urn:[a-z0-9]+:.+)(?P<host>)",
|
||||
]]
|
||||
|
||||
inline_quote_regex = re.compile(r"(^|⏎)(\s*>[^⏎\n]*)", re.MULTILINE)
|
||||
|
||||
quote_regex = re.compile(
|
||||
r"(^|<p/?>|<br/?>|<h\d/?>)(\s*>.*?)(</?p>|<br/?>|</?h\d>|$)",
|
||||
re.MULTILINE,
|
||||
)
|
||||
|
||||
extra_newlines_regex = re.compile(r"\n(\n*)")
|
||||
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._sanitizer = Sanitizer(self.sanitize_settings())
|
||||
self._inline_sanitizer = Sanitizer(self.sanitize_settings(inline=True))
|
||||
|
||||
# The whitespace remover doesn't take <pre> into account
|
||||
sanitizer.normalize_overall_whitespace = lambda html, *args, **kw: html
|
||||
sanitizer.normalize_whitespace_in_text_or_tail = \
|
||||
lambda el, *args, **kw: el
|
||||
|
||||
# hard_wrap: convert all \n to <br> without required two spaces
|
||||
# escape: escape HTML characters in the input string, e.g. tags
|
||||
self._markdown_to_html = mistune.Markdown(
|
||||
hard_wrap=True, escape=True, inline=MarkdownInlineLexer,
|
||||
)
|
||||
|
||||
self._markdown_to_html.block.default_rules = [
|
||||
rule for rule in self._markdown_to_html.block.default_rules
|
||||
if rule != "block_quote"
|
||||
]
|
||||
|
||||
|
||||
def from_markdown(self, text: str, outgoing: bool = False) -> str:
|
||||
return self.filter(self._markdown_to_html(text), outgoing)
|
||||
|
||||
|
||||
def from_markdown_inline(self, text: str, outgoing: bool = False) -> str:
|
||||
return self.filter_inline(self._markdown_to_html(text), outgoing)
|
||||
|
||||
|
||||
def filter_inline(self, html: str, outgoing: bool = False) -> str:
|
||||
html = self._inline_sanitizer.sanitize(html)
|
||||
|
||||
if outgoing:
|
||||
return html
|
||||
|
||||
# Client-side modifications
|
||||
return self.inline_quote_regex.sub(
|
||||
r'\1<span class="quote">\2</span>', html,
|
||||
)
|
||||
|
||||
|
||||
def filter(self, html: str, outgoing: bool = False) -> str:
|
||||
html = self._sanitizer.sanitize(html).rstrip("\n")
|
||||
|
||||
if outgoing:
|
||||
return html
|
||||
|
||||
return self.quote_regex.sub(r'\1<span class="quote">\2</span>\3', html)
|
||||
|
||||
|
||||
def sanitize_settings(self, inline: bool = False) -> dict:
|
||||
# https://matrix.org/docs/spec/client_server/latest#m-room-message-msgtypes
|
||||
# TODO: mx-reply and the new hidden thing
|
||||
|
||||
inline_tags = self.inline_tags
|
||||
all_tags = inline_tags | self.block_tags
|
||||
|
||||
inlines_attributes = {
|
||||
"font": {"color"},
|
||||
"a": {"href"},
|
||||
"code": {"class"},
|
||||
}
|
||||
attributes = {**inlines_attributes, **{
|
||||
"ol": {"start"},
|
||||
"hr": {"width"},
|
||||
}}
|
||||
|
||||
return {
|
||||
"tags": inline_tags if inline else all_tags,
|
||||
"attributes": inlines_attributes if inline else attributes,
|
||||
"empty": {} if inline else {"hr", "br"},
|
||||
"separate": {"a"} if inline else {
|
||||
"a", "p", "li", "table", "tr", "th", "td", "br", "hr",
|
||||
},
|
||||
"whitespace": {},
|
||||
"keep_typographic_whitespace": True,
|
||||
"add_nofollow": False,
|
||||
"autolink": {
|
||||
"link_regexes": self.link_regexes,
|
||||
"avoid_hosts": [],
|
||||
},
|
||||
"sanitize_href": lambda href: href,
|
||||
"element_preprocessors": [
|
||||
sanitizer.bold_span_to_strong,
|
||||
sanitizer.italic_span_to_em,
|
||||
sanitizer.tag_replacer("strong", "b"),
|
||||
sanitizer.tag_replacer("em", "i"),
|
||||
sanitizer.tag_replacer("strike", "s"),
|
||||
sanitizer.tag_replacer("del", "s"),
|
||||
sanitizer.tag_replacer("form", "p"),
|
||||
sanitizer.tag_replacer("div", "p"),
|
||||
sanitizer.tag_replacer("caption", "p"),
|
||||
sanitizer.target_blank_noopener,
|
||||
self._process_span_font,
|
||||
self._img_to_a,
|
||||
self._remove_extra_newlines,
|
||||
self._newlines_to_return_symbol if inline else lambda el: el,
|
||||
],
|
||||
"element_postprocessors": [],
|
||||
"is_mergeable": lambda e1, e2: e1.attrib == e2.attrib,
|
||||
}
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _process_span_font(el: HtmlElement) -> HtmlElement:
|
||||
if el.tag not in ("span", "font"):
|
||||
return el
|
||||
|
||||
color = el.attrib.pop("data-mx-color", None)
|
||||
if color:
|
||||
el.tag = "font"
|
||||
el.attrib["color"] = color
|
||||
|
||||
return el
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _img_to_a(el: HtmlElement) -> HtmlElement:
|
||||
if el.tag == "img":
|
||||
el.tag = "a"
|
||||
el.attrib["href"] = el.attrib.pop("src", "")
|
||||
el.text = el.attrib.pop("alt", None) or el.attrib["href"]
|
||||
|
||||
return el
|
||||
|
||||
|
||||
def _remove_extra_newlines(self, el: HtmlElement) -> HtmlElement:
|
||||
# Remove excess \n characters to avoid additional blank lines with
|
||||
# HTML/CSS using `white-space: pre`, except in <pre> content.
|
||||
|
||||
pre_parent = any(parent.tag == "pre" for parent in el.iterancestors())
|
||||
|
||||
if el.tag != "pre" and not pre_parent:
|
||||
if el.text:
|
||||
el.text = self.extra_newlines_regex.sub(r"\1", el.text)
|
||||
if el.tail:
|
||||
el.tail = self.extra_newlines_regex.sub(r"\1", el.tail)
|
||||
|
||||
return el
|
||||
|
||||
|
||||
def _newlines_to_return_symbol(self, el: HtmlElement) -> HtmlElement:
|
||||
# Add a return unicode symbol (U+23CE) to blocks with siblings
|
||||
# (e.g. a <p> followed by another <p>) or <br>.
|
||||
# The <br> themselves will be removed by the inline sanitizer.
|
||||
|
||||
is_block_with_siblings = (el.tag in self.block_tags and
|
||||
next(el.itersiblings(), None) is not None)
|
||||
|
||||
if el.tag == "br" or is_block_with_siblings:
|
||||
el.tail = f" ⏎ {el.tail or ''}"
|
||||
|
||||
|
||||
# Replace left \n in text/tail of <pre> content by the return symbol.
|
||||
if el.text:
|
||||
el.text = re.sub(r"\n", r" ⏎ ", el.text)
|
||||
|
||||
if el.tail:
|
||||
el.tail = re.sub(r"\n", r" ⏎ ", el.tail)
|
||||
|
||||
return el
|
||||
|
||||
|
||||
HTML_FILTER = HtmlFilter()
|
944
src/backend/matrix_client.py
Normal file
944
src/backend/matrix_client.py
Normal file
@@ -0,0 +1,944 @@
|
||||
import asyncio
|
||||
import html
|
||||
import io
|
||||
import logging as log
|
||||
import platform
|
||||
import re
|
||||
import traceback
|
||||
from contextlib import suppress
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any, DefaultDict, Dict, List, NamedTuple, Optional, Set, Tuple, Type,
|
||||
Union,
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import cairosvg
|
||||
from PIL import Image as PILImage
|
||||
from pymediainfo import MediaInfo
|
||||
|
||||
import nio
|
||||
from nio.crypto import AsyncDataT as UploadData
|
||||
from nio.crypto import async_generator_from_data
|
||||
|
||||
from .__about__ import __pkg_name__, __pretty_name__
|
||||
from . import utils
|
||||
from .errors import (
|
||||
BadMimeType, InvalidUserId, InvalidUserInContext, MatrixError,
|
||||
UneededThumbnail, UserNotFound,
|
||||
)
|
||||
from .html_filter import HTML_FILTER
|
||||
from .models.items import (
|
||||
Account, Event, Member, Room, TypeSpecifier, Upload, UploadStatus,
|
||||
)
|
||||
from .models.model_store import ModelStore
|
||||
from .pyotherside_events import AlertRequested
|
||||
|
||||
CryptDict = Dict[str, Any]
|
||||
|
||||
|
||||
class UploadReturn(NamedTuple):
|
||||
mxc: str
|
||||
mime: str
|
||||
decryption_dict: Dict[str, Any]
|
||||
|
||||
|
||||
class MatrixImageInfo(NamedTuple):
|
||||
w: int
|
||||
h: int
|
||||
mimetype: str
|
||||
size: int
|
||||
|
||||
|
||||
class MatrixClient(nio.AsyncClient):
|
||||
user_id_regex = re.compile(r"^@.+:.+")
|
||||
room_id_or_alias_regex = re.compile(r"^[#!].+:.+")
|
||||
http_s_url = re.compile(r"^https?://")
|
||||
|
||||
def __init__(self,
|
||||
backend,
|
||||
user: str,
|
||||
homeserver: str = "https://matrix.org",
|
||||
device_id: Optional[str] = None) -> None:
|
||||
|
||||
if not urlparse(homeserver).scheme:
|
||||
raise ValueError(
|
||||
f"homeserver is missing scheme (e.g. https://): {homeserver}",
|
||||
)
|
||||
|
||||
store = Path(backend.app.appdirs.user_data_dir) / "encryption"
|
||||
store.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
super().__init__(
|
||||
homeserver = homeserver,
|
||||
user = user,
|
||||
device_id = device_id,
|
||||
store_path = store,
|
||||
config = nio.AsyncClientConfig(
|
||||
max_timeout_retry_wait_time = 10,
|
||||
# TODO: pass a custom encryption DB pickle key?
|
||||
),
|
||||
)
|
||||
|
||||
from .backend import Backend
|
||||
self.backend: Backend = backend
|
||||
self.models: ModelStore = self.backend.models
|
||||
|
||||
self.profile_task: Optional[asyncio.Future] = None
|
||||
self.sync_task: Optional[asyncio.Future] = None
|
||||
self.load_rooms_task: Optional[asyncio.Future] = None
|
||||
self.first_sync_done: asyncio.Event = asyncio.Event()
|
||||
self.first_sync_date: Optional[datetime] = None
|
||||
|
||||
self.past_tokens: Dict[str, str] = {} # {room_id: token}
|
||||
self.fully_loaded_rooms: Set[str] = set() # {room_id}
|
||||
self.loaded_once_rooms: Set[str] = set() # {room_id}
|
||||
self.cleared_events_rooms: Set[str] = set() # {room_id}
|
||||
|
||||
self.skipped_events: DefaultDict[str, int] = DefaultDict(lambda: 0)
|
||||
|
||||
from .nio_callbacks import NioCallbacks
|
||||
self.nio_callbacks = NioCallbacks(self)
|
||||
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "%s(user_id=%r, homeserver=%r, device_id=%r)" % (
|
||||
type(self).__name__, self.user_id, self.homeserver, self.device_id,
|
||||
)
|
||||
|
||||
|
||||
@property
|
||||
def default_device_name(self) -> str:
|
||||
os_ = f" on {platform.system()}".rstrip()
|
||||
os_ = f"{os_} {platform.release()}".rstrip() if os_ != " on" else ""
|
||||
return f"{__pretty_name__}{os_}"
|
||||
|
||||
|
||||
async def login(self, password: str, device_name: str = "") -> None:
|
||||
response = await super().login(
|
||||
password, device_name or self.default_device_name,
|
||||
)
|
||||
|
||||
if isinstance(response, nio.LoginError):
|
||||
raise MatrixError.from_nio(response)
|
||||
|
||||
asyncio.ensure_future(self.start())
|
||||
|
||||
|
||||
async def resume(self, user_id: str, token: str, device_id: str) -> None:
|
||||
response = nio.LoginResponse(user_id, device_id, token)
|
||||
await self.receive_response(response)
|
||||
|
||||
asyncio.ensure_future(self.start())
|
||||
|
||||
|
||||
async def logout(self) -> None:
|
||||
for task in (self.profile_task, self.load_rooms_task, self.sync_task):
|
||||
if task:
|
||||
task.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
await super().logout()
|
||||
await self.close()
|
||||
|
||||
|
||||
@property
|
||||
def syncing(self) -> bool:
|
||||
if not self.sync_task:
|
||||
return False
|
||||
|
||||
return not self.sync_task.done()
|
||||
|
||||
|
||||
async def start(self) -> None:
|
||||
def on_profile_response(future) -> None:
|
||||
exception = future.exception()
|
||||
|
||||
if exception:
|
||||
log.warn("On %s client startup: %r", self.user_id, exception)
|
||||
self.profile_task = asyncio.ensure_future(
|
||||
self.backend.get_profile(self.user_id),
|
||||
)
|
||||
self.profile_task.add_done_callback(on_profile_response)
|
||||
return
|
||||
|
||||
resp = future.result()
|
||||
account = self.models[Account][self.user_id]
|
||||
account.profile_updated = datetime.now()
|
||||
account.display_name = resp.displayname or ""
|
||||
account.avatar_url = resp.avatar_url or ""
|
||||
|
||||
self.profile_task = asyncio.ensure_future(
|
||||
self.backend.get_profile(self.user_id),
|
||||
)
|
||||
self.profile_task.add_done_callback(on_profile_response)
|
||||
|
||||
while True:
|
||||
try:
|
||||
self.sync_task = asyncio.ensure_future(
|
||||
self.sync_forever(timeout=10_000),
|
||||
)
|
||||
await self.sync_task
|
||||
break # task cancelled
|
||||
except Exception:
|
||||
trace = traceback.format_exc().rstrip()
|
||||
log.error("Exception during sync, will restart:\n%s", trace)
|
||||
await asyncio.sleep(2)
|
||||
|
||||
|
||||
@property
|
||||
def all_rooms(self) -> Dict[str, nio.MatrixRoom]:
|
||||
return {**self.invited_rooms, **self.rooms}
|
||||
|
||||
|
||||
async def send_text(self, room_id: str, text: str) -> None:
|
||||
escape = False
|
||||
if text.startswith("//") or text.startswith(r"\/"):
|
||||
escape = True
|
||||
text = text[1:]
|
||||
|
||||
if text.startswith("/me ") and not escape:
|
||||
event_type = nio.RoomMessageEmote
|
||||
text = text[len("/me "): ]
|
||||
content = {"body": text, "msgtype": "m.emote"}
|
||||
to_html = HTML_FILTER.from_markdown_inline(text, outgoing=True)
|
||||
echo_body = HTML_FILTER.from_markdown_inline(text)
|
||||
else:
|
||||
event_type = nio.RoomMessageText
|
||||
content = {"body": text, "msgtype": "m.text"}
|
||||
to_html = HTML_FILTER.from_markdown(text, outgoing=True)
|
||||
echo_body = HTML_FILTER.from_markdown(text)
|
||||
|
||||
if to_html not in (html.escape(text), f"<p>{html.escape(text)}</p>"):
|
||||
content["format"] = "org.matrix.custom.html"
|
||||
content["formatted_body"] = to_html
|
||||
|
||||
# Can't use the standard Matrix transaction IDs; they're only visible
|
||||
# to the sender so our other accounts wouldn't be able to replace
|
||||
# local echoes by real messages.
|
||||
tx_id = uuid4()
|
||||
content[f"{__pkg_name__}.transaction_id"] = str(tx_id)
|
||||
|
||||
await self._local_echo(room_id, tx_id, event_type, content=echo_body)
|
||||
await self._send_message(room_id, content)
|
||||
|
||||
|
||||
async def send_file(self, room_id: str, path: Union[Path, str]) -> None:
|
||||
item_uuid = uuid4()
|
||||
|
||||
try:
|
||||
await self._send_file(item_uuid, room_id, path)
|
||||
except (nio.TransferCancelledError, asyncio.CancelledError):
|
||||
log.info("Deleting item for cancelled upload %s", item_uuid)
|
||||
del self.models[Upload, room_id][str(item_uuid)]
|
||||
|
||||
|
||||
async def _send_file(
|
||||
self, item_uuid: UUID, room_id: str, path: Union[Path, str],
|
||||
) -> None:
|
||||
from .media_cache import Media, Thumbnail
|
||||
|
||||
transaction_id = uuid4()
|
||||
path = Path(path)
|
||||
encrypt = room_id in self.encrypted_rooms
|
||||
|
||||
try:
|
||||
size = path.resolve().stat().st_size
|
||||
except (PermissionError, FileNotFoundError):
|
||||
# This error will be caught again by the try block later below
|
||||
size = 0
|
||||
|
||||
task = asyncio.Task.current_task()
|
||||
monitor = nio.TransferMonitor(size)
|
||||
upload_item = Upload(item_uuid, task, monitor, path, total_size=size)
|
||||
self.models[Upload, room_id][str(item_uuid)] = upload_item
|
||||
|
||||
def on_transferred(transferred: int) -> None:
|
||||
upload_item.uploaded = transferred
|
||||
|
||||
def on_speed_changed(speed: float) -> None:
|
||||
upload_item.speed = speed
|
||||
upload_item.time_left = monitor.remaining_time
|
||||
|
||||
monitor.on_transferred = on_transferred
|
||||
monitor.on_speed_changed = on_speed_changed
|
||||
|
||||
try:
|
||||
url, mime, crypt_dict = await self.upload(
|
||||
lambda *_: path,
|
||||
filename = path.name,
|
||||
encrypt = encrypt, monitor=monitor,
|
||||
)
|
||||
except (MatrixError, OSError) as err:
|
||||
upload_item.status = UploadStatus.Error
|
||||
upload_item.error = type(err)
|
||||
upload_item.error_args = err.args
|
||||
|
||||
# Wait for cancellation from UI, see parent send_file() method
|
||||
while True:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
upload_item.status = UploadStatus.Caching
|
||||
await Media.from_existing_file(self.backend.media_cache, url, path)
|
||||
|
||||
kind = (mime or "").split("/")[0]
|
||||
|
||||
thumb_url: str = ""
|
||||
thumb_info: Optional[MatrixImageInfo] = None
|
||||
|
||||
content: dict = {
|
||||
f"{__pkg_name__}.transaction_id": str(transaction_id),
|
||||
|
||||
"body": path.name,
|
||||
"info": {
|
||||
"mimetype": mime,
|
||||
"size": upload_item.total_size,
|
||||
},
|
||||
}
|
||||
|
||||
if encrypt:
|
||||
content["file"] = {"url": url, **crypt_dict}
|
||||
else:
|
||||
content["url"] = url
|
||||
|
||||
if kind == "image":
|
||||
is_svg = mime == "image/svg+xml"
|
||||
|
||||
event_type = \
|
||||
nio.RoomEncryptedImage if encrypt else nio.RoomMessageImage
|
||||
|
||||
content["msgtype"] = "m.image"
|
||||
|
||||
content["info"]["w"], content["info"]["h"] = (
|
||||
await utils.svg_dimensions(path) if is_svg else
|
||||
PILImage.open(path).size
|
||||
)
|
||||
|
||||
try:
|
||||
thumb_data, thumb_info = await self.generate_thumbnail(
|
||||
path, is_svg=is_svg,
|
||||
)
|
||||
except UneededThumbnail:
|
||||
pass
|
||||
except OSError as err:
|
||||
log.warning(f"Failed thumbnailing {path}: {err}")
|
||||
else:
|
||||
thumb_name = f"{path.stem}_thumbnail{''.join(path.suffixes)}"
|
||||
|
||||
upload_item.status = UploadStatus.Uploading
|
||||
upload_item.filepath = Path(thumb_name)
|
||||
upload_item.total_size = len(thumb_data)
|
||||
|
||||
try:
|
||||
thumb_url, _, thumb_crypt_dict = await self.upload(
|
||||
lambda *_: thumb_data,
|
||||
filename =
|
||||
f"{path.stem}_sample{''.join(path.suffixes)}",
|
||||
encrypt = encrypt,
|
||||
)
|
||||
except MatrixError as err:
|
||||
log.warning(f"Failed uploading thumbnail {path}: {err}")
|
||||
else:
|
||||
upload_item.status = UploadStatus.Caching
|
||||
|
||||
await Thumbnail.from_bytes(
|
||||
self.backend.media_cache,
|
||||
thumb_url,
|
||||
thumb_data,
|
||||
wanted_size = (content["info"]["w"],
|
||||
content["info"]["h"]),
|
||||
)
|
||||
|
||||
if encrypt:
|
||||
content["info"]["thumbnail_file"] = {
|
||||
"url": thumb_url,
|
||||
**thumb_crypt_dict,
|
||||
}
|
||||
else:
|
||||
content["info"]["thumbnail_url"] = thumb_url
|
||||
|
||||
content["info"]["thumbnail_info"] = thumb_info._asdict()
|
||||
|
||||
elif kind == "audio":
|
||||
event_type = \
|
||||
nio.RoomEncryptedAudio if encrypt else nio.RoomMessageAudio
|
||||
|
||||
content["msgtype"] = "m.audio"
|
||||
content["info"]["duration"] = getattr(
|
||||
MediaInfo.parse(path).tracks[0], "duration", 0,
|
||||
) or 0
|
||||
|
||||
elif kind == "video":
|
||||
event_type = \
|
||||
nio.RoomEncryptedVideo if encrypt else nio.RoomMessageVideo
|
||||
|
||||
content["msgtype"] = "m.video"
|
||||
|
||||
tracks = MediaInfo.parse(path).tracks
|
||||
|
||||
content["info"]["duration"] = \
|
||||
getattr(tracks[0], "duration", 0) or 0
|
||||
|
||||
content["info"]["w"] = max(
|
||||
getattr(t, "width", 0) or 0 for t in tracks
|
||||
)
|
||||
content["info"]["h"] = max(
|
||||
getattr(t, "height", 0) or 0 for t in tracks
|
||||
)
|
||||
|
||||
else:
|
||||
event_type = \
|
||||
nio.RoomEncryptedFile if encrypt else nio.RoomMessageFile
|
||||
|
||||
content["msgtype"] = "m.file"
|
||||
content["filename"] = path.name
|
||||
|
||||
del self.models[Upload, room_id][str(upload_item.uuid)]
|
||||
|
||||
await self._local_echo(
|
||||
room_id,
|
||||
transaction_id,
|
||||
event_type,
|
||||
inline_content = path.name,
|
||||
media_url = url,
|
||||
media_title = path.name,
|
||||
media_width = content["info"].get("w", 0),
|
||||
media_height = content["info"].get("h", 0),
|
||||
media_duration = content["info"].get("duration", 0),
|
||||
media_size = content["info"]["size"],
|
||||
media_mime = content["info"]["mimetype"],
|
||||
thumbnail_url = thumb_url,
|
||||
thumbnail_width =
|
||||
content["info"].get("thumbnail_info", {}).get("w", 0),
|
||||
thumbnail_height =
|
||||
content["info"].get("thumbnail_info", {}).get("h", 0),
|
||||
)
|
||||
|
||||
await self._send_message(room_id, content)
|
||||
|
||||
|
||||
async def _local_echo(
|
||||
self, room_id: str, transaction_id: UUID,
|
||||
event_type: Type[nio.Event], **event_fields,
|
||||
) -> None:
|
||||
|
||||
our_info = self.models[Member, self.user_id, room_id][self.user_id]
|
||||
|
||||
event = Event(
|
||||
source = None,
|
||||
client_id = f"echo-{transaction_id}",
|
||||
event_id = "",
|
||||
date = datetime.now(),
|
||||
sender_id = self.user_id,
|
||||
sender_name = our_info.display_name,
|
||||
sender_avatar = our_info.avatar_url,
|
||||
is_local_echo = True,
|
||||
local_event_type = event_type,
|
||||
**event_fields,
|
||||
)
|
||||
|
||||
for user_id in self.models[Account]:
|
||||
if user_id in self.models[Member, self.user_id, room_id]:
|
||||
key = f"echo-{transaction_id}"
|
||||
self.models[Event, user_id, room_id][key] = event
|
||||
|
||||
if user_id == self.user_id:
|
||||
self.models[Event, user_id, room_id].sync_now()
|
||||
|
||||
await self.set_room_last_event(room_id, event)
|
||||
|
||||
|
||||
async def _send_message(self, room_id: str, content: dict) -> None:
|
||||
|
||||
async with self.backend.send_locks[room_id]:
|
||||
response = await self.room_send(
|
||||
room_id = room_id,
|
||||
message_type = "m.room.message",
|
||||
content = content,
|
||||
ignore_unverified_devices = True,
|
||||
)
|
||||
|
||||
if isinstance(response, nio.RoomSendError):
|
||||
raise MatrixError.from_nio(response)
|
||||
|
||||
|
||||
async def load_past_events(self, room_id: str) -> bool:
|
||||
if room_id in self.fully_loaded_rooms or \
|
||||
room_id in self.invited_rooms or \
|
||||
room_id in self.cleared_events_rooms:
|
||||
return False
|
||||
|
||||
await self.first_sync_done.wait()
|
||||
|
||||
while not self.past_tokens.get(room_id):
|
||||
# If a new room was added, wait for onSyncResponse to set the token
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
response = await self.room_messages(
|
||||
room_id = room_id,
|
||||
start = self.past_tokens[room_id],
|
||||
limit = 100 if room_id in self.loaded_once_rooms else 25,
|
||||
)
|
||||
|
||||
if isinstance(response, nio.RoomMessagesError):
|
||||
log.error("Loading past messages for room %s failed: %s",
|
||||
room_id, response)
|
||||
return True
|
||||
|
||||
self.loaded_once_rooms.add(room_id)
|
||||
more_to_load = True
|
||||
|
||||
self.past_tokens[room_id] = response.end
|
||||
|
||||
for event in response.chunk:
|
||||
if isinstance(event, nio.RoomCreateEvent):
|
||||
self.fully_loaded_rooms.add(room_id)
|
||||
more_to_load = False
|
||||
|
||||
for cb in self.event_callbacks:
|
||||
if (cb.filter is None or isinstance(event, cb.filter)):
|
||||
await cb.func(self.all_rooms[room_id], event)
|
||||
|
||||
return more_to_load
|
||||
|
||||
|
||||
async def load_rooms_without_visible_events(self) -> None:
|
||||
for room_id in self.models[Room, self.user_id]:
|
||||
asyncio.ensure_future(
|
||||
self._load_room_without_visible_events(room_id),
|
||||
)
|
||||
|
||||
|
||||
async def _load_room_without_visible_events(self, room_id: str) -> None:
|
||||
events = self.models[Event, self.user_id, room_id]
|
||||
more = True
|
||||
|
||||
while self.skipped_events[room_id] and not events and more:
|
||||
more = await self.load_past_events(room_id)
|
||||
|
||||
|
||||
async def new_direct_chat(self, invite: str, encrypt: bool = False) -> str:
|
||||
if invite == self.user_id:
|
||||
raise InvalidUserInContext(invite)
|
||||
|
||||
if not self.user_id_regex.match(invite):
|
||||
raise InvalidUserId(invite)
|
||||
|
||||
if isinstance(await self.get_profile(invite), nio.ProfileGetError):
|
||||
raise UserNotFound(invite)
|
||||
|
||||
response = await super().room_create(
|
||||
invite = [invite],
|
||||
is_direct = True,
|
||||
visibility = nio.RoomVisibility.private,
|
||||
initial_state =
|
||||
[nio.EnableEncryptionBuilder().as_dict()] if encrypt else [],
|
||||
)
|
||||
|
||||
if isinstance(response, nio.RoomCreateError):
|
||||
raise MatrixError.from_nio(response)
|
||||
|
||||
return response.room_id
|
||||
|
||||
|
||||
async def new_group_chat(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
topic: Optional[str] = None,
|
||||
public: bool = False,
|
||||
encrypt: bool = False,
|
||||
federate: bool = True,
|
||||
) -> str:
|
||||
|
||||
response = await super().room_create(
|
||||
name = name or None,
|
||||
topic = topic or None,
|
||||
federate = federate,
|
||||
visibility =
|
||||
nio.RoomVisibility.public if public else
|
||||
nio.RoomVisibility.private,
|
||||
initial_state =
|
||||
[nio.EnableEncryptionBuilder().as_dict()] if encrypt else [],
|
||||
)
|
||||
|
||||
if isinstance(response, nio.RoomCreateError):
|
||||
raise MatrixError.from_nio(response)
|
||||
|
||||
return response.room_id
|
||||
|
||||
async def room_join(self, alias_or_id_or_url: str) -> str:
|
||||
string = alias_or_id_or_url.strip()
|
||||
|
||||
if self.http_s_url.match(string):
|
||||
for part in urlparse(string).fragment.split("/"):
|
||||
if self.room_id_or_alias_regex.match(part):
|
||||
string = part
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"No alias or room id found in url {string}")
|
||||
|
||||
if not self.room_id_or_alias_regex.match(string):
|
||||
raise ValueError("Not an alias or room id")
|
||||
|
||||
response = await super().join(string)
|
||||
|
||||
if isinstance(response, nio.JoinError):
|
||||
raise MatrixError.from_nio(response)
|
||||
|
||||
return response.room_id
|
||||
|
||||
|
||||
async def room_forget(self, room_id: str) -> None:
|
||||
await super().room_leave(room_id)
|
||||
await super().room_forget(room_id)
|
||||
self.models[Room, self.user_id].pop(room_id, None)
|
||||
self.models.pop((Event, self.user_id, room_id), None)
|
||||
self.models.pop((Member, self.user_id, room_id), None)
|
||||
|
||||
|
||||
async def room_mass_invite(
|
||||
self, room_id: str, *user_ids: str,
|
||||
) -> Tuple[List[str], List[Tuple[str, Exception]]]:
|
||||
|
||||
user_ids = tuple(
|
||||
uid for uid in user_ids
|
||||
# Server would return a 403 forbidden for users already in the room
|
||||
if uid not in self.all_rooms[room_id].users
|
||||
)
|
||||
|
||||
async def invite(user):
|
||||
if not self.user_id_regex.match(user):
|
||||
return InvalidUserId(user)
|
||||
|
||||
if isinstance(await self.get_profile(user), nio.ProfileGetError):
|
||||
return UserNotFound(user)
|
||||
|
||||
return await self.room_invite(room_id, user)
|
||||
|
||||
coros = [invite(uid) for uid in user_ids]
|
||||
successes = []
|
||||
errors: list = []
|
||||
responses = await asyncio.gather(*coros)
|
||||
|
||||
for user_id, response in zip(user_ids, responses):
|
||||
if isinstance(response, nio.RoomInviteError):
|
||||
errors.append((user_id, MatrixError.from_nio(response)))
|
||||
|
||||
elif isinstance(response, Exception):
|
||||
errors.append((user_id, response))
|
||||
|
||||
else:
|
||||
successes.append(user_id)
|
||||
|
||||
return (successes, errors)
|
||||
|
||||
|
||||
async def generate_thumbnail(
|
||||
self, data: UploadData, is_svg: bool = False,
|
||||
) -> Tuple[bytes, MatrixImageInfo]:
|
||||
|
||||
png_modes = ("1", "L", "P", "RGBA")
|
||||
|
||||
data = b"".join([c async for c in async_generator_from_data(data)])
|
||||
is_svg = await utils.guess_mime(data) == "image/svg+xml"
|
||||
|
||||
if is_svg:
|
||||
svg_width, svg_height = await utils.svg_dimensions(data)
|
||||
|
||||
data = cairosvg.svg2png(
|
||||
bytestring = data,
|
||||
parent_width = svg_width,
|
||||
parent_height = svg_height,
|
||||
)
|
||||
|
||||
thumb = PILImage.open(io.BytesIO(data))
|
||||
|
||||
small = thumb.width <= 800 and thumb.height <= 600
|
||||
is_jpg_png = thumb.format in ("JPEG", "PNG")
|
||||
jpgable_png = thumb.format == "PNG" and thumb.mode not in png_modes
|
||||
|
||||
if small and is_jpg_png and not jpgable_png and not is_svg:
|
||||
raise UneededThumbnail()
|
||||
|
||||
if not small:
|
||||
thumb.thumbnail((800, 600), PILImage.LANCZOS)
|
||||
|
||||
with io.BytesIO() as out:
|
||||
if thumb.mode in png_modes:
|
||||
thumb.save(out, "PNG", optimize=True)
|
||||
mime = "image/png"
|
||||
else:
|
||||
thumb.convert("RGB").save(out, "JPEG", optimize=True)
|
||||
mime = "image/jpeg"
|
||||
|
||||
data = out.getvalue()
|
||||
|
||||
info = MatrixImageInfo(thumb.width, thumb.height, mime, len(data))
|
||||
return (data, info)
|
||||
|
||||
|
||||
async def upload(
|
||||
self,
|
||||
data_provider: nio.DataProvider,
|
||||
mime: Optional[str] = None,
|
||||
filename: Optional[str] = None,
|
||||
encrypt: bool = False,
|
||||
monitor: Optional[nio.TransferMonitor] = None,
|
||||
) -> UploadReturn:
|
||||
|
||||
mime = mime or await utils.guess_mime(data_provider(0, 0))
|
||||
|
||||
response, decryption_dict = await super().upload(
|
||||
data_provider,
|
||||
"application/octet-stream" if encrypt else mime,
|
||||
filename,
|
||||
encrypt,
|
||||
monitor,
|
||||
)
|
||||
|
||||
if isinstance(response, nio.UploadError):
|
||||
raise MatrixError.from_nio(response)
|
||||
|
||||
return UploadReturn(response.content_uri, mime, decryption_dict)
|
||||
|
||||
|
||||
async def set_avatar_from_file(self, path: Union[Path, str]) -> None:
|
||||
mime = await utils.guess_mime(path)
|
||||
|
||||
if mime.split("/")[0] != "image":
|
||||
raise BadMimeType(wanted="image/*", got=mime)
|
||||
|
||||
mxc, *_ = await self.upload(lambda *_: path, mime, Path(path).name)
|
||||
await self.set_avatar(mxc)
|
||||
|
||||
|
||||
async def import_keys(self, infile: str, passphrase: str) -> None:
|
||||
await super().import_keys(infile, passphrase)
|
||||
await self.retry_decrypting_events()
|
||||
|
||||
|
||||
async def export_keys(self, outfile: str, passphrase: str) -> None:
|
||||
path = Path(outfile)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# The QML dialog asks the user if he wants to overwrite before this
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
|
||||
await super().export_keys(outfile, passphrase)
|
||||
|
||||
|
||||
async def retry_decrypting_events(self) -> None:
|
||||
for sync_id, model in self.models.items():
|
||||
if not (isinstance(sync_id, tuple) and
|
||||
sync_id[0:2] == (Event, self.user_id)):
|
||||
continue
|
||||
|
||||
_, _, room_id = sync_id
|
||||
|
||||
for ev in model.values():
|
||||
room = self.all_rooms[room_id]
|
||||
|
||||
if isinstance(ev.source, nio.MegolmEvent):
|
||||
try:
|
||||
decrypted = self.decrypt_event(ev.source)
|
||||
|
||||
if not decrypted:
|
||||
raise nio.EncryptionError()
|
||||
|
||||
except nio.EncryptionError:
|
||||
continue
|
||||
|
||||
for cb in self.event_callbacks:
|
||||
if not cb.filter or isinstance(decrypted, cb.filter):
|
||||
await asyncio.coroutine(cb.func)(room, decrypted)
|
||||
|
||||
|
||||
async def clear_events(self, room_id: str) -> None:
|
||||
self.cleared_events_rooms.add(room_id)
|
||||
model = self.models[Event, self.user_id, room_id]
|
||||
if model:
|
||||
model.clear()
|
||||
model.sync_now()
|
||||
|
||||
|
||||
# Functions to register data into models
|
||||
|
||||
async def event_is_past(self, ev: Union[nio.Event, Event]) -> bool:
|
||||
if not self.first_sync_date:
|
||||
return True
|
||||
|
||||
if isinstance(ev, Event):
|
||||
return ev.date < self.first_sync_date
|
||||
|
||||
date = datetime.fromtimestamp(ev.server_timestamp / 1000)
|
||||
return date < self.first_sync_date
|
||||
|
||||
|
||||
async def set_room_last_event(self, room_id: str, item: Event) -> None:
|
||||
model = self.models[Room, self.user_id]
|
||||
room = model[room_id]
|
||||
|
||||
if room.last_event is None:
|
||||
room.last_event = item.serialized
|
||||
|
||||
if item.is_local_echo:
|
||||
model.sync_now()
|
||||
|
||||
return
|
||||
|
||||
is_profile_ev = item.type_specifier == TypeSpecifier.profile_change
|
||||
|
||||
# If there were no better events available to show previously
|
||||
prev_is_profile_ev = \
|
||||
room.last_event["type_specifier"] == TypeSpecifier.profile_change
|
||||
|
||||
# If this is a profile event, only replace the currently shown one if
|
||||
# it was also a profile event (we had nothing better to show).
|
||||
if is_profile_ev and not prev_is_profile_ev:
|
||||
return
|
||||
|
||||
# If this event is older than the currently shown one, only replace
|
||||
# it if the previous was a profile event.
|
||||
if item.date < room.last_event["date"] and not prev_is_profile_ev:
|
||||
return
|
||||
|
||||
room.last_event = item.serialized
|
||||
|
||||
if item.is_local_echo:
|
||||
model.sync_now()
|
||||
|
||||
|
||||
async def register_nio_room(self, room: nio.MatrixRoom, left: bool = False,
|
||||
) -> None:
|
||||
# Add room
|
||||
try:
|
||||
last_ev = self.models[Room, self.user_id][room.room_id].last_event
|
||||
except KeyError:
|
||||
last_ev = None
|
||||
|
||||
inviter = getattr(room, "inviter", "") or ""
|
||||
levels = room.power_levels
|
||||
can_send_state = partial(levels.can_user_send_state, self.user_id)
|
||||
can_send_msg = partial(levels.can_user_send_message, self.user_id)
|
||||
|
||||
self.models[Room, self.user_id][room.room_id] = Room(
|
||||
room_id = room.room_id,
|
||||
given_name = room.name or "",
|
||||
display_name = room.display_name or "",
|
||||
avatar_url = room.gen_avatar_url or "",
|
||||
plain_topic = room.topic or "",
|
||||
topic = HTML_FILTER.filter_inline(room.topic or ""),
|
||||
inviter_id = inviter,
|
||||
inviter_name = room.user_name(inviter) if inviter else "",
|
||||
inviter_avatar =
|
||||
(room.avatar_url(inviter) or "") if inviter else "",
|
||||
left = left,
|
||||
|
||||
encrypted = room.encrypted,
|
||||
invite_required = room.join_rule == "invite",
|
||||
guests_allowed = room.guest_access == "can_join",
|
||||
|
||||
can_invite = levels.can_user_invite(self.user),
|
||||
can_send_messages = can_send_msg(),
|
||||
can_set_name = can_send_state("m.room.name"),
|
||||
can_set_topic = can_send_state("m.room.topic"),
|
||||
can_set_avatar = can_send_state("m.room.avatar"),
|
||||
can_set_encryption = can_send_state("m.room.encryption"),
|
||||
can_set_join_rules = can_send_state("m.room.join_rules"),
|
||||
can_set_guest_access = can_send_state("m.room.guest_access"),
|
||||
|
||||
last_event = last_ev,
|
||||
)
|
||||
|
||||
# List members that left the room, then remove them from our model
|
||||
left_the_room = [
|
||||
user_id
|
||||
for user_id in self.models[Member, self.user_id, room.room_id]
|
||||
if user_id not in room.users
|
||||
]
|
||||
|
||||
for user_id in left_the_room:
|
||||
del self.models[Member, self.user_id, room.room_id][user_id]
|
||||
|
||||
# Add the room members to the added room
|
||||
new_dict = {
|
||||
user_id: Member(
|
||||
user_id = user_id,
|
||||
display_name = room.user_name(user_id) # disambiguated
|
||||
if member.display_name else "",
|
||||
avatar_url = member.avatar_url or "",
|
||||
typing = user_id in room.typing_users,
|
||||
power_level = member.power_level,
|
||||
invited = member.invited,
|
||||
) for user_id, member in room.users.items()
|
||||
}
|
||||
self.models[Member, self.user_id, room.room_id].update(new_dict)
|
||||
|
||||
|
||||
async def get_member_name_avatar(self, room_id: str, user_id: str,
|
||||
) -> Tuple[str, str]:
|
||||
try:
|
||||
item = self.models[Member, self.user_id, room_id][user_id]
|
||||
except KeyError: # e.g. user is not anymore in the room
|
||||
try:
|
||||
info = await self.backend.get_profile(user_id)
|
||||
return (info.displayname or "", info.avatar_url or "")
|
||||
except MatrixError:
|
||||
return ("", "")
|
||||
else:
|
||||
return (item.display_name, item.avatar_url)
|
||||
|
||||
|
||||
async def register_nio_event(
|
||||
self, room: nio.MatrixRoom, ev: nio.Event, **fields,
|
||||
) -> None:
|
||||
|
||||
await self.register_nio_room(room)
|
||||
|
||||
sender_name, sender_avatar = \
|
||||
await self.get_member_name_avatar(room.room_id, ev.sender)
|
||||
|
||||
target_id = getattr(ev, "state_key", "") or ""
|
||||
|
||||
target_name, target_avatar = \
|
||||
await self.get_member_name_avatar(room.room_id, target_id) \
|
||||
if target_id else ("", "")
|
||||
|
||||
# Create Event ModelItem
|
||||
item = Event(
|
||||
source = ev,
|
||||
client_id = ev.event_id,
|
||||
event_id = ev.event_id,
|
||||
date = datetime.fromtimestamp(ev.server_timestamp / 1000),
|
||||
sender_id = ev.sender,
|
||||
sender_name = sender_name,
|
||||
sender_avatar = sender_avatar,
|
||||
target_id = target_id,
|
||||
target_name = target_name,
|
||||
target_avatar = target_avatar,
|
||||
**fields,
|
||||
)
|
||||
|
||||
# Add the Event to model
|
||||
tx_id = ev.source.get("content", {}).get(
|
||||
f"{__pkg_name__}.transaction_id",
|
||||
)
|
||||
local_sender = ev.sender in self.backend.clients
|
||||
|
||||
if local_sender and tx_id:
|
||||
item.client_id = f"echo-{tx_id}"
|
||||
|
||||
if not local_sender and not await self.event_is_past(ev):
|
||||
AlertRequested()
|
||||
|
||||
self.models[Event, self.user_id, room.room_id][item.client_id] = item
|
||||
|
||||
await self.set_room_last_event(room.room_id, item)
|
||||
|
||||
if item.sender_id == self.user_id:
|
||||
self.models[Event, self.user_id, room.room_id].sync_now()
|
263
src/backend/media_cache.py
Normal file
263
src/backend/media_cache.py
Normal file
@@ -0,0 +1,263 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import io
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, DefaultDict, Dict, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiofiles
|
||||
from PIL import Image as PILImage
|
||||
|
||||
import nio
|
||||
|
||||
from .backend import Backend
|
||||
from .utils import Size
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
import pyfastcopy # noqa
|
||||
|
||||
CryptDict = Optional[Dict[str, Any]]
|
||||
|
||||
CONCURRENT_DOWNLOADS_LIMIT = asyncio.BoundedSemaphore(8)
|
||||
ACCESS_LOCKS: DefaultDict[str, asyncio.Lock] = DefaultDict(asyncio.Lock)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Media:
|
||||
cache: "MediaCache" = field()
|
||||
mxc: str = field()
|
||||
crypt_dict: CryptDict = field(repr=False)
|
||||
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.mxc = re.sub(r"#auto$", "", self.mxc)
|
||||
|
||||
if not re.match(r"^mxc://.+/.+", self.mxc):
|
||||
raise ValueError(f"Invalid mxc URI: {self.mxc}")
|
||||
|
||||
|
||||
@property
|
||||
def local_path(self) -> Path:
|
||||
parsed = urlparse(self.mxc)
|
||||
name = parsed.path.lstrip("/")
|
||||
return self.cache.downloads_dir / parsed.netloc / name
|
||||
|
||||
|
||||
async def get(self) -> Path:
|
||||
async with ACCESS_LOCKS[self.mxc]:
|
||||
try:
|
||||
return await self._get_local_existing_file()
|
||||
except FileNotFoundError:
|
||||
return await self.create()
|
||||
|
||||
|
||||
async def _get_local_existing_file(self) -> Path:
|
||||
if not self.local_path.exists():
|
||||
raise FileNotFoundError()
|
||||
|
||||
return self.local_path
|
||||
|
||||
|
||||
async def create(self) -> Path:
|
||||
async with CONCURRENT_DOWNLOADS_LIMIT:
|
||||
data = await self._get_remote_data()
|
||||
|
||||
self.local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async with aiofiles.open(self.local_path, "wb") as file:
|
||||
await file.write(data)
|
||||
|
||||
return self.local_path
|
||||
|
||||
|
||||
async def _get_remote_data(self) -> bytes:
|
||||
parsed = urlparse(self.mxc)
|
||||
|
||||
resp = await self.cache.backend.download(
|
||||
server_name = parsed.netloc,
|
||||
media_id = parsed.path.lstrip("/"),
|
||||
)
|
||||
|
||||
return await self._decrypt(resp.body)
|
||||
|
||||
|
||||
async def _decrypt(self, data: bytes) -> bytes:
|
||||
if not self.crypt_dict:
|
||||
return data
|
||||
|
||||
func = functools.partial(
|
||||
nio.crypto.attachments.decrypt_attachment,
|
||||
data,
|
||||
self.crypt_dict["key"]["k"],
|
||||
self.crypt_dict["hashes"]["sha256"],
|
||||
self.crypt_dict["iv"],
|
||||
)
|
||||
|
||||
# Run in a separate thread
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
||||
|
||||
@classmethod
|
||||
async def from_existing_file(
|
||||
cls,
|
||||
cache: "MediaCache",
|
||||
mxc: str,
|
||||
existing: Path,
|
||||
overwrite: bool = False,
|
||||
**kwargs,
|
||||
) -> "Media":
|
||||
|
||||
media = cls(cache, mxc, {}, **kwargs) # type: ignore
|
||||
media.local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not media.local_path.exists() or overwrite:
|
||||
func = functools.partial(shutil.copy, existing, media.local_path)
|
||||
await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
||||
return media
|
||||
|
||||
|
||||
@classmethod
|
||||
async def from_bytes(
|
||||
cls,
|
||||
cache: "MediaCache",
|
||||
mxc: str,
|
||||
data: bytes,
|
||||
overwrite: bool = False,
|
||||
**kwargs,
|
||||
) -> "Media":
|
||||
|
||||
media = cls(cache, mxc, {}, **kwargs) # type: ignore
|
||||
media.local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not media.local_path.exists() or overwrite:
|
||||
async with aiofiles.open(media.local_path, "wb") as file:
|
||||
await file.write(data)
|
||||
|
||||
return media
|
||||
|
||||
|
||||
@dataclass
|
||||
class Thumbnail(Media):
|
||||
cache: "MediaCache" = field()
|
||||
mxc: str = field()
|
||||
crypt_dict: CryptDict = field(repr=False)
|
||||
wanted_size: Size = field()
|
||||
|
||||
server_size: Optional[Size] = field(init=False, repr=False, default=None)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def normalize_size(size: Size) -> Size:
|
||||
# https://matrix.org/docs/spec/client_server/latest#thumbnails
|
||||
|
||||
if size[0] > 640 or size[1] > 480:
|
||||
return (800, 600)
|
||||
|
||||
if size[0] > 320 or size[1] > 240:
|
||||
return (640, 480)
|
||||
|
||||
if size[0] > 96 or size[1] > 96:
|
||||
return (320, 240)
|
||||
|
||||
if size[0] > 32 or size[1] > 32:
|
||||
return (96, 96)
|
||||
|
||||
return (32, 32)
|
||||
|
||||
|
||||
@property
|
||||
def local_path(self) -> Path:
|
||||
# example: thumbnails/matrix.org/32x32/<mxc id>
|
||||
|
||||
parsed = urlparse(self.mxc)
|
||||
size = self.normalize_size(self.server_size or self.wanted_size)
|
||||
name = "%dx%d/%s" % (size[0], size[1], parsed.path.lstrip("/"))
|
||||
|
||||
return self.cache.thumbs_dir / parsed.netloc / name
|
||||
|
||||
|
||||
async def _get_local_existing_file(self) -> Path:
|
||||
if self.local_path.exists():
|
||||
return self.local_path
|
||||
|
||||
# If we have a bigger size thumbnail than the wanted_size for this pic,
|
||||
# return it instead of asking the server for a smaller thumbnail.
|
||||
|
||||
try_sizes = ((32, 32), (96, 96), (320, 240), (640, 480), (800, 600))
|
||||
parts = list(self.local_path.parts)
|
||||
size = self.normalize_size(self.server_size or self.wanted_size)
|
||||
|
||||
for width, height in try_sizes:
|
||||
if width < size[0] or height < size[1]:
|
||||
continue
|
||||
|
||||
parts[-2] = f"{width}x{height}"
|
||||
path = Path("/".join(parts))
|
||||
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
raise FileNotFoundError()
|
||||
|
||||
|
||||
async def _get_remote_data(self) -> bytes:
|
||||
parsed = urlparse(self.mxc)
|
||||
|
||||
if self.crypt_dict:
|
||||
# Matrix makes encrypted thumbs only available through the download
|
||||
# end-point, not the thumbnail one
|
||||
resp = await self.cache.backend.download(
|
||||
server_name = parsed.netloc,
|
||||
media_id = parsed.path.lstrip("/"),
|
||||
)
|
||||
else:
|
||||
resp = await self.cache.backend.thumbnail(
|
||||
server_name = parsed.netloc,
|
||||
media_id = parsed.path.lstrip("/"),
|
||||
width = self.wanted_size[0],
|
||||
height = self.wanted_size[1],
|
||||
)
|
||||
|
||||
decrypted = await self._decrypt(resp.body)
|
||||
|
||||
with io.BytesIO(decrypted) as img:
|
||||
# The server may return a thumbnail bigger than what we asked for
|
||||
self.server_size = PILImage.open(img).size
|
||||
|
||||
return decrypted
|
||||
|
||||
|
||||
@dataclass
|
||||
class MediaCache:
|
||||
backend: Backend = field()
|
||||
base_dir: Path = field()
|
||||
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.thumbs_dir = self.base_dir / "thumbnails"
|
||||
self.downloads_dir = self.base_dir / "downloads"
|
||||
|
||||
self.thumbs_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.downloads_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# These methods are for conveniant usage from QML
|
||||
|
||||
async def get_media(self, mxc: str, crypt_dict: CryptDict = None) -> Path:
|
||||
return await Media(self, mxc, crypt_dict).get()
|
||||
|
||||
|
||||
async def get_thumbnail(
|
||||
self, mxc: str, width: int, height: int, crypt_dict: CryptDict = None,
|
||||
) -> Path:
|
||||
|
||||
thumb = Thumbnail(
|
||||
# QML sometimes pass float sizes, which matrix API doesn't like.
|
||||
self, mxc, crypt_dict, (round(width), round(height)),
|
||||
)
|
||||
return await thumb.get()
|
6
src/backend/models/__init__.py
Normal file
6
src/backend/models/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from typing import Tuple, Type, Union
|
||||
|
||||
from .model_item import ModelItem
|
||||
|
||||
# last one: Tuple[Union[Type[ModelItem], Tuple[Type[ModelItem]]], str...]
|
||||
SyncId = Union[Type[ModelItem], Tuple[Type[ModelItem]], tuple]
|
238
src/backend/models/items.py
Normal file
238
src/backend/models/items.py
Normal file
@@ -0,0 +1,238 @@
|
||||
import asyncio
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
from uuid import UUID
|
||||
|
||||
import lxml # nosec
|
||||
|
||||
import nio
|
||||
|
||||
from ..html_filter import HTML_FILTER
|
||||
from ..utils import AutoStrEnum, auto
|
||||
from .model_item import ModelItem
|
||||
|
||||
OptionalExceptionType = Union[Type[None], Type[Exception]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Account(ModelItem):
|
||||
user_id: str = field()
|
||||
display_name: str = ""
|
||||
avatar_url: str = ""
|
||||
first_sync_done: bool = False
|
||||
profile_updated: Optional[datetime] = None
|
||||
|
||||
def __lt__(self, other: "Account") -> bool:
|
||||
name = self.display_name or self.user_id[1:]
|
||||
other_name = other.display_name or other.user_id[1:]
|
||||
return name < other_name
|
||||
|
||||
@property
|
||||
def filter_string(self) -> str:
|
||||
return self.display_name
|
||||
|
||||
|
||||
@dataclass
|
||||
class Room(ModelItem):
|
||||
room_id: str = field()
|
||||
given_name: str = ""
|
||||
display_name: str = ""
|
||||
avatar_url: str = ""
|
||||
plain_topic: str = ""
|
||||
topic: str = ""
|
||||
inviter_id: str = ""
|
||||
inviter_name: str = ""
|
||||
inviter_avatar: str = ""
|
||||
left: bool = False
|
||||
typing_members: List[str] = field(default_factory=list)
|
||||
|
||||
encrypted: bool = False
|
||||
invite_required: bool = True
|
||||
guests_allowed: bool = True
|
||||
|
||||
can_invite: bool = False
|
||||
can_send_messages: bool = False
|
||||
can_set_name: bool = False
|
||||
can_set_topic: bool = False
|
||||
can_set_avatar: bool = False
|
||||
can_set_encryption: bool = False
|
||||
can_set_join_rules: bool = False
|
||||
can_set_guest_access: bool = False
|
||||
|
||||
# Event.serialized
|
||||
last_event: Optional[Dict[str, Any]] = field(default=None, repr=False)
|
||||
|
||||
def __lt__(self, other: "Room") -> bool:
|
||||
# Order: Invited rooms > joined rooms > left rooms.
|
||||
# Within these categories, sort by date then by name.
|
||||
# Left rooms may still have an inviter_id, so check left first.
|
||||
return (
|
||||
self.left,
|
||||
|
||||
other.inviter_id,
|
||||
|
||||
other.last_event["date"] if other.last_event else
|
||||
datetime.fromtimestamp(0),
|
||||
|
||||
self.display_name.lower() or self.room_id,
|
||||
) < (
|
||||
other.left,
|
||||
|
||||
self.inviter_id,
|
||||
|
||||
self.last_event["date"] if self.last_event else
|
||||
datetime.fromtimestamp(0),
|
||||
|
||||
other.display_name.lower() or other.room_id,
|
||||
)
|
||||
|
||||
@property
|
||||
def filter_string(self) -> str:
|
||||
return " ".join((
|
||||
self.display_name,
|
||||
self.topic,
|
||||
re.sub(r"<.*?>", "", self.last_event["inline_content"])
|
||||
if self.last_event else "",
|
||||
))
|
||||
|
||||
|
||||
@dataclass
|
||||
class Member(ModelItem):
|
||||
user_id: str = field()
|
||||
display_name: str = ""
|
||||
avatar_url: str = ""
|
||||
typing: bool = False
|
||||
power_level: int = 0
|
||||
invited: bool = False
|
||||
|
||||
def __lt__(self, other: "Member") -> bool:
|
||||
# Sort by name, but have members with higher power-level first and
|
||||
# invited-but-not-joined members last
|
||||
name = (self.display_name or self.user_id[1:]).lower()
|
||||
other_name = (other.display_name or other.user_id[1:]).lower()
|
||||
|
||||
return (
|
||||
self.invited, other.power_level, name,
|
||||
) < (
|
||||
other.invited, self.power_level, other_name,
|
||||
)
|
||||
|
||||
|
||||
@property
|
||||
def filter_string(self) -> str:
|
||||
return self.display_name
|
||||
|
||||
|
||||
class UploadStatus(AutoStrEnum):
|
||||
Uploading = auto()
|
||||
Caching = auto()
|
||||
Error = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Upload(ModelItem):
|
||||
uuid: UUID = field()
|
||||
task: asyncio.Task = field()
|
||||
monitor: nio.TransferMonitor = field()
|
||||
filepath: Path = field()
|
||||
|
||||
total_size: int = 0
|
||||
uploaded: int = 0
|
||||
speed: float = 0
|
||||
time_left: Optional[timedelta] = None
|
||||
|
||||
status: UploadStatus = UploadStatus.Uploading
|
||||
error: OptionalExceptionType = type(None)
|
||||
error_args: Tuple[Any, ...] = ()
|
||||
|
||||
start_date: datetime = field(init=False, default_factory=datetime.now)
|
||||
|
||||
|
||||
def __lt__(self, other: "Upload") -> bool:
|
||||
# Sort from newest upload to oldest.
|
||||
return self.start_date > other.start_date
|
||||
|
||||
|
||||
class TypeSpecifier(AutoStrEnum):
|
||||
none = auto()
|
||||
profile_change = auto()
|
||||
membership_change = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Event(ModelItem):
|
||||
source: Optional[nio.Event] = field()
|
||||
client_id: str = field()
|
||||
event_id: str = field()
|
||||
date: datetime = field()
|
||||
sender_id: str = field()
|
||||
sender_name: str = field()
|
||||
sender_avatar: str = field()
|
||||
|
||||
content: str = ""
|
||||
inline_content: str = ""
|
||||
|
||||
type_specifier: TypeSpecifier = TypeSpecifier.none
|
||||
|
||||
target_id: str = ""
|
||||
target_name: str = ""
|
||||
target_avatar: str = ""
|
||||
|
||||
is_local_echo: bool = False
|
||||
local_event_type: Optional[Type[nio.Event]] = None
|
||||
|
||||
media_url: str = ""
|
||||
media_title: str = ""
|
||||
media_width: int = 0
|
||||
media_height: int = 0
|
||||
media_duration: int = 0
|
||||
media_size: int = 0
|
||||
media_mime: str = ""
|
||||
media_crypt_dict: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
thumbnail_url: str = ""
|
||||
thumbnail_width: int = 0
|
||||
thumbnail_height: int = 0
|
||||
thumbnail_crypt_dict: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.inline_content:
|
||||
self.inline_content = HTML_FILTER.filter_inline(self.content)
|
||||
|
||||
|
||||
def __lt__(self, other: "Event") -> bool:
|
||||
# Sort events from newest to oldest. return True means return False.
|
||||
return self.date > other.date
|
||||
|
||||
@property
|
||||
def event_type(self) -> Type:
|
||||
if self.local_event_type:
|
||||
return self.local_event_type
|
||||
|
||||
return type(self.source)
|
||||
|
||||
@property
|
||||
def links(self) -> List[str]:
|
||||
urls: List[str] = []
|
||||
|
||||
if self.content.strip():
|
||||
urls += [link[2] for link in lxml.html.iterlinks(self.content)]
|
||||
|
||||
if self.media_url:
|
||||
urls.append(self.media_url)
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
@dataclass
|
||||
class Device(ModelItem):
|
||||
device_id: str = field()
|
||||
ed25519_key: str = field()
|
||||
trusted: bool = False
|
||||
blacklisted: bool = False
|
||||
display_name: str = ""
|
||||
last_seen_ip: str = ""
|
||||
last_seen_date: str = ""
|
111
src/backend/models/model.py
Normal file
111
src/backend/models/model.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import logging as log
|
||||
import time
|
||||
from threading import Lock, Thread
|
||||
from typing import Any, Dict, Iterator, List, MutableMapping
|
||||
|
||||
from . import SyncId
|
||||
from ..pyotherside_events import ModelUpdated
|
||||
from .model_item import ModelItem
|
||||
|
||||
|
||||
class Model(MutableMapping):
|
||||
def __init__(self, sync_id: SyncId) -> None:
|
||||
self.sync_id: SyncId = sync_id
|
||||
self._data: Dict[Any, ModelItem] = {}
|
||||
|
||||
self._changed: bool = False
|
||||
self._sync_lock: Lock = Lock()
|
||||
self._sync_thread: Thread = Thread(target=self._sync_loop, daemon=True)
|
||||
self._sync_thread.start()
|
||||
|
||||
|
||||
def __repr__(self) -> str:
|
||||
try:
|
||||
from pprintpp import pformat
|
||||
except ImportError:
|
||||
from pprint import pformat # type: ignore
|
||||
|
||||
if isinstance(self.sync_id, tuple):
|
||||
sid = (self.sync_id[0].__name__, *self.sync_id[1:])
|
||||
else:
|
||||
sid = self.sync_id.__name__ # type: ignore
|
||||
|
||||
return "%s(sync_id=%s, %s)" % (
|
||||
type(self).__name__, sid, pformat(self._data),
|
||||
)
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
if isinstance(self.sync_id, tuple):
|
||||
reprs = tuple(repr(s) for s in self.sync_id[1:])
|
||||
sid = ", ".join((self.sync_id[0].__name__, *reprs))
|
||||
sid = f"({sid})"
|
||||
else:
|
||||
sid = self.sync_id.__name__
|
||||
|
||||
return f"{sid!s}: {len(self)} items"
|
||||
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._data[key]
|
||||
|
||||
|
||||
def __setitem__(self, key, value: ModelItem) -> None:
|
||||
new = value
|
||||
|
||||
if key in self:
|
||||
existing = dict(self[key].serialized) # copy to not alter with pop
|
||||
merged = {**existing, **value.serialized}
|
||||
|
||||
existing.pop("parent_model", None)
|
||||
merged.pop("parent_model", None)
|
||||
|
||||
if merged == existing:
|
||||
return
|
||||
|
||||
merged_init_kwargs = {**vars(self[key]), **vars(value)}
|
||||
merged_init_kwargs.pop("parent_model", None)
|
||||
new = type(value)(**merged_init_kwargs)
|
||||
|
||||
new.parent_model = self
|
||||
|
||||
with self._sync_lock:
|
||||
self._data[key] = new
|
||||
self._changed = True
|
||||
|
||||
|
||||
def __delitem__(self, key) -> None:
|
||||
with self._sync_lock:
|
||||
del self._data[key]
|
||||
self._changed = True
|
||||
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
return iter(self._data)
|
||||
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._data)
|
||||
|
||||
|
||||
def _sync_loop(self) -> None:
|
||||
while True:
|
||||
time.sleep(0.25)
|
||||
|
||||
if self._changed:
|
||||
with self._sync_lock:
|
||||
log.debug("Syncing %s", self)
|
||||
self.sync_now()
|
||||
|
||||
|
||||
def sync_now(self) -> None:
|
||||
ModelUpdated(self.sync_id, self.serialized())
|
||||
self._changed = False
|
||||
|
||||
|
||||
def serialized(self) -> List[Dict[str, Any]]:
|
||||
return [item.serialized for item in sorted(self._data.values())]
|
||||
|
||||
|
||||
def __lt__(self, other: "Model") -> bool:
|
||||
return str(self.sync_id) < str(other.sync_id)
|
33
src/backend/models/model_item.py
Normal file
33
src/backend/models/model_item.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from ..utils import serialize_value_for_qml
|
||||
|
||||
|
||||
class ModelItem:
|
||||
def __new__(cls, *_args, **_kwargs) -> "ModelItem":
|
||||
from .model import Model
|
||||
cls.parent_model: Optional[Model] = None
|
||||
return super().__new__(cls)
|
||||
|
||||
|
||||
def __setattr__(self, name: str, value) -> None:
|
||||
super().__setattr__(name, value)
|
||||
|
||||
if name != "parent_model" and self.parent_model is not None:
|
||||
with self.parent_model._sync_lock:
|
||||
self.parent_model._changed = True
|
||||
|
||||
|
||||
def __delattr__(self, name: str) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@property
|
||||
def serialized(self) -> Dict[str, Any]:
|
||||
return {
|
||||
name: serialize_value_for_qml(getattr(self, name))
|
||||
for name in dir(self)
|
||||
if not (
|
||||
name.startswith("_") or name in ("parent_model", "serialized")
|
||||
)
|
||||
}
|
61
src/backend/models/model_store.py
Normal file
61
src/backend/models/model_store.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from typing import Dict, Iterator, MutableMapping, Set, Tuple, Type, Union
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from . import SyncId
|
||||
from .model_item import ModelItem
|
||||
from .model import Model
|
||||
|
||||
KeyType = Union[Type[ModelItem], Tuple[Type, ...]]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelStore(MutableMapping):
|
||||
allowed_key_types: Set[KeyType] = field()
|
||||
|
||||
data: Dict[SyncId, Model] = field(init=False, default_factory=dict)
|
||||
|
||||
|
||||
def __getitem__(self, key: SyncId) -> Model:
|
||||
try:
|
||||
return self.data[key]
|
||||
except KeyError:
|
||||
if isinstance(key, tuple):
|
||||
for i in key:
|
||||
if not i:
|
||||
raise ValueError(f"Empty string in key: {key!r}")
|
||||
|
||||
key_type = (key[0],) + \
|
||||
tuple(type(el) for el in key[1:])
|
||||
else:
|
||||
key_type = key # type: ignore
|
||||
|
||||
if key_type not in self.allowed_key_types:
|
||||
raise TypeError(f"{key_type!r} not in allowed key types: "
|
||||
f"{self.allowed_key_types!r}")
|
||||
|
||||
model = Model(key)
|
||||
self.data[key] = model
|
||||
return model
|
||||
|
||||
|
||||
def __setitem__(self, key, item) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def __delitem__(self, key: SyncId) -> None:
|
||||
del self.data[key]
|
||||
|
||||
|
||||
def __iter__(self) -> Iterator[SyncId]:
|
||||
return iter(self.data)
|
||||
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "%s(\n %s\n)" % (
|
||||
type(self).__name__,
|
||||
"\n ".join(sorted(str(v) for v in self.values())),
|
||||
)
|
381
src/backend/nio_callbacks.py
Normal file
381
src/backend/nio_callbacks.py
Normal file
@@ -0,0 +1,381 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging as log
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Optional, Tuple
|
||||
from urllib.parse import quote
|
||||
|
||||
import nio
|
||||
|
||||
from . import utils
|
||||
from .html_filter import HTML_FILTER
|
||||
from .matrix_client import MatrixClient
|
||||
from .models.items import Account, Room, TypeSpecifier
|
||||
|
||||
|
||||
@dataclass
|
||||
class NioCallbacks:
|
||||
client: MatrixClient = field()
|
||||
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
c = self.client
|
||||
|
||||
for name, class_ in utils.classes_defined_in(nio.responses).items():
|
||||
with suppress(AttributeError):
|
||||
c.add_response_callback(getattr(self, f"on{name}"), class_)
|
||||
|
||||
for name, class_ in utils.classes_defined_in(nio.events).items():
|
||||
with suppress(AttributeError):
|
||||
c.add_event_callback(getattr(self, f"on{name}"), class_)
|
||||
|
||||
c.add_ephemeral_callback(
|
||||
self.onTypingNoticeEvent, nio.events.TypingNoticeEvent,
|
||||
)
|
||||
|
||||
|
||||
async def onSyncResponse(self, resp: nio.SyncResponse) -> None:
|
||||
c = self.client
|
||||
|
||||
for room_id, info in resp.rooms.join.items():
|
||||
if room_id not in c.past_tokens:
|
||||
c.past_tokens[room_id] = info.timeline.prev_batch
|
||||
|
||||
# TODO: way of knowing if a nio.MatrixRoom is left
|
||||
for room_id, info in resp.rooms.leave.items():
|
||||
# TODO: handle in nio, these are rooms that were left before
|
||||
# starting the client.
|
||||
if room_id not in c.all_rooms:
|
||||
log.warning("Left room not in MatrixClient.rooms: %r", room_id)
|
||||
continue
|
||||
|
||||
# TODO: handle left events in nio async client
|
||||
for ev in info.timeline.events:
|
||||
if isinstance(ev, nio.RoomMemberEvent):
|
||||
await self.onRoomMemberEvent(c.all_rooms[room_id], ev)
|
||||
|
||||
await c.register_nio_room(c.all_rooms[room_id], left=True)
|
||||
|
||||
if not c.first_sync_done.is_set():
|
||||
self.client.load_rooms_task = asyncio.ensure_future(
|
||||
c.load_rooms_without_visible_events(),
|
||||
)
|
||||
|
||||
c.first_sync_done.set()
|
||||
c.first_sync_date = datetime.now()
|
||||
c.models[Account][c.user_id].first_sync_done = True
|
||||
|
||||
|
||||
async def onErrorResponse(self, resp: nio.ErrorResponse) -> None:
|
||||
# TODO: show something in the client, must be seen on login screen too
|
||||
try:
|
||||
log.warning("%s - %s", resp, json.dumps(vars(resp), indent=4))
|
||||
except Exception:
|
||||
log.warning(repr(resp))
|
||||
|
||||
|
||||
# Callbacks for nio room events
|
||||
# Content: %1 is the sender, %2 the target (ev.state_key).
|
||||
|
||||
async def onRoomMessageText(self, room, ev) -> None:
|
||||
co = HTML_FILTER.filter(
|
||||
ev.formatted_body
|
||||
if ev.format == "org.matrix.custom.html" else
|
||||
utils.plain2html(ev.body),
|
||||
)
|
||||
await self.client.register_nio_event(room, ev, content=co)
|
||||
|
||||
|
||||
async def onRoomMessageNotice(self, room, ev) -> None:
|
||||
await self.onRoomMessageText(room, ev)
|
||||
|
||||
|
||||
async def onRoomMessageEmote(self, room, ev) -> None:
|
||||
await self.onRoomMessageText(room, ev)
|
||||
|
||||
|
||||
async def onRoomMessageUnknown(self, room, ev) -> None:
|
||||
co = f"%1 sent an unsupported <b>{ev.msgtype}</b> message"
|
||||
await self.client.register_nio_event(room, ev, content=co)
|
||||
|
||||
|
||||
async def onRoomMessageMedia(self, room, ev) -> None:
|
||||
info = ev.source["content"].get("info", {})
|
||||
media_crypt_dict = ev.source["content"].get("file", {})
|
||||
thumb_info = info.get("thumbnail_info", {})
|
||||
thumb_crypt_dict = info.get("thumbnail_file", {})
|
||||
|
||||
await self.client.register_nio_event(
|
||||
room,
|
||||
ev,
|
||||
content = "",
|
||||
inline_content = ev.body,
|
||||
|
||||
media_url = ev.url,
|
||||
media_title = ev.body,
|
||||
media_width = info.get("w") or 0,
|
||||
media_height = info.get("h") or 0,
|
||||
media_duration = info.get("duration") or 0,
|
||||
media_size = info.get("size") or 0,
|
||||
media_mime = info.get("mimetype") or 0,
|
||||
media_crypt_dict = media_crypt_dict,
|
||||
|
||||
thumbnail_url =
|
||||
info.get("thumbnail_url") or thumb_crypt_dict.get("url") or "",
|
||||
|
||||
thumbnail_width = thumb_info.get("w") or 0,
|
||||
thumbnail_height = thumb_info.get("h") or 0,
|
||||
thumbnail_crypt_dict = thumb_crypt_dict,
|
||||
)
|
||||
|
||||
|
||||
async def onRoomEncryptedMedia(self, room, ev) -> None:
|
||||
await self.onRoomMessageMedia(room, ev)
|
||||
|
||||
|
||||
async def onRoomCreateEvent(self, room, ev) -> None:
|
||||
co = "%1 allowed users on other matrix servers to join this room" \
|
||||
if ev.federate else \
|
||||
"%1 blocked users on other matrix servers from joining this room"
|
||||
await self.client.register_nio_event(room, ev, content=co)
|
||||
|
||||
|
||||
async def onRoomGuestAccessEvent(self, room, ev) -> None:
|
||||
allowed = "allowed" if ev.guest_access else "forbad"
|
||||
co = f"%1 {allowed} guests to join the room"
|
||||
await self.client.register_nio_event(room, ev, content=co)
|
||||
|
||||
|
||||
async def onRoomJoinRulesEvent(self, room, ev) -> None:
|
||||
access = "public" if ev.join_rule == "public" else "invite-only"
|
||||
co = f"%1 made the room {access}"
|
||||
await self.client.register_nio_event(room, ev, content=co)
|
||||
|
||||
|
||||
async def onRoomHistoryVisibilityEvent(self, room, ev) -> None:
|
||||
if ev.history_visibility == "shared":
|
||||
to = "all room members"
|
||||
elif ev.history_visibility == "world_readable":
|
||||
to = "any member or outsider"
|
||||
elif ev.history_visibility == "joined":
|
||||
to = "all room members, since the time they joined"
|
||||
elif ev.history_visibility == "invited":
|
||||
to = "all room members, since the time they were invited"
|
||||
else:
|
||||
to = "???"
|
||||
log.warning("Invalid visibility - %s",
|
||||
json.dumps(vars(ev), indent=4))
|
||||
|
||||
co = f"%1 made future room history visible to {to}"
|
||||
await self.client.register_nio_event(room, ev, content=co)
|
||||
|
||||
|
||||
async def onPowerLevelsEvent(self, room, ev) -> None:
|
||||
co = "%1 changed the room's permissions" # TODO: improve
|
||||
await self.client.register_nio_event(room, ev, content=co)
|
||||
|
||||
|
||||
async def process_room_member_event(
|
||||
self, room, ev,
|
||||
) -> Optional[Tuple[TypeSpecifier, str]]:
|
||||
|
||||
if ev.prev_content == ev.content:
|
||||
return None
|
||||
|
||||
prev = ev.prev_content
|
||||
now = ev.content
|
||||
membership = ev.membership
|
||||
prev_membership = ev.prev_membership
|
||||
ev_date = datetime.fromtimestamp(ev.server_timestamp / 1000)
|
||||
|
||||
member_change = TypeSpecifier.membership_change
|
||||
|
||||
# Membership changes
|
||||
if not prev or membership != prev_membership:
|
||||
reason = f". Reason: {now['reason']}" if now.get("reason") else ""
|
||||
|
||||
if membership == "join":
|
||||
return (
|
||||
member_change,
|
||||
"%1 accepted their invitation"
|
||||
if prev and prev_membership == "invite" else
|
||||
"%1 joined the room",
|
||||
)
|
||||
|
||||
if membership == "invite":
|
||||
return (member_change, "%1 invited %2 to the room")
|
||||
|
||||
if membership == "leave":
|
||||
if ev.state_key == ev.sender:
|
||||
return (
|
||||
member_change,
|
||||
f"%1 declined their invitation{reason}"
|
||||
if prev and prev_membership == "invite" else
|
||||
f"%1 left the room{reason}",
|
||||
)
|
||||
|
||||
return (
|
||||
member_change,
|
||||
|
||||
f"%1 withdrew %2's invitation{reason}"
|
||||
if prev and prev_membership == "invite" else
|
||||
|
||||
f"%1 unbanned %2 from the room{reason}"
|
||||
if prev and prev_membership == "ban" else
|
||||
|
||||
f"%1 kicked out %2 from the room{reason}",
|
||||
)
|
||||
|
||||
if membership == "ban":
|
||||
return (member_change, f"%1 banned %2 from the room{reason}")
|
||||
|
||||
# Profile changes
|
||||
changed = []
|
||||
|
||||
if prev and now["avatar_url"] != prev["avatar_url"]:
|
||||
changed.append("profile picture") # TODO: <img>s
|
||||
|
||||
if prev and now["displayname"] != prev["displayname"]:
|
||||
changed.append('display name from "{}" to "{}"'.format(
|
||||
prev["displayname"] or ev.state_key,
|
||||
now["displayname"] or ev.state_key,
|
||||
))
|
||||
|
||||
if changed:
|
||||
# Update our account profile if the event is newer than last update
|
||||
if ev.state_key == self.client.user_id:
|
||||
account = self.client.models[Account][self.client.user_id]
|
||||
updated = account.profile_updated
|
||||
|
||||
if not updated or updated < ev_date:
|
||||
account.profile_updated = ev_date
|
||||
account.display_name = now["displayname"] or ""
|
||||
account.avatar_url = now["avatar_url"] or ""
|
||||
|
||||
# Hide profile events from the timeline - XXX
|
||||
self.client.skipped_events[room.room_id] += 1
|
||||
return None
|
||||
|
||||
return (
|
||||
TypeSpecifier.profile_change,
|
||||
"%1 changed their {}".format(" and ".join(changed)),
|
||||
)
|
||||
|
||||
log.warning("Unknown member event: %s", json.dumps(vars(ev), indent=4))
|
||||
return None
|
||||
|
||||
|
||||
async def onRoomMemberEvent(self, room, ev) -> None:
|
||||
type_and_content = await self.process_room_member_event(room, ev)
|
||||
|
||||
if type_and_content is not None:
|
||||
type_specifier, content = type_and_content
|
||||
|
||||
await self.client.register_nio_event(
|
||||
room, ev, content=content, type_specifier=type_specifier,
|
||||
)
|
||||
else:
|
||||
# Normally, register_nio_event() will call register_nio_room().
|
||||
# but in this case we don't have any event we want to register.
|
||||
await self.client.register_nio_room(room)
|
||||
|
||||
|
||||
async def onRoomAliasEvent(self, room, ev) -> None:
|
||||
if ev.canonical_alias:
|
||||
url = f"https://matrix.to/#/{quote(ev.canonical_alias)}"
|
||||
link = f"<a href='{url}'>{ev.canonical_alias}</a>"
|
||||
co = f"%1 set the room's main address to {link}"
|
||||
else:
|
||||
co = "%1 removed the room's main address"
|
||||
|
||||
await self.client.register_nio_event(room, ev, content=co)
|
||||
|
||||
|
||||
async def onRoomNameEvent(self, room, ev) -> None:
|
||||
if ev.name:
|
||||
co = f"%1 changed the room's name to \"{ev.name}\""
|
||||
else:
|
||||
co = "%1 removed the room's name"
|
||||
|
||||
await self.client.register_nio_event(room, ev, content=co)
|
||||
|
||||
|
||||
async def onRoomAvatarEvent(self, room, ev) -> None:
|
||||
if ev.avatar_url:
|
||||
co = "%1 changed the room's picture"
|
||||
else:
|
||||
co = "%1 removed the room's picture"
|
||||
|
||||
await self.client.register_nio_event(
|
||||
room, ev, content=co, media_url=ev.avatar_url,
|
||||
)
|
||||
|
||||
|
||||
async def onRoomTopicEvent(self, room, ev) -> None:
|
||||
if ev.topic:
|
||||
topic = HTML_FILTER.filter_inline(ev.topic)
|
||||
co = f"%1 changed the room's topic to \"{topic}\""
|
||||
else:
|
||||
co = "%1 removed the room's topic"
|
||||
|
||||
await self.client.register_nio_event(room, ev, content=co)
|
||||
|
||||
|
||||
async def onRoomEncryptionEvent(self, room, ev) -> None:
|
||||
co = "%1 turned on encryption for this room"
|
||||
await self.client.register_nio_event(room, ev, content=co)
|
||||
|
||||
|
||||
async def onMegolmEvent(self, room, ev) -> None:
|
||||
co = "%1 sent an undecryptable message"
|
||||
await self.client.register_nio_event(room, ev, content=co)
|
||||
|
||||
|
||||
async def onBadEvent(self, room, ev) -> None:
|
||||
co = f"%1 sent a malformed <b>{ev.type}</b> event"
|
||||
await self.client.register_nio_event(room, ev, content=co)
|
||||
|
||||
|
||||
async def onUnknownBadEvent(self, room, ev) -> None:
|
||||
co = "%1 sent a malformed event lacking a minimal structure"
|
||||
await self.client.register_nio_event(room, ev, content=co)
|
||||
|
||||
|
||||
async def onUnknownEvent(self, room, ev) -> None:
|
||||
co = f"%1 sent an unsupported <b>{ev.type}</b> event"
|
||||
await self.client.register_nio_event(room, ev, content=co)
|
||||
|
||||
|
||||
async def onUnknownEncryptedEvent(self, room, ev) -> None:
|
||||
co = (
|
||||
f"%1 sent an <b>{ev.type}</b> event encrypted with "
|
||||
f"unsupported <b>{ev.algorithm}</b> algorithm"
|
||||
)
|
||||
await self.client.register_nio_event(room, ev, content=co)
|
||||
|
||||
|
||||
# Callbacks for nio invite events
|
||||
|
||||
async def onInviteEvent(self, room, ev) -> None:
|
||||
await self.client.register_nio_room(room)
|
||||
|
||||
|
||||
# Callbacks for nio ephemeral events
|
||||
|
||||
async def onTypingNoticeEvent(self, room, ev) -> None:
|
||||
# Prevent recent past typing notices from being shown for a split
|
||||
# second on client startup:
|
||||
if not self.client.first_sync_done.is_set():
|
||||
return
|
||||
|
||||
if room.room_id not in self.client.models[Room, self.client.user_id]:
|
||||
return
|
||||
|
||||
room_item = self.client.models[Room, self.client.user_id][room.room_id]
|
||||
|
||||
room_item.typing_members = sorted(
|
||||
room.user_name(user_id) for user_id in ev.users
|
||||
if user_id not in self.client.backend.clients
|
||||
)
|
65
src/backend/pyotherside_events.py
Normal file
65
src/backend/pyotherside_events.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import pyotherside
|
||||
|
||||
from .models import SyncId
|
||||
from .utils import serialize_value_for_qml
|
||||
|
||||
|
||||
@dataclass
|
||||
class PyOtherSideEvent:
|
||||
"""Event that will be sent to QML by PyOtherSide."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# CPython >= 3.6 or any Python >= 3.7 needed for correct dict order
|
||||
args = [
|
||||
serialize_value_for_qml(getattr(self, field))
|
||||
for field in self.__dataclass_fields__ # type: ignore
|
||||
]
|
||||
pyotherside.send(type(self).__name__, *args)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExitRequested(PyOtherSideEvent):
|
||||
"""Request for the application to exit."""
|
||||
|
||||
exit_code: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlertRequested(PyOtherSideEvent):
|
||||
"""Request an alert to be shown for msec milliseconds.
|
||||
The Alert state for example sets the urgency hint on X11/Wayland,
|
||||
or flashes the taskbar icon on Windows.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CoroutineDone(PyOtherSideEvent):
|
||||
"""Indicate that an asyncio coroutine finished."""
|
||||
|
||||
uuid: str = field()
|
||||
result: Any = None
|
||||
exception: Optional[Exception] = None
|
||||
traceback: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelUpdated(PyOtherSideEvent):
|
||||
"""Indicate that a backend model's data changed."""
|
||||
|
||||
sync_id: SyncId = field()
|
||||
data: List[Dict[str, Any]] = field()
|
||||
|
||||
serialized_sync_id: Union[str, List[str]] = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if isinstance(self.sync_id, tuple):
|
||||
self.serialized_sync_id = [
|
||||
e.__name__ if isinstance(e, type) else e for e in self.sync_id
|
||||
]
|
||||
else:
|
||||
self.serialized_sync_id = self.sync_id.__name__
|
||||
|
||||
super().__post_init__()
|
69
src/backend/theme_parser.py
Normal file
69
src/backend/theme_parser.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import re
|
||||
from typing import Generator
|
||||
|
||||
PROPERTY_TYPES = {"bool", "double", "int", "list", "real", "string", "url",
|
||||
"var", "date", "point", "rect", "size", "color"}
|
||||
|
||||
|
||||
def _add_property(line: str) -> str:
|
||||
if re.match(r"^\s*[a-zA-Z\d_]+\s*:$", line):
|
||||
return re.sub(r"^(\s*)(\S*\s*):$",
|
||||
r"\1readonly property QtObject \2: QtObject",
|
||||
line)
|
||||
|
||||
types = "|".join(PROPERTY_TYPES)
|
||||
if re.match(fr"^\s*({types}) [a-zA-Z\d_]+\s*:", line):
|
||||
return re.sub(r"^(\s*)(\S*)", r"\1property \2", line)
|
||||
|
||||
return line
|
||||
|
||||
|
||||
def _process_lines(content: str) -> Generator[str, None, None]:
|
||||
skip = False
|
||||
indent = " " * 4
|
||||
current_indent = 0
|
||||
|
||||
for line in content.split("\n"):
|
||||
line = line.rstrip()
|
||||
|
||||
if not line.strip() or line.strip().startswith("//"):
|
||||
continue
|
||||
|
||||
start_space_list = re.findall(r"^ +", line)
|
||||
start_space = start_space_list[0] if start_space_list else ""
|
||||
|
||||
line_indents = len(re.findall(indent, start_space))
|
||||
|
||||
if not skip:
|
||||
if line_indents > current_indent:
|
||||
yield "%s{" % (indent * current_indent)
|
||||
current_indent = line_indents
|
||||
|
||||
while line_indents < current_indent:
|
||||
current_indent -= 1
|
||||
yield "%s}" % (indent * current_indent)
|
||||
|
||||
line = _add_property(line)
|
||||
|
||||
yield line
|
||||
|
||||
skip = any((line.endswith(e) for e in "([{+\\,?:"))
|
||||
|
||||
while current_indent:
|
||||
current_indent -= 1
|
||||
yield "%s}" % (indent * current_indent)
|
||||
|
||||
|
||||
def convert_to_qml(theme_content: str) -> str:
|
||||
lines = [
|
||||
"import QtQuick 2.12",
|
||||
'import "Base"',
|
||||
"QtObject {",
|
||||
" function hsluv(h, s, l, a) { return utils.hsluv(h, s, l, a) }",
|
||||
" function hsl(h, s, l) { return utils.hsl(h, s, l) }",
|
||||
" function hsla(h, s, l, a) { return utils.hsla(h, s, l, a) }",
|
||||
" id: theme",
|
||||
]
|
||||
lines += [f" {line}" for line in _process_lines(theme_content)]
|
||||
lines += ["}"]
|
||||
return "\n".join(lines)
|
156
src/backend/utils.py
Normal file
156
src/backend/utils.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Contains various utilities that are used throughout the package."""
|
||||
|
||||
import collections
|
||||
import html
|
||||
import inspect
|
||||
import io
|
||||
import xml.etree.cElementTree as xml_etree # FIXME: bandit warning
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from enum import auto as autostr
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any, Dict, Tuple, Type
|
||||
from uuid import UUID
|
||||
|
||||
import filetype
|
||||
from aiofiles.threadpool.binary import AsyncBufferedReader
|
||||
|
||||
from nio.crypto import AsyncDataT as File
|
||||
from nio.crypto import async_generator_from_data
|
||||
|
||||
Size = Tuple[int, int]
|
||||
auto = autostr
|
||||
|
||||
|
||||
class AutoStrEnum(Enum):
|
||||
"""An Enum where auto() assigns the member's name instead of an int.
|
||||
|
||||
Example:
|
||||
>>> class Fruits(AutoStrEnum): apple = auto()
|
||||
>>> Fruits.apple.value
|
||||
"apple"
|
||||
"""
|
||||
@staticmethod
|
||||
def _generate_next_value_(name, *_):
|
||||
return name
|
||||
|
||||
|
||||
def dict_update_recursive(dict1: dict, dict2: dict) -> None:
|
||||
"""Recursive version of dict.update()."""
|
||||
# https://gist.github.com/angstwad/bf22d1822c38a92ec0a9
|
||||
|
||||
for k in dict2:
|
||||
if (k in dict1 and isinstance(dict1[k], dict) and
|
||||
isinstance(dict2[k], collections.Mapping)):
|
||||
dict_update_recursive(dict1[k], dict2[k])
|
||||
else:
|
||||
dict1[k] = dict2[k]
|
||||
|
||||
|
||||
async def is_svg(file: File) -> bool:
|
||||
"""Return True if the file is a SVG. Uses lxml for detection."""
|
||||
|
||||
chunks = [c async for c in async_generator_from_data(file)]
|
||||
|
||||
with io.BytesIO(b"".join(chunks)) as file:
|
||||
try:
|
||||
_, element = next(xml_etree.iterparse(file, ("start",)))
|
||||
return element.tag == "{http://www.w3.org/2000/svg}svg"
|
||||
except (StopIteration, xml_etree.ParseError):
|
||||
return False
|
||||
|
||||
|
||||
async def svg_dimensions(file: File) -> Size:
|
||||
"""Return the width & height or viewBox width & height for a SVG.
|
||||
If these properties are missing (broken file), ``(256, 256)`` is returned.
|
||||
"""
|
||||
|
||||
chunks = [c async for c in async_generator_from_data(file)]
|
||||
|
||||
with io.BytesIO(b"".join(chunks)) as file:
|
||||
attrs = xml_etree.parse(file).getroot().attrib
|
||||
|
||||
try:
|
||||
width = round(float(attrs.get("width", attrs["viewBox"].split()[3])))
|
||||
except (KeyError, IndexError, ValueError, TypeError):
|
||||
width = 256
|
||||
|
||||
try:
|
||||
height = round(float(attrs.get("height", attrs["viewBox"].split()[4])))
|
||||
except (KeyError, IndexError, ValueError, TypeError):
|
||||
height = 256
|
||||
|
||||
return (width, height)
|
||||
|
||||
|
||||
async def guess_mime(file: File) -> str:
|
||||
"""Return the mime type for a file, or application/octet-stream if it
|
||||
can't be guessed.
|
||||
"""
|
||||
|
||||
if isinstance(file, io.IOBase):
|
||||
file.seek(0, 0)
|
||||
elif isinstance(file, AsyncBufferedReader):
|
||||
await file.seek(0, 0)
|
||||
|
||||
try:
|
||||
first_chunk: bytes
|
||||
async for first_chunk in async_generator_from_data(file):
|
||||
break
|
||||
|
||||
# TODO: plaintext
|
||||
mime = filetype.guess_mime(first_chunk)
|
||||
|
||||
return mime or (
|
||||
"image/svg+xml" if await is_svg(file) else
|
||||
"application/octet-stream"
|
||||
)
|
||||
finally:
|
||||
if isinstance(file, io.IOBase):
|
||||
file.seek(0, 0)
|
||||
elif isinstance(file, AsyncBufferedReader):
|
||||
await file.seek(0, 0)
|
||||
|
||||
|
||||
def plain2html(text: str) -> str:
|
||||
"""Transform plain text into HTML, this converts \n and \t."""
|
||||
|
||||
return html.escape(text)\
|
||||
.replace("\n", "<br>")\
|
||||
.replace("\t", " " * 4)
|
||||
|
||||
|
||||
def serialize_value_for_qml(value: Any) -> Any:
|
||||
"""Transform a value to make it easier to use from QML.
|
||||
|
||||
Currently, this transforms Enum members to their actual value and Path
|
||||
objects to their string version.
|
||||
"""
|
||||
|
||||
if hasattr(value, "__class__") and issubclass(value.__class__, Enum):
|
||||
return value.value
|
||||
|
||||
if isinstance(value, Path):
|
||||
return f"file://{value!s}"
|
||||
|
||||
if isinstance(value, UUID):
|
||||
return str(value)
|
||||
|
||||
if isinstance(value, timedelta):
|
||||
return value.total_seconds() * 1000
|
||||
|
||||
if inspect.isclass(value):
|
||||
return value.__name__
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def classes_defined_in(module: ModuleType) -> Dict[str, Type]:
|
||||
"""Return a {name: class} dict of all the classes a module defines."""
|
||||
|
||||
return {
|
||||
m[0]: m[1] for m in inspect.getmembers(module, inspect.isclass)
|
||||
if not m[0].startswith("_") and
|
||||
m[1].__module__.startswith(module.__name__)
|
||||
}
|
Reference in New Issue
Block a user