Coverage for /opt/conda/envs/apienv/lib/python3.10/site-packages/daiquiri/core/hardware/bliss/scans.py: 2%

326 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-14 02:13 +0000

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3import gevent 

4import numpy as np 

5import math 

6 

7from bliss.config.settings import scan as rdsscan 

8 

9try: 

10 from blissdata.data.nodes.lima import LimaImageChannelDataNode 

11 from blissdata.data.node import get_or_create_node, get_nodes 

12 from blissdata.data.scan import watch_session_scans 

13 from blissdata.data.nodes.scan_group import GroupScanNode 

14except ImportError: 

15 from bliss.data.nodes.lima import LimaImageChannelDataNode 

16 from bliss.data.node import get_or_create_node, get_nodes 

17 from bliss.data.scan import watch_session_scans 

18 from bliss.data.nodes.scan_group import GroupScanNode 

19from bliss.scanning.scan import ScanState 

20 

21try: 

22 from silx.utils.retry import RetryTimeoutError 

23except ImportError: 

24 # For older silx version 

25 class RetryTimeoutError(Exception): 

26 pass 

27 

28 

29from daiquiri.core.hardware.bliss.helpers import get_data_from_file 

30 

31from daiquiri.core.hardware.abstract.scansource import ScanSource, ScanStates 

32from daiquiri.core.utils import get_start_end, make_json_safe 

33 

34import logging 

35 

36logger = logging.getLogger(__name__) 

37 

38 

39STATE_MAPPING = { 

40 # TODO: This should not be needed anymore with BLISS 1.11 

41 0: 0, 

42 1: 1, 

43 2: 2, 

44 3: 2, 

45 4: 3, 

46 5: 4, 

47 6: 5, 

48 # BLISS >= 1.11 

49 ScanState.IDLE.name: 0, 

50 ScanState.PREPARING.name: 1, 

51 ScanState.STARTING.name: 2, 

52 ScanState.STOPPING.name: 2, 

53 ScanState.DONE.name: 3, 

54 ScanState.USER_ABORTED.name: 4, 

55 ScanState.KILLED.name: 5, 

56} 

57 

58 

59class BlissScans(ScanSource): 

60 def __init__( 

61 self, 

62 config, 

63 app=None, 

64 source_config={}, 

65 ): 

66 super().__init__(config, app=app, source_config=source_config) 

67 logger.info("Initialising Bliss Scan Watcher") 

68 

69 self._session = get_or_create_node(self._session_name, node_type="session") 

70 self._task = gevent.spawn( 

71 watch_session_scans, 

72 self._session_name, 

73 self._scan_start, 

74 self._scan_child, 

75 self._scan_data, 

76 self._scan_end, 

77 watch_scan_group=True, 

78 exclude_existing_scans=True, 

79 ) 

80 

81 def _scan_start(self, info, *args, **kwargs): 

82 self._npoints = {} 

83 scanid = self._create_scanid(info["node_name"]) 

84 self._emit_new_scan_event( 

85 scanid, info.get("type", "none"), info["title"], metadata=info 

86 ) 

87 

88 def _scan_end(self, info, *args, **kwargs): 

89 scanid = self._create_scanid(info["node_name"]) 

90 self._emit_end_scan_event(scanid, metadata=info) 

91 

92 def _scan_child(self, *args, **kwargs): 

93 pass 

94 

95 def _scan_data(self, dims, master, details): 

96 if not ("data" in details): 

97 return 

98 

99 scanid = self._create_scanid(details["scan_info"]["node_name"]) 

100 

101 if dims == "0d": 

102 # Data is accumulated 

103 channel_name = "0d" 

104 channel_size = min(len(arr) for arr in details["data"].values()) 

105 for k, v in details["data"].items(): 

106 self._emit_new_scan_data_0d_event( 

107 scanid, channel_name=k, channel_size=len(v) 

108 ) 

109 else: 

110 channel_node = details["channel_data_node"] 

111 if isinstance(channel_node, LimaImageChannelDataNode): 

112 channel_size = details["description"]["last_image_ready"] + 1 

113 else: 

114 channel_size = details["index"] + len(details["data"]) 

115 channel_name = details["channel_name"] 

116 

117 if channel_size <= self._npoints.get(channel_name, 0): 

118 # Skip dup updates, could be: 

119 # - trigger from different 0d 

120 # - trigger from different lima field update last_image_acquired/save 

121 return 

122 self._npoints[channel_name] = channel_size 

123 min_points = min(n for n in self._npoints.values()) 

124 

125 try: 

126 expected_scan_size = details["scan_info"]["npoints"] 

127 except Exception: 

128 expected_scan_size = 1 

129 

130 try: 

131 progress = 100 * min_points / expected_scan_size 

132 except ZeroDivisionError: 

133 progress = 0 

134 

135 try: 

136 channel_progress = 100 * channel_size / expected_scan_size 

137 except ZeroDivisionError: 

138 channel_progress = 0 

139 

140 self._emit_new_scan_data_event( 

141 scanid, master, progress, channel_name, channel_size, channel_progress 

142 ) 

143 

144 def _get_scan_nodes(self): 

145 db_names = rdsscan( 

146 f"{self._session.name}:*_children_list", 

147 count=1000000, 

148 connection=self._session.db_connection, 

149 ) 

150 return ( 

151 node 

152 for node in get_nodes( 

153 *(db_name.replace("_children_list", "") for db_name in db_names) 

154 ) 

155 if node is not None and node.type in ["scan", "scan_group"] 

156 ) 

157 

158 def _get_state(self, info): 

159 """Convert bliss scans states to abstract scan state""" 

160 state = info.get("state") 

161 if state is None: 

162 return "UNKNOWN" 

163 

164 try: 

165 return ScanStates[STATE_MAPPING[state]] 

166 except KeyError: 

167 logger.error(f"No state mapping for scan state {state}") 

168 return "UNKNOWN" 

169 

170 def _get_shape(self, info): 

171 """Return scan shape 

172 

173 This is very bliss specific, maybe generalise in the future 

174 """ 

175 shape = {} 

176 for k in ["npoints1", "npoints2", "dim", "requests"]: 

177 shape[k] = info.get(k) 

178 return shape 

179 

180 def get_scans(self, scanid=None, **kwargs): 

181 scans = [] 

182 

183 if self._session is None: 

184 return scans 

185 

186 nodes = list(self._get_scan_nodes()) 

187 nodes = sorted(nodes, key=lambda k: k.info.get("start_timestamp", 0)) 

188 nodes.reverse() 

189 paging = get_start_end(kwargs, points=len(nodes)) 

190 

191 if scanid: 

192 filtered = nodes 

193 else: 

194 filtered = nodes[paging["st"] : paging["en"]] 

195 

196 for scan in filtered: 

197 if scanid: 

198 if self._create_scanid(scan.db_name) != scanid: 

199 continue 

200 

201 info = scan.info.get_all() 

202 try: 

203 info["node_name"] 

204 except KeyError: 

205 logger.exception(f"No node_name for scan {scan.db_name}") 

206 continue 

207 

208 if scanid: 

209 if self._create_scanid(info["node_name"]) != scanid: 

210 continue 

211 

212 sobj = {"scanid": self._create_scanid(info["node_name"])} 

213 for k in [ 

214 "count_time", 

215 "node_name", 

216 "npoints", 

217 "filename", 

218 "end_timestamp", 

219 "start_timestamp", 

220 "title", 

221 "type", 

222 ]: 

223 sobj[k] = info.get(k) 

224 

225 sobj["status"] = self._get_state(info) 

226 sobj["shape"] = self._get_shape(info) 

227 sobj["scan_number"] = info.get("scan_nb") 

228 

229 if isinstance(scan, GroupScanNode): 

230 sobj["group"] = True 

231 

232 children = [] 

233 for child in scan.walk(include_filter="node_ref_channel", wait=False): 

234 child_scans = child.get(0, -1) 

235 

236 for child_scan in child_scans: 

237 if child_scan is None: 

238 # This can disappear in between 

239 continue 

240 child_scan_info = child_scan.info.get_all() 

241 children.append( 

242 { 

243 "scanid": self._create_scanid( 

244 child_scan_info["node_name"] 

245 ), 

246 "type": child_scan_info["type"], 

247 } 

248 ) 

249 

250 sobj["children"] = children 

251 

252 else: 

253 try: 

254 sobj["estimated_time"] = info["estimation"]["total_time"] 

255 except KeyError: 

256 sobj["estimated_time"] = 0 

257 

258 xs = [] 

259 ys = {"images": [], "scalars": [], "spectra": []} 

260 for k, el in info["acquisition_chain"].items(): 

261 mast = el["master"] 

262 for t in ["images", "scalars", "spectra"]: 

263 xs.extend(mast[t]) 

264 

265 for t in ["images", "scalars", "spectra"]: 

266 ys[t].extend(el[t]) 

267 

268 sobj["axes"] = {"xs": xs, "ys": ys} 

269 

270 scans.append(make_json_safe(sobj)) 

271 

272 if scanid: 

273 if scans: 

274 return scans[0] 

275 else: 

276 return {"total": len(nodes), "rows": scans} 

277 

278 def get_scan_data( 

279 self, scanid, json_safe=True, scalars=None, all_scalars=False, **kwargs 

280 ): 

281 if self._session is None: 

282 return {} 

283 

284 # get the expected scan 

285 for scan in self._get_scan_nodes(): 

286 if self._create_scanid(scan.db_name) == scanid: 

287 break 

288 continue 

289 else: 

290 return {} 

291 

292 sobj = {"data": {}, "info": {}} 

293 info = scan.info.get_all() 

294 

295 try: 

296 info["node_name"] 

297 except KeyError: 

298 logger.exception(f"No node_name for scan {scan.db_name}") 

299 return {} 

300 

301 min_points = self._get_available_points(scan) 

302 sobj["npoints_avail"] = min_points 

303 

304 paging = get_start_end(kwargs, points=min_points, last=True) 

305 sobj["page"] = paging["page"] 

306 sobj["pages"] = paging["pages"] 

307 sobj["per_page"] = paging["per_page"] 

308 sobj["scanid"] = self._create_scanid(scan.db_name) 

309 sobj["npoints"] = info.get("npoints") 

310 sobj["shape"] = self._get_shape(info) 

311 

312 xs = [] 

313 ys = {"images": [], "scalars": [], "spectra": []} 

314 for el in info["acquisition_chain"].values(): 

315 mast = el["master"] 

316 for t in ["images", "scalars", "spectra"]: 

317 xs.extend(mast[t]) 

318 

319 for t in ["images", "scalars", "spectra"]: 

320 ys[t].extend(el[t]) 

321 

322 sobj["axes"] = {"xs": xs, "ys": ys} 

323 

324 scalarid = 0 

325 if scalars is None: 

326 scalars = [] 

327 hdf5_data = None 

328 

329 for node in scan.walk(include_filter="channel", wait=False): 

330 sobj["data"][node.name] = { 

331 "name": node.name, 

332 "shape": node.info["shape"], 

333 "size": len(node), 

334 "dtype": np.dtype(node.dtype).str, 

335 } 

336 

337 data = np.array([]) 

338 if len(node.info["shape"]) == 0 and paging["en"] > paging["st"]: 

339 if ( 

340 node.name not in scalars 

341 and node.name not in xs 

342 and not all_scalars 

343 and ( 

344 ( 

345 len(scalars) == 0 

346 and (node.name.startswith("timer") or scalarid >= 2) 

347 ) 

348 or len(scalars) > 0 

349 ) 

350 ): 

351 continue 

352 

353 try: 

354 if hdf5_data is not None: 

355 raise RuntimeError("Getting all data from hdf5") 

356 

357 data = node.get_as_array(paging["st"], paging["en"]) 

358 except (RuntimeError, IndexError, TypeError): 

359 if hdf5_data is None: 

360 try: 

361 hdf5_data, points = get_data_from_file(scan.name, info) 

362 logger.info("Retrieving data from hdf5") 

363 logger.debug(f"Available keys: {hdf5_data.keys()}") 

364 sobj["npoints_avail"] = points 

365 paging = get_start_end(kwargs, points=points, last=True) 

366 except (OSError, RetryTimeoutError): 

367 hdf5_data = {} 

368 logger.exception("Could not read hdf5 file") 

369 if node.name in hdf5_data: 

370 data = hdf5_data[node.name][paging["st"] : paging["en"]] 

371 else: 

372 log_failure = ( 

373 logger.exception if self._app.debug else logger.info 

374 ) 

375 log_failure( 

376 f"Couldnt get paged scan data for {node.db_name}. Requested {paging['st']} to {paging['en']}, node length {len(node)}" 

377 ) 

378 

379 # TODO: convert nan -> None 

380 # TODO: Make sure the browser doesnt interpret as infinity (1e308) 

381 data = np.nan_to_num(data, posinf=1e200, neginf=-1e200) 

382 sobj["data"][node.name]["data"] = data 

383 

384 if not node.name.startswith("timer") and node.name not in xs: 

385 scalarid += 1 

386 

387 # Channels are now deleted from redis after a short TTL (~15min) 

388 # In the case no channels are available in a scan try to read from h5 

389 if not sobj["data"]: 

390 try: 

391 hdf5_data, points = get_data_from_file(scan.name, info) 

392 logger.info("Retrieving data from hdf5") 

393 logger.debug(f"Available keys: {hdf5_data.keys()}") 

394 

395 except (OSError, RetryTimeoutError): 

396 hdf5_data = {} 

397 logger.exception("Could not read hdf5 file") 

398 else: 

399 paging = get_start_end(kwargs, points=points, last=True) 

400 for channel in hdf5_data.keys(): 

401 sobj["data"][channel] = { 

402 "name": channel, 

403 "shape": [], # scalar is an enpty shape 

404 "data": np.array([]), 

405 "size": len(hdf5_data[channel]), 

406 "dtype": hdf5_data[channel].dtype.str, 

407 } 

408 

409 if channel in xs or channel in scalars or all_scalars: 

410 data = hdf5_data[channel][paging["st"] : paging["en"]] 

411 data = np.nan_to_num(data, posinf=1e200, neginf=-1e200) 

412 sobj["data"][channel]["data"] = data 

413 

414 if json_safe: 

415 sobj = make_json_safe(sobj) 

416 return sobj 

417 

418 def _get_available_points(self, scan): 

419 """TODO: This is problematic because len(node) actually has to 

420 retrieve the data, for a scan with ~2000 ish points it takes 

421 of the order of 500ms 

422 """ 

423 import time 

424 

425 start = time.time() 

426 shortest = None 

427 min_points = math.inf 

428 for node in scan.walk(include_filter="channel", wait=False): 

429 if len(node) < min_points: 

430 shortest = node.name 

431 min_points = len(node) 

432 

433 if min_points == math.inf: 

434 min_points = 0 

435 

436 took = time.time() - start 

437 logger.debug(f"_get_available_points {shortest} {min_points} took: {took} s") 

438 

439 return min_points 

440 

441 def get_scan_spectra(self, scanid, point=0, allpoints=False): 

442 log_failure = logger.exception if self._app.debug else logger.info 

443 for scan in self._get_scan_nodes(): 

444 if self._create_scanid(scan.db_name) != scanid: 

445 continue 

446 

447 info = scan.info.get_all() 

448 try: 

449 info["node_name"] 

450 except KeyError: 

451 logger.exception(f"No node_name for scan {scan.db_name}") 

452 continue 

453 

454 min_points = self._get_available_points(scan) 

455 

456 hdf5_data = None 

457 spectra = {} 

458 for node in scan.walk(include_filter="channel", wait=False): 

459 if not node.info.get("shape"): 

460 continue 

461 

462 if len(node.info["shape"]) > 0: 

463 data = np.array([]) 

464 try: 

465 if hdf5_data is not None: 

466 raise RuntimeError("Getting all spectra from hdf5") 

467 

468 data = ( 

469 node.get_as_array(0, to_index=int(info.get("npoints", 0))) 

470 if allpoints 

471 else np.array([node.get_as_array(point)]) 

472 ) 

473 except (RuntimeError, IndexError, TypeError): 

474 if hdf5_data is None: 

475 try: 

476 hdf5_data, _points = get_data_from_file( 

477 scan.name, info, type="spectrum" 

478 ) 

479 logger.info( 

480 f"Retrieving data from hdf5: {hdf5_data.keys()}" 

481 ) 

482 except (OSError, RetryTimeoutError): 

483 hdf5_data = {} 

484 logger.exception("Could not read hdf5 file") 

485 

486 if node.name in hdf5_data: 

487 if allpoints: 

488 data = hdf5_data[node.name] 

489 else: 

490 if point < len(hdf5_data[node.name]): 

491 data = np.array([hdf5_data[node.name][point]]) 

492 else: 

493 log_failure( 

494 f"Couldnt get scan spectra for {node.db_name}. Requested point {point} outside range {len(hdf5_data[node.name])}" 

495 ) 

496 return None 

497 else: 

498 log_failure( 

499 f"Couldnt get scan spectra for {node.db_name}. Requested 0 to {info.get('npoints')}, node length {len(node)}" 

500 ) 

501 return None 

502 

503 spectra[node.name] = {"data": data, "name": node.name} 

504 

505 # Channels are now deleted from redis after a short TTL (~15min) 

506 # In the case no channels are available in a scan try to read from h5 

507 if not spectra: 

508 try: 

509 hdf5_data, _ = get_data_from_file(scan.name, info, type="spectrum") 

510 except (OSError, RetryTimeoutError): 

511 hdf5_data = {} 

512 logger.exception("Could not read hdf5 file") 

513 else: 

514 for channel in hdf5_data.keys(): 

515 data = {} 

516 if allpoints: 

517 data = hdf5_data[channel] 

518 else: 

519 if point < len(hdf5_data[channel]): 

520 data = np.array([hdf5_data[channel][point]]) 

521 else: 

522 log_failure( 

523 f"Couldnt get scan spectra for {channel}. Requested point {point} outside range {len(hdf5_data[channel])}" 

524 ) 

525 return None 

526 

527 spectra[channel] = {"data": data, "name": channel} 

528 

529 return make_json_safe( 

530 { 

531 "scanid": scanid, 

532 "data": spectra, 

533 "npoints": info.get("npoints"), 

534 "npoints_avail": min_points, 

535 "conversion": self.get_conversion(), 

536 } 

537 ) 

538 

539 def get_scan_image(self, scanid, node_name, image_no): 

540 for scan in self._get_scan_nodes(): 

541 if self._create_scanid(scan.db_name) != scanid: 

542 continue 

543 

544 info = scan.info.get_all() 

545 try: 

546 info["node_name"] 

547 except KeyError: 

548 logger.exception(f"No node_name for scan {scan.db_name}") 

549 continue 

550 

551 for node in scan.walk(include_filter="lima", wait=False): 

552 if node_name != node.name: 

553 continue 

554 

555 view = node.get(image_no, image_no) 

556 return view.get_image(image_no) 

557 

558 for node in scan.walk(include_filter="channel", wait=False): 

559 if node_name != node.name: 

560 continue 

561 shape = node.info.get("shape") 

562 if len(shape) != 2: 

563 continue 

564 return node.get_as_array(image_no) 

565 

566 raise RuntimeError( 

567 f"Data scanid={scanid} node_name={node_name} image_no={image_no} not found" 

568 )