Write user files and media atomically

This commit is contained in:
miruka 2020-03-13 04:35:51 -04:00
parent 9d3e2dbfc4
commit 190eb58187
4 changed files with 33 additions and 9 deletions

View File

@ -2,7 +2,6 @@
## Before release
- Atomic
- Catch server 5xx errors when sending message and retry
- Update README.md

View File

@ -13,11 +13,11 @@ from pathlib import Path
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, Optional
from urllib.parse import urlparse
import aiofiles
import nio
from PIL import Image as PILImage
from .utils import Size
import nio
from .utils import Size, atomic_write
if TYPE_CHECKING:
from .backend import Backend
@ -138,7 +138,7 @@ class Media:
self.local_path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(self.local_path, "wb") as file:
async with atomic_write(self.local_path, binary=True) as file:
await file.write(data)
return self.local_path
@ -212,7 +212,7 @@ class Media:
media.local_path.parent.mkdir(parents=True, exist_ok=True)
if not media.local_path.exists() or overwrite:
async with aiofiles.open(media.local_path, "wb") as file:
async with atomic_write(media.local_path, binary=True) as file:
await file.write(data)
return media

View File

@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional
import aiofiles
from .theme_parser import convert_to_qml
from .utils import dict_update_recursive
from .utils import atomic_write, dict_update_recursive
if TYPE_CHECKING:
from .backend import Backend
@ -88,7 +88,7 @@ class DataFile:
if not self.create_missing and not self.path.exists():
continue
async with aiofiles.open(self.path, "w") as new:
async with atomic_write(self.path) as new:
await new.write(self._to_write)
self._to_write = None

View File

@ -8,16 +8,23 @@ import inspect
import io
import json
import xml.etree.cElementTree as xml_etree # FIXME: bandit warning
from contextlib import asynccontextmanager
from datetime import datetime, timedelta
from enum import Enum
from enum import auto as autostr
from pathlib import Path
from tempfile import NamedTemporaryFile
from types import ModuleType
from typing import Any, Dict, Mapping, Sequence, Tuple, Type
from typing import (
Any, AsyncIterator, Dict, Mapping, Sequence, Tuple, Type, Union,
)
from uuid import UUID
import aiofiles
import filetype
from aiofiles.threadpool.binary import AsyncBufferedReader
from aiofiles.threadpool.text import AsyncTextIOWrapper
from nio.crypto import AsyncDataT as File
from nio.crypto import async_generator_from_data
@ -185,3 +192,21 @@ def classes_defined_in(module: ModuleType) -> Dict[str, Type]:
if not m[0].startswith("_") and
m[1].__module__.startswith(module.__name__)
}
@asynccontextmanager
async def atomic_write(
path: Union[Path, str], binary: bool = False, **kwargs,
) -> AsyncIterator[Union[AsyncTextIOWrapper, AsyncBufferedReader]]:
"""Write a file asynchronously (using aiofiles) and atomically."""
mode = "wb" if binary else "w"
path = Path(path)
temp = NamedTemporaryFile(dir=path.parent, delete=False)
temp_path = Path(temp.name)
try:
async with aiofiles.open(temp_path, mode, **kwargs) as out:
yield out
finally:
temp_path.replace(path)