diff --git a/TODO.md b/TODO.md index d1a959a7..fc92bc53 100644 --- a/TODO.md +++ b/TODO.md @@ -2,7 +2,6 @@ ## Before release -- Atomic - Catch server 5xx errors when sending message and retry - Update README.md diff --git a/src/backend/media_cache.py b/src/backend/media_cache.py index 57b7992e..104a19f3 100644 --- a/src/backend/media_cache.py +++ b/src/backend/media_cache.py @@ -13,11 +13,11 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, DefaultDict, Dict, Optional from urllib.parse import urlparse -import aiofiles -import nio from PIL import Image as PILImage -from .utils import Size +import nio + +from .utils import Size, atomic_write if TYPE_CHECKING: from .backend import Backend @@ -138,7 +138,7 @@ class Media: self.local_path.parent.mkdir(parents=True, exist_ok=True) - async with aiofiles.open(self.local_path, "wb") as file: + async with atomic_write(self.local_path, binary=True) as file: await file.write(data) return self.local_path @@ -212,7 +212,7 @@ class Media: media.local_path.parent.mkdir(parents=True, exist_ok=True) if not media.local_path.exists() or overwrite: - async with aiofiles.open(media.local_path, "wb") as file: + async with atomic_write(media.local_path, binary=True) as file: await file.write(data) return media diff --git a/src/backend/user_files.py b/src/backend/user_files.py index a793ba60..232b3780 100644 --- a/src/backend/user_files.py +++ b/src/backend/user_files.py @@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional import aiofiles from .theme_parser import convert_to_qml -from .utils import dict_update_recursive +from .utils import atomic_write, dict_update_recursive if TYPE_CHECKING: from .backend import Backend @@ -88,7 +88,7 @@ class DataFile: if not self.create_missing and not self.path.exists(): continue - async with aiofiles.open(self.path, "w") as new: + async with atomic_write(self.path) as new: await new.write(self._to_write) self._to_write = None diff --git a/src/backend/utils.py b/src/backend/utils.py index e5401587..7a0fb280 100644 --- a/src/backend/utils.py +++ b/src/backend/utils.py @@ -8,16 +8,23 @@ import inspect import io import json import xml.etree.cElementTree as xml_etree # FIXME: bandit warning +from contextlib import asynccontextmanager from datetime import datetime, timedelta from enum import Enum from enum import auto as autostr from pathlib import Path +from tempfile import NamedTemporaryFile from types import ModuleType -from typing import Any, Dict, Mapping, Sequence, Tuple, Type +from typing import ( + Any, AsyncIterator, Dict, Mapping, Sequence, Tuple, Type, Union, +) from uuid import UUID +import aiofiles import filetype from aiofiles.threadpool.binary import AsyncBufferedReader +from aiofiles.threadpool.text import AsyncTextIOWrapper + from nio.crypto import AsyncDataT as File from nio.crypto import async_generator_from_data @@ -185,3 +192,21 @@ def classes_defined_in(module: ModuleType) -> Dict[str, Type]: if not m[0].startswith("_") and m[1].__module__.startswith(module.__name__) } + + +@asynccontextmanager +async def atomic_write( + path: Union[Path, str], binary: bool = False, **kwargs, +) -> AsyncIterator[Union[AsyncTextIOWrapper, AsyncBufferedReader]]: + """Write a file asynchronously (using aiofiles) and atomically.""" + + mode = "wb" if binary else "w" + path = Path(path) + temp = NamedTemporaryFile(dir=path.parent, delete=False) + temp_path = Path(temp.name) + + try: + async with aiofiles.open(temp_path, mode, **kwargs) as out: + yield out + finally: + temp_path.replace(path)