Coverage for portality/models/account.py: 70%

183 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-07-22 15:59 +0100

1import uuid 

2from flask_login import UserMixin 

3from datetime import datetime, timedelta 

4from werkzeug.security import generate_password_hash, check_password_hash 

5 

6from portality.dao import DomainObject as DomainObject 

7from portality.core import app 

8from portality.authorise import Authorise 

9 

10 

11class Account(DomainObject, UserMixin): 

12 __type__ = 'account' 

13 

14 def __init__(self, **kwargs): 

15 from portality.forms.validate import ReservedUsernames 

16 ReservedUsernames().validate(kwargs.get('id', '')) 

17 super(Account, self).__init__(**kwargs) 

18 

19 @classmethod 

20 def make_account(cls, email, username=None, name=None, roles=None, associated_journal_ids=None): 

21 if roles is None: 

22 roles = [] 

23 

24 if associated_journal_ids is None: 

25 associated_journal_ids = [] 

26 

27 # If we have an existing account with these credentials, supply it 

28 a = cls.pull(username) or cls.pull_by_email(email) 

29 if a: 

30 return a 

31 

32 # Create a new account 

33 _id = username or cls.new_short_uuid() 

34 a = Account(id=_id) 

35 a.set_email(email) 

36 a.set_name(name) if name else None 

37 

38 for role in roles: 

39 a.add_role(role) 

40 for jid in associated_journal_ids: 

41 a.add_journal(jid) 

42 

43 # New accounts don't have passwords set - create a reset token for password. 

44 reset_token = uuid.uuid4().hex 

45 # give them 14 days to create their first password if timeout not specified in config 

46 a.set_reset_token(reset_token, app.config.get("PASSWORD_CREATE_TIMEOUT", 

47 app.config.get('PASSWORD_RESET_TIMEOUT', 86400) * 14)) 

48 return a 

49 

50 @classmethod 

51 def pull_by_email(cls, email: str): 

52 if email is None: 

53 return None 

54 res = cls.query(q='email:"' + email + '"') 

55 if res.get('hits', {}).get('total', {}).get('value', 0) == 1: 

56 acc = cls(**res['hits']['hits'][0]['_source']) 

57 if acc.email == email: # Only return the account if it was an exact match with supplied email 

58 return acc 

59 return None 

60 

61 @classmethod 

62 def email_in_use(cls, email: str): 

63 if email is None: 

64 return None 

65 res = cls.query(q='email:"' + email + '"') 

66 return res.get('hits', {}).get('total', {}).get('value', 0) > 0 

67 

68 @classmethod 

69 def get_by_reset_token(cls, reset_token, not_expired=True): 

70 res = cls.query(q='reset_token.exact:"' + reset_token + '"') 

71 obs = [hit.get("_source") for hit in res.get("hits", {}).get("hits", [])] 

72 if len(obs) == 0 or len(obs) > 1: 

73 return None 

74 expires = obs[0].get("reset_expires") 

75 if expires is None: 

76 return None 

77 if not_expired: 

78 try: 

79 ed = datetime.strptime(expires, "%Y-%m-%dT%H:%M:%SZ") 

80 if ed < datetime.now(): 

81 return None 

82 except ValueError: 

83 return None 

84 return cls(**obs[0]) 

85 

86 @property 

87 def marketing_consent(self): 

88 return self.data.get("marketing_consent") 

89 

90 def set_marketing_consent(self, consent): 

91 self.data["marketing_consent"] = bool(consent) 

92 

93 @property 

94 def name(self): 

95 return self.data.get("name") 

96 

97 def set_name(self, name): 

98 self.data["name"] = name 

99 

100 @property 

101 def email(self): 

102 return self.data.get("email") 

103 

104 def set_email(self, email): 

105 self.data["email"] = email 

106 

107 def set_password(self, password): 

108 self.data['password'] = generate_password_hash(password) 

109 

110 def clear_password(self): 

111 if self.data.get('password'): 

112 del self.data['password'] 

113 

114 def check_password(self, password): 

115 try: 

116 return check_password_hash(self.data['password'], password) 

117 except KeyError: 

118 app.logger.error("Problem with user '{}' account: no password field".format(self.data['id'])) 

119 raise 

120 

121 @property 

122 def journal(self): 

123 return self.data.get("journal") 

124 

125 def add_journal(self, jid): 

126 if jid in self.data.get("journal", []): 

127 return 

128 if "journal" not in self.data: 

129 self.data["journal"] = [] 

130 if jid not in self.data["journal"]: 

131 self.data["journal"].append(jid) 

132 

133 def remove_journal(self, jid): 

134 if "journal" not in self.data: 

135 return 

136 self.data["journal"].remove(jid) 

137 

138 @property 

139 def reset_token(self): return self.data.get('reset_token') 

140 

141 def set_reset_token(self, token, timeout): 

142 expires = datetime.now() + timedelta(0, timeout) 

143 self.data["reset_token"] = token 

144 self.data["reset_expires"] = expires.strftime("%Y-%m-%dT%H:%M:%SZ") 

145 

146 def remove_reset_token(self): 

147 if "reset_token" in self.data: 

148 del self.data["reset_token"] 

149 if "reset_expires" in self.data: 

150 del self.data["reset_expires"] 

151 

152 @property 

153 def reset_expires(self): 

154 return self.data.get("reset_expires") 

155 

156 @property 

157 def reset_expires_timestamp(self): 

158 expires = self.reset_expires 

159 if expires is None: 

160 return None 

161 return datetime.strptime(expires, "%Y-%m-%dT%H:%M:%SZ") 

162 

163 def is_reset_expired(self): 

164 expires = self.reset_expires_timestamp 

165 if expires is None: 

166 return True 

167 return expires < datetime.utcnow() 

168 

169 @property 

170 def is_super(self): 

171 # return not self.is_anonymous and self.id in app.config['SUPER_USER'] 

172 return Authorise.has_role(app.config["SUPER_USER_ROLE"], self.data.get("role", [])) 

173 

174 def has_role(self, role): 

175 return Authorise.has_role(role, self.data.get("role", [])) 

176 

177 @classmethod 

178 def all_top_level_roles(cls): 

179 return Authorise.top_level_roles() 

180 

181 def add_role(self, role): 

182 if "role" not in self.data: 

183 self.data["role"] = [] 

184 if role not in self.data["role"]: 

185 self.data["role"].append(role) 

186 # If we're adding the API role, ensure we also have a key to validate 

187 if role == 'api' and not self.data.get('api_key', None): 

188 self.generate_api_key() 

189 

190 def remove_role(self, role): 

191 if "role" not in self.data: 

192 return 

193 if role in self.data["role"]: 

194 self.data["role"].remove(role) 

195 

196 @property 

197 def role(self): 

198 return self.data.get("role", []) 

199 

200 def set_role(self, role): 

201 if not isinstance(role, list): 

202 role = [role] 

203 self.data["role"] = role 

204 

205 def prep(self): 

206 self.data['last_updated'] = datetime.now().strftime("%Y-%m-%dT%H:%M:%SZ") 

207 

208 @property 

209 def api_key(self): 

210 if self.has_role('api'): 

211 return self.data.get('api_key', None) 

212 else: 

213 return None 

214 

215 def generate_api_key(self): 

216 k = uuid.uuid4().hex 

217 self.data['api_key'] = k 

218 return k 

219 

220 @classmethod 

221 def pull_by_api_key(cls, key): 

222 """Find a user by their API key - only succeed if they currently have API access.""" 

223 res = cls.query(q='api_key.exact:"' + key + '"') 

224 if res.get('hits', {}).get('total', {}).get('value', 0) == 1: 

225 usr = cls(**res['hits']['hits'][0]['_source']) 

226 if usr.has_role('api'): 

227 return usr 

228 return None 

229 

230 @classmethod 

231 def new_short_uuid(cls): 

232 """ Generate a short UUID and check it's unique in this type """ 

233 trunc_uuid = str(uuid.uuid4())[:8] 

234 if cls.pull(trunc_uuid) is None: 

235 return trunc_uuid 

236 else: 

237 return cls.new_short_uuid()