Coverage for /opt/conda/envs/apienv/lib/python3.10/site-packages/daiquiri/core/utils/__init__.py: 65%

209 statements  

« prev     ^ index     » next       coverage.py v7.6.5, created at 2024-11-15 02:12 +0000

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3import importlib 

4import threading 

5import time 

6import math 

7import logging 

8import inspect 

9import gevent 

10import cProfile 

11import pstats 

12from io import StringIO 

13from contextlib import contextmanager 

14from functools import wraps 

15import numpy as np 

16from pint import UnitRegistry 

17from copy import deepcopy 

18 

19logger = logging.getLogger(__name__) 

20ureg = UnitRegistry() 

21 

22 

23@contextmanager 

24def profiler(sortby="cumtime", filename=None): 

25 """Wraps the context in a cProile instance 

26 

27 Kwargs: 

28 sortby (str): sort time profile 

29 filename (str): filename to save to 

30 """ 

31 pr = cProfile.Profile() 

32 pr.enable() 

33 try: 

34 yield 

35 finally: 

36 pr.disable() 

37 

38 s = StringIO() 

39 ps = pstats.Stats(pr, stream=s) 

40 

41 if filename: 

42 ps.dump_stats(filename) 

43 

44 

45def timed(fn): 

46 """Decorator to time function execution 

47 

48 Should work with class an static methods 

49 """ 

50 

51 @wraps(fn) 

52 def wrapper(*args, **kwargs): 

53 params = inspect.signature(fn).parameters 

54 if params.get("self"): 

55 cln = args[0].__class__.__name__ 

56 else: 

57 cln = "Static" 

58 

59 start = time.time() 

60 ret = fn(*args, **kwargs) 

61 took = round((time.time() - start) * 1000) 

62 

63 logger.debug(f"{cln}::{fn.__name__} took {took} ms") 

64 

65 return ret 

66 

67 return wrapper 

68 

69 

70def dict_nd_to_list(dct): 

71 """Recursively replace ndarray with list in a dictionary""" 

72 recursed = {} 

73 

74 for k, v in dct.items(): 

75 if type(v) is dict: 

76 recursed[k] = dict_nd_to_list(v) 

77 else: 

78 if isinstance(v, np.ndarray): 

79 recursed[k] = v.tolist() 

80 else: 

81 recursed[k] = v 

82 

83 return recursed 

84 

85 

86def debounce(wait): 

87 """Decorator that will postpone a functions 

88 execution until after wait seconds 

89 have elapsed since the last time it was invoked. 

90 https://gist.github.com/walkermatt/2871026 

91 """ 

92 

93 def decorator(fn): 

94 def debounced(*args, **kwargs): 

95 def call_it(): 

96 debounced._timer = None 

97 debounced._last_call = time.time() 

98 return fn(*args, **kwargs) 

99 

100 time_since_last_call = time.time() - debounced._last_call 

101 if time_since_last_call >= wait: 

102 return call_it() 

103 

104 if debounced._timer is None: 

105 debounced._timer = threading.Timer(wait - time_since_last_call, call_it) 

106 debounced._timer.start() 

107 

108 debounced._timer = None 

109 debounced._last_call = 0 

110 

111 return debounced 

112 

113 return decorator 

114 

115 

116def debounce_with_table(wait): 

117 last_call = {} 

118 

119 def decorator(fn): 

120 def debounced(self, obj, prop, value): 

121 pid = obj.id() + prop 

122 

123 def call_it(_pid): 

124 val = last_call[_pid]["value"] 

125 del last_call[_pid] 

126 return fn(self, obj, prop, val) 

127 

128 if not (pid in last_call): 

129 timer = threading.Timer(wait, call_it, [pid]) 

130 timer.start() 

131 last_call[pid] = {"timer": timer, "last_call": time.time()} 

132 last_call[pid]["value"] = value 

133 

134 else: 

135 data = last_call[pid] 

136 time_since_last_call = time.time() - data["last_call"] 

137 last_call[pid]["value"] = value 

138 

139 if time_since_last_call >= wait: 

140 return call_it(pid) 

141 

142 return debounced 

143 

144 return decorator 

145 

146 

147# Thieved from http://code.activestate.com/recipes/577346-getattr-with-arbitrary-depth/ 

148def get_nested_attr(obj, attr, **kw): 

149 attributes = attr.split(".") 

150 for i in attributes: 

151 try: 

152 obj = getattr(obj, i) 

153 if callable(obj): 

154 obj = obj() 

155 except AttributeError: 

156 if "default" in kw: 

157 return kw["default"] 

158 else: 

159 raise 

160 return obj 

161 

162 

163def loader(base, postfix, module, *args, **kwargs): 

164 """Try loading class "{module.title()}{postfix}" from 

165 "{base}{module}" and instantiate it. 

166 

167 For example "ExampleActor" from "daiquiri.implementors.examplecomponent". 

168 

169 :param str base: base module 

170 :param str postfix: add to module name to get class name 

171 :param str module: submodule 

172 :param args: for class instantiation 

173 :param kwargs: for class instantiation 

174 :returns Any: 

175 """ 

176 # Making sure that everything up to the last module 

177 # is treated as base module 

178 if "." in module: 

179 mod_parts = module.split(".") 

180 base = base + "." + ".".join(mod_parts[0:-1]) 

181 module = mod_parts[-1] 

182 # Load class from module 

183 mod_file = base + "." + module 

184 try: 

185 mod = importlib.import_module(mod_file) 

186 mod = importlib.reload(mod) 

187 except ModuleNotFoundError: 

188 err_msg = "Couldn't find module {}".format(mod_file) 

189 logger.error(err_msg) 

190 raise 

191 # Import class 

192 if hasattr(mod, "Default"): 

193 class_name = "Default" 

194 else: 

195 class_name = module.title() + postfix 

196 try: 

197 cls = getattr(mod, class_name) 

198 except AttributeError: 

199 err_msg = "Couldn't import '{}' from {}".format(class_name, mod_file) 

200 logger.error(err_msg) 

201 raise 

202 # Instantiate class 

203 instance = cls(*args, **kwargs) 

204 logger.debug(f"Instantiated '{class_name}' from {mod_file}") 

205 return instance 

206 

207 

208def format_eng(scalar): 

209 power_prefix = { 

210 24: "Y", 

211 21: "Z", 

212 18: "E", 

213 15: "P", 

214 12: "T", 

215 9: "G", 

216 6: "M", 

217 3: "k", 

218 -3: "m", 

219 -6: "\u00B5", 

220 -9: "n", 

221 -12: "p", 

222 -15: "f", 

223 -18: "a", 

224 -21: "z", 

225 -24: "y", 

226 } 

227 

228 q = math.log(scalar) / math.log(1e3) 

229 subtrhnd = not q.is_integer() 

230 

231 pow_10 = 3 * math.ceil(q - subtrhnd) 

232 prefix = "" if pow_10 == 0 else power_prefix[pow_10] 

233 

234 return { 

235 "scalar": scalar * math.pow(10, -pow_10), 

236 "prefix": prefix, 

237 "multiplier": math.pow(10, -pow_10), 

238 } 

239 

240 

241def get_start_end(kwargs, points=None, last=False, default_per_page=25): 

242 per_page = kwargs.get("per_page", default_per_page) 

243 page = kwargs.get("page", None) 

244 pages = None 

245 

246 if points: 

247 pages = math.ceil(points / per_page) 

248 

249 if page is None: 

250 if last: 

251 st = max(0, points - per_page) 

252 en = points 

253 page = pages 

254 

255 else: 

256 st = 0 

257 en = per_page 

258 page = 1 

259 

260 else: 

261 st = per_page * (page - 1) 

262 en = per_page + per_page * (page - 1) 

263 

264 return {"st": st, "en": en, "page": page, "pages": pages, "per_page": per_page} 

265 

266 

267def to_wavelength(energy): 

268 return ( 

269 ((ureg.planck_constant * ureg.c) / (energy * ureg.eV).to(ureg.J)) 

270 .to(ureg.angstrom) 

271 .magnitude 

272 ) 

273 

274 

275def to_energy(wavelength): 

276 return ( 

277 ((ureg.planck_constant * ureg.c) / (wavelength * ureg.angstrom)) 

278 .to(ureg.eV) 

279 .magnitude 

280 ) 

281 

282 

283def make_json_safe(obj): 

284 """Make sure a dict is json encodable 

285 

286 Converts np types to int, float, list 

287 

288 Args: 

289 obj(dict): The dict to make safe 

290 """ 

291 new_obj = {} 

292 for k, v in obj.items(): 

293 new_obj[k] = _make_json_safe(v) 

294 

295 return new_obj 

296 

297 

298def _make_json_safe(v): 

299 if isinstance(v, dict): 

300 return make_json_safe(v) 

301 

302 elif isinstance(v, list): 

303 return [_make_json_safe(i) for i in v] 

304 

305 elif isinstance(v, np.integer): 

306 return int(v) 

307 

308 elif isinstance(v, np.floating): 

309 return float(v) 

310 

311 elif isinstance(v, np.ndarray): 

312 return v.tolist() 

313 

314 elif isinstance(v, np.dtype): 

315 return str(v) 

316 

317 elif isinstance(v, np.bool_): 

318 return bool(v) 

319 

320 return v 

321 

322 

323def deep_diff(x, y, exclude_keys=[]): 

324 """Take the deep diff of JSON-like dictionaries 

325 

326 https://stackoverflow.com/a/62124626 

327 No warranties when keys, or values are None 

328 """ 

329 if x == y: 

330 return None 

331 

332 if type(x) is not type(y) or type(x) not in [list, dict]: 

333 return y # (x, y) 

334 

335 if isinstance(x, dict): 

336 d = {} 

337 for k in x.keys() ^ y.keys(): 

338 if k in exclude_keys: 

339 continue 

340 if k in x: 

341 d[k] = (deepcopy(x[k]), None) 

342 else: 

343 d[k] = (None, deepcopy(y[k])) 

344 

345 for k in x.keys() & y.keys(): 

346 if k in exclude_keys: 

347 continue 

348 

349 next_d = deep_diff(x[k], y[k]) 

350 if next_d is None: 

351 continue 

352 

353 d[k] = next_d 

354 

355 return d if len(d) else None 

356 

357 # must be list: 

358 d = [None] * max(len(x), len(y)) 

359 flipped = False 

360 if len(x) > len(y): 

361 flipped = True 

362 x, y = y, x 

363 

364 for i, x_val in enumerate(x): 

365 d[i] = deep_diff(y[i], x_val) if flipped else deep_diff(x_val, y[i]) 

366 

367 for i in range(len(x), len(y)): 

368 d[i] = (y[i], None) if flipped else (None, y[i]) 

369 

370 

371def worker(fn): 

372 """Run a function in a real thread""" 

373 pool = gevent.get_hub().threadpool 

374 return pool.spawn(fn).get() 

375 

376 

377def to_unit(value: float, input_unit: str, output_unit: str) -> float: 

378 return (value * ureg(input_unit)).to(ureg(output_unit)).magnitude