Coverage for /opt/conda/envs/apienv/lib/python3.10/site-packages/daiquiri/core/hardware/blissdata/scans.py: 68%
320 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-14 02:13 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-14 02:13 +0000
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4from __future__ import annotations
5from typing import Any, Dict
6import dateutil.parser
8import gevent
9import numpy
10import logging
11import lru
12import time
14from blissdata.scan import Scan
15from blissdata.redis_engine.exceptions import NoScanAvailable
16from blissdata.beacon.data import BeaconData
17from blissdata.redis_engine.store import DataStore
18from blissdata.redis_engine.exceptions import EndOfStream
19from blissdata.redis_engine.scan import ScanState
20from daiquiri.core.hardware.abstract.scansource import ScanSource, ScanStates
21from daiquiri.core.utils import get_start_end, make_json_safe
23# blissdata 2
24# https://gitlab.esrf.fr/bliss/bliss/-/commit/dfbe24b9bee401eabd296664466df93b6f9844d1
25try:
26 from blissdata.streams.base import CursorGroup
28 use_blissdata_2_api = True
29# blissdata < 2
30except ModuleNotFoundError:
31 from blissdata.redis_engine.stream import StreamingClient
32 from blissdata.redis_engine.scan import Scan as RedisScan
34 use_blissdata_2_api = False
36_logger = logging.getLogger(__name__)
38STATE_MAPPING = {
39 ScanState.CREATED: 0,
40 ScanState.PREPARED: 1,
41 ScanState.STARTED: 2,
42 ScanState.STOPPED: 3,
43 ScanState.CLOSED: 5,
44}
46TERMINATION_MAPPING = {
47 "SUCCESS": 3,
48 "USER_ABORT": 4,
49 "FAILURE": 5,
50 "DELETION": 5,
51}
54class BlissdataScans(ScanSource):
55 UPDATE_PERIOD = 0.1
56 """Limiting update rate."""
58 def __init__(self, config: dict, app=None, source_config={}):
59 super().__init__(config, app=app, source_config=source_config)
60 self._key_mapping: lru.LRU = lru.LRU(256)
62 url = source_config.get("redis_url")
63 if not url:
64 url = BeaconData().get_redis_data_db()
65 self._store: DataStore = DataStore(url)
67 _logger.info("Initialising Blissdata Scan Watcher")
69 self._service_running: bool = True
70 self._task = gevent.spawn(self._listen_scans)
71 self._task.link_exception(self._gevent_exception)
73 def close(self):
74 self._service_running = False
75 self._task.join()
76 self._task = None
77 self._store = None
79 def _gevent_exception(self, greenlet):
80 """Process gevent exception"""
81 try:
82 greenlet.get()
83 except Exception:
84 _logger.error(
85 "Error while executing greenlet %s", greenlet.name, exc_info=True
86 )
88 def _listen_scans(self):
89 timestamp_cursor = None
90 while self._service_running:
91 try:
92 timestamp_cursor, scan_key = self._store.get_next_scan(
93 since=timestamp_cursor, block=True, timeout=1
94 )
95 except NoScanAvailable:
96 continue
97 self._key_mapping[self._create_scanid(scan_key)] = scan_key
98 scan = self._store.load_scan(scan_key)
99 if scan.session == self._session_name:
100 # Only watch scans from this session
101 # FIXME: would be good to close safely this greenlet at service termination
102 g = gevent.spawn(self._listen_scan_state, scan)
103 g.link_exception(self._gevent_exception)
105 def _listen_scan_state(self, scan: Scan):
106 while scan.state < ScanState.STARTED:
107 scan.update()
109 self._scan_start(scan.key, scan.info)
110 try:
111 if use_blissdata_2_api:
112 _data_greenlet = gevent.spawn(self._listen_scan_streams2, scan)
113 else:
114 _data_greenlet = gevent.spawn(self._listen_scan_streams, scan)
115 _data_greenlet.link_exception(self._gevent_exception)
116 while scan.state < ScanState.CLOSED:
117 scan.update()
118 finally:
119 if _data_greenlet is not None:
120 _data_greenlet.join()
121 _data_greenlet = None
122 self._scan_end(scan.key, scan.info)
124 def _listen_scan_streams2(self, scan: Scan):
125 raw_scan = self._store.load_scan(scan.key)
126 cursor_group = CursorGroup(raw_scan.streams)
127 scanid = self._create_scanid(scan.key)
129 info = scan.info
130 expected_scan_size = info.get("scan_info", {}).get("npoints", 1)
132 while self._service_running:
133 try:
134 views = cursor_group.read(last_only=True)
135 except EndOfStream:
136 break
138 for stream, view in views.items():
139 channel_name = stream.name
140 if stream.kind == "array":
141 ndim = len(stream.info["shape"])
142 channel_size = view.index + len(view)
143 if ndim == 0:
144 self._emit_new_scan_data_0d_event(
145 scanid,
146 channel_name=channel_name,
147 channel_size=channel_size,
148 continue_=True,
149 )
150 elif stream.kind == "scan":
151 # Scan sequences should be handled at some point
152 continue
153 else:
154 _logger.warning("Unsupported stream kind %s", stream.kind)
155 continue
157 min_points = self._get_available_points(scan)
159 try:
160 progress = 100 * min_points / expected_scan_size
161 except ZeroDivisionError:
162 progress = 0
164 try:
165 channel_progress = 100 * channel_size / expected_scan_size
166 except ZeroDivisionError:
167 channel_progress = 0
169 self._emit_new_scan_data_event(
170 scanid,
171 "root",
172 progress,
173 channel_name,
174 channel_size,
175 channel_progress,
176 )
178 # Rate limit requests to blissdata
179 time.sleep(0.2)
181 def _listen_scan_streams(self, scan: Scan):
182 # Pure redis streams to enable parallel reading
183 raw_scan = self._store.load_scan(scan.key, RedisScan)
184 client = StreamingClient(raw_scan.streams)
185 scanid = self._create_scanid(scan.key)
187 info = scan.info
188 expected_scan_size = info.get("scan_info", {}).get("npoints", 1)
190 while self._service_running:
191 try:
192 data = client.read(count=-1)
193 except EndOfStream:
194 break
196 for stream, entries in data.items():
197 channel_name = stream.name
198 encoding_type = stream.encoding["type"]
199 index, messages = entries
200 if encoding_type == "numeric":
201 ndim = len(stream.encoding["shape"])
202 channel_size = index + len(messages)
203 if ndim == 0:
204 self._emit_new_scan_data_0d_event(
205 scanid,
206 channel_name=channel_name,
207 channel_size=channel_size,
208 continue_=True,
209 )
210 elif encoding_type == "json":
211 json_format = stream.info.get("format")
212 if json_format == "lima_v1":
213 channel_size = messages[-1].get("last_index") + 1
214 elif json_format == "lima_v2":
215 channel_size = messages[-1].get("last_index") + 1
216 elif json_format == "subscan":
217 # This should be handled at some point
218 continue
219 else:
220 _logger.warning("Unsupported json format type %s", json_format)
221 continue
222 else:
223 _logger.warning("Unsupported encoding type %s", encoding_type)
224 continue
226 min_points = self._get_available_points(scan)
228 try:
229 progress = 100 * min_points / expected_scan_size
230 except ZeroDivisionError:
231 progress = 0
233 try:
234 channel_progress = 100 * channel_size / expected_scan_size
235 except ZeroDivisionError:
236 channel_progress = 0
238 self._emit_new_scan_data_event(
239 scanid,
240 "root",
241 progress,
242 channel_name,
243 channel_size,
244 channel_progress,
245 )
247 # Rate limit requests to blissdata
248 time.sleep(0.2)
250 def _scan_start(self, scan_key: str, info: dict):
251 scanid = self._create_scanid(scan_key)
252 self._emit_new_scan_event(
253 scanid, info.get("type", "none"), info["title"], metadata=info
254 )
256 def _scan_end(self, scan_key: str, info: dict):
257 scanid = self._create_scanid(scan_key)
258 self._emit_end_scan_event(scanid, metadata=info)
260 def _get_state(self, state: ScanState, scan_info: dict):
261 """Convert blissdata scans states to abstract scan state"""
262 if state is None:
263 return "UNKNOWN"
264 if state == ScanState.CLOSED:
265 termination = scan_info.get("end_reason", None)
266 iresult = TERMINATION_MAPPING.get(termination)
267 if iresult is None:
268 _logger.error(
269 f"No state termination mapping for scan state {termination}"
270 )
271 return "UNKNOWN"
272 else:
273 iresult = STATE_MAPPING.get(state)
274 if iresult is None:
275 _logger.error(f"No state mapping for scan state {state}")
276 return "UNKNOWN"
277 return ScanStates[iresult]
279 def _get_shape(self, info):
280 """Return scan shape
282 This is very bliss specific, maybe generalise in the future
283 """
284 shape = {}
285 for k in ["npoints1", "npoints2", "dim", "requests"]:
286 shape[k] = info.get(k)
287 return shape
289 def get_scans(self, scanid: int | None = None, **kwargs):
290 def scan_to_result(scan: Scan):
291 info = scan.info
292 sobj: dict = {"scanid": self._create_scanid(scan.key)}
293 for k in [
294 "count_time",
295 "npoints",
296 "filename",
297 "title",
298 "type",
299 ]:
300 sobj[k] = info.get(k)
302 sobj["scan_number"] = info.get("scan_nb")
303 sobj["status"] = self._get_state(scan.state, info)
304 sobj["shape"] = self._get_shape(info)
305 sobj["start_timestamp"] = (
306 dateutil.parser.parse(info["start_time"]).timestamp()
307 if info.get("start_time")
308 else None
309 )
310 sobj["end_timestamp"] = (
311 dateutil.parser.parse(info["end_time"]).timestamp()
312 if info.get("end_time")
313 else None
314 )
316 if info.get("is_scan_sequence", False):
317 sobj["group"] = True
319 children: list[dict] = []
320 """
321 # FIXME: This can be reached from a dedicated channel
323 for child in scan.walk(include_filter="node_ref_channel", wait=False):
324 child_scans = child.get(0, -1)
326 for child_scan in child_scans:
327 if child_scan is None:
328 # This can disappear in between
329 continue
330 child_scan_info = child_scan.info.get_all()
331 children.append(
332 {
333 "scanid": self._scanid(child_scan_info["node_name"]),
334 "type": child_scan_info["type"],
335 }
336 )
337 """
338 sobj["children"] = children
340 else:
341 try:
342 sobj["estimated_time"] = info["estimation"]["total_time"]
343 except KeyError:
344 sobj["estimated_time"] = 0
346 xs = []
347 ys: dict = {"images": [], "scalars": [], "spectra": []}
348 for k, el in info["acquisition_chain"].items():
349 mast = el["master"]
350 for t in ["images", "scalars", "spectra"]:
351 xs.extend(mast[t])
353 for t in ["images", "scalars", "spectra"]:
354 ys[t].extend(el[t])
356 sobj["axes"] = {"xs": xs, "ys": ys}
357 return sobj
359 if scanid is not None:
360 # Special case for a single scan
361 # FIXME: A dedicated `ScanSource.get_scan` API should be exposed instead
362 scan = self._get_scan_from_scanid(scanid)
363 if scan is None:
364 return None
365 sobj = scan_to_result(scan)
366 return make_json_safe(sobj)
368 _timestamp, scan_keys = self._store.search_existing_scans(
369 session=self._session_name
370 )
371 # NOTE: sorting keys is the same as sorting by datetime of the creation
372 # NOTE: The `reversed` is a copy-paste from bliss connector
373 scan_keys = list(reversed(sorted(scan_keys)))
374 paging = get_start_end(kwargs, points=len(scan_keys))
375 filtered = scan_keys[paging["st"] : paging["en"]]
377 scans = []
378 for scan_key in filtered:
379 scan = self._store.load_scan(scan_key)
380 sobj = scan_to_result(scan)
381 scans.append(make_json_safe(sobj))
382 return {"total": len(filtered), "rows": scans}
384 def _get_scankey_from_scanid(self, scanid: int) -> str | None:
385 if scanid not in self._key_mapping:
386 # FIXME: i would be better to use scan_key in the daiquiri client
387 _timestamp, scan_keys = self._store.search_existing_scans(
388 session=self._session_name
389 )
390 for scan_key in scan_keys:
391 # We could use all this keys to update the cache
392 if self._create_scanid(scan_key) == scanid:
393 break
394 else:
395 # That could be a problem not to update the cache with None
396 return None
397 self._key_mapping[scanid] = scan_key
398 else:
399 scan_key = self._key_mapping[scanid]
401 return scan_key
403 def _get_scan_from_scanid(self, scanid: int) -> Scan | None:
404 scan_key = self._get_scankey_from_scanid(scanid)
405 if scan_key is None:
406 return None
407 return self._store.load_scan(scan_key)
409 def get_scan_data(
410 self, scanid, json_safe=True, scalars=None, all_scalars=False, **kwargs
411 ):
412 scan = self._get_scan_from_scanid(scanid)
413 if scan is None:
414 return {}
416 sobj: Dict[str, Any] = {"data": {}, "info": {}}
417 info = scan.info
419 min_points = self._get_available_points(scan)
420 sobj["npoints_avail"] = min_points
422 paging = get_start_end(kwargs, points=min_points, last=True)
423 sobj["page"] = paging["page"]
424 sobj["pages"] = paging["pages"]
425 sobj["per_page"] = paging["per_page"]
426 sobj["scanid"] = scanid
427 sobj["npoints"] = info.get("npoints")
428 sobj["shape"] = self._get_shape(info)
430 xs = []
431 ys: dict = {"images": [], "scalars": [], "spectra": []}
432 for el in info["acquisition_chain"].values():
433 mast = el["master"]
434 for t in ["images", "scalars", "spectra"]:
435 xs.extend(mast[t])
437 for t in ["images", "scalars", "spectra"]:
438 ys[t].extend(el[t])
440 sobj["axes"] = {"xs": xs, "ys": ys}
442 scalarid = 0
443 if scalars is None:
444 scalars = []
446 for stream in scan.streams.values():
447 if stream.name == "SUBSCANS":
448 # That's a special stream, better to skip it
449 continue
450 sobj["data"][stream.name] = {
451 "name": stream.name,
452 "shape": stream.info["shape"],
453 "size": len(stream),
454 "dtype": numpy.dtype(stream.info["dtype"]).str,
455 }
457 data = numpy.array([])
458 if len(stream.info["shape"]) == 0 and paging["en"] > paging["st"]:
459 if (
460 stream.name not in scalars
461 and stream.name not in xs
462 and not all_scalars
463 and (
464 (
465 len(scalars) == 0
466 and (stream.name.startswith("timer") or scalarid >= 2)
467 )
468 or len(scalars) > 0
469 )
470 ):
471 continue
473 # NOTE: Data size could be smaller than what it is requested
474 data = stream[paging["st"] : paging["en"]]
476 # TODO: convert nan -> None
477 # TODO: Make sure the browser doesnt interpret as infinity (1e308)
478 data = numpy.nan_to_num(data, posinf=1e200, neginf=-1e200)
479 sobj["data"][stream.name]["data"] = data
481 if not stream.name.startswith("timer") and stream.name not in xs:
482 scalarid += 1
484 if json_safe:
485 sobj = make_json_safe(sobj)
486 return sobj
488 def _get_available_points(self, scan: Scan) -> int:
489 """TODO: This is problematic because len(node) actually has to
490 retrieve the data, for a scan with ~2000 ish points it takes
491 of the order of 500ms
492 """
493 min_points = None
494 for stream in scan.streams.values():
495 nb = len(stream)
496 if min_points is None or nb < min_points:
497 min_points = nb
499 if min_points is None:
500 return 0
501 return min_points
503 def get_scan_spectra(self, scanid, point=0, allpoints=False):
504 log_failure = _logger.exception if self._app.debug else _logger.info
505 scan = self._get_scan_from_scanid(scanid)
506 if scan is None:
507 return None
509 info = scan.info
511 min_points = self._get_available_points(scan)
513 spectra = {}
514 for stream in scan.streams.values():
515 shape = stream.info.get("shape")
516 if stream is None:
517 continue
519 if len(shape) == 1:
520 data = numpy.array([])
521 try:
522 data = (
523 stream[0 : int(info.get("npoints", 0))]
524 if allpoints
525 else numpy.array([stream[point]])
526 )
527 except (RuntimeError, IndexError, TypeError):
528 log_failure(
529 f"Couldnt get scan spectra for {stream.name}. Requested 0 to {info.get('npoints')}, node length {len(stream)}"
530 )
531 return None
533 spectra[stream.name] = {"data": data, "name": stream.name}
535 return make_json_safe(
536 {
537 "scanid": scanid,
538 "data": spectra,
539 "npoints": info.get("npoints"),
540 "npoints_avail": min_points,
541 "conversion": self.get_conversion(),
542 }
543 )
545 def get_scan_image(self, scanid: int, node_name: str, image_no: int):
546 scan = self._get_scan_from_scanid(scanid)
547 if scan is None:
548 return None
549 stream = scan.streams.get(node_name)
550 if stream is None:
551 return None
552 return stream[image_no]