diff --git a/src/python/matrix_client.py b/src/python/matrix_client.py index cd0b4061..96561453 100644 --- a/src/python/matrix_client.py +++ b/src/python/matrix_client.py @@ -11,7 +11,8 @@ from datetime import datetime from functools import partial from pathlib import Path from typing import ( - Any, BinaryIO, DefaultDict, Dict, Optional, Set, Tuple, Type, Union, + Any, AsyncIterable, BinaryIO, DefaultDict, Dict, Optional, Set, Tuple, + Type, Union, ) from urllib.parse import urlparse from uuid import uuid4 @@ -34,7 +35,8 @@ from .models.items import ( from .models.model_store import ModelStore from .pyotherside_events import AlertRequested -CryptDict = Dict[str, Any] +UploadData = Union[bytes, BinaryIO, AsyncIterable[bytes]] +CryptDict = Dict[str, Any] class MatrixClient(nio.AsyncClient): @@ -225,7 +227,7 @@ class MatrixClient(nio.AsyncClient): content["msgtype"] = "m.image" content["info"]["w"], content["info"]["h"] = ( - utils.svg_dimensions(str(path)) if is_svg else + utils.svg_dimensions(path) if is_svg else PILImage.open(path).size ) @@ -508,7 +510,7 @@ class MatrixClient(nio.AsyncClient): try: if is_svg: - svg_width, svg_height = utils.svg_dimensions(str(path)) + svg_width, svg_height = utils.svg_dimensions(path) thumb = PILImage.open(io.BytesIO( cairosvg.svg2png( @@ -581,7 +583,6 @@ class MatrixClient(nio.AsyncClient): with open(path, "rb") as file: mime = utils.guess_mime(file) - file.seek(0, 0) data: Union[BinaryIO, bytes] @@ -604,8 +605,9 @@ class MatrixClient(nio.AsyncClient): ) - async def upload(self, data, mime: str, filename: Optional[str] = None, - ) -> str: + async def upload( + self, data: UploadData, mime: str, filename: Optional[str] = None, + ) -> str: response = await super().upload(data, mime, filename) if isinstance(response, nio.UploadError): diff --git a/src/python/utils.py b/src/python/utils.py index ca533105..c091dbb0 100644 --- a/src/python/utils.py +++ b/src/python/utils.py @@ -96,7 +96,11 @@ def guess_mime(file: File) -> str: elif isinstance(file, io.IOBase): file.seek(0, 0) - return filetype.guess_mime(file) or "application/octet-stream" + try: + return filetype.guess_mime(file) or "application/octet-stream" + finally: + if isinstance(file, io.IOBase): + file.seek(0, 0) def plain2html(text: str) -> str: