Coverage for portality / models / account.py: 76%

206 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-04 09:41 +0100

1import uuid 

2from flask_login import UserMixin 

3from datetime import timedelta 

4from werkzeug.security import generate_password_hash, check_password_hash 

5 

6from portality import constants 

7from portality.dao import DomainObject as DomainObject 

8from portality.core import app 

9from portality.authorise import Authorise 

10from portality.lib import dates 

11from portality.lib.dates import FMT_DATETIME_STD 

12 

13class Account(DomainObject, UserMixin): 

14 __type__ = 'account' 

15 

16 def __init__(self, **kwargs): 

17 from portality.forms.validate import ReservedUsernames 

18 ReservedUsernames().validate(kwargs.get('id', '')) 

19 super(Account, self).__init__(**kwargs) 

20 

21 @classmethod 

22 def make_account(cls, email, username=None, name=None, roles=None, associated_journal_ids=None): 

23 if roles is None: 

24 roles = [] 

25 

26 if associated_journal_ids is None: 

27 associated_journal_ids = [] 

28 

29 # If we have an existing account with these credentials, supply it 

30 a = cls.pull(username) or cls.pull_by_email(email) 

31 if a: 

32 return a 

33 

34 # Create a new account 

35 _id = username or cls.new_short_uuid() 

36 a = Account(id=_id) 

37 a.set_email(email) 

38 a.set_name(name) if name else None 

39 

40 for role in roles: 

41 a.add_role(role) 

42 for jid in associated_journal_ids: 

43 a.add_journal(jid) 

44 

45 # New accounts don't have passwords set - create a reset token for password. 

46 reset_token = uuid.uuid4().hex 

47 # give them 14 days to create their first password if timeout not specified in config 

48 a.set_reset_token(reset_token, app.config.get("PASSWORD_CREATE_TIMEOUT", 

49 app.config.get('PASSWORD_RESET_TIMEOUT', 86400) * 14)) 

50 return a 

51 

52 @classmethod 

53 def pull_by_email(cls, email: str): 

54 if email is None: 

55 return None 

56 res = cls.query(q='email:"' + email + '"') 

57 if res.get('hits', {}).get('total', {}).get('value', 0) == 1: 

58 acc = cls(**res['hits']['hits'][0]['_source']) 

59 if acc.email.lower() == email.lower(): # allow case insensitive login 

60 return acc 

61 return None 

62 

63 @classmethod 

64 def email_in_use(cls, email: str): 

65 if email is None: 

66 return None 

67 res = cls.query(q='email:"' + email + '"') 

68 return res.get('hits', {}).get('total', {}).get('value', 0) > 0 

69 

70 @classmethod 

71 def get_by_reset_token(cls, reset_token, not_expired=True): 

72 res = cls.query(q='reset_token.exact:"' + reset_token + '"') 

73 obs = [hit.get("_source") for hit in res.get("hits", {}).get("hits", [])] 

74 if len(obs) == 0 or len(obs) > 1: 

75 return None 

76 expires = obs[0].get("reset_expires") 

77 if expires is None: 

78 return None 

79 if not_expired: 

80 try: 

81 ed = dates.parse(expires) 

82 if ed < dates.now(): 

83 return None 

84 except ValueError: 

85 return None 

86 return cls(**obs[0]) 

87 

88 # @classmethod 

89 # def autocomplete(cls, field, prefix, admin_only=False, size=5): 

90 # 

91 # return {"suggestions": super().autocomplete(field, prefix, filter_condition=filter_condition, size=size)} 

92 

93 @classmethod 

94 def admin_autocomplete(cls, field, prefix, size=5): 

95 """Autocomplete for admin users only.""" 

96 filter_condition = {"role.exact": "admin"} 

97 return cls.autocomplete(field, prefix, filter_condition=filter_condition, size=size) 

98 

99 @property 

100 def marketing_consent(self): 

101 return self.data.get("marketing_consent") 

102 

103 def set_marketing_consent(self, consent): 

104 self.data["marketing_consent"] = bool(consent) 

105 

106 @property 

107 def name(self): 

108 return self.data.get("name") 

109 

110 def set_name(self, name): 

111 self.data["name"] = name 

112 

113 @property 

114 def email(self): 

115 return self.data.get("email") 

116 

117 def set_email(self, email): 

118 self.data["email"] = email 

119 

120 def set_password(self, password): 

121 self.data['password'] = generate_password_hash(password) 

122 

123 def set_password_hash(self, hash): 

124 self.data['password'] = hash 

125 

126 def clear_password(self): 

127 if self.data.get('password'): 

128 del self.data['password'] 

129 

130 def check_password(self, password): 

131 try: 

132 return check_password_hash(self.data['password'], password) 

133 except KeyError: 

134 app.logger.error("Problem with user '{}' account: no password field".format(self.data['id'])) 

135 raise 

136 

137 @property 

138 def journal(self): 

139 return self.data.get("journal") 

140 

141 def add_journal(self, jid): 

142 if jid in self.data.get("journal", []): 

143 return 

144 if "journal" not in self.data: 

145 self.data["journal"] = [] 

146 if jid not in self.data["journal"]: 

147 self.data["journal"].append(jid) 

148 

149 def remove_journal(self, jid): 

150 if "journal" not in self.data: 

151 return 

152 self.data["journal"].remove(jid) 

153 

154 @property 

155 def reset_token(self): 

156 return self.data.get('reset_token') 

157 

158 def set_reset_token(self, token, timeout): 

159 expires = dates.now() + timedelta(0, timeout) 

160 self.data["reset_token"] = token 

161 self.data["reset_expires"] = expires.strftime(FMT_DATETIME_STD) 

162 

163 def remove_reset_token(self): 

164 if "reset_token" in self.data: 

165 del self.data["reset_token"] 

166 if "reset_expires" in self.data: 

167 del self.data["reset_expires"] 

168 

169 @property 

170 def reset_expires(self): 

171 return self.data.get("reset_expires") 

172 

173 @property 

174 def reset_expires_timestamp(self): 

175 expires = self.reset_expires 

176 if expires is None: 

177 return None 

178 return dates.parse(expires) 

179 

180 def is_reset_expired(self): 

181 expires = self.reset_expires_timestamp 

182 if expires is None: 

183 return True 

184 return expires < dates.now() 

185 

186 @property 

187 def is_super(self): 

188 # return not self.is_anonymous and self.id in app.config['SUPER_USER'] 

189 return Authorise.has_role(app.config["SUPER_USER_ROLE"], self.data.get("role", [])) 

190 

191 def has_role(self, role): 

192 return Authorise.has_role(role, self.data.get("role", [])) 

193 

194 @classmethod 

195 def all_top_level_roles(cls): 

196 return Authorise.top_level_roles() 

197 

198 def add_role(self, role): 

199 if "role" not in self.data: 

200 self.data["role"] = [] 

201 if role not in self.data["role"]: 

202 self.data["role"].append(role) 

203 # If we're adding the API role, ensure we also have a key to validate 

204 if role == 'api' and not self.data.get('api_key', None): 

205 self.generate_api_key() 

206 

207 def remove_role(self, role): 

208 if "role" not in self.data: 

209 return 

210 if role in self.data["role"]: 

211 self.data["role"].remove(role) 

212 

213 @property 

214 def role(self): 

215 return self.data.get("role", []) 

216 

217 def set_role(self, role): 

218 if not isinstance(role, list): 

219 role = [role] 

220 self.data["role"] = role 

221 

222 def prep(self): 

223 self.data['last_updated'] = dates.now_str() 

224 

225 @property 

226 def api_key(self): 

227 if self.has_role('api'): 

228 return self.data.get('api_key', None) 

229 else: 

230 return None 

231 

232 def generate_api_key(self): 

233 k = uuid.uuid4().hex 

234 self.data['api_key'] = k 

235 return k 

236 

237 @property 

238 def is_premium(self): 

239 return (self.has_role(constants.ROLE_PREMIUM) or 

240 self.has_role(constants.ROLE_PREMIUM_OAI) or 

241 self.has_role(constants.ROLE_PREMIUM_PDD) or 

242 self.has_role(constants.ROLE_PREMIUM_CSV)) 

243 

244 @classmethod 

245 def pull_by_api_key(cls, key): 

246 """Find a user by their API key - only succeed if they currently have API access.""" 

247 res = cls.query(q='api_key.exact:"' + key + '"') 

248 if res.get('hits', {}).get('total', {}).get('value', 0) == 1: 

249 usr = cls(**res['hits']['hits'][0]['_source']) 

250 if usr.has_role('api'): 

251 return usr 

252 return None 

253 

254 @classmethod 

255 def new_short_uuid(cls): 

256 """ Generate a short UUID and check it's unique in this type """ 

257 trunc_uuid = str(uuid.uuid4())[:8] 

258 if cls.pull(trunc_uuid) is None: 

259 return trunc_uuid 

260 else: 

261 return cls.new_short_uuid() 

262 

263 @classmethod 

264 def get_name_safe(cls, account_id) -> str: 

265 if account_id: 

266 author = Account.pull(account_id) 

267 if author is not None and author.name: 

268 return author.name 

269 return '' 

270 

271 @classmethod 

272 def is_enable_publisher_email(cls) -> bool: 

273 # TODO: in the long run this needs to move out to the user's email preferences but for now it 

274 # is here to replicate the behaviour in the code it replaces 

275 return app.config.get("ENABLE_PUBLISHER_EMAIL", False)