Coverage for /opt/conda/envs/apienv/lib/python3.10/site-packages/daiquiri/core/hardware/blissdata/scans.py: 68%

320 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-14 02:13 +0000

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4from __future__ import annotations 

5from typing import Any, Dict 

6import dateutil.parser 

7 

8import gevent 

9import numpy 

10import logging 

11import lru 

12import time 

13 

14from blissdata.scan import Scan 

15from blissdata.redis_engine.exceptions import NoScanAvailable 

16from blissdata.beacon.data import BeaconData 

17from blissdata.redis_engine.store import DataStore 

18from blissdata.redis_engine.exceptions import EndOfStream 

19from blissdata.redis_engine.scan import ScanState 

20from daiquiri.core.hardware.abstract.scansource import ScanSource, ScanStates 

21from daiquiri.core.utils import get_start_end, make_json_safe 

22 

23# blissdata 2 

24# https://gitlab.esrf.fr/bliss/bliss/-/commit/dfbe24b9bee401eabd296664466df93b6f9844d1 

25try: 

26 from blissdata.streams.base import CursorGroup 

27 

28 use_blissdata_2_api = True 

29# blissdata < 2 

30except ModuleNotFoundError: 

31 from blissdata.redis_engine.stream import StreamingClient 

32 from blissdata.redis_engine.scan import Scan as RedisScan 

33 

34 use_blissdata_2_api = False 

35 

36_logger = logging.getLogger(__name__) 

37 

38STATE_MAPPING = { 

39 ScanState.CREATED: 0, 

40 ScanState.PREPARED: 1, 

41 ScanState.STARTED: 2, 

42 ScanState.STOPPED: 3, 

43 ScanState.CLOSED: 5, 

44} 

45 

46TERMINATION_MAPPING = { 

47 "SUCCESS": 3, 

48 "USER_ABORT": 4, 

49 "FAILURE": 5, 

50 "DELETION": 5, 

51} 

52 

53 

54class BlissdataScans(ScanSource): 

55 UPDATE_PERIOD = 0.1 

56 """Limiting update rate.""" 

57 

58 def __init__(self, config: dict, app=None, source_config={}): 

59 super().__init__(config, app=app, source_config=source_config) 

60 self._key_mapping: lru.LRU = lru.LRU(256) 

61 

62 url = source_config.get("redis_url") 

63 if not url: 

64 url = BeaconData().get_redis_data_db() 

65 self._store: DataStore = DataStore(url) 

66 

67 _logger.info("Initialising Blissdata Scan Watcher") 

68 

69 self._service_running: bool = True 

70 self._task = gevent.spawn(self._listen_scans) 

71 self._task.link_exception(self._gevent_exception) 

72 

73 def close(self): 

74 self._service_running = False 

75 self._task.join() 

76 self._task = None 

77 self._store = None 

78 

79 def _gevent_exception(self, greenlet): 

80 """Process gevent exception""" 

81 try: 

82 greenlet.get() 

83 except Exception: 

84 _logger.error( 

85 "Error while executing greenlet %s", greenlet.name, exc_info=True 

86 ) 

87 

88 def _listen_scans(self): 

89 timestamp_cursor = None 

90 while self._service_running: 

91 try: 

92 timestamp_cursor, scan_key = self._store.get_next_scan( 

93 since=timestamp_cursor, block=True, timeout=1 

94 ) 

95 except NoScanAvailable: 

96 continue 

97 self._key_mapping[self._create_scanid(scan_key)] = scan_key 

98 scan = self._store.load_scan(scan_key) 

99 if scan.session == self._session_name: 

100 # Only watch scans from this session 

101 # FIXME: would be good to close safely this greenlet at service termination 

102 g = gevent.spawn(self._listen_scan_state, scan) 

103 g.link_exception(self._gevent_exception) 

104 

105 def _listen_scan_state(self, scan: Scan): 

106 while scan.state < ScanState.STARTED: 

107 scan.update() 

108 

109 self._scan_start(scan.key, scan.info) 

110 try: 

111 if use_blissdata_2_api: 

112 _data_greenlet = gevent.spawn(self._listen_scan_streams2, scan) 

113 else: 

114 _data_greenlet = gevent.spawn(self._listen_scan_streams, scan) 

115 _data_greenlet.link_exception(self._gevent_exception) 

116 while scan.state < ScanState.CLOSED: 

117 scan.update() 

118 finally: 

119 if _data_greenlet is not None: 

120 _data_greenlet.join() 

121 _data_greenlet = None 

122 self._scan_end(scan.key, scan.info) 

123 

124 def _listen_scan_streams2(self, scan: Scan): 

125 raw_scan = self._store.load_scan(scan.key) 

126 cursor_group = CursorGroup(raw_scan.streams) 

127 scanid = self._create_scanid(scan.key) 

128 

129 info = scan.info 

130 expected_scan_size = info.get("scan_info", {}).get("npoints", 1) 

131 

132 while self._service_running: 

133 try: 

134 views = cursor_group.read(last_only=True) 

135 except EndOfStream: 

136 break 

137 

138 for stream, view in views.items(): 

139 channel_name = stream.name 

140 if stream.kind == "array": 

141 ndim = len(stream.info["shape"]) 

142 channel_size = view.index + len(view) 

143 if ndim == 0: 

144 self._emit_new_scan_data_0d_event( 

145 scanid, 

146 channel_name=channel_name, 

147 channel_size=channel_size, 

148 continue_=True, 

149 ) 

150 elif stream.kind == "scan": 

151 # Scan sequences should be handled at some point 

152 continue 

153 else: 

154 _logger.warning("Unsupported stream kind %s", stream.kind) 

155 continue 

156 

157 min_points = self._get_available_points(scan) 

158 

159 try: 

160 progress = 100 * min_points / expected_scan_size 

161 except ZeroDivisionError: 

162 progress = 0 

163 

164 try: 

165 channel_progress = 100 * channel_size / expected_scan_size 

166 except ZeroDivisionError: 

167 channel_progress = 0 

168 

169 self._emit_new_scan_data_event( 

170 scanid, 

171 "root", 

172 progress, 

173 channel_name, 

174 channel_size, 

175 channel_progress, 

176 ) 

177 

178 # Rate limit requests to blissdata 

179 time.sleep(0.2) 

180 

181 def _listen_scan_streams(self, scan: Scan): 

182 # Pure redis streams to enable parallel reading 

183 raw_scan = self._store.load_scan(scan.key, RedisScan) 

184 client = StreamingClient(raw_scan.streams) 

185 scanid = self._create_scanid(scan.key) 

186 

187 info = scan.info 

188 expected_scan_size = info.get("scan_info", {}).get("npoints", 1) 

189 

190 while self._service_running: 

191 try: 

192 data = client.read(count=-1) 

193 except EndOfStream: 

194 break 

195 

196 for stream, entries in data.items(): 

197 channel_name = stream.name 

198 encoding_type = stream.encoding["type"] 

199 index, messages = entries 

200 if encoding_type == "numeric": 

201 ndim = len(stream.encoding["shape"]) 

202 channel_size = index + len(messages) 

203 if ndim == 0: 

204 self._emit_new_scan_data_0d_event( 

205 scanid, 

206 channel_name=channel_name, 

207 channel_size=channel_size, 

208 continue_=True, 

209 ) 

210 elif encoding_type == "json": 

211 json_format = stream.info.get("format") 

212 if json_format == "lima_v1": 

213 channel_size = messages[-1].get("last_index") + 1 

214 elif json_format == "lima_v2": 

215 channel_size = messages[-1].get("last_index") + 1 

216 elif json_format == "subscan": 

217 # This should be handled at some point 

218 continue 

219 else: 

220 _logger.warning("Unsupported json format type %s", json_format) 

221 continue 

222 else: 

223 _logger.warning("Unsupported encoding type %s", encoding_type) 

224 continue 

225 

226 min_points = self._get_available_points(scan) 

227 

228 try: 

229 progress = 100 * min_points / expected_scan_size 

230 except ZeroDivisionError: 

231 progress = 0 

232 

233 try: 

234 channel_progress = 100 * channel_size / expected_scan_size 

235 except ZeroDivisionError: 

236 channel_progress = 0 

237 

238 self._emit_new_scan_data_event( 

239 scanid, 

240 "root", 

241 progress, 

242 channel_name, 

243 channel_size, 

244 channel_progress, 

245 ) 

246 

247 # Rate limit requests to blissdata 

248 time.sleep(0.2) 

249 

250 def _scan_start(self, scan_key: str, info: dict): 

251 scanid = self._create_scanid(scan_key) 

252 self._emit_new_scan_event( 

253 scanid, info.get("type", "none"), info["title"], metadata=info 

254 ) 

255 

256 def _scan_end(self, scan_key: str, info: dict): 

257 scanid = self._create_scanid(scan_key) 

258 self._emit_end_scan_event(scanid, metadata=info) 

259 

260 def _get_state(self, state: ScanState, scan_info: dict): 

261 """Convert blissdata scans states to abstract scan state""" 

262 if state is None: 

263 return "UNKNOWN" 

264 if state == ScanState.CLOSED: 

265 termination = scan_info.get("end_reason", None) 

266 iresult = TERMINATION_MAPPING.get(termination) 

267 if iresult is None: 

268 _logger.error( 

269 f"No state termination mapping for scan state {termination}" 

270 ) 

271 return "UNKNOWN" 

272 else: 

273 iresult = STATE_MAPPING.get(state) 

274 if iresult is None: 

275 _logger.error(f"No state mapping for scan state {state}") 

276 return "UNKNOWN" 

277 return ScanStates[iresult] 

278 

279 def _get_shape(self, info): 

280 """Return scan shape 

281 

282 This is very bliss specific, maybe generalise in the future 

283 """ 

284 shape = {} 

285 for k in ["npoints1", "npoints2", "dim", "requests"]: 

286 shape[k] = info.get(k) 

287 return shape 

288 

289 def get_scans(self, scanid: int | None = None, **kwargs): 

290 def scan_to_result(scan: Scan): 

291 info = scan.info 

292 sobj: dict = {"scanid": self._create_scanid(scan.key)} 

293 for k in [ 

294 "count_time", 

295 "npoints", 

296 "filename", 

297 "title", 

298 "type", 

299 ]: 

300 sobj[k] = info.get(k) 

301 

302 sobj["scan_number"] = info.get("scan_nb") 

303 sobj["status"] = self._get_state(scan.state, info) 

304 sobj["shape"] = self._get_shape(info) 

305 sobj["start_timestamp"] = ( 

306 dateutil.parser.parse(info["start_time"]).timestamp() 

307 if info.get("start_time") 

308 else None 

309 ) 

310 sobj["end_timestamp"] = ( 

311 dateutil.parser.parse(info["end_time"]).timestamp() 

312 if info.get("end_time") 

313 else None 

314 ) 

315 

316 if info.get("is_scan_sequence", False): 

317 sobj["group"] = True 

318 

319 children: list[dict] = [] 

320 """ 

321 # FIXME: This can be reached from a dedicated channel 

322 

323 for child in scan.walk(include_filter="node_ref_channel", wait=False): 

324 child_scans = child.get(0, -1) 

325 

326 for child_scan in child_scans: 

327 if child_scan is None: 

328 # This can disappear in between 

329 continue 

330 child_scan_info = child_scan.info.get_all() 

331 children.append( 

332 { 

333 "scanid": self._scanid(child_scan_info["node_name"]), 

334 "type": child_scan_info["type"], 

335 } 

336 ) 

337 """ 

338 sobj["children"] = children 

339 

340 else: 

341 try: 

342 sobj["estimated_time"] = info["estimation"]["total_time"] 

343 except KeyError: 

344 sobj["estimated_time"] = 0 

345 

346 xs = [] 

347 ys: dict = {"images": [], "scalars": [], "spectra": []} 

348 for k, el in info["acquisition_chain"].items(): 

349 mast = el["master"] 

350 for t in ["images", "scalars", "spectra"]: 

351 xs.extend(mast[t]) 

352 

353 for t in ["images", "scalars", "spectra"]: 

354 ys[t].extend(el[t]) 

355 

356 sobj["axes"] = {"xs": xs, "ys": ys} 

357 return sobj 

358 

359 if scanid is not None: 

360 # Special case for a single scan 

361 # FIXME: A dedicated `ScanSource.get_scan` API should be exposed instead 

362 scan = self._get_scan_from_scanid(scanid) 

363 if scan is None: 

364 return None 

365 sobj = scan_to_result(scan) 

366 return make_json_safe(sobj) 

367 

368 _timestamp, scan_keys = self._store.search_existing_scans( 

369 session=self._session_name 

370 ) 

371 # NOTE: sorting keys is the same as sorting by datetime of the creation 

372 # NOTE: The `reversed` is a copy-paste from bliss connector 

373 scan_keys = list(reversed(sorted(scan_keys))) 

374 paging = get_start_end(kwargs, points=len(scan_keys)) 

375 filtered = scan_keys[paging["st"] : paging["en"]] 

376 

377 scans = [] 

378 for scan_key in filtered: 

379 scan = self._store.load_scan(scan_key) 

380 sobj = scan_to_result(scan) 

381 scans.append(make_json_safe(sobj)) 

382 return {"total": len(filtered), "rows": scans} 

383 

384 def _get_scankey_from_scanid(self, scanid: int) -> str | None: 

385 if scanid not in self._key_mapping: 

386 # FIXME: i would be better to use scan_key in the daiquiri client 

387 _timestamp, scan_keys = self._store.search_existing_scans( 

388 session=self._session_name 

389 ) 

390 for scan_key in scan_keys: 

391 # We could use all this keys to update the cache 

392 if self._create_scanid(scan_key) == scanid: 

393 break 

394 else: 

395 # That could be a problem not to update the cache with None 

396 return None 

397 self._key_mapping[scanid] = scan_key 

398 else: 

399 scan_key = self._key_mapping[scanid] 

400 

401 return scan_key 

402 

403 def _get_scan_from_scanid(self, scanid: int) -> Scan | None: 

404 scan_key = self._get_scankey_from_scanid(scanid) 

405 if scan_key is None: 

406 return None 

407 return self._store.load_scan(scan_key) 

408 

409 def get_scan_data( 

410 self, scanid, json_safe=True, scalars=None, all_scalars=False, **kwargs 

411 ): 

412 scan = self._get_scan_from_scanid(scanid) 

413 if scan is None: 

414 return {} 

415 

416 sobj: Dict[str, Any] = {"data": {}, "info": {}} 

417 info = scan.info 

418 

419 min_points = self._get_available_points(scan) 

420 sobj["npoints_avail"] = min_points 

421 

422 paging = get_start_end(kwargs, points=min_points, last=True) 

423 sobj["page"] = paging["page"] 

424 sobj["pages"] = paging["pages"] 

425 sobj["per_page"] = paging["per_page"] 

426 sobj["scanid"] = scanid 

427 sobj["npoints"] = info.get("npoints") 

428 sobj["shape"] = self._get_shape(info) 

429 

430 xs = [] 

431 ys: dict = {"images": [], "scalars": [], "spectra": []} 

432 for el in info["acquisition_chain"].values(): 

433 mast = el["master"] 

434 for t in ["images", "scalars", "spectra"]: 

435 xs.extend(mast[t]) 

436 

437 for t in ["images", "scalars", "spectra"]: 

438 ys[t].extend(el[t]) 

439 

440 sobj["axes"] = {"xs": xs, "ys": ys} 

441 

442 scalarid = 0 

443 if scalars is None: 

444 scalars = [] 

445 

446 for stream in scan.streams.values(): 

447 if stream.name == "SUBSCANS": 

448 # That's a special stream, better to skip it 

449 continue 

450 sobj["data"][stream.name] = { 

451 "name": stream.name, 

452 "shape": stream.info["shape"], 

453 "size": len(stream), 

454 "dtype": numpy.dtype(stream.info["dtype"]).str, 

455 } 

456 

457 data = numpy.array([]) 

458 if len(stream.info["shape"]) == 0 and paging["en"] > paging["st"]: 

459 if ( 

460 stream.name not in scalars 

461 and stream.name not in xs 

462 and not all_scalars 

463 and ( 

464 ( 

465 len(scalars) == 0 

466 and (stream.name.startswith("timer") or scalarid >= 2) 

467 ) 

468 or len(scalars) > 0 

469 ) 

470 ): 

471 continue 

472 

473 # NOTE: Data size could be smaller than what it is requested 

474 data = stream[paging["st"] : paging["en"]] 

475 

476 # TODO: convert nan -> None 

477 # TODO: Make sure the browser doesnt interpret as infinity (1e308) 

478 data = numpy.nan_to_num(data, posinf=1e200, neginf=-1e200) 

479 sobj["data"][stream.name]["data"] = data 

480 

481 if not stream.name.startswith("timer") and stream.name not in xs: 

482 scalarid += 1 

483 

484 if json_safe: 

485 sobj = make_json_safe(sobj) 

486 return sobj 

487 

488 def _get_available_points(self, scan: Scan) -> int: 

489 """TODO: This is problematic because len(node) actually has to 

490 retrieve the data, for a scan with ~2000 ish points it takes 

491 of the order of 500ms 

492 """ 

493 min_points = None 

494 for stream in scan.streams.values(): 

495 nb = len(stream) 

496 if min_points is None or nb < min_points: 

497 min_points = nb 

498 

499 if min_points is None: 

500 return 0 

501 return min_points 

502 

503 def get_scan_spectra(self, scanid, point=0, allpoints=False): 

504 log_failure = _logger.exception if self._app.debug else _logger.info 

505 scan = self._get_scan_from_scanid(scanid) 

506 if scan is None: 

507 return None 

508 

509 info = scan.info 

510 

511 min_points = self._get_available_points(scan) 

512 

513 spectra = {} 

514 for stream in scan.streams.values(): 

515 shape = stream.info.get("shape") 

516 if stream is None: 

517 continue 

518 

519 if len(shape) == 1: 

520 data = numpy.array([]) 

521 try: 

522 data = ( 

523 stream[0 : int(info.get("npoints", 0))] 

524 if allpoints 

525 else numpy.array([stream[point]]) 

526 ) 

527 except (RuntimeError, IndexError, TypeError): 

528 log_failure( 

529 f"Couldnt get scan spectra for {stream.name}. Requested 0 to {info.get('npoints')}, node length {len(stream)}" 

530 ) 

531 return None 

532 

533 spectra[stream.name] = {"data": data, "name": stream.name} 

534 

535 return make_json_safe( 

536 { 

537 "scanid": scanid, 

538 "data": spectra, 

539 "npoints": info.get("npoints"), 

540 "npoints_avail": min_points, 

541 "conversion": self.get_conversion(), 

542 } 

543 ) 

544 

545 def get_scan_image(self, scanid: int, node_name: str, image_no: int): 

546 scan = self._get_scan_from_scanid(scanid) 

547 if scan is None: 

548 return None 

549 stream = scan.streams.get(node_name) 

550 if stream is None: 

551 return None 

552 return stream[image_no]