Coverage for /opt/conda/envs/apienv/lib/python3.10/site-packages/daiquiri/core/components/proxy.py: 0%
140 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-14 02:13 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-14 02:13 +0000
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import json
4import logging
5from typing import Dict, List
7import requests
8from marshmallow import fields, Schema
9from marshmallow.schema import SchemaMeta
10from flask import make_response
11from werkzeug.routing import Rule, Map
12from urllib.parse import urlparse
14from daiquiri.core import marshal
15from daiquiri.core.components import Component, ComponentResource
16from daiquiri.core.schema.proxy import METHODS, ProxyConfigSchema
17from daiquiri.resources.utils import get_resource_provider
19logger = logging.getLogger(__name__)
21PARAM_MAP = {
22 "boolean": fields.Boolean,
23 "integer": fields.Int,
24 "number": fields.Float,
25 "string": fields.Str,
26}
29class Proxy(Component):
30 """Proxy http/s requests from another service"""
32 _config_schema = ProxyConfigSchema()
34 def _retrieve(self, path, object):
35 if len(path):
36 node = path.pop(0)
37 return self._retrieve(path, object[node])
38 else:
39 return object
41 def _resolve_schema(self, schema, spec):
42 path = schema["$ref"].replace("#/", "").split("/")
43 node = self._retrieve(path, spec)
44 return self._parse_properties(node, spec)
46 def _parse_properties(self, node, spec):
47 spec_v3 = spec.get("openapi", "").startswith("3.")
48 properties = node["properties"]
49 params = {}
50 try:
51 for property_name, property in properties.items():
52 if spec_v3:
53 params[property_name] = PARAM_MAP[property["schema"]["type"]]()
54 else:
55 if "type" not in property:
56 schema = Schema.from_dict(
57 self._resolve_schema(property, spec),
58 name=f"{property_name}Schema",
59 )
60 params[property_name] = fields.Nested(
61 schema, required=(property_name in node.get("required", []))
62 )
63 elif property["type"] == "array":
64 params[property_name] = fields.List(
65 PARAM_MAP[property["items"]["type"]](
66 required=(property_name in node.get("required", []))
67 )
68 )
69 else:
70 params[property_name] = PARAM_MAP[property["type"]](
71 required=(property_name in node.get("required", []))
72 )
73 except Exception:
74 logger.error(
75 "Error while reading node '%s' param %s:\n%s",
76 node,
77 property_name,
78 property,
79 exc_info=True,
80 )
82 return params
84 def create_params_from_schema(
85 self,
86 url: str,
87 method_name: str,
88 openapi_spec: dict,
89 ) -> Dict[str, fields.Field]:
90 """
91 "paths": {
92 "/api/chat": {
93 "get": {
94 "description": "Get the last n chat messages",
95 "parameters": [
96 {
97 "in": "query",
98 "name": "limit",
99 "required": false,
100 "type": "integer"
101 },
102 {
103 "in": "query",
104 "name": "offset",
105 "required": false,
106 "type": "integer"
107 }
108 ],
109 }
110 }
111 }
112 """
113 params = {}
114 nb_errors = 0
115 spec_v3 = openapi_spec.get("openapi", "").startswith("3.")
116 paths = openapi_spec["paths"]
117 for path, path_definition in paths.items():
118 if url.endswith(path):
119 parameters = path_definition[method_name]["parameters"]
120 try:
121 for parameter_id, parameter in enumerate(parameters):
122 if spec_v3:
123 params[parameter["name"]] = PARAM_MAP[
124 parameter["schema"]["type"]
125 ]()
126 else:
127 if "type" not in parameter:
128 fields_dict = self._resolve_schema(
129 parameter["schema"], openapi_spec
130 )
132 if parameter.get("in") == "body":
133 params = Schema.from_dict(
134 fields_dict,
135 name=f"{parameter['name']}Schema",
136 )
137 else:
138 params[parameter["name"]] = fields_dict
140 elif parameter["type"] == "array":
141 params[parameter["name"]] = fields.List(
142 PARAM_MAP[parameter["items"]["type"]](
143 required=parameter.get("required")
144 )
145 )
146 else:
147 params[parameter["name"]] = PARAM_MAP[
148 parameter["type"]
149 ](required=parameter.get("required"))
150 except Exception:
151 logger.error(
152 "Error while reading spec from path '%s' param %s:\n%s",
153 path,
154 parameter_id,
155 parameter,
156 exc_info=True,
157 )
158 nb_errors += 1
160 if nb_errors:
161 raise RuntimeError("Failed to read proxy api specification")
163 return params
165 def create_proxy(
166 self,
167 method_name: str,
168 url: str,
169 name: str,
170 openapi_spec: dict,
171 headers=None,
172 ) -> callable:
173 """Create a proxied route"""
175 def proxy(self, *args, **kwargs):
176 method = getattr(requests, method_name)
177 full_url = url
179 if kwargs:
180 url_parts = urlparse(url)
181 url_rule = Rule(url_parts.path)
182 Map([url_rule])
183 built_url = url_rule.build(kwargs)
184 full_url = f"{url_parts.scheme}://{url_parts.netloc}{built_url[1]}"
186 logger.debug(f"Proxying to {full_url}")
188 try:
189 if method_name == "post":
190 logger.debug("Proxying post request with body: %s", kwargs)
191 response = method(full_url, params=kwargs, headers=headers)
192 else:
193 response = method(full_url, headers=headers)
194 except Exception as e:
195 logger.error("Proxy '%s' failed: %s", full_url, e.args[0])
196 return make_response((f"Proxy not available: {e.args[0]}", 404))
198 return make_response(
199 (response.content, response.status_code, list(response.headers.items()))
200 )
202 proxy.__name__ = method_name
203 proxy.__qualname__ = method_name
204 proxy.__doc__ = name
206 schema = self.create_params_from_schema(url, method_name, openapi_spec)
207 if isinstance(schema, SchemaMeta):
208 inp = schema
209 else:
210 inp = dict(schema)
212 return marshal(inp=inp)(proxy)
214 def setup(self, *args, **kwargs) -> None:
215 for proxy in self._config.get("proxies", []):
216 self.setup_proxy(proxy)
218 def setup_proxy(self, proxy):
219 """Setup a new proxy"""
220 openapi_spec = self._read_openapi_spec(proxy["openapi"])
221 try:
222 routes_description = proxy["routes"]
223 except KeyError: # Load routes from openapi file
224 routes_description = self._get_openapi_routes(openapi_spec)
226 routes = {}
227 for route in routes_description:
228 route_name = route["name"].lstrip("/")
229 mapping = f"{proxy['name'].strip('/')}/{route_name}"
231 for method in route.get("methods", ["get"]):
232 routes.setdefault(mapping, {})[method] = self.create_proxy(
233 method,
234 f"{proxy['target'].rstrip('/')}/{route_name}",
235 mapping,
236 openapi_spec,
237 headers=route.get("headers"),
238 )
240 for path, methods in routes.items():
241 methods_ = list(methods.values())
242 proxied_route = type(
243 f"{methods_[0].__doc__}Proxy",
244 (ComponentResource,),
245 methods,
246 )
247 self.register_route(proxied_route, f"/{path}")
249 def _read_openapi_spec(self, url: str) -> dict:
250 """
251 Returns an OpenAPI spec structure from an URI.
253 Arguments:
254 url: Can use protocol `http[s]://` to read the specification from
255 a web service, `file://` to it from the file system. 'res://' to
256 read it from Daiquiri resources.
258 """
259 url_parts = urlparse(url)
260 if url_parts.scheme in ("", "file"):
261 with open(url_parts.netloc + url_parts.path, "r") as f:
262 return json.load(f)
264 if url_parts.scheme in ("res"):
265 provider = get_resource_provider()
266 resource = provider.abs_resource(url_parts.netloc + url_parts.path)
267 with provider.open_resource(resource, mode="t") as f:
268 return json.load(f)
270 response = requests.get(url, timeout=10)
271 if response.status_code == 200:
272 return response.json()
273 else:
274 raise RuntimeError(
275 f"Could not load OpenAPI spec {url}: {response.status_code}"
276 )
278 def _get_openapi_routes(self, openapi_spec: dict) -> List[dict]:
279 routes = []
280 for name, methods_desc in openapi_spec["paths"].items():
281 # Path arguments conversion
282 for method in methods_desc.values():
283 for param in method.get("parameters", ()):
284 if param["in"] == "path":
285 name = name.replace(
286 f"{{{param['name']}}}", f"<{param['type']}:{param['name']}>"
287 )
289 methods = [m for m in methods_desc.keys() if m in METHODS]
290 routes.append(dict(name=name, methods=methods))
291 return routes