diff --git a/backends/file.py b/backends/file.py index ff45c42..fb03e06 100644 --- a/backends/file.py +++ b/backends/file.py @@ -15,7 +15,7 @@ def _load(name,pos,file,obj): for line in linegen(file): split=line.split(',') if split[pos]==name: return obj(*split) - raise Exception(f'No such {obj.__name__} with identifier {name}') + return None def _load_multi(name,pos,file,obj): out=[] for line in linegen(file): @@ -46,53 +46,56 @@ def load_token(tokeid): def load_tokens(user): return _load_multi(user.username,1,access_tokens,models.AccessToken) def save_token(token): - return _save(token,0,access_tokens,token.value) + return _save(token,0,access_tokens,str(token)) def delete_token(token): - return _delete(token,0,access_tokens,token.value) + return _delete(token,0,access_tokens,str(token)) def load_invite(tokeid): return _load(tokeid,0,invite_tokens,models.InviteToken) def load_invites(user): return _load_multi(user.username,1,invite_tokens,models.InviteToken) def save_invite(token): - return _save(token,0,invite_tokens,token.value) + return _save(token,0,invite_tokens,str(token)) def delete_invite(token): - return _delete(token,0,invite_tokens,token.value) + return _delete(token,0,invite_tokens,str(token)) def linegen(path): i=0 file=open(path) - line=file.readline().rstrip('\n') # Maybe sometimes there won't be one. + line=_rl(file) while line: i+=1 response=yield line if response is not None: yield i - line=file.readline() + line=_rl(file) file.close() def update(path,lines): def _action(file,i,line): if i in lines: file.write(lines[i]+'\n') - else: file.write(line) + else: file.write(line+'\n') def _post(file): if -1 in lines: file.write(lines[-1]+'\n') writefile(path,_action,_post) def remove(path,lines): def _action(file,i,line): - if i not in lines: file.write(line) + if i not in lines: file.write(line+'\n') writefile(path,_action) def writefile(path,action,post=None): i=0 file=open(path) write=open(f'{path}.tmp','w') - line=file.readline() + line=_rl(file) while line: i+=1 action(write,i,line) - line=file.readline().rstrip('\n') + line=_rl(file) if post is not None: post(write) file.close() write.close() - rename(f'{path}.tmp',path) \ No newline at end of file + rename(f'{path}.tmp',path) + +def _rl(file): + return file.readline().rstrip('\n') # Maybe sometimes there won't be one. \ No newline at end of file diff --git a/main.py b/main.py index b473a2a..917284c 100644 --- a/main.py +++ b/main.py @@ -6,5 +6,6 @@ globals()['backend']=import_module('.backends.'+config['backend']['type'],'auth' backend.config=config['backend']['options'] backend.utils=utils backend.models=models +models.utils=utils models.backend=backend backend.init() \ No newline at end of file diff --git a/models.py b/models.py index 92f2c36..d836701 100644 --- a/models.py +++ b/models.py @@ -1,19 +1,15 @@ from dataclasses import dataclass +from pydantic.v1 import validate_arguments from datetime import datetime -from hashlib import shake_256 import uuid -def phash(pw,salt=None): - if salt is None: salt=uuid.uuid4() - return shake_256(f'{pw}{salt}'.encode('utf-8')).hexdigest(256),salt - class AbstractUser(): pass # It fixes circular dep on user.invited_by # And on User.register (InviteToken) @dataclass class Token(): - value: str + value: uuid.UUID owner: AbstractUser def __init__(self,value,owner=None): if owner is None: @@ -26,15 +22,18 @@ class Token(): backend.delete_token(self) @property def serialise(self): - return ','.join([self.value,self.owner]) + return ','.join(map(str,[self,self.owner])) + def __str__(self): return str(self.value) @dataclass class InviteToken(Token): _uses: int=0 _max_uses: int=-1 expires: datetime=None - def __init__(self,*args,**kwargs): - return super().__init__(*args,**kwargs) + @validate_arguments + def __init__(self,value,owner,uses:int,maxuses:int,expiry:datetime|None): + self.uses,self.max_uses,self.expires=uses,maxuses,expiry + return super().__init__(value,owner) @property def uses(self): return self._uses @uses.setter @@ -49,7 +48,7 @@ class InviteToken(Token): self._max_uses=val @property def serialise(self): - return ','.join([self.value,self.owner,self.uses,self.max_uses,self.expires]) + return super().serialise+','+','.join(map(str,[self.uses,self.max_uses,self.expires])) @dataclass class AccessToken(Token): @@ -61,38 +60,49 @@ class User(AbstractUser): username: str password_hash: str salt: str - _invited_by: AbstractUser # Root node will just reference itself + _invited_by: AbstractUser|str # Root node will just reference itself + email: str='' + def _load_invite(self): + if self._invited_by==self.username: return self # Sanity-check to prevent infinite recursion. + return backend.load_user(self._invited_by) @property def invited_by(self): - if isinstance(self._invited_by,str): self._invited_by=backend.load_user(self._invited_by) + if isinstance(self._invited_by,str): + self._invited_by=self._load_invite() return self._invited_by - email: str='' @property def serialise(self): - return ','.join([self.username,self.password_hash,str(self.salt),self.invited_by.username,self.email]) + return ','.join([self.username,self.password_hash,str(self.salt),self.invited_by.username,self.email or '']) def create_inv_token(self,*args,**kwargs): tok=InviteToken(self,*args,**kwargs) backend.save_invite(tok) return tok - def change_password(self,old_pw:str|None,new_pw:str): pass + def auth(self,pw): + return utils.phash(pw,self.salt)[0]==self.password_hash + def change_password(self,old_pw:str|None,new_pw:str): + if self.auth(old_pw): + self.password_hash,self.salt=utils.phash(new_pw) + self.save() def save(self): backend.save_user(self) + def __str__(self): return self.username @classmethod def login(cls,username:str,password:str): u=backend.load_user(username) if u is None: raise Exception("User doesn't exist") - if phash(password,u.salt)[0]==u.password_hash: + if u.auth(password): a=AccessToken(u) backend.save_token(a) return a raise Exception("Incorrect password") @classmethod - def register(cls,username:str,password:str,invite:InviteToken,email:str|None): + def register(cls,username:str,password:str,invite:InviteToken,email:str=''): if set([chr(n) for n in range(32)]+[','])&set(username): raise Exception('Invalid username') u=backend.load_user(username) if u is not None: raise Exception("User already exists") - u=User(username,*phash(password),invite.owner,email) + u=User(username,*utils.phash(password),invite.owner,email) u.save() invite.uses+=1 + backend.save_invite(invite) return u \ No newline at end of file diff --git a/utils.py b/utils.py index e69de29..8874617 100644 --- a/utils.py +++ b/utils.py @@ -0,0 +1,6 @@ +from hashlib import shake_256 +import uuid + +def phash(pw,salt=None): + if salt is None: salt=uuid.uuid4() + return shake_256(f'{pw}{salt}'.encode('utf-8')).hexdigest(256),salt \ No newline at end of file