Coverage for portality / bll / services / query.py: 84%
214 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-04 09:41 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-04 09:41 +0100
1from portality.core import app, es_connection
2from portality.util import ipt_prefix
3from portality.bll import exceptions
4from portality.lib import plugin
5from copy import deepcopy
7class QueryService(object):
8 """
9 ~~Query:Service~~
10 """
11 def _get_config_for_search(self, domain, index_type, account):
12 # load the query route config and the path we are being requested for
13 # ~~-> Query:Config~~
14 qrs = app.config.get("QUERY_ROUTE", {})
16 # get the configuration for this url route
17 route_cfg = None
18 for key in qrs:
19 if domain == key:
20 route_cfg = qrs.get(key)
21 break
23 if route_cfg is None:
24 raise exceptions.AuthoriseException(exceptions.AuthoriseException.NOT_AUTHORISED)
26 cfg = route_cfg.get(index_type)
27 if cfg is None:
28 raise exceptions.AuthoriseException(exceptions.AuthoriseException.NOT_AUTHORISED)
30 # does the user have to be authenticated
31 if cfg.get("auth", True):
32 if account is None:
33 raise exceptions.AuthoriseException(exceptions.AuthoriseException.NOT_AUTHORISED)
35 # if so, does the user require a role
36 role = cfg.get("role")
37 if role is not None and not account.has_role(role):
38 raise exceptions.AuthoriseException(exceptions.AuthoriseException.WRONG_ROLE)
40 return cfg
42 def _validate_query(self, cfg, query):
43 validators = cfg.get("query_validators")
44 if validators:
45 for validator in validators:
46 filters = app.config.get("QUERY_FILTERS", {})
47 validator_path = filters.get(validator)
48 fn = plugin.load_function(validator_path)
49 if fn is None:
50 msg = "Unable to load query validator for {x}".format(x=validator)
51 raise exceptions.ConfigurationException(msg)
53 if not fn(query):
54 return False
55 return True
58 def _pre_filter_search_query(self, cfg, query):
59 # now run the query through the filters
60 filters = app.config.get("QUERY_FILTERS", {})
61 filter_names = cfg.get("query_filters", [])
62 for filter_name in filter_names:
63 # because of back-compat, we have to do a few tricky things here...
64 # filter may be the name of a filter in the list of query filters
65 fn = plugin.load_function(filters.get(filter_name))
66 if fn is None:
67 msg = "Unable to load query filter for {x}".format(x=filter_name)
68 raise exceptions.ConfigurationException(msg)
70 # run the filter
71 fn(query)
73 return query
75 def _post_filter_search_results(self, cfg, res, unpacked=False):
76 filters = app.config.get("QUERY_FILTERS", {})
77 result_filter_names = cfg.get("result_filters", [])
78 for result_filter_name in result_filter_names:
79 fn = plugin.load_function(filters.get(result_filter_name))
80 if fn is None:
81 msg = "Unable to load result filter for {x}".format(x=result_filter_name)
82 raise exceptions.ConfigurationException(msg)
84 # apply the result filter
85 res = fn(res, unpacked=unpacked)
87 return res
89 def _get_query(self, cfg, raw_query):
90 query = Query()
91 if raw_query is not None:
92 query = Query(raw_query)
94 # validate the query, to make sure it is of a permitted form
95 if not self._validate_query(cfg, query):
96 raise exceptions.AuthoriseException()
98 # add any required filters to the query
99 query = self._pre_filter_search_query(cfg, query)
100 return query
102 def _get_dao_klass(self, cfg):
103 # get the name of the model that will handle this query, and then look up
104 # the class that will handle it
105 dao_name = cfg.get("dao")
106 dao_klass = plugin.load_class(dao_name)
107 if dao_klass is None:
108 raise exceptions.NoSuchObjectException(dao_name)
109 return dao_klass
111 def search(self, domain, index_type, raw_query, account, additional_parameters):
112 cfg = self._get_config_for_search(domain, index_type, account)
114 # check that the request values permit a query to this endpoint
115 required_parameters = cfg.get("required_parameters")
116 if required_parameters is not None:
117 for k, vs in required_parameters.items():
118 val = additional_parameters.get(k)
119 if val is None or val not in vs:
120 raise exceptions.AuthoriseException()
122 dao_klass = self._get_dao_klass(cfg)
124 # get the query
125 query = self._get_query(cfg, raw_query)
127 # send the query
128 res = dao_klass.query(q=query.as_dict())
130 # filter the results as needed
131 res = self._post_filter_search_results(cfg, res)
133 return res
135 def scroll(self, domain, index_type, raw_query, account, page_size, scan=False):
136 cfg = self._get_config_for_search(domain, index_type, account)
138 dao_klass = self._get_dao_klass(cfg)
140 # get the query
141 query = self._get_query(cfg, raw_query)
143 # get the scroll parameters
144 if page_size is None:
145 page_size = cfg.get("page_size", 1000)
146 limit = cfg.get("limit", None)
147 keepalive = cfg.get("keepalive", "1m")
149 # ~~->Elasticsearch:Technology~~
150 for result in dao_klass.iterate(q=query.as_dict(), page_size=page_size, limit=limit, wrap=False, keepalive=keepalive):
151 res = self._post_filter_search_results(cfg, result, unpacked=True)
152 yield res
154 def make_actionable_query(self, domain, index_type, account, raw_query):
155 cfg = self._get_config_for_search(domain, index_type, account)
156 query = self._get_query(cfg, raw_query)
157 return query
159class Query(object):
160 """
161 ~~Query:Query -> Elasticsearch:Technology~~
162 """
163 def __init__(self, raw=None, filtered=False):
164 self.q = {"track_total_hits" : True, "query": {"match_all": {}}} if raw is None else raw
165 self.filtered = filtered is True or self.q.get("query", {}).get("filtered") is not None
166 if self.filtered:
167 # FIXME: this is just to help us catch filtered queries during development. Once we have them
168 # all, all the filtering logic in this class can come out
169 raise Exception("Filtered queries are no longer supported")
171 def convert_to_bool(self):
172 if self.filtered is True:
173 return
175 current_query = None
176 if "query" in self.q:
177 if "bool" in self.q["query"]:
178 return
179 current_query = deepcopy(self.q["query"])
180 del self.q["query"]
181 if len(list(current_query.keys())) == 0:
182 current_query = None
184 if "query" not in self.q:
185 self.q["query"] = {}
186 if "bool" not in self.q["query"]:
187 self.q["query"]["bool"] = {}
189 if current_query is not None:
190 if "must" not in self.q["query"]["bool"]:
191 self.q["query"]["bool"]["must"] = []
193 self.q["query"]["bool"]["must"].append(current_query)
195 def add_must(self, filter):
196 # self.convert_to_filtered()
197 self.convert_to_bool()
198 context = self.q["query"]["bool"]
199 if "must" not in context:
200 context["must"] = []
201 context["must"].append(filter)
203 def get_field_context(self):
204 """Get query string context"""
205 context = None
206 if "query_string" in self.q["query"]:
207 context = self.q["query"]["query_string"]
209 elif "bool" in self.q["query"]:
210 if "must" in self.q["query"]["bool"]:
211 context = self.q["query"]["bool"]["must"]
212 return context
214 def add_default_field(self, value: str):
215 """ Add a default field to the query string, if one is not already present"""
216 context = self.get_field_context()
218 if context:
219 if isinstance(context, dict):
220 if "default_field" not in context:
221 context["default_field"] = value
222 elif isinstance(context, list):
223 for item in context:
224 if "query_string" in item:
225 if "default_field" not in item["query_string"]:
226 item["query_string"]["default_field"] = value
227 break
229 def add_should(self, filter, minimum_should_match=1):
230 self.convert_to_bool()
231 context = self.q["query"]["bool"]
232 if "should" not in context:
233 context["should"] = []
234 if isinstance(filter, list):
235 context["should"].extend(filter)
236 else:
237 context["should"].append(filter)
238 context["minimum_should_match"] = minimum_should_match
240 def add_must_filter(self, filter):
241 self.convert_to_bool()
242 context = self.q["query"]["bool"]
243 if "filter" not in context:
244 context["filter"] = []
246 context["filter"].append(filter)
248 def add_must_not(self, filter):
249 self.convert_to_bool()
250 context = self.q["query"]["bool"]
251 if "must_not" not in context:
252 context["must_not"] = []
254 context["must_not"].append(filter)
256 def clear_match_all(self):
257 if "match_all" in self.q["query"]:
258 del self.q["query"]["match_all"]
260 def has_facets(self):
261 return "facets" in self.q or "aggregations" in self.q or "aggs" in self.q
263 def clear_facets(self):
264 if "facets" in self.q:
265 del self.q["facets"]
266 if "aggregations" in self.q:
267 del self.q["aggregations"]
268 if "aggs" in self.q:
269 del self.q["aggs"]
271 def size(self):
272 if "size" in self.q:
273 try:
274 return int(self.q["size"])
275 except ValueError:
276 app.logger.warn("Invalid size parameter in query: [{x}], "
277 "expected integer value".format(x=self.q["size"]))
278 return 10
280 def from_result(self):
281 if "from" in self.q:
282 try:
283 return int(self.q["from"])
284 except ValueError:
285 app.logger.warn("Invalid from parameter in query: [{x}], "
286 "expected integer value".format(x=self.q["from"]))
287 return 0
289 def as_dict(self):
290 return self.q
292 def add_include(self, fields):
293 if "_source" not in self.q:
294 self.q["_source"] = {}
295 if "includes" not in self.q["_source"]:
296 self.q["_source"]["includes"] = []
297 if not isinstance(fields, list):
298 fields = [fields]
299 self.q["_source"]["includes"] = list(set(self.q["_source"]["includes"] + fields))
301 def sort(self):
302 return self.q.get("sort")
304 def set_sort(self, s):
305 self.q["sort"] = s
308class QueryFilterException(Exception):
309 pass