Make utils function accept Path, str, bytes and IO

This commit is contained in:
miruka 2019-11-12 09:24:58 -04:00
parent 5832c3ca2d
commit ef391d1eb1

View File

@ -4,6 +4,7 @@ import asyncio
import collections import collections
import html import html
import inspect import inspect
import io
import logging as log import logging as log
import xml.etree.cElementTree as xml_etree # FIXME: bandit warning import xml.etree.cElementTree as xml_etree # FIXME: bandit warning
from enum import Enum from enum import Enum
@ -14,6 +15,7 @@ from typing import IO, Any, Callable, Dict, Tuple, Type, Union
import filetype import filetype
File = Union[IO, bytes, str, Path]
auto = autostr auto = autostr
CANCELLABLE_FUTURES: Dict[Tuple[Any, Callable], asyncio.Future] = {} CANCELLABLE_FUTURES: Dict[Tuple[Any, Callable], asyncio.Future] = {}
@ -44,9 +46,12 @@ def dict_update_recursive(dict1: dict, dict2: dict) -> None:
dict1[k] = dict2[k] dict1[k] = dict2[k]
def is_svg(file: Union[IO, bytes, str]) -> bool: def is_svg(file: File) -> bool:
"""Return True if the file is a SVG. Uses lxml for detection.""" """Return True if the file is a SVG. Uses lxml for detection."""
if isinstance(file, Path):
file = str(file)
try: try:
_, element = next(xml_etree.iterparse(file, ("start",))) _, element = next(xml_etree.iterparse(file, ("start",)))
return element.tag == "{http://www.w3.org/2000/svg}svg" return element.tag == "{http://www.w3.org/2000/svg}svg"
@ -54,11 +59,14 @@ def is_svg(file: Union[IO, bytes, str]) -> bool:
return False return False
def svg_dimensions(file: Union[IO, bytes, str]) -> Tuple[int, int]: def svg_dimensions(file: File) -> Tuple[int, int]:
"""Return the width & height or viewBox width & height for a SVG. """Return the width & height or viewBox width & height for a SVG.
If these properties are missing (broken file), ``(256, 256)`` is returned. If these properties are missing (broken file), ``(256, 256)`` is returned.
""" """
if isinstance(file, Path):
file = str(file)
attrs = xml_etree.parse(file).getroot().attrib attrs = xml_etree.parse(file).getroot().attrib
try: try:
@ -74,7 +82,7 @@ def svg_dimensions(file: Union[IO, bytes, str]) -> Tuple[int, int]:
return (width, height) return (width, height)
def guess_mime(file: IO) -> str: def guess_mime(file: File) -> str:
"""Return the mime type for a file, or application/octet-stream if it """Return the mime type for a file, or application/octet-stream if it
can't be guessed. can't be guessed.
""" """
@ -82,7 +90,11 @@ def guess_mime(file: IO) -> str:
if is_svg(file): if is_svg(file):
return "image/svg+xml" return "image/svg+xml"
file.seek(0, 0) if isinstance(file, Path):
file = str(file)
elif isinstance(file, io.IOBase):
file.seek(0, 0)
return filetype.guess_mime(file) or "application/octet-stream" return filetype.guess_mime(file) or "application/octet-stream"