Coverage for /opt/conda/envs/apienv/lib/python3.10/site-packages/daiquiri/core/schema/__init__.py: 90%
178 statements
« prev ^ index » next coverage.py v7.6.5, created at 2024-11-15 02:12 +0000
« prev ^ index » next coverage.py v7.6.5, created at 2024-11-15 02:12 +0000
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4from __future__ import annotations
5import logging
6import typing
8from marshmallow import Schema as MarshmallowSchema, fields, ValidationError, INCLUDE
9from marshmallow_jsonschema import JSONSchema
11from flask import g
13from daiquiri.core import (
14 CoreBase,
15 CoreResource,
16 marshal,
17 require_control,
18)
19from daiquiri.core.schema.metadata import paginated
20from daiquiri.core.schema.union import Union
22logger = logging.getLogger(__name__)
25class ErrorSchema(MarshmallowSchema):
26 error = fields.Str()
29class MessageSchema(MarshmallowSchema):
30 message = fields.Str()
33class ValidationErrorSchema(MarshmallowSchema):
34 messages = fields.Dict()
37class SchemasListSchema(MarshmallowSchema):
38 name = fields.Str()
41class SchemaSchema(MarshmallowSchema):
42 definitions = fields.Dict()
43 asyncValidate = fields.Bool(
44 metadata={"description": "Whether the schema can be asynchronously validated"}
45 )
46 cache = fields.Bool(
47 metadata={"description": "Whether values for this schema can be cached"}
48 )
49 method = fields.Str(metadata={"description": "Http method for this schema"})
50 url = fields.Str(metadata={"description": "Url to submit this schema to"})
51 uiorder = fields.List(fields.Str(), metadata={"description": "Field order"})
52 uischema = fields.Dict(
53 metadata={"description": "Override the default field widgets"}
54 )
55 uigroups = fields.List(
56 Union(fields=[fields.Dict(), fields.Str()]),
57 metadata={"description": "Field grouping"},
58 )
59 presets = fields.Dict(metadata={"description": "Presets for this schema"})
60 save_presets = fields.Bool(
61 metadata={"description": "Whether this schema can save presets"}
62 )
63 auto_load_preset = fields.Bool(
64 metadata={"description": "Whether this schema will auto load the last preset"}
65 )
66 exception = fields.Str()
67 traceback = fields.Str()
69 class Meta:
70 additional = ["$ref", "$schema"]
73class SchemasResource(CoreResource):
74 def get(self):
75 """All schemas in a JSON API spec format"""
76 return self._parent.schemas(), 200
79class SchemaListResource(CoreResource):
80 @marshal(
81 out=[[200, paginated(SchemasListSchema), "A list of available schema names"]]
82 )
83 def get(self):
84 """List of schema names"""
85 schemas = [{"name": n} for n in self._parent.list()]
86 return {"total": len(schemas), "rows": schemas}
89class SchemaResource(CoreResource):
90 @marshal(
91 out=[
92 [200, SchemaSchema(), "The requested schema"],
93 [404, ErrorSchema(), "No such schema"],
94 ]
95 )
96 def get(self, name):
97 """Get a single schema in JSON API Spec format"""
98 schema = self._parent.get(name)
99 if schema:
100 return schema, 200
101 else:
102 return {"error": "No such schema"}, 404
105class SchemaValidateResource(CoreResource):
106 @require_control
107 @marshal(
108 inp={"data": fields.Dict(required=True), "name": fields.Str()},
109 out=[[404, ErrorSchema(), "No such schema"]],
110 )
111 def post(self, name, **kwargs):
112 """Validate and compute calculated parameters for a schema"""
113 validated = self._parent.validate(name, kwargs["data"])
114 if validated:
115 return validated, 200
116 else:
117 return {"error": "No such schema"}, 404
120class SavePresetResource(CoreResource):
121 @require_control
122 @marshal(
123 inp={
124 "data": fields.Dict(required=True),
125 "name": fields.Str(),
126 "preset": fields.Str(),
127 },
128 out=[
129 [200, MessageSchema(), "Schema preset saved"],
130 [400, ErrorSchema(), "Could not save schema preset"],
131 ],
132 )
133 def post(self, name, **kwargs):
134 """Save a schema preset"""
135 try:
136 self._parent.save_preset(name, kwargs["preset"], kwargs["data"])
137 return {"message": f"Preset {kwargs['preset']} saved for {name}"}, 200
138 except Exception as e:
139 return {"error": f"Could not save preset, {str(e)}"}, 400
140 logger.exception(f"Could not save preset {kwargs['preset']} for {name}")
143class Schema(CoreBase):
144 """The schema handler
146 Allows schemas to be registered, and retrieved via a flask resource
147 """
149 _require_session = True
150 _schemas = {}
152 def setup(self):
153 self._schema = self
154 self.register_route(SchemasResource, "")
155 self.register_route(SchemaListResource, "/list")
156 self.register_route(SchemaResource, "/<string:name>")
157 self.register_route(SchemaValidateResource, "/validate/<string:name>")
158 self.register_route(SavePresetResource, "/preset/<string:name>")
160 def set_session(self, session):
161 self._session = session
163 def register(self, schema, url=None, method=None):
164 """Register a schema under its class name"""
165 cls = schema.__class__.__name__
166 if cls not in self._schemas:
167 self._schemas[cls] = schema
169 if url and method:
170 self._schemas[cls].url = {"url": url, "method": method}
172 def schemas(self):
173 """Get all schemas"""
174 return {key: self.get(key) for key in self._filter_schemas()}
176 def list(self):
177 """Get a list of registered schemas"""
178 return self._filter_schemas()
180 def _filter_schemas(self):
181 keys = []
182 for key, schema in self._schemas.items():
183 if (hasattr(schema, "_require_staff") and g.user.staff()) or not hasattr(
184 schema, "_require_staff"
185 ):
186 keys.append(key)
187 return keys
189 def iterate_schema(self, flds, root):
190 for f, p in flds.items():
191 if isinstance(p, fields.List):
192 inner = p.inner
194 if isinstance(inner, fields.Nested):
195 nested = inner.nested
197 if nested.Meta:
198 for k in [
199 "uiorder",
200 "uischema",
201 "presets",
202 "cache",
203 "uigroups",
204 ]:
205 if hasattr(nested.Meta, k) and nested.__name__ in root:
206 root[nested.__name__][k] = getattr(nested.Meta, k)
208 # print('iterate_schema list nested', f, root)
209 if nested.__name__ in root:
210 self.iterate_schema(
211 nested._declared_fields, root[nested.__name__]
212 )
214 if isinstance(p, fields.Nested):
215 nested = p.nested
217 if nested.Meta:
218 for k in ["uiorder", "uischema", "presets", "cache", "uigroups"]:
219 if hasattr(nested.Meta, k) and nested.__name__ in root:
220 root[nested.__name__][k] = getattr(nested.Meta, k)
222 # print('iterate_schema nested', f, dir(nested))
223 if nested.__name__ in root:
224 self.iterate_schema(nested._declared_fields, root[nested.__name__])
226 def get(self, name):
227 """Get a specific schema
229 Args:
230 name (str): The schema to retreive
231 """
232 if name in self.list():
233 schema = self._schemas[name]
235 json_schema = JSONSchema()
236 json = json_schema.dump(schema)
238 if hasattr(schema, "url"):
239 json["url"] = schema.url["url"]
240 json["method"] = schema.url["method"]
242 if hasattr(schema, "exception"):
243 json["exception"] = schema.exception
244 json["traceback"] = schema.traceback
246 # Try to attach ui:schema,order to nested schemas
247 self.iterate_schema(schema.fields, json["definitions"])
249 if schema.Meta:
250 for k in ["uiorder", "uischema", "uigroups", "presets", "cache"]:
251 if hasattr(schema.Meta, k):
252 json[k] = getattr(schema.Meta, k)
254 asyncValidate = False
255 for m in ["calculated", "schema_validate", "time_estimate"]:
256 if hasattr(schema, m):
257 asyncValidate = True
259 if hasattr(schema, "get_presets"):
260 json["presets"] = schema.get_presets()
262 json["save_presets"] = hasattr(schema, "save_preset")
264 json["asyncValidate"] = asyncValidate
266 return json
268 def validate(self, name: str, data: dict[str, typing.Any]):
269 """Async Schema Validation
271 This allows for schemas to make asynchronous validation, i.e. to check
272 combinations of beamline parameters.
274 It will also compute any calculated parameters if the schema defines them
275 and also a time estimate if defined
277 Args:
278 name: Schema name
279 data: Kwargs to validate
281 Returns:
282 errors (dict): Dictionary of errors
283 warnings (dict): Dictionary of warnings
284 calculated (dict): Dictionary of calculated params
285 time_estimate (float): Time estimate for these params
286 """
287 sch = self._schemas.get(name, None)
288 if sch is None:
289 logger.error(f"Schema name {name} does not exist")
290 return None
292 debug = getattr(sch, "DEBUG", False)
293 validated = {
294 "errors": {},
295 "warnings": {},
296 "calculated": {},
297 "time_estimate": None,
298 }
300 def validate_func(sch, func_name: str):
301 nonlocal validated
302 validator = getattr(sch, func_name, None)
303 if validator is not None:
304 if debug:
305 logger.info(f"Execute {name}.{func_name}")
306 try:
307 result = validator(data)
308 if debug:
309 logger.info(f"Result: {result}", exc_info=True)
310 if result is not None:
311 validated[func_name] = result
312 except Exception as e:
313 if debug:
314 logger.error(
315 f"{name}.{func_name} failed with: {e}", exc_info=True
316 )
317 else:
318 logger.debug(
319 f"{name}.{func_name} failed with: {e}", exc_info=True
320 )
322 validate_func(sch, "calculated")
323 data.update(validated["calculated"])
325 try:
326 sch.load(data, unknown=INCLUDE)
327 except ValidationError as err:
328 if debug:
329 logger.error(f"{name} validation failed with: {err}", exc_info=True)
330 if "_schema" in err.messages:
331 validated["errors"] = {"schema": err.messages["_schema"]}
333 validate_func(sch, "warnings")
334 validate_func(sch, "time_estimate")
336 return validated
338 def save_preset(self, name, preset, data):
339 """Save a preset for a schema
341 Args:
342 name (str): The schema name
343 preset (str): The preset name
344 data (dict): The preset data
346 Returns:
347 bool: True if successful
348 """
349 if name in self._schemas:
350 sch = self._schemas[name]
352 if hasattr(sch, "save_preset"):
353 sch.save_preset(preset, data)