Coverage for /opt/conda/envs/apienv/lib/python3.10/site-packages/daiquiri/core/utils/__init__.py: 65%
209 statements
« prev ^ index » next coverage.py v7.6.5, created at 2024-11-15 02:12 +0000
« prev ^ index » next coverage.py v7.6.5, created at 2024-11-15 02:12 +0000
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import importlib
4import threading
5import time
6import math
7import logging
8import inspect
9import gevent
10import cProfile
11import pstats
12from io import StringIO
13from contextlib import contextmanager
14from functools import wraps
15import numpy as np
16from pint import UnitRegistry
17from copy import deepcopy
19logger = logging.getLogger(__name__)
20ureg = UnitRegistry()
23@contextmanager
24def profiler(sortby="cumtime", filename=None):
25 """Wraps the context in a cProile instance
27 Kwargs:
28 sortby (str): sort time profile
29 filename (str): filename to save to
30 """
31 pr = cProfile.Profile()
32 pr.enable()
33 try:
34 yield
35 finally:
36 pr.disable()
38 s = StringIO()
39 ps = pstats.Stats(pr, stream=s)
41 if filename:
42 ps.dump_stats(filename)
45def timed(fn):
46 """Decorator to time function execution
48 Should work with class an static methods
49 """
51 @wraps(fn)
52 def wrapper(*args, **kwargs):
53 params = inspect.signature(fn).parameters
54 if params.get("self"):
55 cln = args[0].__class__.__name__
56 else:
57 cln = "Static"
59 start = time.time()
60 ret = fn(*args, **kwargs)
61 took = round((time.time() - start) * 1000)
63 logger.debug(f"{cln}::{fn.__name__} took {took} ms")
65 return ret
67 return wrapper
70def dict_nd_to_list(dct):
71 """Recursively replace ndarray with list in a dictionary"""
72 recursed = {}
74 for k, v in dct.items():
75 if type(v) is dict:
76 recursed[k] = dict_nd_to_list(v)
77 else:
78 if isinstance(v, np.ndarray):
79 recursed[k] = v.tolist()
80 else:
81 recursed[k] = v
83 return recursed
86def debounce(wait):
87 """Decorator that will postpone a functions
88 execution until after wait seconds
89 have elapsed since the last time it was invoked.
90 https://gist.github.com/walkermatt/2871026
91 """
93 def decorator(fn):
94 def debounced(*args, **kwargs):
95 def call_it():
96 debounced._timer = None
97 debounced._last_call = time.time()
98 return fn(*args, **kwargs)
100 time_since_last_call = time.time() - debounced._last_call
101 if time_since_last_call >= wait:
102 return call_it()
104 if debounced._timer is None:
105 debounced._timer = threading.Timer(wait - time_since_last_call, call_it)
106 debounced._timer.start()
108 debounced._timer = None
109 debounced._last_call = 0
111 return debounced
113 return decorator
116def debounce_with_table(wait):
117 last_call = {}
119 def decorator(fn):
120 def debounced(self, obj, prop, value):
121 pid = obj.id() + prop
123 def call_it(_pid):
124 val = last_call[_pid]["value"]
125 del last_call[_pid]
126 return fn(self, obj, prop, val)
128 if not (pid in last_call):
129 timer = threading.Timer(wait, call_it, [pid])
130 timer.start()
131 last_call[pid] = {"timer": timer, "last_call": time.time()}
132 last_call[pid]["value"] = value
134 else:
135 data = last_call[pid]
136 time_since_last_call = time.time() - data["last_call"]
137 last_call[pid]["value"] = value
139 if time_since_last_call >= wait:
140 return call_it(pid)
142 return debounced
144 return decorator
147# Thieved from http://code.activestate.com/recipes/577346-getattr-with-arbitrary-depth/
148def get_nested_attr(obj, attr, **kw):
149 attributes = attr.split(".")
150 for i in attributes:
151 try:
152 obj = getattr(obj, i)
153 if callable(obj):
154 obj = obj()
155 except AttributeError:
156 if "default" in kw:
157 return kw["default"]
158 else:
159 raise
160 return obj
163def loader(base, postfix, module, *args, **kwargs):
164 """Try loading class "{module.title()}{postfix}" from
165 "{base}{module}" and instantiate it.
167 For example "ExampleActor" from "daiquiri.implementors.examplecomponent".
169 :param str base: base module
170 :param str postfix: add to module name to get class name
171 :param str module: submodule
172 :param args: for class instantiation
173 :param kwargs: for class instantiation
174 :returns Any:
175 """
176 # Making sure that everything up to the last module
177 # is treated as base module
178 if "." in module:
179 mod_parts = module.split(".")
180 base = base + "." + ".".join(mod_parts[0:-1])
181 module = mod_parts[-1]
182 # Load class from module
183 mod_file = base + "." + module
184 try:
185 mod = importlib.import_module(mod_file)
186 mod = importlib.reload(mod)
187 except ModuleNotFoundError:
188 err_msg = "Couldn't find module {}".format(mod_file)
189 logger.error(err_msg)
190 raise
191 # Import class
192 if hasattr(mod, "Default"):
193 class_name = "Default"
194 else:
195 class_name = module.title() + postfix
196 try:
197 cls = getattr(mod, class_name)
198 except AttributeError:
199 err_msg = "Couldn't import '{}' from {}".format(class_name, mod_file)
200 logger.error(err_msg)
201 raise
202 # Instantiate class
203 instance = cls(*args, **kwargs)
204 logger.debug(f"Instantiated '{class_name}' from {mod_file}")
205 return instance
208def format_eng(scalar):
209 power_prefix = {
210 24: "Y",
211 21: "Z",
212 18: "E",
213 15: "P",
214 12: "T",
215 9: "G",
216 6: "M",
217 3: "k",
218 -3: "m",
219 -6: "\u00B5",
220 -9: "n",
221 -12: "p",
222 -15: "f",
223 -18: "a",
224 -21: "z",
225 -24: "y",
226 }
228 q = math.log(scalar) / math.log(1e3)
229 subtrhnd = not q.is_integer()
231 pow_10 = 3 * math.ceil(q - subtrhnd)
232 prefix = "" if pow_10 == 0 else power_prefix[pow_10]
234 return {
235 "scalar": scalar * math.pow(10, -pow_10),
236 "prefix": prefix,
237 "multiplier": math.pow(10, -pow_10),
238 }
241def get_start_end(kwargs, points=None, last=False, default_per_page=25):
242 per_page = kwargs.get("per_page", default_per_page)
243 page = kwargs.get("page", None)
244 pages = None
246 if points:
247 pages = math.ceil(points / per_page)
249 if page is None:
250 if last:
251 st = max(0, points - per_page)
252 en = points
253 page = pages
255 else:
256 st = 0
257 en = per_page
258 page = 1
260 else:
261 st = per_page * (page - 1)
262 en = per_page + per_page * (page - 1)
264 return {"st": st, "en": en, "page": page, "pages": pages, "per_page": per_page}
267def to_wavelength(energy):
268 return (
269 ((ureg.planck_constant * ureg.c) / (energy * ureg.eV).to(ureg.J))
270 .to(ureg.angstrom)
271 .magnitude
272 )
275def to_energy(wavelength):
276 return (
277 ((ureg.planck_constant * ureg.c) / (wavelength * ureg.angstrom))
278 .to(ureg.eV)
279 .magnitude
280 )
283def make_json_safe(obj):
284 """Make sure a dict is json encodable
286 Converts np types to int, float, list
288 Args:
289 obj(dict): The dict to make safe
290 """
291 new_obj = {}
292 for k, v in obj.items():
293 new_obj[k] = _make_json_safe(v)
295 return new_obj
298def _make_json_safe(v):
299 if isinstance(v, dict):
300 return make_json_safe(v)
302 elif isinstance(v, list):
303 return [_make_json_safe(i) for i in v]
305 elif isinstance(v, np.integer):
306 return int(v)
308 elif isinstance(v, np.floating):
309 return float(v)
311 elif isinstance(v, np.ndarray):
312 return v.tolist()
314 elif isinstance(v, np.dtype):
315 return str(v)
317 elif isinstance(v, np.bool_):
318 return bool(v)
320 return v
323def deep_diff(x, y, exclude_keys=[]):
324 """Take the deep diff of JSON-like dictionaries
326 https://stackoverflow.com/a/62124626
327 No warranties when keys, or values are None
328 """
329 if x == y:
330 return None
332 if type(x) is not type(y) or type(x) not in [list, dict]:
333 return y # (x, y)
335 if isinstance(x, dict):
336 d = {}
337 for k in x.keys() ^ y.keys():
338 if k in exclude_keys:
339 continue
340 if k in x:
341 d[k] = (deepcopy(x[k]), None)
342 else:
343 d[k] = (None, deepcopy(y[k]))
345 for k in x.keys() & y.keys():
346 if k in exclude_keys:
347 continue
349 next_d = deep_diff(x[k], y[k])
350 if next_d is None:
351 continue
353 d[k] = next_d
355 return d if len(d) else None
357 # must be list:
358 d = [None] * max(len(x), len(y))
359 flipped = False
360 if len(x) > len(y):
361 flipped = True
362 x, y = y, x
364 for i, x_val in enumerate(x):
365 d[i] = deep_diff(y[i], x_val) if flipped else deep_diff(x_val, y[i])
367 for i in range(len(x), len(y)):
368 d[i] = (y[i], None) if flipped else (None, y[i])
371def worker(fn):
372 """Run a function in a real thread"""
373 pool = gevent.get_hub().threadpool
374 return pool.spawn(fn).get()
377def to_unit(value: float, input_unit: str, output_unit: str) -> float:
378 return (value * ureg(input_unit)).to(ureg(output_unit)).magnitude