Coverage for portality / bll / services / query.py: 84%

214 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 00:09 +0100

1from portality.core import app, es_connection 

2from portality.util import ipt_prefix 

3from portality.bll import exceptions 

4from portality.lib import plugin 

5from copy import deepcopy 

6 

7class QueryService(object): 

8 """ 

9 ~~Query:Service~~ 

10 """ 

11 def _get_config_for_search(self, domain, index_type, account): 

12 # load the query route config and the path we are being requested for 

13 # ~~-> Query:Config~~ 

14 qrs = app.config.get("QUERY_ROUTE", {}) 

15 

16 # get the configuration for this url route 

17 route_cfg = None 

18 for key in qrs: 

19 if domain == key: 

20 route_cfg = qrs.get(key) 

21 break 

22 

23 if route_cfg is None: 

24 raise exceptions.AuthoriseException(exceptions.AuthoriseException.NOT_AUTHORISED) 

25 

26 cfg = route_cfg.get(index_type) 

27 if cfg is None: 

28 raise exceptions.AuthoriseException(exceptions.AuthoriseException.NOT_AUTHORISED) 

29 

30 # does the user have to be authenticated 

31 if cfg.get("auth", True): 

32 if account is None: 

33 raise exceptions.AuthoriseException(exceptions.AuthoriseException.NOT_AUTHORISED) 

34 

35 # if so, does the user require a role 

36 role = cfg.get("role") 

37 if role is not None and not account.has_role(role): 

38 raise exceptions.AuthoriseException(exceptions.AuthoriseException.WRONG_ROLE) 

39 

40 return cfg 

41 

42 def _validate_query(self, cfg, query): 

43 validators = cfg.get("query_validators") 

44 if validators: 

45 for validator in validators: 

46 filters = app.config.get("QUERY_FILTERS", {}) 

47 validator_path = filters.get(validator) 

48 fn = plugin.load_function(validator_path) 

49 if fn is None: 

50 msg = "Unable to load query validator for {x}".format(x=validator) 

51 raise exceptions.ConfigurationException(msg) 

52 

53 if not fn(query): 

54 return False 

55 return True 

56 

57 

58 def _pre_filter_search_query(self, cfg, query): 

59 # now run the query through the filters 

60 filters = app.config.get("QUERY_FILTERS", {}) 

61 filter_names = cfg.get("query_filters", []) 

62 for filter_name in filter_names: 

63 # because of back-compat, we have to do a few tricky things here... 

64 # filter may be the name of a filter in the list of query filters 

65 fn = plugin.load_function(filters.get(filter_name)) 

66 if fn is None: 

67 msg = "Unable to load query filter for {x}".format(x=filter_name) 

68 raise exceptions.ConfigurationException(msg) 

69 

70 # run the filter 

71 fn(query) 

72 

73 return query 

74 

75 def _post_filter_search_results(self, cfg, res, unpacked=False): 

76 filters = app.config.get("QUERY_FILTERS", {}) 

77 result_filter_names = cfg.get("result_filters", []) 

78 for result_filter_name in result_filter_names: 

79 fn = plugin.load_function(filters.get(result_filter_name)) 

80 if fn is None: 

81 msg = "Unable to load result filter for {x}".format(x=result_filter_name) 

82 raise exceptions.ConfigurationException(msg) 

83 

84 # apply the result filter 

85 res = fn(res, unpacked=unpacked) 

86 

87 return res 

88 

89 def _get_query(self, cfg, raw_query): 

90 query = Query() 

91 if raw_query is not None: 

92 query = Query(raw_query) 

93 

94 # validate the query, to make sure it is of a permitted form 

95 if not self._validate_query(cfg, query): 

96 raise exceptions.AuthoriseException() 

97 

98 # add any required filters to the query 

99 query = self._pre_filter_search_query(cfg, query) 

100 return query 

101 

102 def _get_dao_klass(self, cfg): 

103 # get the name of the model that will handle this query, and then look up 

104 # the class that will handle it 

105 dao_name = cfg.get("dao") 

106 dao_klass = plugin.load_class(dao_name) 

107 if dao_klass is None: 

108 raise exceptions.NoSuchObjectException(dao_name) 

109 return dao_klass 

110 

111 def search(self, domain, index_type, raw_query, account, additional_parameters): 

112 cfg = self._get_config_for_search(domain, index_type, account) 

113 

114 # check that the request values permit a query to this endpoint 

115 required_parameters = cfg.get("required_parameters") 

116 if required_parameters is not None: 

117 for k, vs in required_parameters.items(): 

118 val = additional_parameters.get(k) 

119 if val is None or val not in vs: 

120 raise exceptions.AuthoriseException() 

121 

122 dao_klass = self._get_dao_klass(cfg) 

123 

124 # get the query 

125 query = self._get_query(cfg, raw_query) 

126 

127 # send the query 

128 res = dao_klass.query(q=query.as_dict()) 

129 

130 # filter the results as needed 

131 res = self._post_filter_search_results(cfg, res) 

132 

133 return res 

134 

135 def scroll(self, domain, index_type, raw_query, account, page_size, scan=False): 

136 cfg = self._get_config_for_search(domain, index_type, account) 

137 

138 dao_klass = self._get_dao_klass(cfg) 

139 

140 # get the query 

141 query = self._get_query(cfg, raw_query) 

142 

143 # get the scroll parameters 

144 if page_size is None: 

145 page_size = cfg.get("page_size", 1000) 

146 limit = cfg.get("limit", None) 

147 keepalive = cfg.get("keepalive", "1m") 

148 

149 # ~~->Elasticsearch:Technology~~ 

150 for result in dao_klass.iterate(q=query.as_dict(), page_size=page_size, limit=limit, wrap=False, keepalive=keepalive): 

151 res = self._post_filter_search_results(cfg, result, unpacked=True) 

152 yield res 

153 

154 def make_actionable_query(self, domain, index_type, account, raw_query): 

155 cfg = self._get_config_for_search(domain, index_type, account) 

156 query = self._get_query(cfg, raw_query) 

157 return query 

158 

159class Query(object): 

160 """ 

161 ~~Query:Query -> Elasticsearch:Technology~~ 

162 """ 

163 def __init__(self, raw=None, filtered=False): 

164 self.q = {"track_total_hits" : True, "query": {"match_all": {}}} if raw is None else raw 

165 self.filtered = filtered is True or self.q.get("query", {}).get("filtered") is not None 

166 if self.filtered: 

167 # FIXME: this is just to help us catch filtered queries during development. Once we have them 

168 # all, all the filtering logic in this class can come out 

169 raise Exception("Filtered queries are no longer supported") 

170 

171 def convert_to_bool(self): 

172 if self.filtered is True: 

173 return 

174 

175 current_query = None 

176 if "query" in self.q: 

177 if "bool" in self.q["query"]: 

178 return 

179 current_query = deepcopy(self.q["query"]) 

180 del self.q["query"] 

181 if len(list(current_query.keys())) == 0: 

182 current_query = None 

183 

184 if "query" not in self.q: 

185 self.q["query"] = {} 

186 if "bool" not in self.q["query"]: 

187 self.q["query"]["bool"] = {} 

188 

189 if current_query is not None: 

190 if "must" not in self.q["query"]["bool"]: 

191 self.q["query"]["bool"]["must"] = [] 

192 

193 self.q["query"]["bool"]["must"].append(current_query) 

194 

195 def add_must(self, filter): 

196 # self.convert_to_filtered() 

197 self.convert_to_bool() 

198 context = self.q["query"]["bool"] 

199 if "must" not in context: 

200 context["must"] = [] 

201 context["must"].append(filter) 

202 

203 def get_field_context(self): 

204 """Get query string context""" 

205 context = None 

206 if "query_string" in self.q["query"]: 

207 context = self.q["query"]["query_string"] 

208 

209 elif "bool" in self.q["query"]: 

210 if "must" in self.q["query"]["bool"]: 

211 context = self.q["query"]["bool"]["must"] 

212 return context 

213 

214 def add_default_field(self, value: str): 

215 """ Add a default field to the query string, if one is not already present""" 

216 context = self.get_field_context() 

217 

218 if context: 

219 if isinstance(context, dict): 

220 if "default_field" not in context: 

221 context["default_field"] = value 

222 elif isinstance(context, list): 

223 for item in context: 

224 if "query_string" in item: 

225 if "default_field" not in item["query_string"]: 

226 item["query_string"]["default_field"] = value 

227 break 

228 

229 def add_should(self, filter, minimum_should_match=1): 

230 self.convert_to_bool() 

231 context = self.q["query"]["bool"] 

232 if "should" not in context: 

233 context["should"] = [] 

234 if isinstance(filter, list): 

235 context["should"].extend(filter) 

236 else: 

237 context["should"].append(filter) 

238 context["minimum_should_match"] = minimum_should_match 

239 

240 def add_must_filter(self, filter): 

241 self.convert_to_bool() 

242 context = self.q["query"]["bool"] 

243 if "filter" not in context: 

244 context["filter"] = [] 

245 

246 context["filter"].append(filter) 

247 

248 def add_must_not(self, filter): 

249 self.convert_to_bool() 

250 context = self.q["query"]["bool"] 

251 if "must_not" not in context: 

252 context["must_not"] = [] 

253 

254 context["must_not"].append(filter) 

255 

256 def clear_match_all(self): 

257 if "match_all" in self.q["query"]: 

258 del self.q["query"]["match_all"] 

259 

260 def has_facets(self): 

261 return "facets" in self.q or "aggregations" in self.q or "aggs" in self.q 

262 

263 def clear_facets(self): 

264 if "facets" in self.q: 

265 del self.q["facets"] 

266 if "aggregations" in self.q: 

267 del self.q["aggregations"] 

268 if "aggs" in self.q: 

269 del self.q["aggs"] 

270 

271 def size(self): 

272 if "size" in self.q: 

273 try: 

274 return int(self.q["size"]) 

275 except ValueError: 

276 app.logger.warn("Invalid size parameter in query: [{x}], " 

277 "expected integer value".format(x=self.q["size"])) 

278 return 10 

279 

280 def from_result(self): 

281 if "from" in self.q: 

282 try: 

283 return int(self.q["from"]) 

284 except ValueError: 

285 app.logger.warn("Invalid from parameter in query: [{x}], " 

286 "expected integer value".format(x=self.q["from"])) 

287 return 0 

288 

289 def as_dict(self): 

290 return self.q 

291 

292 def add_include(self, fields): 

293 if "_source" not in self.q: 

294 self.q["_source"] = {} 

295 if "includes" not in self.q["_source"]: 

296 self.q["_source"]["includes"] = [] 

297 if not isinstance(fields, list): 

298 fields = [fields] 

299 self.q["_source"]["includes"] = list(set(self.q["_source"]["includes"] + fields)) 

300 

301 def sort(self): 

302 return self.q.get("sort") 

303 

304 def set_sort(self, s): 

305 self.q["sort"] = s 

306 

307 

308class QueryFilterException(Exception): 

309 pass