Coverage for /opt/conda/envs/apienv/lib/python3.10/site-packages/daiquiri/core/metadata/ispyalchemy/__init__.py: 69%

128 statements  

« prev     ^ index     » next       coverage.py v7.6.5, created at 2024-11-15 02:12 +0000

1# -*- coding: utf-8 -*- 

2import json 

3import os 

4from contextlib import contextmanager 

5from datetime import datetime, timedelta 

6from urllib.parse import quote 

7 

8from flask import g 

9 

10import sqlalchemy 

11 

12# from sqlalchemy import orm, event, func, distinct, and_ 

13from sqlalchemy import orm, event, func, and_ 

14from sqlalchemy.schema import Table, MetaData 

15 

16try: 

17 from sqlalchemy.orm import declarative_base 

18except ImportError: 

19 from sqlalchemy.ext.declarative import declarative_base 

20 

21from marshmallow import fields, Schema, EXCLUDE 

22from daiquiri.core.metadata import MetaDataHandler 

23from daiquiri.core.metadata.user import User 

24from daiquiri.core.metadata.ispyalchemy.sample import SampleHandler 

25from daiquiri.core.metadata.ispyalchemy.dc import DCHandler 

26from daiquiri.core.metadata.ispyalchemy.xrf import XRFHandler 

27from daiquiri.core.metadata.ispyalchemy.autoproc import AutoProcHandler 

28 

29 

30class ISPyBConfigSchema(Schema): 

31 meta_url = fields.Str() 

32 meta_user = fields.Str() 

33 meta_password = fields.Str() 

34 meta_beamline = fields.Str() 

35 meta_staff = fields.Str() 

36 meta_charset = fields.Str() 

37 

38 

39@event.listens_for(Table, "column_reflect") 

40def column_reflect(inspector, table, column_info): 

41 column_info["key"] = column_info["name"].lower() 

42 

43 # TODO: Fix for position table with virtual column 

44 # from sqlalchemy.schema import FetchedValue 

45 # reflect does know these should be default = FetchedValue() 

46 # if table.name == "Position": 

47 # if column_info["name"] in ["X", "Y", "Z"]: 

48 # column_info["default"] = FetchedValue() 

49 

50 if table.name in ["BLSample", "BLSubSample"]: 

51 if column_info["name"] == "extraMetadata": 

52 column_info["type"] = sqlalchemy.JSON 

53 

54 

55class IspyalchemyMetaDataHandler(MetaDataHandler): 

56 def __init__(self, *args, **kwargs): 

57 self._config = ISPyBConfigSchema().load( 

58 kwargs.get("config", {}), unknown=EXCLUDE 

59 ) 

60 super().__init__(*args, **kwargs) 

61 

62 def setup(self): 

63 char = "" 

64 if self._config.get("meta_charset"): 

65 char = f"?charset={self._config['meta_charset']}" 

66 

67 url = os.environ.get("DAIQUIRI_META_URL", self._config["meta_url"]) 

68 self._engine = sqlalchemy.create_engine( 

69 f"mysql+mysqlconnector://{self._config['meta_user']}:{quote(self._config['meta_pass'])}@{url}{char}", 

70 # Blobs get decoded as str without this resulting in TypeError: string argument without an encoding 

71 # https://stackoverflow.com/a/53468522 

72 connect_args={"use_pure": True}, 

73 isolation_level="READ UNCOMMITTED", 

74 # https://docs.sqlalchemy.org/en/13/core/pooling.html#dealing-with-disconnects 

75 pool_pre_ping=True, 

76 pool_recycle=3600, 

77 # pooling 

78 # https://docs.sqlalchemy.org/en/13/errors.html#error-3o7r 

79 # maybe consider https://docs.sqlalchemy.org/en/13/core/pooling.html#sqlalchemy.pool.NullPool ? 

80 pool_size=self._config.get("meta_pool", 10), 

81 max_overflow=self._config.get("meta_overflow", 20), 

82 ) 

83 

84 self._connection = self._engine.connect() 

85 self._metadata = MetaData() 

86 

87 self._session_maker = orm.sessionmaker() 

88 self._session_maker.configure(bind=self._engine) 

89 

90 self._tables = {} 

91 

92 for t in [ 

93 "Person", 

94 "Permission", 

95 "UserGroup", 

96 "UserGroup_has_Person", 

97 "UserGroup_has_Permission", 

98 "Proposal", 

99 "BLSession", 

100 "Session_has_Person", 

101 "DataCollectionGroup", 

102 "DataCollection", 

103 "DataCollectionFileAttachment", 

104 "DataCollectionComment", 

105 "GridInfo", 

106 "Protein", 

107 "Crystal", 

108 "Shipping", 

109 "Dewar", 

110 "Container", 

111 "ContainerHistory", 

112 "ContainerQueue", 

113 "ContainerQueueSample", 

114 "BLSample", 

115 "BLSample_has_DataCollectionPlan", 

116 "BLSubSample", 

117 "Position", 

118 "Positioner", 

119 "BLSubSample_has_Positioner", 

120 "DiffractionPlan", 

121 "BLSampleImage", 

122 "BLSampleImage_has_Positioner", 

123 "ContainerInspection", 

124 "XRFFluorescenceMapping", 

125 "XRFFluorescenceMappingROI", 

126 "XFEFluorescenceComposite", 

127 "RobotAction", 

128 "RobotActionPosition", 

129 "ImageQualityIndicators", 

130 "AutoProcProgram", 

131 "AutoProcProgramAttachment", 

132 "AutoProcProgramMessage", 

133 "ProcessingJob", 

134 "ProcessingJobParameter", 

135 ]: 

136 table = type( 

137 t, 

138 (declarative_base(),), 

139 {"__table__": Table(t, self._metadata, autoload_with=self._engine)}, 

140 ) 

141 

142 self._tables[t] = table 

143 setattr(self, t, table) 

144 

145 self._handlers = [] 

146 for cls in [SampleHandler, DCHandler, XRFHandler, AutoProcHandler]: 

147 handler = cls( 

148 tables=self._tables, 

149 session_scope=self.session_scope, 

150 config=self._config, 

151 ) 

152 self._handlers.append(handler) 

153 

154 for m in handler.exported: 

155 setattr(self, m, getattr(handler, m)) 

156 

157 for m in SampleHandler.exported: 

158 setattr(DCHandler, m, getattr(SampleHandler, m)) 

159 

160 for m in DCHandler.exported: 

161 setattr(SampleHandler, m, getattr(DCHandler, m)) 

162 

163 super().setup() 

164 

165 @contextmanager 

166 def session_scope(self): 

167 session = self._session_maker() 

168 try: 

169 yield session 

170 session.commit() 

171 except Exception as e: 

172 session.rollback() 

173 raise e 

174 finally: 

175 session.close() 

176 

177 def _row_to_dict(self, row): 

178 d = {} 

179 for column in row.__table__.columns: 

180 d[column.key] = getattr(row, column.key) 

181 

182 return d 

183 

184 def get_user(self, **kwargs): 

185 with self.session_scope() as ses: 

186 p = ( 

187 ses.query( 

188 self.Person.givenname, 

189 self.Person.familyname, 

190 self.Person.personid, 

191 func.concat( 

192 self.Person.givenname, " ", self.Person.familyname 

193 ).label("fullname"), 

194 func.group_concat(self.Permission.type).label("permissions"), 

195 ) 

196 .outerjoin( 

197 self.UserGroup_has_Person, 

198 # self.UserGroup_has_Person.personid == self.Person.personid, 

199 ) 

200 .outerjoin( 

201 self.UserGroup_has_Permission, 

202 self.UserGroup_has_Permission.usergroupid 

203 == self.UserGroup_has_Person.usergroupid, 

204 ) 

205 .outerjoin( 

206 self.Permission, 

207 # self.Permission.permissionid 

208 # == self.UserGroup_has_Permission.permissionid, 

209 ) 

210 .filter(self.Person.login == g.login) 

211 .group_by(self.Person.personid) 

212 .first() 

213 ) 

214 

215 if not p: 

216 return None 

217 

218 dct = p._asdict() 

219 dct["permissions"] = ( 

220 dct["permissions"].split(",") if dct["permissions"] else [] 

221 ) 

222 dct["is_staff"] = self._config["meta_staff"] in dct["permissions"] 

223 

224 return User(**dct) 

225 

226 def get_user_cache(self): 

227 with self.session_scope() as ses: 

228 person = ( 

229 ses.query( 

230 self.Person, 

231 ) 

232 .filter(self.Person.login == g.login) 

233 .first() 

234 ) 

235 

236 if not person: 

237 raise RuntimeError(f"Could not retrieve person: `{g.login}`") 

238 

239 return json.loads(person.cache) if person.cache else {} 

240 

241 def update_user_cache(self, cache): 

242 with self.session_scope() as ses: 

243 person = ses.query(self.Person).filter(self.Person.login == g.login).first() 

244 person.cache = json.dumps(cache) 

245 ses.commit() 

246 

247 return cache 

248 

249 def verify_session(self, session): 

250 if not session: 

251 return None 

252 

253 return self.get_sessions(session=session) 

254 

255 def get_sessions(self, session=None, **kwargs): 

256 grace = self._config.get("meta_session_grace", 1) 

257 staff_grace = self._config.get("meta_session_grace_staff") 

258 if kwargs.get("no_context"): 

259 # In `no_context`` we dont know if its a user or staff member so enable 

260 # maximum grace time 

261 if staff_grace is not None: 

262 grace = staff_grace 

263 else: 

264 if g.user.staff(): 

265 if staff_grace is not None: 

266 grace = staff_grace 

267 

268 with self.session_scope() as ses: 

269 sessions = ( 

270 ses.query( 

271 func.concat( 

272 self.Proposal.proposalcode, 

273 self.Proposal.proposalnumber, 

274 "-", 

275 self.BLSession.visit_number, 

276 ).label("session"), 

277 func.concat( 

278 self.Proposal.proposalcode, self.Proposal.proposalnumber 

279 ).label("proposal"), 

280 self.Proposal.proposalid, 

281 self.BLSession.sessionid, 

282 self.BLSession.startdate, 

283 self.BLSession.enddate, 

284 self.BLSession.beamlinename, 

285 func.IF( 

286 and_( 

287 self.BLSession.startdate <= datetime.now(), 

288 self.BLSession.enddate >= datetime.now(), 

289 ), 

290 True, 

291 False, 

292 ).label("active"), 

293 ) 

294 .join( 

295 self.Proposal, self.Proposal.proposalid == self.BLSession.proposalid 

296 ) 

297 .filter(self.BLSession.beamlinename == self._config["meta_beamline"]) 

298 .filter( 

299 and_( 

300 self.BLSession.startdate 

301 <= (datetime.now() + timedelta(hours=grace)), 

302 self.BLSession.enddate 

303 >= (datetime.now() - timedelta(hours=grace)), 

304 ) 

305 ) 

306 .order_by(self.BLSession.startdate) 

307 ) 

308 

309 if not kwargs.get("no_context"): 

310 if not g.user.staff(): 

311 sessions = sessions.join( 

312 self.Session_has_Person, 

313 self.Session_has_Person.sessionid == self.BLSession.sessionid, 

314 ) 

315 sessions = sessions.filter( 

316 self.Session_has_Person.personid == g.user["personid"] 

317 ) 

318 

319 sessions = sessions.group_by(self.BLSession.sessionid) 

320 

321 if session or kwargs.get("sessionid"): 

322 if session: 

323 sessions = sessions.filter( 

324 func.concat( 

325 self.Proposal.proposalcode, 

326 self.Proposal.proposalnumber, 

327 "-", 

328 self.BLSession.visit_number, 

329 ) 

330 == session 

331 ) 

332 else: 

333 sessions = sessions.filter( 

334 self.BLSession.sessionid == kwargs["sessionid"] 

335 ) 

336 

337 session = sessions.first() 

338 if session: 

339 return session._asdict() 

340 

341 else: 

342 sessions = [r._asdict() for r in sessions.all()] 

343 return {"total": len(sessions), "rows": sessions}