Clarify upload data type, guess_mime seek on end

This commit is contained in:
miruka 2019-11-12 09:37:21 -04:00
parent 47bfad1d72
commit 37f5f5973c
2 changed files with 14 additions and 8 deletions

View File

@ -11,7 +11,8 @@ from datetime import datetime
from functools import partial
from pathlib import Path
from typing import (
Any, BinaryIO, DefaultDict, Dict, Optional, Set, Tuple, Type, Union,
Any, AsyncIterable, BinaryIO, DefaultDict, Dict, Optional, Set, Tuple,
Type, Union,
)
from urllib.parse import urlparse
from uuid import uuid4
@ -34,7 +35,8 @@ from .models.items import (
from .models.model_store import ModelStore
from .pyotherside_events import AlertRequested
CryptDict = Dict[str, Any]
UploadData = Union[bytes, BinaryIO, AsyncIterable[bytes]]
CryptDict = Dict[str, Any]
class MatrixClient(nio.AsyncClient):
@ -225,7 +227,7 @@ class MatrixClient(nio.AsyncClient):
content["msgtype"] = "m.image"
content["info"]["w"], content["info"]["h"] = (
utils.svg_dimensions(str(path)) if is_svg else
utils.svg_dimensions(path) if is_svg else
PILImage.open(path).size
)
@ -508,7 +510,7 @@ class MatrixClient(nio.AsyncClient):
try:
if is_svg:
svg_width, svg_height = utils.svg_dimensions(str(path))
svg_width, svg_height = utils.svg_dimensions(path)
thumb = PILImage.open(io.BytesIO(
cairosvg.svg2png(
@ -581,7 +583,6 @@ class MatrixClient(nio.AsyncClient):
with open(path, "rb") as file:
mime = utils.guess_mime(file)
file.seek(0, 0)
data: Union[BinaryIO, bytes]
@ -604,8 +605,9 @@ class MatrixClient(nio.AsyncClient):
)
async def upload(self, data, mime: str, filename: Optional[str] = None,
) -> str:
async def upload(
self, data: UploadData, mime: str, filename: Optional[str] = None,
) -> str:
response = await super().upload(data, mime, filename)
if isinstance(response, nio.UploadError):

View File

@ -96,7 +96,11 @@ def guess_mime(file: File) -> str:
elif isinstance(file, io.IOBase):
file.seek(0, 0)
return filetype.guess_mime(file) or "application/octet-stream"
try:
return filetype.guess_mime(file) or "application/octet-stream"
finally:
if isinstance(file, io.IOBase):
file.seek(0, 0)
def plain2html(text: str) -> str: