Ensure we don't atomically write an empty file

This commit is contained in:
miruka 2020-03-17 12:09:08 -04:00
parent 9a66afb4eb
commit e4d47b9b9c
3 changed files with 36 additions and 11 deletions

View File

@ -138,8 +138,9 @@ class Media:
self.local_path.parent.mkdir(parents=True, exist_ok=True)
async with atomic_write(self.local_path, binary=True) as file:
async with atomic_write(self.local_path, binary=True) as (file, done):
await file.write(data)
done()
return self.local_path
@ -212,8 +213,11 @@ class Media:
media.local_path.parent.mkdir(parents=True, exist_ok=True)
if not media.local_path.exists() or overwrite:
async with atomic_write(media.local_path, binary=True) as file:
path = media.local_path
async with atomic_write(path, binary=True) as (file, done):
await file.write(data)
done()
return media

View File

@ -90,8 +90,9 @@ class DataFile:
if not self.create_missing and not self.path.exists():
continue
async with atomic_write(self.path) as new:
async with atomic_write(self.path) as (new, done):
await new.write(self._to_write)
done()
self._to_write = None

View File

@ -16,7 +16,7 @@ from pathlib import Path
from tempfile import NamedTemporaryFile
from types import ModuleType
from typing import (
Any, AsyncIterator, Dict, Mapping, Sequence, Tuple, Type, Union,
Any, AsyncIterator, Callable, Dict, Mapping, Sequence, Tuple, Type, Union,
)
from uuid import UUID
@ -24,7 +24,6 @@ import aiofiles
import filetype
from aiofiles.threadpool.binary import AsyncBufferedReader
from aiofiles.threadpool.text import AsyncTextIOWrapper
from nio.crypto import AsyncDataT as File
from nio.crypto import async_generator_from_data
@ -33,8 +32,9 @@ if sys.version_info >= (3, 7):
else:
from async_generator import asynccontextmanager
Size = Tuple[int, int]
auto = autostr
AsyncOpenFile = Union[AsyncTextIOWrapper, AsyncBufferedReader]
Size = Tuple[int, int]
auto = autostr
class AutoStrEnum(Enum):
@ -202,16 +202,36 @@ def classes_defined_in(module: ModuleType) -> Dict[str, Type]:
@asynccontextmanager
async def atomic_write(
path: Union[Path, str], binary: bool = False, **kwargs,
) -> AsyncIterator[Union[AsyncTextIOWrapper, AsyncBufferedReader]]:
"""Write a file asynchronously (using aiofiles) and atomically."""
) -> AsyncIterator[Tuple[AsyncOpenFile, Callable[[], None]]]:
"""Write a file asynchronously (using aiofiles) and atomically.
Yields a `(open_temporary_file, done_function)` tuple.
The done function should be called after writing to the given file.
When the context manager exits, the temporary file will either replace
`path` if the function was called, or be deleted.
Example:
>>> async with atomic_write("foo.txt") as (file, done):
>>> await file.write("Sample text")
>>> done()
"""
mode = "wb" if binary else "w"
path = Path(path)
temp = NamedTemporaryFile(dir=path.parent, delete=False)
temp_path = Path(temp.name)
can_replace = False
def done() -> None:
nonlocal can_replace
can_replace = True
try:
async with aiofiles.open(temp_path, mode, **kwargs) as out:
yield out
yield (out, done)
finally:
temp_path.replace(path)
if can_replace:
temp_path.replace(path)
else:
temp_path.unlink()