Support a order key for accounts in config

This commit is contained in:
miruka 2020-05-14 03:33:34 -04:00
parent fd3fe06d15
commit 9b5b3a9f06
5 changed files with 38 additions and 17 deletions

View File

@ -106,6 +106,7 @@ class Backend:
password: str,
device_id: Optional[str] = None,
homeserver: str = "https://matrix.org",
order: int = -1,
) -> str:
"""Create and register a `MatrixClient`, login and return a user ID."""
@ -120,15 +121,18 @@ class Backend:
raise
self.clients[client.user_id] = client
self.models["accounts"][client.user_id] = Account(client.user_id)
self.models["accounts"][client.user_id] = Account(client.user_id,order)
return client.user_id
async def resume_client(self,
user_id: str,
token: str,
device_id: str,
homeserver: str = "https://matrix.org") -> None:
async def resume_client(
self,
user_id: str,
token: str,
device_id: str,
homeserver: str = "https://matrix.org",
order: int = -1,
) -> None:
"""Create and register a `MatrixClient` with known account details."""
client = MatrixClient(
@ -137,7 +141,7 @@ class Backend:
)
self.clients[user_id] = client
self.models["accounts"][user_id] = Account(user_id)
self.models["accounts"][user_id] = Account(user_id, order)
await client.resume(user_id=user_id, token=token, device_id=device_id)
@ -145,12 +149,13 @@ class Backend:
async def load_saved_accounts(self) -> List[str]:
"""Call `resume_client` for all saved accounts in user config."""
async def resume(user_id: str, info: Dict[str, str]) -> str:
async def resume(user_id: str, info: Dict[str, Any]) -> str:
await self.resume_client(
user_id = user_id,
token = info["token"],
device_id = info["device_id"],
homeserver = info["homeserver"],
order = info.get("order", -1),
)
return user_id

View File

@ -33,6 +33,7 @@ class Account(ModelItem):
"""A logged in matrix account."""
id: str = field()
order: int = -1
display_name: str = ""
avatar_url: str = ""
max_upload_size: int = 0
@ -42,10 +43,8 @@ class Account(ModelItem):
total_mentions: int = 0
def __lt__(self, other: "Account") -> bool:
"""Sort by user ID."""
name = self.id[1:]
other_name = other.id[1:]
return name.lower() < other_name.lower()
"""Sort by order, then by user ID."""
return (self.order, self.id.lower()) < (other.order, other.id.lower())
@dataclass
@ -120,10 +119,12 @@ class Room(ModelItem):
@dataclass
class AccountOrRoom(Account, Room):
type: Union[Type[Account], Type[Room]] = Account
type: Union[Type[Account], Type[Room]] = Account
account_order: int = -1
def __lt__(self, other: "AccountOrRoom") -> bool: # type: ignore
return (
self.account_order,
self.id if self.type is Account else self.for_account,
other.type is Account,
self.left,
@ -134,6 +135,7 @@ class AccountOrRoom(Account, Room):
(self.display_name or self.id).lower(),
) < (
other.account_order,
other.id if other.type is Account else other.for_account,
self.type is Account,
other.left,

View File

@ -29,7 +29,7 @@ class ModelStore(UserDict):
model: Model
if key == "all_rooms":
model = AllRooms()
model = AllRooms(self["accounts"])
elif key == "matching_accounts":
model = MatchingAccounts(self["all_rooms"])
elif is_tuple and len(key) == 3 and key[2] == "filtered_members":

View File

@ -10,10 +10,12 @@ from .model_item import ModelItem
class AllRooms(FieldSubstringFilter):
def __init__(self) -> None:
def __init__(self, accounts: Model) -> None:
super().__init__(sync_id="all_rooms", fields=("display_name",))
self.items_changed_callbacks.append(self.refilter_accounts)
self.accounts = accounts
self._collapsed: Set[str] = set()
@ -39,7 +41,14 @@ class AllRooms(FieldSubstringFilter):
def convert_item(self, item: ModelItem) -> AccountOrRoom:
return AccountOrRoom(**asdict(item), type=type(item)) # type: ignore
return AccountOrRoom(
**asdict(item),
type = type(item), # type: ignore
account_order =
item.order if isinstance(item, Account) else
self.accounts[item.for_account].order, # type: ignore
)
def accept_item(self, item: ModelItem) -> bool:

View File

@ -183,14 +183,19 @@ class Accounts(JSONDataFile):
"""
client = self.backend.clients[user_id]
saved = await self.read()
await self.write({
**await self.read(),
**saved,
client.user_id: {
"homeserver": client.homeserver,
"token": client.access_token,
"device_id": client.device_id,
"enabled": True,
"order": max(
account.get("order", i)
for i, account in enumerate(saved.values())
) + 1,
},
})