Fix flake8/mypy/isort code errors from new setup
This commit is contained in:
		@@ -13,11 +13,10 @@ from typing import Any, DefaultDict, Dict, List, Optional, Tuple, Union
 | 
			
		||||
from urllib.parse import urlparse
 | 
			
		||||
 | 
			
		||||
import aiohttp
 | 
			
		||||
import nio
 | 
			
		||||
import plyer
 | 
			
		||||
from appdirs import AppDirs
 | 
			
		||||
 | 
			
		||||
import nio
 | 
			
		||||
 | 
			
		||||
from . import __app_name__
 | 
			
		||||
from .errors import MatrixError, MatrixInvalidAccessToken
 | 
			
		||||
from .matrix_client import MatrixClient
 | 
			
		||||
@@ -29,7 +28,7 @@ from .models.model import Model
 | 
			
		||||
from .models.model_store import ModelStore
 | 
			
		||||
from .presence import Presence
 | 
			
		||||
from .sso_server import SSOServer
 | 
			
		||||
from .user_files import Accounts, History, Theme, Settings, UIState
 | 
			
		||||
from .user_files import Accounts, History, Settings, Theme, UIState
 | 
			
		||||
 | 
			
		||||
# Logging configuration
 | 
			
		||||
log.getLogger().setLevel(log.INFO)
 | 
			
		||||
@@ -195,7 +194,7 @@ class Backend:
 | 
			
		||||
 | 
			
		||||
    async def password_auth(
 | 
			
		||||
        self, user: str, password: str, homeserver: str,
 | 
			
		||||
   ) -> str:
 | 
			
		||||
    ) -> str:
 | 
			
		||||
        """Create & register a `MatrixClient`, login using the password
 | 
			
		||||
        and return the user ID we get.
 | 
			
		||||
        """
 | 
			
		||||
@@ -437,7 +436,9 @@ class Backend:
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    async def set_string_filter(self, model_id: SyncId, value: str) -> None:
 | 
			
		||||
    async def set_string_filter(
 | 
			
		||||
        self, model_id: Union[SyncId, List[str]], value: str,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """Set a FieldStringFilter (or derived class) model's filter property.
 | 
			
		||||
 | 
			
		||||
        This should only be called from QML.
 | 
			
		||||
@@ -475,7 +476,7 @@ class Backend:
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                await session.get(f"{homeserver_url}/_matrix/client/versions")
 | 
			
		||||
            except Exception as err:
 | 
			
		||||
            except aiohttp.ClientError as err:
 | 
			
		||||
                log.warning("Failed pinging %s: %r", homeserver_url, err)
 | 
			
		||||
                item.status = PingStatus.Failed
 | 
			
		||||
                return
 | 
			
		||||
 
 | 
			
		||||
@@ -10,11 +10,10 @@ from urllib.parse import unquote
 | 
			
		||||
import html_sanitizer.sanitizer as sanitizer
 | 
			
		||||
import lxml.html  # nosec
 | 
			
		||||
import mistune
 | 
			
		||||
import nio
 | 
			
		||||
from html_sanitizer.sanitizer import Sanitizer
 | 
			
		||||
from lxml.html import HtmlElement, etree  # nosec
 | 
			
		||||
 | 
			
		||||
import nio
 | 
			
		||||
 | 
			
		||||
from .svg_colors import SVG_COLORS
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -113,7 +112,7 @@ class HTMLProcessor:
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    block_tags = {
 | 
			
		||||
        "h1", "h2", "h3", "h4", "h5", "h6","blockquote",
 | 
			
		||||
        "h1", "h2", "h3", "h4", "h5", "h6", "blockquote",
 | 
			
		||||
        "p", "ul", "ol", "li", "hr", "br", "img",
 | 
			
		||||
        "table", "thead", "tbody", "tr", "th", "td", "pre",
 | 
			
		||||
        "mx-reply",
 | 
			
		||||
 
 | 
			
		||||
@@ -25,14 +25,12 @@ from typing import (
 | 
			
		||||
from urllib.parse import urlparse
 | 
			
		||||
from uuid import UUID, uuid4
 | 
			
		||||
 | 
			
		||||
import aiofiles
 | 
			
		||||
import cairosvg
 | 
			
		||||
from PIL import Image as PILImage
 | 
			
		||||
from pymediainfo import MediaInfo
 | 
			
		||||
 | 
			
		||||
import nio
 | 
			
		||||
from nio.crypto import AsyncDataT as UploadData
 | 
			
		||||
from nio.crypto import async_generator_from_data
 | 
			
		||||
from PIL import Image as PILImage
 | 
			
		||||
from pymediainfo import MediaInfo
 | 
			
		||||
 | 
			
		||||
from . import __display_name__, __reverse_dns__, utils
 | 
			
		||||
from .errors import (
 | 
			
		||||
@@ -69,16 +67,18 @@ PathCallable = Union[
 | 
			
		||||
 | 
			
		||||
IS_WINDOWS = platform.system() == "Windows"
 | 
			
		||||
 | 
			
		||||
MATRIX_TO = "https://matrix.to/#"
 | 
			
		||||
 | 
			
		||||
REPLY_FALLBACK = (
 | 
			
		||||
"<mx-reply>"
 | 
			
		||||
    "<blockquote>"
 | 
			
		||||
        '<a href="https://matrix.to/#/{room_id}/{event_id}">In reply to</a> '
 | 
			
		||||
        '<a href="https://matrix.to/#/{user_id}">{user_id}</a>'
 | 
			
		||||
        "<br>"
 | 
			
		||||
        "{content}"
 | 
			
		||||
    "</blockquote>"
 | 
			
		||||
"</mx-reply>"
 | 
			
		||||
"{reply_content}"
 | 
			
		||||
    "<mx-reply>"
 | 
			
		||||
        "<blockquote>"
 | 
			
		||||
            '<a href="{matrix_to}/{room_id}/{event_id}">In reply to</a> '
 | 
			
		||||
            '<a href="{matrix_to}/{user_id}">{user_id}</a>'
 | 
			
		||||
            "<br>"
 | 
			
		||||
            "{content}"
 | 
			
		||||
        "</blockquote>"
 | 
			
		||||
    "</mx-reply>"
 | 
			
		||||
    "{reply_content}"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -88,6 +88,7 @@ class SyncFilterIds(NamedTuple):
 | 
			
		||||
    first:  str
 | 
			
		||||
    others: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UploadReturn(NamedTuple):
 | 
			
		||||
    """Details for an uploaded file."""
 | 
			
		||||
 | 
			
		||||
@@ -161,10 +162,10 @@ class MatrixClient(nio.AsyncClient):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
         backend,
 | 
			
		||||
         user:       str           = "",
 | 
			
		||||
         homeserver: str           = "https://matrix.org",
 | 
			
		||||
         device_id:  Optional[str] = None,
 | 
			
		||||
        backend,
 | 
			
		||||
        user:       str           = "",
 | 
			
		||||
        homeserver: str           = "https://matrix.org",
 | 
			
		||||
        device_id:  Optional[str] = None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
 | 
			
		||||
        store = Path(backend.appdirs.user_data_dir) / "encryption"
 | 
			
		||||
@@ -364,7 +365,7 @@ class MatrixClient(nio.AsyncClient):
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                account.max_upload_size = future.result() or 0
 | 
			
		||||
            except Exception:
 | 
			
		||||
            except MatrixError:
 | 
			
		||||
                trace = traceback.format_exc().rstrip()
 | 
			
		||||
                log.warn(
 | 
			
		||||
                    "On %s server config retrieval: %s", self.user_id, trace,
 | 
			
		||||
@@ -403,8 +404,9 @@ class MatrixClient(nio.AsyncClient):
 | 
			
		||||
                    sync_filter       = sync_filter_ids.others,
 | 
			
		||||
                ))
 | 
			
		||||
                await self.sync_task
 | 
			
		||||
                self.last_sync_error = None
 | 
			
		||||
                break  # task cancelled
 | 
			
		||||
            except Exception as err:
 | 
			
		||||
            except Exception as err:  # noqa
 | 
			
		||||
                self.last_sync_error = err
 | 
			
		||||
 | 
			
		||||
                trace = traceback.format_exc().rstrip()
 | 
			
		||||
@@ -417,8 +419,6 @@ class MatrixClient(nio.AsyncClient):
 | 
			
		||||
                    )
 | 
			
		||||
                else:
 | 
			
		||||
                    LoopException(str(err), err, trace)
 | 
			
		||||
            else:
 | 
			
		||||
                self.last_sync_error = None
 | 
			
		||||
 | 
			
		||||
            await asyncio.sleep(5)
 | 
			
		||||
 | 
			
		||||
@@ -562,7 +562,7 @@ class MatrixClient(nio.AsyncClient):
 | 
			
		||||
 | 
			
		||||
        if text.startswith("/me ") and not escape:
 | 
			
		||||
            event_type = nio.RoomMessageEmote
 | 
			
		||||
            text       = text[len("/me "): ]
 | 
			
		||||
            text       = text[len("/me "):]
 | 
			
		||||
            content    = {"body": text, "msgtype": "m.emote"}
 | 
			
		||||
            to_html    = from_md(text, inline=True, outgoing=True)
 | 
			
		||||
            echo_body  = from_md(text, inline=True)
 | 
			
		||||
@@ -586,14 +586,15 @@ class MatrixClient(nio.AsyncClient):
 | 
			
		||||
            plain_source_body = "\n".join(
 | 
			
		||||
                f"> <{to.sender_id}> {line}" if i == 0 else f"> {line}"
 | 
			
		||||
                for i, line in enumerate(source_body.splitlines())
 | 
			
		||||
        )
 | 
			
		||||
            )
 | 
			
		||||
            content["body"]   = f"{plain_source_body}\n\n{text}"
 | 
			
		||||
 | 
			
		||||
            to_html = REPLY_FALLBACK.format(
 | 
			
		||||
                room_id  = room_id,
 | 
			
		||||
                event_id = to.event_id,
 | 
			
		||||
                user_id  = to.sender_id,
 | 
			
		||||
                content  =
 | 
			
		||||
                matrix_to = MATRIX_TO,
 | 
			
		||||
                room_id   = room_id,
 | 
			
		||||
                event_id  = to.event_id,
 | 
			
		||||
                user_id   = to.sender_id,
 | 
			
		||||
                content   =
 | 
			
		||||
                    getattr(to.source, "formatted_body", "") or
 | 
			
		||||
                    source_body or
 | 
			
		||||
                    html.escape(to.source.source["type"] if to.source else ""),
 | 
			
		||||
@@ -605,7 +606,7 @@ class MatrixClient(nio.AsyncClient):
 | 
			
		||||
            content["formatted_body"] = HTML.filter(to_html, outgoing=True)
 | 
			
		||||
 | 
			
		||||
            content["m.relates_to"] = {
 | 
			
		||||
                "m.in_reply_to": { "event_id": to.event_id },
 | 
			
		||||
                "m.in_reply_to": {"event_id": to.event_id},
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        # Can't use the standard Matrix transaction IDs; they're only visible
 | 
			
		||||
@@ -662,7 +663,7 @@ class MatrixClient(nio.AsyncClient):
 | 
			
		||||
                # optimize is too slow for large images
 | 
			
		||||
                compressed = await utils.compress_image(image, optimize=False)
 | 
			
		||||
 | 
			
		||||
                async with aiofiles.open(temp.name, "wb") as file:
 | 
			
		||||
                async with utils.aiopen(temp.name, "wb") as file:
 | 
			
		||||
                    await file.write(compressed)
 | 
			
		||||
 | 
			
		||||
                return Path(temp.name)
 | 
			
		||||
@@ -814,7 +815,7 @@ class MatrixClient(nio.AsyncClient):
 | 
			
		||||
                )
 | 
			
		||||
            except UneededThumbnail:
 | 
			
		||||
                pass
 | 
			
		||||
            except Exception:
 | 
			
		||||
            except Exception:  # noqa
 | 
			
		||||
                trace = traceback.format_exc().rstrip()
 | 
			
		||||
                log.warning("Failed thumbnailing %s:\n%s", path, trace)
 | 
			
		||||
            else:
 | 
			
		||||
@@ -1366,7 +1367,8 @@ class MatrixClient(nio.AsyncClient):
 | 
			
		||||
 | 
			
		||||
                event.set_fields(
 | 
			
		||||
                    content = await self.get_redacted_event_content(
 | 
			
		||||
                        event.event_type, self.user_id, event.sender_id,reason,
 | 
			
		||||
                        event.event_type, self.user_id, event.sender_id,
 | 
			
		||||
                        reason,
 | 
			
		||||
                    ),
 | 
			
		||||
 | 
			
		||||
                    event_type       = nio.RedactedEvent,
 | 
			
		||||
@@ -1568,7 +1570,7 @@ class MatrixClient(nio.AsyncClient):
 | 
			
		||||
            account.status_msg = status_msg
 | 
			
		||||
 | 
			
		||||
            await super().set_presence(
 | 
			
		||||
                "offline"  if presence == "invisible" else presence,
 | 
			
		||||
                "offline" if presence == "invisible" else presence,
 | 
			
		||||
                status_msg,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
@@ -2070,7 +2072,7 @@ class MatrixClient(nio.AsyncClient):
 | 
			
		||||
        image_data = None
 | 
			
		||||
        create     = False
 | 
			
		||||
 | 
			
		||||
        async with aiofiles.open(avatar_path, "rb") as file:
 | 
			
		||||
        async with utils.aiopen(avatar_path, "rb") as file:
 | 
			
		||||
            if await utils.is_svg(file):
 | 
			
		||||
                await file.seek(0, 0)
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -14,9 +14,8 @@ from pathlib import Path
 | 
			
		||||
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, Optional
 | 
			
		||||
from urllib.parse import urlparse
 | 
			
		||||
 | 
			
		||||
from PIL import Image as PILImage
 | 
			
		||||
 | 
			
		||||
import nio
 | 
			
		||||
from PIL import Image as PILImage
 | 
			
		||||
 | 
			
		||||
from .utils import Size, atomic_write
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -11,7 +11,6 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
 | 
			
		||||
from uuid import UUID
 | 
			
		||||
 | 
			
		||||
import lxml  # nosec
 | 
			
		||||
 | 
			
		||||
import nio
 | 
			
		||||
 | 
			
		||||
from ..presence import Presence
 | 
			
		||||
 
 | 
			
		||||
@@ -3,7 +3,7 @@
 | 
			
		||||
 | 
			
		||||
from collections import UserDict
 | 
			
		||||
from dataclasses import dataclass, field
 | 
			
		||||
from typing import Dict
 | 
			
		||||
from typing import Dict, List, Union
 | 
			
		||||
 | 
			
		||||
from . import SyncId
 | 
			
		||||
from .model import Model
 | 
			
		||||
@@ -47,7 +47,7 @@ class ModelStore(UserDict):
 | 
			
		||||
        elif is_tuple and len(key) == 3 and key[2] == "autocompleted_members":
 | 
			
		||||
            model = AutoCompletedMembers(user_id=key[0], room_id=key[1])
 | 
			
		||||
        else:
 | 
			
		||||
            model = Model(sync_id=key)  # type: ignore
 | 
			
		||||
            model = Model(sync_id=key)
 | 
			
		||||
 | 
			
		||||
        self.data[key] = model
 | 
			
		||||
        return model
 | 
			
		||||
@@ -62,7 +62,9 @@ class ModelStore(UserDict):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    async def ensure_exists_from_qml(self, sync_id: SyncId) -> None:
 | 
			
		||||
    async def ensure_exists_from_qml(
 | 
			
		||||
        self, sync_id: Union[SyncId, List[str]],
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """Create model if it doesn't exist. Should only be called by QML."""
 | 
			
		||||
 | 
			
		||||
        if isinstance(sync_id, list):  # QML can't pass tuples
 | 
			
		||||
 
 | 
			
		||||
@@ -40,7 +40,7 @@ class AllRooms(FieldSubstringFilter):
 | 
			
		||||
        return source.sync_id == "accounts" or (
 | 
			
		||||
            isinstance(source.sync_id, tuple) and
 | 
			
		||||
            len(source.sync_id) == 2 and
 | 
			
		||||
            source.sync_id[1] == "rooms"  # type: ignore
 | 
			
		||||
            source.sync_id[1] == "rooms"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -366,9 +366,9 @@ class NioCallbacks:
 | 
			
		||||
 | 
			
		||||
        def lvl(level: int) -> str:
 | 
			
		||||
            return (
 | 
			
		||||
                f"Admin ({level})"     if level == 100 else
 | 
			
		||||
                f"Admin ({level})" if level == 100 else
 | 
			
		||||
                f"Moderator ({level})" if level >= 50 else
 | 
			
		||||
                f"User ({level})"      if level >= 0 else
 | 
			
		||||
                f"User ({level})" if level >= 0 else
 | 
			
		||||
                f"Muted ({level})"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
@@ -788,8 +788,9 @@ class NioCallbacks:
 | 
			
		||||
            datetime.now() - timedelta(milliseconds=ev.last_active_ago)
 | 
			
		||||
        ) if ev.last_active_ago else datetime.fromtimestamp(0)
 | 
			
		||||
 | 
			
		||||
        presence.presence = Presence.State(ev.presence) if ev.presence\
 | 
			
		||||
                                    else Presence.State.offline
 | 
			
		||||
        presence.presence = \
 | 
			
		||||
            Presence.State(ev.presence) if ev.presence else \
 | 
			
		||||
            Presence.State.offline
 | 
			
		||||
 | 
			
		||||
        # Add all existing members related to this presence
 | 
			
		||||
        for room_id in self.models[self.user_id, "rooms"]:
 | 
			
		||||
@@ -818,9 +819,9 @@ class NioCallbacks:
 | 
			
		||||
 | 
			
		||||
            # Set status_msg if none is set on the server and we have one
 | 
			
		||||
            if (
 | 
			
		||||
                not presence.status_msg                           and
 | 
			
		||||
                account.status_msg                                and
 | 
			
		||||
                ev.user_id in self.client.backend.clients         and
 | 
			
		||||
                not presence.status_msg and
 | 
			
		||||
                account.status_msg and
 | 
			
		||||
                ev.user_id in self.client.backend.clients and
 | 
			
		||||
                account.presence != Presence.State.echo_invisible and
 | 
			
		||||
                presence.presence == Presence.State.offline
 | 
			
		||||
            ):
 | 
			
		||||
 
 | 
			
		||||
@@ -364,7 +364,7 @@ class Section(MutableMapping):
 | 
			
		||||
                for name in child.inherit_from.dumps().split(","):
 | 
			
		||||
                    name = name.strip()
 | 
			
		||||
 | 
			
		||||
                    if root_arg is not None and name:
 | 
			
		||||
                    if name:
 | 
			
		||||
                        child_inherit.append(type(attrgetter(name)(root_arg)))
 | 
			
		||||
 | 
			
		||||
                instance._set_section(section.from_source_code(
 | 
			
		||||
 
 | 
			
		||||
@@ -89,7 +89,7 @@ class QMLBridge:
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                result = future.result()
 | 
			
		||||
            except Exception as err:
 | 
			
		||||
            except Exception as err:  # noqa
 | 
			
		||||
                exception = err
 | 
			
		||||
                trace     = traceback.format_exc().rstrip()
 | 
			
		||||
 | 
			
		||||
@@ -173,7 +173,7 @@ class QMLBridge:
 | 
			
		||||
            asyncio.run_coroutine_threadsafe(
 | 
			
		||||
                self.backend.terminate_clients(), self._loop,
 | 
			
		||||
            ).result()
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
        except Exception as e:  # noqa
 | 
			
		||||
            print(e)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@
 | 
			
		||||
import asyncio
 | 
			
		||||
from http.server import BaseHTTPRequestHandler, HTTPServer
 | 
			
		||||
from urllib.parse import parse_qs, quote, urlparse
 | 
			
		||||
 | 
			
		||||
from . import __display_name__
 | 
			
		||||
 | 
			
		||||
_SUCCESS_HTML_PAGE = """<!DOCTYPE html>
 | 
			
		||||
 
 | 
			
		||||
@@ -14,15 +14,15 @@ from typing import (
 | 
			
		||||
    TYPE_CHECKING, Any, ClassVar, Dict, Iterator, Optional, Tuple,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
import aiofiles
 | 
			
		||||
from watchgod import Change, awatch
 | 
			
		||||
 | 
			
		||||
import pyotherside
 | 
			
		||||
from watchgod import Change, awatch
 | 
			
		||||
 | 
			
		||||
from .pcn.section import Section
 | 
			
		||||
from .pyotherside_events import LoopException, UserFileChanged
 | 
			
		||||
from .theme_parser import convert_to_qml
 | 
			
		||||
from .utils import atomic_write, deep_serialize_for_qml, dict_update_recursive
 | 
			
		||||
from .utils import (
 | 
			
		||||
    aiopen, atomic_write, deep_serialize_for_qml, dict_update_recursive,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from .backend import Backend
 | 
			
		||||
@@ -122,7 +122,7 @@ class UserFile:
 | 
			
		||||
                            ignored += 1
 | 
			
		||||
                            continue
 | 
			
		||||
 | 
			
		||||
                        async with aiofiles.open(self.path) as file:
 | 
			
		||||
                        async with aiopen(self.path) as file:
 | 
			
		||||
                            text            = await file.read()
 | 
			
		||||
                            self.data, save = self.deserialized(text)
 | 
			
		||||
 | 
			
		||||
@@ -138,7 +138,7 @@ class UserFile:
 | 
			
		||||
                if changes and ignored < len(changes):
 | 
			
		||||
                    UserFileChanged(type(self), self.qml_data)
 | 
			
		||||
 | 
			
		||||
            except Exception as err:
 | 
			
		||||
            except Exception as err:  # noqa
 | 
			
		||||
                LoopException(str(err), err, traceback.format_exc().rstrip())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -159,7 +159,7 @@ class UserFile:
 | 
			
		||||
                    self._need_write = False
 | 
			
		||||
                    self._mtime      = self.write_path.stat().st_mtime
 | 
			
		||||
 | 
			
		||||
            except Exception as err:
 | 
			
		||||
            except Exception as err:  # noqa
 | 
			
		||||
                self._need_write = False
 | 
			
		||||
                LoopException(str(err), err, traceback.format_exc().rstrip())
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -10,7 +10,7 @@ import inspect
 | 
			
		||||
import io
 | 
			
		||||
import json
 | 
			
		||||
import sys
 | 
			
		||||
import xml.etree.cElementTree as xml_etree  # FIXME: bandit warning
 | 
			
		||||
import xml.etree.cElementTree as xml_etree
 | 
			
		||||
from concurrent.futures import ProcessPoolExecutor
 | 
			
		||||
from contextlib import suppress
 | 
			
		||||
from datetime import date, datetime, time, timedelta
 | 
			
		||||
@@ -27,8 +27,7 @@ from uuid import UUID
 | 
			
		||||
 | 
			
		||||
import aiofiles
 | 
			
		||||
import filetype
 | 
			
		||||
from aiofiles.threadpool.binary import AsyncBufferedReader
 | 
			
		||||
from aiofiles.threadpool.text import AsyncTextIOWrapper
 | 
			
		||||
from aiofiles.threadpool.binary import AsyncBufferedIOBase
 | 
			
		||||
from nio.crypto import AsyncDataT as File
 | 
			
		||||
from nio.crypto import async_generator_from_data
 | 
			
		||||
from PIL import Image as PILImage
 | 
			
		||||
@@ -38,10 +37,9 @@ if sys.version_info >= (3, 7):
 | 
			
		||||
else:
 | 
			
		||||
    from async_generator import asynccontextmanager
 | 
			
		||||
 | 
			
		||||
AsyncOpenFile = Union[AsyncTextIOWrapper, AsyncBufferedReader]
 | 
			
		||||
Size          = Tuple[int, int]
 | 
			
		||||
BytesOrPIL    = Union[bytes, PILImage.Image]
 | 
			
		||||
auto          = autostr
 | 
			
		||||
Size       = Tuple[int, int]
 | 
			
		||||
BytesOrPIL = Union[bytes, PILImage.Image]
 | 
			
		||||
auto       = autostr
 | 
			
		||||
 | 
			
		||||
COMPRESSION_POOL = ProcessPoolExecutor()
 | 
			
		||||
 | 
			
		||||
@@ -114,8 +112,8 @@ async def guess_mime(file: File) -> str:
 | 
			
		||||
 | 
			
		||||
    if isinstance(file, io.IOBase):
 | 
			
		||||
        file.seek(0, 0)
 | 
			
		||||
    elif isinstance(file, AsyncBufferedReader):
 | 
			
		||||
        await file.seek(0, 0)
 | 
			
		||||
    elif isinstance(file, AsyncBufferedIOBase):
 | 
			
		||||
        await file.seek(0, 0)  # type: ignore
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        first_chunk: bytes
 | 
			
		||||
@@ -134,8 +132,8 @@ async def guess_mime(file: File) -> str:
 | 
			
		||||
    finally:
 | 
			
		||||
        if isinstance(file, io.IOBase):
 | 
			
		||||
            file.seek(0, 0)
 | 
			
		||||
        elif isinstance(file, AsyncBufferedReader):
 | 
			
		||||
            await file.seek(0, 0)
 | 
			
		||||
        elif isinstance(file, AsyncBufferedIOBase):
 | 
			
		||||
            await file.seek(0, 0)  # type: ignore
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def plain2html(text: str) -> str:
 | 
			
		||||
@@ -250,10 +248,17 @@ def classes_defined_in(module: ModuleType) -> Dict[str, Type]:
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@asynccontextmanager
 | 
			
		||||
async def aiopen(*args, **kwargs) -> AsyncIterator[Any]:
 | 
			
		||||
    """Wrapper for `aiofiles.open()` that doesn't break mypy"""
 | 
			
		||||
    async with aiofiles.open(*args, **kwargs) as file:
 | 
			
		||||
        yield file
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@asynccontextmanager
 | 
			
		||||
async def atomic_write(
 | 
			
		||||
    path: Union[Path, str], binary: bool = False, **kwargs,
 | 
			
		||||
) -> AsyncIterator[Tuple[AsyncOpenFile, Callable[[], None]]]:
 | 
			
		||||
) -> AsyncIterator[Tuple[Any, Callable[[], None]]]:
 | 
			
		||||
    """Write a file asynchronously (using aiofiles) and atomically.
 | 
			
		||||
 | 
			
		||||
    Yields a `(open_temporary_file, done_function)` tuple.
 | 
			
		||||
@@ -279,7 +284,7 @@ async def atomic_write(
 | 
			
		||||
        can_replace = True
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        async with aiofiles.open(temp_path, mode, **kwargs) as out:
 | 
			
		||||
        async with aiopen(temp_path, mode, **kwargs) as out:
 | 
			
		||||
            yield (out, done)
 | 
			
		||||
    finally:
 | 
			
		||||
        if can_replace:
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user