# SPDX-License-Identifier: LGPL-3.0-or-later """Various utilities that are used throughout the package.""" import collections import html import inspect import io import json import xml.etree.cElementTree as xml_etree # FIXME: bandit warning from datetime import timedelta from enum import Enum from enum import auto as autostr from pathlib import Path from types import ModuleType from typing import Any, Dict, Tuple, Type from uuid import UUID import filetype from aiofiles.threadpool.binary import AsyncBufferedReader from nio.crypto import AsyncDataT as File from nio.crypto import async_generator_from_data from .models.model_item import ModelItem Size = Tuple[int, int] auto = autostr class AutoStrEnum(Enum): """An Enum where auto() assigns the member's name instead of an int. Example: >>> class Fruits(AutoStrEnum): apple = auto() >>> Fruits.apple.value "apple" """ @staticmethod def _generate_next_value_(name, *_): return name def dict_update_recursive(dict1: dict, dict2: dict) -> None: """Deep-merge `dict1` and `dict2`, recursive version of `dict.update()`.""" # https://gist.github.com/angstwad/bf22d1822c38a92ec0a9 for k in dict2: if (k in dict1 and isinstance(dict1[k], dict) and isinstance(dict2[k], collections.Mapping)): dict_update_recursive(dict1[k], dict2[k]) else: dict1[k] = dict2[k] async def is_svg(file: File) -> bool: """Return whether the file is a SVG (`lxml` is used for detection).""" chunks = [c async for c in async_generator_from_data(file)] with io.BytesIO(b"".join(chunks)) as file: try: _, element = next(xml_etree.iterparse(file, ("start",))) return element.tag == "{http://www.w3.org/2000/svg}svg" except (StopIteration, xml_etree.ParseError): return False async def svg_dimensions(file: File) -> Size: """Return the width and height, or viewBox width and height for a SVG. If these properties are missing (broken file), ``(256, 256)`` is returned. """ chunks = [c async for c in async_generator_from_data(file)] with io.BytesIO(b"".join(chunks)) as file: attrs = xml_etree.parse(file).getroot().attrib try: width = round(float(attrs.get("width", attrs["viewBox"].split()[3]))) except (KeyError, IndexError, ValueError, TypeError): width = 256 try: height = round(float(attrs.get("height", attrs["viewBox"].split()[4]))) except (KeyError, IndexError, ValueError, TypeError): height = 256 return (width, height) async def guess_mime(file: File) -> str: """Return the file's mimetype, or `application/octet-stream` if unknown.""" if isinstance(file, io.IOBase): file.seek(0, 0) elif isinstance(file, AsyncBufferedReader): await file.seek(0, 0) try: first_chunk: bytes async for first_chunk in async_generator_from_data(file): break else: return "inode/x-empty" # empty file # TODO: plaintext mime = filetype.guess_mime(first_chunk) return mime or ( "image/svg+xml" if await is_svg(file) else "application/octet-stream" ) finally: if isinstance(file, io.IOBase): file.seek(0, 0) elif isinstance(file, AsyncBufferedReader): await file.seek(0, 0) def plain2html(text: str) -> str: """Convert `\\n` into `
` tags and `\\t` into four spaces.""" return html.escape(text)\ .replace("\n", "
")\ .replace("\t", " " * 4) def serialize_value_for_qml(value: Any, json_lists: bool = False) -> Any: """Convert a value to make it easier to use from QML. Returns: - Return the member's actual value for `Enum` members - A `file://...` string for `Path` objects - Strings for `UUID` objects - A number of milliseconds for `datetime.timedelta` objects - The class `__name__` for class types. - `ModelItem.serialized` for `ModelItem`s """ if json_lists and isinstance(value, list): return json.dumps(value) if hasattr(value, "__class__") and issubclass(value.__class__, Enum): return value.value if isinstance(value, ModelItem): return value.serialized if isinstance(value, Path): return f"file://{value!s}" if isinstance(value, UUID): return str(value) if isinstance(value, timedelta): return value.total_seconds() * 1000 if inspect.isclass(value): return value.__name__ return value def classes_defined_in(module: ModuleType) -> Dict[str, Type]: """Return a `{name: class}` dict of all the classes a module defines.""" return { m[0]: m[1] for m in inspect.getmembers(module, inspect.isclass) if not m[0].startswith("_") and m[1].__module__.startswith(module.__name__) }