Coverage for /opt/conda/envs/apienv/lib/python3.10/site-packages/daiquiri/core/components/proxy.py: 0%

140 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-14 02:13 +0000

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3import json 

4import logging 

5from typing import Dict, List 

6 

7import requests 

8from marshmallow import fields, Schema 

9from marshmallow.schema import SchemaMeta 

10from flask import make_response 

11from werkzeug.routing import Rule, Map 

12from urllib.parse import urlparse 

13 

14from daiquiri.core import marshal 

15from daiquiri.core.components import Component, ComponentResource 

16from daiquiri.core.schema.proxy import METHODS, ProxyConfigSchema 

17from daiquiri.resources.utils import get_resource_provider 

18 

19logger = logging.getLogger(__name__) 

20 

21PARAM_MAP = { 

22 "boolean": fields.Boolean, 

23 "integer": fields.Int, 

24 "number": fields.Float, 

25 "string": fields.Str, 

26} 

27 

28 

29class Proxy(Component): 

30 """Proxy http/s requests from another service""" 

31 

32 _config_schema = ProxyConfigSchema() 

33 

34 def _retrieve(self, path, object): 

35 if len(path): 

36 node = path.pop(0) 

37 return self._retrieve(path, object[node]) 

38 else: 

39 return object 

40 

41 def _resolve_schema(self, schema, spec): 

42 path = schema["$ref"].replace("#/", "").split("/") 

43 node = self._retrieve(path, spec) 

44 return self._parse_properties(node, spec) 

45 

46 def _parse_properties(self, node, spec): 

47 spec_v3 = spec.get("openapi", "").startswith("3.") 

48 properties = node["properties"] 

49 params = {} 

50 try: 

51 for property_name, property in properties.items(): 

52 if spec_v3: 

53 params[property_name] = PARAM_MAP[property["schema"]["type"]]() 

54 else: 

55 if "type" not in property: 

56 schema = Schema.from_dict( 

57 self._resolve_schema(property, spec), 

58 name=f"{property_name}Schema", 

59 ) 

60 params[property_name] = fields.Nested( 

61 schema, required=(property_name in node.get("required", [])) 

62 ) 

63 elif property["type"] == "array": 

64 params[property_name] = fields.List( 

65 PARAM_MAP[property["items"]["type"]]( 

66 required=(property_name in node.get("required", [])) 

67 ) 

68 ) 

69 else: 

70 params[property_name] = PARAM_MAP[property["type"]]( 

71 required=(property_name in node.get("required", [])) 

72 ) 

73 except Exception: 

74 logger.error( 

75 "Error while reading node '%s' param %s:\n%s", 

76 node, 

77 property_name, 

78 property, 

79 exc_info=True, 

80 ) 

81 

82 return params 

83 

84 def create_params_from_schema( 

85 self, 

86 url: str, 

87 method_name: str, 

88 openapi_spec: dict, 

89 ) -> Dict[str, fields.Field]: 

90 """ 

91 "paths": { 

92 "/api/chat": { 

93 "get": { 

94 "description": "Get the last n chat messages", 

95 "parameters": [ 

96 { 

97 "in": "query", 

98 "name": "limit", 

99 "required": false, 

100 "type": "integer" 

101 }, 

102 { 

103 "in": "query", 

104 "name": "offset", 

105 "required": false, 

106 "type": "integer" 

107 } 

108 ], 

109 } 

110 } 

111 } 

112 """ 

113 params = {} 

114 nb_errors = 0 

115 spec_v3 = openapi_spec.get("openapi", "").startswith("3.") 

116 paths = openapi_spec["paths"] 

117 for path, path_definition in paths.items(): 

118 if url.endswith(path): 

119 parameters = path_definition[method_name]["parameters"] 

120 try: 

121 for parameter_id, parameter in enumerate(parameters): 

122 if spec_v3: 

123 params[parameter["name"]] = PARAM_MAP[ 

124 parameter["schema"]["type"] 

125 ]() 

126 else: 

127 if "type" not in parameter: 

128 fields_dict = self._resolve_schema( 

129 parameter["schema"], openapi_spec 

130 ) 

131 

132 if parameter.get("in") == "body": 

133 params = Schema.from_dict( 

134 fields_dict, 

135 name=f"{parameter['name']}Schema", 

136 ) 

137 else: 

138 params[parameter["name"]] = fields_dict 

139 

140 elif parameter["type"] == "array": 

141 params[parameter["name"]] = fields.List( 

142 PARAM_MAP[parameter["items"]["type"]]( 

143 required=parameter.get("required") 

144 ) 

145 ) 

146 else: 

147 params[parameter["name"]] = PARAM_MAP[ 

148 parameter["type"] 

149 ](required=parameter.get("required")) 

150 except Exception: 

151 logger.error( 

152 "Error while reading spec from path '%s' param %s:\n%s", 

153 path, 

154 parameter_id, 

155 parameter, 

156 exc_info=True, 

157 ) 

158 nb_errors += 1 

159 

160 if nb_errors: 

161 raise RuntimeError("Failed to read proxy api specification") 

162 

163 return params 

164 

165 def create_proxy( 

166 self, 

167 method_name: str, 

168 url: str, 

169 name: str, 

170 openapi_spec: dict, 

171 headers=None, 

172 ) -> callable: 

173 """Create a proxied route""" 

174 

175 def proxy(self, *args, **kwargs): 

176 method = getattr(requests, method_name) 

177 full_url = url 

178 

179 if kwargs: 

180 url_parts = urlparse(url) 

181 url_rule = Rule(url_parts.path) 

182 Map([url_rule]) 

183 built_url = url_rule.build(kwargs) 

184 full_url = f"{url_parts.scheme}://{url_parts.netloc}{built_url[1]}" 

185 

186 logger.debug(f"Proxying to {full_url}") 

187 

188 try: 

189 if method_name == "post": 

190 logger.debug("Proxying post request with body: %s", kwargs) 

191 response = method(full_url, params=kwargs, headers=headers) 

192 else: 

193 response = method(full_url, headers=headers) 

194 except Exception as e: 

195 logger.error("Proxy '%s' failed: %s", full_url, e.args[0]) 

196 return make_response((f"Proxy not available: {e.args[0]}", 404)) 

197 

198 return make_response( 

199 (response.content, response.status_code, list(response.headers.items())) 

200 ) 

201 

202 proxy.__name__ = method_name 

203 proxy.__qualname__ = method_name 

204 proxy.__doc__ = name 

205 

206 schema = self.create_params_from_schema(url, method_name, openapi_spec) 

207 if isinstance(schema, SchemaMeta): 

208 inp = schema 

209 else: 

210 inp = dict(schema) 

211 

212 return marshal(inp=inp)(proxy) 

213 

214 def setup(self, *args, **kwargs) -> None: 

215 for proxy in self._config.get("proxies", []): 

216 self.setup_proxy(proxy) 

217 

218 def setup_proxy(self, proxy): 

219 """Setup a new proxy""" 

220 openapi_spec = self._read_openapi_spec(proxy["openapi"]) 

221 try: 

222 routes_description = proxy["routes"] 

223 except KeyError: # Load routes from openapi file 

224 routes_description = self._get_openapi_routes(openapi_spec) 

225 

226 routes = {} 

227 for route in routes_description: 

228 route_name = route["name"].lstrip("/") 

229 mapping = f"{proxy['name'].strip('/')}/{route_name}" 

230 

231 for method in route.get("methods", ["get"]): 

232 routes.setdefault(mapping, {})[method] = self.create_proxy( 

233 method, 

234 f"{proxy['target'].rstrip('/')}/{route_name}", 

235 mapping, 

236 openapi_spec, 

237 headers=route.get("headers"), 

238 ) 

239 

240 for path, methods in routes.items(): 

241 methods_ = list(methods.values()) 

242 proxied_route = type( 

243 f"{methods_[0].__doc__}Proxy", 

244 (ComponentResource,), 

245 methods, 

246 ) 

247 self.register_route(proxied_route, f"/{path}") 

248 

249 def _read_openapi_spec(self, url: str) -> dict: 

250 """ 

251 Returns an OpenAPI spec structure from an URI. 

252 

253 Arguments: 

254 url: Can use protocol `http[s]://` to read the specification from 

255 a web service, `file://` to it from the file system. 'res://' to 

256 read it from Daiquiri resources. 

257 

258 """ 

259 url_parts = urlparse(url) 

260 if url_parts.scheme in ("", "file"): 

261 with open(url_parts.netloc + url_parts.path, "r") as f: 

262 return json.load(f) 

263 

264 if url_parts.scheme in ("res"): 

265 provider = get_resource_provider() 

266 resource = provider.abs_resource(url_parts.netloc + url_parts.path) 

267 with provider.open_resource(resource, mode="t") as f: 

268 return json.load(f) 

269 

270 response = requests.get(url, timeout=10) 

271 if response.status_code == 200: 

272 return response.json() 

273 else: 

274 raise RuntimeError( 

275 f"Could not load OpenAPI spec {url}: {response.status_code}" 

276 ) 

277 

278 def _get_openapi_routes(self, openapi_spec: dict) -> List[dict]: 

279 routes = [] 

280 for name, methods_desc in openapi_spec["paths"].items(): 

281 # Path arguments conversion 

282 for method in methods_desc.values(): 

283 for param in method.get("parameters", ()): 

284 if param["in"] == "path": 

285 name = name.replace( 

286 f"{{{param['name']}}}", f"<{param['type']}:{param['name']}>" 

287 ) 

288 

289 methods = [m for m in methods_desc.keys() if m in METHODS] 

290 routes.append(dict(name=name, methods=methods)) 

291 return routes