Coverage for portality/bll/services/query.py: 76%

177 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-07-22 15:59 +0100

1from portality.core import app, es_connection 

2from portality.util import ipt_prefix 

3from portality.bll import exceptions 

4from portality.lib import plugin 

5from copy import deepcopy 

6 

7import esprit 

8 

9 

10class QueryService(object): 

11 """ 

12 ~~Query:Service~~ 

13 """ 

14 def _get_config_for_search(self, domain, index_type, account): 

15 # load the query route config and the path we are being requested for 

16 # ~~-> Query:Config~~ 

17 qrs = app.config.get("QUERY_ROUTE", {}) 

18 

19 # get the configuration for this url route 

20 route_cfg = None 

21 for key in qrs: 

22 if domain == key: 

23 route_cfg = qrs.get(key) 

24 break 

25 

26 if route_cfg is None: 

27 raise exceptions.AuthoriseException(exceptions.AuthoriseException.NOT_AUTHORISED) 

28 

29 cfg = route_cfg.get(index_type) 

30 if cfg is None: 

31 raise exceptions.AuthoriseException(exceptions.AuthoriseException.NOT_AUTHORISED) 

32 

33 # does the user have to be authenticated 

34 if cfg.get("auth", True): 

35 if account is None: 

36 raise exceptions.AuthoriseException(exceptions.AuthoriseException.NOT_AUTHORISED) 

37 

38 # if so, does the user require a role 

39 role = cfg.get("role") 

40 if role is not None and not account.has_role(role): 

41 raise exceptions.AuthoriseException(exceptions.AuthoriseException.WRONG_ROLE) 

42 

43 return cfg 

44 

45 def _validate_query(self, cfg, query): 

46 validator = cfg.get("query_validator") 

47 if validator is None: 

48 return True 

49 

50 filters = app.config.get("QUERY_FILTERS", {}) 

51 validator_path = filters.get(validator) 

52 fn = plugin.load_function(validator_path) 

53 if fn is None: 

54 msg = "Unable to load query validator for {x}".format(x=validator) 

55 raise exceptions.ConfigurationException(msg) 

56 

57 return fn(query) 

58 

59 def _pre_filter_search_query(self, cfg, query): 

60 # now run the query through the filters 

61 filters = app.config.get("QUERY_FILTERS", {}) 

62 filter_names = cfg.get("query_filters", []) 

63 for filter_name in filter_names: 

64 # because of back-compat, we have to do a few tricky things here... 

65 # filter may be the name of a filter in the list of query filters 

66 fn = plugin.load_function(filters.get(filter_name)) 

67 if fn is None: 

68 msg = "Unable to load query filter for {x}".format(x=filter_name) 

69 raise exceptions.ConfigurationException(msg) 

70 

71 # run the filter 

72 fn(query) 

73 

74 return query 

75 

76 def _post_filter_search_results(self, cfg, res, unpacked=False): 

77 filters = app.config.get("QUERY_FILTERS", {}) 

78 result_filter_names = cfg.get("result_filters", []) 

79 for result_filter_name in result_filter_names: 

80 fn = plugin.load_function(filters.get(result_filter_name)) 

81 if fn is None: 

82 msg = "Unable to load result filter for {x}".format(x=result_filter_name) 

83 raise exceptions.ConfigurationException(msg) 

84 

85 # apply the result filter 

86 res = fn(res, unpacked=unpacked) 

87 

88 return res 

89 

90 def _get_query(self, cfg, raw_query): 

91 query = Query() 

92 if raw_query is not None: 

93 query = Query(raw_query) 

94 

95 # validate the query, to make sure it is of a permitted form 

96 if not self._validate_query(cfg, query): 

97 raise exceptions.AuthoriseException() 

98 

99 # add any required filters to the query 

100 query = self._pre_filter_search_query(cfg, query) 

101 return query 

102 

103 def _get_dao_klass(self, cfg): 

104 # get the name of the model that will handle this query, and then look up 

105 # the class that will handle it 

106 dao_name = cfg.get("dao") 

107 dao_klass = plugin.load_class(dao_name) 

108 if dao_klass is None: 

109 raise exceptions.NoSuchObjectException(dao_name) 

110 return dao_klass 

111 

112 def search(self, domain, index_type, raw_query, account, additional_parameters): 

113 cfg = self._get_config_for_search(domain, index_type, account) 

114 

115 # check that the request values permit a query to this endpoint 

116 required_parameters = cfg.get("required_parameters") 

117 if required_parameters is not None: 

118 for k, vs in required_parameters.items(): 

119 val = additional_parameters.get(k) 

120 if val is None or val not in vs: 

121 raise exceptions.AuthoriseException() 

122 

123 dao_klass = self._get_dao_klass(cfg) 

124 

125 # get the query 

126 query = self._get_query(cfg, raw_query) 

127 

128 # send the query 

129 res = dao_klass.query(q=query.as_dict()) 

130 

131 # filter the results as needed 

132 res = self._post_filter_search_results(cfg, res) 

133 

134 return res 

135 

136 def scroll(self, domain, index_type, raw_query, account, page_size, scan=False): 

137 cfg = self._get_config_for_search(domain, index_type, account) 

138 

139 dao_klass = self._get_dao_klass(cfg) 

140 

141 # get the query 

142 query = self._get_query(cfg, raw_query) 

143 

144 # get the scroll parameters 

145 if page_size is None: 

146 page_size = cfg.get("page_size", 1000) 

147 limit = cfg.get("limit", None) 

148 keepalive = cfg.get("keepalive", "1m") 

149 

150 # ~~->Elasticsearch:Technology~~ 

151 for result in dao_klass.iterate(q=query.as_dict(), page_size=page_size, limit=limit, wrap=False, keepalive=keepalive): 

152 res = self._post_filter_search_results(cfg, result, unpacked=True) 

153 yield res 

154 

155 

156class Query(object): 

157 """ 

158 ~~Query:Query -> Elasticsearch:Technology~~ 

159 """ 

160 def __init__(self, raw=None, filtered=False): 

161 self.q = {"track_total_hits" : True, "query": {"match_all": {}}} if raw is None else raw 

162 self.filtered = filtered is True or self.q.get("query", {}).get("filtered") is not None 

163 if self.filtered: 

164 # FIXME: this is just to help us catch filtered queries during development. Once we have them 

165 # all, all the filtering logic in this class can come out 

166 raise Exception("Filtered queries are no longer supported") 

167 

168 def convert_to_bool(self): 

169 if self.filtered is True: 

170 return 

171 

172 current_query = None 

173 if "query" in self.q: 

174 if "bool" in self.q["query"]: 

175 return 

176 current_query = deepcopy(self.q["query"]) 

177 del self.q["query"] 

178 if len(list(current_query.keys())) == 0: 

179 current_query = None 

180 

181 if "query" not in self.q: 

182 self.q["query"] = {} 

183 if "bool" not in self.q["query"]: 

184 self.q["query"]["bool"] = {} 

185 

186 if current_query is not None: 

187 if "must" not in self.q["query"]["bool"]: 

188 self.q["query"]["bool"]["must"] = [] 

189 

190 self.q["query"]["bool"]["must"].append(current_query) 

191 

192 def add_must(self, filter): 

193 # self.convert_to_filtered() 

194 self.convert_to_bool() 

195 context = self.q["query"]["bool"] 

196 if "must" not in context: 

197 context["must"] = [] 

198 context["must"].append(filter) 

199 

200 def add_must_filter(self, filter): 

201 self.convert_to_bool() 

202 context = self.q["query"]["bool"] 

203 if "filter" not in context: 

204 context["filter"] = [] 

205 

206 context["filter"].append(filter) 

207 

208 def add_must_not(self, filter): 

209 self.convert_to_bool() 

210 context = self.q["query"]["bool"] 

211 if "must_not" not in context: 

212 context["must_not"] = [] 

213 

214 context["must_not"].append(filter) 

215 

216 def clear_match_all(self): 

217 if "match_all" in self.q["query"]: 

218 del self.q["query"]["match_all"] 

219 

220 def has_facets(self): 

221 return "facets" in self.q or "aggregations" in self.q or "aggs" in self.q 

222 

223 def clear_facets(self): 

224 if "facets" in self.q: 

225 del self.q["facets"] 

226 if "aggregations" in self.q: 

227 del self.q["aggregations"] 

228 if "aggs" in self.q: 

229 del self.q["aggs"] 

230 

231 def size(self): 

232 if "size" in self.q: 

233 try: 

234 return int(self.q["size"]) 

235 except ValueError: 

236 return 10 

237 return 10 

238 

239 def from_result(self): 

240 if "from" in self.q: 

241 return int(self.q["from"]) 

242 return 0 

243 

244 def as_dict(self): 

245 return self.q 

246 

247 def add_include(self, fields): 

248 if "_source" not in self.q: 

249 self.q["_source"] = {} 

250 if "includes" not in self.q["_source"]: 

251 self.q["_source"]["includes"] = [] 

252 if not isinstance(fields, list): 

253 fields = [fields] 

254 self.q["_source"]["includes"] = list(set(self.q["_source"]["includes"] + fields)) 

255 

256 def sort(self): 

257 return self.q.get("sort") 

258 

259 def set_sort(self, s): 

260 self.q["sort"] = s 

261 

262 

263class QueryFilterException(Exception): 

264 pass