Coverage for /opt/conda/envs/apienv/lib/python3.10/site-packages/daiquiri/core/schema/union.py: 64%
33 statements
« prev ^ index » next coverage.py v7.6.5, created at 2024-11-15 02:12 +0000
« prev ^ index » next coverage.py v7.6.5, created at 2024-11-15 02:12 +0000
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3import typing as t
5import marshmallow
6from marshmallow import fields, ValidationError
8# Stolen from:
9# https://github.com/adamboche/python-marshmallow-union
12class MarshmallowUnionException(Exception):
13 """Base exception for marshmallow_union."""
16class ExceptionGroup(MarshmallowUnionException):
17 """Collection of possibly multiple exceptions."""
19 def __init__(self, msg: str, errors):
20 self.msg = msg
21 self.errors = errors
22 super().__init__(msg, errors)
25class Union(fields.Field):
26 """Field that accepts any one of multiple fields.
27 Each argument will be tried until one succeeds.
28 Args:
29 fields: The list of candidate fields to try.
30 reverse_serialize_candidates: Whether to try the candidates in reverse order when
31 serializing.
32 """
34 def __init__(
35 self,
36 fields: t.List[fields.Field],
37 reverse_serialize_candidates: bool = False,
38 **kwargs
39 ):
40 self._candidate_fields = fields
41 self._reverse_serialize_candidates = reverse_serialize_candidates
42 super().__init__(**kwargs)
44 def _serialize(self, value: t.Any, attr: str, obj: str, **kwargs):
45 """Pulls the value for the given key from the object, applies the
46 field's formatting and returns the result.
47 Args:
48 value: The value to be serialized.
49 attr: The attribute or key to get from the object.
50 obj: The object to pull the key from.
51 kwargs': Field-specific keyword arguments.
52 Raises:
53 marshmallow.exceptions.ValidationError: In case of formatting problem
54 """
56 error_store = kwargs.pop("error_store", marshmallow.error_store.ErrorStore())
57 fields = self._candidate_fields
58 if self._reverse_serialize_candidates:
59 fields = list(reversed(fields))
61 for candidate_field in fields:
63 try:
64 # pylint: disable=protected-access
65 return candidate_field._serialize(
66 value, attr, obj, error_store=error_store, **kwargs
67 )
68 except (TypeError, ValueError) as e:
69 error_store.store_error({attr: e})
71 raise ExceptionGroup("All serializers raised exceptions.\n", error_store.errors)
73 def _deserialize(self, value, attr=None, data=None, **kwargs):
74 """Deserialize ``value``.
75 Args:
76 value: The value to be deserialized.
77 attr: The attribute/key in `data` to be deserialized.
78 data: The raw input data passed to the `Schema.load`.
79 kwargs: Field-specific keyword arguments.
80 Raises:
81 ValidationError: If an invalid value is passed or if a required value
82 is missing.
83 """
85 errors = []
86 for candidate_field in self._candidate_fields:
87 try:
88 return candidate_field.deserialize(value, attr, data, **kwargs)
89 except ValidationError as exc:
90 errors.append(exc.messages)
91 raise ValidationError(message=errors, field_name=attr)