From 37c2eb33be0d513f2640e6f72ab96d851bb91dc1 Mon Sep 17 00:00:00 2001 From: atuonufure Date: Mon, 28 Aug 2023 11:19:22 +0200 Subject: [PATCH] Refactor equivalence --- fhirpathpy/engine/invocations/equality.py | 48 +++++++++++++++-------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/fhirpathpy/engine/invocations/equality.py b/fhirpathpy/engine/invocations/equality.py index 6b23846..ea37106 100644 --- a/fhirpathpy/engine/invocations/equality.py +++ b/fhirpathpy/engine/invocations/equality.py @@ -19,6 +19,28 @@ def equality(ctx, x, y): return x == y +def normalize_string(s): + return " ".join(s.lower().split()) + + +def decimal_places(a): + match = str(a).split(".") + return len(match[1]) if len(match) > 1 else 0 + + +def round_to_decimal_places(a, n): + rounding_format = Decimal("10") ** -n + return Decimal(a).quantize(rounding_format) + + +def is_equivalent(a, b): + precision = min(decimal_places(a), decimal_places(b)) + if precision == 0: + return round(a) == round(b) + else: + return round_to_decimal_places(a, precision) == round_to_decimal_places(b, precision) + + def equivalence(ctx, x, y): if util.is_empty(x) and util.is_empty(y): return True @@ -26,25 +48,17 @@ def equivalence(ctx, x, y): if util.is_empty(x) or util.is_empty(y): return False - if type(x[0]) in DATETIME_NODES_LIST or type(y[0]) in DATETIME_NODES_LIST: + a = util.get_data(x[0]) + b = util.get_data(y[0]) + + if type(a) in DATETIME_NODES_LIST or type(b) in DATETIME_NODES_LIST: return datetime_equality(ctx, x, y) - # string: the strings must be the same - # while ignoring case and normalizing whitespace. - if isinstance(x[0], str) and isinstance(y[0], str): - return " ".join(x[0].lower().split()) == " ".join(y[0].lower().split()) - - # decimal: values must be equal, comparison is done on values rounded - # to the precision of the least precise operand. - # Trailing zeroes are ignored in determining precision. - if isinstance(x[0], Decimal) or isinstance(y[0], Decimal): - precision_x = len(str(x[0]).split(".")[1]) if "." in str(x[0]) else 0 - precision_y = len(str(y[0]).split(".")[1]) if "." in str(y[0]) else 0 - least_precision = min(precision_x, precision_y) - rounding_format = Decimal("10") ** -least_precision - rounded_x = Decimal(x[0]).quantize(rounding_format) - rounded_y = Decimal(y[0]).quantize(rounding_format) - return rounded_x == rounded_y + if isinstance(a, str) and isinstance(b, str): + return normalize_string(a) == normalize_string(b) + + if isinstance(a, Decimal) or isinstance(b, Decimal): + return is_equivalent(a, b) return x == y