Coverage for portality / models / account.py: 74%
253 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 00:09 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 00:09 +0100
1import uuid
2import hashlib
3import hmac
4import re
5from flask_login import UserMixin
6from datetime import timedelta
7from werkzeug.security import generate_password_hash, check_password_hash
9from portality import constants
10from portality.dao import DomainObject as DomainObject
11from portality.core import app
12from portality.authorise import Authorise
13from portality.lib import dates
14from portality.lib.dates import FMT_DATETIME_STD
16class Account(DomainObject, UserMixin):
17 __type__ = 'account'
19 def __init__(self, **kwargs):
20 from portality.forms.validate import ReservedUsernames
21 ReservedUsernames().validate(kwargs.get('id', ''))
22 super(Account, self).__init__(**kwargs)
24 @classmethod
25 def make_account(cls, email, username=None, name=None, roles=None, associated_journal_ids=None):
26 if roles is None:
27 roles = []
29 if associated_journal_ids is None:
30 associated_journal_ids = []
32 # If we have an existing account with these credentials, supply it
33 a = cls.pull(username) or cls.pull_by_email(email)
34 if a:
35 return a
37 # Create a new account
38 _id = username or cls.new_short_uuid()
39 a = Account(id=_id)
40 a.set_email(email)
41 a.set_name(name) if name else None
43 for role in roles:
44 a.add_role(role)
45 for jid in associated_journal_ids:
46 a.add_journal(jid)
48 # New accounts don't have passwords set - create a reset token for password.
49 reset_token = uuid.uuid4().hex
50 # give them 14 days to create their first password if timeout not specified in config
51 a.set_reset_token(reset_token, app.config.get("PASSWORD_CREATE_TIMEOUT",
52 app.config.get('PASSWORD_RESET_TIMEOUT', 86400) * 14))
53 return a
55 @classmethod
56 def pull_by_email(cls, email: str):
57 if email is None:
58 return None
59 res = cls.query(q='email:"' + email + '"')
60 if res.get('hits', {}).get('total', {}).get('value', 0) == 1:
61 acc = cls(**res['hits']['hits'][0]['_source'])
62 if acc.email.lower() == email.lower(): # allow case insensitive login
63 return acc
64 return None
66 @classmethod
67 def email_in_use(cls, email: str):
68 if email is None:
69 return None
70 res = cls.query(q='email:"' + email + '"')
71 return res.get('hits', {}).get('total', {}).get('value', 0) > 0
73 @classmethod
74 def get_by_reset_token(cls, reset_token, not_expired=True):
75 res = cls.query(q='reset_token.exact:"' + reset_token + '"')
76 obs = [hit.get("_source") for hit in res.get("hits", {}).get("hits", [])]
77 if len(obs) == 0 or len(obs) > 1:
78 return None
79 expires = obs[0].get("reset_expires")
80 if expires is None:
81 return None
82 if not_expired:
83 try:
84 ed = dates.parse(expires)
85 if ed < dates.now():
86 return None
87 except ValueError:
88 return None
89 return cls(**obs[0])
91 # @classmethod
92 # def autocomplete(cls, field, prefix, admin_only=False, size=5):
93 #
94 # return {"suggestions": super().autocomplete(field, prefix, filter_condition=filter_condition, size=size)}
96 @classmethod
97 def admin_autocomplete(cls, field, prefix, size=5):
98 """Autocomplete for admin users only."""
99 filter_condition = {"role.exact": "admin"}
100 return cls.autocomplete(field, prefix, filter_condition=filter_condition, size=size)
102 @property
103 def marketing_consent(self):
104 return self.data.get("marketing_consent")
106 def set_marketing_consent(self, consent):
107 self.data["marketing_consent"] = bool(consent)
109 @property
110 def name(self):
111 return self.data.get("name")
113 def set_name(self, name):
114 self.data["name"] = name
116 @property
117 def email(self):
118 return self.data.get("email")
120 def set_email(self, email):
121 self.data["email"] = email
123 def set_password(self, password):
124 self.data['password'] = generate_password_hash(password)
126 def set_password_hash(self, hash):
127 self.data['password'] = hash
129 def clear_password(self):
130 if self.data.get('password'):
131 del self.data['password']
133 def check_password(self, password):
134 """Check the provided password against the stored hash.
136 Handles legacy hashes removed in Werkzeug 3 (e.g. 'sha1$...' or raw 40-hex SHA1) by verifying once
137 and upgrading them to a modern hash. This preserves behaviour for existing records while moving
138 them forward to supported hash schemes.
139 """
140 try:
141 stored = self.data['password']
142 except KeyError:
143 app.logger.error("Problem with user '{}' account: no password field".format(self.data['id']))
144 raise
146 # If the stored hash looks like a legacy SHA1 format, verify via compatibility shim first.
147 if self._is_legacy_sha1_hash(stored):
148 if self._verify_legacy_sha1(stored, password):
149 # Upgrade path: replace legacy hash with a modern one and persist.
150 # Note: This handles a breaking change in Werkzeug 3 (legacy verifiers removed).
151 self.set_password(password)
152 try:
153 # DomainObject.save() is expected to exist; failure to save should not block login success.
154 self.save()
155 except Exception as e:
156 app.logger.warning(
157 "Password upgraded for user '%s' but save failed: %s", self.data.get('id'), str(e)
158 )
159 return True
160 return False
162 # Otherwise, use Werkzeug's checker. If Werkzeug raises due to an unsupported legacy format,
163 # fall back to the legacy verifier as a last resort.
164 try:
165 return check_password_hash(stored, password)
166 except ValueError:
167 # Fallback for unsupported legacy formats encountered at runtime.
168 if self._verify_legacy_sha1(stored, password):
169 self.set_password(password)
170 try:
171 self.save()
172 except Exception as e:
173 app.logger.warning(
174 "Password upgraded for user '%s' after ValueError but save failed: %s",
175 self.data.get('id'), str(e)
176 )
177 return True
178 return False
180 # --- Legacy SHA1 compatibility (Werkzeug 3 removal) ---
181 _SHA1_HEX_RE = re.compile(r"^[a-f0-9]{40}$", re.IGNORECASE)
183 @classmethod
184 def _is_legacy_sha1_hash(cls, stored: str) -> bool:
185 """Detect legacy SHA1 formats that Werkzeug 3 no longer supports.
187 Supported legacy patterns:
188 - 'sha1$<salt>$<hexdigest>' (old Werkzeug simple salted SHA1)
189 - '<40-hex>' (unsalted plain SHA1 of password)
190 """
191 if not stored or not isinstance(stored, str):
192 return False
193 if stored.startswith('sha1$'):
194 parts = stored.split('$')
195 return len(parts) == 3 and bool(parts[1]) and bool(parts[2])
196 # plain 40 hex characters implies unsalted SHA1
197 return bool(cls._SHA1_HEX_RE.fullmatch(stored))
199 @classmethod
200 def _verify_legacy_sha1(cls, stored: str, password: str) -> bool:
201 """Verify a password against legacy SHA1 formats.
203 - 'sha1$<salt>$<hexdigest>' uses sha1(salt + password)
204 - '<40-hex>' uses sha1(password)
205 """
206 if not stored or password is None:
207 return False
208 try:
209 if stored.startswith('sha1$'):
210 # salted format: sha1$<salt>$<hexdigest>
211 _, salt, hexdigest = stored.split('$', 2)
212 digest = hashlib.sha1((salt + password).encode('utf-8')).hexdigest()
213 return hmac.compare_digest(digest, hexdigest)
214 # unsalted plain SHA1 hex
215 if cls._SHA1_HEX_RE.fullmatch(stored):
216 digest = hashlib.sha1(password.encode('utf-8')).hexdigest()
217 return hmac.compare_digest(digest, stored.lower())
218 except Exception:
219 # Any parsing/encoding issues -> treat as non-match
220 return False
221 return False
223 @property
224 def journal(self):
225 return self.data.get("journal")
227 def add_journal(self, jid):
228 if jid in self.data.get("journal", []):
229 return
230 if "journal" not in self.data:
231 self.data["journal"] = []
232 if jid not in self.data["journal"]:
233 self.data["journal"].append(jid)
235 def remove_journal(self, jid):
236 if "journal" not in self.data:
237 return
238 self.data["journal"].remove(jid)
240 @property
241 def reset_token(self):
242 return self.data.get('reset_token')
244 def set_reset_token(self, token, timeout):
245 expires = dates.now() + timedelta(0, timeout)
246 self.data["reset_token"] = token
247 self.data["reset_expires"] = expires.strftime(FMT_DATETIME_STD)
249 def remove_reset_token(self):
250 if "reset_token" in self.data:
251 del self.data["reset_token"]
252 if "reset_expires" in self.data:
253 del self.data["reset_expires"]
255 @property
256 def reset_expires(self):
257 return self.data.get("reset_expires")
259 @property
260 def reset_expires_timestamp(self):
261 expires = self.reset_expires
262 if expires is None:
263 return None
264 return dates.parse(expires)
266 def is_reset_expired(self):
267 expires = self.reset_expires_timestamp
268 if expires is None:
269 return True
270 return expires < dates.now()
272 @property
273 def is_super(self):
274 # return not self.is_anonymous and self.id in app.config['SUPER_USER']
275 return Authorise.has_role(app.config["SUPER_USER_ROLE"], self.data.get("role", []))
277 def has_role(self, role):
278 return Authorise.has_role(role, self.data.get("role", []))
280 @classmethod
281 def all_top_level_roles(cls):
282 return Authorise.top_level_roles()
284 def add_role(self, role):
285 if "role" not in self.data:
286 self.data["role"] = []
287 if role not in self.data["role"]:
288 self.data["role"].append(role)
289 # If we're adding the API role, ensure we also have a key to validate
290 if role == 'api' and not self.data.get('api_key', None):
291 self.generate_api_key()
293 def remove_role(self, role):
294 if "role" not in self.data:
295 return
296 if role in self.data["role"]:
297 self.data["role"].remove(role)
299 @property
300 def role(self):
301 return self.data.get("role", [])
303 def set_role(self, role):
304 if not isinstance(role, list):
305 role = [role]
306 self.data["role"] = role
308 def prep(self):
309 self.data['last_updated'] = dates.now_str()
311 @property
312 def api_key(self):
313 if self.has_role('api'):
314 return self.data.get('api_key', None)
315 else:
316 return None
318 def generate_api_key(self):
319 k = uuid.uuid4().hex
320 self.data['api_key'] = k
321 return k
323 @property
324 def is_premium(self):
325 return (self.has_role(constants.ROLE_PREMIUM) or
326 self.has_role(constants.ROLE_PREMIUM_OAI) or
327 self.has_role(constants.ROLE_PREMIUM_PDD) or
328 self.has_role(constants.ROLE_PREMIUM_CSV))
330 @classmethod
331 def pull_by_api_key(cls, key):
332 """Find a user by their API key - only succeed if they currently have API access."""
333 res = cls.query(q='api_key.exact:"' + key + '"')
334 if res.get('hits', {}).get('total', {}).get('value', 0) == 1:
335 usr = cls(**res['hits']['hits'][0]['_source'])
336 if usr.has_role('api'):
337 return usr
338 return None
340 @classmethod
341 def new_short_uuid(cls):
342 """ Generate a short UUID and check it's unique in this type """
343 trunc_uuid = str(uuid.uuid4())[:8]
344 if cls.pull(trunc_uuid) is None:
345 return trunc_uuid
346 else:
347 return cls.new_short_uuid()
349 @classmethod
350 def get_name_safe(cls, account_id) -> str:
351 if account_id:
352 author = Account.pull(account_id)
353 if author is not None and author.name:
354 return author.name
355 return ''
357 @classmethod
358 def is_enable_publisher_email(cls) -> bool:
359 # TODO: in the long run this needs to move out to the user's email preferences but for now it
360 # is here to replicate the behaviour in the code it replaces
361 return app.config.get("ENABLE_PUBLISHER_EMAIL", False)