Coverage for portality / decorators.py: 66%

104 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-04 09:41 +0100

1import json, signal 

2import re 

3from functools import wraps 

4from flask import request, abort, redirect, flash, url_for, render_template, make_response 

5from flask_login import login_user, current_user 

6 

7from portality.core import app 

8from portality.lib import dates 

9from portality.models import Account 

10from portality.models.harvester import HarvesterProgressReport as Report 

11from portality.ui import templates 

12 

13def swag(swag_summary, swag_spec): 

14 """ 

15 ~~Swagger:Feature~~ 

16 Decorator for API functions, adding swagger info to the swagger spec. 

17 """ 

18 def decorator(f): 

19 f.summary = re.sub('</?(span|div).*?>', '', swag_summary) 

20 f.swag = swag_spec 

21 f.description = swag_summary 

22 return f 

23 

24 return decorator 

25 

26 

27def api_key_required(fn): 

28 """ 

29 ~~APIKey:Feature~~ 

30 Decorator for API functions, requiring a valid key to find a user 

31 """ 

32 @wraps(fn) 

33 def decorated_view(*args, **kwargs): 

34 api_key = request.values.get("api_key", None) 

35 if api_key is not None: 

36 user = Account.pull_by_api_key(api_key) 

37 if user is not None: 

38 if login_user(user, remember=False): 

39 return fn(*args, **kwargs) 

40 # else 

41 from portality.api.common import Api401Error 

42 raise Api401Error("An API Key is required to access this.") 

43 

44 return decorated_view 

45 

46 

47def api_key_optional(fn): 

48 """ 

49 ~~APIKey:Feature~~ 

50 Decorator for API functions, requiring a valid key to find a user if a key is provided. OK if none provided. 

51 """ 

52 @wraps(fn) 

53 def decorated_view(*args, **kwargs): 

54 api_key = request.values.get("api_key", None) 

55 if api_key: 

56 user = Account.pull_by_api_key(api_key) 

57 if user is not None: 

58 if login_user(user, remember=False): 

59 return fn(*args, **kwargs) 

60 # else 

61 abort(401) 

62 

63 # no api key, which is ok 

64 return fn(*args, **kwargs) 

65 

66 return decorated_view 

67 

68 

69def ssl_required(fn): 

70 """ 

71 ~~SSLRequired:Feature~~ 

72 Decorator for when a view f() should be served only over SSL 

73 """ 

74 @wraps(fn) 

75 def decorated_view(*args, **kwargs): 

76 if app.config.get("SSL"): 

77 if request.is_secure: 

78 return fn(*args, **kwargs) 

79 else: 

80 return redirect(request.url.replace("http://", "https://")) 

81 

82 return fn(*args, **kwargs) 

83 

84 return decorated_view 

85 

86 

87def restrict_to_role(role): 

88 """ 

89 ~~Authorisation:Feature~~ 

90 :param role: 

91 :return: 

92 """ 

93 if current_user.is_anonymous: 

94 flash('You are trying to access a protected area. Please log in first.', 'error') 

95 return redirect(url_for('account.login', next=request.url)) 

96 

97 if not current_user.has_role(role): 

98 flash('You do not have permission to access this area of the site.', 'error') 

99 return redirect(url_for('doaj.home')) 

100 

101 

102def write_required(script=False, api=False, allowed_methods=None): 

103 """ 

104 ~~ReadOnlyMode:Feature~~ 

105 :param script: 

106 :param api: 

107 :return: 

108 """ 

109 if allowed_methods is None: 

110 allowed_methods = {"GET", "HEAD", "OPTIONS"} 

111 

112 def decorator(fn): 

113 @wraps(fn) 

114 def decorated_view(*args, **kwargs): 

115 if app.config.get("READ_ONLY_MODE", False): 

116 # try to detect request method; if no request context, treat as non-safe unless `script` is True 

117 try: 

118 method = request.method 

119 except RuntimeError: 

120 method = None 

121 

122 if method in allowed_methods: 

123 return fn(*args, **kwargs) 

124 

125 # TODO remove "script" argument from decorator. 

126 # Should be possible to detect if this is run in a web context or not. 

127 if script: 

128 raise RuntimeError('This task cannot run since the system is in read-only mode.') 

129 elif api: 

130 resp = make_response(json.dumps({"message" : "We are currently carrying out essential maintenance, and this route is temporarily unavailable"}), 503) 

131 resp.mimetype = "application/json" 

132 return resp 

133 else: 

134 # FIXME: ideally, this would show a different page for each different user class 

135 return render_template(templates.PUBLIC_READ_ONLY_MODE) 

136 

137 return fn(*args, **kwargs) 

138 

139 setattr(decorated_view, "_write_required", True) 

140 return decorated_view 

141 return decorator 

142 

143 

144class CaughtTermException(Exception): 

145 pass 

146 

147 

148def _term_handler(signum, frame): 

149 app.logger.warning("Harvester terminated with signal " + str(signum)) 

150 raise CaughtTermException 

151 

152 

153def capture_sigterm(fn): 

154 """ 

155 ~~CaptureSigterm:Feature~~ 

156 Decorator which allows graceful exit on SIGTERM 

157 """ 

158 

159 # Register the SIGTERM handler to raise an exception, allowing graceful exit. 

160 signal.signal(signal.SIGTERM, _term_handler) 

161 

162 @wraps(fn) 

163 def decorated_fn(*args, **kwargs): 

164 try: 

165 fn(*args, **kwargs) 

166 except (CaughtTermException, KeyboardInterrupt): 

167 app.logger.warning(u"Harvester caught SIGTERM. Exiting.") 

168 report = Report.write_report() 

169 if app.config.get("HARVESTER_EMAIL_ON_EVENT", False): 

170 to = app.config.get("HARVESTER_EMAIL_RECIPIENTS", None) 

171 fro = app.config.get("SYSTEM_EMAIL_FROM") 

172 

173 if to is not None: 

174 from portality import app_email as mail 

175 mail.send_mail( 

176 to=to, 

177 fro=fro, 

178 subject="DOAJ Harvester caught SIGTERM at {0}".format(dates.now_str()), 

179 msg_body=report 

180 ) 

181 app.logger.info(report) 

182 exit(1) 

183 

184 return decorated_fn