Coverage for /opt/conda/envs/apienv/lib/python3.10/site-packages/daiquiri/core/schema/union.py: 64%

33 statements  

« prev     ^ index     » next       coverage.py v7.6.5, created at 2024-11-15 02:12 +0000

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3import typing as t 

4 

5import marshmallow 

6from marshmallow import fields, ValidationError 

7 

8# Stolen from: 

9# https://github.com/adamboche/python-marshmallow-union 

10 

11 

12class MarshmallowUnionException(Exception): 

13 """Base exception for marshmallow_union.""" 

14 

15 

16class ExceptionGroup(MarshmallowUnionException): 

17 """Collection of possibly multiple exceptions.""" 

18 

19 def __init__(self, msg: str, errors): 

20 self.msg = msg 

21 self.errors = errors 

22 super().__init__(msg, errors) 

23 

24 

25class Union(fields.Field): 

26 """Field that accepts any one of multiple fields. 

27 Each argument will be tried until one succeeds. 

28 Args: 

29 fields: The list of candidate fields to try. 

30 reverse_serialize_candidates: Whether to try the candidates in reverse order when 

31 serializing. 

32 """ 

33 

34 def __init__( 

35 self, 

36 fields: t.List[fields.Field], 

37 reverse_serialize_candidates: bool = False, 

38 **kwargs 

39 ): 

40 self._candidate_fields = fields 

41 self._reverse_serialize_candidates = reverse_serialize_candidates 

42 super().__init__(**kwargs) 

43 

44 def _serialize(self, value: t.Any, attr: str, obj: str, **kwargs): 

45 """Pulls the value for the given key from the object, applies the 

46 field's formatting and returns the result. 

47 Args: 

48 value: The value to be serialized. 

49 attr: The attribute or key to get from the object. 

50 obj: The object to pull the key from. 

51 kwargs': Field-specific keyword arguments. 

52 Raises: 

53 marshmallow.exceptions.ValidationError: In case of formatting problem 

54 """ 

55 

56 error_store = kwargs.pop("error_store", marshmallow.error_store.ErrorStore()) 

57 fields = self._candidate_fields 

58 if self._reverse_serialize_candidates: 

59 fields = list(reversed(fields)) 

60 

61 for candidate_field in fields: 

62 

63 try: 

64 # pylint: disable=protected-access 

65 return candidate_field._serialize( 

66 value, attr, obj, error_store=error_store, **kwargs 

67 ) 

68 except (TypeError, ValueError) as e: 

69 error_store.store_error({attr: e}) 

70 

71 raise ExceptionGroup("All serializers raised exceptions.\n", error_store.errors) 

72 

73 def _deserialize(self, value, attr=None, data=None, **kwargs): 

74 """Deserialize ``value``. 

75 Args: 

76 value: The value to be deserialized. 

77 attr: The attribute/key in `data` to be deserialized. 

78 data: The raw input data passed to the `Schema.load`. 

79 kwargs: Field-specific keyword arguments. 

80 Raises: 

81 ValidationError: If an invalid value is passed or if a required value 

82 is missing. 

83 """ 

84 

85 errors = [] 

86 for candidate_field in self._candidate_fields: 

87 try: 

88 return candidate_field.deserialize(value, attr, data, **kwargs) 

89 except ValidationError as exc: 

90 errors.append(exc.messages) 

91 raise ValidationError(message=errors, field_name=attr)