Fixed indentation (w). Probably fixed redactions sometimes displaying viewing user's name in place of actor's name. Fixed room history never loading sometimes (but not missing chunks in the middle yet).

This commit is contained in:
Zergling_man 2023-10-27 20:25:20 +11:00
parent bc20e47fb1
commit b6543b09cc
29 changed files with 7158 additions and 7155 deletions

View File

@ -22,42 +22,42 @@ ROOT = Path(__file__).parent
class Watcher(DefaultWatcher): class Watcher(DefaultWatcher):
def accept_change(self, entry: os.DirEntry) -> bool: def accept_change(self, entry: os.DirEntry) -> bool:
path = Path(entry.path) path = Path(entry.path)
for bad in ("src/config", "src/themes"): for bad in ("src/config", "src/themes"):
if path.is_relative_to(ROOT / bad): if path.is_relative_to(ROOT / bad):
return False return False
for good in ("src", "submodules"): for good in ("src", "submodules"):
if path.is_relative_to(ROOT / good): if path.is_relative_to(ROOT / good):
return True return True
return False return False
def should_watch_dir(self, entry: os.DirEntry) -> bool: def should_watch_dir(self, entry: os.DirEntry) -> bool:
return super().should_watch_dir(entry) and self.accept_change(entry) return super().should_watch_dir(entry) and self.accept_change(entry)
def should_watch_file(self, entry: os.DirEntry) -> bool: def should_watch_file(self, entry: os.DirEntry) -> bool:
return super().should_watch_file(entry) and self.accept_change(entry) return super().should_watch_file(entry) and self.accept_change(entry)
def cmd(*parts) -> subprocess.CompletedProcess: def cmd(*parts) -> subprocess.CompletedProcess:
return subprocess.run(parts, cwd=ROOT, check=True) return subprocess.run(parts, cwd=ROOT, check=True)
def run_app(args=sys.argv[1:]) -> None: def run_app(args=sys.argv[1:]) -> None:
print("\n\x1b[36m", "" * term_size().columns, "\x1b[0m\n", sep="") print("\n\x1b[36m", "" * term_size().columns, "\x1b[0m\n", sep="")
with suppress(KeyboardInterrupt): with suppress(KeyboardInterrupt):
cmd("qmake", "moment.pro", "CONFIG+=dev") cmd("qmake", "moment.pro", "CONFIG+=dev")
cmd("make") cmd("make")
cmd("./moment", "-name", "dev", *args) cmd("./moment", "-name", "dev", *args)
if __name__ == "__main__": if __name__ == "__main__":
if len(sys.argv) > 2 and sys.argv[1] in ("-h", "--help"): if len(sys.argv) > 2 and sys.argv[1] in ("-h", "--help"):
print(__doc__) print(__doc__)
else: else:
(ROOT / "Makefile").exists() and cmd("make", "clean") (ROOT / "Makefile").exists() and cmd("make", "clean")
run_process(ROOT, run_app, callback=print, watcher_cls=Watcher) run_process(ROOT, run_app, callback=print, watcher_cls=Watcher)

View File

@ -2,31 +2,31 @@ import json
import yaml import yaml
with open("moment.flatpak.base.yaml") as f: with open("moment.flatpak.base.yaml") as f:
base = yaml.load(f, Loader=yaml.FullLoader) base = yaml.load(f, Loader=yaml.FullLoader)
with open("flatpak-pip.json") as f: with open("flatpak-pip.json") as f:
modules = json.load(f)["modules"] modules = json.load(f)["modules"]
# set some modules in front as dependencies and dropping matrix-nio # set some modules in front as dependencies and dropping matrix-nio
# which is declared separately # which is declared separately
front = [] front = []
back = [] back = []
for m in modules: for m in modules:
n = m["name"] n = m["name"]
if n.startswith("python3-") and \ if n.startswith("python3-") and \
n[len("python3-"):] in ["cffi", "importlib-metadata", "multidict", "pytest-runner", "setuptools-scm"]: n[len("python3-"):] in ["cffi", "importlib-metadata", "multidict", "pytest-runner", "setuptools-scm"]:
front.append(m) front.append(m)
else: else:
back.append(m) back.append(m)
# replace placeholder with modules # replace placeholder with modules
phold = None phold = None
for i in range(len(base["modules"])): for i in range(len(base["modules"])):
if base["modules"][i]["name"] == "PLACEHOLDER PYTHON DEPENDENCIES": if base["modules"][i]["name"] == "PLACEHOLDER PYTHON DEPENDENCIES":
phold = i phold = i
break break
base["modules"] = base["modules"][:i] + front + back + base["modules"][i+1:] base["modules"] = base["modules"][:i] + front + back + base["modules"][i+1:]
with open("moment.flatpak.yaml", "w") as f: with open("moment.flatpak.yaml", "w") as f:
f.write(yaml.dump(base, sort_keys=False, indent=2)) f.write(yaml.dump(base, sort_keys=False, indent=2))

View File

@ -4,29 +4,29 @@ import html
import re import re
from pathlib import Path from pathlib import Path
root = Path(__file__).resolve().parent.parent root = Path(__file__).resolve().parent.parent
title_pattern = re.compile(r"## (\d+\.\d+\.\d+) \((\d{4}-\d\d-\d\d)\)") title_pattern = re.compile(r"## (\d+\.\d+\.\d+) \((\d{4}-\d\d-\d\d)\)")
release_lines = [" <releases>"] release_lines = [" <releases>"]
for line in (root / "docs" / "CHANGELOG.md").read_text().splitlines(): for line in (root / "docs" / "CHANGELOG.md").read_text().splitlines():
match = title_pattern.match(line) match = title_pattern.match(line)
if match: if match:
args = (html.escape(match.group(1)), html.escape(match.group(2))) args = (html.escape(match.group(1)), html.escape(match.group(2)))
release_lines.append(' <release version="%s" date="%s"/>' % args) release_lines.append(' <release version="%s" date="%s"/>' % args)
appdata = root / "packaging" / "moment.metainfo.xml" appdata = root / "packaging" / "moment.metainfo.xml"
in_releases = False in_releases = False
final_lines = [] final_lines = []
for line in appdata.read_text().splitlines(): for line in appdata.read_text().splitlines():
if line == " <releases>": if line == " <releases>":
in_releases = True in_releases = True
final_lines += release_lines final_lines += release_lines
elif line == " </releases>": elif line == " </releases>":
in_releases = False in_releases = False
if not in_releases: if not in_releases:
final_lines.append(line) final_lines.append(line)
appdata.write_text("\n".join(final_lines)) appdata.write_text("\n".join(final_lines))

View File

@ -13,7 +13,7 @@ documentation in the following modules first:
- `nio_callbacks` - `nio_callbacks`
""" """
__app_name__ = "moment" __app_name__ = "moment"
__display_name__ = "Moment" __display_name__ = "Moment"
__reverse_dns__ = "xyz.mx-moment" __reverse_dns__ = "xyz.mx-moment"
__version__ = "0.7.3" __version__ = "0.7.3"

File diff suppressed because it is too large Load Diff

View File

@ -17,442 +17,442 @@ ColorTuple = Tuple[float, float, float, float]
@dataclass(repr=False) @dataclass(repr=False)
class Color: class Color:
"""A color manipulable in HSLuv, HSL, RGB, hexadecimal and by SVG name. """A color manipulable in HSLuv, HSL, RGB, hexadecimal and by SVG name.
The `Color` object constructor accepts hexadecimal string The `Color` object constructor accepts hexadecimal string
("#RGB", "#RRGGBB" or "#RRGGBBAA") or another `Color` to copy. ("#RGB", "#RRGGBB" or "#RRGGBBAA") or another `Color` to copy.
Attributes representing the color in HSLuv, HSL, RGB, hexadecimal and Attributes representing the color in HSLuv, HSL, RGB, hexadecimal and
SVG name formats can be accessed and modified on these `Color` objects. SVG name formats can be accessed and modified on these `Color` objects.
The `hsluv()`/`hsluva()`, `hsl()`/`hsla()` and `rgb()`/`rgba()` The `hsluv()`/`hsluva()`, `hsl()`/`hsla()` and `rgb()`/`rgba()`
functions in this module are provided to create an object by specifying functions in this module are provided to create an object by specifying
a color in other formats. a color in other formats.
Copies of objects with modified attributes can be created with the Copies of objects with modified attributes can be created with the
with the `Color.but()`, `Color.plus()` and `Copy.times()` methods. with the `Color.but()`, `Color.plus()` and `Copy.times()` methods.
If the `hue` is outside of the normal 0-359 range, the number is If the `hue` is outside of the normal 0-359 range, the number is
interpreted as `hue % 360`, e.g. `360` is `0`, `460` is `100`, interpreted as `hue % 360`, e.g. `360` is `0`, `460` is `100`,
or `-20` is `340`. or `-20` is `340`.
""" """
# The saturation and luv are properties due to the need for a setter # The saturation and luv are properties due to the need for a setter
# capping the value between 0-100, as hsluv handles numbers outside # capping the value between 0-100, as hsluv handles numbers outside
# this range incorrectly. # this range incorrectly.
color_or_hex: InitVar[str] = "#00000000" color_or_hex: InitVar[str] = "#00000000"
hue: float = field(init=False, default=0) hue: float = field(init=False, default=0)
_saturation: float = field(init=False, default=0) _saturation: float = field(init=False, default=0)
_luv: float = field(init=False, default=0) _luv: float = field(init=False, default=0)
alpha: float = field(init=False, default=1) alpha: float = field(init=False, default=1)
def __post_init__(self, color_or_hex: Union["Color", str]) -> None: def __post_init__(self, color_or_hex: Union["Color", str]) -> None:
if isinstance(color_or_hex, Color): if isinstance(color_or_hex, Color):
hsluva = color_or_hex.hsluva hsluva = color_or_hex.hsluva
self.hue, self.saturation, self.luv, self.alpha = hsluva self.hue, self.saturation, self.luv, self.alpha = hsluva
else: else:
self.hex = color_or_hex self.hex = color_or_hex
# HSLuv # HSLuv
@property @property
def hsluva(self) -> ColorTuple: def hsluva(self) -> ColorTuple:
return (self.hue, self.saturation, self.luv, self.alpha) return (self.hue, self.saturation, self.luv, self.alpha)
@hsluva.setter @hsluva.setter
def hsluva(self, value: ColorTuple) -> None: def hsluva(self, value: ColorTuple) -> None:
self.hue, self.saturation, self.luv, self.alpha = value self.hue, self.saturation, self.luv, self.alpha = value
@property @property
def saturation(self) -> float: def saturation(self) -> float:
return self._saturation return self._saturation
@saturation.setter @saturation.setter
def saturation(self, value: float) -> None: def saturation(self, value: float) -> None:
self._saturation = max(0, min(100, value)) self._saturation = max(0, min(100, value))
@property @property
def luv(self) -> float: def luv(self) -> float:
return self._luv return self._luv
@luv.setter @luv.setter
def luv(self, value: float) -> None: def luv(self, value: float) -> None:
self._luv = max(0, min(100, value)) self._luv = max(0, min(100, value))
# HSL # HSL
@property @property
def hsla(self) -> ColorTuple: def hsla(self) -> ColorTuple:
r, g, b = (self.red / 255, self.green / 255, self.blue / 255) r, g, b = (self.red / 255, self.green / 255, self.blue / 255)
h, l, s = colorsys.rgb_to_hls(r, g, b) h, l, s = colorsys.rgb_to_hls(r, g, b)
return (h * 360, s * 100, l * 100, self.alpha) return (h * 360, s * 100, l * 100, self.alpha)
@hsla.setter @hsla.setter
def hsla(self, value: ColorTuple) -> None: def hsla(self, value: ColorTuple) -> None:
h, s, l = (value[0] / 360, value[1] / 100, value[2] / 100) # noqa h, s, l = (value[0] / 360, value[1] / 100, value[2] / 100) # noqa
r, g, b = colorsys.hls_to_rgb(h, l, s) r, g, b = colorsys.hls_to_rgb(h, l, s)
self.rgba = (r * 255, g * 255, b * 255, value[3]) self.rgba = (r * 255, g * 255, b * 255, value[3])
@property @property
def light(self) -> float: def light(self) -> float:
return self.hsla[2] return self.hsla[2]
@light.setter @light.setter
def light(self, value: float) -> None: def light(self, value: float) -> None:
self.hsla = (self.hue, self.saturation, value, self.alpha) self.hsla = (self.hue, self.saturation, value, self.alpha)
# RGB # RGB
@property @property
def rgba(self) -> ColorTuple: def rgba(self) -> ColorTuple:
r, g, b = hsluv_to_rgb(self.hsluva) r, g, b = hsluv_to_rgb(self.hsluva)
return r * 255, g * 255, b * 255, self.alpha return r * 255, g * 255, b * 255, self.alpha
@rgba.setter @rgba.setter
def rgba(self, value: ColorTuple) -> None: def rgba(self, value: ColorTuple) -> None:
r, g, b = (value[0] / 255, value[1] / 255, value[2] / 255) r, g, b = (value[0] / 255, value[1] / 255, value[2] / 255)
self.hsluva = rgb_to_hsluv((r, g, b)) + (self.alpha,) self.hsluva = rgb_to_hsluv((r, g, b)) + (self.alpha,)
@property @property
def red(self) -> float: def red(self) -> float:
return self.rgba[0] return self.rgba[0]
@red.setter @red.setter
def red(self, value: float) -> None: def red(self, value: float) -> None:
self.rgba = (value, self.green, self.blue, self.alpha) self.rgba = (value, self.green, self.blue, self.alpha)
@property @property
def green(self) -> float: def green(self) -> float:
return self.rgba[1] return self.rgba[1]
@green.setter @green.setter
def green(self, value: float) -> None: def green(self, value: float) -> None:
self.rgba = (self.red, value, self.blue, self.alpha) self.rgba = (self.red, value, self.blue, self.alpha)
@property @property
def blue(self) -> float: def blue(self) -> float:
return self.rgba[2] return self.rgba[2]
@blue.setter @blue.setter
def blue(self, value: float) -> None: def blue(self, value: float) -> None:
self.rgba = (self.red, self.green, value, self.alpha) self.rgba = (self.red, self.green, value, self.alpha)
# Hexadecimal # Hexadecimal
@property @property
def hex(self) -> str: def hex(self) -> str:
rgb = hsluv_to_hex(self.hsluva) rgb = hsluv_to_hex(self.hsluva)
alpha = builtins.hex(int(self.alpha * 255))[2:] alpha = builtins.hex(int(self.alpha * 255))[2:]
alpha = f"0{alpha}" if len(alpha) == 1 else alpha alpha = f"0{alpha}" if len(alpha) == 1 else alpha
return f"{alpha if self.alpha < 1 else ''}{rgb}".lower() return f"{alpha if self.alpha < 1 else ''}{rgb}".lower()
@hex.setter @hex.setter
def hex(self, value: str) -> None: def hex(self, value: str) -> None:
if len(value) == 4: if len(value) == 4:
template = "#{r}{r}{g}{g}{b}{b}" template = "#{r}{r}{g}{g}{b}{b}"
value = template.format(r=value[1], g=value[2], b=value[3]) value = template.format(r=value[1], g=value[2], b=value[3])
alpha = int(value[-2:] if len(value) == 9 else "ff", 16) / 255 alpha = int(value[-2:] if len(value) == 9 else "ff", 16) / 255
self.hsluva = hex_to_hsluv(value) + (alpha,) self.hsluva = hex_to_hsluv(value) + (alpha,)
# name color # name color
@property @property
def name(self) -> Optional[str]: def name(self) -> Optional[str]:
try: try:
return SVGColor(self.hex).name return SVGColor(self.hex).name
except ValueError: except ValueError:
return None return None
@name.setter @name.setter
def name(self, value: str) -> None: def name(self, value: str) -> None:
self.hex = SVGColor[value.lower()].value.hex self.hex = SVGColor[value.lower()].value.hex
# Other methods # Other methods
def __repr__(self) -> str: def __repr__(self) -> str:
r, g, b = int(self.red), int(self.green), int(self.blue) r, g, b = int(self.red), int(self.green), int(self.blue)
h, s, luv = int(self.hue), int(self.saturation), int(self.luv) h, s, luv = int(self.hue), int(self.saturation), int(self.luv)
l = int(self.light) # noqa l = int(self.light) # noqa
a = self.alpha a = self.alpha
block = f"\x1b[38;2;{r};{g};{b}m█████\x1b[0m" block = f"\x1b[38;2;{r};{g};{b}m█████\x1b[0m"
sep = "\x1b[1;33m/\x1b[0m" sep = "\x1b[1;33m/\x1b[0m"
end = f" {sep} {self.name}" if self.name else "" end = f" {sep} {self.name}" if self.name else ""
# Need a terminal with true color support to render the block! # Need a terminal with true color support to render the block!
return ( return (
f"{block} hsluva({h}, {s}, {luv}, {a}) {sep} " f"{block} hsluva({h}, {s}, {luv}, {a}) {sep} "
f"hsla({h}, {s}, {l}, {a}) {sep} rgba({r}, {g}, {b}, {a}) {sep} " f"hsla({h}, {s}, {l}, {a}) {sep} rgba({r}, {g}, {b}, {a}) {sep} "
f"{self.hex}{end}" f"{self.hex}{end}"
) )
def but( def but(
self, self,
hue: Optional[float] = None, hue: Optional[float] = None,
saturation: Optional[float] = None, saturation: Optional[float] = None,
luv: Optional[float] = None, luv: Optional[float] = None,
alpha: Optional[float] = None, alpha: Optional[float] = None,
*, *,
hsluva: Optional[ColorTuple] = None, hsluva: Optional[ColorTuple] = None,
hsla: Optional[ColorTuple] = None, hsla: Optional[ColorTuple] = None,
rgba: Optional[ColorTuple] = None, rgba: Optional[ColorTuple] = None,
hex: Optional[str] = None, hex: Optional[str] = None,
name: Optional[str] = None, name: Optional[str] = None,
light: Optional[float] = None, light: Optional[float] = None,
red: Optional[float] = None, red: Optional[float] = None,
green: Optional[float] = None, green: Optional[float] = None,
blue: Optional[float] = None, blue: Optional[float] = None,
) -> "Color": ) -> "Color":
"""Return a copy of this `Color` with overriden attributes. """Return a copy of this `Color` with overriden attributes.
Example: Example:
>>> first = Color(100, 50, 50) >>> first = Color(100, 50, 50)
>>> second = c.but(hue=20, saturation=100) >>> second = c.but(hue=20, saturation=100)
>>> second.hsluva >>> second.hsluva
(20, 50, 100, 1) (20, 50, 100, 1)
""" """
new = copy(self) new = copy(self)
for arg, value in locals().items(): for arg, value in locals().items():
if arg not in ("new", "self") and value is not None: if arg not in ("new", "self") and value is not None:
setattr(new, arg, value) setattr(new, arg, value)
return new return new
def plus( def plus(
self, self,
hue: Optional[float] = None, hue: Optional[float] = None,
saturation: Optional[float] = None, saturation: Optional[float] = None,
luv: Optional[float] = None, luv: Optional[float] = None,
alpha: Optional[float] = None, alpha: Optional[float] = None,
*, *,
light: Optional[float] = None, light: Optional[float] = None,
red: Optional[float] = None, red: Optional[float] = None,
green: Optional[float] = None, green: Optional[float] = None,
blue: Optional[float] = None, blue: Optional[float] = None,
) -> "Color": ) -> "Color":
"""Return a copy of this `Color` with values added to attributes. """Return a copy of this `Color` with values added to attributes.
Example: Example:
>>> first = Color(100, 50, 50) >>> first = Color(100, 50, 50)
>>> second = c.plus(hue=10, saturation=-20) >>> second = c.plus(hue=10, saturation=-20)
>>> second.hsluva >>> second.hsluva
(110, 30, 50, 1) (110, 30, 50, 1)
""" """
new = copy(self) new = copy(self)
for arg, value in locals().items(): for arg, value in locals().items():
if arg not in ("new", "self") and value is not None: if arg not in ("new", "self") and value is not None:
setattr(new, arg, getattr(self, arg) + value) setattr(new, arg, getattr(self, arg) + value)
return new return new
def times( def times(
self, self,
hue: Optional[float] = None, hue: Optional[float] = None,
saturation: Optional[float] = None, saturation: Optional[float] = None,
luv: Optional[float] = None, luv: Optional[float] = None,
alpha: Optional[float] = None, alpha: Optional[float] = None,
*, *,
light: Optional[float] = None, light: Optional[float] = None,
red: Optional[float] = None, red: Optional[float] = None,
green: Optional[float] = None, green: Optional[float] = None,
blue: Optional[float] = None, blue: Optional[float] = None,
) -> "Color": ) -> "Color":
"""Return a copy of this `Color` with multiplied attributes. """Return a copy of this `Color` with multiplied attributes.
Example: Example:
>>> first = Color(100, 50, 50, 0.8) >>> first = Color(100, 50, 50, 0.8)
>>> second = c.times(luv=2, alpha=0.5) >>> second = c.times(luv=2, alpha=0.5)
>>> second.hsluva >>> second.hsluva
(100, 50, 100, 0.4) (100, 50, 100, 0.4)
""" """
new = copy(self) new = copy(self)
for arg, value in locals().items(): for arg, value in locals().items():
if arg not in ("new", "self") and value is not None: if arg not in ("new", "self") and value is not None:
setattr(new, arg, getattr(self, arg) * value) setattr(new, arg, getattr(self, arg) * value)
return new return new
class SVGColor(Enum): class SVGColor(Enum):
"""Standard SVG/HTML/CSS colors, with the addition of `transparent`.""" """Standard SVG/HTML/CSS colors, with the addition of `transparent`."""
aliceblue = Color("#f0f8ff") aliceblue = Color("#f0f8ff")
antiquewhite = Color("#faebd7") antiquewhite = Color("#faebd7")
aqua = Color("#00ffff") aqua = Color("#00ffff")
aquamarine = Color("#7fffd4") aquamarine = Color("#7fffd4")
azure = Color("#f0ffff") azure = Color("#f0ffff")
beige = Color("#f5f5dc") beige = Color("#f5f5dc")
bisque = Color("#ffe4c4") bisque = Color("#ffe4c4")
black = Color("#000000") black = Color("#000000")
blanchedalmond = Color("#ffebcd") blanchedalmond = Color("#ffebcd")
blue = Color("#0000ff") blue = Color("#0000ff")
blueviolet = Color("#8a2be2") blueviolet = Color("#8a2be2")
brown = Color("#a52a2a") brown = Color("#a52a2a")
burlywood = Color("#deb887") burlywood = Color("#deb887")
cadetblue = Color("#5f9ea0") cadetblue = Color("#5f9ea0")
chartreuse = Color("#7fff00") chartreuse = Color("#7fff00")
chocolate = Color("#d2691e") chocolate = Color("#d2691e")
coral = Color("#ff7f50") coral = Color("#ff7f50")
cornflowerblue = Color("#6495ed") cornflowerblue = Color("#6495ed")
cornsilk = Color("#fff8dc") cornsilk = Color("#fff8dc")
crimson = Color("#dc143c") crimson = Color("#dc143c")
cyan = Color("#00ffff") cyan = Color("#00ffff")
darkblue = Color("#00008b") darkblue = Color("#00008b")
darkcyan = Color("#008b8b") darkcyan = Color("#008b8b")
darkgoldenrod = Color("#b8860b") darkgoldenrod = Color("#b8860b")
darkgray = Color("#a9a9a9") darkgray = Color("#a9a9a9")
darkgreen = Color("#006400") darkgreen = Color("#006400")
darkgrey = Color("#a9a9a9") darkgrey = Color("#a9a9a9")
darkkhaki = Color("#bdb76b") darkkhaki = Color("#bdb76b")
darkmagenta = Color("#8b008b") darkmagenta = Color("#8b008b")
darkolivegreen = Color("#556b2f") darkolivegreen = Color("#556b2f")
darkorange = Color("#ff8c00") darkorange = Color("#ff8c00")
darkorchid = Color("#9932cc") darkorchid = Color("#9932cc")
darkred = Color("#8b0000") darkred = Color("#8b0000")
darksalmon = Color("#e9967a") darksalmon = Color("#e9967a")
darkseagreen = Color("#8fbc8f") darkseagreen = Color("#8fbc8f")
darkslateblue = Color("#483d8b") darkslateblue = Color("#483d8b")
darkslategray = Color("#2f4f4f") darkslategray = Color("#2f4f4f")
darkslategrey = Color("#2f4f4f") darkslategrey = Color("#2f4f4f")
darkturquoise = Color("#00ced1") darkturquoise = Color("#00ced1")
darkviolet = Color("#9400d3") darkviolet = Color("#9400d3")
deeppink = Color("#ff1493") deeppink = Color("#ff1493")
deepskyblue = Color("#00bfff") deepskyblue = Color("#00bfff")
dimgray = Color("#696969") dimgray = Color("#696969")
dimgrey = Color("#696969") dimgrey = Color("#696969")
dodgerblue = Color("#1e90ff") dodgerblue = Color("#1e90ff")
firebrick = Color("#b22222") firebrick = Color("#b22222")
floralwhite = Color("#fffaf0") floralwhite = Color("#fffaf0")
forestgreen = Color("#228b22") forestgreen = Color("#228b22")
fuchsia = Color("#ff00ff") fuchsia = Color("#ff00ff")
gainsboro = Color("#dcdcdc") gainsboro = Color("#dcdcdc")
ghostwhite = Color("#f8f8ff") ghostwhite = Color("#f8f8ff")
gold = Color("#ffd700") gold = Color("#ffd700")
goldenrod = Color("#daa520") goldenrod = Color("#daa520")
gray = Color("#808080") gray = Color("#808080")
green = Color("#008000") green = Color("#008000")
greenyellow = Color("#adff2f") greenyellow = Color("#adff2f")
grey = Color("#808080") grey = Color("#808080")
honeydew = Color("#f0fff0") honeydew = Color("#f0fff0")
hotpink = Color("#ff69b4") hotpink = Color("#ff69b4")
indianred = Color("#cd5c5c") indianred = Color("#cd5c5c")
indigo = Color("#4b0082") indigo = Color("#4b0082")
ivory = Color("#fffff0") ivory = Color("#fffff0")
khaki = Color("#f0e68c") khaki = Color("#f0e68c")
lavender = Color("#e6e6fa") lavender = Color("#e6e6fa")
lavenderblush = Color("#fff0f5") lavenderblush = Color("#fff0f5")
lawngreen = Color("#7cfc00") lawngreen = Color("#7cfc00")
lemonchiffon = Color("#fffacd") lemonchiffon = Color("#fffacd")
lightblue = Color("#add8e6") lightblue = Color("#add8e6")
lightcoral = Color("#f08080") lightcoral = Color("#f08080")
lightcyan = Color("#e0ffff") lightcyan = Color("#e0ffff")
lightgoldenrodyellow = Color("#fafad2") lightgoldenrodyellow = Color("#fafad2")
lightgray = Color("#d3d3d3") lightgray = Color("#d3d3d3")
lightgreen = Color("#90ee90") lightgreen = Color("#90ee90")
lightgrey = Color("#d3d3d3") lightgrey = Color("#d3d3d3")
lightpink = Color("#ffb6c1") lightpink = Color("#ffb6c1")
lightsalmon = Color("#ffa07a") lightsalmon = Color("#ffa07a")
lightseagreen = Color("#20b2aa") lightseagreen = Color("#20b2aa")
lightskyblue = Color("#87cefa") lightskyblue = Color("#87cefa")
lightslategray = Color("#778899") lightslategray = Color("#778899")
lightslategrey = Color("#778899") lightslategrey = Color("#778899")
lightsteelblue = Color("#b0c4de") lightsteelblue = Color("#b0c4de")
lightyellow = Color("#ffffe0") lightyellow = Color("#ffffe0")
lime = Color("#00ff00") lime = Color("#00ff00")
limegreen = Color("#32cd32") limegreen = Color("#32cd32")
linen = Color("#faf0e6") linen = Color("#faf0e6")
magenta = Color("#ff00ff") magenta = Color("#ff00ff")
maroon = Color("#800000") maroon = Color("#800000")
mediumaquamarine = Color("#66cdaa") mediumaquamarine = Color("#66cdaa")
mediumblue = Color("#0000cd") mediumblue = Color("#0000cd")
mediumorchid = Color("#ba55d3") mediumorchid = Color("#ba55d3")
mediumpurple = Color("#9370db") mediumpurple = Color("#9370db")
mediumseagreen = Color("#3cb371") mediumseagreen = Color("#3cb371")
mediumslateblue = Color("#7b68ee") mediumslateblue = Color("#7b68ee")
mediumspringgreen = Color("#00fa9a") mediumspringgreen = Color("#00fa9a")
mediumturquoise = Color("#48d1cc") mediumturquoise = Color("#48d1cc")
mediumvioletred = Color("#c71585") mediumvioletred = Color("#c71585")
midnightblue = Color("#191970") midnightblue = Color("#191970")
mintcream = Color("#f5fffa") mintcream = Color("#f5fffa")
mistyrose = Color("#ffe4e1") mistyrose = Color("#ffe4e1")
moccasin = Color("#ffe4b5") moccasin = Color("#ffe4b5")
navajowhite = Color("#ffdead") navajowhite = Color("#ffdead")
navy = Color("#000080") navy = Color("#000080")
oldlace = Color("#fdf5e6") oldlace = Color("#fdf5e6")
olive = Color("#808000") olive = Color("#808000")
olivedrab = Color("#6b8e23") olivedrab = Color("#6b8e23")
orange = Color("#ffa500") orange = Color("#ffa500")
orangered = Color("#ff4500") orangered = Color("#ff4500")
orchid = Color("#da70d6") orchid = Color("#da70d6")
palegoldenrod = Color("#eee8aa") palegoldenrod = Color("#eee8aa")
palegreen = Color("#98fb98") palegreen = Color("#98fb98")
paleturquoise = Color("#afeeee") paleturquoise = Color("#afeeee")
palevioletred = Color("#db7093") palevioletred = Color("#db7093")
papayawhip = Color("#ffefd5") papayawhip = Color("#ffefd5")
peachpuff = Color("#ffdab9") peachpuff = Color("#ffdab9")
peru = Color("#cd853f") peru = Color("#cd853f")
pink = Color("#ffc0cb") pink = Color("#ffc0cb")
plum = Color("#dda0dd") plum = Color("#dda0dd")
powderblue = Color("#b0e0e6") powderblue = Color("#b0e0e6")
purple = Color("#800080") purple = Color("#800080")
rebeccapurple = Color("#663399") rebeccapurple = Color("#663399")
red = Color("#ff0000") red = Color("#ff0000")
rosybrown = Color("#bc8f8f") rosybrown = Color("#bc8f8f")
royalblue = Color("#4169e1") royalblue = Color("#4169e1")
saddlebrown = Color("#8b4513") saddlebrown = Color("#8b4513")
salmon = Color("#fa8072") salmon = Color("#fa8072")
sandybrown = Color("#f4a460") sandybrown = Color("#f4a460")
seagreen = Color("#2e8b57") seagreen = Color("#2e8b57")
seashell = Color("#fff5ee") seashell = Color("#fff5ee")
sienna = Color("#a0522d") sienna = Color("#a0522d")
silver = Color("#c0c0c0") silver = Color("#c0c0c0")
skyblue = Color("#87ceeb") skyblue = Color("#87ceeb")
slateblue = Color("#6a5acd") slateblue = Color("#6a5acd")
slategray = Color("#708090") slategray = Color("#708090")
slategrey = Color("#708090") slategrey = Color("#708090")
snow = Color("#fffafa") snow = Color("#fffafa")
springgreen = Color("#00ff7f") springgreen = Color("#00ff7f")
steelblue = Color("#4682b4") steelblue = Color("#4682b4")
tan = Color("#d2b48c") tan = Color("#d2b48c")
teal = Color("#008080") teal = Color("#008080")
thistle = Color("#d8bfd8") thistle = Color("#d8bfd8")
tomato = Color("#ff6347") tomato = Color("#ff6347")
transparent = Color("#00000000") # not standard but exists in QML transparent = Color("#00000000") # not standard but exists in QML
turquoise = Color("#40e0d0") turquoise = Color("#40e0d0")
violet = Color("#ee82ee") violet = Color("#ee82ee")
wheat = Color("#f5deb3") wheat = Color("#f5deb3")
white = Color("#ffffff") white = Color("#ffffff")
whitesmoke = Color("#f5f5f5") whitesmoke = Color("#f5f5f5")
yellow = Color("#ffff00") yellow = Color("#ffff00")
yellowgreen = Color("#9acd32") yellowgreen = Color("#9acd32")
def hsluva( def hsluva(
hue: float = 0, saturation: float = 0, luv: float = 0, alpha: float = 1, hue: float = 0, saturation: float = 0, luv: float = 0, alpha: float = 1,
) -> Color: ) -> Color:
"""Return a `Color` from `(0-359, 0-100, 0-100, 0-1)` HSLuv arguments.""" """Return a `Color` from `(0-359, 0-100, 0-100, 0-1)` HSLuv arguments."""
return Color().but(hue, saturation, luv, alpha) return Color().but(hue, saturation, luv, alpha)
def hsla( def hsla(
hue: float = 0, saturation: float = 0, light: float = 0, alpha: float = 1, hue: float = 0, saturation: float = 0, light: float = 0, alpha: float = 1,
) -> Color: ) -> Color:
"""Return a `Color` from `(0-359, 0-100, 0-100, 0-1)` HSL arguments.""" """Return a `Color` from `(0-359, 0-100, 0-100, 0-1)` HSL arguments."""
return Color().but(hue, saturation, light=light, alpha=alpha) return Color().but(hue, saturation, light=light, alpha=alpha)
def rgba( def rgba(
red: float = 0, green: float = 0, blue: float = 0, alpha: float = 1, red: float = 0, green: float = 0, blue: float = 0, alpha: float = 1,
) -> Color: ) -> Color:
"""Return a `Color` from `(0-255, 0-255, 0-255, 0-1)` RGB arguments.""" """Return a `Color` from `(0-255, 0-255, 0-255, 0-1)` RGB arguments."""
return Color().but(red=red, green=green, blue=blue, alpha=alpha) return Color().but(red=red, green=green, blue=blue, alpha=alpha)
# Aliases # Aliases

View File

@ -12,117 +12,117 @@ import nio
@dataclass @dataclass
class MatrixError(Exception): class MatrixError(Exception):
"""An error returned by a Matrix server.""" """An error returned by a Matrix server."""
http_code: int = 400 http_code: int = 400
m_code: Optional[str] = None m_code: Optional[str] = None
message: Optional[str] = None message: Optional[str] = None
content: str = "" content: str = ""
@classmethod @classmethod
async def from_nio(cls, response: nio.ErrorResponse) -> "MatrixError": async def from_nio(cls, response: nio.ErrorResponse) -> "MatrixError":
"""Return a `MatrixError` subclass from a nio `ErrorResponse`.""" """Return a `MatrixError` subclass from a nio `ErrorResponse`."""
http_code = response.transport_response.status http_code = response.transport_response.status
m_code = response.status_code m_code = response.status_code
message = response.message message = response.message
content = await response.transport_response.text() content = await response.transport_response.text()
for subcls in cls.__subclasses__(): for subcls in cls.__subclasses__():
if subcls.m_code and subcls.m_code == m_code: if subcls.m_code and subcls.m_code == m_code:
return subcls(http_code, m_code, message, content) return subcls(http_code, m_code, message, content)
# If error doesn't have a M_CODE, look for a generic http error class # If error doesn't have a M_CODE, look for a generic http error class
for subcls in cls.__subclasses__(): for subcls in cls.__subclasses__():
if not subcls.m_code and subcls.http_code == http_code: if not subcls.m_code and subcls.http_code == http_code:
return subcls(http_code, m_code, message, content) return subcls(http_code, m_code, message, content)
return cls(http_code, m_code, message, content) return cls(http_code, m_code, message, content)
@dataclass @dataclass
class MatrixUnrecognized(MatrixError): class MatrixUnrecognized(MatrixError):
http_code: int = 400 http_code: int = 400
m_code: str = "M_UNRECOGNIZED" m_code: str = "M_UNRECOGNIZED"
@dataclass @dataclass
class MatrixInvalidAccessToken(MatrixError): class MatrixInvalidAccessToken(MatrixError):
http_code: int = 401 http_code: int = 401
m_code: str = "M_UNKNOWN_TOKEN" m_code: str = "M_UNKNOWN_TOKEN"
@dataclass @dataclass
class MatrixUnauthorized(MatrixError): class MatrixUnauthorized(MatrixError):
http_code: int = 401 http_code: int = 401
m_code: str = "M_UNAUTHORIZED" m_code: str = "M_UNAUTHORIZED"
@dataclass @dataclass
class MatrixForbidden(MatrixError): class MatrixForbidden(MatrixError):
http_code: int = 403 http_code: int = 403
m_code: str = "M_FORBIDDEN" m_code: str = "M_FORBIDDEN"
@dataclass @dataclass
class MatrixBadJson(MatrixError): class MatrixBadJson(MatrixError):
http_code: int = 403 http_code: int = 403
m_code: str = "M_BAD_JSON" m_code: str = "M_BAD_JSON"
@dataclass @dataclass
class MatrixNotJson(MatrixError): class MatrixNotJson(MatrixError):
http_code: int = 403 http_code: int = 403
m_code: str = "M_NOT_JSON" m_code: str = "M_NOT_JSON"
@dataclass @dataclass
class MatrixUserDeactivated(MatrixError): class MatrixUserDeactivated(MatrixError):
http_code: int = 403 http_code: int = 403
m_code: str = "M_USER_DEACTIVATED" m_code: str = "M_USER_DEACTIVATED"
@dataclass @dataclass
class MatrixNotFound(MatrixError): class MatrixNotFound(MatrixError):
http_code: int = 404 http_code: int = 404
m_code: str = "M_NOT_FOUND" m_code: str = "M_NOT_FOUND"
@dataclass @dataclass
class MatrixTooLarge(MatrixError): class MatrixTooLarge(MatrixError):
http_code: int = 413 http_code: int = 413
m_code: str = "M_TOO_LARGE" m_code: str = "M_TOO_LARGE"
@dataclass @dataclass
class MatrixBadGateway(MatrixError): class MatrixBadGateway(MatrixError):
http_code: int = 502 http_code: int = 502
m_code: Optional[str] = None m_code: Optional[str] = None
# Client errors # Client errors
@dataclass @dataclass
class InvalidUserId(Exception): class InvalidUserId(Exception):
user_id: str = field() user_id: str = field()
@dataclass @dataclass
class InvalidUserInContext(Exception): class InvalidUserInContext(Exception):
user_id: str = field() user_id: str = field()
@dataclass @dataclass
class UserFromOtherServerDisallowed(Exception): class UserFromOtherServerDisallowed(Exception):
user_id: str = field() user_id: str = field()
@dataclass @dataclass
class UneededThumbnail(Exception): class UneededThumbnail(Exception):
pass pass
@dataclass @dataclass
class BadMimeType(Exception): class BadMimeType(Exception):
wanted: str = field() wanted: str = field()
got: str = field() got: str = field()

View File

@ -19,503 +19,503 @@ from .color import SVGColor
def parse_colour(inline, m, state): def parse_colour(inline, m, state):
colour = m.group(1) colour = m.group(1)
text = m.group(2) text = m.group(2)
return "colour", colour, text return "colour", colour, text
def render_html_colour(colour, text): def render_html_colour(colour, text):
return f'<span data-mx-color="{colour}">{text}</span>' return f'<span data-mx-color="{colour}">{text}</span>'
def plugin_matrix(md): def plugin_matrix(md):
# test string: r"<b>(x) <r>(x) \<a>b>(x) <a\>b>(x) <b>(\(z) <c>(foo\)xyz)" # test string: r"<b>(x) <r>(x) \<a>b>(x) <a\>b>(x) <b>(\(z) <c>(foo\)xyz)"
colour = ( colour = (
r"^<(.+?)>" # capture the colour in `<colour>` r"^<(.+?)>" # capture the colour in `<colour>`
r"\((.+?)" # capture text in `(text` r"\((.+?)" # capture text in `(text`
r"(?<!\\)(?:\\\\)*" # ignore the next `)` if it's \escaped r"(?<!\\)(?:\\\\)*" # ignore the next `)` if it's \escaped
r"\)" # finish on a `)` r"\)" # finish on a `)`
) )
# Mark colour as high priority as otherwise e.g. <red>(hi) matches the # Mark colour as high priority as otherwise e.g. <red>(hi) matches the
# inline_html rule instead of the colour rule. # inline_html rule instead of the colour rule.
md.inline.rules.insert(1, "colour") md.inline.rules.insert(1, "colour")
md.inline.register_rule("colour", colour, parse_colour) md.inline.register_rule("colour", colour, parse_colour)
if md.renderer.NAME == "html": if md.renderer.NAME == "html":
md.renderer.register("colour", render_html_colour) md.renderer.register("colour", render_html_colour)
class HTMLProcessor: class HTMLProcessor:
"""Provide HTML filtering and conversion from Markdown. """Provide HTML filtering and conversion from Markdown.
Filtering sanitizes HTML and ensures it complies both with the Matrix Filtering sanitizes HTML and ensures it complies both with the Matrix
specification: specification:
https://matrix.org/docs/spec/client_server/latest#m-room-message-msgtypes https://matrix.org/docs/spec/client_server/latest#m-room-message-msgtypes
and the supported Qt HTML subset for usage in QML: and the supported Qt HTML subset for usage in QML:
https://doc.qt.io/qt-5/richtext-html-subset.html https://doc.qt.io/qt-5/richtext-html-subset.html
Some methods take an `outgoing` argument, specifying if the HTML is Some methods take an `outgoing` argument, specifying if the HTML is
intended to be sent to matrix servers or used locally in our application. intended to be sent to matrix servers or used locally in our application.
For local usage, extra transformations are applied: For local usage, extra transformations are applied:
- Wrap text lines starting with a `>` in `<span>` with a `quote` class. - Wrap text lines starting with a `>` in `<span>` with a `quote` class.
This allows them to be styled appropriately from QML. This allows them to be styled appropriately from QML.
Some methods take an `inline` argument, which return text appropriate Some methods take an `inline` argument, which return text appropriate
for UI elements restricted to display a single line, e.g. the room for UI elements restricted to display a single line, e.g. the room
last message subtitles in QML or notifications. last message subtitles in QML or notifications.
In inline filtered HTML, block tags are stripped or substituted and In inline filtered HTML, block tags are stripped or substituted and
newlines are turned into symbols (U+23CE). newlines are turned into symbols (U+23CE).
""" """
inline_tags = { inline_tags = {
"span", "font", "a", "sup", "sub", "b", "i", "s", "u", "code", "span", "font", "a", "sup", "sub", "b", "i", "s", "u", "code",
"mx-reply", "mx-reply",
} }
block_tags = { block_tags = {
"h1", "h2", "h3", "h4", "h5", "h6", "blockquote", "h1", "h2", "h3", "h4", "h5", "h6", "blockquote",
"p", "ul", "ol", "li", "hr", "br", "img", "p", "ul", "ol", "li", "hr", "br", "img",
"table", "thead", "tbody", "tr", "th", "td", "pre", "table", "thead", "tbody", "tr", "th", "td", "pre",
"mx-reply", "mx-reply",
} }
opaque_id = r"[a-zA-Z\d._-]+?" opaque_id = r"[a-zA-Z\d._-]+?"
user_id_localpart = r"[\x21-\x39\x3B-\x7E]+?" user_id_localpart = r"[\x21-\x39\x3B-\x7E]+?"
user_id_regex = re.compile( user_id_regex = re.compile(
rf"(?P<body>@{user_id_localpart}:(?P<host>[a-zA-Z\d.:-]*[a-zA-Z\d]))", rf"(?P<body>@{user_id_localpart}:(?P<host>[a-zA-Z\d.:-]*[a-zA-Z\d]))",
) )
room_id_regex = re.compile( room_id_regex = re.compile(
rf"(?P<body>!{opaque_id}:(?P<host>[a-zA-Z\d.:-]*[a-zA-Z\d]))", rf"(?P<body>!{opaque_id}:(?P<host>[a-zA-Z\d.:-]*[a-zA-Z\d]))",
) )
room_alias_regex = re.compile( room_alias_regex = re.compile(
r"(?=^|\W)(?P<body>#\S+?:(?P<host>[a-zA-Z\d.:-]*[a-zA-Z\d]))", r"(?=^|\W)(?P<body>#\S+?:(?P<host>[a-zA-Z\d.:-]*[a-zA-Z\d]))",
) )
link_regexes = [re.compile(r, re.IGNORECASE) link_regexes = [re.compile(r, re.IGNORECASE)
if isinstance(r, str) else r for r in [ if isinstance(r, str) else r for r in [
# Normal :// URLs # Normal :// URLs
(r"(?P<body>[a-z\d]+://(?P<host>[a-z\d._-]+(?:\:\d+)?)" (r"(?P<body>[a-z\d]+://(?P<host>[a-z\d._-]+(?:\:\d+)?)"
r"(?:/[/\-.,\w#%&?:;=~!$*+^@']*)?(?:\([/\-_.,a-z\d#%&?;=~]*\))?)"), r"(?:/[/\-.,\w#%&?:;=~!$*+^@']*)?(?:\([/\-_.,a-z\d#%&?;=~]*\))?)"),
# mailto: and tel: # mailto: and tel:
r"mailto:(?P<body>[a-z0-9._-]+@(?P<host>[a-z0-9.:-]*[a-z\d]))", r"mailto:(?P<body>[a-z0-9._-]+@(?P<host>[a-z0-9.:-]*[a-z\d]))",
r"tel:(?P<body>[0-9+-]+)(?P<host>)", r"tel:(?P<body>[0-9+-]+)(?P<host>)",
# magnet: # magnet:
r"(?P<body>magnet:\?xt=urn:[a-z0-9]+:.+)(?P<host>)", r"(?P<body>magnet:\?xt=urn:[a-z0-9]+:.+)(?P<host>)",
user_id_regex, room_id_regex, room_alias_regex, user_id_regex, room_id_regex, room_alias_regex,
]] ]]
matrix_to_regex = re.compile(r"^https?://matrix.to/#/", re.IGNORECASE) matrix_to_regex = re.compile(r"^https?://matrix.to/#/", re.IGNORECASE)
link_is_matrix_to_regex = re.compile( link_is_matrix_to_regex = re.compile(
r"https?://matrix.to/#/\S+", re.IGNORECASE, r"https?://matrix.to/#/\S+", re.IGNORECASE,
) )
link_is_user_id_regex = re.compile( link_is_user_id_regex = re.compile(
r"https?://matrix.to/#/@\S+", re.IGNORECASE, r"https?://matrix.to/#/@\S+", re.IGNORECASE,
) )
link_is_room_id_regex = re.compile( link_is_room_id_regex = re.compile(
r"https?://matrix.to/#/!\S+", re.IGNORECASE, r"https?://matrix.to/#/!\S+", re.IGNORECASE,
) )
link_is_room_alias_regex = re.compile( link_is_room_alias_regex = re.compile(
r"https?://matrix.to/#/#\S+", re.IGNORECASE, r"https?://matrix.to/#/#\S+", re.IGNORECASE,
) )
link_is_message_id_regex = re.compile( link_is_message_id_regex = re.compile(
r"https?://matrix.to/#/[!#]\S+/\$\S+", re.IGNORECASE, r"https?://matrix.to/#/[!#]\S+/\$\S+", re.IGNORECASE,
) )
inline_quote_regex = re.compile(r"(^|⏎|>)(\s*&gt;[^⏎\n]*)", re.MULTILINE) inline_quote_regex = re.compile(r"(^|⏎|>)(\s*&gt;[^⏎\n]*)", re.MULTILINE)
quote_regex = re.compile( quote_regex = re.compile(
r"(^|<span/?>|<p/?>|<br/?>|<h\d/?>|<mx-reply/?>)" r"(^|<span/?>|<p/?>|<br/?>|<h\d/?>|<mx-reply/?>)"
r"(\s*&gt;.*?)" r"(\s*&gt;.*?)"
r"(<span/?>|</?p>|<br/?>|</?h\d>|</mx-reply/?>|$)", r"(<span/?>|</?p>|<br/?>|</?h\d>|</mx-reply/?>|$)",
re.MULTILINE, re.MULTILINE,
) )
extra_newlines_regex = re.compile(r"\n(\n*)") extra_newlines_regex = re.compile(r"\n(\n*)")
def __init__(self) -> None: def __init__(self) -> None:
# The whitespace remover doesn't take <pre> into account # The whitespace remover doesn't take <pre> into account
sanitizer.normalize_overall_whitespace = lambda html, *args, **kw: html sanitizer.normalize_overall_whitespace = lambda html, *args, **kw: html
sanitizer.normalize_whitespace_in_text_or_tail = \ sanitizer.normalize_whitespace_in_text_or_tail = \
lambda el, *args, **kw: el lambda el, *args, **kw: el
# hard_wrap: convert all \n to <br> without required two spaces # hard_wrap: convert all \n to <br> without required two spaces
# escape: escape HTML characters in the input string, e.g. tags # escape: escape HTML characters in the input string, e.g. tags
self._markdown_to_html = mistune.create_markdown( self._markdown_to_html = mistune.create_markdown(
hard_wrap = True, hard_wrap = True,
escape = True, escape = True,
renderer = "html", renderer = "html",
plugins = ['strikethrough', plugin_matrix], plugins = ['strikethrough', plugin_matrix],
) )
def mentions_in_html(self, html: str) -> List[Tuple[str, str]]: def mentions_in_html(self, html: str) -> List[Tuple[str, str]]:
"""Return list of (text, href) tuples for all mention links in html.""" """Return list of (text, href) tuples for all mention links in html."""
if not html.strip(): if not html.strip():
return [] return []
return [ return [
(a_tag.text, href) (a_tag.text, href)
for a_tag, _, href, _ in lxml.html.iterlinks(html) for a_tag, _, href, _ in lxml.html.iterlinks(html)
if a_tag.text and if a_tag.text and
self.link_is_matrix_to_regex.match(unquote(href.strip())) self.link_is_matrix_to_regex.match(unquote(href.strip()))
] ]
def from_markdown( def from_markdown(
self, self,
text: str, text: str,
inline: bool = False, inline: bool = False,
outgoing: bool = False, outgoing: bool = False,
display_name_mentions: Optional[Dict[str, str]] = None, display_name_mentions: Optional[Dict[str, str]] = None,
) -> str: ) -> str:
"""Return filtered HTML from Markdown text.""" """Return filtered HTML from Markdown text."""
return self.filter( return self.filter(
self._markdown_to_html(text), self._markdown_to_html(text),
inline, inline,
outgoing, outgoing,
display_name_mentions, display_name_mentions,
) )
def filter( def filter(
self, self,
html: str, html: str,
inline: bool = False, inline: bool = False,
outgoing: bool = False, outgoing: bool = False,
display_name_mentions: Optional[Dict[str, str]] = None, display_name_mentions: Optional[Dict[str, str]] = None,
) -> str: ) -> str:
"""Filter and return HTML.""" """Filter and return HTML."""
mentions = display_name_mentions mentions = display_name_mentions
sanit = Sanitizer(self.sanitize_settings(inline, outgoing, mentions)) sanit = Sanitizer(self.sanitize_settings(inline, outgoing, mentions))
html = sanit.sanitize(html).rstrip("\n") html = sanit.sanitize(html).rstrip("\n")
if not html.strip(): if not html.strip():
return html return html
tree = etree.fromstring( tree = etree.fromstring(
html, parser=etree.HTMLParser(encoding="utf-8"), html, parser=etree.HTMLParser(encoding="utf-8"),
) )
for a_tag in tree.iterdescendants("a"): for a_tag in tree.iterdescendants("a"):
self._mentions_to_matrix_to_links(a_tag, mentions, outgoing) self._mentions_to_matrix_to_links(a_tag, mentions, outgoing)
if not outgoing: if not outgoing:
self._matrix_to_links_add_classes(a_tag) self._matrix_to_links_add_classes(a_tag)
html = etree.tostring(tree, encoding="utf-8", method="html").decode() html = etree.tostring(tree, encoding="utf-8", method="html").decode()
html = sanit.sanitize(html).rstrip("\n") html = sanit.sanitize(html).rstrip("\n")
if outgoing: if outgoing:
return html return html
# Client-side modifications # Client-side modifications
html = self.quote_regex.sub(r'\1<span class="quote">\2</span>\3', html) html = self.quote_regex.sub(r'\1<span class="quote">\2</span>\3', html)
if not inline: if not inline:
return html return html
return self.inline_quote_regex.sub( return self.inline_quote_regex.sub(
r'\1<span class="quote">\2</span>', html, r'\1<span class="quote">\2</span>', html,
) )
def sanitize_settings( def sanitize_settings(
self, self,
inline: bool = False, inline: bool = False,
outgoing: bool = False, outgoing: bool = False,
display_name_mentions: Optional[Dict[str, str]] = None, display_name_mentions: Optional[Dict[str, str]] = None,
) -> dict: ) -> dict:
"""Return an html_sanitizer configuration.""" """Return an html_sanitizer configuration."""
# https://matrix.org/docs/spec/client_server/latest#m-room-message-msgtypes # https://matrix.org/docs/spec/client_server/latest#m-room-message-msgtypes
inline_tags = self.inline_tags inline_tags = self.inline_tags
all_tags = inline_tags | self.block_tags all_tags = inline_tags | self.block_tags
inlines_attributes = { inlines_attributes = {
"font": {"color"}, "font": {"color"},
"a": {"href", "class", "data-mention"}, "a": {"href", "class", "data-mention"},
"code": {"class"}, "code": {"class"},
} }
attributes = {**inlines_attributes, **{ attributes = {**inlines_attributes, **{
"ol": {"start"}, "ol": {"start"},
"hr": {"width"}, "hr": {"width"},
"span": {"data-mx-color"}, "span": {"data-mx-color"},
"img": { "img": {
"data-mx-emote", "src", "alt", "title", "width", "height", "data-mx-emote", "src", "alt", "title", "width", "height",
}, },
}} }}
username_link_regexes = [re.compile(r) for r in [ username_link_regexes = [re.compile(r) for r in [
rf"(?<!\w)(?P<body>{re.escape(name or user_id)})(?!\w)(?P<host>)" rf"(?<!\w)(?P<body>{re.escape(name or user_id)})(?!\w)(?P<host>)"
for user_id, name in (display_name_mentions or {}).items() for user_id, name in (display_name_mentions or {}).items()
]] ]]
return { return {
"tags": inline_tags if inline else all_tags, "tags": inline_tags if inline else all_tags,
"attributes": inlines_attributes if inline else attributes, "attributes": inlines_attributes if inline else attributes,
"empty": {} if inline else {"hr", "br", "img"}, "empty": {} if inline else {"hr", "br", "img"},
"separate": {"a"} if inline else { "separate": {"a"} if inline else {
"a", "p", "li", "table", "tr", "th", "td", "br", "hr", "img", "a", "p", "li", "table", "tr", "th", "td", "br", "hr", "img",
}, },
"whitespace": {}, "whitespace": {},
"keep_typographic_whitespace": True, "keep_typographic_whitespace": True,
"add_nofollow": False, "add_nofollow": False,
"autolink": { "autolink": {
"link_regexes": "link_regexes":
self.link_regexes + username_link_regexes, # type: ignore self.link_regexes + username_link_regexes, # type: ignore
"avoid_hosts": [], "avoid_hosts": [],
}, },
"sanitize_href": lambda href: href, "sanitize_href": lambda href: href,
"element_preprocessors": [ "element_preprocessors": [
sanitizer.bold_span_to_strong, sanitizer.bold_span_to_strong,
sanitizer.italic_span_to_em, sanitizer.italic_span_to_em,
sanitizer.tag_replacer("strong", "b"), sanitizer.tag_replacer("strong", "b"),
sanitizer.tag_replacer("em", "i"), sanitizer.tag_replacer("em", "i"),
sanitizer.tag_replacer("strike", "s"), sanitizer.tag_replacer("strike", "s"),
sanitizer.tag_replacer("del", "s"), sanitizer.tag_replacer("del", "s"),
sanitizer.tag_replacer("form", "p"), sanitizer.tag_replacer("form", "p"),
sanitizer.tag_replacer("div", "p"), sanitizer.tag_replacer("div", "p"),
sanitizer.tag_replacer("caption", "p"), sanitizer.tag_replacer("caption", "p"),
sanitizer.target_blank_noopener, sanitizer.target_blank_noopener,
self._span_color_to_font if not outgoing else lambda el: el, self._span_color_to_font if not outgoing else lambda el: el,
self._img_to_a, self._img_to_a,
self._remove_extra_newlines, self._remove_extra_newlines,
self._newlines_to_return_symbol if inline else lambda el: el, self._newlines_to_return_symbol if inline else lambda el: el,
self._reply_to_inline if inline else lambda el: el, self._reply_to_inline if inline else lambda el: el,
], ],
"element_postprocessors": [ "element_postprocessors": [
self._font_color_to_span if outgoing else lambda el: el, self._font_color_to_span if outgoing else lambda el: el,
self._hr_to_dashes if not outgoing else lambda el: el, self._hr_to_dashes if not outgoing else lambda el: el,
], ],
"is_mergeable": lambda e1, e2: e1.attrib == e2.attrib, "is_mergeable": lambda e1, e2: e1.attrib == e2.attrib,
} }
@staticmethod @staticmethod
def _span_color_to_font(el: HtmlElement) -> HtmlElement: def _span_color_to_font(el: HtmlElement) -> HtmlElement:
"""Convert HTML `<span data-mx-color=...` to `<font color=...>`.""" """Convert HTML `<span data-mx-color=...` to `<font color=...>`."""
if el.tag not in ("span", "font"): if el.tag not in ("span", "font"):
return el return el
color = el.attrib.pop("data-mx-color", None) color = el.attrib.pop("data-mx-color", None)
if color: if color:
el.tag = "font" el.tag = "font"
el.attrib["color"] = color el.attrib["color"] = color
return el return el
@staticmethod @staticmethod
def _font_color_to_span(el: HtmlElement) -> HtmlElement: def _font_color_to_span(el: HtmlElement) -> HtmlElement:
"""Convert HTML `<font color=...>` to `<span data-mx-color=...`.""" """Convert HTML `<font color=...>` to `<span data-mx-color=...`."""
if el.tag not in ("span", "font"): if el.tag not in ("span", "font"):
return el return el
color = el.attrib.pop("color", None) color = el.attrib.pop("color", None)
if color: if color:
el.tag = "span" el.tag = "span"
el.attrib["data-mx-color"] = color el.attrib["data-mx-color"] = color
return el return el
@staticmethod @staticmethod
def _img_to_a(el: HtmlElement) -> HtmlElement: def _img_to_a(el: HtmlElement) -> HtmlElement:
"""Linkify images by wrapping `<img>` tags in `<a>`.""" """Linkify images by wrapping `<img>` tags in `<a>`."""
if el.tag != "img": if el.tag != "img":
return el return el
src = el.attrib.get("src", "") src = el.attrib.get("src", "")
width = el.attrib.get("width") width = el.attrib.get("width")
height = el.attrib.get("height") height = el.attrib.get("height")
is_emote = "data-mx-emote" in el.attrib is_emote = "data-mx-emote" in el.attrib
if src.startswith("mxc://"): if src.startswith("mxc://"):
el.attrib["src"] = nio.Api.mxc_to_http(src) el.attrib["src"] = nio.Api.mxc_to_http(src)
if is_emote and not width and not height: if is_emote and not width and not height:
el.attrib["width"] = 32 el.attrib["width"] = 32
el.attrib["height"] = 32 el.attrib["height"] = 32
elif is_emote and width and not height: elif is_emote and width and not height:
el.attrib["height"] = width el.attrib["height"] = width
elif is_emote and height and not width: elif is_emote and height and not width:
el.attrib["width"] = height el.attrib["width"] = height
elif not is_emote and (not width or not height): elif not is_emote and (not width or not height):
el.tag = "a" el.tag = "a"
el.attrib["href"] = el.attrib.pop("src", "") el.attrib["href"] = el.attrib.pop("src", "")
el.text = el.attrib.pop("alt", None) or el.attrib["href"] el.text = el.attrib.pop("alt", None) or el.attrib["href"]
return el return el
def _remove_extra_newlines(self, el: HtmlElement) -> HtmlElement: def _remove_extra_newlines(self, el: HtmlElement) -> HtmlElement:
r"""Remove excess `\n` characters from HTML elements. r"""Remove excess `\n` characters from HTML elements.
This is done to avoid additional blank lines when the CSS directive This is done to avoid additional blank lines when the CSS directive
`white-space: pre` is used. `white-space: pre` is used.
Text inside `<pre>` tags is ignored, except for the final newlines. Text inside `<pre>` tags is ignored, except for the final newlines.
""" """
pre_parent = any(parent.tag == "pre" for parent in el.iterancestors()) pre_parent = any(parent.tag == "pre" for parent in el.iterancestors())
if el.tag != "pre" and not pre_parent: if el.tag != "pre" and not pre_parent:
if el.text: if el.text:
el.text = self.extra_newlines_regex.sub(r"\1", el.text) el.text = self.extra_newlines_regex.sub(r"\1", el.text)
if el.tail: if el.tail:
el.tail = self.extra_newlines_regex.sub(r"\1", el.tail) el.tail = self.extra_newlines_regex.sub(r"\1", el.tail)
else: else:
if el.text and el.text.endswith("\n"): if el.text and el.text.endswith("\n"):
el.text = el.text[:-1] el.text = el.text[:-1]
if el.tail and el.tail.endswith("\n"): if el.tail and el.tail.endswith("\n"):
el.tail = el.tail[:-1] el.tail = el.tail[:-1]
return el return el
def _newlines_to_return_symbol(self, el: HtmlElement) -> HtmlElement: def _newlines_to_return_symbol(self, el: HtmlElement) -> HtmlElement:
"""Turn newlines into unicode return symbols (⏎, U+23CE). """Turn newlines into unicode return symbols (⏎, U+23CE).
The symbol is added to blocks with siblings (e.g. a `<p>` followed by The symbol is added to blocks with siblings (e.g. a `<p>` followed by
another `<p>`) and `<br>` tags. another `<p>`) and `<br>` tags.
The `<br>` themselves will be removed by the inline sanitizer. The `<br>` themselves will be removed by the inline sanitizer.
""" """
is_block_with_siblings = (el.tag in self.block_tags and is_block_with_siblings = (el.tag in self.block_tags and
next(el.itersiblings(), None) is not None) next(el.itersiblings(), None) is not None)
if el.tag == "br" or is_block_with_siblings: if el.tag == "br" or is_block_with_siblings:
el.tail = f"{el.tail or ''}" el.tail = f"{el.tail or ''}"
# Replace left \n in text/tail of <pre> content by the return symbol. # Replace left \n in text/tail of <pre> content by the return symbol.
if el.text: if el.text:
el.text = re.sub(r"\n", r"", el.text) el.text = re.sub(r"\n", r"", el.text)
if el.tail: if el.tail:
el.tail = re.sub(r"\n", r"", el.tail) el.tail = re.sub(r"\n", r"", el.tail)
return el return el
def _reply_to_inline(self, el: HtmlElement) -> HtmlElement: def _reply_to_inline(self, el: HtmlElement) -> HtmlElement:
"""Shorten <mx-reply> to only include the replied to event's sender.""" """Shorten <mx-reply> to only include the replied to event's sender."""
if el.tag != "mx-reply": if el.tag != "mx-reply":
return el return el
try: try:
user_id = el.find("blockquote").findall("a")[1].text user_id = el.find("blockquote").findall("a")[1].text
text = f"{user_id[1: ].split(':')[0]}: " # U+21A9 arrow text = f"{user_id[1: ].split(':')[0]}: " # U+21A9 arrow
tail = el.tail.rstrip().rstrip("") tail = el.tail.rstrip().rstrip("")
except (AttributeError, IndexError): except (AttributeError, IndexError):
return el return el
el.clear() el.clear()
el.text = text el.text = text
el.tail = tail el.tail = tail
return el return el
def _mentions_to_matrix_to_links( def _mentions_to_matrix_to_links(
self, self,
el: HtmlElement, el: HtmlElement,
display_name_mentions: Optional[Dict[str, str]] = None, display_name_mentions: Optional[Dict[str, str]] = None,
outgoing: bool = False, outgoing: bool = False,
) -> HtmlElement: ) -> HtmlElement:
"""Turn user ID, usernames and room ID/aliases into matrix.to URL. """Turn user ID, usernames and room ID/aliases into matrix.to URL.
After the HTML sanitizer autolinks these, the links's hrefs are the After the HTML sanitizer autolinks these, the links's hrefs are the
link text, e.g. `<a href="@foo:bar.com">@foo:bar.com</a>`. link text, e.g. `<a href="@foo:bar.com">@foo:bar.com</a>`.
We turn them into proper matrix.to URL in this function. We turn them into proper matrix.to URL in this function.
""" """
if el.tag != "a" or not el.attrib.get("href"): if el.tag != "a" or not el.attrib.get("href"):
return el return el
id_regexes = ( id_regexes = (
self.user_id_regex, self.room_id_regex, self.room_alias_regex, self.user_id_regex, self.room_id_regex, self.room_alias_regex,
) )
for regex in id_regexes: for regex in id_regexes:
if regex.match(unquote(el.attrib["href"])): if regex.match(unquote(el.attrib["href"])):
el.attrib["href"] = f"https://matrix.to/#/{el.attrib['href']}" el.attrib["href"] = f"https://matrix.to/#/{el.attrib['href']}"
return el return el
for user_id, name in (display_name_mentions or {}).items(): for user_id, name in (display_name_mentions or {}).items():
if unquote(el.attrib["href"]) == (name or user_id): if unquote(el.attrib["href"]) == (name or user_id):
el.attrib["href"] = f"https://matrix.to/#/{user_id}" el.attrib["href"] = f"https://matrix.to/#/{user_id}"
return el return el
return el return el
def _matrix_to_links_add_classes(self, el: HtmlElement) -> HtmlElement: def _matrix_to_links_add_classes(self, el: HtmlElement) -> HtmlElement:
"""Add special CSS classes to matrix.to mention links.""" """Add special CSS classes to matrix.to mention links."""
href = unquote(el.attrib.get("href", "")) href = unquote(el.attrib.get("href", ""))
if not href or not el.text: if not href or not el.text:
return el return el
el.text = self.matrix_to_regex.sub("", el.text or "") el.text = self.matrix_to_regex.sub("", el.text or "")
# This must be first, or link will be mistaken by room ID/alias regex # This must be first, or link will be mistaken by room ID/alias regex
if self.link_is_message_id_regex.match(href): if self.link_is_message_id_regex.match(href):
el.attrib["class"] = "mention message-id-mention" el.attrib["class"] = "mention message-id-mention"
el.attrib["data-mention"] = el.text.strip() el.attrib["data-mention"] = el.text.strip()
elif self.link_is_user_id_regex.match(href): elif self.link_is_user_id_regex.match(href):
if el.text.strip().startswith("@"): if el.text.strip().startswith("@"):
el.attrib["class"] = "mention user-id-mention" el.attrib["class"] = "mention user-id-mention"
else: else:
el.attrib["class"] = "mention username-mention" el.attrib["class"] = "mention username-mention"
el.attrib["data-mention"] = el.text.strip() el.attrib["data-mention"] = el.text.strip()
elif self.link_is_room_id_regex.match(href): elif self.link_is_room_id_regex.match(href):
el.attrib["class"] = "mention room-id-mention" el.attrib["class"] = "mention room-id-mention"
el.attrib["data-mention"] = el.text.strip() el.attrib["data-mention"] = el.text.strip()
elif self.link_is_room_alias_regex.match(href): elif self.link_is_room_alias_regex.match(href):
el.attrib["class"] = "mention room-alias-mention" el.attrib["class"] = "mention room-alias-mention"
el.attrib["data-mention"] = el.text.strip() el.attrib["data-mention"] = el.text.strip()
return el return el
def _hr_to_dashes(self, el: HtmlElement) -> HtmlElement: def _hr_to_dashes(self, el: HtmlElement) -> HtmlElement:
if el.tag != "hr": if el.tag != "hr":
return el return el
el.tag = "p" el.tag = "p"
el.attrib["class"] = "ruler" el.attrib["class"] = "ruler"
el.text = "" * 19 el.text = "" * 19
return el return el
HTML_PROCESSOR = HTMLProcessor() HTML_PROCESSOR = HTMLProcessor()

File diff suppressed because it is too large Load Diff

View File

@ -24,354 +24,354 @@ from .models.model import Model
from .utils import Size, atomic_write, current_task from .utils import Size, atomic_write, current_task
if TYPE_CHECKING: if TYPE_CHECKING:
from .backend import Backend from .backend import Backend
if sys.version_info < (3, 8): if sys.version_info < (3, 8):
import pyfastcopy # noqa import pyfastcopy # noqa
CONCURRENT_DOWNLOADS_LIMIT = asyncio.BoundedSemaphore(8) CONCURRENT_DOWNLOADS_LIMIT = asyncio.BoundedSemaphore(8)
ACCESS_LOCKS: DefaultDict[str, asyncio.Lock] = DefaultDict(asyncio.Lock) ACCESS_LOCKS: DefaultDict[str, asyncio.Lock] = DefaultDict(asyncio.Lock)
@dataclass @dataclass
class MediaCache: class MediaCache:
"""Matrix downloaded media cache.""" """Matrix downloaded media cache."""
backend: "Backend" = field() backend: "Backend" = field()
base_dir: Path = field() base_dir: Path = field()
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.thumbs_dir = self.base_dir / "thumbnails" self.thumbs_dir = self.base_dir / "thumbnails"
self.downloads_dir = self.base_dir / "downloads" self.downloads_dir = self.base_dir / "downloads"
self.thumbs_dir.mkdir(parents=True, exist_ok=True) self.thumbs_dir.mkdir(parents=True, exist_ok=True)
self.downloads_dir.mkdir(parents=True, exist_ok=True) self.downloads_dir.mkdir(parents=True, exist_ok=True)
async def get_media(self, *args) -> Path: async def get_media(self, *args) -> Path:
"""Return `Media(self, ...).get()`'s result. Intended for QML.""" """Return `Media(self, ...).get()`'s result. Intended for QML."""
return await Media(self, *args).get() return await Media(self, *args).get()
async def get_thumbnail(self, width: float, height: float, *args) -> Path: async def get_thumbnail(self, width: float, height: float, *args) -> Path:
"""Return `Thumbnail(self, ...).get()`'s result. Intended for QML.""" """Return `Thumbnail(self, ...).get()`'s result. Intended for QML."""
# QML sometimes pass float sizes, which matrix API doesn't like. # QML sometimes pass float sizes, which matrix API doesn't like.
size = (round(width), round(height)) size = (round(width), round(height))
return await Thumbnail( return await Thumbnail(
self, *args, wanted_size=size, # type: ignore self, *args, wanted_size=size, # type: ignore
).get() ).get()
@dataclass @dataclass
class Media: class Media:
"""A matrix media file that is downloaded or has yet to be. """A matrix media file that is downloaded or has yet to be.
If the `room_id` is not set, no `Transfer` model item will be registered If the `room_id` is not set, no `Transfer` model item will be registered
while this media is being downloaded. while this media is being downloaded.
""" """
cache: "MediaCache" = field() cache: "MediaCache" = field()
client_user_id: str = field() client_user_id: str = field()
mxc: str = field() mxc: str = field()
title: str = field() title: str = field()
room_id: Optional[str] = None room_id: Optional[str] = None
filesize: Optional[int] = None filesize: Optional[int] = None
crypt_dict: Optional[Dict[str, Any]] = field(default=None, repr=False) crypt_dict: Optional[Dict[str, Any]] = field(default=None, repr=False)
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.mxc = re.sub(r"#auto$", "", self.mxc) self.mxc = re.sub(r"#auto$", "", self.mxc)
if not re.match(r"^mxc://.+/.+", self.mxc): if not re.match(r"^mxc://.+/.+", self.mxc):
raise ValueError(f"Invalid mxc URI: {self.mxc}") raise ValueError(f"Invalid mxc URI: {self.mxc}")
@property @property
def local_path(self) -> Path: def local_path(self) -> Path:
"""The path where the file either exists or should be downloaded. """The path where the file either exists or should be downloaded.
The returned paths are in this form: The returned paths are in this form:
``` ```
<base download folder>/<homeserver domain>/ <base download folder>/<homeserver domain>/
<file title>_<mxc id>.<file extension>` <file title>_<mxc id>.<file extension>`
``` ```
e.g. `~/.cache/moment/downloads/matrix.org/foo_Hm24ar11i768b0el.png`. e.g. `~/.cache/moment/downloads/matrix.org/foo_Hm24ar11i768b0el.png`.
""" """
parsed = urlparse(self.mxc) parsed = urlparse(self.mxc)
mxc_id = parsed.path.lstrip("/") mxc_id = parsed.path.lstrip("/")
title = Path(self.title) title = Path(self.title)
filename = f"{title.stem}_{mxc_id}{title.suffix}" filename = f"{title.stem}_{mxc_id}{title.suffix}"
return self.cache.downloads_dir / parsed.netloc / filename return self.cache.downloads_dir / parsed.netloc / filename
async def get(self) -> Path: async def get(self) -> Path:
"""Return the cached file's path, downloading it first if needed.""" """Return the cached file's path, downloading it first if needed."""
async with ACCESS_LOCKS[self.mxc]: async with ACCESS_LOCKS[self.mxc]:
try: try:
return await self.get_local() return await self.get_local()
except FileNotFoundError: except FileNotFoundError:
return await self.create() return await self.create()
async def get_local(self) -> Path: async def get_local(self) -> Path:
"""Return a cached local existing path for this media or raise.""" """Return a cached local existing path for this media or raise."""
if not self.local_path.exists(): if not self.local_path.exists():
raise FileNotFoundError() raise FileNotFoundError()
return self.local_path return self.local_path
async def create(self) -> Path: async def create(self) -> Path:
"""Download and cache the media file to disk.""" """Download and cache the media file to disk."""
async with CONCURRENT_DOWNLOADS_LIMIT: async with CONCURRENT_DOWNLOADS_LIMIT:
data = await self._get_remote_data() data = await self._get_remote_data()
self.local_path.parent.mkdir(parents=True, exist_ok=True) self.local_path.parent.mkdir(parents=True, exist_ok=True)
async with atomic_write(self.local_path, binary=True) as (file, done): async with atomic_write(self.local_path, binary=True) as (file, done):
await file.write(data) await file.write(data)
done() done()
if type(self) is Media: if type(self) is Media:
for event in self.cache.backend.mxc_events[self.mxc]: for event in self.cache.backend.mxc_events[self.mxc]:
event.media_local_path = self.local_path event.media_local_path = self.local_path
return self.local_path return self.local_path
async def _get_remote_data(self) -> bytes: async def _get_remote_data(self) -> bytes:
"""Return the file's data from the matrix server, decrypt if needed.""" """Return the file's data from the matrix server, decrypt if needed."""
client = self.cache.backend.clients[self.client_user_id] client = self.cache.backend.clients[self.client_user_id]
transfer: Optional[Transfer] = None transfer: Optional[Transfer] = None
model: Optional[Model] = None model: Optional[Model] = None
if self.room_id: if self.room_id:
model = self.cache.backend.models[self.room_id, "transfers"] model = self.cache.backend.models[self.room_id, "transfers"]
transfer = Transfer( transfer = Transfer(
id = uuid4(), id = uuid4(),
is_upload = False, is_upload = False,
filepath = self.local_path, filepath = self.local_path,
total_size = self.filesize or 0, total_size = self.filesize or 0,
status = TransferStatus.Transfering, status = TransferStatus.Transfering,
) )
assert model is not None assert model is not None
client.transfer_tasks[transfer.id] = current_task() # type: ignore client.transfer_tasks[transfer.id] = current_task() # type: ignore
model[str(transfer.id)] = transfer model[str(transfer.id)] = transfer
try: try:
parsed = urlparse(self.mxc) parsed = urlparse(self.mxc)
resp = await client.download( resp = await client.download(
server_name = parsed.netloc, server_name = parsed.netloc,
media_id = parsed.path.lstrip("/"), media_id = parsed.path.lstrip("/"),
) )
except (nio.TransferCancelledError, asyncio.CancelledError): except (nio.TransferCancelledError, asyncio.CancelledError):
if transfer and model: if transfer and model:
del model[str(transfer.id)] del model[str(transfer.id)]
del client.transfer_tasks[transfer.id] del client.transfer_tasks[transfer.id]
raise raise
if transfer and model: if transfer and model:
del model[str(transfer.id)] del model[str(transfer.id)]
del client.transfer_tasks[transfer.id] del client.transfer_tasks[transfer.id]
return await self._decrypt(resp.body) return await self._decrypt(resp.body)
async def _decrypt(self, data: bytes) -> bytes: async def _decrypt(self, data: bytes) -> bytes:
"""Decrypt an encrypted file's data.""" """Decrypt an encrypted file's data."""
if not self.crypt_dict: if not self.crypt_dict:
return data return data
func = functools.partial( func = functools.partial(
nio.crypto.attachments.decrypt_attachment, nio.crypto.attachments.decrypt_attachment,
data, data,
self.crypt_dict["key"]["k"], self.crypt_dict["key"]["k"],
self.crypt_dict["hashes"]["sha256"], self.crypt_dict["hashes"]["sha256"],
self.crypt_dict["iv"], self.crypt_dict["iv"],
) )
# Run in a separate thread # Run in a separate thread
return await asyncio.get_event_loop().run_in_executor(None, func) return await asyncio.get_event_loop().run_in_executor(None, func)
@classmethod @classmethod
async def from_existing_file( async def from_existing_file(
cls, cls,
cache: "MediaCache", cache: "MediaCache",
client_user_id: str, client_user_id: str,
mxc: str, mxc: str,
existing: Path, existing: Path,
overwrite: bool = False, overwrite: bool = False,
**kwargs, **kwargs,
) -> "Media": ) -> "Media":
"""Copy an existing file to cache and return a `Media` for it.""" """Copy an existing file to cache and return a `Media` for it."""
media = cls( media = cls(
cache = cache, cache = cache,
client_user_id = client_user_id, client_user_id = client_user_id,
mxc = mxc, mxc = mxc,
title = existing.name, title = existing.name,
filesize = existing.stat().st_size, filesize = existing.stat().st_size,
**kwargs, **kwargs,
) )
media.local_path.parent.mkdir(parents=True, exist_ok=True) media.local_path.parent.mkdir(parents=True, exist_ok=True)
if not media.local_path.exists() or overwrite: if not media.local_path.exists() or overwrite:
func = functools.partial(shutil.copy, existing, media.local_path) func = functools.partial(shutil.copy, existing, media.local_path)
await asyncio.get_event_loop().run_in_executor(None, func) await asyncio.get_event_loop().run_in_executor(None, func)
return media return media
@classmethod @classmethod
async def from_bytes( async def from_bytes(
cls, cls,
cache: "MediaCache", cache: "MediaCache",
client_user_id: str, client_user_id: str,
mxc: str, mxc: str,
filename: str, filename: str,
data: bytes, data: bytes,
overwrite: bool = False, overwrite: bool = False,
**kwargs, **kwargs,
) -> "Media": ) -> "Media":
"""Create a cached file from bytes data and return a `Media` for it.""" """Create a cached file from bytes data and return a `Media` for it."""
media = cls( media = cls(
cache, client_user_id, mxc, filename, filesize=len(data), **kwargs, cache, client_user_id, mxc, filename, filesize=len(data), **kwargs,
) )
media.local_path.parent.mkdir(parents=True, exist_ok=True) media.local_path.parent.mkdir(parents=True, exist_ok=True)
if not media.local_path.exists() or overwrite: if not media.local_path.exists() or overwrite:
path = media.local_path path = media.local_path
async with atomic_write(path, binary=True) as (file, done): async with atomic_write(path, binary=True) as (file, done):
await file.write(data) await file.write(data)
done() done()
return media return media
@dataclass @dataclass
class Thumbnail(Media): class Thumbnail(Media):
"""A matrix media's thumbnail, which is downloaded or has yet to be.""" """A matrix media's thumbnail, which is downloaded or has yet to be."""
wanted_size: Size = (800, 600) wanted_size: Size = (800, 600)
server_size: Optional[Size] = field(init=False, repr=False, default=None) server_size: Optional[Size] = field(init=False, repr=False, default=None)
@staticmethod @staticmethod
def normalize_size(size: Size) -> Size: def normalize_size(size: Size) -> Size:
"""Return standard `(width, height)` matrix thumbnail dimensions. """Return standard `(width, height)` matrix thumbnail dimensions.
The Matrix specification defines a few standard thumbnail dimensions The Matrix specification defines a few standard thumbnail dimensions
for homeservers to store and return: 32x32, 96x96, 320x240, 640x480, for homeservers to store and return: 32x32, 96x96, 320x240, 640x480,
and 800x600. and 800x600.
This method returns the best matching size for a `size` without This method returns the best matching size for a `size` without
upscaling, e.g. passing `(641, 480)` will return `(800, 600)`. upscaling, e.g. passing `(641, 480)` will return `(800, 600)`.
""" """
if size[0] > 640 or size[1] > 480: if size[0] > 640 or size[1] > 480:
return (800, 600) return (800, 600)
if size[0] > 320 or size[1] > 240: if size[0] > 320 or size[1] > 240:
return (640, 480) return (640, 480)
if size[0] > 96 or size[1] > 96: if size[0] > 96 or size[1] > 96:
return (320, 240) return (320, 240)
if size[0] > 32 or size[1] > 32: if size[0] > 32 or size[1] > 32:
return (96, 96) return (96, 96)
return (32, 32) return (32, 32)
@property @property
def local_path(self) -> Path: def local_path(self) -> Path:
"""The path where the thumbnail either exists or should be downloaded. """The path where the thumbnail either exists or should be downloaded.
The returned paths are in this form: The returned paths are in this form:
``` ```
<base thumbnail folder>/<homeserver domain>/<standard size>/ <base thumbnail folder>/<homeserver domain>/<standard size>/
<file title>_<mxc id>.<file extension>` <file title>_<mxc id>.<file extension>`
``` ```
e.g. e.g.
`~/.cache/moment/thumbnails/matrix.org/32x32/foo_Hm24ar11i768b0el.png`. `~/.cache/moment/thumbnails/matrix.org/32x32/foo_Hm24ar11i768b0el.png`.
""" """
size = self.normalize_size(self.server_size or self.wanted_size) size = self.normalize_size(self.server_size or self.wanted_size)
size_dir = f"{size[0]}x{size[1]}" size_dir = f"{size[0]}x{size[1]}"
parsed = urlparse(self.mxc) parsed = urlparse(self.mxc)
mxc_id = parsed.path.lstrip("/") mxc_id = parsed.path.lstrip("/")
title = Path(self.title) title = Path(self.title)
filename = f"{title.stem}_{mxc_id}{title.suffix}" filename = f"{title.stem}_{mxc_id}{title.suffix}"
return self.cache.thumbs_dir / parsed.netloc / size_dir / filename return self.cache.thumbs_dir / parsed.netloc / size_dir / filename
async def get_local(self) -> Path: async def get_local(self) -> Path:
"""Return an existing thumbnail path or raise `FileNotFoundError`. """Return an existing thumbnail path or raise `FileNotFoundError`.
If we have a bigger size thumbnail downloaded than the `wanted_size` If we have a bigger size thumbnail downloaded than the `wanted_size`
for the media, return it instead of asking the server for a for the media, return it instead of asking the server for a
smaller thumbnail. smaller thumbnail.
""" """
if self.local_path.exists(): if self.local_path.exists():
return self.local_path return self.local_path
try_sizes = ((32, 32), (96, 96), (320, 240), (640, 480), (800, 600)) try_sizes = ((32, 32), (96, 96), (320, 240), (640, 480), (800, 600))
parts = list(self.local_path.parts) parts = list(self.local_path.parts)
size = self.normalize_size(self.server_size or self.wanted_size) size = self.normalize_size(self.server_size or self.wanted_size)
for width, height in try_sizes: for width, height in try_sizes:
if width < size[0] or height < size[1]: if width < size[0] or height < size[1]:
continue continue
parts[-2] = f"{width}x{height}" parts[-2] = f"{width}x{height}"
path = Path("/".join(parts)) path = Path("/".join(parts))
if path.exists(): if path.exists():
return path return path
raise FileNotFoundError() raise FileNotFoundError()
async def _get_remote_data(self) -> bytes: async def _get_remote_data(self) -> bytes:
"""Return the (decrypted) media file's content from the server.""" """Return the (decrypted) media file's content from the server."""
parsed = urlparse(self.mxc) parsed = urlparse(self.mxc)
client = self.cache.backend.clients[self.client_user_id] client = self.cache.backend.clients[self.client_user_id]
if self.crypt_dict: if self.crypt_dict:
# Matrix makes encrypted thumbs only available through the download # Matrix makes encrypted thumbs only available through the download
# end-point, not the thumbnail one # end-point, not the thumbnail one
resp = await client.download( resp = await client.download(
server_name = parsed.netloc, server_name = parsed.netloc,
media_id = parsed.path.lstrip("/"), media_id = parsed.path.lstrip("/"),
) )
else: else:
resp = await client.thumbnail( resp = await client.thumbnail(
server_name = parsed.netloc, server_name = parsed.netloc,
media_id = parsed.path.lstrip("/"), media_id = parsed.path.lstrip("/"),
width = self.wanted_size[0], width = self.wanted_size[0],
height = self.wanted_size[1], height = self.wanted_size[1],
) )
decrypted = await self._decrypt(resp.body) decrypted = await self._decrypt(resp.body)
with io.BytesIO(decrypted) as img: with io.BytesIO(decrypted) as img:
# The server may return a thumbnail bigger than what we asked for # The server may return a thumbnail bigger than what we asked for
self.server_size = PILImage.open(img).size self.server_size = PILImage.open(img).size
return decrypted return decrypted

View File

@ -2,7 +2,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later # SPDX-License-Identifier: LGPL-3.0-or-later
from typing import ( from typing import (
TYPE_CHECKING, Any, Callable, Collection, Dict, List, Optional, Tuple, TYPE_CHECKING, Any, Callable, Collection, Dict, List, Optional, Tuple,
) )
from . import SyncId from . import SyncId
@ -10,185 +10,185 @@ from .model import Model
from .proxy import ModelProxy from .proxy import ModelProxy
if TYPE_CHECKING: if TYPE_CHECKING:
from .model_item import ModelItem from .model_item import ModelItem
class ModelFilter(ModelProxy): class ModelFilter(ModelProxy):
"""Filter data from one or more source models.""" """Filter data from one or more source models."""
def __init__(self, sync_id: SyncId) -> None: def __init__(self, sync_id: SyncId) -> None:
self.filtered_out: Dict[Tuple[Optional[SyncId], str], "ModelItem"] = {} self.filtered_out: Dict[Tuple[Optional[SyncId], str], "ModelItem"] = {}
self.items_changed_callbacks: List[Callable[[], None]] = [] self.items_changed_callbacks: List[Callable[[], None]] = []
super().__init__(sync_id) super().__init__(sync_id)
def accept_item(self, item: "ModelItem") -> bool: def accept_item(self, item: "ModelItem") -> bool:
"""Return whether an item should be present or filtered out.""" """Return whether an item should be present or filtered out."""
return True return True
def source_item_set( def source_item_set(
self, self,
source: Model, source: Model,
key, key,
value: "ModelItem", value: "ModelItem",
_changed_fields: Optional[Dict[str, Any]] = None, _changed_fields: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
with self.write_lock: with self.write_lock:
if self.accept_source(source): if self.accept_source(source):
value = self.convert_item(value) value = self.convert_item(value)
if self.accept_item(value): if self.accept_item(value):
self.__setitem__( self.__setitem__(
(source.sync_id, key), value, _changed_fields, (source.sync_id, key), value, _changed_fields,
) )
self.filtered_out.pop((source.sync_id, key), None) self.filtered_out.pop((source.sync_id, key), None)
else: else:
self.filtered_out[source.sync_id, key] = value self.filtered_out[source.sync_id, key] = value
self.pop((source.sync_id, key), None) self.pop((source.sync_id, key), None)
for callback in self.items_changed_callbacks: for callback in self.items_changed_callbacks:
callback() callback()
def source_item_deleted(self, source: Model, key) -> None: def source_item_deleted(self, source: Model, key) -> None:
with self.write_lock: with self.write_lock:
if self.accept_source(source): if self.accept_source(source):
try: try:
del self[source.sync_id, key] del self[source.sync_id, key]
except KeyError: except KeyError:
del self.filtered_out[source.sync_id, key] del self.filtered_out[source.sync_id, key]
for callback in self.items_changed_callbacks: for callback in self.items_changed_callbacks:
callback() callback()
def source_cleared(self, source: Model) -> None: def source_cleared(self, source: Model) -> None:
with self.write_lock: with self.write_lock:
if self.accept_source(source): if self.accept_source(source):
for source_sync_id, key in self.copy(): for source_sync_id, key in self.copy():
if source_sync_id == source.sync_id: if source_sync_id == source.sync_id:
try: try:
del self[source.sync_id, key] del self[source.sync_id, key]
except KeyError: except KeyError:
del self.filtered_out[source.sync_id, key] del self.filtered_out[source.sync_id, key]
for callback in self.items_changed_callbacks: for callback in self.items_changed_callbacks:
callback() callback()
def refilter( def refilter(
self, self,
only_if: Optional[Callable[["ModelItem"], bool]] = None, only_if: Optional[Callable[["ModelItem"], bool]] = None,
) -> None: ) -> None:
"""Recheck every item to decide if they should be filtered out.""" """Recheck every item to decide if they should be filtered out."""
with self.write_lock: with self.write_lock:
take_out = [] take_out = []
bring_back = [] bring_back = []
for key, item in sorted(self.items(), key=lambda kv: kv[1]): for key, item in sorted(self.items(), key=lambda kv: kv[1]):
if only_if and not only_if(item): if only_if and not only_if(item):
continue continue
if not self.accept_item(item): if not self.accept_item(item):
take_out.append(key) take_out.append(key)
for key, item in self.filtered_out.items(): for key, item in self.filtered_out.items():
if only_if and not only_if(item): if only_if and not only_if(item):
continue continue
if self.accept_item(item): if self.accept_item(item):
bring_back.append(key) bring_back.append(key)
with self.batch_remove(): with self.batch_remove():
for key in take_out: for key in take_out:
self.filtered_out[key] = self.pop(key) self.filtered_out[key] = self.pop(key)
for key in bring_back: for key in bring_back:
self[key] = self.filtered_out.pop(key) self[key] = self.filtered_out.pop(key)
if take_out or bring_back: if take_out or bring_back:
for callback in self.items_changed_callbacks: for callback in self.items_changed_callbacks:
callback() callback()
class FieldStringFilter(ModelFilter): class FieldStringFilter(ModelFilter):
"""Filter source models based on if their fields matches a string. """Filter source models based on if their fields matches a string.
This is used for filter fields in QML: the user enters some text and only This is used for filter fields in QML: the user enters some text and only
items with a certain field (typically `display_name`) that starts with the items with a certain field (typically `display_name`) that starts with the
entered text will be shown. entered text will be shown.
Matching is done using "smart case": insensitive if the filter text is Matching is done using "smart case": insensitive if the filter text is
all lowercase, sensitive otherwise. all lowercase, sensitive otherwise.
""" """
def __init__( def __init__(
self, self,
sync_id: SyncId, sync_id: SyncId,
fields: Collection[str], fields: Collection[str],
no_filter_accept_all_items: bool = True, no_filter_accept_all_items: bool = True,
) -> None: ) -> None:
self.fields = fields self.fields = fields
self.no_filter_accept_all_items = no_filter_accept_all_items self.no_filter_accept_all_items = no_filter_accept_all_items
self._filter: str = "" self._filter: str = ""
super().__init__(sync_id) super().__init__(sync_id)
@property @property
def filter(self) -> str: def filter(self) -> str:
return self._filter return self._filter
@filter.setter @filter.setter
def filter(self, value: str) -> None: def filter(self, value: str) -> None:
if value != self._filter: if value != self._filter:
self._filter = value self._filter = value
self.refilter() self.refilter()
def accept_item(self, item: "ModelItem") -> bool: def accept_item(self, item: "ModelItem") -> bool:
if not self.filter: if not self.filter:
return self.no_filter_accept_all_items return self.no_filter_accept_all_items
fields = {f: getattr(item, f) for f in self.fields} fields = {f: getattr(item, f) for f in self.fields}
filtr = self.filter filtr = self.filter
lowercase = filtr.lower() lowercase = filtr.lower()
if lowercase == filtr: if lowercase == filtr:
# Consider case only if filter isn't all lowercase # Consider case only if filter isn't all lowercase
filtr = lowercase filtr = lowercase
fields = {name: value.lower() for name, value in fields.items()} fields = {name: value.lower() for name, value in fields.items()}
return self.match(fields, filtr) return self.match(fields, filtr)
def match(self, fields: Dict[str, str], filtr: str) -> bool: def match(self, fields: Dict[str, str], filtr: str) -> bool:
for value in fields.values(): for value in fields.values():
if value.startswith(filtr): if value.startswith(filtr):
return True return True
return False return False
class FieldSubstringFilter(FieldStringFilter): class FieldSubstringFilter(FieldStringFilter):
"""Fuzzy-like alternative to `FieldStringFilter`. """Fuzzy-like alternative to `FieldStringFilter`.
All words in the filter string must fully or partially match words in the All words in the filter string must fully or partially match words in the
item field values, e.g. "red l" can match "red light", item field values, e.g. "red l" can match "red light",
"tired legs", "light red" (order of the filter words doesn't matter), "tired legs", "light red" (order of the filter words doesn't matter),
but not just "red" or "light" by themselves. but not just "red" or "light" by themselves.
""" """
def match(self, fields: Dict[str, str], filtr: str) -> bool: def match(self, fields: Dict[str, str], filtr: str) -> bool:
text = " ".join(fields.values()) text = " ".join(fields.values())
for word in filtr.split(): for word in filtr.split():
if word and word not in text: if word and word not in text:
return False return False
return True return True

View File

@ -23,415 +23,415 @@ ZERO_DATE = datetime.fromtimestamp(0)
class TypeSpecifier(AutoStrEnum): class TypeSpecifier(AutoStrEnum):
"""Enum providing clarification of purpose for some matrix events.""" """Enum providing clarification of purpose for some matrix events."""
Unset = auto() Unset = auto()
ProfileChange = auto() ProfileChange = auto()
MembershipChange = auto() MembershipChange = auto()
class PingStatus(AutoStrEnum): class PingStatus(AutoStrEnum):
"""Enum for the status of a homeserver ping operation.""" """Enum for the status of a homeserver ping operation."""
Done = auto() Done = auto()
Pinging = auto() Pinging = auto()
Failed = auto() Failed = auto()
class RoomNotificationOverride(AutoStrEnum): class RoomNotificationOverride(AutoStrEnum):
"""Possible per-room notification override settings, as displayed in the """Possible per-room notification override settings, as displayed in the
left sidepane's context menu when right-clicking a room. left sidepane's context menu when right-clicking a room.
""" """
UseDefaultSettings = auto() UseDefaultSettings = auto()
AllEvents = auto() AllEvents = auto()
HighlightsOnly = auto() HighlightsOnly = auto()
IgnoreEvents = auto() IgnoreEvents = auto()
@dataclass(eq=False) @dataclass(eq=False)
class Homeserver(ModelItem): class Homeserver(ModelItem):
"""A homeserver we can connect to. The `id` field is the server's URL.""" """A homeserver we can connect to. The `id` field is the server's URL."""
id: str = field() id: str = field()
name: str = field() name: str = field()
site_url: str = field() site_url: str = field()
country: str = field() country: str = field()
ping: int = -1 ping: int = -1
status: PingStatus = PingStatus.Pinging status: PingStatus = PingStatus.Pinging
stability: float = -1 stability: float = -1
downtimes_ms: List[float] = field(default_factory=list) downtimes_ms: List[float] = field(default_factory=list)
def __lt__(self, other: "Homeserver") -> bool: def __lt__(self, other: "Homeserver") -> bool:
return (self.name.lower(), self.id) < (other.name.lower(), other.id) return (self.name.lower(), self.id) < (other.name.lower(), other.id)
@dataclass(eq=False) @dataclass(eq=False)
class Account(ModelItem): class Account(ModelItem):
"""A logged in matrix account.""" """A logged in matrix account."""
id: str = field() id: str = field()
order: int = -1 order: int = -1
display_name: str = "" display_name: str = ""
avatar_url: str = "" avatar_url: str = ""
max_upload_size: int = 0 max_upload_size: int = 0
profile_updated: datetime = ZERO_DATE profile_updated: datetime = ZERO_DATE
connecting: bool = False connecting: bool = False
total_unread: int = 0 total_unread: int = 0
total_highlights: int = 0 total_highlights: int = 0
local_unreads: bool = False local_unreads: bool = False
ignored_users: Set[str] = field(default_factory=set) ignored_users: Set[str] = field(default_factory=set)
# For some reason, Account cannot inherit Presence, because QML keeps # For some reason, Account cannot inherit Presence, because QML keeps
# complaining type error on unknown file # complaining type error on unknown file
presence_support: bool = False presence_support: bool = False
save_presence: bool = True save_presence: bool = True
presence: Presence.State = Presence.State.offline presence: Presence.State = Presence.State.offline
currently_active: bool = False currently_active: bool = False
last_active_at: datetime = ZERO_DATE last_active_at: datetime = ZERO_DATE
status_msg: str = "" status_msg: str = ""
def __lt__(self, other: "Account") -> bool: def __lt__(self, other: "Account") -> bool:
return (self.order, self.id) < (other.order, other.id) return (self.order, self.id) < (other.order, other.id)
@dataclass(eq=False) @dataclass(eq=False)
class PushRule(ModelItem): class PushRule(ModelItem):
"""A push rule configured for one of our account.""" """A push rule configured for one of our account."""
id: Tuple[str, str] = field() # (kind.value, rule_id) id: Tuple[str, str] = field() # (kind.value, rule_id)
kind: nio.PushRuleKind = field() kind: nio.PushRuleKind = field()
rule_id: str = field() rule_id: str = field()
order: int = field() order: int = field()
default: bool = field() default: bool = field()
enabled: bool = True enabled: bool = True
conditions: List[Dict[str, Any]] = field(default_factory=list) conditions: List[Dict[str, Any]] = field(default_factory=list)
pattern: str = "" pattern: str = ""
actions: List[Dict[str, Any]] = field(default_factory=list) actions: List[Dict[str, Any]] = field(default_factory=list)
notify: bool = False notify: bool = False
highlight: bool = False highlight: bool = False
bubble: bool = False bubble: bool = False
sound: str = "" # usually "default" when set sound: str = "" # usually "default" when set
urgency_hint: bool = False urgency_hint: bool = False
def __lt__(self, other: "PushRule") -> bool: def __lt__(self, other: "PushRule") -> bool:
return ( return (
self.kind is nio.PushRuleKind.underride, self.kind is nio.PushRuleKind.underride,
self.kind is nio.PushRuleKind.sender, self.kind is nio.PushRuleKind.sender,
self.kind is nio.PushRuleKind.room, self.kind is nio.PushRuleKind.room,
self.kind is nio.PushRuleKind.content, self.kind is nio.PushRuleKind.content,
self.kind is nio.PushRuleKind.override, self.kind is nio.PushRuleKind.override,
self.order, self.order,
self.id, self.id,
) < ( ) < (
other.kind is nio.PushRuleKind.underride, other.kind is nio.PushRuleKind.underride,
other.kind is nio.PushRuleKind.sender, other.kind is nio.PushRuleKind.sender,
other.kind is nio.PushRuleKind.room, other.kind is nio.PushRuleKind.room,
other.kind is nio.PushRuleKind.content, other.kind is nio.PushRuleKind.content,
other.kind is nio.PushRuleKind.override, other.kind is nio.PushRuleKind.override,
other.order, other.order,
other.id, other.id,
) )
@dataclass @dataclass
class Room(ModelItem): class Room(ModelItem):
"""A matrix room we are invited to, are or were member of.""" """A matrix room we are invited to, are or were member of."""
id: str = field() id: str = field()
for_account: str = "" for_account: str = ""
given_name: str = "" given_name: str = ""
display_name: str = "" display_name: str = ""
main_alias: str = "" main_alias: str = ""
avatar_url: str = "" avatar_url: str = ""
plain_topic: str = "" plain_topic: str = ""
topic: str = "" topic: str = ""
inviter_id: str = "" inviter_id: str = ""
inviter_name: str = "" inviter_name: str = ""
inviter_avatar: str = "" inviter_avatar: str = ""
left: bool = False left: bool = False
typing_members: List[str] = field(default_factory=list) typing_members: List[str] = field(default_factory=list)
federated: bool = True federated: bool = True
encrypted: bool = False encrypted: bool = False
unverified_devices: bool = False unverified_devices: bool = False
invite_required: bool = True invite_required: bool = True
guests_allowed: bool = True guests_allowed: bool = True
default_power_level: int = 0 default_power_level: int = 0
own_power_level: int = 0 own_power_level: int = 0
can_invite: bool = False can_invite: bool = False
can_kick: bool = False can_kick: bool = False
can_redact_all: bool = False can_redact_all: bool = False
can_send_messages: bool = False can_send_messages: bool = False
can_set_name: bool = False can_set_name: bool = False
can_set_topic: bool = False can_set_topic: bool = False
can_set_avatar: bool = False can_set_avatar: bool = False
can_set_encryption: bool = False can_set_encryption: bool = False
can_set_join_rules: bool = False can_set_join_rules: bool = False
can_set_guest_access: bool = False can_set_guest_access: bool = False
can_set_power_levels: bool = False can_set_power_levels: bool = False
last_event_date: datetime = ZERO_DATE last_event_date: datetime = ZERO_DATE
unreads: int = 0 unreads: int = 0
highlights: int = 0 highlights: int = 0
local_unreads: bool = False local_unreads: bool = False
notification_setting: RoomNotificationOverride = \ notification_setting: RoomNotificationOverride = \
RoomNotificationOverride.UseDefaultSettings RoomNotificationOverride.UseDefaultSettings
lexical_sorting: bool = False lexical_sorting: bool = False
pinned: bool = False pinned: bool = False
# Allowed keys: "last_event_date", "unreads", "highlights", "local_unreads" # Allowed keys: "last_event_date", "unreads", "highlights", "local_unreads"
# Keys in this dict will override their corresponding item fields for the # Keys in this dict will override their corresponding item fields for the
# __lt__ method. This is used when we want to lock a room's position, # __lt__ method. This is used when we want to lock a room's position,
# e.g. to avoid having the room move around when it is focused in the GUI # e.g. to avoid having the room move around when it is focused in the GUI
_sort_overrides: Dict[str, Any] = field(default_factory=dict) _sort_overrides: Dict[str, Any] = field(default_factory=dict)
def _sorting(self, key: str) -> Any: def _sorting(self, key: str) -> Any:
return self._sort_overrides.get(key, getattr(self, key)) return self._sort_overrides.get(key, getattr(self, key))
def __lt__(self, other: "Room") -> bool: def __lt__(self, other: "Room") -> bool:
by_activity = not self.lexical_sorting by_activity = not self.lexical_sorting
return ( return (
self.for_account, self.for_account,
other.pinned, other.pinned,
self.left, # Left rooms may have an inviter_id, check them first self.left, # Left rooms may have an inviter_id, check them first
bool(other.inviter_id), bool(other.inviter_id),
bool(by_activity and other._sorting("highlights")), bool(by_activity and other._sorting("highlights")),
bool(by_activity and other._sorting("unreads")), bool(by_activity and other._sorting("unreads")),
bool(by_activity and other._sorting("local_unreads")), bool(by_activity and other._sorting("local_unreads")),
other._sorting("last_event_date") if by_activity else ZERO_DATE, other._sorting("last_event_date") if by_activity else ZERO_DATE,
(self.display_name or self.id).lower(), (self.display_name or self.id).lower(),
self.id, self.id,
) < ( ) < (
other.for_account, other.for_account,
self.pinned, self.pinned,
other.left, other.left,
bool(self.inviter_id), bool(self.inviter_id),
bool(by_activity and self._sorting("highlights")), bool(by_activity and self._sorting("highlights")),
bool(by_activity and self._sorting("unreads")), bool(by_activity and self._sorting("unreads")),
bool(by_activity and self._sorting("local_unreads")), bool(by_activity and self._sorting("local_unreads")),
self._sorting("last_event_date") if by_activity else ZERO_DATE, self._sorting("last_event_date") if by_activity else ZERO_DATE,
(other.display_name or other.id).lower(), (other.display_name or other.id).lower(),
other.id, other.id,
) )
@dataclass(eq=False) @dataclass(eq=False)
class AccountOrRoom(Account, Room): class AccountOrRoom(Account, Room):
"""The left sidepane in the GUI lists a mixture of accounts and rooms """The left sidepane in the GUI lists a mixture of accounts and rooms
giving a tree view illusion. Since all items in a QML ListView must have giving a tree view illusion. Since all items in a QML ListView must have
the same available properties, this class inherits both the same available properties, this class inherits both
`Account` and `Room` to fulfill that purpose. `Account` and `Room` to fulfill that purpose.
""" """
type: Union[Type[Account], Type[Room]] = Account type: Union[Type[Account], Type[Room]] = Account
account_order: int = -1 account_order: int = -1
def __lt__(self, other: "AccountOrRoom") -> bool: # type: ignore def __lt__(self, other: "AccountOrRoom") -> bool: # type: ignore
by_activity = not self.lexical_sorting by_activity = not self.lexical_sorting
return ( return (
self.account_order, self.account_order,
self.id if self.type is Account else self.for_account, self.id if self.type is Account else self.for_account,
other.type is Account, other.type is Account,
other.pinned, other.pinned,
self.left, self.left,
bool(other.inviter_id), bool(other.inviter_id),
bool(by_activity and other._sorting("highlights")), bool(by_activity and other._sorting("highlights")),
bool(by_activity and other._sorting("unreads")), bool(by_activity and other._sorting("unreads")),
bool(by_activity and other._sorting("local_unreads")), bool(by_activity and other._sorting("local_unreads")),
other._sorting("last_event_date") if by_activity else ZERO_DATE, other._sorting("last_event_date") if by_activity else ZERO_DATE,
(self.display_name or self.id).lower(), (self.display_name or self.id).lower(),
self.id, self.id,
) < ( ) < (
other.account_order, other.account_order,
other.id if other.type is Account else other.for_account, other.id if other.type is Account else other.for_account,
self.type is Account, self.type is Account,
self.pinned, self.pinned,
other.left, other.left,
bool(self.inviter_id), bool(self.inviter_id),
bool(by_activity and self._sorting("highlights")), bool(by_activity and self._sorting("highlights")),
bool(by_activity and self._sorting("unreads")), bool(by_activity and self._sorting("unreads")),
bool(by_activity and self._sorting("local_unreads")), bool(by_activity and self._sorting("local_unreads")),
self._sorting("last_event_date") if by_activity else ZERO_DATE, self._sorting("last_event_date") if by_activity else ZERO_DATE,
(other.display_name or other.id).lower(), (other.display_name or other.id).lower(),
other.id, other.id,
) )
@dataclass(eq=False) @dataclass(eq=False)
class Member(ModelItem): class Member(ModelItem):
"""A member in a matrix room.""" """A member in a matrix room."""
id: str = field() id: str = field()
display_name: str = "" display_name: str = ""
avatar_url: str = "" avatar_url: str = ""
typing: bool = False typing: bool = False
power_level: int = 0 power_level: int = 0
invited: bool = False invited: bool = False
ignored: bool = False ignored: bool = False
profile_updated: datetime = ZERO_DATE profile_updated: datetime = ZERO_DATE
last_read_event: str = "" last_read_event: str = ""
presence: Presence.State = Presence.State.offline presence: Presence.State = Presence.State.offline
currently_active: bool = False currently_active: bool = False
last_active_at: datetime = ZERO_DATE last_active_at: datetime = ZERO_DATE
status_msg: str = "" status_msg: str = ""
def __lt__(self, other: "Member") -> bool: def __lt__(self, other: "Member") -> bool:
return ( return (
self.invited, self.invited,
other.power_level, other.power_level,
self.ignored, self.ignored,
Presence.State.offline if self.ignored else self.presence, Presence.State.offline if self.ignored else self.presence,
(self.display_name or self.id[1:]).lower(), (self.display_name or self.id[1:]).lower(),
self.id, self.id,
) < ( ) < (
other.invited, other.invited,
self.power_level, self.power_level,
other.ignored, other.ignored,
Presence.State.offline if other.ignored else other.presence, Presence.State.offline if other.ignored else other.presence,
(other.display_name or other.id[1:]).lower(), (other.display_name or other.id[1:]).lower(),
other.id, other.id,
) )
class TransferStatus(AutoStrEnum): class TransferStatus(AutoStrEnum):
"""Enum describing the status of an upload operation.""" """Enum describing the status of an upload operation."""
Preparing = auto() Preparing = auto()
Transfering = auto() Transfering = auto()
Caching = auto() Caching = auto()
Error = auto() Error = auto()
@dataclass(eq=False) @dataclass(eq=False)
class Transfer(ModelItem): class Transfer(ModelItem):
"""Represent a running or failed file upload/download operation.""" """Represent a running or failed file upload/download operation."""
id: UUID = field() id: UUID = field()
is_upload: bool = field() is_upload: bool = field()
filepath: Path = Path("-") filepath: Path = Path("-")
total_size: int = 0 total_size: int = 0
transferred: int = 0 transferred: int = 0
speed: float = 0 speed: float = 0
time_left: timedelta = timedelta(0) time_left: timedelta = timedelta(0)
paused: bool = False paused: bool = False
status: TransferStatus = TransferStatus.Preparing status: TransferStatus = TransferStatus.Preparing
error: OptionalExceptionType = type(None) error: OptionalExceptionType = type(None)
error_args: Tuple[Any, ...] = () error_args: Tuple[Any, ...] = ()
start_date: datetime = field(init=False, default_factory=datetime.now) start_date: datetime = field(init=False, default_factory=datetime.now)
def __lt__(self, other: "Transfer") -> bool: def __lt__(self, other: "Transfer") -> bool:
return (self.start_date, self.id) > (other.start_date, other.id) return (self.start_date, self.id) > (other.start_date, other.id)
@dataclass(eq=False) @dataclass(eq=False)
class Event(ModelItem): class Event(ModelItem):
"""A matrix state event or message.""" """A matrix state event or message."""
id: str = field() id: str = field()
event_id: str = field() event_id: str = field()
event_type: Type[nio.Event] = field() event_type: Type[nio.Event] = field()
date: datetime = field() date: datetime = field()
sender_id: str = field() sender_id: str = field()
sender_name: str = field() sender_name: str = field()
sender_avatar: str = field() sender_avatar: str = field()
fetch_profile: bool = False fetch_profile: bool = False
content: str = "" content: str = ""
inline_content: str = "" inline_content: str = ""
reason: str = "" reason: str = ""
links: List[str] = field(default_factory=list) links: List[str] = field(default_factory=list)
mentions: List[Tuple[str, str]] = field(default_factory=list) mentions: List[Tuple[str, str]] = field(default_factory=list)
type_specifier: TypeSpecifier = TypeSpecifier.Unset type_specifier: TypeSpecifier = TypeSpecifier.Unset
target_id: str = "" target_id: str = ""
target_name: str = "" target_name: str = ""
target_avatar: str = "" target_avatar: str = ""
redacter_id: str = "" redacter_id: str = ""
redacter_name: str = "" redacter_name: str = ""
# {user_id: server_timestamp} - QML can't parse dates from JSONified dicts # {user_id: server_timestamp} - QML can't parse dates from JSONified dicts
last_read_by: Dict[str, int] = field(default_factory=dict) last_read_by: Dict[str, int] = field(default_factory=dict)
read_by_count: int = 0 read_by_count: int = 0
is_local_echo: bool = False is_local_echo: bool = False
source: Optional[nio.Event] = None source: Optional[nio.Event] = None
media_url: str = "" media_url: str = ""
media_http_url: str = "" media_http_url: str = ""
media_title: str = "" media_title: str = ""
media_width: int = 0 media_width: int = 0
media_height: int = 0 media_height: int = 0
media_duration: int = 0 media_duration: int = 0
media_size: int = 0 media_size: int = 0
media_mime: str = "" media_mime: str = ""
media_crypt_dict: Dict[str, Any] = field(default_factory=dict) media_crypt_dict: Dict[str, Any] = field(default_factory=dict)
media_local_path: Union[str, Path] = "" media_local_path: Union[str, Path] = ""
thumbnail_url: str = "" thumbnail_url: str = ""
thumbnail_mime: str = "" thumbnail_mime: str = ""
thumbnail_width: int = 0 thumbnail_width: int = 0
thumbnail_height: int = 0 thumbnail_height: int = 0
thumbnail_crypt_dict: Dict[str, Any] = field(default_factory=dict) thumbnail_crypt_dict: Dict[str, Any] = field(default_factory=dict)
def __lt__(self, other: "Event") -> bool: def __lt__(self, other: "Event") -> bool:
return (self.date, self.id) > (other.date, other.id) return (self.date, self.id) > (other.date, other.id)
@property @property
def plain_content(self) -> str: def plain_content(self) -> str:
"""Plaintext version of the event's content.""" """Plaintext version of the event's content."""
if isinstance(self.source, nio.RoomMessageText): if isinstance(self.source, nio.RoomMessageText):
return self.source.body return self.source.body
return strip_html_tags(self.content) return strip_html_tags(self.content)
@staticmethod @staticmethod
def parse_links(text: str) -> List[str]: def parse_links(text: str) -> List[str]:
"""Return list of URLs (`<a href=...>` tags) present in the content.""" """Return list of URLs (`<a href=...>` tags) present in the content."""
ignore = [] ignore = []
if "<mx-reply>" in text or "mention" in text: if "<mx-reply>" in text or "mention" in text:
parser = lxml.html.etree.HTMLParser() parser = lxml.html.etree.HTMLParser()
tree = lxml.etree.fromstring(text, parser) tree = lxml.etree.fromstring(text, parser)
ignore = [ ignore = [
lxml.etree.tostring(matching_element) lxml.etree.tostring(matching_element)
for ugly_disgusting_xpath in [ for ugly_disgusting_xpath in [
# Match mx-reply > blockquote > second a (user ID link) # Match mx-reply > blockquote > second a (user ID link)
"//mx-reply/blockquote/a[count(preceding-sibling::*)<=1]", "//mx-reply/blockquote/a[count(preceding-sibling::*)<=1]",
# Match <a> tags with a mention class # Match <a> tags with a mention class
'//a[contains(concat(" ",normalize-space(@class)," ")' '//a[contains(concat(" ",normalize-space(@class)," ")'
'," mention ")]', '," mention ")]',
] ]
for matching_element in tree.xpath(ugly_disgusting_xpath) for matching_element in tree.xpath(ugly_disgusting_xpath)
] ]
if not text.strip(): if not text.strip():
return [] return []
return [ return [
url for el, attrib, url, pos in lxml.html.iterlinks(text) url for el, attrib, url, pos in lxml.html.iterlinks(text)
if lxml.etree.tostring(el) not in ignore if lxml.etree.tostring(el) not in ignore
] ]
def serialized_field(self, field: str) -> Any: def serialized_field(self, field: str) -> Any:
if field == "source": if field == "source":
source_dict = asdict(self.source) if self.source else {} source_dict = asdict(self.source) if self.source else {}
return json.dumps(source_dict) return json.dumps(source_dict)
return super().serialized_field(field) return super().serialized_field(field)

View File

@ -5,7 +5,7 @@ import itertools
from contextlib import contextmanager from contextlib import contextmanager
from threading import RLock from threading import RLock
from typing import ( from typing import (
TYPE_CHECKING, Any, Dict, Iterator, List, MutableMapping, Optional, Tuple, TYPE_CHECKING, Any, Dict, Iterator, List, MutableMapping, Optional, Tuple,
) )
from sortedcontainers import SortedList from sortedcontainers import SortedList
@ -15,199 +15,199 @@ from ..utils import serialize_value_for_qml
from . import SyncId from . import SyncId
if TYPE_CHECKING: if TYPE_CHECKING:
from .model_item import ModelItem from .model_item import ModelItem
from .proxy import ModelProxy # noqa from .proxy import ModelProxy # noqa
class Model(MutableMapping): class Model(MutableMapping):
"""A mapping of `{ModelItem.id: ModelItem}` synced between Python & QML. """A mapping of `{ModelItem.id: ModelItem}` synced between Python & QML.
From the Python side, the model is usable like a normal dict of From the Python side, the model is usable like a normal dict of
`ModelItem` subclass objects. `ModelItem` subclass objects.
Different types of `ModelItem` must not be mixed in the same model. Different types of `ModelItem` must not be mixed in the same model.
When items are added, replaced, removed, have field value changes, or the When items are added, replaced, removed, have field value changes, or the
model is cleared, corresponding `PyOtherSideEvent` are fired to inform model is cleared, corresponding `PyOtherSideEvent` are fired to inform
QML of the changes so that it can keep its models in sync. QML of the changes so that it can keep its models in sync.
Items in the model are kept sorted using the `ModelItem` subclass `__lt__`. Items in the model are kept sorted using the `ModelItem` subclass `__lt__`.
""" """
instances: Dict[SyncId, "Model"] = {} instances: Dict[SyncId, "Model"] = {}
proxies: Dict[SyncId, "ModelProxy"] = {} proxies: Dict[SyncId, "ModelProxy"] = {}
def __init__(self, sync_id: Optional[SyncId]) -> None: def __init__(self, sync_id: Optional[SyncId]) -> None:
self.sync_id: Optional[SyncId] = sync_id self.sync_id: Optional[SyncId] = sync_id
self.write_lock: RLock = RLock() self.write_lock: RLock = RLock()
self._data: Dict[Any, "ModelItem"] = {} self._data: Dict[Any, "ModelItem"] = {}
self._sorted_data: SortedList["ModelItem"] = SortedList() self._sorted_data: SortedList["ModelItem"] = SortedList()
self.take_items_ownership: bool = True self.take_items_ownership: bool = True
# [(index, item.id), ...] # [(index, item.id), ...]
self._active_batch_removed: Optional[List[Tuple[int, Any]]] = None self._active_batch_removed: Optional[List[Tuple[int, Any]]] = None
if self.sync_id: if self.sync_id:
self.instances[self.sync_id] = self self.instances[self.sync_id] = self
def __repr__(self) -> str: def __repr__(self) -> str:
"""Provide a full representation of the model and its content.""" """Provide a full representation of the model and its content."""
return "%s(sync_id=%s, %s)" % ( return "%s(sync_id=%s, %s)" % (
type(self).__name__, self.sync_id, self._data, type(self).__name__, self.sync_id, self._data,
) )
def __str__(self) -> str: def __str__(self) -> str:
"""Provide a short "<sync_id>: <num> items" representation.""" """Provide a short "<sync_id>: <num> items" representation."""
return f"{self.sync_id}: {len(self)} items" return f"{self.sync_id}: {len(self)} items"
def __getitem__(self, key): def __getitem__(self, key):
return self._data[key] return self._data[key]
def __setitem__( def __setitem__(
self, self,
key, key,
value: "ModelItem", value: "ModelItem",
_changed_fields: Optional[Dict[str, Any]] = None, _changed_fields: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
with self.write_lock: with self.write_lock:
existing = self._data.get(key) existing = self._data.get(key)
new = value new = value
# Collect changed fields # Collect changed fields
changed_fields = _changed_fields or {} changed_fields = _changed_fields or {}
if not changed_fields: if not changed_fields:
for field in new.__dataclass_fields__: # type: ignore for field in new.__dataclass_fields__: # type: ignore
if field.startswith("_"): if field.startswith("_"):
continue continue
changed = True changed = True
if existing: if existing:
changed = \ changed = \
getattr(new, field) != getattr(existing, field) getattr(new, field) != getattr(existing, field)
if changed: if changed:
changed_fields[field] = new.serialized_field(field) changed_fields[field] = new.serialized_field(field)
# Set parent model on new item # Set parent model on new item
if self.sync_id and self.take_items_ownership: if self.sync_id and self.take_items_ownership:
new.parent_model = self new.parent_model = self
# Insert into sorted data # Insert into sorted data
index_then = None index_then = None
if existing: if existing:
index_then = self._sorted_data.index(existing) index_then = self._sorted_data.index(existing)
del self._sorted_data[index_then] del self._sorted_data[index_then]
self._sorted_data.add(new) self._sorted_data.add(new)
index_now = self._sorted_data.index(new) index_now = self._sorted_data.index(new)
# Insert into dict data # Insert into dict data
self._data[key] = new self._data[key] = new
# Callbacks # Callbacks
for sync_id, proxy in self.proxies.items(): for sync_id, proxy in self.proxies.items():
if sync_id != self.sync_id: if sync_id != self.sync_id:
proxy.source_item_set(self, key, value) proxy.source_item_set(self, key, value)
# Emit PyOtherSide event # Emit PyOtherSide event
if self.sync_id and (index_then != index_now or changed_fields): if self.sync_id and (index_then != index_now or changed_fields):
ModelItemSet( ModelItemSet(
self.sync_id, index_then, index_now, changed_fields, self.sync_id, index_then, index_now, changed_fields,
) )
def __delitem__(self, key) -> None: def __delitem__(self, key) -> None:
with self.write_lock: with self.write_lock:
item = self._data[key] item = self._data[key]
if self.sync_id and self.take_items_ownership: if self.sync_id and self.take_items_ownership:
item.parent_model = None item.parent_model = None
del self._data[key] del self._data[key]
index = self._sorted_data.index(item) index = self._sorted_data.index(item)
del self._sorted_data[index] del self._sorted_data[index]
for sync_id, proxy in self.proxies.items(): for sync_id, proxy in self.proxies.items():
if sync_id != self.sync_id: if sync_id != self.sync_id:
proxy.source_item_deleted(self, key) proxy.source_item_deleted(self, key)
if self.sync_id: if self.sync_id:
if self._active_batch_removed is None: if self._active_batch_removed is None:
i = serialize_value_for_qml(item.id, json_list_dicts=True) i = serialize_value_for_qml(item.id, json_list_dicts=True)
ModelItemDeleted(self.sync_id, index, 1, (i,)) ModelItemDeleted(self.sync_id, index, 1, (i,))
else: else:
self._active_batch_removed.append((index, item.id)) self._active_batch_removed.append((index, item.id))
def __iter__(self) -> Iterator: def __iter__(self) -> Iterator:
return iter(self._data) return iter(self._data)
def __len__(self) -> int: def __len__(self) -> int:
return len(self._data) return len(self._data)
def __lt__(self, other: "Model") -> bool: def __lt__(self, other: "Model") -> bool:
"""Sort `Model` objects lexically by `sync_id`.""" """Sort `Model` objects lexically by `sync_id`."""
return str(self.sync_id) < str(other.sync_id) return str(self.sync_id) < str(other.sync_id)
def clear(self) -> None: def clear(self) -> None:
super().clear() super().clear()
if self.sync_id: if self.sync_id:
ModelCleared(self.sync_id) ModelCleared(self.sync_id)
def copy(self, sync_id: Optional[SyncId] = None) -> "Model": def copy(self, sync_id: Optional[SyncId] = None) -> "Model":
new = type(self)(sync_id=sync_id) new = type(self)(sync_id=sync_id)
new.update(self) new.update(self)
return new return new
@contextmanager @contextmanager
def batch_remove(self): def batch_remove(self):
"""Context manager that accumulates item removal events. """Context manager that accumulates item removal events.
When the context manager exits, sequences of removed items are grouped When the context manager exits, sequences of removed items are grouped
and one `ModelItemDeleted` pyotherside event is fired per sequence. and one `ModelItemDeleted` pyotherside event is fired per sequence.
""" """
with self.write_lock: with self.write_lock:
try: try:
self._active_batch_removed = [] self._active_batch_removed = []
yield None yield None
finally: finally:
batch = self._active_batch_removed batch = self._active_batch_removed
groups = [ groups = [
list(group) for item, group in list(group) for item, group in
itertools.groupby(batch, key=lambda x: x[0]) itertools.groupby(batch, key=lambda x: x[0])
] ]
def serialize_id(id_): def serialize_id(id_):
return serialize_value_for_qml(id_, json_list_dicts=True) return serialize_value_for_qml(id_, json_list_dicts=True)
for group in groups: for group in groups:
ModelItemDeleted( ModelItemDeleted(
self.sync_id, self.sync_id,
index = group[0][0], index = group[0][0],
count = len(group), count = len(group),
ids = [serialize_id(item[1]) for item in group], ids = [serialize_id(item[1]) for item in group],
) )
self._active_batch_removed = None self._active_batch_removed = None

View File

@ -8,122 +8,122 @@ from ..pyotherside_events import ModelItemSet
from ..utils import serialize_value_for_qml from ..utils import serialize_value_for_qml
if TYPE_CHECKING: if TYPE_CHECKING:
from .model import Model from .model import Model
@dataclass(eq=False) @dataclass(eq=False)
class ModelItem: class ModelItem:
"""Base class for items stored inside a `Model`. """Base class for items stored inside a `Model`.
This class must be subclassed and not used directly. This class must be subclassed and not used directly.
All subclasses must use the `@dataclass(eq=False)` decorator. All subclasses must use the `@dataclass(eq=False)` decorator.
Subclasses are also expected to implement `__lt__()`, Subclasses are also expected to implement `__lt__()`,
to provide support for comparisons with the `<`, `>`, `<=`, `=>` operators to provide support for comparisons with the `<`, `>`, `<=`, `=>` operators
and thus allow a `Model` to keep its data sorted. and thus allow a `Model` to keep its data sorted.
Make sure to respect SortedList requirements when implementing `__lt__()`: Make sure to respect SortedList requirements when implementing `__lt__()`:
http://www.grantjenks.com/docs/sortedcontainers/introduction.html#caveats http://www.grantjenks.com/docs/sortedcontainers/introduction.html#caveats
""" """
id: Any = field() id: Any = field()
def __new__(cls, *_args, **_kwargs) -> "ModelItem": def __new__(cls, *_args, **_kwargs) -> "ModelItem":
cls.parent_model: Optional[Model] = None cls.parent_model: Optional[Model] = None
return super().__new__(cls) return super().__new__(cls)
def __setattr__(self, name: str, value) -> None: def __setattr__(self, name: str, value) -> None:
self.set_fields(**{name: value}) self.set_fields(**{name: value})
def __delattr__(self, name: str) -> None: def __delattr__(self, name: str) -> None:
raise NotImplementedError() raise NotImplementedError()
@property @property
def serialized(self) -> Dict[str, Any]: def serialized(self) -> Dict[str, Any]:
"""Return this item as a dict ready to be passed to QML.""" """Return this item as a dict ready to be passed to QML."""
return { return {
name: self.serialized_field(name) name: self.serialized_field(name)
for name in self.__dataclass_fields__ # type: ignore for name in self.__dataclass_fields__ # type: ignore
if not name.startswith("_") if not name.startswith("_")
} }
def serialized_field(self, field: str) -> Any: def serialized_field(self, field: str) -> Any:
"""Return a field's value in a form suitable for passing to QML.""" """Return a field's value in a form suitable for passing to QML."""
value = getattr(self, field) value = getattr(self, field)
return serialize_value_for_qml(value, json_list_dicts=True) return serialize_value_for_qml(value, json_list_dicts=True)
def set_fields(self, _force: bool = False, **fields: Any) -> None: def set_fields(self, _force: bool = False, **fields: Any) -> None:
"""Set one or more field's value and call `ModelItem.notify_change`. """Set one or more field's value and call `ModelItem.notify_change`.
For efficiency, to change multiple fields, this method should be For efficiency, to change multiple fields, this method should be
used rather than setting them one after another with `=` or `setattr`. used rather than setting them one after another with `=` or `setattr`.
""" """
parent = self.parent_model parent = self.parent_model
# If we're currently being created or haven't been put in a model yet: # If we're currently being created or haven't been put in a model yet:
if not parent: if not parent:
for name, value in fields.items(): for name, value in fields.items():
super().__setattr__(name, value) super().__setattr__(name, value)
return return
with parent.write_lock: with parent.write_lock:
qml_changes = {} qml_changes = {}
changes = { changes = {
name: value for name, value in fields.items() name: value for name, value in fields.items()
if _force or getattr(self, name) != value if _force or getattr(self, name) != value
} }
if not changes: if not changes:
return return
# To avoid corrupting the SortedList, we have to take out the item, # To avoid corrupting the SortedList, we have to take out the item,
# apply the field changes, *then* add it back in. # apply the field changes, *then* add it back in.
index_then = parent._sorted_data.index(self) index_then = parent._sorted_data.index(self)
del parent._sorted_data[index_then] del parent._sorted_data[index_then]
for name, value in changes.items(): for name, value in changes.items():
super().__setattr__(name, value) super().__setattr__(name, value)
is_field = name in self.__dataclass_fields__ # type: ignore is_field = name in self.__dataclass_fields__ # type: ignore
if is_field and not name.startswith("_"): if is_field and not name.startswith("_"):
qml_changes[name] = self.serialized_field(name) qml_changes[name] = self.serialized_field(name)
parent._sorted_data.add(self) parent._sorted_data.add(self)
index_now = parent._sorted_data.index(self) index_now = parent._sorted_data.index(self)
index_change = index_then != index_now index_change = index_then != index_now
# Now, inform QML about changed dataclass fields if any. # Now, inform QML about changed dataclass fields if any.
if not parent.sync_id or (not qml_changes and not index_change): if not parent.sync_id or (not qml_changes and not index_change):
return return
ModelItemSet(parent.sync_id, index_then, index_now, qml_changes) ModelItemSet(parent.sync_id, index_then, index_now, qml_changes)
# Inform any proxy connected to the parent model of the field changes # Inform any proxy connected to the parent model of the field changes
for sync_id, proxy in parent.proxies.items(): for sync_id, proxy in parent.proxies.items():
if sync_id != parent.sync_id: if sync_id != parent.sync_id:
proxy.source_item_set(parent, self.id, self, qml_changes) proxy.source_item_set(parent, self.id, self, qml_changes)
def notify_change(self, *fields: str) -> None: def notify_change(self, *fields: str) -> None:
"""Notify the parent model that fields of this item have changed. """Notify the parent model that fields of this item have changed.
The model cannot automatically detect changes inside The model cannot automatically detect changes inside
object fields, such as list or dicts having their data modified. object fields, such as list or dicts having their data modified.
In these cases, this method should be called. In these cases, this method should be called.
""" """
kwargs = {name: getattr(self, name) for name in fields} kwargs = {name: getattr(self, name) for name in fields}
kwargs["_force"] = True kwargs["_force"] = True
self.set_fields(**kwargs) self.set_fields(**kwargs)

View File

@ -8,66 +8,66 @@ from typing import Dict, List, Union
from . import SyncId from . import SyncId
from .model import Model from .model import Model
from .special_models import ( from .special_models import (
AllRooms, AutoCompletedMembers, FilteredHomeservers, FilteredMembers, AllRooms, AutoCompletedMembers, FilteredHomeservers, FilteredMembers,
MatchingAccounts, MatchingAccounts,
) )
@dataclass(frozen=True) @dataclass(frozen=True)
class ModelStore(UserDict): class ModelStore(UserDict):
"""Dict of sync ID keys and `Model` values. """Dict of sync ID keys and `Model` values.
The dict keys must be the sync ID of `Model` values. The dict keys must be the sync ID of `Model` values.
If a non-existent key is accessed, a corresponding `Model` will be If a non-existent key is accessed, a corresponding `Model` will be
created, put into the internal `data` dict and returned. created, put into the internal `data` dict and returned.
""" """
data: Dict[SyncId, Model] = field(default_factory=dict) data: Dict[SyncId, Model] = field(default_factory=dict)
def __missing__(self, key: SyncId) -> Model: def __missing__(self, key: SyncId) -> Model:
"""When accessing a non-existent model, create and return it. """When accessing a non-existent model, create and return it.
Special models rather than a generic `Model` object may be returned Special models rather than a generic `Model` object may be returned
depending on the passed key. depending on the passed key.
""" """
is_tuple = isinstance(key, tuple) is_tuple = isinstance(key, tuple)
model: Model model: Model
if key == "all_rooms": if key == "all_rooms":
model = AllRooms(self["accounts"]) model = AllRooms(self["accounts"])
elif key == "matching_accounts": elif key == "matching_accounts":
model = MatchingAccounts(self["all_rooms"]) model = MatchingAccounts(self["all_rooms"])
elif key == "filtered_homeservers": elif key == "filtered_homeservers":
model = FilteredHomeservers() model = FilteredHomeservers()
elif is_tuple and len(key) == 3 and key[2] == "filtered_members": elif is_tuple and len(key) == 3 and key[2] == "filtered_members":
model = FilteredMembers(user_id=key[0], room_id=key[1]) model = FilteredMembers(user_id=key[0], room_id=key[1])
elif is_tuple and len(key) == 3 and key[2] == "autocompleted_members": elif is_tuple and len(key) == 3 and key[2] == "autocompleted_members":
model = AutoCompletedMembers(user_id=key[0], room_id=key[1]) model = AutoCompletedMembers(user_id=key[0], room_id=key[1])
else: else:
model = Model(sync_id=key) model = Model(sync_id=key)
self.data[key] = model self.data[key] = model
return model return model
def __str__(self) -> str: def __str__(self) -> str:
"""Provide a nice overview of stored models when `print()` called.""" """Provide a nice overview of stored models when `print()` called."""
return "%s(\n %s\n)" % ( return "%s(\n %s\n)" % (
type(self).__name__, type(self).__name__,
"\n ".join(sorted(str(v) for v in self.values())), "\n ".join(sorted(str(v) for v in self.values())),
) )
async def ensure_exists_from_qml( async def ensure_exists_from_qml(
self, sync_id: Union[SyncId, List[str]], self, sync_id: Union[SyncId, List[str]],
) -> None: ) -> None:
"""Create model if it doesn't exist. Should only be called by QML.""" """Create model if it doesn't exist. Should only be called by QML."""
if isinstance(sync_id, list): # QML can't pass tuples if isinstance(sync_id, list): # QML can't pass tuples
sync_id = tuple(sync_id) sync_id = tuple(sync_id)
self[sync_id] # will call __missing__ if needed self[sync_id] # will call __missing__ if needed

View File

@ -8,68 +8,68 @@ from . import SyncId
from .model import Model from .model import Model
if TYPE_CHECKING: if TYPE_CHECKING:
from .model_item import ModelItem from .model_item import ModelItem
class ModelProxy(Model): class ModelProxy(Model):
"""Proxies data from one or more `Model` objects.""" """Proxies data from one or more `Model` objects."""
def __init__(self, sync_id: SyncId) -> None: def __init__(self, sync_id: SyncId) -> None:
super().__init__(sync_id) super().__init__(sync_id)
self.take_items_ownership = False self.take_items_ownership = False
Model.proxies[sync_id] = self Model.proxies[sync_id] = self
with self.write_lock: with self.write_lock:
for sync_id, model in Model.instances.items(): for sync_id, model in Model.instances.items():
if sync_id != self.sync_id and self.accept_source(model): if sync_id != self.sync_id and self.accept_source(model):
for key, item in model.items(): for key, item in model.items():
self.source_item_set(model, key, item) self.source_item_set(model, key, item)
def accept_source(self, source: Model) -> bool: def accept_source(self, source: Model) -> bool:
"""Return whether passed `Model` should be proxied by this proxy.""" """Return whether passed `Model` should be proxied by this proxy."""
return True return True
def convert_item(self, item: "ModelItem") -> "ModelItem": def convert_item(self, item: "ModelItem") -> "ModelItem":
"""Take a source `ModelItem`, return an appropriate one for proxy. """Take a source `ModelItem`, return an appropriate one for proxy.
By default, this returns the passed item unchanged. By default, this returns the passed item unchanged.
Due to QML `ListModel` restrictions, if multiple source models Due to QML `ListModel` restrictions, if multiple source models
containing different subclasses of `ModelItem` are proxied, containing different subclasses of `ModelItem` are proxied,
they should be converted to a same `ModelItem` they should be converted to a same `ModelItem`
subclass by overriding this function. subclass by overriding this function.
""" """
return copy(item) return copy(item)
def source_item_set( def source_item_set(
self, self,
source: Model, source: Model,
key, key,
value: "ModelItem", value: "ModelItem",
_changed_fields: Optional[Dict[str, Any]] = None, _changed_fields: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
"""Called when a source model item is added or changed.""" """Called when a source model item is added or changed."""
if self.accept_source(source): if self.accept_source(source):
value = self.convert_item(value) value = self.convert_item(value)
self.__setitem__((source.sync_id, key), value, _changed_fields) self.__setitem__((source.sync_id, key), value, _changed_fields)
def source_item_deleted(self, source: Model, key) -> None: def source_item_deleted(self, source: Model, key) -> None:
"""Called when a source model item is removed.""" """Called when a source model item is removed."""
if self.accept_source(source): if self.accept_source(source):
del self[source.sync_id, key] del self[source.sync_id, key]
def source_cleared(self, source: Model) -> None: def source_cleared(self, source: Model) -> None:
"""Called when a source model is cleared.""" """Called when a source model is cleared."""
if self.accept_source(source): if self.accept_source(source):
with self.batch_remove(): with self.batch_remove():
for source_sync_id, key in self.copy(): for source_sync_id, key in self.copy():
if source_sync_id == source.sync_id: if source_sync_id == source.sync_id:
del self[source_sync_id, key] del self[source_sync_id, key]

View File

@ -11,143 +11,143 @@ from .model_item import ModelItem
class AllRooms(FieldSubstringFilter): class AllRooms(FieldSubstringFilter):
"""Flat filtered list of all accounts and their rooms.""" """Flat filtered list of all accounts and their rooms."""
def __init__(self, accounts: Model) -> None: def __init__(self, accounts: Model) -> None:
self.accounts = accounts self.accounts = accounts
self._collapsed: Set[str] = set() self._collapsed: Set[str] = set()
super().__init__(sync_id="all_rooms", fields=("display_name",)) super().__init__(sync_id="all_rooms", fields=("display_name",))
self.items_changed_callbacks.append(self.refilter_accounts) self.items_changed_callbacks.append(self.refilter_accounts)
def set_account_collapse(self, user_id: str, collapsed: bool) -> None: def set_account_collapse(self, user_id: str, collapsed: bool) -> None:
"""Set whether the rooms for an account should be filtered out.""" """Set whether the rooms for an account should be filtered out."""
def only_if(item): def only_if(item):
return item.type is Room and item.for_account == user_id return item.type is Room and item.for_account == user_id
if collapsed and user_id not in self._collapsed: if collapsed and user_id not in self._collapsed:
self._collapsed.add(user_id) self._collapsed.add(user_id)
self.refilter(only_if) self.refilter(only_if)
if not collapsed and user_id in self._collapsed: if not collapsed and user_id in self._collapsed:
self._collapsed.remove(user_id) self._collapsed.remove(user_id)
self.refilter(only_if) self.refilter(only_if)
def accept_source(self, source: Model) -> bool: def accept_source(self, source: Model) -> bool:
return source.sync_id == "accounts" or ( return source.sync_id == "accounts" or (
isinstance(source.sync_id, tuple) and isinstance(source.sync_id, tuple) and
len(source.sync_id) == 2 and len(source.sync_id) == 2 and
source.sync_id[1] == "rooms" source.sync_id[1] == "rooms"
) )
def convert_item(self, item: ModelItem) -> AccountOrRoom: def convert_item(self, item: ModelItem) -> AccountOrRoom:
return AccountOrRoom( return AccountOrRoom(
**asdict(item), **asdict(item),
type = type(item), # type: ignore type = type(item), # type: ignore
account_order = account_order =
item.order if isinstance(item, Account) else item.order if isinstance(item, Account) else
self.accounts[item.for_account].order, # type: ignore self.accounts[item.for_account].order, # type: ignore
) )
def accept_item(self, item: ModelItem) -> bool: def accept_item(self, item: ModelItem) -> bool:
assert isinstance(item, AccountOrRoom) # nosec assert isinstance(item, AccountOrRoom) # nosec
if not self.filter and \ if not self.filter and \
item.type is Room and \ item.type is Room and \
item.for_account in self._collapsed: item.for_account in self._collapsed:
return False return False
matches_filter = super().accept_item(item) matches_filter = super().accept_item(item)
if item.type is not Account or not self.filter: if item.type is not Account or not self.filter:
return matches_filter return matches_filter
return next( return next(
(i for i in self.values() if i.for_account == item.id), False, (i for i in self.values() if i.for_account == item.id), False,
) )
def refilter_accounts(self) -> None: def refilter_accounts(self) -> None:
self.refilter(lambda i: i.type is Account) # type: ignore self.refilter(lambda i: i.type is Account) # type: ignore
class MatchingAccounts(ModelFilter): class MatchingAccounts(ModelFilter):
"""List of our accounts in `AllRooms` with at least one matching room if """List of our accounts in `AllRooms` with at least one matching room if
a `filter` is set, else list of all accounts. a `filter` is set, else list of all accounts.
""" """
def __init__(self, all_rooms: AllRooms) -> None: def __init__(self, all_rooms: AllRooms) -> None:
self.all_rooms = all_rooms self.all_rooms = all_rooms
self.all_rooms.items_changed_callbacks.append(self.refilter) self.all_rooms.items_changed_callbacks.append(self.refilter)
super().__init__(sync_id="matching_accounts") super().__init__(sync_id="matching_accounts")
def accept_source(self, source: Model) -> bool: def accept_source(self, source: Model) -> bool:
return source.sync_id == "accounts" return source.sync_id == "accounts"
def accept_item(self, item: ModelItem) -> bool: def accept_item(self, item: ModelItem) -> bool:
if not self.all_rooms.filter: if not self.all_rooms.filter:
return True return True
return next( return next(
(i for i in self.all_rooms.values() if i.id == item.id), (i for i in self.all_rooms.values() if i.id == item.id),
False, False,
) )
class FilteredMembers(FieldSubstringFilter): class FilteredMembers(FieldSubstringFilter):
"""Filtered list of members for a room.""" """Filtered list of members for a room."""
def __init__(self, user_id: str, room_id: str) -> None: def __init__(self, user_id: str, room_id: str) -> None:
self.user_id = user_id self.user_id = user_id
self.room_id = room_id self.room_id = room_id
sync_id = (user_id, room_id, "filtered_members") sync_id = (user_id, room_id, "filtered_members")
super().__init__(sync_id=sync_id, fields=("display_name",)) super().__init__(sync_id=sync_id, fields=("display_name",))
def accept_source(self, source: Model) -> bool: def accept_source(self, source: Model) -> bool:
return source.sync_id == (self.user_id, self.room_id, "members") return source.sync_id == (self.user_id, self.room_id, "members")
class AutoCompletedMembers(FieldStringFilter): class AutoCompletedMembers(FieldStringFilter):
"""Filtered list of mentionable members for tab-completion.""" """Filtered list of mentionable members for tab-completion."""
def __init__(self, user_id: str, room_id: str) -> None: def __init__(self, user_id: str, room_id: str) -> None:
self.user_id = user_id self.user_id = user_id
self.room_id = room_id self.room_id = room_id
sync_id = (user_id, room_id, "autocompleted_members") sync_id = (user_id, room_id, "autocompleted_members")
super().__init__( super().__init__(
sync_id = sync_id, sync_id = sync_id,
fields = ("display_name", "id"), fields = ("display_name", "id"),
no_filter_accept_all_items = False, no_filter_accept_all_items = False,
) )
def accept_source(self, source: Model) -> bool: def accept_source(self, source: Model) -> bool:
return source.sync_id == (self.user_id, self.room_id, "members") return source.sync_id == (self.user_id, self.room_id, "members")
def match(self, fields: Dict[str, str], filtr: str) -> bool: def match(self, fields: Dict[str, str], filtr: str) -> bool:
fields["id"] = fields["id"][1:] # remove leading @ fields["id"] = fields["id"][1:] # remove leading @
return super().match(fields, filtr) return super().match(fields, filtr)
class FilteredHomeservers(FieldSubstringFilter): class FilteredHomeservers(FieldSubstringFilter):
"""Filtered list of public Matrix homeservers.""" """Filtered list of public Matrix homeservers."""
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(sync_id="filtered_homeservers", fields=("id", "name")) super().__init__(sync_id="filtered_homeservers", fields=("id", "name"))
def accept_source(self, source: Model) -> bool: def accept_source(self, source: Model) -> bool:
return source.sync_id == "homeservers" return source.sync_id == "homeservers"

File diff suppressed because it is too large Load Diff

View File

@ -2,46 +2,46 @@ from collections import UserDict
from typing import TYPE_CHECKING, Any, Dict, Iterator from typing import TYPE_CHECKING, Any, Dict, Iterator
if TYPE_CHECKING: if TYPE_CHECKING:
from .section import Section from .section import Section
from .. import color from .. import color
PCN_GLOBALS: Dict[str, Any] = { PCN_GLOBALS: Dict[str, Any] = {
"color": color.Color, "color": color.Color,
"hsluv": color.hsluv, "hsluv": color.hsluv,
"hsluva": color.hsluva, "hsluva": color.hsluva,
"hsl": color.hsl, "hsl": color.hsl,
"hsla": color.hsla, "hsla": color.hsla,
"rgb": color.rgb, "rgb": color.rgb,
"rgba": color.rgba, "rgba": color.rgba,
} }
class GlobalsDict(UserDict): class GlobalsDict(UserDict):
def __init__(self, section: "Section") -> None: def __init__(self, section: "Section") -> None:
super().__init__() super().__init__()
self.section = section self.section = section
@property @property
def full_dict(self) -> Dict[str, Any]: def full_dict(self) -> Dict[str, Any]:
return { return {
**PCN_GLOBALS, **PCN_GLOBALS,
**(self.section.root if self.section.root else {}), **(self.section.root if self.section.root else {}),
**(self.section.root.globals if self.section.root else {}), **(self.section.root.globals if self.section.root else {}),
"self": self.section, "self": self.section,
"parent": self.section.parent, "parent": self.section.parent,
"root": self.section.parent, "root": self.section.parent,
**self.data, **self.data,
} }
def __getitem__(self, key: str) -> Any: def __getitem__(self, key: str) -> Any:
return self.full_dict[key] return self.full_dict[key]
def __iter__(self) -> Iterator[str]: def __iter__(self) -> Iterator[str]:
return iter(self.full_dict) return iter(self.full_dict)
def __len__(self) -> int: def __len__(self) -> int:
return len(self.full_dict) return len(self.full_dict)
def __repr__(self) -> str: def __repr__(self) -> str:
return repr(self.full_dict) return repr(self.full_dict)

View File

@ -3,50 +3,50 @@ from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Dict, Type from typing import TYPE_CHECKING, Any, Callable, Dict, Type
if TYPE_CHECKING: if TYPE_CHECKING:
from .section import Section from .section import Section
TYPE_PROCESSORS: Dict[str, Callable[[Any], Any]] = { TYPE_PROCESSORS: Dict[str, Callable[[Any], Any]] = {
"tuple": lambda v: tuple(v), "tuple": lambda v: tuple(v),
"set": lambda v: set(v), "set": lambda v: set(v),
} }
class Unset: class Unset:
pass pass
@dataclass @dataclass
class Property: class Property:
name: str = field() name: str = field()
annotation: str = field() annotation: str = field()
expression: str = field() expression: str = field()
section: "Section" = field() section: "Section" = field()
value_override: Any = Unset value_override: Any = Unset
def __get__(self, obj: "Section", objtype: Type["Section"]) -> Any: def __get__(self, obj: "Section", objtype: Type["Section"]) -> Any:
if not obj: if not obj:
return self return self
if self.value_override is not Unset: if self.value_override is not Unset:
return self.value_override return self.value_override
env = obj.globals env = obj.globals
result = eval(self.expression, dict(env), env) # nosec result = eval(self.expression, dict(env), env) # nosec
return process_value(self.annotation, result) return process_value(self.annotation, result)
def __set__(self, obj: "Section", value: Any) -> None: def __set__(self, obj: "Section", value: Any) -> None:
self.value_override = value self.value_override = value
obj._edited[self.name] = value obj._edited[self.name] = value
def process_value(annotation: str, value: Any) -> Any: def process_value(annotation: str, value: Any) -> Any:
annotation = re.sub(r"\[.*\]$", "", annotation) annotation = re.sub(r"\[.*\]$", "", annotation)
if annotation in TYPE_PROCESSORS: if annotation in TYPE_PROCESSORS:
return TYPE_PROCESSORS[annotation](value) return TYPE_PROCESSORS[annotation](value)
if annotation.lower() in TYPE_PROCESSORS: if annotation.lower() in TYPE_PROCESSORS:
return TYPE_PROCESSORS[annotation.lower()](value) return TYPE_PROCESSORS[annotation.lower()](value)
return value return value

View File

@ -7,8 +7,8 @@ from dataclasses import dataclass, field
from operator import attrgetter from operator import attrgetter
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
Any, Callable, ClassVar, Dict, Generator, List, Optional, Set, Tuple, Type, Any, Callable, ClassVar, Dict, Generator, List, Optional, Set, Tuple, Type,
Union, Union,
) )
import pyotherside import pyotherside
@ -25,423 +25,423 @@ assert BUILTINS_DIR.name == "src"
@dataclass(repr=False, eq=False) @dataclass(repr=False, eq=False)
class Section(MutableMapping): class Section(MutableMapping):
sections: ClassVar[Set[str]] = set() sections: ClassVar[Set[str]] = set()
methods: ClassVar[Set[str]] = set() methods: ClassVar[Set[str]] = set()
properties: ClassVar[Set[str]] = set() properties: ClassVar[Set[str]] = set()
order: ClassVar[Dict[str, None]] = OrderedDict() order: ClassVar[Dict[str, None]] = OrderedDict()
source_path: Optional[Path] = None source_path: Optional[Path] = None
root: Optional["Section"] = None root: Optional["Section"] = None
parent: Optional["Section"] = None parent: Optional["Section"] = None
builtins_path: Path = BUILTINS_DIR builtins_path: Path = BUILTINS_DIR
included: List[Path] = field(default_factory=list) included: List[Path] = field(default_factory=list)
globals: GlobalsDict = field(init=False) globals: GlobalsDict = field(init=False)
_edited: Dict[str, Any] = field(init=False, default_factory=dict) _edited: Dict[str, Any] = field(init=False, default_factory=dict)
def __init_subclass__(cls, **kwargs) -> None: def __init_subclass__(cls, **kwargs) -> None:
# Make these attributes not shared between Section and its subclasses # Make these attributes not shared between Section and its subclasses
cls.sections = set() cls.sections = set()
cls.methods = set() cls.methods = set()
cls.properties = set() cls.properties = set()
cls.order = OrderedDict() cls.order = OrderedDict()
for parent_class in cls.__bases__: for parent_class in cls.__bases__:
if not issubclass(parent_class, Section): if not issubclass(parent_class, Section):
continue continue
cls.sections |= parent_class.sections # union operator cls.sections |= parent_class.sections # union operator
cls.methods |= parent_class.methods cls.methods |= parent_class.methods
cls.properties |= parent_class.properties cls.properties |= parent_class.properties
cls.order.update(parent_class.order) cls.order.update(parent_class.order)
super().__init_subclass__(**kwargs) # type: ignore super().__init_subclass__(**kwargs) # type: ignore
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.globals = GlobalsDict(self) self.globals = GlobalsDict(self)
def __getattr__(self, name: str) -> Union["Section", Any]: def __getattr__(self, name: str) -> Union["Section", Any]:
# This method signature tells mypy about the dynamic attribute types # This method signature tells mypy about the dynamic attribute types
# we can access. The body is run for attributes that aren't found. # we can access. The body is run for attributes that aren't found.
return super().__getattribute__(name) return super().__getattribute__(name)
def __setattr__(self, name: str, value: Any) -> None: def __setattr__(self, name: str, value: Any) -> None:
# This method tells mypy about the dynamic attribute types we can set. # This method tells mypy about the dynamic attribute types we can set.
# The body is also run when setting an existing or new attribute. # The body is also run when setting an existing or new attribute.
if name in self.__dataclass_fields__: if name in self.__dataclass_fields__:
super().__setattr__(name, value) super().__setattr__(name, value)
return return
if name in self.properties: if name in self.properties:
value = process_value(getattr(type(self), name).annotation, value) value = process_value(getattr(type(self), name).annotation, value)
if self[name] == value: if self[name] == value:
return return
getattr(type(self), name).value_override = value getattr(type(self), name).value_override = value
self._edited[name] = value self._edited[name] = value
return return
if name in self.sections or isinstance(value, Section): if name in self.sections or isinstance(value, Section):
raise NotImplementedError(f"cannot set section {name!r}") raise NotImplementedError(f"cannot set section {name!r}")
if name in self.methods or callable(value): if name in self.methods or callable(value):
raise NotImplementedError(f"cannot set method {name!r}") raise NotImplementedError(f"cannot set method {name!r}")
self._set_property(name, "Any", "None") self._set_property(name, "Any", "None")
getattr(type(self), name).value_override = value getattr(type(self), name).value_override = value
self._edited[name] = value self._edited[name] = value
def __delattr__(self, name: str) -> None: def __delattr__(self, name: str) -> None:
raise NotImplementedError(f"cannot delete existing attribute {name!r}") raise NotImplementedError(f"cannot delete existing attribute {name!r}")
def __getitem__(self, key: str) -> Any: def __getitem__(self, key: str) -> Any:
try: try:
return getattr(self, key) return getattr(self, key)
except AttributeError as err: except AttributeError as err:
raise KeyError(str(err)) raise KeyError(str(err))
def __setitem__(self, key: str, value: Union["Section", str]) -> None: def __setitem__(self, key: str, value: Union["Section", str]) -> None:
setattr(self, key, value) setattr(self, key, value)
def __delitem__(self, key: str) -> None: def __delitem__(self, key: str) -> None:
delattr(self, key) delattr(self, key)
def __iter__(self) -> Generator[str, None, None]: def __iter__(self) -> Generator[str, None, None]:
for attr_name in self.order: for attr_name in self.order:
yield attr_name yield attr_name
def __len__(self) -> int: def __len__(self) -> int:
return len(self.order) return len(self.order)
def __eq__(self, obj: Any) -> bool: def __eq__(self, obj: Any) -> bool:
if not isinstance(obj, Section): if not isinstance(obj, Section):
return False return False
if self.globals.data != obj.globals.data or self.order != obj.order: if self.globals.data != obj.globals.data or self.order != obj.order:
return False return False
return not any(self[attr] != obj[attr] for attr in self.order) return not any(self[attr] != obj[attr] for attr in self.order)
def __repr__(self) -> str: def __repr__(self) -> str:
name: str = type(self).__name__ name: str = type(self).__name__
children: List[str] = [] children: List[str] = []
content: str = "" content: str = ""
newline: bool = False newline: bool = False
for attr_name in self.order: for attr_name in self.order:
value = getattr(self, attr_name) value = getattr(self, attr_name)
if attr_name in self.sections: if attr_name in self.sections:
before = "\n" if children else "" before = "\n" if children else ""
newline = True newline = True
try: try:
children.append(f"{before}{value!r},") children.append(f"{before}{value!r},")
except RecursionError as err: except RecursionError as err:
name = type(value).__name__ name = type(value).__name__
children.append(f"{before}{name}(\n {err!r}\n),") children.append(f"{before}{name}(\n {err!r}\n),")
pass pass
elif attr_name in self.methods: elif attr_name in self.methods:
before = "\n" if children else "" before = "\n" if children else ""
newline = True newline = True
children.append(f"{before}def {value.__name__}(…),") children.append(f"{before}def {value.__name__}(…),")
elif attr_name in self.properties: elif attr_name in self.properties:
before = "\n" if newline else "" before = "\n" if newline else ""
newline = False newline = False
try: try:
children.append(f"{before}{attr_name} = {value!r},") children.append(f"{before}{attr_name} = {value!r},")
except RecursionError as err: except RecursionError as err:
children.append(f"{before}{attr_name} = {err!r},") children.append(f"{before}{attr_name} = {err!r},")
else: else:
newline = False newline = False
if children: if children:
content = "\n%s\n" % textwrap.indent("\n".join(children), " " * 4) content = "\n%s\n" % textwrap.indent("\n".join(children), " " * 4)
return f"{name}({content})" return f"{name}({content})"
def children(self) -> Tuple[Tuple[str, Union["Section", Any]], ...]: def children(self) -> Tuple[Tuple[str, Union["Section", Any]], ...]:
"""Return pairs of (name, value) for child sections and properties.""" """Return pairs of (name, value) for child sections and properties."""
return tuple((name, getattr(self, name)) for name in self) return tuple((name, getattr(self, name)) for name in self)
@classmethod @classmethod
def _register_set_attr(cls, name: str, add_to_set_name: str) -> None: def _register_set_attr(cls, name: str, add_to_set_name: str) -> None:
cls.methods.discard(name) cls.methods.discard(name)
cls.properties.discard(name) cls.properties.discard(name)
cls.sections.discard(name) cls.sections.discard(name)
getattr(cls, add_to_set_name).add(name) getattr(cls, add_to_set_name).add(name)
cls.order[name] = None cls.order[name] = None
for subclass in cls.__subclasses__(): for subclass in cls.__subclasses__():
subclass._register_set_attr(name, add_to_set_name) subclass._register_set_attr(name, add_to_set_name)
def _set_section(self, section: "Section") -> None: def _set_section(self, section: "Section") -> None:
name = type(section).__name__ name = type(section).__name__
if hasattr(self, name) and name not in self.order: if hasattr(self, name) and name not in self.order:
raise AttributeError(f"{name!r}: forbidden name") raise AttributeError(f"{name!r}: forbidden name")
if name in self.sections: if name in self.sections:
self[name].deep_merge(section) self[name].deep_merge(section)
return return
self._register_set_attr(name, "sections") self._register_set_attr(name, "sections")
setattr(type(self), name, section) setattr(type(self), name, section)
def _set_method(self, name: str, method: Callable) -> None: def _set_method(self, name: str, method: Callable) -> None:
if hasattr(self, name) and name not in self.order: if hasattr(self, name) and name not in self.order:
raise AttributeError(f"{name!r}: forbidden name") raise AttributeError(f"{name!r}: forbidden name")
self._register_set_attr(name, "methods") self._register_set_attr(name, "methods")
setattr(type(self), name, method) setattr(type(self), name, method)
def _set_property( def _set_property(
self, name: str, annotation: str, expression: str, self, name: str, annotation: str, expression: str,
) -> None: ) -> None:
if hasattr(self, name) and name not in self.order: if hasattr(self, name) and name not in self.order:
raise AttributeError(f"{name!r}: forbidden name") raise AttributeError(f"{name!r}: forbidden name")
prop = Property(name, annotation, expression, self) prop = Property(name, annotation, expression, self)
self._register_set_attr(name, "properties") self._register_set_attr(name, "properties")
setattr(type(self), name, prop) setattr(type(self), name, prop)
def deep_merge(self, section2: "Section") -> None: def deep_merge(self, section2: "Section") -> None:
self.included += section2.included self.included += section2.included
for key in section2: for key in section2:
if key in self.sections and key in section2.sections: if key in self.sections and key in section2.sections:
self.globals.data.update(section2.globals.data) self.globals.data.update(section2.globals.data)
self[key].deep_merge(section2[key]) self[key].deep_merge(section2[key])
elif key in section2.sections: elif key in section2.sections:
self.globals.data.update(section2.globals.data) self.globals.data.update(section2.globals.data)
new_type = type(key, (Section,), {}) new_type = type(key, (Section,), {})
instance = new_type( instance = new_type(
source_path = self.source_path, source_path = self.source_path,
root = self.root or self, root = self.root or self,
parent = self, parent = self,
builtins_path = self.builtins_path, builtins_path = self.builtins_path,
) )
self._set_section(instance) self._set_section(instance)
instance.deep_merge(section2[key]) instance.deep_merge(section2[key])
elif key in section2.methods: elif key in section2.methods:
self._set_method(key, section2[key]) self._set_method(key, section2[key])
else: else:
prop2 = getattr(type(section2), key) prop2 = getattr(type(section2), key)
self._set_property(key, prop2.annotation, prop2.expression) self._set_property(key, prop2.annotation, prop2.expression)
def include_file(self, path: Union[Path, str]) -> None: def include_file(self, path: Union[Path, str]) -> None:
path = Path(path) path = Path(path)
if not path.is_absolute() and self.source_path: if not path.is_absolute() and self.source_path:
path = self.source_path.parent / path path = self.source_path.parent / path
with suppress(ValueError): with suppress(ValueError):
self.included.remove(path) self.included.remove(path)
self.included.append(path) self.included.append(path)
self.deep_merge(Section.from_file(path)) self.deep_merge(Section.from_file(path))
def include_builtin(self, relative_path: Union[Path, str]) -> None: def include_builtin(self, relative_path: Union[Path, str]) -> None:
path = self.builtins_path / relative_path path = self.builtins_path / relative_path
with suppress(ValueError): with suppress(ValueError):
self.included.remove(path) self.included.remove(path)
self.included.append(path) self.included.append(path)
self.deep_merge(Section.from_file(path)) self.deep_merge(Section.from_file(path))
def as_dict(self, _section: Optional["Section"] = None) -> Dict[str, Any]: def as_dict(self, _section: Optional["Section"] = None) -> Dict[str, Any]:
dct = {} dct = {}
section = self if _section is None else _section section = self if _section is None else _section
for key, value in section.items(): for key, value in section.items():
if isinstance(value, Section): if isinstance(value, Section):
dct[key] = self.as_dict(value) dct[key] = self.as_dict(value)
else: else:
dct[key] = value dct[key] = value
return dct return dct
def edits_as_dict( def edits_as_dict(
self, _section: Optional["Section"] = None, self, _section: Optional["Section"] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
warning = ( warning = (
"This file is generated when settings are changed from the GUI, " "This file is generated when settings are changed from the GUI, "
"and properties in it override the ones in the corresponding " "and properties in it override the ones in the corresponding "
"PCN user config file. " "PCN user config file. "
"If a property is gets changed in the PCN file, any corresponding " "If a property is gets changed in the PCN file, any corresponding "
"property override here is removed." "property override here is removed."
) )
if _section is None: if _section is None:
section = self section = self
dct = {"__comment": warning, "set": section._edited.copy()} dct = {"__comment": warning, "set": section._edited.copy()}
add_to = dct["set"] add_to = dct["set"]
else: else:
section = _section section = _section
dct = { dct = {
prop_name: ( prop_name: (
getattr(type(section), prop_name).expression, getattr(type(section), prop_name).expression,
value_override, value_override,
) )
for prop_name, value_override in section._edited.items() for prop_name, value_override in section._edited.items()
} }
add_to = dct add_to = dct
for name in section.sections: for name in section.sections:
edits = section.edits_as_dict(section[name]) edits = section.edits_as_dict(section[name])
if edits: if edits:
add_to[name] = edits # type: ignore add_to[name] = edits # type: ignore
return dct return dct
def deep_merge_edits( def deep_merge_edits(
self, edits: Dict[str, Any], has_expressions: bool = True, self, edits: Dict[str, Any], has_expressions: bool = True,
) -> bool: ) -> bool:
changes = False changes = False
if not self.parent: # this is Root if not self.parent: # this is Root
edits = edits.get("set", {}) edits = edits.get("set", {})
for name, value in edits.copy().items(): for name, value in edits.copy().items():
if isinstance(self.get(name), Section) and isinstance(value, dict): if isinstance(self.get(name), Section) and isinstance(value, dict):
if self[name].deep_merge_edits(value, has_expressions): if self[name].deep_merge_edits(value, has_expressions):
changes = True changes = True
elif not has_expressions: elif not has_expressions:
self[name] = value self[name] = value
elif isinstance(value, (tuple, list)): elif isinstance(value, (tuple, list)):
user_expression, gui_value = value user_expression, gui_value = value
if not hasattr(type(self), name): if not hasattr(type(self), name):
self[name] = gui_value self[name] = gui_value
elif getattr(type(self), name).expression == user_expression: elif getattr(type(self), name).expression == user_expression:
self[name] = gui_value self[name] = gui_value
else: else:
# If user changed their config file, discard the GUI edit # If user changed their config file, discard the GUI edit
del edits[name] del edits[name]
changes = True changes = True
return changes return changes
@property @property
def all_includes(self) -> Generator[Path, None, None]: def all_includes(self) -> Generator[Path, None, None]:
yield from self.included yield from self.included
for sub in self.sections: for sub in self.sections:
yield from self[sub].all_includes yield from self[sub].all_includes
@classmethod @classmethod
def from_source_code( def from_source_code(
cls, cls,
code: str, code: str,
path: Optional[Path] = None, path: Optional[Path] = None,
builtins: Optional[Path] = None, builtins: Optional[Path] = None,
*, *,
inherit: Tuple[Type["Section"], ...] = (), inherit: Tuple[Type["Section"], ...] = (),
node: Union[None, red.RedBaron, red.ClassNode] = None, node: Union[None, red.RedBaron, red.ClassNode] = None,
name: str = "Root", name: str = "Root",
root: Optional["Section"] = None, root: Optional["Section"] = None,
parent: Optional["Section"] = None, parent: Optional["Section"] = None,
) -> "Section": ) -> "Section":
builtins = builtins or BUILTINS_DIR builtins = builtins or BUILTINS_DIR
section: Type["Section"] = type(name, inherit or (Section,), {}) section: Type["Section"] = type(name, inherit or (Section,), {})
instance: Section = section(path, root, parent, builtins) instance: Section = section(path, root, parent, builtins)
node = node or red.RedBaron(code) node = node or red.RedBaron(code)
for child in node.node_list: for child in node.node_list:
if isinstance(child, red.ClassNode): if isinstance(child, red.ClassNode):
root_arg = instance if root is None else root root_arg = instance if root is None else root
child_inherit = [] child_inherit = []
for name in child.inherit_from.dumps().split(","): for name in child.inherit_from.dumps().split(","):
name = name.strip() name = name.strip()
if name: if name:
child_inherit.append(type(attrgetter(name)(root_arg))) child_inherit.append(type(attrgetter(name)(root_arg)))
instance._set_section(section.from_source_code( instance._set_section(section.from_source_code(
code = code, code = code,
path = path, path = path,
builtins = builtins, builtins = builtins,
inherit = tuple(child_inherit), inherit = tuple(child_inherit),
node = child, node = child,
name = child.name, name = child.name,
root = root_arg, root = root_arg,
parent = instance, parent = instance,
)) ))
elif isinstance(child, red.AssignmentNode): elif isinstance(child, red.AssignmentNode):
if isinstance(child.target, red.NameNode): if isinstance(child.target, red.NameNode):
name = child.target.value name = child.target.value
else: else:
name = str(child.target.to_python()) name = str(child.target.to_python())
instance._set_property( instance._set_property(
name, name,
child.annotation.dumps() if child.annotation else "", child.annotation.dumps() if child.annotation else "",
child.value.dumps(), child.value.dumps(),
) )
else: else:
env = instance.globals env = instance.globals
exec(child.dumps(), dict(env), env) # nosec exec(child.dumps(), dict(env), env) # nosec
if isinstance(child, red.DefNode): if isinstance(child, red.DefNode):
instance._set_method(child.name, env[child.name]) instance._set_method(child.name, env[child.name])
return instance return instance
@classmethod @classmethod
def from_file( def from_file(
cls, path: Union[str, Path], builtins: Union[str, Path] = BUILTINS_DIR, cls, path: Union[str, Path], builtins: Union[str, Path] = BUILTINS_DIR,
) -> "Section": ) -> "Section":
path = Path(re.sub(r"^qrc:/", "", str(path))) path = Path(re.sub(r"^qrc:/", "", str(path)))
try: try:
content = pyotherside.qrc_get_file_contents(str(path)).decode() content = pyotherside.qrc_get_file_contents(str(path)).decode()
except ValueError: # App was compiled without QRC except ValueError: # App was compiled without QRC
content = path.read_text() content = path.read_text()
return Section.from_source_code(content, path, Path(builtins)) return Section.from_source_code(content, path, Path(builtins))

View File

@ -8,89 +8,89 @@ from typing import TYPE_CHECKING, Dict, Optional
from .utils import AutoStrEnum, auto from .utils import AutoStrEnum, auto
if TYPE_CHECKING: if TYPE_CHECKING:
from .models.items import Account, Member from .models.items import Account, Member
ORDER: Dict[str, int] = { ORDER: Dict[str, int] = {
"online": 0, "online": 0,
"unavailable": 1, "unavailable": 1,
"invisible": 2, "invisible": 2,
"offline": 3, "offline": 3,
} }
@dataclass @dataclass
class Presence: class Presence:
"""Represents a single matrix user's presence fields. """Represents a single matrix user's presence fields.
These objects are stored in `Backend.presences`, indexed by user ID. These objects are stored in `Backend.presences`, indexed by user ID.
It must only be instanced when receiving a `PresenceEvent` or It must only be instanced when receiving a `PresenceEvent` or
registering an `Account` model item. registering an `Account` model item.
When receiving a `PresenceEvent`, we get or create a `Presence` object in When receiving a `PresenceEvent`, we get or create a `Presence` object in
`Backend.presences` for the targeted user. If the user is registered in any `Backend.presences` for the targeted user. If the user is registered in any
room, add its `Member` model item to `members`. Finally, update every room, add its `Member` model item to `members`. Finally, update every
`Member` presence fields inside `members`. `Member` presence fields inside `members`.
When a room member is registered, we try to find a `Presence` in When a room member is registered, we try to find a `Presence` in
`Backend.presences` for that user ID. If found, the `Member` item is added `Backend.presences` for that user ID. If found, the `Member` item is added
to `members`. to `members`.
When an Account model is registered, we create a `Presence` in When an Account model is registered, we create a `Presence` in
`Backend.presences` for the accountu's user ID whether the server supports `Backend.presences` for the accountu's user ID whether the server supports
presence or not (we cannot know yet at this point), presence or not (we cannot know yet at this point),
and assign that `Account` to the `Presence.account` field. and assign that `Account` to the `Presence.account` field.
Special attributes: Special attributes:
members: A `{room_id: Member}` dict for storing room members related to members: A `{room_id: Member}` dict for storing room members related to
this `Presence`. As each room has its own `Member`s objects, we this `Presence`. As each room has its own `Member`s objects, we
have to keep track of their presence fields. `Member`s are indexed have to keep track of their presence fields. `Member`s are indexed
by room ID. by room ID.
account: `Account` related to this `Presence`, if any. Should be account: `Account` related to this `Presence`, if any. Should be
assigned when client starts (`MatrixClient._start()`) and assigned when client starts (`MatrixClient._start()`) and
cleared when client stops (`MatrixClient._start()`). cleared when client stops (`MatrixClient._start()`).
""" """
class State(AutoStrEnum): class State(AutoStrEnum):
offline = auto() # can mean offline, invisible or unknwon offline = auto() # can mean offline, invisible or unknwon
unavailable = auto() unavailable = auto()
online = auto() online = auto()
invisible = auto() invisible = auto()
def __lt__(self, other: "Presence.State") -> bool: def __lt__(self, other: "Presence.State") -> bool:
return ORDER[self.value] < ORDER[other.value] return ORDER[self.value] < ORDER[other.value]
presence: State = State.offline presence: State = State.offline
currently_active: bool = False currently_active: bool = False
last_active_at: datetime = datetime.fromtimestamp(0) last_active_at: datetime = datetime.fromtimestamp(0)
status_msg: str = "" status_msg: str = ""
members: Dict[str, "Member"] = field(default_factory=dict) members: Dict[str, "Member"] = field(default_factory=dict)
account: Optional["Account"] = None account: Optional["Account"] = None
def update_members(self) -> None: def update_members(self) -> None:
"""Update presence fields of every `Member` in `members`. """Update presence fields of every `Member` in `members`.
Currently it is only called when receiving a `PresenceEvent` and when Currently it is only called when receiving a `PresenceEvent` and when
registering room members. registering room members.
""" """
for member in self.members.values(): for member in self.members.values():
member.set_fields( member.set_fields(
presence = self.presence, presence = self.presence,
status_msg = self.status_msg, status_msg = self.status_msg,
last_active_at = self.last_active_at, last_active_at = self.last_active_at,
currently_active = self.currently_active, currently_active = self.currently_active,
) )
def update_account(self) -> None: def update_account(self) -> None:
"""Update presence fields of `Account` related to this `Presence`.""" """Update presence fields of `Account` related to this `Presence`."""
if self.account: if self.account:
self.account.set_fields( self.account.set_fields(
presence = self.presence, presence = self.presence,
status_msg = self.status_msg, status_msg = self.status_msg,
last_active_at = self.last_active_at, last_active_at = self.last_active_at,
currently_active = self.currently_active, currently_active = self.currently_active,
) )

View File

@ -10,117 +10,117 @@ import pyotherside
from .utils import serialize_value_for_qml from .utils import serialize_value_for_qml
if TYPE_CHECKING: if TYPE_CHECKING:
from .models import SyncId from .models import SyncId
from .user_files import UserFile from .user_files import UserFile
@dataclass @dataclass
class PyOtherSideEvent: class PyOtherSideEvent:
"""Event that will be sent on instanciation to QML by PyOtherSide.""" """Event that will be sent on instanciation to QML by PyOtherSide."""
def __post_init__(self) -> None: def __post_init__(self) -> None:
# XXX: CPython 3.6 or any Python implemention >= 3.7 is required for # XXX: CPython 3.6 or any Python implemention >= 3.7 is required for
# correct __dataclass_fields__ dict order. # correct __dataclass_fields__ dict order.
args = [ args = [
serialize_value_for_qml(getattr(self, field)) serialize_value_for_qml(getattr(self, field))
for field in self.__dataclass_fields__ # type: ignore for field in self.__dataclass_fields__ # type: ignore
if field != "callbacks" if field != "callbacks"
] ]
pyotherside.send(type(self).__name__, *args) pyotherside.send(type(self).__name__, *args)
@dataclass @dataclass
class NotificationRequested(PyOtherSideEvent): class NotificationRequested(PyOtherSideEvent):
"""Request a notification bubble, sound or window urgency hint. """Request a notification bubble, sound or window urgency hint.
Urgency hints usually flash or highlight the program's icon in a taskbar, Urgency hints usually flash or highlight the program's icon in a taskbar,
dock or panel. dock or panel.
""" """
id: str = field() id: str = field()
critical: bool = False critical: bool = False
bubble: bool = False bubble: bool = False
sound: bool = False sound: bool = False
urgency_hint: bool = False urgency_hint: bool = False
# Bubble parameters # Bubble parameters
title: str = "" title: str = ""
body: str = "" body: str = ""
image: Union[Path, str] = "" image: Union[Path, str] = ""
@dataclass @dataclass
class CoroutineDone(PyOtherSideEvent): class CoroutineDone(PyOtherSideEvent):
"""Indicate that an asyncio coroutine finished.""" """Indicate that an asyncio coroutine finished."""
uuid: str = field() uuid: str = field()
result: Any = None result: Any = None
exception: Optional[Exception] = None exception: Optional[Exception] = None
traceback: Optional[str] = None traceback: Optional[str] = None
@dataclass @dataclass
class LoopException(PyOtherSideEvent): class LoopException(PyOtherSideEvent):
"""Indicate an uncaught exception occurance in the asyncio loop.""" """Indicate an uncaught exception occurance in the asyncio loop."""
message: str = field() message: str = field()
exception: Optional[Exception] = field() exception: Optional[Exception] = field()
traceback: Optional[str] = None traceback: Optional[str] = None
@dataclass @dataclass
class Pre070SettingsDetected(PyOtherSideEvent): class Pre070SettingsDetected(PyOtherSideEvent):
"""Warn that a pre-0.7.0 settings.json file exists.""" """Warn that a pre-0.7.0 settings.json file exists."""
path: Path = field() path: Path = field()
@dataclass @dataclass
class UserFileChanged(PyOtherSideEvent): class UserFileChanged(PyOtherSideEvent):
"""Indicate that a config or data file changed on disk.""" """Indicate that a config or data file changed on disk."""
type: Type["UserFile"] = field() type: Type["UserFile"] = field()
new_data: Any = field() new_data: Any = field()
@dataclass @dataclass
class ModelEvent(PyOtherSideEvent): class ModelEvent(PyOtherSideEvent):
"""Base class for model change events.""" """Base class for model change events."""
sync_id: "SyncId" = field() sync_id: "SyncId" = field()
@dataclass @dataclass
class ModelItemSet(ModelEvent): class ModelItemSet(ModelEvent):
"""Indicate `ModelItem` insert or field changes in a `Backend` `Model`.""" """Indicate `ModelItem` insert or field changes in a `Backend` `Model`."""
index_then: Optional[int] = field() index_then: Optional[int] = field()
index_now: int = field() index_now: int = field()
fields: Dict[str, Any] = field() fields: Dict[str, Any] = field()
@dataclass @dataclass
class ModelItemDeleted(ModelEvent): class ModelItemDeleted(ModelEvent):
"""Indicate the removal of a `ModelItem` from a `Backend` `Model`.""" """Indicate the removal of a `ModelItem` from a `Backend` `Model`."""
index: int = field() index: int = field()
count: int = 1 count: int = 1
ids: Sequence[Any] = () ids: Sequence[Any] = ()
@dataclass @dataclass
class ModelCleared(ModelEvent): class ModelCleared(ModelEvent):
"""Indicate that a `Backend` `Model` was cleared.""" """Indicate that a `Backend` `Model` was cleared."""
@dataclass @dataclass
class DevicesUpdated(PyOtherSideEvent): class DevicesUpdated(PyOtherSideEvent):
"""Indicate changes in devices for us or users we share a room with.""" """Indicate changes in devices for us or users we share a room with."""
our_user_id: str = field() our_user_id: str = field()
@dataclass @dataclass
class InvalidAccessToken(PyOtherSideEvent): class InvalidAccessToken(PyOtherSideEvent):
"""Indicate one of our account's access token is invalid or revoked.""" """Indicate one of our account's access token is invalid or revoked."""
user_id: str = field() user_id: str = field()

View File

@ -29,143 +29,143 @@ from .pyotherside_events import CoroutineDone, LoopException
class QMLBridge: class QMLBridge:
"""Setup asyncio and provide methods to call coroutines from QML. """Setup asyncio and provide methods to call coroutines from QML.
A thread is created to run the asyncio loop in, to ensure all calls from A thread is created to run the asyncio loop in, to ensure all calls from
QML return instantly. QML return instantly.
Synchronous methods are provided for QML to call coroutines using Synchronous methods are provided for QML to call coroutines using
PyOtherSide, which doesn't have this ability out of the box. PyOtherSide, which doesn't have this ability out of the box.
Attributes: Attributes:
backend: The `backend.Backend` object containing general coroutines backend: The `backend.Backend` object containing general coroutines
for QML and that manages `MatrixClient` objects. for QML and that manages `MatrixClient` objects.
""" """
def __init__(self) -> None: def __init__(self) -> None:
try: try:
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
except RuntimeError: except RuntimeError:
self._loop = asyncio.new_event_loop() self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop) asyncio.set_event_loop(self._loop)
self._loop.set_exception_handler(self._loop_exception_handler) self._loop.set_exception_handler(self._loop_exception_handler)
from .backend import Backend from .backend import Backend
self.backend: Backend = Backend() self.backend: Backend = Backend()
self._running_futures: Dict[str, Future] = {} self._running_futures: Dict[str, Future] = {}
self._cancelled_early: Set[str] = set() self._cancelled_early: Set[str] = set()
Thread(target=self._start_asyncio_loop).start() Thread(target=self._start_asyncio_loop).start()
def _loop_exception_handler( def _loop_exception_handler(
self, loop: asyncio.AbstractEventLoop, context: dict, self, loop: asyncio.AbstractEventLoop, context: dict,
) -> None: ) -> None:
if "exception" in context: if "exception" in context:
err = context["exception"] err = context["exception"]
trace = "".join( trace = "".join(
traceback.format_exception(type(err), err, err.__traceback__), traceback.format_exception(type(err), err, err.__traceback__),
) )
LoopException(context["message"], err, trace) LoopException(context["message"], err, trace)
loop.default_exception_handler(context) loop.default_exception_handler(context)
def _start_asyncio_loop(self) -> None: def _start_asyncio_loop(self) -> None:
asyncio.set_event_loop(self._loop) asyncio.set_event_loop(self._loop)
self._loop.run_forever() self._loop.run_forever()
def _call_coro(self, coro: Coroutine, uuid: str) -> None: def _call_coro(self, coro: Coroutine, uuid: str) -> None:
"""Schedule a coroutine to run in our thread and return a `Future`.""" """Schedule a coroutine to run in our thread and return a `Future`."""
if uuid in self._cancelled_early: if uuid in self._cancelled_early:
self._cancelled_early.remove(uuid) self._cancelled_early.remove(uuid)
return return
def on_done(future: Future) -> None: def on_done(future: Future) -> None:
"""Send a PyOtherSide event with the coro's result/exception.""" """Send a PyOtherSide event with the coro's result/exception."""
result = exception = trace = None result = exception = trace = None
try: try:
result = future.result() result = future.result()
except Exception as err: # noqa except Exception as err: # noqa
exception = err exception = err
trace = traceback.format_exc().rstrip() trace = traceback.format_exc().rstrip()
CoroutineDone(uuid, result, exception, trace) CoroutineDone(uuid, result, exception, trace)
del self._running_futures[uuid] del self._running_futures[uuid]
future = asyncio.run_coroutine_threadsafe(coro, self._loop) future = asyncio.run_coroutine_threadsafe(coro, self._loop)
self._running_futures[uuid] = future self._running_futures[uuid] = future
future.add_done_callback(on_done) future.add_done_callback(on_done)
def call_backend_coro( def call_backend_coro(
self, name: str, uuid: str, args: Sequence[str] = (), self, name: str, uuid: str, args: Sequence[str] = (),
) -> None: ) -> None:
"""Schedule a coroutine from the `QMLBridge.backend` object.""" """Schedule a coroutine from the `QMLBridge.backend` object."""
if uuid in self._cancelled_early: if uuid in self._cancelled_early:
self._cancelled_early.remove(uuid) self._cancelled_early.remove(uuid)
else: else:
self._call_coro(attrgetter(name)(self.backend)(*args), uuid) self._call_coro(attrgetter(name)(self.backend)(*args), uuid)
def call_client_coro( def call_client_coro(
self, user_id: str, name: str, uuid: str, args: Sequence[str] = (), self, user_id: str, name: str, uuid: str, args: Sequence[str] = (),
) -> None: ) -> None:
"""Schedule a coroutine from a `QMLBridge.backend.clients` client.""" """Schedule a coroutine from a `QMLBridge.backend.clients` client."""
if uuid in self._cancelled_early: if uuid in self._cancelled_early:
self._cancelled_early.remove(uuid) self._cancelled_early.remove(uuid)
else: else:
client = self.backend.clients[user_id] client = self.backend.clients[user_id]
self._call_coro(attrgetter(name)(client)(*args), uuid) self._call_coro(attrgetter(name)(client)(*args), uuid)
def cancel_coro(self, uuid: str) -> None: def cancel_coro(self, uuid: str) -> None:
"""Cancel a couroutine scheduled by the `QMLBridge` methods.""" """Cancel a couroutine scheduled by the `QMLBridge` methods."""
if uuid in self._running_futures: if uuid in self._running_futures:
self._running_futures[uuid].cancel() self._running_futures[uuid].cancel()
else: else:
self._cancelled_early.add(uuid) self._cancelled_early.add(uuid)
def pdb(self, extra_data: Sequence = (), remote: bool = False) -> None: def pdb(self, extra_data: Sequence = (), remote: bool = False) -> None:
"""Call the python debugger, defining some conveniance variables.""" """Call the python debugger, defining some conveniance variables."""
ad = extra_data # noqa ad = extra_data # noqa
ba = self.backend # noqa ba = self.backend # noqa
mo = self.backend.models # noqa mo = self.backend.models # noqa
cl = self.backend.clients cl = self.backend.clients
gcl = lambda user: cl[f"@{user}"] # noqa gcl = lambda user: cl[f"@{user}"] # noqa
rc = lambda c: asyncio.run_coroutine_threadsafe(c, self._loop) # noqa rc = lambda c: asyncio.run_coroutine_threadsafe(c, self._loop) # noqa
try: try:
from devtools import debug # noqa from devtools import debug # noqa
d = debug # noqa d = debug # noqa
except ModuleNotFoundError: except ModuleNotFoundError:
log.warning("Module python-devtools not found, can't use debug()") log.warning("Module python-devtools not found, can't use debug()")
if remote: if remote:
# Run `socat readline tcp:127.0.0.1:4444` in a terminal to connect # Run `socat readline tcp:127.0.0.1:4444` in a terminal to connect
import remote_pdb import remote_pdb
remote_pdb.RemotePdb("127.0.0.1", 4444).set_trace() remote_pdb.RemotePdb("127.0.0.1", 4444).set_trace()
else: else:
import pdb import pdb
pdb.set_trace() pdb.set_trace()
def exit(self) -> None: def exit(self) -> None:
try: try:
asyncio.run_coroutine_threadsafe( asyncio.run_coroutine_threadsafe(
self.backend.terminate_clients(), self._loop, self.backend.terminate_clients(), self._loop,
).result() ).result()
except Exception as e: # noqa except Exception as e: # noqa
print(e) print(e)
# The AppImage AppRun script overwrites some environment path variables to # The AppImage AppRun script overwrites some environment path variables to
@ -174,8 +174,8 @@ class QMLBridge:
# to prevent problems like QML Qt.openUrlExternally() failing because # to prevent problems like QML Qt.openUrlExternally() failing because
# the external launched program is affected by our AppImage-specific variables. # the external launched program is affected by our AppImage-specific variables.
for var in ("LD_LIBRARY_PATH", "PYTHONHOME", "PYTHONUSERBASE"): for var in ("LD_LIBRARY_PATH", "PYTHONHOME", "PYTHONUSERBASE"):
if f"RESTORE_{var}" in os.environ: if f"RESTORE_{var}" in os.environ:
os.environ[var] = os.environ[f"RESTORE_{var}"] os.environ[var] = os.environ[f"RESTORE_{var}"]
BRIDGE = QMLBridge() BRIDGE = QMLBridge()

View File

@ -9,99 +9,99 @@ from . import __display_name__
_SUCCESS_HTML_PAGE = """<!DOCTYPE html> _SUCCESS_HTML_PAGE = """<!DOCTYPE html>
<html> <html>
<head> <head>
<title>""" + __display_name__ + """</title> <title>""" + __display_name__ + """</title>
<meta charset="utf-8"> <meta charset="utf-8">
<style> <style>
body { background: hsl(0, 0%, 90%); } body { background: hsl(0, 0%, 90%); }
@keyframes appear { @keyframes appear {
0% { transform: scale(0); } 0% { transform: scale(0); }
45% { transform: scale(0); } 45% { transform: scale(0); }
80% { transform: scale(1.6); } 80% { transform: scale(1.6); }
100% { transform: scale(1); } 100% { transform: scale(1); }
} }
.circle { .circle {
width: 90px; width: 90px;
height: 90px; height: 90px;
position: absolute; position: absolute;
top: 50%; top: 50%;
left: 50%; left: 50%;
margin: -45px 0 0 -45px; margin: -45px 0 0 -45px;
border-radius: 50%; border-radius: 50%;
font-size: 60px; font-size: 60px;
line-height: 90px; line-height: 90px;
text-align: center; text-align: center;
background: hsl(203, 51%, 15%); background: hsl(203, 51%, 15%);
color: hsl(162, 56%, 42%, 1); color: hsl(162, 56%, 42%, 1);
animation: appear 0.4s linear; animation: appear 0.4s linear;
} }
</style> </style>
</head> </head>
<body><div class="circle"></div></body> <body><div class="circle"></div></body>
</html>""" </html>"""
class _SSORequestHandler(BaseHTTPRequestHandler): class _SSORequestHandler(BaseHTTPRequestHandler):
def do_GET(self) -> None: def do_GET(self) -> None:
self.server: "SSOServer" self.server: "SSOServer"
redirect = "%s/_matrix/client/r0/login/sso/redirect?redirectUrl=%s" % ( redirect = "%s/_matrix/client/r0/login/sso/redirect?redirectUrl=%s" % (
self.server.for_homeserver, self.server.for_homeserver,
quote(self.server.url_to_open), quote(self.server.url_to_open),
) )
parameters = parse_qs(urlparse(self.path).query) parameters = parse_qs(urlparse(self.path).query)
if "loginToken" in parameters: if "loginToken" in parameters:
self.server._token = parameters["loginToken"][0] self.server._token = parameters["loginToken"][0]
self.send_response(200) # OK self.send_response(200) # OK
self.send_header("Content-type", "text/html") self.send_header("Content-type", "text/html")
self.end_headers() self.end_headers()
self.wfile.write(_SUCCESS_HTML_PAGE.encode()) self.wfile.write(_SUCCESS_HTML_PAGE.encode())
else: else:
self.send_response(308) # Permanent redirect, same method only self.send_response(308) # Permanent redirect, same method only
self.send_header("Location", redirect) self.send_header("Location", redirect)
self.end_headers() self.end_headers()
self.close_connection = True self.close_connection = True
class SSOServer(HTTPServer): class SSOServer(HTTPServer):
"""Local HTTP server to retrieve a SSO login token. """Local HTTP server to retrieve a SSO login token.
Call `SSOServer.wait_for_token()` in a background task to start waiting Call `SSOServer.wait_for_token()` in a background task to start waiting
for a SSO login token from the Matrix homeserver. for a SSO login token from the Matrix homeserver.
Once the task is running, the user must open `SSOServer.url_to_open` in Once the task is running, the user must open `SSOServer.url_to_open` in
their browser, where they will be able to complete the login process. their browser, where they will be able to complete the login process.
Once they are done, the homeserver will call us back with a login token Once they are done, the homeserver will call us back with a login token
and the `SSOServer.wait_for_token()` task will return. and the `SSOServer.wait_for_token()` task will return.
""" """
def __init__(self, for_homeserver: str) -> None: def __init__(self, for_homeserver: str) -> None:
self.for_homeserver: str = for_homeserver self.for_homeserver: str = for_homeserver
self._token: str = "" self._token: str = ""
# Pick the first available port # Pick the first available port
super().__init__(("127.0.0.1", 0), _SSORequestHandler) super().__init__(("127.0.0.1", 0), _SSORequestHandler)
@property @property
def url_to_open(self) -> str: def url_to_open(self) -> str:
"""URL for the user to open in their browser, to do the SSO process.""" """URL for the user to open in their browser, to do the SSO process."""
return f"http://{self.server_address[0]}:{self.server_port}" return f"http://{self.server_address[0]}:{self.server_port}"
async def wait_for_token(self) -> str: async def wait_for_token(self) -> str:
"""Wait until the homeserver gives us a login token and return it.""" """Wait until the homeserver gives us a login token and return it."""
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
while not self._token: while not self._token:
await loop.run_in_executor(None, self.handle_request) await loop.run_in_executor(None, self.handle_request)
return self._token return self._token

View File

@ -11,77 +11,77 @@ import re
from typing import Generator from typing import Generator
PROPERTY_TYPES = {"bool", "double", "int", "list", "real", "string", "url", PROPERTY_TYPES = {"bool", "double", "int", "list", "real", "string", "url",
"var", "date", "point", "rect", "size", "color"} "var", "date", "point", "rect", "size", "color"}
def _add_property(line: str) -> str: def _add_property(line: str) -> str:
"""Return a QML property declaration line from a QPL property line.""" """Return a QML property declaration line from a QPL property line."""
if re.match(r"^\s*[a-zA-Z\d_]+\s*:$", line): if re.match(r"^\s*[a-zA-Z\d_]+\s*:$", line):
return re.sub(r"^(\s*)(\S*\s*):$", return re.sub(r"^(\s*)(\S*\s*):$",
r"\1readonly property QtObject \2: QtObject", r"\1readonly property QtObject \2: QtObject",
line) line)
types = "|".join(PROPERTY_TYPES) types = "|".join(PROPERTY_TYPES)
if re.match(fr"^\s*({types}) [a-zA-Z\d_]+\s*:", line): if re.match(fr"^\s*({types}) [a-zA-Z\d_]+\s*:", line):
return re.sub(r"^(\s*)(\S*)", r"\1property \2", line) return re.sub(r"^(\s*)(\S*)", r"\1property \2", line)
return line return line
def _process_lines(content: str) -> Generator[str, None, None]: def _process_lines(content: str) -> Generator[str, None, None]:
"""Yield lines of real QML from lines of QPL.""" """Yield lines of real QML from lines of QPL."""
skip = False skip = False
indent = " " * 4 indent = " " * 4
current_indent = 0 current_indent = 0
for line in content.split("\n"): for line in content.split("\n"):
line = line.rstrip() line = line.rstrip()
if not line.strip() or line.strip().startswith("//"): if not line.strip() or line.strip().startswith("//"):
continue continue
start_space_list = re.findall(r"^ +", line) start_space_list = re.findall(r"^ +", line)
start_space = start_space_list[0] if start_space_list else "" start_space = start_space_list[0] if start_space_list else ""
line_indents = len(re.findall(indent, start_space)) line_indents = len(re.findall(indent, start_space))
if not skip: if not skip:
if line_indents > current_indent: if line_indents > current_indent:
yield "%s{" % (indent * current_indent) yield "%s{" % (indent * current_indent)
current_indent = line_indents current_indent = line_indents
while line_indents < current_indent: while line_indents < current_indent:
current_indent -= 1 current_indent -= 1
yield "%s}" % (indent * current_indent) yield "%s}" % (indent * current_indent)
line = _add_property(line) line = _add_property(line)
yield line yield line
skip = any((line.endswith(e) for e in "([{+\\,?:")) skip = any((line.endswith(e) for e in "([{+\\,?:"))
while current_indent: while current_indent:
current_indent -= 1 current_indent -= 1
yield "%s}" % (indent * current_indent) yield "%s}" % (indent * current_indent)
def convert_to_qml(theme_content: str) -> str: def convert_to_qml(theme_content: str) -> str:
"""Return valid QML code with imports from QPL content.""" """Return valid QML code with imports from QPL content."""
theme_content = theme_content.replace("\t", " ") theme_content = theme_content.replace("\t", " ")
lines = [ lines = [
"import QtQuick 2.12", "import QtQuick 2.12",
'import "../Base"', 'import "../Base"',
"QtObject {", "QtObject {",
" function hsluv(h, s, l, a) { return utils.hsluv(h, s, l, a) }", " function hsluv(h, s, l, a) { return utils.hsluv(h, s, l, a) }",
" function hsl(h, s, l) { return utils.hsl(h, s, l) }", " function hsl(h, s, l) { return utils.hsl(h, s, l) }",
" function hsla(h, s, l, a) { return utils.hsla(h, s, l, a) }", " function hsla(h, s, l, a) { return utils.hsla(h, s, l, a) }",
" id: theme", " id: theme",
] ]
lines += [f" {line}" for line in _process_lines(theme_content)] lines += [f" {line}" for line in _process_lines(theme_content)]
lines += ["}"] lines += ["}"]
return "\n".join(lines) return "\n".join(lines)

View File

@ -12,7 +12,7 @@ from collections.abc import MutableMapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
TYPE_CHECKING, Any, ClassVar, Dict, Iterator, Optional, Tuple, TYPE_CHECKING, Any, ClassVar, Dict, Iterator, Optional, Tuple,
) )
import pyotherside import pyotherside
@ -20,521 +20,521 @@ from watchgod import Change, awatch
from .pcn.section import Section from .pcn.section import Section
from .pyotherside_events import ( from .pyotherside_events import (
LoopException, Pre070SettingsDetected, UserFileChanged, LoopException, Pre070SettingsDetected, UserFileChanged,
) )
from .theme_parser import convert_to_qml from .theme_parser import convert_to_qml
from .utils import ( from .utils import (
aiopen, atomic_write, deep_serialize_for_qml, dict_update_recursive, aiopen, atomic_write, deep_serialize_for_qml, dict_update_recursive,
flatten_dict_keys, flatten_dict_keys,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from .backend import Backend from .backend import Backend
@dataclass @dataclass
class UserFile: class UserFile:
"""Base class representing a user config or data file.""" """Base class representing a user config or data file."""
create_missing: ClassVar[bool] = True create_missing: ClassVar[bool] = True
backend: "Backend" = field(repr=False) backend: "Backend" = field(repr=False)
filename: str = field() filename: str = field()
parent: Optional["UserFile"] = None parent: Optional["UserFile"] = None
children: Dict[Path, "UserFile"] = field(default_factory=dict) children: Dict[Path, "UserFile"] = field(default_factory=dict)
data: Any = field(init=False, default_factory=dict) data: Any = field(init=False, default_factory=dict)
_need_write: bool = field(init=False, default=False) _need_write: bool = field(init=False, default=False)
_mtime: Optional[float] = field(init=False, default=None) _mtime: Optional[float] = field(init=False, default=None)
_reader: Optional[asyncio.Future] = field(init=False, default=None) _reader: Optional[asyncio.Future] = field(init=False, default=None)
_writer: Optional[asyncio.Future] = field(init=False, default=None) _writer: Optional[asyncio.Future] = field(init=False, default=None)
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.data = self.default_data self.data = self.default_data
self._need_write = self.create_missing self._need_write = self.create_missing
if self.path.exists(): if self.path.exists():
try: try:
text = self.path.read_text() text = self.path.read_text()
self.data, self._need_write = self.deserialized(text) self.data, self._need_write = self.deserialized(text)
except Exception as err: # noqa except Exception as err: # noqa
LoopException(str(err), err, traceback.format_exc().rstrip()) LoopException(str(err), err, traceback.format_exc().rstrip())
self._reader = asyncio.ensure_future(self._start_reader()) self._reader = asyncio.ensure_future(self._start_reader())
self._writer = asyncio.ensure_future(self._start_writer()) self._writer = asyncio.ensure_future(self._start_writer())
@property @property
def path(self) -> Path: def path(self) -> Path:
"""Full path of the file to read, can exist or not exist.""" """Full path of the file to read, can exist or not exist."""
raise NotImplementedError() raise NotImplementedError()
@property @property
def write_path(self) -> Path: def write_path(self) -> Path:
"""Full path of the file to write, can exist or not exist.""" """Full path of the file to write, can exist or not exist."""
return self.path return self.path
@property @property
def default_data(self) -> Any: def default_data(self) -> Any:
"""Default deserialized content to use if the file doesn't exist.""" """Default deserialized content to use if the file doesn't exist."""
raise NotImplementedError() raise NotImplementedError()
@property @property
def qml_data(self) -> Any: def qml_data(self) -> Any:
"""Data converted for usage in QML.""" """Data converted for usage in QML."""
return self.data return self.data
def deserialized(self, data: str) -> Tuple[Any, bool]: def deserialized(self, data: str) -> Tuple[Any, bool]:
"""Return parsed data from file text and whether to call `save()`.""" """Return parsed data from file text and whether to call `save()`."""
return (data, False) return (data, False)
def serialized(self) -> str: def serialized(self) -> str:
"""Return text from `UserFile.data` that can be written to disk.""" """Return text from `UserFile.data` that can be written to disk."""
raise NotImplementedError() raise NotImplementedError()
def save(self) -> None: def save(self) -> None:
"""Inform the disk writer coroutine that the data has changed.""" """Inform the disk writer coroutine that the data has changed."""
self._need_write = True self._need_write = True
def stop_watching(self) -> None: def stop_watching(self) -> None:
"""Stop watching the on-disk file for changes.""" """Stop watching the on-disk file for changes."""
if self._reader: if self._reader:
self._reader.cancel() self._reader.cancel()
if self._writer: if self._writer:
self._writer.cancel() self._writer.cancel()
for child in self.children.values(): for child in self.children.values():
child.stop_watching() child.stop_watching()
async def set_data(self, data: Any) -> None: async def set_data(self, data: Any) -> None:
"""Set `data` and call `save()`, conveniance method for QML.""" """Set `data` and call `save()`, conveniance method for QML."""
self.data = data self.data = data
self.save() self.save()
async def update_from_file(self) -> None: async def update_from_file(self) -> None:
"""Read file at `path`, update `data` and call `save()` if needed.""" """Read file at `path`, update `data` and call `save()` if needed."""
if not self.path.exists(): if not self.path.exists():
self.data = self.default_data self.data = self.default_data
self._need_write = self.create_missing self._need_write = self.create_missing
return return
async with aiopen(self.path) as file: async with aiopen(self.path) as file:
self.data, self._need_write = self.deserialized(await file.read()) self.data, self._need_write = self.deserialized(await file.read())
async def _start_reader(self) -> None: async def _start_reader(self) -> None:
"""Disk reader coroutine, watches for file changes to update `data`.""" """Disk reader coroutine, watches for file changes to update `data`."""
while not self.path.exists(): while not self.path.exists():
await asyncio.sleep(1) await asyncio.sleep(1)
async for changes in awatch(self.path): async for changes in awatch(self.path):
try: try:
ignored = 0 ignored = 0
for change in changes: for change in changes:
if change[0] in (Change.added, Change.modified): if change[0] in (Change.added, Change.modified):
mtime = self.path.stat().st_mtime mtime = self.path.stat().st_mtime
if mtime == self._mtime: if mtime == self._mtime:
ignored += 1 ignored += 1
continue continue
await self.update_from_file() await self.update_from_file()
self._mtime = mtime self._mtime = mtime
elif change[0] == Change.deleted: elif change[0] == Change.deleted:
self._mtime = None self._mtime = None
self.data = self.default_data self.data = self.default_data
self._need_write = self.create_missing self._need_write = self.create_missing
if changes and ignored < len(changes): if changes and ignored < len(changes):
UserFileChanged(type(self), self.qml_data) UserFileChanged(type(self), self.qml_data)
parent = self.parent parent = self.parent
while parent: while parent:
await parent.update_from_file() await parent.update_from_file()
UserFileChanged(type(parent), parent.qml_data) UserFileChanged(type(parent), parent.qml_data)
parent = parent.parent parent = parent.parent
while not self.path.exists(): while not self.path.exists():
# Prevent error spam after file gets deleted # Prevent error spam after file gets deleted
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
except Exception as err: # noqa except Exception as err: # noqa
LoopException(str(err), err, traceback.format_exc().rstrip()) LoopException(str(err), err, traceback.format_exc().rstrip())
async def _start_writer(self) -> None: async def _start_writer(self) -> None:
"""Disk writer coroutine, update the file with a 1 second cooldown.""" """Disk writer coroutine, update the file with a 1 second cooldown."""
if self.write_path.parts[0] == "qrc:": if self.write_path.parts[0] == "qrc:":
return return
self.write_path.parent.mkdir(parents=True, exist_ok=True) self.write_path.parent.mkdir(parents=True, exist_ok=True)
while True: while True:
await asyncio.sleep(1) await asyncio.sleep(1)
try: try:
if self._need_write: if self._need_write:
async with atomic_write(self.write_path) as (new, done): async with atomic_write(self.write_path) as (new, done):
await new.write(self.serialized()) await new.write(self.serialized())
done() done()
self._need_write = False self._need_write = False
self._mtime = self.write_path.stat().st_mtime self._mtime = self.write_path.stat().st_mtime
except Exception as err: # noqa except Exception as err: # noqa
self._need_write = False self._need_write = False
LoopException(str(err), err, traceback.format_exc().rstrip()) LoopException(str(err), err, traceback.format_exc().rstrip())
@dataclass @dataclass
class ConfigFile(UserFile): class ConfigFile(UserFile):
"""A file that goes in the configuration directory, e.g. ~/.config/app.""" """A file that goes in the configuration directory, e.g. ~/.config/app."""
@property @property
def path(self) -> Path: def path(self) -> Path:
return Path( return Path(
os.environ.get("MOMENT_CONFIG_DIR") or os.environ.get("MOMENT_CONFIG_DIR") or
self.backend.appdirs.user_config_dir, self.backend.appdirs.user_config_dir,
) / self.filename ) / self.filename
@dataclass @dataclass
class UserDataFile(UserFile): class UserDataFile(UserFile):
"""A file that goes in the user data directory, e.g. ~/.local/share/app.""" """A file that goes in the user data directory, e.g. ~/.local/share/app."""
@property @property
def path(self) -> Path: def path(self) -> Path:
return Path( return Path(
os.environ.get("MOMENT_DATA_DIR") or os.environ.get("MOMENT_DATA_DIR") or
self.backend.appdirs.user_data_dir, self.backend.appdirs.user_data_dir,
) / self.filename ) / self.filename
@dataclass @dataclass
class MappingFile(MutableMapping, UserFile): class MappingFile(MutableMapping, UserFile):
"""A file manipulable like a dict. `data` must be a mutable mapping.""" """A file manipulable like a dict. `data` must be a mutable mapping."""
def __getitem__(self, key: Any) -> Any: def __getitem__(self, key: Any) -> Any:
return self.data[key] return self.data[key]
def __setitem__(self, key: Any, value: Any) -> None: def __setitem__(self, key: Any, value: Any) -> None:
self.data[key] = value self.data[key] = value
def __delitem__(self, key: Any) -> None: def __delitem__(self, key: Any) -> None:
del self.data[key] del self.data[key]
def __iter__(self) -> Iterator: def __iter__(self) -> Iterator:
return iter(self.data) return iter(self.data)
def __len__(self) -> int: def __len__(self) -> int:
return len(self.data) return len(self.data)
def __getattr__(self, key: Any) -> Any: def __getattr__(self, key: Any) -> Any:
try: try:
return self.data[key] return self.data[key]
except KeyError: except KeyError:
return super().__getattribute__(key) return super().__getattribute__(key)
def __setattr__(self, key: Any, value: Any) -> None: def __setattr__(self, key: Any, value: Any) -> None:
if key in self.__dataclass_fields__: if key in self.__dataclass_fields__:
super().__setattr__(key, value) super().__setattr__(key, value)
return return
self.data[key] = value self.data[key] = value
def __delattr__(self, key: Any) -> None: def __delattr__(self, key: Any) -> None:
del self.data[key] del self.data[key]
@dataclass @dataclass
class JSONFile(MappingFile): class JSONFile(MappingFile):
"""A file stored on disk in the JSON format.""" """A file stored on disk in the JSON format."""
@property @property
def default_data(self) -> dict: def default_data(self) -> dict:
return {} return {}
def deserialized(self, data: str) -> Tuple[dict, bool]: def deserialized(self, data: str) -> Tuple[dict, bool]:
"""Return parsed data from file text and whether to call `save()`. """Return parsed data from file text and whether to call `save()`.
If the file has missing keys, the missing data will be merged to the If the file has missing keys, the missing data will be merged to the
returned dict and the second tuple item will be `True`. returned dict and the second tuple item will be `True`.
""" """
loaded = json.loads(data) loaded = json.loads(data)
all_data = self.default_data.copy() all_data = self.default_data.copy()
dict_update_recursive(all_data, loaded) dict_update_recursive(all_data, loaded)
return (all_data, loaded != all_data) return (all_data, loaded != all_data)
def serialized(self) -> str: def serialized(self) -> str:
data = self.data data = self.data
return json.dumps(data, indent=4, ensure_ascii=False, sort_keys=True) return json.dumps(data, indent=4, ensure_ascii=False, sort_keys=True)
@dataclass @dataclass
class PCNFile(MappingFile): class PCNFile(MappingFile):
"""File stored in the PCN format, with machine edits in a separate JSON.""" """File stored in the PCN format, with machine edits in a separate JSON."""
create_missing = False create_missing = False
path_override: Optional[Path] = None path_override: Optional[Path] = None
@property @property
def path(self) -> Path: def path(self) -> Path:
return self.path_override or super().path return self.path_override or super().path
@property @property
def write_path(self) -> Path: def write_path(self) -> Path:
"""Full path of file where programatically-done edits are stored.""" """Full path of file where programatically-done edits are stored."""
return self.path.with_suffix(".gui.json") return self.path.with_suffix(".gui.json")
@property @property
def qml_data(self) -> Dict[str, Any]: def qml_data(self) -> Dict[str, Any]:
return deep_serialize_for_qml(self.data.as_dict()) # type: ignore return deep_serialize_for_qml(self.data.as_dict()) # type: ignore
@property @property
def default_data(self) -> Section: def default_data(self) -> Section:
return Section() return Section()
def deserialized(self, data: str) -> Tuple[Section, bool]: def deserialized(self, data: str) -> Tuple[Section, bool]:
root = Section.from_source_code(data, self.path) root = Section.from_source_code(data, self.path)
edits = "{}" edits = "{}"
if self.write_path.exists(): if self.write_path.exists():
edits = self.write_path.read_text() edits = self.write_path.read_text()
includes_now = list(root.all_includes) includes_now = list(root.all_includes)
for path, pcn in self.children.copy().items(): for path, pcn in self.children.copy().items():
if path not in includes_now: if path not in includes_now:
pcn.stop_watching() pcn.stop_watching()
del self.children[path] del self.children[path]
for path in includes_now: for path in includes_now:
if path not in self.children: if path not in self.children:
self.children[path] = PCNFile( self.children[path] = PCNFile(
self.backend, self.backend,
filename = path.name, filename = path.name,
parent = self, parent = self,
path_override = path, path_override = path,
) )
return (root, root.deep_merge_edits(json.loads(edits))) return (root, root.deep_merge_edits(json.loads(edits)))
def serialized(self) -> str: def serialized(self) -> str:
edits = self.data.edits_as_dict() edits = self.data.edits_as_dict()
return json.dumps(edits, indent=4, ensure_ascii=False) return json.dumps(edits, indent=4, ensure_ascii=False)
async def set_data(self, data: Dict[str, Any]) -> None: async def set_data(self, data: Dict[str, Any]) -> None:
self.data.deep_merge_edits({"set": data}, has_expressions=False) self.data.deep_merge_edits({"set": data}, has_expressions=False)
self.save() self.save()
@dataclass @dataclass
class Accounts(ConfigFile, JSONFile): class Accounts(ConfigFile, JSONFile):
"""Config file for saved matrix accounts: user ID, access tokens, etc""" """Config file for saved matrix accounts: user ID, access tokens, etc"""
filename: str = "accounts.json" filename: str = "accounts.json"
async def any_saved(self) -> bool: async def any_saved(self) -> bool:
"""Return for QML whether there are any accounts saved on disk.""" """Return for QML whether there are any accounts saved on disk."""
return bool(self.data) return bool(self.data)
async def add(self, user_id: str) -> None: async def add(self, user_id: str) -> None:
"""Add an account to the config and write it on disk. """Add an account to the config and write it on disk.
The account's details such as its access token are retrieved from The account's details such as its access token are retrieved from
the corresponding `MatrixClient` in `backend.clients`. the corresponding `MatrixClient` in `backend.clients`.
""" """
client = self.backend.clients[user_id] client = self.backend.clients[user_id]
account = self.backend.models["accounts"][user_id] account = self.backend.models["accounts"][user_id]
self.update({ self.update({
client.user_id: { client.user_id: {
"homeserver": client.homeserver, "homeserver": client.homeserver,
"token": client.access_token, "token": client.access_token,
"device_id": client.device_id, "device_id": client.device_id,
"enabled": True, "enabled": True,
"presence": account.presence.value.replace("echo_", ""), "presence": account.presence.value.replace("echo_", ""),
"status_msg": account.status_msg, "status_msg": account.status_msg,
"order": account.order, "order": account.order,
}, },
}) })
self.save() self.save()
async def set( async def set(
self, self,
user_id: str, user_id: str,
enabled: Optional[str] = None, enabled: Optional[str] = None,
presence: Optional[str] = None, presence: Optional[str] = None,
order: Optional[int] = None, order: Optional[int] = None,
status_msg: Optional[str] = None, status_msg: Optional[str] = None,
) -> None: ) -> None:
"""Update an account if found in the config file and write to disk.""" """Update an account if found in the config file and write to disk."""
if user_id not in self: if user_id not in self:
return return
if enabled is not None: if enabled is not None:
self[user_id]["enabled"] = enabled self[user_id]["enabled"] = enabled
if presence is not None: if presence is not None:
self[user_id]["presence"] = presence self[user_id]["presence"] = presence
if order is not None: if order is not None:
self[user_id]["order"] = order self[user_id]["order"] = order
if status_msg is not None: if status_msg is not None:
self[user_id]["status_msg"] = status_msg self[user_id]["status_msg"] = status_msg
self.save() self.save()
async def forget(self, user_id: str) -> None: async def forget(self, user_id: str) -> None:
"""Delete an account from the config and write it on disk.""" """Delete an account from the config and write it on disk."""
self.pop(user_id, None) self.pop(user_id, None)
self.save() self.save()
@dataclass @dataclass
class Pre070Settings(ConfigFile): class Pre070Settings(ConfigFile):
"""Detect and warn about the presence of a pre-0.7.0 settings.json file.""" """Detect and warn about the presence of a pre-0.7.0 settings.json file."""
filename: str = "settings.json" filename: str = "settings.json"
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self.path.exists(): if self.path.exists():
Pre070SettingsDetected(self.path) Pre070SettingsDetected(self.path)
@dataclass @dataclass
class Settings(ConfigFile, PCNFile): class Settings(ConfigFile, PCNFile):
"""General config file for UI and backend settings""" """General config file for UI and backend settings"""
filename: str = "settings.py" filename: str = "settings.py"
@property @property
def default_data(self) -> Section: def default_data(self) -> Section:
root = Section.from_file("src/config/settings.py") root = Section.from_file("src/config/settings.py")
edits = "{}" edits = "{}"
if self.write_path.exists(): if self.write_path.exists():
edits = self.write_path.read_text() edits = self.write_path.read_text()
root.deep_merge_edits(json.loads(edits)) root.deep_merge_edits(json.loads(edits))
return root return root
def deserialized(self, data: str) -> Tuple[Section, bool]: def deserialized(self, data: str) -> Tuple[Section, bool]:
section, save = super().deserialized(data) section, save = super().deserialized(data)
if self and self.General.theme != section.General.theme: if self and self.General.theme != section.General.theme:
if hasattr(self.backend, "theme"): if hasattr(self.backend, "theme"):
self.backend.theme.stop_watching() self.backend.theme.stop_watching()
self.backend.theme = Theme( self.backend.theme = Theme(
self.backend, section.General.theme, # type: ignore self.backend, section.General.theme, # type: ignore
) )
UserFileChanged(Theme, self.backend.theme.qml_data) UserFileChanged(Theme, self.backend.theme.qml_data)
# if self and self.General.new_theme != section.General.new_theme: # if self and self.General.new_theme != section.General.new_theme:
# self.backend.new_theme.stop_watching() # self.backend.new_theme.stop_watching()
# self.backend.new_theme = NewTheme( # self.backend.new_theme = NewTheme(
# self.backend, section.General.new_theme, # type: ignore # self.backend, section.General.new_theme, # type: ignore
# ) # )
# UserFileChanged(Theme, self.backend.new_theme.qml_data) # UserFileChanged(Theme, self.backend.new_theme.qml_data)
return (section, save) return (section, save)
@dataclass @dataclass
class NewTheme(UserDataFile, PCNFile): class NewTheme(UserDataFile, PCNFile):
"""A theme file defining the look of QML components.""" """A theme file defining the look of QML components."""
create_missing = False create_missing = False
@property @property
def path(self) -> Path: def path(self) -> Path:
data_dir = Path( data_dir = Path(
os.environ.get("MOMENT_DATA_DIR") or os.environ.get("MOMENT_DATA_DIR") or
self.backend.appdirs.user_data_dir, self.backend.appdirs.user_data_dir,
) )
return data_dir / "themes" / self.filename return data_dir / "themes" / self.filename
@property @property
def qml_data(self) -> Dict[str, Any]: def qml_data(self) -> Dict[str, Any]:
return flatten_dict_keys(super().qml_data, last_level=False) return flatten_dict_keys(super().qml_data, last_level=False)
@dataclass @dataclass
class UIState(UserDataFile, JSONFile): class UIState(UserDataFile, JSONFile):
"""File used to save and restore the state of QML components.""" """File used to save and restore the state of QML components."""
filename: str = "state.json" filename: str = "state.json"
@property @property
def default_data(self) -> dict: def default_data(self) -> dict:
return { return {
"collapseAccounts": {}, "collapseAccounts": {},
"page": "Pages/Default.qml", "page": "Pages/Default.qml",
"pageProperties": {}, "pageProperties": {},
} }
def deserialized(self, data: str) -> Tuple[dict, bool]: def deserialized(self, data: str) -> Tuple[dict, bool]:
dict_data, save = super().deserialized(data) dict_data, save = super().deserialized(data)
for user_id, do in dict_data["collapseAccounts"].items(): for user_id, do in dict_data["collapseAccounts"].items():
self.backend.models["all_rooms"].set_account_collapse(user_id, do) self.backend.models["all_rooms"].set_account_collapse(user_id, do)
return (dict_data, save) return (dict_data, save)
@dataclass @dataclass
class History(UserDataFile, JSONFile): class History(UserDataFile, JSONFile):
"""File to save and restore lines typed by the user in QML components.""" """File to save and restore lines typed by the user in QML components."""
filename: str = "history.json" filename: str = "history.json"
@property @property
def default_data(self) -> dict: def default_data(self) -> dict:
return {"console": []} return {"console": []}
@dataclass @dataclass
class Theme(UserDataFile): class Theme(UserDataFile):
"""A theme file defining the look of QML components.""" """A theme file defining the look of QML components."""
# Since it currently breaks at every update and the file format will be # Since it currently breaks at every update and the file format will be
# changed later, don't copy the theme to user data dir if it doesn't exist. # changed later, don't copy the theme to user data dir if it doesn't exist.
create_missing = False create_missing = False
@property @property
def path(self) -> Path: def path(self) -> Path:
data_dir = Path( data_dir = Path(
os.environ.get("MOMENT_DATA_DIR") or os.environ.get("MOMENT_DATA_DIR") or
self.backend.appdirs.user_data_dir, self.backend.appdirs.user_data_dir,
) )
return data_dir / "themes" / self.filename return data_dir / "themes" / self.filename
@property @property
def default_data(self) -> str: def default_data(self) -> str:
if self.filename in ("Foliage.qpl", "Midnight.qpl", "Glass.qpl"): if self.filename in ("Foliage.qpl", "Midnight.qpl", "Glass.qpl"):
path = f"src/themes/{self.filename}" path = f"src/themes/{self.filename}"
else: else:
path = "src/themes/Foliage.qpl" path = "src/themes/Foliage.qpl"
try: try:
byte_content = pyotherside.qrc_get_file_contents(path) byte_content = pyotherside.qrc_get_file_contents(path)
except ValueError: except ValueError:
# App was compiled without QRC # App was compiled without QRC
return convert_to_qml(Path(path).read_text()) return convert_to_qml(Path(path).read_text())
else: else:
return convert_to_qml(byte_content.decode()) return convert_to_qml(byte_content.decode())
def deserialized(self, data: str) -> Tuple[str, bool]: def deserialized(self, data: str) -> Tuple[str, bool]:
return (convert_to_qml(data), False) return (convert_to_qml(data), False)

View File

@ -20,8 +20,8 @@ from pathlib import Path
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from types import ModuleType from types import ModuleType
from typing import ( from typing import (
Any, AsyncIterator, Callable, Collection, Dict, Iterable, Mapping, Any, AsyncIterator, Callable, Collection, Dict, Iterable, Mapping,
Optional, Tuple, Type, Union, Optional, Tuple, Type, Union,
) )
from uuid import UUID from uuid import UUID
@ -36,348 +36,348 @@ from .color import Color
from .pcn.section import Section from .pcn.section import Section
if sys.version_info >= (3, 7): if sys.version_info >= (3, 7):
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
current_task = asyncio.current_task current_task = asyncio.current_task
else: else:
from async_generator import asynccontextmanager from async_generator import asynccontextmanager
current_task = asyncio.Task.current_task current_task = asyncio.Task.current_task
if sys.version_info >= (3, 10): if sys.version_info >= (3, 10):
import collections.abc as collections import collections.abc as collections
else: else:
import collections import collections
Size = Tuple[int, int] Size = Tuple[int, int]
BytesOrPIL = Union[bytes, PILImage.Image] BytesOrPIL = Union[bytes, PILImage.Image]
auto = autostr auto = autostr
COMPRESSION_POOL = ProcessPoolExecutor() COMPRESSION_POOL = ProcessPoolExecutor()
class AutoStrEnum(Enum): class AutoStrEnum(Enum):
"""An Enum where auto() assigns the member's name instead of an integer. """An Enum where auto() assigns the member's name instead of an integer.
Example: Example:
>>> class Fruits(AutoStrEnum): apple = auto() >>> class Fruits(AutoStrEnum): apple = auto()
>>> Fruits.apple.value >>> Fruits.apple.value
"apple" "apple"
""" """
@staticmethod @staticmethod
def _generate_next_value_(name, *_): def _generate_next_value_(name, *_):
return name return name
def dict_update_recursive(dict1: dict, dict2: dict) -> None: def dict_update_recursive(dict1: dict, dict2: dict) -> None:
"""Deep-merge `dict1` and `dict2`, recursive version of `dict.update()`.""" """Deep-merge `dict1` and `dict2`, recursive version of `dict.update()`."""
# https://gist.github.com/angstwad/bf22d1822c38a92ec0a9 # https://gist.github.com/angstwad/bf22d1822c38a92ec0a9
for k in dict2: for k in dict2:
if (k in dict1 and isinstance(dict1[k], dict) and if (k in dict1 and isinstance(dict1[k], dict) and
isinstance(dict2[k], collections.Mapping)): isinstance(dict2[k], collections.Mapping)):
dict_update_recursive(dict1[k], dict2[k]) dict_update_recursive(dict1[k], dict2[k])
else: else:
dict1[k] = dict2[k] dict1[k] = dict2[k]
def flatten_dict_keys( def flatten_dict_keys(
source: Optional[Dict[str, Any]] = None, source: Optional[Dict[str, Any]] = None,
separator: str = ".", separator: str = ".",
last_level: bool = True, last_level: bool = True,
_flat: Optional[Dict[str, Any]] = None, _flat: Optional[Dict[str, Any]] = None,
_prefix: str = "", _prefix: str = "",
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Return a flattened version of the ``source`` dict. """Return a flattened version of the ``source`` dict.
Example: Example:
>>> dct >>> dct
{"content": {"body": "foo"}, "m.test": {"key": {"bar": 1}}} {"content": {"body": "foo"}, "m.test": {"key": {"bar": 1}}}
>>> flatten_dict_keys(dct) >>> flatten_dict_keys(dct)
{"content.body": "foo", "m.test.key.bar": 1} {"content.body": "foo", "m.test.key.bar": 1}
>>> flatten_dict_keys(dct, last_level=False) >>> flatten_dict_keys(dct, last_level=False)
{"content": {"body": "foo"}, "m.test.key": {bar": 1}} {"content": {"body": "foo"}, "m.test.key": {bar": 1}}
""" """
flat = {} if _flat is None else _flat flat = {} if _flat is None else _flat
for key, value in (source or {}).items(): for key, value in (source or {}).items():
if isinstance(value, dict): if isinstance(value, dict):
prefix = f"{_prefix}{key}{separator}" prefix = f"{_prefix}{key}{separator}"
flatten_dict_keys(value, separator, last_level, flat, prefix) flatten_dict_keys(value, separator, last_level, flat, prefix)
elif last_level: elif last_level:
flat[f"{_prefix}{key}"] = value flat[f"{_prefix}{key}"] = value
else: else:
prefix = _prefix[:-len(separator)] # remove trailing separator prefix = _prefix[:-len(separator)] # remove trailing separator
flat.setdefault(prefix, {})[key] = value flat.setdefault(prefix, {})[key] = value
return flat return flat
def config_get_account_room_rule( def config_get_account_room_rule(
rules: Section, user_id: str, room_id: str, rules: Section, user_id: str, room_id: str,
) -> Any: ) -> Any:
"""Return best matching rule value for an account/room PCN free Section.""" """Return best matching rule value for an account/room PCN free Section."""
for name, value in reversed(rules.children()): for name, value in reversed(rules.children()):
name = re.sub(r"\s+", " ", name.strip()) name = re.sub(r"\s+", " ", name.strip())
if name in (user_id, room_id, f"{user_id} {room_id}"): if name in (user_id, room_id, f"{user_id} {room_id}"):
return value return value
return rules.default return rules.default
async def is_svg(file: File) -> bool: async def is_svg(file: File) -> bool:
"""Return whether the file is a SVG (`lxml` is used for detection).""" """Return whether the file is a SVG (`lxml` is used for detection)."""
chunks = [c async for c in async_generator_from_data(file)] chunks = [c async for c in async_generator_from_data(file)]
with io.BytesIO(b"".join(chunks)) as file: with io.BytesIO(b"".join(chunks)) as file:
try: try:
_, element = next(xml_etree.iterparse(file, ("start",))) _, element = next(xml_etree.iterparse(file, ("start",)))
return element.tag == "{http://www.w3.org/2000/svg}svg" return element.tag == "{http://www.w3.org/2000/svg}svg"
except (StopIteration, xml_etree.ParseError): except (StopIteration, xml_etree.ParseError):
return False return False
async def svg_dimensions(file: File) -> Size: async def svg_dimensions(file: File) -> Size:
"""Return the width and height, or viewBox width and height for a SVG. """Return the width and height, or viewBox width and height for a SVG.
If these properties are missing (broken file), ``(256, 256)`` is returned. If these properties are missing (broken file), ``(256, 256)`` is returned.
""" """
chunks = [c async for c in async_generator_from_data(file)] chunks = [c async for c in async_generator_from_data(file)]
with io.BytesIO(b"".join(chunks)) as file: with io.BytesIO(b"".join(chunks)) as file:
attrs = xml_etree.parse(file).getroot().attrib attrs = xml_etree.parse(file).getroot().attrib
try: try:
width = round(float(attrs.get("width", attrs["viewBox"].split()[3]))) width = round(float(attrs.get("width", attrs["viewBox"].split()[3])))
except (KeyError, IndexError, ValueError, TypeError): except (KeyError, IndexError, ValueError, TypeError):
width = 256 width = 256
try: try:
height = round(float(attrs.get("height", attrs["viewBox"].split()[4]))) height = round(float(attrs.get("height", attrs["viewBox"].split()[4])))
except (KeyError, IndexError, ValueError, TypeError): except (KeyError, IndexError, ValueError, TypeError):
height = 256 height = 256
return (width, height) return (width, height)
async def guess_mime(file: File) -> str: async def guess_mime(file: File) -> str:
"""Return the file's mimetype, or `application/octet-stream` if unknown.""" """Return the file's mimetype, or `application/octet-stream` if unknown."""
if isinstance(file, io.IOBase): if isinstance(file, io.IOBase):
file.seek(0, 0) file.seek(0, 0)
elif isinstance(file, AsyncBufferedIOBase): elif isinstance(file, AsyncBufferedIOBase):
await file.seek(0, 0) await file.seek(0, 0)
try: try:
first_chunk: bytes first_chunk: bytes
async for first_chunk in async_generator_from_data(file): async for first_chunk in async_generator_from_data(file):
break break
else: else:
return "inode/x-empty" # empty file return "inode/x-empty" # empty file
# TODO: plaintext # TODO: plaintext
mime = filetype.guess_mime(first_chunk) mime = filetype.guess_mime(first_chunk)
return mime or ( return mime or (
"image/svg+xml" if await is_svg(file) else "image/svg+xml" if await is_svg(file) else
"application/octet-stream" "application/octet-stream"
) )
finally: finally:
if isinstance(file, io.IOBase): if isinstance(file, io.IOBase):
file.seek(0, 0) file.seek(0, 0)
elif isinstance(file, AsyncBufferedIOBase): elif isinstance(file, AsyncBufferedIOBase):
await file.seek(0, 0) await file.seek(0, 0)
def plain2html(text: str) -> str: def plain2html(text: str) -> str:
"""Convert `\\n` into `<br>` tags and `\\t` into four spaces.""" """Convert `\\n` into `<br>` tags and `\\t` into four spaces."""
return html.escape(text)\ return html.escape(text)\
.replace("\n", "<br>")\ .replace("\n", "<br>")\
.replace("\t", "&nbsp;" * 4) .replace("\t", "&nbsp;" * 4)
def strip_html_tags(text: str) -> str: def strip_html_tags(text: str) -> str:
"""Remove HTML tags from text.""" """Remove HTML tags from text."""
return re.sub(r"<\/?[^>]+(>|$)", "", text) return re.sub(r"<\/?[^>]+(>|$)", "", text)
def serialize_value_for_qml( def serialize_value_for_qml(
value: Any, json_list_dicts: bool = False, reject_unknown: bool = False, value: Any, json_list_dicts: bool = False, reject_unknown: bool = False,
) -> Any: ) -> Any:
"""Convert a value to make it easier to use from QML. """Convert a value to make it easier to use from QML.
Returns: Returns:
- For `bool`, `int`, `float`, `bytes`, `str`, `datetime`, `date`, `time`: - For `bool`, `int`, `float`, `bytes`, `str`, `datetime`, `date`, `time`:
the unchanged value (PyOtherSide handles these) the unchanged value (PyOtherSide handles these)
- For `Collection` objects (includes `list` and `dict`): - For `Collection` objects (includes `list` and `dict`):
a JSON dump if `json_list_dicts` is `True`, else the unchanged value a JSON dump if `json_list_dicts` is `True`, else the unchanged value
- If the value is an instancied object and has a `serialized` attribute or - If the value is an instancied object and has a `serialized` attribute or
property, return that property, return that
- For `Enum` members, the actual value of the member - For `Enum` members, the actual value of the member
- For `Path` objects, a `file://<path...>` string - For `Path` objects, a `file://<path...>` string
- For `UUID` object: the UUID in string form - For `UUID` object: the UUID in string form
- For `timedelta` objects: the delta as a number of milliseconds `int` - For `timedelta` objects: the delta as a number of milliseconds `int`
- For `Color` objects: the color's hexadecimal value - For `Color` objects: the color's hexadecimal value
- For class types: the class `__name__` - For class types: the class `__name__`
- For anything else: raise a `TypeError` if `reject_unknown` is `True`, - For anything else: raise a `TypeError` if `reject_unknown` is `True`,
else return the unchanged value. else return the unchanged value.
""" """
if isinstance(value, (bool, int, float, bytes, str, datetime, date, time)): if isinstance(value, (bool, int, float, bytes, str, datetime, date, time)):
return value return value
if json_list_dicts and isinstance(value, Collection): if json_list_dicts and isinstance(value, Collection):
if isinstance(value, set): if isinstance(value, set):
value = list(value) value = list(value)
return json.dumps(value) return json.dumps(value)
if not inspect.isclass(value) and hasattr(value, "serialized"): if not inspect.isclass(value) and hasattr(value, "serialized"):
return value.serialized return value.serialized
if isinstance(value, Iterable): if isinstance(value, Iterable):
return value return value
if hasattr(value, "__class__") and issubclass(value.__class__, Enum): if hasattr(value, "__class__") and issubclass(value.__class__, Enum):
return value.value return value.value
if isinstance(value, Path): if isinstance(value, Path):
return f"file://{value!s}" return f"file://{value!s}"
if isinstance(value, UUID): if isinstance(value, UUID):
return str(value) return str(value)
if isinstance(value, timedelta): if isinstance(value, timedelta):
return value.total_seconds() * 1000 return value.total_seconds() * 1000
if isinstance(value, Color): if isinstance(value, Color):
return value.hex return value.hex
if inspect.isclass(value): if inspect.isclass(value):
return value.__name__ return value.__name__
if reject_unknown: if reject_unknown:
raise TypeError("Unknown type reject") raise TypeError("Unknown type reject")
return value return value
def deep_serialize_for_qml(obj: Iterable) -> Union[list, dict]: def deep_serialize_for_qml(obj: Iterable) -> Union[list, dict]:
"""Recursively serialize lists and dict values for QML.""" """Recursively serialize lists and dict values for QML."""
if isinstance(obj, Mapping): if isinstance(obj, Mapping):
dct = {} dct = {}
for key, value in obj.items(): for key, value in obj.items():
if isinstance(value, Iterable) and not isinstance(value, str): if isinstance(value, Iterable) and not isinstance(value, str):
# PyOtherSide only accept dicts with string keys # PyOtherSide only accept dicts with string keys
dct[str(key)] = deep_serialize_for_qml(value) dct[str(key)] = deep_serialize_for_qml(value)
continue continue
with suppress(TypeError): with suppress(TypeError):
dct[str(key)] = \ dct[str(key)] = \
serialize_value_for_qml(value, reject_unknown=True) serialize_value_for_qml(value, reject_unknown=True)
return dct return dct
lst = [] lst = []
for value in obj: for value in obj:
if isinstance(value, Iterable) and not isinstance(value, str): if isinstance(value, Iterable) and not isinstance(value, str):
lst.append(deep_serialize_for_qml(value)) lst.append(deep_serialize_for_qml(value))
continue continue
with suppress(TypeError): with suppress(TypeError):
lst.append(serialize_value_for_qml(value, reject_unknown=True)) lst.append(serialize_value_for_qml(value, reject_unknown=True))
return lst return lst
def classes_defined_in(module: ModuleType) -> Dict[str, Type]: def classes_defined_in(module: ModuleType) -> Dict[str, Type]:
"""Return a `{name: class}` dict of all the classes a module defines.""" """Return a `{name: class}` dict of all the classes a module defines."""
return { return {
m[0]: m[1] for m in inspect.getmembers(module, inspect.isclass) m[0]: m[1] for m in inspect.getmembers(module, inspect.isclass)
if not m[0].startswith("_") and if not m[0].startswith("_") and
m[1].__module__.startswith(module.__name__) m[1].__module__.startswith(module.__name__)
} }
@asynccontextmanager @asynccontextmanager
async def aiopen(*args, **kwargs) -> AsyncIterator[Any]: async def aiopen(*args, **kwargs) -> AsyncIterator[Any]:
"""Wrapper for `aiofiles.open()` that doesn't break mypy""" """Wrapper for `aiofiles.open()` that doesn't break mypy"""
async with aiofiles.open(*args, **kwargs) as file: async with aiofiles.open(*args, **kwargs) as file:
yield file yield file
@asynccontextmanager @asynccontextmanager
async def atomic_write( async def atomic_write(
path: Union[Path, str], binary: bool = False, **kwargs, path: Union[Path, str], binary: bool = False, **kwargs,
) -> AsyncIterator[Tuple[Any, Callable[[], None]]]: ) -> AsyncIterator[Tuple[Any, Callable[[], None]]]:
"""Write a file asynchronously (using aiofiles) and atomically. """Write a file asynchronously (using aiofiles) and atomically.
Yields a `(open_temporary_file, done_function)` tuple. Yields a `(open_temporary_file, done_function)` tuple.
The done function should be called after writing to the given file. The done function should be called after writing to the given file.
When the context manager exits, the temporary file will either replace When the context manager exits, the temporary file will either replace
`path` if the function was called, or be deleted. `path` if the function was called, or be deleted.
Example: Example:
>>> async with atomic_write("foo.txt") as (file, done): >>> async with atomic_write("foo.txt") as (file, done):
>>> await file.write("Sample text") >>> await file.write("Sample text")
>>> done() >>> done()
""" """
mode = "wb" if binary else "w" mode = "wb" if binary else "w"
path = Path(path) path = Path(path)
temp = NamedTemporaryFile(dir=path.parent, delete=False) temp = NamedTemporaryFile(dir=path.parent, delete=False)
temp_path = Path(temp.name) temp_path = Path(temp.name)
can_replace = False can_replace = False
def done() -> None: def done() -> None:
nonlocal can_replace nonlocal can_replace
can_replace = True can_replace = True
try: try:
async with aiopen(temp_path, mode, **kwargs) as out: async with aiopen(temp_path, mode, **kwargs) as out:
yield (out, done) yield (out, done)
finally: finally:
if can_replace: if can_replace:
temp_path.replace(path) temp_path.replace(path)
else: else:
temp_path.unlink() temp_path.unlink()
def _compress(image: BytesOrPIL, fmt: str, optimize: bool) -> bytes: def _compress(image: BytesOrPIL, fmt: str, optimize: bool) -> bytes:
if isinstance(image, bytes): if isinstance(image, bytes):
pil_image = PILImage.open(io.BytesIO(image)) pil_image = PILImage.open(io.BytesIO(image))
else: else:
pil_image = image pil_image = image
with io.BytesIO() as buffer: with io.BytesIO() as buffer:
pil_image.save(buffer, fmt, optimize=optimize) pil_image.save(buffer, fmt, optimize=optimize)
return buffer.getvalue() return buffer.getvalue()
async def compress_image( async def compress_image(
image: BytesOrPIL, fmt: str = "PNG", optimize: bool = True, image: BytesOrPIL, fmt: str = "PNG", optimize: bool = True,
) -> bytes: ) -> bytes:
"""Compress image in a separate process, without blocking event loop.""" """Compress image in a separate process, without blocking event loop."""
return await asyncio.get_event_loop().run_in_executor( return await asyncio.get_event_loop().run_in_executor(
COMPRESSION_POOL, _compress, image, fmt, optimize, COMPRESSION_POOL, _compress, image, fmt, optimize,
) )

File diff suppressed because it is too large Load Diff