Ensure we don't atomically write an empty file
This commit is contained in:
parent
9a66afb4eb
commit
e4d47b9b9c
|
@ -138,8 +138,9 @@ class Media:
|
||||||
|
|
||||||
self.local_path.parent.mkdir(parents=True, exist_ok=True)
|
self.local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
async with atomic_write(self.local_path, binary=True) as file:
|
async with atomic_write(self.local_path, binary=True) as (file, done):
|
||||||
await file.write(data)
|
await file.write(data)
|
||||||
|
done()
|
||||||
|
|
||||||
return self.local_path
|
return self.local_path
|
||||||
|
|
||||||
|
@ -212,8 +213,11 @@ class Media:
|
||||||
media.local_path.parent.mkdir(parents=True, exist_ok=True)
|
media.local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
if not media.local_path.exists() or overwrite:
|
if not media.local_path.exists() or overwrite:
|
||||||
async with atomic_write(media.local_path, binary=True) as file:
|
path = media.local_path
|
||||||
|
|
||||||
|
async with atomic_write(path, binary=True) as (file, done):
|
||||||
await file.write(data)
|
await file.write(data)
|
||||||
|
done()
|
||||||
|
|
||||||
return media
|
return media
|
||||||
|
|
||||||
|
|
|
@ -90,8 +90,9 @@ class DataFile:
|
||||||
if not self.create_missing and not self.path.exists():
|
if not self.create_missing and not self.path.exists():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
async with atomic_write(self.path) as new:
|
async with atomic_write(self.path) as (new, done):
|
||||||
await new.write(self._to_write)
|
await new.write(self._to_write)
|
||||||
|
done()
|
||||||
|
|
||||||
self._to_write = None
|
self._to_write = None
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ from pathlib import Path
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import (
|
from typing import (
|
||||||
Any, AsyncIterator, Dict, Mapping, Sequence, Tuple, Type, Union,
|
Any, AsyncIterator, Callable, Dict, Mapping, Sequence, Tuple, Type, Union,
|
||||||
)
|
)
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
@ -24,7 +24,6 @@ import aiofiles
|
||||||
import filetype
|
import filetype
|
||||||
from aiofiles.threadpool.binary import AsyncBufferedReader
|
from aiofiles.threadpool.binary import AsyncBufferedReader
|
||||||
from aiofiles.threadpool.text import AsyncTextIOWrapper
|
from aiofiles.threadpool.text import AsyncTextIOWrapper
|
||||||
|
|
||||||
from nio.crypto import AsyncDataT as File
|
from nio.crypto import AsyncDataT as File
|
||||||
from nio.crypto import async_generator_from_data
|
from nio.crypto import async_generator_from_data
|
||||||
|
|
||||||
|
@ -33,6 +32,7 @@ if sys.version_info >= (3, 7):
|
||||||
else:
|
else:
|
||||||
from async_generator import asynccontextmanager
|
from async_generator import asynccontextmanager
|
||||||
|
|
||||||
|
AsyncOpenFile = Union[AsyncTextIOWrapper, AsyncBufferedReader]
|
||||||
Size = Tuple[int, int]
|
Size = Tuple[int, int]
|
||||||
auto = autostr
|
auto = autostr
|
||||||
|
|
||||||
|
@ -202,16 +202,36 @@ def classes_defined_in(module: ModuleType) -> Dict[str, Type]:
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def atomic_write(
|
async def atomic_write(
|
||||||
path: Union[Path, str], binary: bool = False, **kwargs,
|
path: Union[Path, str], binary: bool = False, **kwargs,
|
||||||
) -> AsyncIterator[Union[AsyncTextIOWrapper, AsyncBufferedReader]]:
|
) -> AsyncIterator[Tuple[AsyncOpenFile, Callable[[], None]]]:
|
||||||
"""Write a file asynchronously (using aiofiles) and atomically."""
|
"""Write a file asynchronously (using aiofiles) and atomically.
|
||||||
|
|
||||||
|
Yields a `(open_temporary_file, done_function)` tuple.
|
||||||
|
The done function should be called after writing to the given file.
|
||||||
|
When the context manager exits, the temporary file will either replace
|
||||||
|
`path` if the function was called, or be deleted.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> async with atomic_write("foo.txt") as (file, done):
|
||||||
|
>>> await file.write("Sample text")
|
||||||
|
>>> done()
|
||||||
|
"""
|
||||||
|
|
||||||
mode = "wb" if binary else "w"
|
mode = "wb" if binary else "w"
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
temp = NamedTemporaryFile(dir=path.parent, delete=False)
|
temp = NamedTemporaryFile(dir=path.parent, delete=False)
|
||||||
temp_path = Path(temp.name)
|
temp_path = Path(temp.name)
|
||||||
|
|
||||||
|
can_replace = False
|
||||||
|
|
||||||
|
def done() -> None:
|
||||||
|
nonlocal can_replace
|
||||||
|
can_replace = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with aiofiles.open(temp_path, mode, **kwargs) as out:
|
async with aiofiles.open(temp_path, mode, **kwargs) as out:
|
||||||
yield out
|
yield (out, done)
|
||||||
finally:
|
finally:
|
||||||
|
if can_replace:
|
||||||
temp_path.replace(path)
|
temp_path.replace(path)
|
||||||
|
else:
|
||||||
|
temp_path.unlink()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user