Coverage for /opt/conda/envs/apienv/lib/python3.10/site-packages/daiquiri/core/hardware/blissdata/scans.py: 68%
324 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-02-06 02:13 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-02-06 02:13 +0000
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4from __future__ import annotations
5from typing import Any, Dict
6import dateutil.parser
8import gevent
9import numpy
10import logging
11import lru
12import time
14from blissdata.scan import Scan
15from blissdata.redis_engine.exceptions import NoScanAvailable
16from blissdata.beacon.data import BeaconData
17from blissdata.redis_engine.store import DataStore
18from blissdata.redis_engine.exceptions import EndOfStream
19from blissdata.redis_engine.scan import ScanState
20from daiquiri.core.hardware.abstract.scansource import ScanSource, ScanStates
21from daiquiri.core.utils import get_start_end, make_json_safe
23# blissdata 2
24# https://gitlab.esrf.fr/bliss/bliss/-/commit/dfbe24b9bee401eabd296664466df93b6f9844d1
25try:
26 from blissdata.streams.base import CursorGroup
28 use_blissdata_2_api = True
29# blissdata < 2
30except ModuleNotFoundError:
31 from blissdata.redis_engine.stream import StreamingClient
32 from blissdata.redis_engine.scan import Scan as RedisScan
34 use_blissdata_2_api = False
36_logger = logging.getLogger(__name__)
38STATE_MAPPING = {
39 ScanState.CREATED: 0,
40 ScanState.PREPARED: 1,
41 ScanState.STARTED: 2,
42 ScanState.STOPPED: 3,
43 ScanState.CLOSED: 5,
44}
46TERMINATION_MAPPING = {
47 "SUCCESS": 3,
48 "USER_ABORT": 4,
49 "FAILURE": 5,
50 "DELETION": 5,
51}
54class BlissdataScans(ScanSource):
55 UPDATE_PERIOD = 0.1
56 """Limiting update rate."""
58 def __init__(self, config: dict, app=None, source_config={}):
59 super().__init__(config, app=app, source_config=source_config)
60 self._key_mapping: lru.LRU = lru.LRU(256)
62 url = source_config.get("redis_url")
63 if not url:
64 url = BeaconData().get_redis_data_db()
65 self._store: DataStore = DataStore(url)
67 _logger.info("Initialising Blissdata Scan Watcher")
69 self._service_running: bool = True
70 self._task = gevent.spawn(self._listen_scans)
71 self._task.link_exception(self._gevent_exception)
73 def close(self):
74 self._service_running = False
75 self._task.join()
76 self._task = None
77 self._store = None
79 def _gevent_exception(self, greenlet):
80 """Process gevent exception"""
81 try:
82 greenlet.get()
83 except Exception:
84 _logger.error(
85 "Error while executing greenlet %s", greenlet.name, exc_info=True
86 )
88 def _listen_scans(self):
89 timestamp_cursor = None
90 while self._service_running:
91 try:
92 timestamp_cursor, scan_key = self._store.get_next_scan(
93 since=timestamp_cursor, block=True, timeout=1
94 )
95 except NoScanAvailable:
96 continue
97 self._key_mapping[self._create_scanid(scan_key)] = scan_key
98 scan = self._store.load_scan(scan_key)
99 if scan.session == self._session_name:
100 # Only watch scans from this session
101 # FIXME: would be good to close safely this greenlet at service termination
102 g = gevent.spawn(self._listen_scan_state, scan)
103 g.link_exception(self._gevent_exception)
105 def _listen_scan_state(self, scan: Scan):
106 while scan.state < ScanState.STARTED:
107 scan.update()
109 self._scan_start(scan.key, scan.info)
110 try:
111 if use_blissdata_2_api:
112 _data_greenlet = gevent.spawn(self._listen_scan_streams2, scan)
113 else:
114 _data_greenlet = gevent.spawn(self._listen_scan_streams, scan)
115 _data_greenlet.link_exception(self._gevent_exception)
116 while scan.state < ScanState.CLOSED:
117 scan.update()
118 finally:
119 if _data_greenlet is not None:
120 _data_greenlet.join()
121 _data_greenlet = None
122 self._scan_end(scan.key, scan.info)
124 def _listen_scan_streams2(self, scan: Scan):
125 raw_scan = self._store.load_scan(scan.key)
126 cursor_group = CursorGroup(raw_scan.streams)
127 scanid = self._create_scanid(scan.key)
129 info = scan.info
130 expected_scan_size = info.get("scan_info", {}).get("npoints", 1)
132 while self._service_running:
133 try:
134 views = cursor_group.read(last_only=True)
135 except EndOfStream:
136 break
138 for stream, view in views.items():
139 channel_name = stream.name
140 if stream.kind == "array":
141 ndim = len(stream.info["shape"])
142 channel_size = view.index + len(view)
143 if ndim == 0:
144 self._emit_new_scan_data_0d_event(
145 scanid,
146 channel_name=channel_name,
147 channel_size=channel_size,
148 continue_=True,
149 )
150 elif stream.kind == "scan":
151 # Scan sequences should be handled at some point
152 continue
153 else:
154 _logger.warning("Unsupported stream kind %s", stream.kind)
155 continue
157 min_points = self._get_available_points(scan)
159 try:
160 progress = 100 * min_points / expected_scan_size
161 except ZeroDivisionError:
162 progress = 0
164 try:
165 channel_progress = 100 * channel_size / expected_scan_size
166 except ZeroDivisionError:
167 channel_progress = 0
169 self._emit_new_scan_data_event(
170 scanid,
171 "root",
172 progress,
173 channel_name,
174 channel_size,
175 channel_progress,
176 )
178 # Rate limit requests to blissdata
179 time.sleep(0.2)
181 def _listen_scan_streams(self, scan: Scan):
182 # Pure redis streams to enable parallel reading
183 raw_scan = self._store.load_scan(scan.key, RedisScan)
184 client = StreamingClient(raw_scan.streams)
185 scanid = self._create_scanid(scan.key)
187 info = scan.info
188 expected_scan_size = info.get("scan_info", {}).get("npoints", 1)
190 while self._service_running:
191 try:
192 data = client.read(count=-1)
193 except EndOfStream:
194 break
196 for stream, entries in data.items():
197 channel_name = stream.name
198 encoding_type = stream.encoding["type"]
199 index, messages = entries
200 if encoding_type == "numeric":
201 ndim = len(stream.encoding["shape"])
202 channel_size = index + len(messages)
203 if ndim == 0:
204 self._emit_new_scan_data_0d_event(
205 scanid,
206 channel_name=channel_name,
207 channel_size=channel_size,
208 continue_=True,
209 )
210 elif encoding_type == "json":
211 json_format = stream.info.get("format")
212 if json_format == "lima_v1":
213 channel_size = messages[-1].get("last_index") + 1
214 elif json_format == "lima_v2":
215 channel_size = messages[-1].get("last_index") + 1
216 elif json_format == "subscan":
217 # This should be handled at some point
218 continue
219 else:
220 _logger.warning("Unsupported json format type %s", json_format)
221 continue
222 else:
223 _logger.warning("Unsupported encoding type %s", encoding_type)
224 continue
226 min_points = self._get_available_points(scan)
228 try:
229 progress = 100 * min_points / expected_scan_size
230 except ZeroDivisionError:
231 progress = 0
233 try:
234 channel_progress = 100 * channel_size / expected_scan_size
235 except ZeroDivisionError:
236 channel_progress = 0
238 self._emit_new_scan_data_event(
239 scanid,
240 "root",
241 progress,
242 channel_name,
243 channel_size,
244 channel_progress,
245 )
247 # Rate limit requests to blissdata
248 time.sleep(0.2)
250 def _scan_start(self, scan_key: str, info: dict):
251 scanid = self._create_scanid(scan_key)
252 self._emit_new_scan_event(
253 scanid, info.get("type", "none"), info["title"], metadata=info
254 )
256 def _scan_end(self, scan_key: str, info: dict):
257 scanid = self._create_scanid(scan_key)
258 self._emit_end_scan_event(scanid, metadata=info)
260 def _get_state(self, state: ScanState, scan_info: dict):
261 """Convert blissdata scans states to abstract scan state"""
262 if state is None:
263 return "UNKNOWN"
264 if state == ScanState.CLOSED:
265 termination = scan_info.get("end_reason", None)
266 iresult = TERMINATION_MAPPING.get(termination)
267 if iresult is None:
268 _logger.error(
269 f"No state termination mapping for scan state {termination}"
270 )
271 return "UNKNOWN"
272 else:
273 iresult = STATE_MAPPING.get(state)
274 if iresult is None:
275 _logger.error(f"No state mapping for scan state {state}")
276 return "UNKNOWN"
277 return ScanStates[iresult]
279 def _get_shape(self, info):
280 """Return scan shape
282 This is very bliss specific, maybe generalise in the future
283 """
284 shape = {}
285 for k in ["npoints1", "npoints2", "dim", "requests"]:
286 shape[k] = info.get(k)
287 return shape
289 def get_scans(self, scanid: int | None = None, **kwargs):
290 def scan_to_result(scan: Scan):
291 info = scan.info
292 sobj: dict = {"scanid": self._create_scanid(scan.key)}
293 for k in [
294 "count_time",
295 "npoints",
296 "filename",
297 "title",
298 "type",
299 ]:
300 sobj[k] = info.get(k)
302 sobj["scan_number"] = info.get("scan_nb")
303 sobj["status"] = self._get_state(scan.state, info)
304 sobj["shape"] = self._get_shape(info)
305 sobj["start_timestamp"] = (
306 dateutil.parser.parse(info["start_time"]).timestamp()
307 if info.get("start_time")
308 else None
309 )
310 sobj["end_timestamp"] = (
311 dateutil.parser.parse(info["end_time"]).timestamp()
312 if info.get("end_time")
313 else None
314 )
316 if info.get("is_scan_sequence", False):
317 sobj["group"] = True
319 children: list[dict] = []
320 child_channel_values = list(scan.streams["SUBSCANS"])
321 for child_scan in child_channel_values:
322 # In bliss 2.2 blissdata plugins automatically resolve child_scans to their `Scan` object
323 if not isinstance(child_scan, Scan):
324 child_scan = self._store.load_scan(child_scan["key"])
325 children.append(
326 {
327 "scanid": self._create_scanid(child_scan.key),
328 "type": child_scan.info.get("type"),
329 "node": child_scan.key,
330 }
331 )
332 sobj["children"] = children
334 else:
335 try:
336 sobj["estimated_time"] = info["estimation"]["total_time"]
337 except KeyError:
338 sobj["estimated_time"] = 0
340 xs = []
341 ys: dict = {"images": [], "scalars": [], "spectra": []}
342 for k, el in info["acquisition_chain"].items():
343 mast = el["master"]
344 for t in ["images", "scalars", "spectra"]:
345 xs.extend(mast.get(t, []))
347 for t in ["images", "scalars", "spectra"]:
348 ys[t].extend(el.get(t, []))
350 sobj["axes"] = {"xs": xs, "ys": ys}
351 return sobj
353 if scanid is not None:
354 # Special case for a single scan
355 # FIXME: A dedicated `ScanSource.get_scan` API should be exposed instead
356 scan = self._get_scan_from_scanid(scanid)
357 if scan is None:
358 return None
359 sobj = scan_to_result(scan)
360 return make_json_safe(sobj)
362 _timestamp, scan_keys = self._store.search_existing_scans(
363 session=self._session_name
364 )
365 # NOTE: sorting keys is the same as sorting by datetime of the creation
366 # NOTE: The `reversed` is a copy-paste from bliss connector
367 scan_keys = list(reversed(sorted(scan_keys)))
368 paging = get_start_end(kwargs, points=len(scan_keys))
369 filtered = scan_keys[paging["st"] : paging["en"]]
371 scans = []
372 for scan_key in filtered:
373 scan = self._store.load_scan(scan_key)
374 sobj = scan_to_result(scan)
375 scans.append(make_json_safe(sobj))
376 return {"total": len(filtered), "rows": scans}
378 def _get_scankey_from_scanid(self, scanid: int) -> str | None:
379 if scanid not in self._key_mapping:
380 # FIXME: i would be better to use scan_key in the daiquiri client
381 _timestamp, scan_keys = self._store.search_existing_scans(
382 session=self._session_name
383 )
384 for scan_key in scan_keys:
385 # We could use all this keys to update the cache
386 if self._create_scanid(scan_key) == scanid:
387 break
388 else:
389 # That could be a problem not to update the cache with None
390 return None
391 self._key_mapping[scanid] = scan_key
392 else:
393 scan_key = self._key_mapping[scanid]
395 return scan_key
397 def _get_scan_from_scanid(self, scanid: int) -> Scan | None:
398 scan_key = self._get_scankey_from_scanid(scanid)
399 if scan_key is None:
400 return None
401 return self._store.load_scan(scan_key)
403 def get_scan_data(
404 self, scanid, json_safe=True, scalars=None, all_scalars=False, **kwargs
405 ):
406 scan = self._get_scan_from_scanid(scanid)
407 if scan is None:
408 return {}
410 sobj: Dict[str, Any] = {"data": {}, "info": {}}
411 info = scan.info
413 min_points = self._get_available_points(scan)
414 sobj["npoints_avail"] = min_points
416 paging = get_start_end(kwargs, points=min_points, last=True)
417 sobj["page"] = paging["page"]
418 sobj["pages"] = paging["pages"]
419 sobj["per_page"] = paging["per_page"]
420 sobj["scanid"] = scanid
421 sobj["npoints"] = info.get("npoints")
422 sobj["shape"] = self._get_shape(info)
424 xs = []
425 ys: dict = {"images": [], "scalars": [], "spectra": []}
426 for el in info["acquisition_chain"].values():
427 mast = el["master"]
428 for t in ["images", "scalars", "spectra"]:
429 xs.extend(mast.get(t, []))
431 for t in ["images", "scalars", "spectra"]:
432 ys[t].extend(el.get(t, []))
434 sobj["axes"] = {"xs": xs, "ys": ys}
436 scalarid = 0
437 if scalars is None:
438 scalars = []
440 for stream in scan.streams.values():
441 if stream.name == "SUBSCANS":
442 # That's a special stream, better to skip it
443 continue
444 sobj["data"][stream.name] = {
445 "name": stream.name,
446 "shape": stream.info.get("shape"),
447 "size": len(stream),
448 "dtype": (
449 numpy.dtype(stream.info["dtype"]).str
450 if stream.info.get("dtype")
451 else None
452 ),
453 }
455 data = numpy.array([])
456 if len(stream.info["shape"]) == 0 and paging["en"] > paging["st"]:
457 if (
458 stream.name not in scalars
459 and stream.name not in xs
460 and not all_scalars
461 and (
462 (
463 len(scalars) == 0
464 and (stream.name.startswith("timer") or scalarid >= 2)
465 )
466 or len(scalars) > 0
467 )
468 ):
469 continue
471 # NOTE: Data size could be smaller than what it is requested
472 data = stream[paging["st"] : paging["en"]]
474 # TODO: convert nan -> None
475 # TODO: Make sure the browser doesnt interpret as infinity (1e308)
476 data = numpy.nan_to_num(data, posinf=1e200, neginf=-1e200)
477 sobj["data"][stream.name]["data"] = data
479 if not stream.name.startswith("timer") and stream.name not in xs:
480 scalarid += 1
482 if json_safe:
483 sobj = make_json_safe(sobj)
484 return sobj
486 def _get_available_points(self, scan: Scan) -> int:
487 """TODO: This is problematic because len(node) actually has to
488 retrieve the data, for a scan with ~2000 ish points it takes
489 of the order of 500ms
490 """
491 min_points = None
492 for stream in scan.streams.values():
493 nb = len(stream)
494 if min_points is None or nb < min_points:
495 min_points = nb
497 if min_points is None:
498 return 0
499 return min_points
501 def get_scan_spectra(self, scanid, point=0, allpoints=False):
502 log_failure = _logger.exception if self._app.debug else _logger.info
503 scan = self._get_scan_from_scanid(scanid)
504 if scan is None:
505 return None
507 info = scan.info
509 min_points = self._get_available_points(scan)
511 spectra = {}
512 for stream in scan.streams.values():
513 shape = stream.info.get("shape")
514 if stream is None:
515 continue
517 if len(shape) == 1:
518 data = numpy.array([])
519 try:
520 data = (
521 stream[0 : int(info.get("npoints", 0))]
522 if allpoints
523 else numpy.array([stream[point]])
524 )
525 except (RuntimeError, IndexError, TypeError):
526 log_failure(
527 f"Couldnt get scan spectra for {stream.name}. Requested 0 to {info.get('npoints')}, node length {len(stream)}"
528 )
529 return None
531 spectra[stream.name] = {"data": data, "name": stream.name}
533 return make_json_safe(
534 {
535 "scanid": scanid,
536 "data": spectra,
537 "npoints": info.get("npoints"),
538 "npoints_avail": min_points,
539 "conversion": self.get_conversion(),
540 }
541 )
543 def get_scan_image(self, scanid: int, node_name: str, image_no: int):
544 scan = self._get_scan_from_scanid(scanid)
545 if scan is None:
546 return None
547 stream = scan.streams.get(node_name)
548 if stream is None:
549 return None
550 return stream[image_no]