Ensure we don't atomically write an empty file

This commit is contained in:
miruka 2020-03-17 12:09:08 -04:00
parent 9a66afb4eb
commit e4d47b9b9c
3 changed files with 36 additions and 11 deletions

View File

@ -138,8 +138,9 @@ class Media:
self.local_path.parent.mkdir(parents=True, exist_ok=True) self.local_path.parent.mkdir(parents=True, exist_ok=True)
async with atomic_write(self.local_path, binary=True) as file: async with atomic_write(self.local_path, binary=True) as (file, done):
await file.write(data) await file.write(data)
done()
return self.local_path return self.local_path
@ -212,8 +213,11 @@ class Media:
media.local_path.parent.mkdir(parents=True, exist_ok=True) media.local_path.parent.mkdir(parents=True, exist_ok=True)
if not media.local_path.exists() or overwrite: if not media.local_path.exists() or overwrite:
async with atomic_write(media.local_path, binary=True) as file: path = media.local_path
async with atomic_write(path, binary=True) as (file, done):
await file.write(data) await file.write(data)
done()
return media return media

View File

@ -90,8 +90,9 @@ class DataFile:
if not self.create_missing and not self.path.exists(): if not self.create_missing and not self.path.exists():
continue continue
async with atomic_write(self.path) as new: async with atomic_write(self.path) as (new, done):
await new.write(self._to_write) await new.write(self._to_write)
done()
self._to_write = None self._to_write = None

View File

@ -16,7 +16,7 @@ from pathlib import Path
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from types import ModuleType from types import ModuleType
from typing import ( from typing import (
Any, AsyncIterator, Dict, Mapping, Sequence, Tuple, Type, Union, Any, AsyncIterator, Callable, Dict, Mapping, Sequence, Tuple, Type, Union,
) )
from uuid import UUID from uuid import UUID
@ -24,7 +24,6 @@ import aiofiles
import filetype import filetype
from aiofiles.threadpool.binary import AsyncBufferedReader from aiofiles.threadpool.binary import AsyncBufferedReader
from aiofiles.threadpool.text import AsyncTextIOWrapper from aiofiles.threadpool.text import AsyncTextIOWrapper
from nio.crypto import AsyncDataT as File from nio.crypto import AsyncDataT as File
from nio.crypto import async_generator_from_data from nio.crypto import async_generator_from_data
@ -33,6 +32,7 @@ if sys.version_info >= (3, 7):
else: else:
from async_generator import asynccontextmanager from async_generator import asynccontextmanager
AsyncOpenFile = Union[AsyncTextIOWrapper, AsyncBufferedReader]
Size = Tuple[int, int] Size = Tuple[int, int]
auto = autostr auto = autostr
@ -202,16 +202,36 @@ def classes_defined_in(module: ModuleType) -> Dict[str, Type]:
@asynccontextmanager @asynccontextmanager
async def atomic_write( async def atomic_write(
path: Union[Path, str], binary: bool = False, **kwargs, path: Union[Path, str], binary: bool = False, **kwargs,
) -> AsyncIterator[Union[AsyncTextIOWrapper, AsyncBufferedReader]]: ) -> AsyncIterator[Tuple[AsyncOpenFile, Callable[[], None]]]:
"""Write a file asynchronously (using aiofiles) and atomically.""" """Write a file asynchronously (using aiofiles) and atomically.
Yields a `(open_temporary_file, done_function)` tuple.
The done function should be called after writing to the given file.
When the context manager exits, the temporary file will either replace
`path` if the function was called, or be deleted.
Example:
>>> async with atomic_write("foo.txt") as (file, done):
>>> await file.write("Sample text")
>>> done()
"""
mode = "wb" if binary else "w" mode = "wb" if binary else "w"
path = Path(path) path = Path(path)
temp = NamedTemporaryFile(dir=path.parent, delete=False) temp = NamedTemporaryFile(dir=path.parent, delete=False)
temp_path = Path(temp.name) temp_path = Path(temp.name)
can_replace = False
def done() -> None:
nonlocal can_replace
can_replace = True
try: try:
async with aiofiles.open(temp_path, mode, **kwargs) as out: async with aiofiles.open(temp_path, mode, **kwargs) as out:
yield out yield (out, done)
finally: finally:
if can_replace:
temp_path.replace(path) temp_path.replace(path)
else:
temp_path.unlink()