Coverage for /opt/conda/envs/apienv/lib/python3.10/site-packages/daiquiri/core/components/celery.py: 61%

256 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-14 02:13 +0000

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3from datetime import datetime 

4import logging 

5import os 

6import time 

7from typing import List, Dict, Union, Any 

8import uuid 

9import _queue 

10 

11import gevent 

12 

13from daiquiri.core import marshal 

14from daiquiri.core.components import ( 

15 actor, 

16 Component, 

17 ComponentResource, 

18 ComponentActor, 

19 ComponentActorSchema, 

20) 

21from daiquiri.core.schema import ErrorSchema, MessageSchema 

22from daiquiri.core.utils import get_start_end 

23from daiquiri.core.schema.celery import ( 

24 CeleryConfigSchema, 

25 TaskConfigSchema, 

26 WorkersSchema, 

27 ExecutedTasksSchema, 

28) 

29from daiquiri.core.schema.metadata import paginated 

30 

31from celery import Celery as CeleryApp 

32from celery.result import AsyncResult 

33from kombu import Queue, Connection 

34 

35try: 

36 from ewoksjob.events.readers import instantiate_reader 

37except ImportError: 

38 instantiate_reader = None 

39 

40try: 

41 from blissdata.beacon.files import read_config 

42except ImportError: 

43 read_config = None 

44 

45 

46logger = logging.getLogger(__name__) 

47 

48 

49class TasksResource(ComponentResource): 

50 @marshal(out=[[200, paginated(TaskConfigSchema), "List of available tasks"]]) 

51 def get(self): 

52 """Get a list of available tasks""" 

53 tasks = self._parent.get_tasks() 

54 return {"total": len(tasks), "rows": tasks} 

55 

56 

57class ExecutedTasksResource(ComponentResource): 

58 @marshal( 

59 out=[[200, paginated(ExecutedTasksSchema), "List of executed tasks"]], 

60 paged=True, 

61 ) 

62 def get(self, **kwargs): 

63 """Get a list of executed tasks""" 

64 return self._parent.get_executed_tasks(**kwargs) 

65 

66 

67class ExecutedTaskResource(ComponentResource): 

68 @marshal( 

69 out=[ 

70 [200, ExecutedTasksSchema, "Details of executed task"], 

71 [404, ErrorSchema, "No such task"], 

72 ] 

73 ) 

74 def get(self, job_id): 

75 """Get details of an executed tasks""" 

76 task = self._parent.get_task_info(job_id, get_result=True) 

77 if task: 

78 return task 

79 else: 

80 return {"error": "No such task"}, 404 

81 

82 

83class WorkersResource(ComponentResource): 

84 @marshal( 

85 out=[ 

86 [200, paginated(WorkersSchema), "List of workers status and their tasks"], 

87 ], 

88 ) 

89 def get(self): 

90 """Get a list of workers and their tasks""" 

91 workers = self._parent.get_workers() 

92 return {"total": len(workers), "rows": workers} 

93 

94 

95class RevokeTaskResource(ComponentResource): 

96 @marshal( 

97 out=[ 

98 [200, MessageSchema(), "Task successfully revoked"], 

99 [400, ErrorSchema(), "Could not revoke task"], 

100 ], 

101 ) 

102 def delete(self, task_id: str): 

103 """Revoke and terminate a running task""" 

104 try: 

105 self._parent.revoke(task_id) 

106 return {"message": f"Task id `{task_id}` revoked"}, 200 

107 except Exception as e: 

108 return {"error": f"Could not revoke task id `{task_id}`: {str(e)}"}, 400 

109 

110 

111class Celery(Component): 

112 """Monitor and execute celery tasks 

113 

114 https://docs.celeryq.dev/en/master/userguide/workers.html#inspecting-workers 

115 """ 

116 

117 _config_schema = CeleryConfigSchema() 

118 _last_emit_time = 0 

119 _emit_timeout = None 

120 

121 def setup(self, *args, **kwargs): 

122 self._tasks = {} 

123 

124 self.register_route(WorkersResource, "/workers") 

125 self.register_route(RevokeTaskResource, "/revoke/<string:task_id>") 

126 self.register_route(TasksResource, "/tasks") 

127 self.register_route(ExecutedTasksResource, "/tasks/executed") 

128 self.register_route(ExecutedTaskResource, "/tasks/executed/<string:job_id>") 

129 

130 self._celery_app = self._create_celery_app(self._config) 

131 self._default_queue = None 

132 if self._config.get("broker_dlq"): 

133 self._default_queue = Queue( 

134 self._config["broker_queue"], 

135 queue_arguments={ 

136 "x-dead-letter-exchange": "dlq", 

137 "x-dead-letter-routing-key": f"dlq.{self._config['broker_queue']}", 

138 }, 

139 ) 

140 

141 for task in self._config["tasks"]: 

142 if not self._config.get("actors"): 

143 self._config["actors"] = {} 

144 

145 self._config["actors"][task["actor"]] = task["actor"] 

146 

147 if task["actor"] in self._actors: 

148 continue 

149 

150 self._actors.append(task["actor"]) 

151 

152 def preprocess(self, *args, **kwargs): 

153 kwargs["send_task"] = self._parent.send_task 

154 kwargs["celery_app"] = self._parent._celery_app 

155 return kwargs 

156 

157 def post(self, *args, **kwargs): 

158 pass 

159 

160 task_res = type( 

161 task["actor"], 

162 (ComponentResource,), 

163 { 

164 "post": actor(task["actor"], preprocess=True, synchronous=True)( 

165 post 

166 ), 

167 "preprocess": preprocess, 

168 }, 

169 ) 

170 self.register_actor_route(task_res, f"/tasks/{task['actor']}") 

171 

172 if self._config.get("monitor"): 

173 gevent.spawn(self._enable_events) 

174 gevent.spawn(self._monitor_events) 

175 

176 if self._config.get("beamline_queue"): 

177 gevent.spawn(self._subscribe_to_queue) 

178 

179 self._ewoks_events_reader = None 

180 if instantiate_reader: 

181 if not self._config.get("ewoks_events_backend_url"): 

182 logger.error( 

183 "`ewoks_events_backend_url` must be defined in the celery config to enable event reading" 

184 ) 

185 return 

186 

187 logger.info( 

188 f"Instantiating ewoks events reader on `{self._config['ewoks_events_backend_url']}`" 

189 ) 

190 self._ewoks_events_reader = instantiate_reader( 

191 self._config["ewoks_events_backend_url"] 

192 ) 

193 

194 def _create_celery_app(self, config): 

195 if "ewoks_config" in config: 

196 if read_config is None: 

197 raise RuntimeError( 

198 "`ewoks_config` defined but blissdata is not installed in this python env" 

199 ) 

200 config_url = config["ewoks_config"] 

201 ewoks_config = read_config(config_url) 

202 celery_config = ewoks_config.get("celery", None) 

203 if celery_config is None: 

204 raise RuntimeError(f"`celery` config not found from `{config_url}`") 

205 broker = celery_config["broker_url"] 

206 backend = celery_config["result_backend"] 

207 else: 

208 broker = os.environ.get("CELERY_BROKER_URL", config["broker_url"]) 

209 backend = os.environ.get("CELERY_BACKEND_URL", config["backend_url"]) 

210 

211 app = CeleryApp("tasks", broker=broker, backend=backend) 

212 return app 

213 

214 def _subscribe_to_queue(self): 

215 """Subscribe to a beamline specific queue 

216 

217 Allows daiquiri to receieve notifications from on-going task execution 

218 """ 

219 with Connection(self._celery_app.conf.broker_url) as conn: 

220 logger.info(f"Subscribing to queue {self._config['beamline_queue']}") 

221 queue = Queue( 

222 f"notification.{self._config['beamline_queue']}", 

223 exchange="notification", 

224 durable=False, 

225 ) 

226 while True: 

227 with conn.SimpleBuffer(queue) as simple_queue: 

228 try: 

229 message = simple_queue.get_nowait() 

230 except _queue.Empty: 

231 pass 

232 else: 

233 logger.info(f"Received: {message.payload}") 

234 self._queue_emit_message(message.headers, message.payload) 

235 message.ack() 

236 time.sleep(0.1) 

237 

238 def _queue_emit_message(self, headers, message): 

239 """Debounce event emission 

240 

241 Try to not spam client 

242 """ 

243 if self._emit_timeout is not None: 

244 self._emit_timeout.kill() 

245 self._emit_timeout = None 

246 

247 now = time.time() 

248 if now - self._last_emit_time > 0.2: 

249 self._emit_message(headers, message) 

250 else: 

251 self._emit_timeout = gevent.spawn_later( 

252 0.2, self._emit_message, headers, message 

253 ) 

254 

255 def _emit_message(self, headers, message): 

256 self.emit("message", message) 

257 self._last_emit_time = time.time() 

258 

259 def get_tasks(self) -> List[str]: 

260 return self._config["tasks"] 

261 

262 def _enable_events(self) -> None: 

263 logger.info("Starting periodic enabling of celery worker events") 

264 while True: 

265 self._celery_app.control.enable_events() 

266 time.sleep(5) 

267 

268 def _on_event(self, event: dict) -> None: 

269 # https://github.com/mher/flower/blob/master/flower/events.py#L65 

270 if event["type"].startswith("task"): 

271 if event["uuid"] not in self._tasks: 

272 self._tasks[event["uuid"]] = { 

273 "args": event.get("args"), 

274 "kwargs": event.get("kwargs"), 

275 "name": event.get("name"), 

276 } 

277 

278 if event["type"] == "task-received": 

279 self._tasks[event["uuid"]]["received"] = datetime.fromtimestamp( 

280 event["timestamp"] 

281 ) 

282 

283 if event["type"] == "task-started": 

284 self._tasks[event["uuid"]]["started"] = datetime.fromtimestamp( 

285 event["timestamp"] 

286 ) 

287 

288 if event["type"] == "task-succeeded" or event["type"] == "task-failed": 

289 self._tasks[event["uuid"]]["finished"] = datetime.fromtimestamp( 

290 event["timestamp"] 

291 ) 

292 

293 if event["type"] == "task-failed": 

294 self._tasks[event["uuid"]]["exception"] = event.get("exception") 

295 self._tasks[event["uuid"]]["traceback"] = event.get("traceback") 

296 

297 def _monitor_events(self) -> None: 

298 with self._celery_app.connection() as connection: 

299 logger.info("Starting celery event monitor") 

300 recv = self._celery_app.events.Receiver( 

301 connection, handlers={"*": self._on_event} 

302 ) 

303 recv.capture(limit=None, timeout=None, wakeup=True) 

304 

305 def get_executed_tasks(self, **kwargs) -> Any: 

306 tasks = list(self._tasks.keys()) 

307 limits = get_start_end(kwargs, points=len(tasks), last=True) 

308 return { 

309 "total": len(tasks), 

310 "rows": [ 

311 self.get_task_info(job_id) 

312 for job_id in tasks[limits["st"] : limits["en"]] 

313 ], 

314 } 

315 

316 def get_task_info(self, job_id: str, get_result: bool = False) -> Dict[str, Any]: 

317 if job_id not in self._tasks: 

318 return 

319 

320 info = self._tasks[job_id] 

321 result = AsyncResult(job_id) 

322 

323 output_uris = {} 

324 # Resolve output h5 from ewoks jobs 

325 # TODO: This could go into ewoksjob 

326 if self._ewoks_events_reader: 

327 events = list( 

328 self._ewoks_events_reader.get_events(job_id=job_id, context="node") 

329 ) 

330 for event in events: 

331 if event["output_uris"]: 

332 for uri in event["output_uris"]: 

333 parts = uri["value"].split("?") 

334 if parts: 

335 stripped_uri = parts[0] 

336 if stripped_uri not in output_uris: 

337 output_uris[event["task_id"]] = stripped_uri 

338 

339 return { 

340 "job_id": job_id, 

341 "name": info["name"], 

342 "result": result.result if get_result else None, 

343 "received": info["received"].timestamp() if info.get("received") else None, 

344 "started": info["started"].timestamp() if info.get("started") else None, 

345 "finished": info["finished"].timestamp() if info.get("finished") else None, 

346 "status": result.status, 

347 "args": info.get("args"), 

348 "kwargs": info.get("kwargs"), 

349 "uris": output_uris, 

350 "exception": info.get("exception"), 

351 "traceback": info.get("traceback"), 

352 } 

353 

354 def get_workers(self) -> List[Dict]: 

355 app_inspector = self._celery_app.control.inspect() 

356 

357 workers_stats = app_inspector.stats() 

358 running_tasks = app_inspector.active() 

359 pending_tasks = app_inspector.scheduled() 

360 reserved_tasks = app_inspector.reserved() 

361 

362 rows = {} 

363 if workers_stats: 

364 for worker, stats in workers_stats.items(): 

365 if worker not in rows: 

366 rows[worker] = { 

367 "host": worker, 

368 "stats": {"uptime": stats["uptime"]}, 

369 "tasks": [], 

370 } 

371 

372 if reserved_tasks: 

373 for worker, tasks in reserved_tasks.items(): 

374 for task in tasks: 

375 rows[worker]["tasks"].append( 

376 { 

377 "status": "reserved", 

378 "id": task["id"], 

379 "name": task["name"], 

380 "args": task["args"], 

381 "kwargs": task["kwargs"], 

382 } 

383 ) 

384 

385 if pending_tasks: 

386 for worker, tasks in pending_tasks.items(): 

387 for task in tasks: 

388 rows[worker]["tasks"].append( 

389 { 

390 "status": "pending", 

391 "id": task["id"], 

392 "name": task["name"], 

393 "args": task["args"], 

394 "kwargs": task["kwargs"], 

395 } 

396 ) 

397 

398 if running_tasks: 

399 for worker, tasks in running_tasks.items(): 

400 for task in tasks: 

401 rows[worker]["tasks"].append( 

402 { 

403 "status": "running", 

404 "id": task["id"], 

405 "name": task["name"], 

406 "args": task["args"], 

407 "kwargs": task["kwargs"], 

408 "time_start": task["time_start"], 

409 } 

410 ) 

411 

412 return rows.values() 

413 

414 def revoke(self, task_id: str, terminate: bool = True) -> None: 

415 """Revoke (and kill) a scheduled or running task""" 

416 # TODO: The manual says explicitly not to do this ! 

417 self._celery_app.control.revoke(task_id, terminate=terminate, signal="SIGKILL") 

418 

419 def send_task(self, task, *args, with_id=True, **kwargs) -> AsyncResult: 

420 """Start a celery task""" 

421 if with_id: 

422 if "parameters" not in kwargs: 

423 kwargs["parameters"] = {} 

424 kwargs["parameters"]["daiquiri_id"] = str(uuid.uuid4()) 

425 

426 if self._default_queue: 

427 future = self._celery_app.send_task( 

428 task, args=args, kwargs=kwargs, queue=self._default_queue 

429 ) 

430 else: 

431 queue = kwargs.pop("queue", None) 

432 if queue: 

433 future = self._celery_app.send_task( 

434 task, args=args, kwargs=kwargs, queue=queue 

435 ) 

436 else: 

437 future = self._celery_app.send_task(task, args=args, kwargs=kwargs) 

438 

439 return future 

440 

441 def send_event(self, datacollectionid: int, event: str) -> str: 

442 """Send a processing event 

443 

444 Helper function to send an event to mimas 

445 """ 

446 if event not in ["start", "end"]: 

447 raise AttributeError(f"Unknown event {event}") 

448 

449 task_name = self._config.get("mimas_task", "sidecar.celery.mimas.task.mimas") 

450 

451 try: 

452 task = self.send_task(task_name, event, datacollectionid) 

453 

454 self.emit( 

455 "message", 

456 { 

457 "type": "event", 

458 "event": event, 

459 "datacollectionid": datacollectionid, 

460 "job_id": task.id, 

461 }, 

462 ) 

463 

464 return task.id 

465 except Exception: 

466 logger.exception("Could not send task to celery") 

467 

468 def execute_graph( 

469 self, 

470 graph: str, 

471 dataCollectionId: int, 

472 parameters: Dict[str, any] = {}, 

473 wait: bool = False, 

474 ) -> Union[Any, str]: 

475 """Execute an ewoks graph 

476 

477 Kwargs: 

478 graph: The graph to execute 

479 dataCollectionId: The datacollection to run against 

480 parameters: Any parameters to pass to the ewoks graph 

481 wait: Block and wait for future to resolve 

482 """ 

483 task_name = self._config.get( 

484 "ewoks_task", "sidecar.celery.ewoks.tasks.execute_graph" 

485 ) 

486 task = self.send_task( 

487 task_name, 

488 graph, 

489 dataCollectionId=dataCollectionId, 

490 parameters=parameters, 

491 ) 

492 

493 if wait: 

494 task.wait() 

495 return task.result 

496 

497 return task.id 

498 

499 def reprocess_graph( 

500 self, 

501 processingJobId: int, 

502 parameters: Dict[str, any] = {}, 

503 overwrite: bool = False, 

504 wait: bool = False, 

505 ) -> Union[Any, str]: 

506 """Reprocess an ewoks graph 

507 

508 Kwargs: 

509 processingJobId: The processingJobId to trigger 

510 parameters: Any parameters to pass to the ewoks graph 

511 overwrite: Overwrite an existing AutoProcProgram entry and pass this to the task 

512 wait: Block and wait for future to resolve 

513 """ 

514 task_name = self._config.get( 

515 "reprocess_ewoks_task", "sidecar.celery.ewoks.tasks.reprocess_graph" 

516 ) 

517 task = self.send_task( 

518 task_name, processingJobId, parameters=parameters, overwrite=overwrite 

519 ) 

520 

521 if wait: 

522 task.wait() 

523 return task.result 

524 

525 return task.id 

526 

527 

528class CeleryTaskActor(ComponentActor): 

529 task = None 

530 

531 def method(self, **kwargs): 

532 if self.task is None: 

533 raise RuntimeError("No task defined") 

534 

535 send_task = kwargs.pop("send_task") 

536 kwargs.pop("celery_app") 

537 future = send_task( 

538 self.task, 

539 **kwargs, 

540 ) 

541 

542 return {"task_id": future.task_id} 

543 

544 

545class CeleryTaskSchema(ComponentActorSchema): 

546 pass