|
@@ -0,0 +1,92 @@
|
|
|
|
+import logging
|
|
|
|
+import math
|
|
|
|
+import dataclasses
|
|
|
|
+from .common import MatchSeries
|
|
|
|
+from structures.measurement import Measurement24v, Measurement480v
|
|
|
|
+from structures.plant import S7State, CompactLogixState
|
|
|
|
+from structures.correlated import CorrelatedMeasurements
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
|
+"""
|
|
|
|
+This middleware filters the measurements by series and yields if any given field matches.
|
|
|
|
+"""
|
|
|
|
+class MatchAny(MatchSeries):
|
|
|
|
+ def __init__(self, parent, series, **kwargs) -> None:
|
|
|
|
+ super().__init__(series)
|
|
|
|
+ self._fields = kwargs
|
|
|
|
+
|
|
|
|
+ def execute(self, values):
|
|
|
|
+ for measurement in values:
|
|
|
|
+ dataset = self.get_series(measurement)
|
|
|
|
+
|
|
|
|
+ if not dataset:
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ if not self._fields:
|
|
|
|
+ yield measurement
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ # check if any field matches
|
|
|
|
+ for field, value in self._fields.items():
|
|
|
|
+ if getattr(dataset, field, None) == value:
|
|
|
|
+ yield measurement
|
|
|
|
+ break
|
|
|
|
+
|
|
|
|
+"""
|
|
|
|
+This middleware filters the measurements by series and yields if all given fields match.
|
|
|
|
+"""
|
|
|
|
+class MatchAll(MatchSeries):
|
|
|
|
+ def __init__(self, parent, series, **kwargs) -> None:
|
|
|
|
+ super().__init__(series)
|
|
|
|
+ self._fields = kwargs
|
|
|
|
+
|
|
|
|
+ def execute(self, values):
|
|
|
|
+ for measurement in values:
|
|
|
|
+ dataset = self.get_series(measurement)
|
|
|
|
+ if not dataset:
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+ # check if all fields match
|
|
|
|
+ success = True
|
|
|
|
+ for field, value in self._fields.items():
|
|
|
|
+ if getattr(dataset, field, None) != value:
|
|
|
|
+ success = False
|
|
|
|
+ break
|
|
|
|
+ if success:
|
|
|
|
+ yield measurement
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ALLOWED_NAMES = \
|
|
|
|
+ [x.name for x in dataclasses.fields(Measurement24v)] + \
|
|
|
|
+ [x.name for x in dataclasses.fields(Measurement480v)] + \
|
|
|
|
+ [x.name for x in dataclasses.fields(CompactLogixState)] + \
|
|
|
|
+ [x.name for x in dataclasses.fields(S7State)] + \
|
|
|
|
+ [x.name for x in dataclasses.fields(CorrelatedMeasurements)] + \
|
|
|
|
+ ['sum', 'min', 'max', 'avg', 'count', 'last']
|
|
|
|
+
|
|
|
|
+ALLOWED_NAMES = set([name for name in ALLOWED_NAMES if not name.startswith('_')])
|
|
|
|
+
|
|
|
|
+class ComplexFilter():
|
|
|
|
+ def __init__(self, parent, predicate) -> None:
|
|
|
|
+ self._predicate = predicate
|
|
|
|
+ self._compiled = compile(predicate, "<string>", "eval")
|
|
|
|
+ # Validate allowed names
|
|
|
|
+ for name in self._compiled.co_names:
|
|
|
|
+ if name not in ALLOWED_NAMES:
|
|
|
|
+ raise NameError(f"The use of '{name}' is not allowed in '{predicate}'")
|
|
|
|
+
|
|
|
|
+ def execute(self, values):
|
|
|
|
+ for measurement in values:
|
|
|
|
+ try:
|
|
|
|
+ if eval(self._compiled, {"__builtins__": {
|
|
|
|
+ 'sum': sum,
|
|
|
|
+ 'min': min,
|
|
|
|
+ 'max': max,
|
|
|
|
+ 'avg': lambda x: sum(x) / len(x),
|
|
|
|
+ 'count': len,
|
|
|
|
+ 'last': lambda x: x[-1],
|
|
|
|
+ }}, measurement.__dict__):
|
|
|
|
+ yield measurement
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"Error while evaluating predicate '{self._predicate}': {e}")
|