Coverage for portality / models / account.py: 74%

253 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 00:09 +0100

1import uuid 

2import hashlib 

3import hmac 

4import re 

5from flask_login import UserMixin 

6from datetime import timedelta 

7from werkzeug.security import generate_password_hash, check_password_hash 

8 

9from portality import constants 

10from portality.dao import DomainObject as DomainObject 

11from portality.core import app 

12from portality.authorise import Authorise 

13from portality.lib import dates 

14from portality.lib.dates import FMT_DATETIME_STD 

15 

16class Account(DomainObject, UserMixin): 

17 __type__ = 'account' 

18 

19 def __init__(self, **kwargs): 

20 from portality.forms.validate import ReservedUsernames 

21 ReservedUsernames().validate(kwargs.get('id', '')) 

22 super(Account, self).__init__(**kwargs) 

23 

24 @classmethod 

25 def make_account(cls, email, username=None, name=None, roles=None, associated_journal_ids=None): 

26 if roles is None: 

27 roles = [] 

28 

29 if associated_journal_ids is None: 

30 associated_journal_ids = [] 

31 

32 # If we have an existing account with these credentials, supply it 

33 a = cls.pull(username) or cls.pull_by_email(email) 

34 if a: 

35 return a 

36 

37 # Create a new account 

38 _id = username or cls.new_short_uuid() 

39 a = Account(id=_id) 

40 a.set_email(email) 

41 a.set_name(name) if name else None 

42 

43 for role in roles: 

44 a.add_role(role) 

45 for jid in associated_journal_ids: 

46 a.add_journal(jid) 

47 

48 # New accounts don't have passwords set - create a reset token for password. 

49 reset_token = uuid.uuid4().hex 

50 # give them 14 days to create their first password if timeout not specified in config 

51 a.set_reset_token(reset_token, app.config.get("PASSWORD_CREATE_TIMEOUT", 

52 app.config.get('PASSWORD_RESET_TIMEOUT', 86400) * 14)) 

53 return a 

54 

55 @classmethod 

56 def pull_by_email(cls, email: str): 

57 if email is None: 

58 return None 

59 res = cls.query(q='email:"' + email + '"') 

60 if res.get('hits', {}).get('total', {}).get('value', 0) == 1: 

61 acc = cls(**res['hits']['hits'][0]['_source']) 

62 if acc.email.lower() == email.lower(): # allow case insensitive login 

63 return acc 

64 return None 

65 

66 @classmethod 

67 def email_in_use(cls, email: str): 

68 if email is None: 

69 return None 

70 res = cls.query(q='email:"' + email + '"') 

71 return res.get('hits', {}).get('total', {}).get('value', 0) > 0 

72 

73 @classmethod 

74 def get_by_reset_token(cls, reset_token, not_expired=True): 

75 res = cls.query(q='reset_token.exact:"' + reset_token + '"') 

76 obs = [hit.get("_source") for hit in res.get("hits", {}).get("hits", [])] 

77 if len(obs) == 0 or len(obs) > 1: 

78 return None 

79 expires = obs[0].get("reset_expires") 

80 if expires is None: 

81 return None 

82 if not_expired: 

83 try: 

84 ed = dates.parse(expires) 

85 if ed < dates.now(): 

86 return None 

87 except ValueError: 

88 return None 

89 return cls(**obs[0]) 

90 

91 # @classmethod 

92 # def autocomplete(cls, field, prefix, admin_only=False, size=5): 

93 # 

94 # return {"suggestions": super().autocomplete(field, prefix, filter_condition=filter_condition, size=size)} 

95 

96 @classmethod 

97 def admin_autocomplete(cls, field, prefix, size=5): 

98 """Autocomplete for admin users only.""" 

99 filter_condition = {"role.exact": "admin"} 

100 return cls.autocomplete(field, prefix, filter_condition=filter_condition, size=size) 

101 

102 @property 

103 def marketing_consent(self): 

104 return self.data.get("marketing_consent") 

105 

106 def set_marketing_consent(self, consent): 

107 self.data["marketing_consent"] = bool(consent) 

108 

109 @property 

110 def name(self): 

111 return self.data.get("name") 

112 

113 def set_name(self, name): 

114 self.data["name"] = name 

115 

116 @property 

117 def email(self): 

118 return self.data.get("email") 

119 

120 def set_email(self, email): 

121 self.data["email"] = email 

122 

123 def set_password(self, password): 

124 self.data['password'] = generate_password_hash(password) 

125 

126 def set_password_hash(self, hash): 

127 self.data['password'] = hash 

128 

129 def clear_password(self): 

130 if self.data.get('password'): 

131 del self.data['password'] 

132 

133 def check_password(self, password): 

134 """Check the provided password against the stored hash. 

135 

136 Handles legacy hashes removed in Werkzeug 3 (e.g. 'sha1$...' or raw 40-hex SHA1) by verifying once 

137 and upgrading them to a modern hash. This preserves behaviour for existing records while moving 

138 them forward to supported hash schemes. 

139 """ 

140 try: 

141 stored = self.data['password'] 

142 except KeyError: 

143 app.logger.error("Problem with user '{}' account: no password field".format(self.data['id'])) 

144 raise 

145 

146 # If the stored hash looks like a legacy SHA1 format, verify via compatibility shim first. 

147 if self._is_legacy_sha1_hash(stored): 

148 if self._verify_legacy_sha1(stored, password): 

149 # Upgrade path: replace legacy hash with a modern one and persist. 

150 # Note: This handles a breaking change in Werkzeug 3 (legacy verifiers removed). 

151 self.set_password(password) 

152 try: 

153 # DomainObject.save() is expected to exist; failure to save should not block login success. 

154 self.save() 

155 except Exception as e: 

156 app.logger.warning( 

157 "Password upgraded for user '%s' but save failed: %s", self.data.get('id'), str(e) 

158 ) 

159 return True 

160 return False 

161 

162 # Otherwise, use Werkzeug's checker. If Werkzeug raises due to an unsupported legacy format, 

163 # fall back to the legacy verifier as a last resort. 

164 try: 

165 return check_password_hash(stored, password) 

166 except ValueError: 

167 # Fallback for unsupported legacy formats encountered at runtime. 

168 if self._verify_legacy_sha1(stored, password): 

169 self.set_password(password) 

170 try: 

171 self.save() 

172 except Exception as e: 

173 app.logger.warning( 

174 "Password upgraded for user '%s' after ValueError but save failed: %s", 

175 self.data.get('id'), str(e) 

176 ) 

177 return True 

178 return False 

179 

180 # --- Legacy SHA1 compatibility (Werkzeug 3 removal) --- 

181 _SHA1_HEX_RE = re.compile(r"^[a-f0-9]{40}$", re.IGNORECASE) 

182 

183 @classmethod 

184 def _is_legacy_sha1_hash(cls, stored: str) -> bool: 

185 """Detect legacy SHA1 formats that Werkzeug 3 no longer supports. 

186 

187 Supported legacy patterns: 

188 - 'sha1$<salt>$<hexdigest>' (old Werkzeug simple salted SHA1) 

189 - '<40-hex>' (unsalted plain SHA1 of password) 

190 """ 

191 if not stored or not isinstance(stored, str): 

192 return False 

193 if stored.startswith('sha1$'): 

194 parts = stored.split('$') 

195 return len(parts) == 3 and bool(parts[1]) and bool(parts[2]) 

196 # plain 40 hex characters implies unsalted SHA1 

197 return bool(cls._SHA1_HEX_RE.fullmatch(stored)) 

198 

199 @classmethod 

200 def _verify_legacy_sha1(cls, stored: str, password: str) -> bool: 

201 """Verify a password against legacy SHA1 formats. 

202 

203 - 'sha1$<salt>$<hexdigest>' uses sha1(salt + password) 

204 - '<40-hex>' uses sha1(password) 

205 """ 

206 if not stored or password is None: 

207 return False 

208 try: 

209 if stored.startswith('sha1$'): 

210 # salted format: sha1$<salt>$<hexdigest> 

211 _, salt, hexdigest = stored.split('$', 2) 

212 digest = hashlib.sha1((salt + password).encode('utf-8')).hexdigest() 

213 return hmac.compare_digest(digest, hexdigest) 

214 # unsalted plain SHA1 hex 

215 if cls._SHA1_HEX_RE.fullmatch(stored): 

216 digest = hashlib.sha1(password.encode('utf-8')).hexdigest() 

217 return hmac.compare_digest(digest, stored.lower()) 

218 except Exception: 

219 # Any parsing/encoding issues -> treat as non-match 

220 return False 

221 return False 

222 

223 @property 

224 def journal(self): 

225 return self.data.get("journal") 

226 

227 def add_journal(self, jid): 

228 if jid in self.data.get("journal", []): 

229 return 

230 if "journal" not in self.data: 

231 self.data["journal"] = [] 

232 if jid not in self.data["journal"]: 

233 self.data["journal"].append(jid) 

234 

235 def remove_journal(self, jid): 

236 if "journal" not in self.data: 

237 return 

238 self.data["journal"].remove(jid) 

239 

240 @property 

241 def reset_token(self): 

242 return self.data.get('reset_token') 

243 

244 def set_reset_token(self, token, timeout): 

245 expires = dates.now() + timedelta(0, timeout) 

246 self.data["reset_token"] = token 

247 self.data["reset_expires"] = expires.strftime(FMT_DATETIME_STD) 

248 

249 def remove_reset_token(self): 

250 if "reset_token" in self.data: 

251 del self.data["reset_token"] 

252 if "reset_expires" in self.data: 

253 del self.data["reset_expires"] 

254 

255 @property 

256 def reset_expires(self): 

257 return self.data.get("reset_expires") 

258 

259 @property 

260 def reset_expires_timestamp(self): 

261 expires = self.reset_expires 

262 if expires is None: 

263 return None 

264 return dates.parse(expires) 

265 

266 def is_reset_expired(self): 

267 expires = self.reset_expires_timestamp 

268 if expires is None: 

269 return True 

270 return expires < dates.now() 

271 

272 @property 

273 def is_super(self): 

274 # return not self.is_anonymous and self.id in app.config['SUPER_USER'] 

275 return Authorise.has_role(app.config["SUPER_USER_ROLE"], self.data.get("role", [])) 

276 

277 def has_role(self, role): 

278 return Authorise.has_role(role, self.data.get("role", [])) 

279 

280 @classmethod 

281 def all_top_level_roles(cls): 

282 return Authorise.top_level_roles() 

283 

284 def add_role(self, role): 

285 if "role" not in self.data: 

286 self.data["role"] = [] 

287 if role not in self.data["role"]: 

288 self.data["role"].append(role) 

289 # If we're adding the API role, ensure we also have a key to validate 

290 if role == 'api' and not self.data.get('api_key', None): 

291 self.generate_api_key() 

292 

293 def remove_role(self, role): 

294 if "role" not in self.data: 

295 return 

296 if role in self.data["role"]: 

297 self.data["role"].remove(role) 

298 

299 @property 

300 def role(self): 

301 return self.data.get("role", []) 

302 

303 def set_role(self, role): 

304 if not isinstance(role, list): 

305 role = [role] 

306 self.data["role"] = role 

307 

308 def prep(self): 

309 self.data['last_updated'] = dates.now_str() 

310 

311 @property 

312 def api_key(self): 

313 if self.has_role('api'): 

314 return self.data.get('api_key', None) 

315 else: 

316 return None 

317 

318 def generate_api_key(self): 

319 k = uuid.uuid4().hex 

320 self.data['api_key'] = k 

321 return k 

322 

323 @property 

324 def is_premium(self): 

325 return (self.has_role(constants.ROLE_PREMIUM) or 

326 self.has_role(constants.ROLE_PREMIUM_OAI) or 

327 self.has_role(constants.ROLE_PREMIUM_PDD) or 

328 self.has_role(constants.ROLE_PREMIUM_CSV)) 

329 

330 @classmethod 

331 def pull_by_api_key(cls, key): 

332 """Find a user by their API key - only succeed if they currently have API access.""" 

333 res = cls.query(q='api_key.exact:"' + key + '"') 

334 if res.get('hits', {}).get('total', {}).get('value', 0) == 1: 

335 usr = cls(**res['hits']['hits'][0]['_source']) 

336 if usr.has_role('api'): 

337 return usr 

338 return None 

339 

340 @classmethod 

341 def new_short_uuid(cls): 

342 """ Generate a short UUID and check it's unique in this type """ 

343 trunc_uuid = str(uuid.uuid4())[:8] 

344 if cls.pull(trunc_uuid) is None: 

345 return trunc_uuid 

346 else: 

347 return cls.new_short_uuid() 

348 

349 @classmethod 

350 def get_name_safe(cls, account_id) -> str: 

351 if account_id: 

352 author = Account.pull(account_id) 

353 if author is not None and author.name: 

354 return author.name 

355 return '' 

356 

357 @classmethod 

358 def is_enable_publisher_email(cls) -> bool: 

359 # TODO: in the long run this needs to move out to the user's email preferences but for now it 

360 # is here to replicate the behaviour in the code it replaces 

361 return app.config.get("ENABLE_PUBLISHER_EMAIL", False)