diff --git a/sdks/python/apache_beam/metrics/cells.pxd b/sdks/python/apache_beam/metrics/cells.pxd index 98bb5eff0977..6ba0fbbd3d33 100644 --- a/sdks/python/apache_beam/metrics/cells.pxd +++ b/sdks/python/apache_beam/metrics/cells.pxd @@ -33,6 +33,7 @@ cdef class CounterCell(MetricCell): cpdef bint update(self, value) except -1 +# Not using AbstractMetricCell so that data can be typed. cdef class DistributionCell(MetricCell): cdef readonly DistributionData data @@ -40,14 +41,18 @@ cdef class DistributionCell(MetricCell): cdef inline bint _update(self, value) except -1 -cdef class GaugeCell(MetricCell): +cdef class AbstractMetricCell(MetricCell): + cdef readonly object data_class cdef readonly object data + cdef bint _update_locked(self, value) except -1 -cdef class StringSetCell(MetricCell): - cdef readonly object data +cdef class GaugeCell(AbstractMetricCell): + pass - cdef inline bint _update(self, value) except -1 + +cdef class StringSetCell(AbstractMetricCell): + pass cdef class DistributionData(object): @@ -55,3 +60,14 @@ cdef class DistributionData(object): cdef readonly libc.stdint.int64_t count cdef readonly libc.stdint.int64_t min cdef readonly libc.stdint.int64_t max + + +cdef class _BoundedTrieNode(object): + cdef readonly libc.stdint.int64_t _size + cdef dict _children + cdef bint _truncated + +cdef class BoundedTrieData(object): + cdef readonly libc.stdint.int64_t _bound + cdef readonly object _singleton + cdef readonly _BoundedTrieNode _root diff --git a/sdks/python/apache_beam/metrics/cells.py b/sdks/python/apache_beam/metrics/cells.py index 63fc9f3f7cc9..0029a8a26978 100644 --- a/sdks/python/apache_beam/metrics/cells.py +++ b/sdks/python/apache_beam/metrics/cells.py @@ -23,6 +23,7 @@ # pytype: skip-file +import copy import logging import threading import time @@ -43,11 +44,7 @@ class fake_cython: globals()['cython'] = fake_cython __all__ = [ - 'MetricAggregator', - 'MetricCell', - 'MetricCellFactory', - 'DistributionResult', - 'GaugeResult' + 'MetricCell', 'MetricCellFactory', 'DistributionResult', 'GaugeResult' ] _LOGGER = logging.getLogger(__name__) @@ -110,11 +107,11 @@ class CounterCell(MetricCell): """ def __init__(self, *args): super().__init__(*args) - self.value = CounterAggregator.identity_element() + self.value = 0 def reset(self): # type: () -> None - self.value = CounterAggregator.identity_element() + self.value = 0 def combine(self, other): # type: (CounterCell) -> CounterCell @@ -175,11 +172,11 @@ class DistributionCell(MetricCell): """ def __init__(self, *args): super().__init__(*args) - self.data = DistributionAggregator.identity_element() + self.data = DistributionData.identity_element() def reset(self): # type: () -> None - self.data = DistributionAggregator.identity_element() + self.data = DistributionData.identity_element() def combine(self, other): # type: (DistributionCell) -> DistributionCell @@ -221,47 +218,65 @@ def to_runner_api_monitoring_info_impl(self, name, transform_id): ptransform=transform_id) -class GaugeCell(MetricCell): +class AbstractMetricCell(MetricCell): """For internal use only; no backwards-compatibility guarantees. - Tracks the current value and delta for a gauge metric. - - Each cell tracks the state of a metric independently per context per bundle. - Therefore, each metric has a different cell in each bundle, that is later - aggregated. + Tracks the current value and delta for a metric with a data class. This class is thread safe. """ - def __init__(self, *args): - super().__init__(*args) - self.data = GaugeAggregator.identity_element() + def __init__(self, data_class): + super().__init__() + self.data_class = data_class + self.data = self.data_class.identity_element() def reset(self): - self.data = GaugeAggregator.identity_element() + self.data = self.data_class.identity_element() - def combine(self, other): - # type: (GaugeCell) -> GaugeCell - result = GaugeCell() + def combine(self, other: 'AbstractMetricCell') -> 'AbstractMetricCell': + result = type(self)() result.data = self.data.combine(other.data) return result def set(self, value): - self.update(value) + with self._lock: + self._update_locked(value) def update(self, value): - # type: (SupportsInt) -> None - value = int(value) with self._lock: - # Set the value directly without checking timestamp, because - # this value is naturally the latest value. - self.data.value = value - self.data.timestamp = time.time() + self._update_locked(value) + + def _update_locked(self, value): + raise NotImplementedError(type(self)) def get_cumulative(self): - # type: () -> GaugeData with self._lock: return self.data.get_cumulative() + def to_runner_api_monitoring_info_impl(self, name, transform_id): + raise NotImplementedError(type(self)) + + +class GaugeCell(AbstractMetricCell): + """For internal use only; no backwards-compatibility guarantees. + + Tracks the current value and delta for a gauge metric. + + Each cell tracks the state of a metric independently per context per bundle. + Therefore, each metric has a different cell in each bundle, that is later + aggregated. + + This class is thread safe. + """ + def __init__(self): + super().__init__(GaugeData) + + def _update_locked(self, value): + # Set the value directly without checking timestamp, because + # this value is naturally the latest value. + self.data.value = int(value) + self.data.timestamp = time.time() + def to_runner_api_monitoring_info_impl(self, name, transform_id): from apache_beam.metrics import monitoring_infos return monitoring_infos.int64_user_gauge( @@ -271,7 +286,7 @@ def to_runner_api_monitoring_info_impl(self, name, transform_id): ptransform=transform_id) -class StringSetCell(MetricCell): +class StringSetCell(AbstractMetricCell): """For internal use only; no backwards-compatibility guarantees. Tracks the current value for a StringSet metric. @@ -282,49 +297,51 @@ class StringSetCell(MetricCell): This class is thread safe. """ - def __init__(self, *args): - super().__init__(*args) - self.data = StringSetAggregator.identity_element() + def __init__(self): + super().__init__(StringSetData) def add(self, value): self.update(value) - def update(self, value): - # type: (str) -> None - if cython.compiled: - # We will hold the GIL throughout the entire _update. - self._update(value) - else: - with self._lock: - self._update(value) - - def _update(self, value): + def _update_locked(self, value): self.data.add(value) - def get_cumulative(self): - # type: () -> StringSetData - with self._lock: - return self.data.get_cumulative() - - def combine(self, other): - # type: (StringSetCell) -> StringSetCell - combined = StringSetAggregator().combine(self.data, other.data) - result = StringSetCell() - result.data = combined - return result - def to_runner_api_monitoring_info_impl(self, name, transform_id): from apache_beam.metrics import monitoring_infos - return monitoring_infos.user_set_string( name.namespace, name.name, self.get_cumulative(), ptransform=transform_id) - def reset(self): - # type: () -> None - self.data = StringSetAggregator.identity_element() + +class BoundedTrieCell(AbstractMetricCell): + """For internal use only; no backwards-compatibility guarantees. + + Tracks the current value for a StringSet metric. + + Each cell tracks the state of a metric independently per context per bundle. + Therefore, each metric has a different cell in each bundle, that is later + aggregated. + + This class is thread safe. + """ + def __init__(self): + super().__init__(BoundedTrieData) + + def add(self, value): + self.update(value) + + def _update_locked(self, value): + self.data.add(value) + + def to_runner_api_monitoring_info_impl(self, name, transform_id): + from apache_beam.metrics import monitoring_infos + return monitoring_infos.user_bounded_trie( + name.namespace, + name.name, + self.get_cumulative(), + ptransform=transform_id) class DistributionResult(object): @@ -449,6 +466,10 @@ def get_cumulative(self): # type: () -> GaugeData return GaugeData(self.value, timestamp=self.timestamp) + def get_result(self): + # type: () -> GaugeResult + return GaugeResult(self.get_cumulative()) + def combine(self, other): # type: (Optional[GaugeData]) -> GaugeData if other is None: @@ -464,6 +485,11 @@ def singleton(value, timestamp=None): # type: (Optional[int], Optional[int]) -> GaugeData return GaugeData(value, timestamp=timestamp) + @staticmethod + def identity_element(): + # type: () -> GaugeData + return GaugeData(0, timestamp=0) + class DistributionData(object): """For internal use only; no backwards-compatibility guarantees. @@ -510,6 +536,9 @@ def get_cumulative(self): # type: () -> DistributionData return DistributionData(self.sum, self.count, self.min, self.max) + def get_result(self): + return DistributionResult(self.get_cumulative()) + def combine(self, other): # type: (Optional[DistributionData]) -> DistributionData if other is None: @@ -526,6 +555,11 @@ def singleton(value): # type: (int) -> DistributionData return DistributionData(value, 1, value, value) + @staticmethod + def identity_element(): + # type: () -> DistributionData + return DistributionData(0, 0, 2**63 - 1, -2**63) + class StringSetData(object): """For internal use only; no backwards-compatibility guarantees. @@ -568,6 +602,9 @@ def __repr__(self) -> str: def get_cumulative(self) -> "StringSetData": return StringSetData(set(self.string_set), self.string_size) + def get_result(self) -> set[str]: + return set(self.string_set) + def add(self, *strings): """ Add strings into this StringSetData and return the result StringSetData. @@ -585,6 +622,11 @@ def combine(self, other: "StringSetData") -> "StringSetData": if other is None: return self + if not other.string_set: + return self + elif not self.string_set: + return other + combined = set(self.string_set) string_size = self.add_until_capacity( combined, self.string_size, other.string_set) @@ -614,113 +656,179 @@ def add_until_capacity( return current_size @staticmethod - def singleton(value): - # type: (int) -> DistributionData - return DistributionData(value, 1, value, value) - - -class MetricAggregator(object): - """For internal use only; no backwards-compatibility guarantees. - - Base interface for aggregating metric data during pipeline execution.""" - def identity_element(self): - # type: () -> Any - - """Returns the identical element of an Aggregation. - - For the identity element, it must hold that - Aggregator.combine(any_element, identity_element) == any_element. - """ - raise NotImplementedError + def singleton(value: str) -> "StringSetData": + return StringSetData({value}) - def combine(self, x, y): - # type: (Any, Any) -> Any - raise NotImplementedError + @staticmethod + def identity_element() -> "StringSetData": + return StringSetData() - def result(self, x): - # type: (Any) -> Any - raise NotImplementedError -class CounterAggregator(MetricAggregator): - """For internal use only; no backwards-compatibility guarantees. +class _BoundedTrieNode(object): + def __init__(self): + # invariant: size = len(self.flattened()) = min(1, sum(size of children)) + self._size = 1 + self._children: dict[str, '_BoundedTrieNode'] = {} + self._truncated = False + + def size(self): + return self._size + + def add(self, segments) -> int: + if self._truncated or not segments: + return 0 + head, *tail = segments + was_empty = not self._children + child = self._children.get(head, None) + if child is None: + child = self._children[head] = _BoundedTrieNode() + delta = not was_empty + else: + delta = 0 + if tail: + delta += child.add(tail) + self._size += delta + return delta + + def add_all(self, segments_iter): + return sum(self.add(segments) for segments in segments_iter) + + def trim(self) -> int: + if not self._children: + return 0 + max_child = max(self._children.values(), key=lambda child: child._size) + if max_child._size == 1: + delta = 1 - self._size + self._truncated = True + self._children = None + else: + delta = max_child.trim() + self._size += delta + return delta + + def merge(self, other) -> int: + if self._truncated: + delta = 0 + elif other._truncated: + delta = 1 - self._size + self._truncated = True + self._children = None + elif not other._children: + delta = 0 + elif not self._children: + self._children = other._children + delta = self._size - other._size + else: + delta = 0 + for prefix, other_child in other._children.items(): + self_child = self._children.get(prefix, None) + if self_child is None: + self._children[prefix] = other_child + delta += other_child._size + else: + delta += self_child.merge(other_child) + self._size += delta + return delta + + def flattened(self): + if self._truncated: + yield (True,) + elif not self._children: + yield (False,) + else: + for prefix, child in sorted(self._children.items()): + for flattened in child.flattened(): + yield (prefix,) + flattened - Aggregator for Counter metric data during pipeline execution. + def __hash__(self): + return self._truncated or hash(sorted(self._children.items())) - Values aggregated should be ``int`` objects. - """ - @staticmethod - def identity_element(): - # type: () -> int - return 0 + def __eq__(self, other): + if isinstance(other, _BoundedTrieNode): + return self._children == other._children + else: + return False - def combine(self, x, y): - # type: (SupportsInt, SupportsInt) -> int - return int(x) + int(y) + def __repr__(self): + return repr(set(''.join(str(s) for s in t) for t in self.flattened())) - def result(self, x): - # type: (SupportsInt) -> int - return int(x) +class BoundedTrieData(object): + _DEFAULT_BOUND = 100 -class DistributionAggregator(MetricAggregator): - """For internal use only; no backwards-compatibility guarantees. + def __init__(self, * ,root=None, singleton=None, bound=_DEFAULT_BOUND): + self._singleton = singleton + self._root = root + self._bound = bound - Aggregator for Distribution metric data during pipeline execution. + def as_trie(self): + if self._root is not None: + return self._root + else: + root = _BoundedTrieNode() + if self._singleton is not None: + root.add(self._singleton) + return root - Values aggregated should be ``DistributionData`` objects. - """ - @staticmethod - def identity_element(): - # type: () -> DistributionData - return DistributionData(0, 0, 2**63 - 1, -2**63) + def __eq__(self, other: object) -> bool: + if isinstance(other, BoundedTrieData): + return self.as_trie() == other.as_trie() + else: + return False - def combine(self, x, y): - # type: (DistributionData, DistributionData) -> DistributionData - return x.combine(y) + def __hash__(self) -> int: + return hash(self.as_trie()) - def result(self, x): - # type: (DistributionData) -> DistributionResult - return DistributionResult(x.get_cumulative()) + def __repr__(self) -> str: + return 'BoundedTrieData({})'.format(self.as_trie()) + def get_cumulative(self) -> "BoundedTrieData": + return copy.deepcopy(self) -class GaugeAggregator(MetricAggregator): - """For internal use only; no backwards-compatibility guarantees. + def get_result(self) -> set[tuple]: + if self._root is None: + if self._singleton is None: + return set() + else: + return set([self._singleton + (False,)]) + else: + return set(self._root.flattened()) - Aggregator for Gauge metric data during pipeline execution. + def add(self, segments): + if self._root is None and self._singleton is None: + self._singleton = segments + else: + if self._root is None: + self._root = self.as_trie() + self._root.add(segments) + if self._root._size > self._bound: + self._root.trim() + + def combine(self, other: "BoundedTrieData") -> "BoundedTrieData": + if self._root is None and self._singleton is None: + return other + elif other._root is None and other._singleton is None: + return self + else: + if self._root is None and other._root is not None: + self, other = other, self + combined = copy.deepcopy(self.as_trie()) + if other._root is not None: + combined.merge(other._root) + else: + combined.add(other._singleton) + self._bound = min(self._bound, other._bound) + while combined._size > self._bound: + combined.trim() + return BoundedTrieData(root=combined) - Values aggregated should be ``GaugeData`` objects. - """ @staticmethod - def identity_element(): - # type: () -> GaugeData - return GaugeData(0, timestamp=0) + def singleton(value: str) -> "BoundedTrieData": + s = BoundedTrieData() + s.add(value) + return s - def combine(self, x, y): - # type: (GaugeData, GaugeData) -> GaugeData - result = x.combine(y) - return result - - def result(self, x): - # type: (GaugeData) -> GaugeResult - return GaugeResult(x.get_cumulative()) - - -class StringSetAggregator(MetricAggregator): @staticmethod - def identity_element(): - # type: () -> StringSetData - return StringSetData() - - def combine(self, x, y): - # type: (StringSetData, StringSetData) -> StringSetData - if len(x.string_set) == 0: - return y - elif len(y.string_set) == 0: - return x - else: - return x.combine(y) - - def result(self, x): - # type: (StringSetData) -> set - return set(x.string_set) + def identity_element() -> "BoundedTrieData": + return BoundedTrieData() diff --git a/sdks/python/apache_beam/metrics/cells_test.py b/sdks/python/apache_beam/metrics/cells_test.py index d1ee37b8ed82..be3563344e01 100644 --- a/sdks/python/apache_beam/metrics/cells_test.py +++ b/sdks/python/apache_beam/metrics/cells_test.py @@ -17,6 +17,9 @@ # pytype: skip-file +import copy +import itertools +import random import threading import unittest @@ -27,6 +30,8 @@ from apache_beam.metrics.cells import GaugeData from apache_beam.metrics.cells import StringSetCell from apache_beam.metrics.cells import StringSetData +from apache_beam.metrics.cells import _BoundedTrieNode +from apache_beam.metrics.cells import BoundedTrieData from apache_beam.metrics.metricbase import MetricName @@ -203,5 +208,218 @@ def test_add_size_tracked_correctly(self): self.assertEqual(s.data.string_size, 3) +class TestBoundedTrieNode(unittest.TestCase): + @classmethod + def random_segments_fixed_depth(cls, n, depth, overlap, rand): + if depth == 0: + yield from ((), ) * n + else: + seen = [] + to_string = lambda ix: chr(ord('a') + ix) if ix < 26 else f'z{ix}' + for suffix in cls.random_segments_fixed_depth(n, depth - 1, overlap, + rand): + if not seen or rand.random() > overlap: + prefix = to_string(len(seen)) + seen.append(prefix) + else: + prefix = rand.choice(seen) + yield (prefix, ) + suffix + + @classmethod + def random_segments(cls, n, min_depth, max_depth, overlap, rand): + for depth, segments in zip( + itertools.cycle(range(min_depth, max_depth + 1)), + cls.random_segments_fixed_depth(n, max_depth, overlap, rand)): + yield segments[:depth] + + def assert_covers(self, node, expected, max_truncated=0): + self.assert_covers_flattened(node.flattened(), expected, max_truncated) + + def assert_covers_flattened(self, flattened, expected, max_truncated=0): + expected = set(expected) + # Split node into the exact and truncated segments. + partitioned = {True: set(), False: set()} + for segments in flattened: + partitioned[segments[-1]].add(segments[:-1]) + exact, truncated = partitioned[False], partitioned[True] + # Check we cover both parts. + self.assertLessEqual(len(truncated), max_truncated, truncated) + self.assertTrue(exact.issubset(expected), exact - expected) + seen_truncated = set() + for segments in expected - exact: + found = 0 + for ix in range(len(segments)): + if segments[:ix] in truncated: + seen_truncated.add(segments[:ix]) + found += 1 + if found != 1: + self.fail( + f"Expected exactly one prefix of {segments} " + f"to occur in {truncated}, found {found}") + self.assertEqual(seen_truncated, truncated, truncated - seen_truncated) + + def run_covers_test(self, flattened, expected, max_truncated): + def parse(s): + return tuple(s.strip('*')) + (s.endswith('*'), ) + + self.assert_covers_flattened([parse(s) for s in flattened], + [tuple(s) for s in expected], + max_truncated) + + def test_covers_exact(self): + self.run_covers_test(['ab', 'ac', 'cd'], ['ab', 'ac', 'cd'], 0) + with self.assertRaises(AssertionError): + self.run_covers_test(['ab', 'ac', 'cd'], ['ac', 'cd'], 0) + with self.assertRaises(AssertionError): + self.run_covers_test(['ab', 'ac'], ['ab', 'ac', 'cd'], 0) + with self.assertRaises(AssertionError): + self.run_covers_test(['a*', 'cd'], ['ab', 'ac', 'cd'], 0) + + def test_covers_trunacted(self): + self.run_covers_test(['a*', 'cd'], ['ab', 'ac', 'cd'], 1) + self.run_covers_test(['a*', 'cd'], ['ab', 'ac', 'abcde', 'cd'], 1) + with self.assertRaises(AssertionError): + self.run_covers_test(['ab', 'ac', 'cd'], ['ac', 'cd'], 1) + with self.assertRaises(AssertionError): + self.run_covers_test(['ab', 'ac'], ['ab', 'ac', 'cd'], 1) + with self.assertRaises(AssertionError): + self.run_covers_test(['a*', 'c*'], ['ab', 'ac', 'cd'], 1) + with self.assertRaises(AssertionError): + self.run_covers_test(['a*', 'c*'], ['ab', 'ac'], 1) + + def run_test(self, to_add): + everything = list(set(to_add)) + all_prefixees = set( + segments[:ix] for segments in everything for ix in range(len(segments))) + everything_deduped = set(everything) - all_prefixees + + # Check basic addition. + node = _BoundedTrieNode() + total_size = node.size() + self.assertEqual(total_size, 1) + for segments in everything: + total_size += node.add(segments) + self.assertEqual(node.size(), len(everything_deduped), node) + self.assertEqual(node.size(), total_size, node) + self.assert_covers(node, everything_deduped) + + # Check merging + node0 = _BoundedTrieNode() + node0.add_all(everything[0::2]) + node1 = _BoundedTrieNode() + node1.add_all(everything[1::2]) + pre_merge_size = node0.size() + merge_delta = node0.merge(node1) + self.assertEqual(node0.size(), pre_merge_size + merge_delta) + self.assertEqual(node0, node) + + # Check trimming. + if node.size() > 1: + trim_delta = node.trim() + self.assertLess(trim_delta, 0, node) + self.assertEqual(node.size(), total_size + trim_delta) + self.assert_covers(node, everything_deduped, max_truncated=1) + + if node.size() > 1: + trim2_delta = node.trim() + self.assertLess(trim2_delta, 0) + self.assertEqual(node.size(), total_size + trim_delta + trim2_delta) + self.assert_covers(node, everything_deduped, max_truncated=2) + + # Adding after trimming should be a no-op. + node_copy = copy.deepcopy(node) + for segments in everything: + self.assertEqual(node.add(segments), 0) + self.assertEqual(node, node_copy) + + # Merging after trimming should be a no-op. + self.assertEqual(node.merge(node0), 0) + self.assertEqual(node.merge(node1), 0) + self.assertEqual(node, node_copy) + + if node._truncated: + expected_delta = 0 + else: + expected_delta = 2 + + # Adding something new is not. + new_values = [('new1', ), ('new2', 'new2.1')] + self.assertEqual(node.add_all(new_values), expected_delta) + self.assert_covers( + node, list(everything_deduped) + new_values, max_truncated=2) + + # Nor is merging something new. + new_values_node = _BoundedTrieNode() + new_values_node.add_all(new_values) + self.assertEqual(node_copy.merge(new_values_node), expected_delta) + self.assert_covers( + node_copy, list(everything_deduped) + new_values, max_truncated=2) + + def run_fuzz(self, iterations=10, **params): + for _ in range(iterations): + seed = random.getrandbits(64) + segments = self.random_segments(**params, rand=random.Random(seed)) + try: + self.run_test(segments) + except: + print("SEED", seed) + raise + + def test_trivial(self): + self.run_test([('a', 'b'), ('a', 'c')]) + + def test_flat(self): + self.run_test([('a', 'a'), ('b', 'b'), ('c', 'c')]) + + def test_deep(self): + self.run_test([('a', ) * 10, ('b', ) * 12]) + + def test_small(self): + self.run_fuzz(n=5, min_depth=2, max_depth=3, overlap=0.5) + + def test_medium(self): + self.run_fuzz(n=20, min_depth=2, max_depth=4, overlap=0.5) + + def test_large_sparse(self): + self.run_fuzz(n=120, min_depth=2, max_depth=4, overlap=0.2) + + def test_large_dense(self): + self.run_fuzz(n=120, min_depth=2, max_depth=4, overlap=0.8) + + def test_bounded_trie_data_combine(self): + empty = BoundedTrieData() + # The merging here isn't complicated we're just ensuring that + # BoundedTrieData invokes _BoundedTrieNode correctly. + singletonA = BoundedTrieData(singleton=('a', 'a')) + singletonB = BoundedTrieData(singleton=('b', 'b')) + lots_root = _BoundedTrieNode() + lots_root.add_all([('c', 'c'), ('d', 'd')]) + lots = BoundedTrieData(root=lots_root) + self.assertEqual(empty.get_result(), set()) + self.assertEqual( + empty.combine(singletonA).get_result(), set([('a', 'a', False)])) + self.assertEqual( + singletonA.combine(empty).get_result(), set([('a', 'a', False)])) + self.assertEqual( + singletonA.combine(singletonB).get_result(), + set([('a', 'a', False), ('b', 'b', False)])) + self.assertEqual( + singletonA.combine(lots).get_result(), + set([('a', 'a', False), ('c', 'c', False), ('d', 'd', False)])) + self.assertEqual( + lots.combine(singletonA).get_result(), + set([('a', 'a', False), ('c', 'c', False), ('d', 'd', False)])) + + def test_bounded_trie_data_combine_trim(self): + left = _BoundedTrieNode() + left.add_all([('a', 'x'), ('b', 'd')]) + right = _BoundedTrieNode() + right.add_all([('a', 'y'), ('c', 'd')]) + self.assertEqual( + BoundedTrieData(root=left).combine( + BoundedTrieData(root=right, bound=3)).get_result(), + set([('a', True), ('b', 'd', False), ('c', 'd', False)])) + + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/runners/direct/direct_metrics.py b/sdks/python/apache_beam/runners/direct/direct_metrics.py index f715ce3bf521..693c0a64538e 100644 --- a/sdks/python/apache_beam/runners/direct/direct_metrics.py +++ b/sdks/python/apache_beam/runners/direct/direct_metrics.py @@ -25,22 +25,81 @@ import threading from collections import defaultdict -from apache_beam.metrics.cells import CounterAggregator -from apache_beam.metrics.cells import DistributionAggregator -from apache_beam.metrics.cells import GaugeAggregator -from apache_beam.metrics.cells import StringSetAggregator +from apache_beam.metrics.cells import DistributionData +from apache_beam.metrics.cells import GaugeData +from apache_beam.metrics.cells import StringSetData from apache_beam.metrics.execution import MetricKey from apache_beam.metrics.execution import MetricResult from apache_beam.metrics.metric import MetricResults +class MetricAggregator(object): + """For internal use only; no backwards-compatibility guarantees. + + Base interface for aggregating metric data during pipeline execution.""" + def identity_element(self): + # type: () -> Any + + """Returns the identical element of an Aggregation. + + For the identity element, it must hold that + Aggregator.combine(any_element, identity_element) == any_element. + """ + raise NotImplementedError + + def combine(self, x, y): + # type: (Any, Any) -> Any + raise NotImplementedError + + def result(self, x): + # type: (Any) -> Any + raise NotImplementedError + + +class CounterAggregator(MetricAggregator): + """For internal use only; no backwards-compatibility guarantees. + + Aggregator for Counter metric data during pipeline execution. + + Values aggregated should be ``int`` objects. + """ + @staticmethod + def identity_element(): + # type: () -> int + return 0 + + def combine(self, x, y): + # type: (SupportsInt, SupportsInt) -> int + return int(x) + int(y) + + def result(self, x): + # type: (SupportsInt) -> int + return int(x) + + +class GenericAggregator(MetricAggregator): + def __init__(self, data_class): + self._data_class = data_class + + def identity_element(self): + return self._data_class.identity_element() + + def combine(self, x, y): + return x.combine(y) + + def result(self, x): + return x.get_result() + + class DirectMetrics(MetricResults): def __init__(self): self._counters = defaultdict(lambda: DirectMetric(CounterAggregator())) self._distributions = defaultdict( - lambda: DirectMetric(DistributionAggregator())) - self._gauges = defaultdict(lambda: DirectMetric(GaugeAggregator())) - self._string_sets = defaultdict(lambda: DirectMetric(StringSetAggregator())) + lambda: DirectMetric(GenericAggregator(DistributionData))) + self._gauges = defaultdict( + lambda: DirectMetric(GenericAggregator(GuageData))) + self._string_sets = defaultdict( + lambda: DirectMetric(GenericAggregator(StringSetData))) def _apply_operation(self, bundle, updates, op): for k, v in updates.counters.items():