Coverage for portality/bll/services/query.py: 76%
177 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-07-22 15:59 +0100
« prev ^ index » next coverage.py v6.4.2, created at 2022-07-22 15:59 +0100
1from portality.core import app, es_connection
2from portality.util import ipt_prefix
3from portality.bll import exceptions
4from portality.lib import plugin
5from copy import deepcopy
7import esprit
10class QueryService(object):
11 """
12 ~~Query:Service~~
13 """
14 def _get_config_for_search(self, domain, index_type, account):
15 # load the query route config and the path we are being requested for
16 # ~~-> Query:Config~~
17 qrs = app.config.get("QUERY_ROUTE", {})
19 # get the configuration for this url route
20 route_cfg = None
21 for key in qrs:
22 if domain == key:
23 route_cfg = qrs.get(key)
24 break
26 if route_cfg is None:
27 raise exceptions.AuthoriseException(exceptions.AuthoriseException.NOT_AUTHORISED)
29 cfg = route_cfg.get(index_type)
30 if cfg is None:
31 raise exceptions.AuthoriseException(exceptions.AuthoriseException.NOT_AUTHORISED)
33 # does the user have to be authenticated
34 if cfg.get("auth", True):
35 if account is None:
36 raise exceptions.AuthoriseException(exceptions.AuthoriseException.NOT_AUTHORISED)
38 # if so, does the user require a role
39 role = cfg.get("role")
40 if role is not None and not account.has_role(role):
41 raise exceptions.AuthoriseException(exceptions.AuthoriseException.WRONG_ROLE)
43 return cfg
45 def _validate_query(self, cfg, query):
46 validator = cfg.get("query_validator")
47 if validator is None:
48 return True
50 filters = app.config.get("QUERY_FILTERS", {})
51 validator_path = filters.get(validator)
52 fn = plugin.load_function(validator_path)
53 if fn is None:
54 msg = "Unable to load query validator for {x}".format(x=validator)
55 raise exceptions.ConfigurationException(msg)
57 return fn(query)
59 def _pre_filter_search_query(self, cfg, query):
60 # now run the query through the filters
61 filters = app.config.get("QUERY_FILTERS", {})
62 filter_names = cfg.get("query_filters", [])
63 for filter_name in filter_names:
64 # because of back-compat, we have to do a few tricky things here...
65 # filter may be the name of a filter in the list of query filters
66 fn = plugin.load_function(filters.get(filter_name))
67 if fn is None:
68 msg = "Unable to load query filter for {x}".format(x=filter_name)
69 raise exceptions.ConfigurationException(msg)
71 # run the filter
72 fn(query)
74 return query
76 def _post_filter_search_results(self, cfg, res, unpacked=False):
77 filters = app.config.get("QUERY_FILTERS", {})
78 result_filter_names = cfg.get("result_filters", [])
79 for result_filter_name in result_filter_names:
80 fn = plugin.load_function(filters.get(result_filter_name))
81 if fn is None:
82 msg = "Unable to load result filter for {x}".format(x=result_filter_name)
83 raise exceptions.ConfigurationException(msg)
85 # apply the result filter
86 res = fn(res, unpacked=unpacked)
88 return res
90 def _get_query(self, cfg, raw_query):
91 query = Query()
92 if raw_query is not None:
93 query = Query(raw_query)
95 # validate the query, to make sure it is of a permitted form
96 if not self._validate_query(cfg, query):
97 raise exceptions.AuthoriseException()
99 # add any required filters to the query
100 query = self._pre_filter_search_query(cfg, query)
101 return query
103 def _get_dao_klass(self, cfg):
104 # get the name of the model that will handle this query, and then look up
105 # the class that will handle it
106 dao_name = cfg.get("dao")
107 dao_klass = plugin.load_class(dao_name)
108 if dao_klass is None:
109 raise exceptions.NoSuchObjectException(dao_name)
110 return dao_klass
112 def search(self, domain, index_type, raw_query, account, additional_parameters):
113 cfg = self._get_config_for_search(domain, index_type, account)
115 # check that the request values permit a query to this endpoint
116 required_parameters = cfg.get("required_parameters")
117 if required_parameters is not None:
118 for k, vs in required_parameters.items():
119 val = additional_parameters.get(k)
120 if val is None or val not in vs:
121 raise exceptions.AuthoriseException()
123 dao_klass = self._get_dao_klass(cfg)
125 # get the query
126 query = self._get_query(cfg, raw_query)
128 # send the query
129 res = dao_klass.query(q=query.as_dict())
131 # filter the results as needed
132 res = self._post_filter_search_results(cfg, res)
134 return res
136 def scroll(self, domain, index_type, raw_query, account, page_size, scan=False):
137 cfg = self._get_config_for_search(domain, index_type, account)
139 dao_klass = self._get_dao_klass(cfg)
141 # get the query
142 query = self._get_query(cfg, raw_query)
144 # get the scroll parameters
145 if page_size is None:
146 page_size = cfg.get("page_size", 1000)
147 limit = cfg.get("limit", None)
148 keepalive = cfg.get("keepalive", "1m")
150 # ~~->Elasticsearch:Technology~~
151 for result in dao_klass.iterate(q=query.as_dict(), page_size=page_size, limit=limit, wrap=False, keepalive=keepalive):
152 res = self._post_filter_search_results(cfg, result, unpacked=True)
153 yield res
156class Query(object):
157 """
158 ~~Query:Query -> Elasticsearch:Technology~~
159 """
160 def __init__(self, raw=None, filtered=False):
161 self.q = {"track_total_hits" : True, "query": {"match_all": {}}} if raw is None else raw
162 self.filtered = filtered is True or self.q.get("query", {}).get("filtered") is not None
163 if self.filtered:
164 # FIXME: this is just to help us catch filtered queries during development. Once we have them
165 # all, all the filtering logic in this class can come out
166 raise Exception("Filtered queries are no longer supported")
168 def convert_to_bool(self):
169 if self.filtered is True:
170 return
172 current_query = None
173 if "query" in self.q:
174 if "bool" in self.q["query"]:
175 return
176 current_query = deepcopy(self.q["query"])
177 del self.q["query"]
178 if len(list(current_query.keys())) == 0:
179 current_query = None
181 if "query" not in self.q:
182 self.q["query"] = {}
183 if "bool" not in self.q["query"]:
184 self.q["query"]["bool"] = {}
186 if current_query is not None:
187 if "must" not in self.q["query"]["bool"]:
188 self.q["query"]["bool"]["must"] = []
190 self.q["query"]["bool"]["must"].append(current_query)
192 def add_must(self, filter):
193 # self.convert_to_filtered()
194 self.convert_to_bool()
195 context = self.q["query"]["bool"]
196 if "must" not in context:
197 context["must"] = []
198 context["must"].append(filter)
200 def add_must_filter(self, filter):
201 self.convert_to_bool()
202 context = self.q["query"]["bool"]
203 if "filter" not in context:
204 context["filter"] = []
206 context["filter"].append(filter)
208 def add_must_not(self, filter):
209 self.convert_to_bool()
210 context = self.q["query"]["bool"]
211 if "must_not" not in context:
212 context["must_not"] = []
214 context["must_not"].append(filter)
216 def clear_match_all(self):
217 if "match_all" in self.q["query"]:
218 del self.q["query"]["match_all"]
220 def has_facets(self):
221 return "facets" in self.q or "aggregations" in self.q or "aggs" in self.q
223 def clear_facets(self):
224 if "facets" in self.q:
225 del self.q["facets"]
226 if "aggregations" in self.q:
227 del self.q["aggregations"]
228 if "aggs" in self.q:
229 del self.q["aggs"]
231 def size(self):
232 if "size" in self.q:
233 try:
234 return int(self.q["size"])
235 except ValueError:
236 return 10
237 return 10
239 def from_result(self):
240 if "from" in self.q:
241 return int(self.q["from"])
242 return 0
244 def as_dict(self):
245 return self.q
247 def add_include(self, fields):
248 if "_source" not in self.q:
249 self.q["_source"] = {}
250 if "includes" not in self.q["_source"]:
251 self.q["_source"]["includes"] = []
252 if not isinstance(fields, list):
253 fields = [fields]
254 self.q["_source"]["includes"] = list(set(self.q["_source"]["includes"] + fields))
256 def sort(self):
257 return self.q.get("sort")
259 def set_sort(self, s):
260 self.q["sort"] = s
263class QueryFilterException(Exception):
264 pass