diff --git a/src/backend/media_cache.py b/src/backend/media_cache.py index 104a19f3..e2e49c23 100644 --- a/src/backend/media_cache.py +++ b/src/backend/media_cache.py @@ -138,8 +138,9 @@ class Media: self.local_path.parent.mkdir(parents=True, exist_ok=True) - async with atomic_write(self.local_path, binary=True) as file: + async with atomic_write(self.local_path, binary=True) as (file, done): await file.write(data) + done() return self.local_path @@ -212,8 +213,11 @@ class Media: media.local_path.parent.mkdir(parents=True, exist_ok=True) if not media.local_path.exists() or overwrite: - async with atomic_write(media.local_path, binary=True) as file: + path = media.local_path + + async with atomic_write(path, binary=True) as (file, done): await file.write(data) + done() return media diff --git a/src/backend/user_files.py b/src/backend/user_files.py index bf72eb99..dbe332fb 100644 --- a/src/backend/user_files.py +++ b/src/backend/user_files.py @@ -90,8 +90,9 @@ class DataFile: if not self.create_missing and not self.path.exists(): continue - async with atomic_write(self.path) as new: + async with atomic_write(self.path) as (new, done): await new.write(self._to_write) + done() self._to_write = None diff --git a/src/backend/utils.py b/src/backend/utils.py index ee021219..eb0a8378 100644 --- a/src/backend/utils.py +++ b/src/backend/utils.py @@ -16,7 +16,7 @@ from pathlib import Path from tempfile import NamedTemporaryFile from types import ModuleType from typing import ( - Any, AsyncIterator, Dict, Mapping, Sequence, Tuple, Type, Union, + Any, AsyncIterator, Callable, Dict, Mapping, Sequence, Tuple, Type, Union, ) from uuid import UUID @@ -24,7 +24,6 @@ import aiofiles import filetype from aiofiles.threadpool.binary import AsyncBufferedReader from aiofiles.threadpool.text import AsyncTextIOWrapper - from nio.crypto import AsyncDataT as File from nio.crypto import async_generator_from_data @@ -33,8 +32,9 @@ if sys.version_info >= (3, 7): else: from async_generator import asynccontextmanager -Size = Tuple[int, int] -auto = autostr +AsyncOpenFile = Union[AsyncTextIOWrapper, AsyncBufferedReader] +Size = Tuple[int, int] +auto = autostr class AutoStrEnum(Enum): @@ -202,16 +202,36 @@ def classes_defined_in(module: ModuleType) -> Dict[str, Type]: @asynccontextmanager async def atomic_write( path: Union[Path, str], binary: bool = False, **kwargs, -) -> AsyncIterator[Union[AsyncTextIOWrapper, AsyncBufferedReader]]: - """Write a file asynchronously (using aiofiles) and atomically.""" +) -> AsyncIterator[Tuple[AsyncOpenFile, Callable[[], None]]]: + """Write a file asynchronously (using aiofiles) and atomically. + + Yields a `(open_temporary_file, done_function)` tuple. + The done function should be called after writing to the given file. + When the context manager exits, the temporary file will either replace + `path` if the function was called, or be deleted. + + Example: + >>> async with atomic_write("foo.txt") as (file, done): + >>> await file.write("Sample text") + >>> done() + """ mode = "wb" if binary else "w" path = Path(path) temp = NamedTemporaryFile(dir=path.parent, delete=False) temp_path = Path(temp.name) + can_replace = False + + def done() -> None: + nonlocal can_replace + can_replace = True + try: async with aiofiles.open(temp_path, mode, **kwargs) as out: - yield out + yield (out, done) finally: - temp_path.replace(path) + if can_replace: + temp_path.replace(path) + else: + temp_path.unlink()