|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | + |
| 4 | +__all__ = ("AttrCondition", "FunctionCondition", "OperatorCondition") |
| 5 | + |
| 6 | +from functools import partial, reduce |
| 7 | +import typing |
| 8 | + |
| 9 | +from ...context import BaseProcessorContext |
| 10 | +from .operator_enums import OperatorEnum |
| 11 | + |
| 12 | + |
| 13 | +if typing.TYPE_CHECKING: |
| 14 | + from ...protocols import CallableConditionProtocol |
| 15 | + from ..mappers import ProcessAttr |
| 16 | + |
| 17 | +ContextT = typing.TypeVar("ContextT", bound=BaseProcessorContext) |
| 18 | +OperatorCallablesT = typing.Callable[[typing.Iterable[typing.Any]], bool] |
| 19 | +OperatorMapT = typing.Dict[OperatorEnum, OperatorCallablesT] |
| 20 | + |
| 21 | + |
| 22 | +class OperatorCondition(typing.Generic[ContextT]): |
| 23 | + operator_map: typing.ClassVar[OperatorMapT] = { |
| 24 | + OperatorEnum.AND: all, |
| 25 | + OperatorEnum.OR: any, |
| 26 | + OperatorEnum.XOR: partial(reduce, lambda itm, other: itm ^ other), |
| 27 | + } |
| 28 | + |
| 29 | + def __init__( |
| 30 | + self, |
| 31 | + conditions: typing.Iterable[CallableConditionProtocol], |
| 32 | + *, |
| 33 | + operator: OperatorEnum = OperatorEnum.AND, |
| 34 | + negated: bool = False, |
| 35 | + ): |
| 36 | + self.conditions = conditions |
| 37 | + self.negated = negated |
| 38 | + self.operator = operator |
| 39 | + |
| 40 | + def __call__(self, context: ContextT) -> bool: |
| 41 | + operator_f = self.operator_map[self.operator] |
| 42 | + result = operator_f(bool(condition(context)) for condition in self.conditions) |
| 43 | + return not result if self.negated else result |
| 44 | + |
| 45 | + def __invert__(self) -> OperatorCondition: |
| 46 | + return OperatorCondition([self], operator=self.operator, negated=not self.negated) |
| 47 | + |
| 48 | + def __and__(self, other: CallableConditionProtocol) -> OperatorCondition: |
| 49 | + return OperatorCondition([self, other], operator=OperatorEnum.AND) |
| 50 | + |
| 51 | + def __or__(self, other: CallableConditionProtocol) -> OperatorCondition: |
| 52 | + return OperatorCondition([self, other], operator=OperatorEnum.OR) |
| 53 | + |
| 54 | + def __xor__(self, other: CallableConditionProtocol) -> OperatorCondition: |
| 55 | + return OperatorCondition([self, other], operator=OperatorEnum.XOR) |
| 56 | + |
| 57 | + |
| 58 | +class AttrCondition(OperatorCondition[ContextT]): |
| 59 | + def __init__(self, process_attr: ProcessAttr, *, negated: bool = False): |
| 60 | + self.process_attr = process_attr |
| 61 | + super().__init__(operator=OperatorEnum.AND, conditions=[self], negated=negated) |
| 62 | + |
| 63 | + def __call__(self, context: ContextT) -> bool: |
| 64 | + value = self.process_attr.get_value(context) |
| 65 | + result = bool(value) |
| 66 | + return not result if self.negated else result |
| 67 | + |
| 68 | + |
| 69 | +class FunctionCondition(OperatorCondition[ContextT]): |
| 70 | + def __init__(self, func: CallableConditionProtocol, *, negated: bool = False): |
| 71 | + super().__init__(operator=OperatorEnum.AND, conditions=[func], negated=negated) |
0 commit comments