Remove all Backend "get_any_client"-using methods

thumbnail() and download() were remaining.
Use a direct and carefully chosen MatrixClient's methods instead to
avoid problems mentioned in the previous commit 7502c1.
This commit is contained in:
miruka
2020-08-23 16:57:53 -04:00
parent 0f2efa9ba3
commit df3f1fb645
24 changed files with 117 additions and 129 deletions

View File

@@ -349,46 +349,7 @@ class Backend:
failures += 1
async def get_any_client(self) -> MatrixClient:
"""Return any healthy syncing `MatrixClient` registered in model."""
failures = 0
while True:
for client in self.clients.values():
if client.healthy:
return client
if failures and failures % 300 == 0:
log.warn(
"No healthy client found after %ds, stack trace:\n%s",
failures / 10, traceback.format_stack(),
)
await asyncio.sleep(0.1)
failures += 1
# Client functions that don't need authentification
async def thumbnail(
self, server_name: str, media_id: str, width: int, height: int,
) -> nio.ThumbnailResponse:
"""Return thumbnail for a matrix media."""
args = (server_name, media_id, width, height)
client = await self.get_any_client()
return await client.thumbnail(*args)
async def download(
self, server_name: str, media_id: str,
) -> nio.DownloadResponse:
"""Return the content of a matrix media."""
client = await self.get_any_client()
return await client.download(server_name, media_id)
# Multi-client Matrix functions
async def update_room_read_marker(
self, room_id: str, event_id: str,
@@ -416,6 +377,37 @@ class Backend:
await asyncio.gather(*[update(c) for c in self.clients.values()])
async def verify_device(
self, user_id: str, device_id: str, ed25519_key: str,
) -> None:
"""Mark a device as verified on all our accounts."""
for client in self.clients.values():
try:
device = client.device_store[user_id][device_id]
except KeyError:
continue
if device.ed25519 == ed25519_key:
client.verify_device(device)
async def blacklist_device(
self, user_id: str, device_id: str, ed25519_key: str,
) -> None:
"""Mark a device as blacklisted on all our accounts."""
for client in self.clients.values():
try:
# This won't include the client's current device, as expected
device = client.device_store[user_id][device_id]
except KeyError:
continue
if device.ed25519 == ed25519_key:
client.blacklist_device(device)
# General functions
async def get_config_dir(self) -> Path:
@@ -463,37 +455,6 @@ class Backend:
self.models["all_rooms"].set_account_collapse(user_id, collapse)
async def verify_device(
self, user_id: str, device_id: str, ed25519_key: str,
) -> None:
"""Mark a device as verified on all our accounts."""
for client in self.clients.values():
try:
device = client.device_store[user_id][device_id]
except KeyError:
continue
if device.ed25519 == ed25519_key:
client.verify_device(device)
async def blacklist_device(
self, user_id: str, device_id: str, ed25519_key: str,
) -> None:
"""Mark a device as blacklisted on all our accounts."""
for client in self.clients.values():
try:
# This won't include the client's current device, as expected
device = client.device_store[user_id][device_id]
except KeyError:
continue
if device.ed25519 == ed25519_key:
client.blacklist_device(device)
async def _ping_homeserver(
self, session: aiohttp.ClientSession, homeserver_url: str,
) -> None:

View File

@@ -210,18 +210,6 @@ class MatrixClient(nio.AsyncClient):
)
@property
def healthy(self) -> bool:
"""Return whether we're syncing and last sync was successful."""
task = self.sync_task
if not task or not self.first_sync_date or self.last_sync_error:
return False
return not task.done()
@property
def default_device_name(self) -> str:
"""Device name to set at login if the user hasn't set a custom one."""
@@ -734,7 +722,7 @@ class MatrixClient(nio.AsyncClient):
upload_item.status = UploadStatus.Caching
local_media = await Media.from_existing_file(
self.backend.media_cache, url, path,
self.backend.media_cache, self.user_id, url, path,
)
kind = (mime or "").split("/")[0]
@@ -816,6 +804,7 @@ class MatrixClient(nio.AsyncClient):
await Thumbnail.from_bytes(
self.backend.media_cache,
self.user_id,
thumb_url,
path.name,
thumb_data,

View File

@@ -49,29 +49,31 @@ class MediaCache:
async def get_media(
self,
mxc: str,
title: str,
crypt_dict: CryptDict = None,
client_user_id: str,
mxc: str,
title: str,
crypt_dict: CryptDict = None,
) -> Path:
"""Return `Media.get()`'s result. Intended for QML."""
return await Media(self, mxc, title, crypt_dict).get()
return await Media(self, client_user_id, mxc, title, crypt_dict).get()
async def get_thumbnail(
self,
mxc: str,
title: str,
width: int,
height: int,
crypt_dict: CryptDict = None,
client_user_id: str,
mxc: str,
title: str,
width: int,
height: int,
crypt_dict: CryptDict = None,
) -> Path:
"""Return `Thumbnail.get()`'s result. Intended for QML."""
thumb = Thumbnail(
# QML sometimes pass float sizes, which matrix API doesn't like.
self, mxc, title, crypt_dict, (round(width), round(height)),
)
# QML sometimes pass float sizes, which matrix API doesn't like.
size = (round(width), round(height))
thumb = Thumbnail(self, client_user_id, mxc, title, crypt_dict, size)
return await thumb.get()
@@ -79,10 +81,11 @@ class MediaCache:
class Media:
"""A matrix media file."""
cache: "MediaCache" = field()
mxc: str = field()
title: str = field()
crypt_dict: CryptDict = field(repr=False)
cache: "MediaCache" = field()
client_user_id: str = field()
mxc: str = field()
title: str = field()
crypt_dict: CryptDict = field(repr=False)
def __post_init__(self) -> None:
@@ -154,7 +157,7 @@ class Media:
parsed = urlparse(self.mxc)
resp = await self.cache.backend.download(
resp = await self.cache.backend.clients[self.client_user_id].download(
server_name = parsed.netloc,
media_id = parsed.path.lstrip("/"),
)
@@ -183,15 +186,18 @@ class Media:
@classmethod
async def from_existing_file(
cls,
cache: "MediaCache",
mxc: str,
existing: Path,
overwrite: bool = False,
cache: "MediaCache",
client_user_id: str,
mxc: str,
existing: Path,
overwrite: bool = False,
**kwargs,
) -> "Media":
"""Copy an existing file to cache and return a `Media` for it."""
media = cls(cache, mxc, existing.name, {}, **kwargs) # type: ignore
media = cls(
cache, client_user_id, mxc, existing.name, {}, **kwargs,
) # type: ignore
media.local_path.parent.mkdir(parents=True, exist_ok=True)
if not media.local_path.exists() or overwrite:
@@ -204,16 +210,19 @@ class Media:
@classmethod
async def from_bytes(
cls,
cache: "MediaCache",
mxc: str,
filename: str,
data: bytes,
overwrite: bool = False,
cache: "MediaCache",
client_user_id: str,
mxc: str,
filename: str,
data: bytes,
overwrite: bool = False,
**kwargs,
) -> "Media":
"""Create a cached file from bytes data and return a `Media` for it."""
media = cls(cache, mxc, filename, {}, **kwargs) # type: ignore
media = cls(
cache, client_user_id, mxc, filename, {}, **kwargs,
) # type: ignore
media.local_path.parent.mkdir(parents=True, exist_ok=True)
if not media.local_path.exists() or overwrite:
@@ -230,11 +239,12 @@ class Media:
class Thumbnail(Media):
"""The thumbnail of a matrix media, which is a media itself."""
cache: "MediaCache" = field()
mxc: str = field()
title: str = field()
crypt_dict: CryptDict = field(repr=False)
wanted_size: Size = field()
cache: "MediaCache" = field()
client_user_id: str = field()
mxc: str = field()
title: str = field()
crypt_dict: CryptDict = field(repr=False)
wanted_size: Size = field()
server_size: Optional[Size] = field(init=False, repr=False, default=None)
@@ -322,16 +332,17 @@ class Thumbnail(Media):
"""Return the (decrypted) media file's content from the server."""
parsed = urlparse(self.mxc)
client = self.cache.backend.clients[self.client_user_id]
if self.crypt_dict:
# Matrix makes encrypted thumbs only available through the download
# end-point, not the thumbnail one
resp = await self.cache.backend.download(
resp = await client.download(
server_name = parsed.netloc,
media_id = parsed.path.lstrip("/"),
)
else:
resp = await self.cache.backend.thumbnail(
resp = await client.thumbnail(
server_name = parsed.netloc,
media_id = parsed.path.lstrip("/"),
width = self.wanted_size[0],

View File

@@ -192,10 +192,11 @@ class NioCallbacks:
try:
media_local_path: Union[Path, str] = await Media(
cache = self.client.backend.media_cache,
mxc = ev.url,
title = ev.body,
crypt_dict = media_crypt_dict,
cache = self.client.backend.media_cache,
client_user_id = self.user_id,
mxc = ev.url,
title = ev.body,
crypt_dict = media_crypt_dict,
).get_local()
except FileNotFoundError:
media_local_path = ""