Coverage for /opt/conda/envs/apienv/lib/python3.11/site-packages/daiquiri/core/hardware/bliss/scans.py: 2%

329 statements  

« prev     ^ index     » next       coverage.py v7.7.1, created at 2025-03-29 02:12 +0000

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3import gevent 

4import numpy as np 

5import math 

6 

7from bliss.config.settings import scan as rdsscan 

8 

9try: 

10 from blissdata.data.nodes.lima import LimaImageChannelDataNode 

11 from blissdata.data.node import get_or_create_node, get_nodes 

12 from blissdata.data.scan import watch_session_scans 

13 from blissdata.data.nodes.scan_group import GroupScanNode 

14except ImportError: 

15 from bliss.data.nodes.lima import LimaImageChannelDataNode 

16 from bliss.data.node import get_or_create_node, get_nodes 

17 from bliss.data.scan import watch_session_scans 

18 from bliss.data.nodes.scan_group import GroupScanNode 

19from bliss.scanning.scan import ScanState 

20 

21try: 

22 from silx.utils.retry import RetryTimeoutError 

23except ImportError: 

24 # For older silx version 

25 class RetryTimeoutError(Exception): 

26 pass 

27 

28 

29from daiquiri.core.hardware.bliss.helpers import get_data_from_file 

30 

31from daiquiri.core.hardware.abstract.scansource import ScanSource, ScanStates 

32from daiquiri.core.utils import get_start_end, make_json_safe 

33 

34import logging 

35 

36logger = logging.getLogger(__name__) 

37 

38 

39STATE_MAPPING = { 

40 # TODO: This should not be needed anymore with BLISS 1.11 

41 0: 0, 

42 1: 1, 

43 2: 2, 

44 3: 2, 

45 4: 3, 

46 5: 4, 

47 6: 5, 

48 # BLISS >= 1.11 

49 ScanState.IDLE.name: 0, 

50 ScanState.PREPARING.name: 1, 

51 ScanState.STARTING.name: 2, 

52 ScanState.STOPPING.name: 2, 

53 ScanState.DONE.name: 3, 

54 ScanState.USER_ABORTED.name: 4, 

55 ScanState.KILLED.name: 5, 

56} 

57 

58 

59class BlissScans(ScanSource): 

60 def __init__( 

61 self, 

62 config, 

63 app=None, 

64 source_config={}, 

65 ): 

66 super().__init__(config, app=app, source_config=source_config) 

67 logger.info("Initialising Bliss Scan Watcher") 

68 

69 self._session = get_or_create_node(self._session_name, node_type="session") 

70 self._task = gevent.spawn( 

71 watch_session_scans, 

72 self._session_name, 

73 self._scan_start, 

74 self._scan_child, 

75 self._scan_data, 

76 self._scan_end, 

77 watch_scan_group=True, 

78 exclude_existing_scans=True, 

79 ) 

80 

81 def _scan_start(self, info, *args, **kwargs): 

82 self._npoints = {} 

83 scanid = self._create_scanid(info["node_name"]) 

84 self._emit_new_scan_event( 

85 scanid, info.get("type", "none"), info["title"], metadata=info 

86 ) 

87 

88 def _scan_end(self, info, *args, **kwargs): 

89 scanid = self._create_scanid(info["node_name"]) 

90 self._emit_end_scan_event(scanid, metadata=info) 

91 

92 def _scan_child(self, *args, **kwargs): 

93 pass 

94 

95 def _scan_data(self, dims, master, details): 

96 if not ("data" in details): 

97 return 

98 

99 scanid = self._create_scanid(details["scan_info"]["node_name"]) 

100 

101 if dims == "0d": 

102 # Data is accumulated 

103 channel_name = "0d" 

104 channel_size = min(len(arr) for arr in details["data"].values()) 

105 for k, v in details["data"].items(): 

106 self._emit_new_scan_data_0d_event( 

107 scanid, channel_name=k, channel_size=len(v) 

108 ) 

109 else: 

110 channel_node = details["channel_data_node"] 

111 if isinstance(channel_node, LimaImageChannelDataNode): 

112 channel_size = details["description"]["last_image_ready"] + 1 

113 else: 

114 channel_size = details["index"] + len(details["data"]) 

115 channel_name = details["channel_name"] 

116 

117 if channel_size <= self._npoints.get(channel_name, 0): 

118 # Skip dup updates, could be: 

119 # - trigger from different 0d 

120 # - trigger from different lima field update last_image_acquired/save 

121 return 

122 self._npoints[channel_name] = channel_size 

123 min_points = min(n for n in self._npoints.values()) 

124 

125 try: 

126 expected_scan_size = details["scan_info"]["npoints"] 

127 except Exception: 

128 expected_scan_size = 1 

129 

130 try: 

131 progress = 100 * min_points / expected_scan_size 

132 except ZeroDivisionError: 

133 progress = 0 

134 

135 try: 

136 channel_progress = 100 * channel_size / expected_scan_size 

137 except ZeroDivisionError: 

138 channel_progress = 0 

139 

140 self._emit_new_scan_data_event( 

141 scanid, master, progress, channel_name, channel_size, channel_progress 

142 ) 

143 

144 def _get_scan_nodes(self): 

145 db_names = rdsscan( 

146 f"{self._session.name}:*_children_list", 

147 count=1000000, 

148 connection=self._session.db_connection, 

149 ) 

150 return ( 

151 node 

152 for node in get_nodes( 

153 *(db_name.replace("_children_list", "") for db_name in db_names) 

154 ) 

155 if node is not None and node.type in ["scan", "scan_group"] 

156 ) 

157 

158 def _get_state(self, info): 

159 """Convert bliss scans states to abstract scan state""" 

160 state = info.get("state") 

161 if state is None: 

162 return "UNKNOWN" 

163 

164 try: 

165 return ScanStates[STATE_MAPPING[state]] 

166 except KeyError: 

167 logger.error(f"No state mapping for scan state {state}") 

168 return "UNKNOWN" 

169 

170 def _get_shape(self, info): 

171 """Return scan shape 

172 

173 This is very bliss specific, maybe generalise in the future 

174 """ 

175 shape = {} 

176 for k in ["npoints1", "npoints2", "dim", "requests"]: 

177 shape[k] = info.get(k) 

178 return shape 

179 

180 def get_scans(self, scanid=None, **kwargs): 

181 scans = [] 

182 

183 if self._session is None: 

184 return scans 

185 

186 nodes = list(self._get_scan_nodes()) 

187 nodes = sorted(nodes, key=lambda k: k.info.get("start_timestamp", 0)) 

188 nodes.reverse() 

189 paging = get_start_end(kwargs, points=len(nodes)) 

190 

191 if scanid: 

192 filtered = nodes 

193 else: 

194 filtered = nodes[paging["st"] : paging["en"]] 

195 

196 for scan in filtered: 

197 if scanid: 

198 if self._create_scanid(scan.db_name) != scanid: 

199 continue 

200 

201 info = scan.info.get_all() 

202 try: 

203 info["node_name"] 

204 except KeyError: 

205 logger.exception(f"No node_name for scan {scan.db_name}") 

206 continue 

207 

208 if scanid: 

209 if self._create_scanid(info["node_name"]) != scanid: 

210 continue 

211 

212 sobj = {"scanid": self._create_scanid(info["node_name"])} 

213 for k in [ 

214 "count_time", 

215 "node_name", 

216 "npoints", 

217 "filename", 

218 "end_timestamp", 

219 "start_timestamp", 

220 "title", 

221 "type", 

222 ]: 

223 sobj[k] = info.get(k) 

224 

225 sobj["status"] = self._get_state(info) 

226 sobj["shape"] = self._get_shape(info) 

227 sobj["scan_number"] = info.get("scan_nb") 

228 

229 if isinstance(scan, GroupScanNode): 

230 sobj["group"] = True 

231 

232 children: list[dict] = [] 

233 # Iterate a custom channel on the sequence called "child_nodes" 

234 child_channel = "child_nodes" 

235 data = self.get_scan_data(sobj["scanid"], scalars=[child_channel]) 

236 if child_channel in data["data"]: 

237 child_node_names = [ 

238 # bytes from h5, str from redis? 

239 s.decode("utf-8") if isinstance(s, bytes) else s 

240 for s in data["data"][child_channel]["data"] 

241 ] 

242 if child_node_names: 

243 for child_scan in nodes: 

244 if child_scan.db_name in child_node_names: 

245 child_scan_info = child_scan.info.get_all() 

246 children.append( 

247 { 

248 "scanid": self._scanid(child_scan.db_name), 

249 "type": child_scan_info.get("type"), 

250 "node": child_scan.db_name, 

251 } 

252 ) 

253 

254 children = sorted( 

255 children, 

256 key=lambda child: child_node_names.index(child["node"]), 

257 ) 

258 

259 sobj["children"] = children 

260 

261 else: 

262 try: 

263 sobj["estimated_time"] = info["estimation"]["total_time"] 

264 except KeyError: 

265 sobj["estimated_time"] = 0 

266 

267 xs = [] 

268 ys = {"images": [], "scalars": [], "spectra": []} 

269 for k, el in info["acquisition_chain"].items(): 

270 mast = el["master"] 

271 for t in ["images", "scalars", "spectra"]: 

272 xs.extend(mast.get(t, [])) 

273 

274 for t in ["images", "scalars", "spectra"]: 

275 ys[t].extend(el.get(t, [])) 

276 

277 sobj["axes"] = {"xs": xs, "ys": ys} 

278 

279 scans.append(make_json_safe(sobj)) 

280 

281 if scanid: 

282 if scans: 

283 return scans[0] 

284 else: 

285 return {"total": len(nodes), "rows": scans} 

286 

287 def get_scan_data( 

288 self, scanid, json_safe=True, scalars=None, all_scalars=False, **kwargs 

289 ): 

290 if self._session is None: 

291 return {} 

292 

293 # get the expected scan 

294 for scan in self._get_scan_nodes(): 

295 if self._create_scanid(scan.db_name) == scanid: 

296 break 

297 continue 

298 else: 

299 return {} 

300 

301 sobj = {"data": {}, "info": {}} 

302 info = scan.info.get_all() 

303 

304 try: 

305 info["node_name"] 

306 except KeyError: 

307 logger.exception(f"No node_name for scan {scan.db_name}") 

308 return {} 

309 

310 min_points = self._get_available_points(scan) 

311 sobj["npoints_avail"] = min_points 

312 

313 paging = get_start_end(kwargs, points=min_points, last=True) 

314 sobj["page"] = paging["page"] 

315 sobj["pages"] = paging["pages"] 

316 sobj["per_page"] = paging["per_page"] 

317 sobj["scanid"] = self._create_scanid(scan.db_name) 

318 sobj["npoints"] = info.get("npoints") 

319 sobj["shape"] = self._get_shape(info) 

320 

321 xs = [] 

322 ys = {"images": [], "scalars": [], "spectra": []} 

323 for el in info["acquisition_chain"].values(): 

324 mast = el["master"] 

325 for t in ["images", "scalars", "spectra"]: 

326 xs.extend(mast.get(t, [])) 

327 

328 for t in ["images", "scalars", "spectra"]: 

329 ys[t].extend(el.get(t, [])) 

330 

331 sobj["axes"] = {"xs": xs, "ys": ys} 

332 

333 scalarid = 0 

334 if scalars is None: 

335 scalars = [] 

336 hdf5_data = None 

337 

338 for node in scan.walk(include_filter="channel", wait=False): 

339 sobj["data"][node.name] = { 

340 "name": node.name, 

341 "shape": node.info["shape"], 

342 "size": len(node), 

343 "dtype": np.dtype(node.dtype).str, 

344 } 

345 

346 data = np.array([]) 

347 if len(node.info["shape"]) == 0 and paging["en"] > paging["st"]: 

348 if ( 

349 node.name not in scalars 

350 and node.name not in xs 

351 and not all_scalars 

352 and ( 

353 ( 

354 len(scalars) == 0 

355 and (node.name.startswith("timer") or scalarid >= 2) 

356 ) 

357 or len(scalars) > 0 

358 ) 

359 ): 

360 continue 

361 

362 try: 

363 if hdf5_data is not None: 

364 raise RuntimeError("Getting all data from hdf5") 

365 

366 data = node.get_as_array(paging["st"], paging["en"]) 

367 except (RuntimeError, IndexError, TypeError): 

368 if hdf5_data is None: 

369 try: 

370 hdf5_data, points = get_data_from_file(scan.name, info) 

371 logger.info("Retrieving data from hdf5") 

372 logger.debug(f"Available keys: {hdf5_data.keys()}") 

373 sobj["npoints_avail"] = points 

374 paging = get_start_end(kwargs, points=points, last=True) 

375 except (OSError, RetryTimeoutError): 

376 hdf5_data = {} 

377 logger.exception("Could not read hdf5 file") 

378 if node.name in hdf5_data: 

379 data = hdf5_data[node.name][paging["st"] : paging["en"]] 

380 else: 

381 log_failure = ( 

382 logger.exception if self._app.debug else logger.info 

383 ) 

384 log_failure( 

385 f"Couldnt get paged scan data for {node.db_name}. Requested {paging['st']} to {paging['en']}, node length {len(node)}" 

386 ) 

387 

388 # TODO: convert nan -> None 

389 # TODO: Make sure the browser doesnt interpret as infinity (1e308) 

390 data = np.nan_to_num(data, posinf=1e200, neginf=-1e200) 

391 sobj["data"][node.name]["data"] = data 

392 

393 if not node.name.startswith("timer") and node.name not in xs: 

394 scalarid += 1 

395 

396 # Channels are now deleted from redis after a short TTL (~15min) 

397 # In the case no channels are available in a scan try to read from h5 

398 if not sobj["data"]: 

399 try: 

400 hdf5_data, points = get_data_from_file(scan.name, info) 

401 logger.info("Retrieving data from hdf5") 

402 logger.debug(f"Available keys: {hdf5_data.keys()}") 

403 

404 except (OSError, RetryTimeoutError): 

405 hdf5_data = {} 

406 logger.exception("Could not read hdf5 file") 

407 else: 

408 paging = get_start_end(kwargs, points=points, last=True) 

409 for channel in hdf5_data.keys(): 

410 sobj["data"][channel] = { 

411 "name": channel, 

412 "shape": [], # scalar is an enpty shape 

413 "data": np.array([]), 

414 "size": len(hdf5_data[channel]), 

415 "dtype": hdf5_data[channel].dtype.str, 

416 } 

417 

418 if channel in xs or channel in scalars or all_scalars: 

419 data = hdf5_data[channel][paging["st"] : paging["en"]] 

420 data = np.nan_to_num(data, posinf=1e200, neginf=-1e200) 

421 sobj["data"][channel]["data"] = data 

422 

423 if json_safe: 

424 sobj = make_json_safe(sobj) 

425 return sobj 

426 

427 def _get_available_points(self, scan): 

428 """TODO: This is problematic because len(node) actually has to 

429 retrieve the data, for a scan with ~2000 ish points it takes 

430 of the order of 500ms 

431 """ 

432 import time 

433 

434 start = time.time() 

435 shortest = None 

436 min_points = math.inf 

437 for node in scan.walk(include_filter="channel", wait=False): 

438 if len(node) < min_points: 

439 shortest = node.name 

440 min_points = len(node) 

441 

442 if min_points == math.inf: 

443 min_points = 0 

444 

445 took = time.time() - start 

446 logger.debug(f"_get_available_points {shortest} {min_points} took: {took} s") 

447 

448 return min_points 

449 

450 def get_scan_spectra(self, scanid, point=0, allpoints=False): 

451 log_failure = logger.exception if self._app.debug else logger.info 

452 for scan in self._get_scan_nodes(): 

453 if self._create_scanid(scan.db_name) != scanid: 

454 continue 

455 

456 info = scan.info.get_all() 

457 try: 

458 info["node_name"] 

459 except KeyError: 

460 logger.exception(f"No node_name for scan {scan.db_name}") 

461 continue 

462 

463 min_points = self._get_available_points(scan) 

464 

465 hdf5_data = None 

466 spectra = {} 

467 for node in scan.walk(include_filter="channel", wait=False): 

468 if not node.info.get("shape"): 

469 continue 

470 

471 if len(node.info["shape"]) > 0: 

472 data = np.array([]) 

473 try: 

474 if hdf5_data is not None: 

475 raise RuntimeError("Getting all spectra from hdf5") 

476 

477 data = ( 

478 node.get_as_array(0, to_index=int(info.get("npoints", 0))) 

479 if allpoints 

480 else np.array([node.get_as_array(point)]) 

481 ) 

482 except (RuntimeError, IndexError, TypeError): 

483 if hdf5_data is None: 

484 try: 

485 hdf5_data, _points = get_data_from_file( 

486 scan.name, info, type="spectrum" 

487 ) 

488 logger.info( 

489 f"Retrieving data from hdf5: {hdf5_data.keys()}" 

490 ) 

491 except (OSError, RetryTimeoutError): 

492 hdf5_data = {} 

493 logger.exception("Could not read hdf5 file") 

494 

495 if node.name in hdf5_data: 

496 if allpoints: 

497 data = hdf5_data[node.name] 

498 else: 

499 if point < len(hdf5_data[node.name]): 

500 data = np.array([hdf5_data[node.name][point]]) 

501 else: 

502 log_failure( 

503 f"Couldnt get scan spectra for {node.db_name}. Requested point {point} outside range {len(hdf5_data[node.name])}" 

504 ) 

505 return None 

506 else: 

507 log_failure( 

508 f"Couldnt get scan spectra for {node.db_name}. Requested 0 to {info.get('npoints')}, node length {len(node)}" 

509 ) 

510 return None 

511 

512 spectra[node.name] = {"data": data, "name": node.name} 

513 

514 # Channels are now deleted from redis after a short TTL (~15min) 

515 # In the case no channels are available in a scan try to read from h5 

516 if not spectra: 

517 try: 

518 hdf5_data, _ = get_data_from_file(scan.name, info, type="spectrum") 

519 except (OSError, RetryTimeoutError): 

520 hdf5_data = {} 

521 logger.exception("Could not read hdf5 file") 

522 else: 

523 for channel in hdf5_data.keys(): 

524 data = {} 

525 if allpoints: 

526 data = hdf5_data[channel] 

527 else: 

528 if point < len(hdf5_data[channel]): 

529 data = np.array([hdf5_data[channel][point]]) 

530 else: 

531 log_failure( 

532 f"Couldnt get scan spectra for {channel}. Requested point {point} outside range {len(hdf5_data[channel])}" 

533 ) 

534 return None 

535 

536 spectra[channel] = {"data": data, "name": channel} 

537 

538 return make_json_safe( 

539 { 

540 "scanid": scanid, 

541 "data": spectra, 

542 "npoints": info.get("npoints"), 

543 "npoints_avail": min_points, 

544 "conversion": self.get_conversion(), 

545 } 

546 ) 

547 

548 def get_scan_image(self, scanid, node_name, image_no): 

549 for scan in self._get_scan_nodes(): 

550 if self._create_scanid(scan.db_name) != scanid: 

551 continue 

552 

553 info = scan.info.get_all() 

554 try: 

555 info["node_name"] 

556 except KeyError: 

557 logger.exception(f"No node_name for scan {scan.db_name}") 

558 continue 

559 

560 for node in scan.walk(include_filter="lima", wait=False): 

561 if node_name != node.name: 

562 continue 

563 

564 view = node.get(image_no, image_no) 

565 return view.get_image(image_no) 

566 

567 for node in scan.walk(include_filter="channel", wait=False): 

568 if node_name != node.name: 

569 continue 

570 shape = node.info.get("shape") 

571 if len(shape) != 2: 

572 continue 

573 return node.get_as_array(image_no) 

574 

575 raise RuntimeError( 

576 f"Data scanid={scanid} node_name={node_name} image_no={image_no} not found" 

577 )