Coverage for /opt/conda/envs/apienv/lib/python3.10/site-packages/daiquiri/core/schema/__init__.py: 90%

178 statements  

« prev     ^ index     » next       coverage.py v7.6.5, created at 2024-11-15 02:12 +0000

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4from __future__ import annotations 

5import logging 

6import typing 

7 

8from marshmallow import Schema as MarshmallowSchema, fields, ValidationError, INCLUDE 

9from marshmallow_jsonschema import JSONSchema 

10 

11from flask import g 

12 

13from daiquiri.core import ( 

14 CoreBase, 

15 CoreResource, 

16 marshal, 

17 require_control, 

18) 

19from daiquiri.core.schema.metadata import paginated 

20from daiquiri.core.schema.union import Union 

21 

22logger = logging.getLogger(__name__) 

23 

24 

25class ErrorSchema(MarshmallowSchema): 

26 error = fields.Str() 

27 

28 

29class MessageSchema(MarshmallowSchema): 

30 message = fields.Str() 

31 

32 

33class ValidationErrorSchema(MarshmallowSchema): 

34 messages = fields.Dict() 

35 

36 

37class SchemasListSchema(MarshmallowSchema): 

38 name = fields.Str() 

39 

40 

41class SchemaSchema(MarshmallowSchema): 

42 definitions = fields.Dict() 

43 asyncValidate = fields.Bool( 

44 metadata={"description": "Whether the schema can be asynchronously validated"} 

45 ) 

46 cache = fields.Bool( 

47 metadata={"description": "Whether values for this schema can be cached"} 

48 ) 

49 method = fields.Str(metadata={"description": "Http method for this schema"}) 

50 url = fields.Str(metadata={"description": "Url to submit this schema to"}) 

51 uiorder = fields.List(fields.Str(), metadata={"description": "Field order"}) 

52 uischema = fields.Dict( 

53 metadata={"description": "Override the default field widgets"} 

54 ) 

55 uigroups = fields.List( 

56 Union(fields=[fields.Dict(), fields.Str()]), 

57 metadata={"description": "Field grouping"}, 

58 ) 

59 presets = fields.Dict(metadata={"description": "Presets for this schema"}) 

60 save_presets = fields.Bool( 

61 metadata={"description": "Whether this schema can save presets"} 

62 ) 

63 auto_load_preset = fields.Bool( 

64 metadata={"description": "Whether this schema will auto load the last preset"} 

65 ) 

66 exception = fields.Str() 

67 traceback = fields.Str() 

68 

69 class Meta: 

70 additional = ["$ref", "$schema"] 

71 

72 

73class SchemasResource(CoreResource): 

74 def get(self): 

75 """All schemas in a JSON API spec format""" 

76 return self._parent.schemas(), 200 

77 

78 

79class SchemaListResource(CoreResource): 

80 @marshal( 

81 out=[[200, paginated(SchemasListSchema), "A list of available schema names"]] 

82 ) 

83 def get(self): 

84 """List of schema names""" 

85 schemas = [{"name": n} for n in self._parent.list()] 

86 return {"total": len(schemas), "rows": schemas} 

87 

88 

89class SchemaResource(CoreResource): 

90 @marshal( 

91 out=[ 

92 [200, SchemaSchema(), "The requested schema"], 

93 [404, ErrorSchema(), "No such schema"], 

94 ] 

95 ) 

96 def get(self, name): 

97 """Get a single schema in JSON API Spec format""" 

98 schema = self._parent.get(name) 

99 if schema: 

100 return schema, 200 

101 else: 

102 return {"error": "No such schema"}, 404 

103 

104 

105class SchemaValidateResource(CoreResource): 

106 @require_control 

107 @marshal( 

108 inp={"data": fields.Dict(required=True), "name": fields.Str()}, 

109 out=[[404, ErrorSchema(), "No such schema"]], 

110 ) 

111 def post(self, name, **kwargs): 

112 """Validate and compute calculated parameters for a schema""" 

113 validated = self._parent.validate(name, kwargs["data"]) 

114 if validated: 

115 return validated, 200 

116 else: 

117 return {"error": "No such schema"}, 404 

118 

119 

120class SavePresetResource(CoreResource): 

121 @require_control 

122 @marshal( 

123 inp={ 

124 "data": fields.Dict(required=True), 

125 "name": fields.Str(), 

126 "preset": fields.Str(), 

127 }, 

128 out=[ 

129 [200, MessageSchema(), "Schema preset saved"], 

130 [400, ErrorSchema(), "Could not save schema preset"], 

131 ], 

132 ) 

133 def post(self, name, **kwargs): 

134 """Save a schema preset""" 

135 try: 

136 self._parent.save_preset(name, kwargs["preset"], kwargs["data"]) 

137 return {"message": f"Preset {kwargs['preset']} saved for {name}"}, 200 

138 except Exception as e: 

139 return {"error": f"Could not save preset, {str(e)}"}, 400 

140 logger.exception(f"Could not save preset {kwargs['preset']} for {name}") 

141 

142 

143class Schema(CoreBase): 

144 """The schema handler 

145 

146 Allows schemas to be registered, and retrieved via a flask resource 

147 """ 

148 

149 _require_session = True 

150 _schemas = {} 

151 

152 def setup(self): 

153 self._schema = self 

154 self.register_route(SchemasResource, "") 

155 self.register_route(SchemaListResource, "/list") 

156 self.register_route(SchemaResource, "/<string:name>") 

157 self.register_route(SchemaValidateResource, "/validate/<string:name>") 

158 self.register_route(SavePresetResource, "/preset/<string:name>") 

159 

160 def set_session(self, session): 

161 self._session = session 

162 

163 def register(self, schema, url=None, method=None): 

164 """Register a schema under its class name""" 

165 cls = schema.__class__.__name__ 

166 if cls not in self._schemas: 

167 self._schemas[cls] = schema 

168 

169 if url and method: 

170 self._schemas[cls].url = {"url": url, "method": method} 

171 

172 def schemas(self): 

173 """Get all schemas""" 

174 return {key: self.get(key) for key in self._filter_schemas()} 

175 

176 def list(self): 

177 """Get a list of registered schemas""" 

178 return self._filter_schemas() 

179 

180 def _filter_schemas(self): 

181 keys = [] 

182 for key, schema in self._schemas.items(): 

183 if (hasattr(schema, "_require_staff") and g.user.staff()) or not hasattr( 

184 schema, "_require_staff" 

185 ): 

186 keys.append(key) 

187 return keys 

188 

189 def iterate_schema(self, flds, root): 

190 for f, p in flds.items(): 

191 if isinstance(p, fields.List): 

192 inner = p.inner 

193 

194 if isinstance(inner, fields.Nested): 

195 nested = inner.nested 

196 

197 if nested.Meta: 

198 for k in [ 

199 "uiorder", 

200 "uischema", 

201 "presets", 

202 "cache", 

203 "uigroups", 

204 ]: 

205 if hasattr(nested.Meta, k) and nested.__name__ in root: 

206 root[nested.__name__][k] = getattr(nested.Meta, k) 

207 

208 # print('iterate_schema list nested', f, root) 

209 if nested.__name__ in root: 

210 self.iterate_schema( 

211 nested._declared_fields, root[nested.__name__] 

212 ) 

213 

214 if isinstance(p, fields.Nested): 

215 nested = p.nested 

216 

217 if nested.Meta: 

218 for k in ["uiorder", "uischema", "presets", "cache", "uigroups"]: 

219 if hasattr(nested.Meta, k) and nested.__name__ in root: 

220 root[nested.__name__][k] = getattr(nested.Meta, k) 

221 

222 # print('iterate_schema nested', f, dir(nested)) 

223 if nested.__name__ in root: 

224 self.iterate_schema(nested._declared_fields, root[nested.__name__]) 

225 

226 def get(self, name): 

227 """Get a specific schema 

228 

229 Args: 

230 name (str): The schema to retreive 

231 """ 

232 if name in self.list(): 

233 schema = self._schemas[name] 

234 

235 json_schema = JSONSchema() 

236 json = json_schema.dump(schema) 

237 

238 if hasattr(schema, "url"): 

239 json["url"] = schema.url["url"] 

240 json["method"] = schema.url["method"] 

241 

242 if hasattr(schema, "exception"): 

243 json["exception"] = schema.exception 

244 json["traceback"] = schema.traceback 

245 

246 #  Try to attach ui:schema,order to nested schemas 

247 self.iterate_schema(schema.fields, json["definitions"]) 

248 

249 if schema.Meta: 

250 for k in ["uiorder", "uischema", "uigroups", "presets", "cache"]: 

251 if hasattr(schema.Meta, k): 

252 json[k] = getattr(schema.Meta, k) 

253 

254 asyncValidate = False 

255 for m in ["calculated", "schema_validate", "time_estimate"]: 

256 if hasattr(schema, m): 

257 asyncValidate = True 

258 

259 if hasattr(schema, "get_presets"): 

260 json["presets"] = schema.get_presets() 

261 

262 json["save_presets"] = hasattr(schema, "save_preset") 

263 

264 json["asyncValidate"] = asyncValidate 

265 

266 return json 

267 

268 def validate(self, name: str, data: dict[str, typing.Any]): 

269 """Async Schema Validation 

270 

271 This allows for schemas to make asynchronous validation, i.e. to check 

272 combinations of beamline parameters. 

273 

274 It will also compute any calculated parameters if the schema defines them 

275 and also a time estimate if defined 

276 

277 Args: 

278 name: Schema name 

279 data: Kwargs to validate 

280 

281 Returns: 

282 errors (dict): Dictionary of errors 

283 warnings (dict): Dictionary of warnings 

284 calculated (dict): Dictionary of calculated params 

285 time_estimate (float): Time estimate for these params 

286 """ 

287 sch = self._schemas.get(name, None) 

288 if sch is None: 

289 logger.error(f"Schema name {name} does not exist") 

290 return None 

291 

292 debug = getattr(sch, "DEBUG", False) 

293 validated = { 

294 "errors": {}, 

295 "warnings": {}, 

296 "calculated": {}, 

297 "time_estimate": None, 

298 } 

299 

300 def validate_func(sch, func_name: str): 

301 nonlocal validated 

302 validator = getattr(sch, func_name, None) 

303 if validator is not None: 

304 if debug: 

305 logger.info(f"Execute {name}.{func_name}") 

306 try: 

307 result = validator(data) 

308 if debug: 

309 logger.info(f"Result: {result}", exc_info=True) 

310 if result is not None: 

311 validated[func_name] = result 

312 except Exception as e: 

313 if debug: 

314 logger.error( 

315 f"{name}.{func_name} failed with: {e}", exc_info=True 

316 ) 

317 else: 

318 logger.debug( 

319 f"{name}.{func_name} failed with: {e}", exc_info=True 

320 ) 

321 

322 validate_func(sch, "calculated") 

323 data.update(validated["calculated"]) 

324 

325 try: 

326 sch.load(data, unknown=INCLUDE) 

327 except ValidationError as err: 

328 if debug: 

329 logger.error(f"{name} validation failed with: {err}", exc_info=True) 

330 if "_schema" in err.messages: 

331 validated["errors"] = {"schema": err.messages["_schema"]} 

332 

333 validate_func(sch, "warnings") 

334 validate_func(sch, "time_estimate") 

335 

336 return validated 

337 

338 def save_preset(self, name, preset, data): 

339 """Save a preset for a schema 

340 

341 Args: 

342 name (str): The schema name 

343 preset (str): The preset name 

344 data (dict): The preset data 

345 

346 Returns: 

347 bool: True if successful 

348 """ 

349 if name in self._schemas: 

350 sch = self._schemas[name] 

351 

352 if hasattr(sch, "save_preset"): 

353 sch.save_preset(preset, data)