Coverage for portality/models/account.py: 70%
183 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-08-04 15:38 +0100
« prev ^ index » next coverage.py v6.4.2, created at 2022-08-04 15:38 +0100
1import uuid
2from flask_login import UserMixin
3from datetime import datetime, timedelta
4from werkzeug.security import generate_password_hash, check_password_hash
6from portality.dao import DomainObject as DomainObject
7from portality.core import app
8from portality.authorise import Authorise
11class Account(DomainObject, UserMixin):
12 __type__ = 'account'
14 def __init__(self, **kwargs):
15 from portality.forms.validate import ReservedUsernames
16 ReservedUsernames().validate(kwargs.get('id', ''))
17 super(Account, self).__init__(**kwargs)
19 @classmethod
20 def make_account(cls, email, username=None, name=None, roles=None, associated_journal_ids=None):
21 if roles is None:
22 roles = []
24 if associated_journal_ids is None:
25 associated_journal_ids = []
27 # If we have an existing account with these credentials, supply it
28 a = cls.pull(username) or cls.pull_by_email(email)
29 if a:
30 return a
32 # Create a new account
33 _id = username or cls.new_short_uuid()
34 a = Account(id=_id)
35 a.set_email(email)
36 a.set_name(name) if name else None
38 for role in roles:
39 a.add_role(role)
40 for jid in associated_journal_ids:
41 a.add_journal(jid)
43 # New accounts don't have passwords set - create a reset token for password.
44 reset_token = uuid.uuid4().hex
45 # give them 14 days to create their first password if timeout not specified in config
46 a.set_reset_token(reset_token, app.config.get("PASSWORD_CREATE_TIMEOUT",
47 app.config.get('PASSWORD_RESET_TIMEOUT', 86400) * 14))
48 return a
50 @classmethod
51 def pull_by_email(cls, email: str):
52 if email is None:
53 return None
54 res = cls.query(q='email:"' + email + '"')
55 if res.get('hits', {}).get('total', {}).get('value', 0) == 1:
56 acc = cls(**res['hits']['hits'][0]['_source'])
57 if acc.email == email: # Only return the account if it was an exact match with supplied email
58 return acc
59 return None
61 @classmethod
62 def email_in_use(cls, email: str):
63 if email is None:
64 return None
65 res = cls.query(q='email:"' + email + '"')
66 return res.get('hits', {}).get('total', {}).get('value', 0) > 0
68 @classmethod
69 def get_by_reset_token(cls, reset_token, not_expired=True):
70 res = cls.query(q='reset_token.exact:"' + reset_token + '"')
71 obs = [hit.get("_source") for hit in res.get("hits", {}).get("hits", [])]
72 if len(obs) == 0 or len(obs) > 1:
73 return None
74 expires = obs[0].get("reset_expires")
75 if expires is None:
76 return None
77 if not_expired:
78 try:
79 ed = datetime.strptime(expires, "%Y-%m-%dT%H:%M:%SZ")
80 if ed < datetime.now():
81 return None
82 except ValueError:
83 return None
84 return cls(**obs[0])
86 @property
87 def marketing_consent(self):
88 return self.data.get("marketing_consent")
90 def set_marketing_consent(self, consent):
91 self.data["marketing_consent"] = bool(consent)
93 @property
94 def name(self):
95 return self.data.get("name")
97 def set_name(self, name):
98 self.data["name"] = name
100 @property
101 def email(self):
102 return self.data.get("email")
104 def set_email(self, email):
105 self.data["email"] = email
107 def set_password(self, password):
108 self.data['password'] = generate_password_hash(password)
110 def clear_password(self):
111 if self.data.get('password'):
112 del self.data['password']
114 def check_password(self, password):
115 try:
116 return check_password_hash(self.data['password'], password)
117 except KeyError:
118 app.logger.error("Problem with user '{}' account: no password field".format(self.data['id']))
119 raise
121 @property
122 def journal(self):
123 return self.data.get("journal")
125 def add_journal(self, jid):
126 if jid in self.data.get("journal", []):
127 return
128 if "journal" not in self.data:
129 self.data["journal"] = []
130 if jid not in self.data["journal"]:
131 self.data["journal"].append(jid)
133 def remove_journal(self, jid):
134 if "journal" not in self.data:
135 return
136 self.data["journal"].remove(jid)
138 @property
139 def reset_token(self): return self.data.get('reset_token')
141 def set_reset_token(self, token, timeout):
142 expires = datetime.now() + timedelta(0, timeout)
143 self.data["reset_token"] = token
144 self.data["reset_expires"] = expires.strftime("%Y-%m-%dT%H:%M:%SZ")
146 def remove_reset_token(self):
147 if "reset_token" in self.data:
148 del self.data["reset_token"]
149 if "reset_expires" in self.data:
150 del self.data["reset_expires"]
152 @property
153 def reset_expires(self):
154 return self.data.get("reset_expires")
156 @property
157 def reset_expires_timestamp(self):
158 expires = self.reset_expires
159 if expires is None:
160 return None
161 return datetime.strptime(expires, "%Y-%m-%dT%H:%M:%SZ")
163 def is_reset_expired(self):
164 expires = self.reset_expires_timestamp
165 if expires is None:
166 return True
167 return expires < datetime.utcnow()
169 @property
170 def is_super(self):
171 # return not self.is_anonymous and self.id in app.config['SUPER_USER']
172 return Authorise.has_role(app.config["SUPER_USER_ROLE"], self.data.get("role", []))
174 def has_role(self, role):
175 return Authorise.has_role(role, self.data.get("role", []))
177 @classmethod
178 def all_top_level_roles(cls):
179 return Authorise.top_level_roles()
181 def add_role(self, role):
182 if "role" not in self.data:
183 self.data["role"] = []
184 if role not in self.data["role"]:
185 self.data["role"].append(role)
186 # If we're adding the API role, ensure we also have a key to validate
187 if role == 'api' and not self.data.get('api_key', None):
188 self.generate_api_key()
190 def remove_role(self, role):
191 if "role" not in self.data:
192 return
193 if role in self.data["role"]:
194 self.data["role"].remove(role)
196 @property
197 def role(self):
198 return self.data.get("role", [])
200 def set_role(self, role):
201 if not isinstance(role, list):
202 role = [role]
203 self.data["role"] = role
205 def prep(self):
206 self.data['last_updated'] = datetime.now().strftime("%Y-%m-%dT%H:%M:%SZ")
208 @property
209 def api_key(self):
210 if self.has_role('api'):
211 return self.data.get('api_key', None)
212 else:
213 return None
215 def generate_api_key(self):
216 k = uuid.uuid4().hex
217 self.data['api_key'] = k
218 return k
220 @classmethod
221 def pull_by_api_key(cls, key):
222 """Find a user by their API key - only succeed if they currently have API access."""
223 res = cls.query(q='api_key.exact:"' + key + '"')
224 if res.get('hits', {}).get('total', {}).get('value', 0) == 1:
225 usr = cls(**res['hits']['hits'][0]['_source'])
226 if usr.has_role('api'):
227 return usr
228 return None
230 @classmethod
231 def new_short_uuid(cls):
232 """ Generate a short UUID and check it's unique in this type """
233 trunc_uuid = str(uuid.uuid4())[:8]
234 if cls.pull(trunc_uuid) is None:
235 return trunc_uuid
236 else:
237 return cls.new_short_uuid()