Coverage for /opt/conda/envs/apienv/lib/python3.10/site-packages/daiquiri/cli/server.py: 49%

220 statements  

« prev     ^ index     » next       coverage.py v7.6.5, created at 2024-11-15 02:12 +0000

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3try: 

4 import bliss # noqa F401 

5 

6 # Importing bliss patches gevent which in turn patches python 

7except ImportError: 

8 from gevent.monkey import patch_all 

9 

10 patch_all(thread=False) 

11 

12# TODO: Horrible hack to avoid: 

13# ImportError: dlopen: cannot load any more object with static TLS 

14# https://github.com/pytorch/pytorch/issues/2575 

15# https://github.com/scikit-learn/scikit-learn/issues/14485 

16from silx.math import colormap # noqa F401 

17 

18import time 

19import os 

20import json 

21import argparse 

22import pydantic 

23from ruamel.yaml import YAML 

24from flask import Flask, jsonify 

25from flask_socketio import SocketIO 

26from flask_apispec import FlaskApiSpec 

27from flask_restful import abort 

28from flask import send_from_directory 

29from webargs.flaskparser import parser 

30from apispec import APISpec 

31from apispec.ext.marshmallow import MarshmallowPlugin 

32 

33import daiquiri 

34from daiquiri.core.authenticator import Authenticator 

35from daiquiri.core.session import Session 

36from daiquiri.core.queue import Queue 

37from daiquiri.core.metadata import MetaData 

38from daiquiri.core.components import Components 

39from daiquiri.core.hardware import Hardware 

40from daiquiri.core.schema import Schema 

41from daiquiri.core.layout import Layout 

42from daiquiri.core.saving import Saving 

43from daiquiri.core.stomp import Stomp 

44from daiquiri.resources import utils 

45from daiquiri.core.logging import log 

46from daiquiri.core.responses import nocache 

47from daiquiri.core.options import ServerOptions 

48 

49import logging 

50 

51logger = logging.getLogger(__name__) 

52DAIQUIRI_ROOT = os.path.dirname(daiquiri.__file__) 

53 

54 

55def init_server( 

56 options: ServerOptions, 

57 testing: bool = False, 

58): 

59 """Instantiate a Flask application 

60 

61 :param options: Options use to setup the server 

62 :param testing: Attaches core components to the `app` so that can be retrieved in tests 

63 :returns: Flask, SocketIO 

64 """ 

65 start = time.time() 

66 for resource_folder in reversed(options.resource_folders): 

67 if not resource_folder: 

68 continue 

69 if not os.path.isdir(resource_folder): 

70 raise ValueError(f"Resource folder '{resource_folder}' does not exist") 

71 utils.add_resource_root(resource_folder) 

72 logger.info("Added resource folder: %s", resource_folder) 

73 

74 static_folder = options.static_folder 

75 if static_folder.startswith("static."): 

76 provider = utils.get_resource_provider() 

77 try: 

78 static_folder = provider.get_resource_path(static_folder, "") 

79 except utils.ResourceNotAvailable: 

80 raise ValueError( 

81 f"Static resource '{static_folder}' does not exist" 

82 ) from None 

83 if not os.path.isdir(static_folder): 

84 raise ValueError(f"Static folder '{static_folder}' does not exist") 

85 if options.hardware_folder: 

86 if not os.path.isdir(options.hardware_folder): 

87 raise ValueError( 

88 f"Hardware folder '{options.hardware_folder}' does not exist" 

89 ) 

90 os.environ["HWR_ROOT"] = options.hardware_folder 

91 

92 config = utils.ConfigDict("app.yml") 

93 

94 if not config.get("versions"): 

95 config["versions"] = [] 

96 

97 # Allow CLI to override implementors 

98 if options.implementors: 

99 config["implementors"] = options.implementors 

100 

101 log.start(config=config) 

102 logger.info(f"Starting daiquiri version: {daiquiri.__version__}") 

103 

104 app = Flask(__name__, static_folder=static_folder, static_url_path="/") 

105 

106 app.config["APISPEC_FORMAT_RESPONSE"] = None 

107 app.config["APISPEC_TITLE"] = "daiquiri" 

108 

109 security_definitions = { 

110 "bearer": {"type": "apiKey", "in": "header", "name": "Authorization"} 

111 } 

112 

113 app.config["APISPEC_SPEC"] = APISpec( 

114 title="daiquiri", 

115 version="v1", 

116 openapi_version="2.0", 

117 plugins=[MarshmallowPlugin()], 

118 securityDefinitions=security_definitions, 

119 ) 

120 

121 if not config.get("swagger"): 

122 app.config["APISPEC_SWAGGER_URL"] = None 

123 app.config["APISPEC_SWAGGER_UI_URL"] = None 

124 

125 docs = FlaskApiSpec(app) 

126 

127 sio_kws = {} 

128 if config["cors"]: 

129 from flask_cors import CORS 

130 

131 sio_kws["cors_allowed_origins"] = "*" 

132 CORS(app) 

133 

134 app.config["SECRET_KEY"] = config["iosecret"] 

135 socketio = SocketIO(app, **sio_kws) 

136 

137 log.init_sio(socketio) 

138 

139 @app.errorhandler(404) 

140 def page_not_found(e): 

141 return jsonify(error=str(e)), 404 

142 

143 @app.errorhandler(405) 

144 def method_not_allowed(e): 

145 return jsonify(message="The method is not allowed for the requested URL."), 405 

146 

147 # @app.errorhandler(422) 

148 # def unprocessable(e): 

149 # return jsonify(error=str(e)), 422 

150 

151 @parser.error_handler 

152 def handle_request_parsing_error( 

153 err, req, schema, error_status_code, error_headers 

154 ): 

155 abort(422, description=err.messages) 

156 

157 if not app.debug or os.environ.get("WERKZEUG_RUN_MAIN") == "true": 

158 schema = Schema(app=app, docs=docs, socketio=socketio) 

159 ses = Session( 

160 config=config, app=app, docs=docs, socketio=socketio, schema=schema 

161 ) 

162 schema.set_session(ses) 

163 

164 Authenticator(config=config, app=app, session=ses, docs=docs, schema=schema) 

165 

166 # TODO 

167 # This is BLISS specific, move to where bliss is handled ? 

168 if config.get("controls_session_type", None) == "bliss": 

169 # This is not very elegant but a solution until this section 

170 # have been moved to an appropriate place 

171 from daiquiri.core.hardware.bliss.session import BlissSession 

172 

173 BlissSession(config["controls_session_name"]) 

174 

175 hardware = Hardware( 

176 base_config=config, app=app, socketio=socketio, docs=docs, schema=schema 

177 ) 

178 

179 queue = Queue( 

180 config=config, 

181 app=app, 

182 session=ses, 

183 docs=docs, 

184 socketio=socketio, 

185 schema=schema, 

186 ) 

187 

188 Layout( 

189 config, app=app, session=ses, docs=docs, socketio=socketio, schema=schema 

190 ) 

191 

192 metadata = MetaData( 

193 config, app=app, session=ses, docs=docs, socketio=socketio, schema=schema 

194 ).init() 

195 ses.set_metadata(metadata) 

196 

197 saving = Saving( 

198 config, 

199 app=app, 

200 session=ses, 

201 docs=docs, 

202 socketio=socketio, 

203 schema=schema, 

204 metadata=metadata, 

205 ).init() 

206 

207 stomp = None 

208 if config.get("stomp_host"): 

209 stomp = Stomp( 

210 config=config, 

211 app=app, 

212 session=ses, 

213 docs=docs, 

214 socketio=socketio, 

215 schema=schema, 

216 ) 

217 

218 components = Components( 

219 base_config=config, 

220 app=app, 

221 socketio=socketio, 

222 docs=docs, 

223 schema=schema, 

224 hardware=hardware, 

225 session=ses, 

226 metadata=metadata, 

227 queue=queue, 

228 saving=saving, 

229 stomp=stomp, 

230 ) 

231 

232 if options.save_spec_file is not None: 

233 logger.info("Writing API spec and exiting") 

234 save_spec_dir = os.path.dirname(options.save_spec_file) 

235 if not os.path.exists(save_spec_dir): 

236 os.mkdir(save_spec_dir) 

237 with open(options.save_spec_file, "w") as spec: 

238 json.dump(docs.spec.to_dict(), spec) 

239 

240 exit() 

241 

242 if config["debug"] is True: 

243 app.debug = True 

244 

245 if testing: 

246 app.hardware = hardware 

247 app.queue = queue 

248 app.metadata = metadata 

249 app.components = components 

250 app.session = ses 

251 app.socketio = socketio 

252 app.saving = saving 

253 

254 def close(): 

255 components.close() 

256 

257 app.close = close 

258 

259 @app.route("/manifest.json") 

260 def manifest(): 

261 return app.send_static_file("manifest.json") 

262 

263 @app.route("/meta.json") 

264 @nocache 

265 def meta(): 

266 return app.send_static_file("meta.json") 

267 

268 @app.route("/favicon.ico") 

269 def favicon(): 

270 return app.send_static_file("favicon.ico") 

271 

272 if options.static_resources_folder: 

273 default_static_resources = os.path.join(static_folder, "resources") 

274 static_resources_folder = os.path.abspath(options.static_resources_folder) 

275 

276 @app.route("/resources/<path:path>") 

277 def resources(path): 

278 """Serve a resource file from static_resources_folder if it exists. 

279 

280 Else fall back to the original static directory. 

281 """ 

282 if os.path.isfile(os.path.join(static_resources_folder, path)): 

283 return send_from_directory(static_resources_folder, path) 

284 return send_from_directory(default_static_resources, path) 

285 

286 @app.route("/", defaults={"path": ""}) 

287 @app.route("/<string:path>") 

288 @app.route("/<path:path>") 

289 @nocache 

290 def index(path): 

291 return app.send_static_file("index.html") 

292 

293 took = round(time.time() - start, 2) 

294 logger.info(f"Server ready, startup took {took}s", extra={"startup_time": took}) 

295 

296 return app, socketio 

297 

298 

299def get_certs_file_path(resource): 

300 """Get a file path from a resource path 

301 

302 This can be: 

303 - foobar.txt # read file from resource provides "certs" 

304 - /etc/foobar.txt # read file from absolute path 

305 """ 

306 if resource.startswith("/"): 

307 return os.path.abspath(resource) 

308 provider = utils.get_resource_provider() 

309 return provider.get_resource_path("certs", resource) 

310 

311 

312def get_ssl_context(options: ServerOptions): 

313 """Build an ssl context for serving over HTTPS. 

314 

315 The options get the priority on the configuration. 

316 

317 The configurations should be dropped at some point for security. 

318 """ 

319 config = utils.ConfigDict("app.yml") 

320 

321 use_ssl = options.ssl or config.get("ssl", False) 

322 if not use_ssl: 

323 return None 

324 

325 import ssl 

326 

327 crt = get_certs_file_path(options.ssl_cert or config["ssl_cert"]) 

328 key = get_certs_file_path(options.ssl_key or config["ssl_key"]) 

329 

330 context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) 

331 try: 

332 context.load_cert_chain(crt, key) 

333 except Exception: 

334 logger.exception("Could not load certificate chain", exc_info=True) 

335 raise 

336 

337 return context 

338 

339 

340def run_server(options: ServerOptions): 

341 """Runs REST server 

342 

343 :param int port: 

344 :param kwargs: see `init_server` 

345 """ 

346 app, socketio = init_server(options) 

347 

348 server_args = {} 

349 context = get_ssl_context(options) 

350 if context: 

351 server_args["ssl_context"] = context 

352 

353 socketio.run(app, host="0.0.0.0", port=options.port, **server_args) # nosec 

354 

355 

356def add_pydantic_model(parser: argparse.ArgumentParser, model: pydantic.BaseModel): 

357 "Add Pydantic model to an ArgumentParser" 

358 fields = model.__fields__ 

359 for name, field in fields.items(): 

360 extra = field.json_schema_extra or {} 

361 name_or_flags = [] 

362 if "flag" in extra: 

363 name_or_flags += [extra["argparse_flag"]] 

364 if "argname" in extra: 

365 name_or_flags += [extra["argparse_name"]] 

366 else: 

367 name_or_flags += [f"--{name.replace('_', '-')}"] 

368 

369 parser.add_argument( 

370 *name_or_flags, 

371 dest=name, 

372 nargs=extra.get("argparse_nargs"), 

373 type=extra.get("argparse_type") or field.annotation, 

374 default=None, # let pydantic handle the default later 

375 help=field.description, 

376 ) 

377 

378 

379def parse_options() -> ServerOptions: 

380 """Parse the server options. 

381 

382 Read the command line arguments and the optional configuration file. 

383 """ 

384 parser = argparse.ArgumentParser(description="REST server for a beamline GUI") 

385 parser.add_argument( 

386 "-c", 

387 "--config", 

388 type=str, 

389 default=None, 

390 help="File/resource which is used to setup the server (instead of the command line argument)", 

391 ) 

392 add_pydantic_model(parser, ServerOptions) 

393 cmd_args = parser.parse_args() 

394 

395 config_args = {} 

396 if cmd_args.config: 

397 logging.info("Read config file: %s", cmd_args.config) 

398 yaml = YAML(typ="safe") 

399 with open(cmd_args.config, mode="rt") as f: 

400 config_args = yaml.load(f) 

401 

402 merged_args = {} 

403 merged_args.update(config_args) 

404 # Only override command line arguments which are set 

405 merged_args.update({k: v for k, v in vars(cmd_args).items() if v is not None}) 

406 

407 options = ServerOptions(**merged_args) 

408 

409 # Sounds like flask except an absolute path 

410 if options.static_folder is not None: 

411 options.static_folder = os.path.abspath(options.static_folder) 

412 

413 return options 

414 

415 

416def main(): 

417 """Runs REST server with CLI configuration""" 

418 options = parse_options() 

419 run_server(options) 

420 

421 

422if __name__ == "__main__": 

423 main()