Commit 5ec2e899 authored by Rolf H. B. van Kleef's avatar Rolf H. B. van Kleef

Initial commit

\ No newline at end of file
\ No newline at end of file
from abc import ABC, abstractmethod
class Discriminator(ABC):
def check(self, d: dict):
class KeyValueDiscriminator:
def __init__(self, key, value, has_value=True):
self.key = key
self.value = value
self.has_value = has_value
def __repr__(self):
if self.has_value:
return 'KeyValueDiscriminator(key={}, value={})'.format(self.key, self.value)
return 'KeyValueDiscriminator(key={})'.format(self.key)
def check(self, d: dict):
if self.key not in d:
return False
if self.has_value:
return d[self.key] == self.value
return True
class FunctionDiscriminator:
def __init__(self, matcher):
self.matcher = matcher
def check(self, d: dict):
return self.matcher(d)
sentinel = object()
def discriminate(key=None, value=sentinel, matcher=None):
def _inner(cls):
dc = None
if key is not None:
dc = KeyValueDiscriminator(key, value, value is not sentinel)
elif matcher is not None:
dc = FunctionDiscriminator(matcher)
if dc is None:
return cls
return cls
return _inner
def abstract(cls):
cls._abstract = True
return cls
\ No newline at end of file
from typing import Optional
from typeguard import check_type
class Rule:
def to_rule(tpe):
if isinstance(tpe, Rule):
return tpe
return Rule(tpe)
def __init__(self, type, default=None):
self.type = type
self.default = default
def __repr__(self):
return "Rule(type={}, default={})".format(self.type, self.default)
def validate(self, key, value):
check_type(key, value, self.type)
if value is None:
value = self.default
return value
class BaseMeta(type):
def __new__(mcs, name, bases, namespace):
namespace['_discriminators'] = []
namespace['_abstract'] = False
cls = type.__new__(mcs, name, bases, namespace)
for b in bases:
if b is object:
if hasattr(b, '__annotations__'):
annotations = dict(b.__annotations__)
cls.__annotations__ = annotations
return cls
def rbase(cls, ls=None):
if ls is None:
ls = []
if len(cls.__bases__) > 0:
for k in cls.__bases__:
rbase(k, ls)
return ls
def _is_valid(key: str, value):
return not key.startswith('_') and not callable(value) and not isinstance(value, classmethod) and\
not isinstance(value, staticmethod) and not isinstance(value, property)
class Deserializable(metaclass=BaseMeta):
def get_attrs(cls):
fields = {}
defaults = {}
rl = list(reversed(rbase(cls)))
for c in rl:
for k in c.__dict__:
if _is_valid(k, c.__dict__[k]):
defaults[k] = c.__dict__[k]
fields[k] = Rule(Optional[type(defaults[k])], default=defaults[k])
for k in cls.__annotations__:
if k in defaults and not _is_valid(k, defaults[k]):
rule = Rule.to_rule(cls.__annotations__[k])
if k in defaults:
rule.default = defaults[k]
fields[k] = rule
return fields
def deserialize(t, d, try_all=False):
for sc in t.__subclasses__():
if hasattr(sc, '_discriminators'):
for discriminator in sc._discriminators:
if not discriminator.check(d):
return deserialize(sc, d)
except TypeError as e:
if not try_all:
raise e
if hasattr(t, '_abstract') and t._abstract:
raise TypeError('Cannot deserialize into {}: is abstract.'.format(t.__name__))
instance = t()
for k, rule in t.get_attrs().items():
v = rule.validate(k, d[k] if k in d else None)
setattr(instance, k, v)
return instance
from time import sleep
from typing import Optional
from serializer_utils.deserializer import Deserializable
from serializer_utils.annotations import abstract, discriminate
class ReceiptLine(Deserializable):
pk: Optional[int]
name: str
t = 0
def __repr__(self):
return 'ReceiptLine(pk={}, name={}, t={})'\
.format(,, self.t)
@discriminate('type', 'transaction')
class TransactionLine(ReceiptLine):
article_id: int
amount: int
def __repr__(self):
return 'TransactionLine(super={}, article_id={}, amount={})'\
.format(super(TransactionLine, self).__repr__(), self.article_id, self.amount)
@discriminate('type', 'refund')
class RefundLine(ReceiptLine):
def __repr__(self):
return 'RefundLine(super={})'.format(super(RefundLine, self).__repr__())
if __name__ == '__main__':
import sys
import traceback
from serializer_utils.deserializer import deserialize
print(deserialize(ReceiptLine, {'type': 'transaction', 'name': 'asdf', 'article_id': 5, 'amount': 5}))
print(deserialize(ReceiptLine, {'type': 'refund', 'name': 'asdf', 'pk': 5}))
print(deserialize(ReceiptLine, {'type': 'other', 'name': 'asdf'}))
except TypeError as err:
print(traceback.format_exc(), file=sys.stderr)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment