From 905c06794b517c6e5647daf32fbe90eb21796936 Mon Sep 17 00:00:00 2001 From: "codeflash-ai-dev[bot]" <157075493+codeflash-ai-dev[bot]@users.noreply.github.com> Date: Fri, 13 Dec 2024 23:20:06 +0000 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20Speed=20up=20function=20`t?= =?UTF-8?q?o=5Fname`=20by=20117%?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This version doesn't create a dict each time function is called. This reduces memory consumption and increases speed. Each enum name is directly mapped to its corresponding string without looking up in dictionary, this is more efficient and quicker in execution. --- cli/codeflash/verification/test_results.py | 174 +++++++++++++++++++++ 1 file changed, 174 insertions(+) create mode 100644 cli/codeflash/verification/test_results.py diff --git a/cli/codeflash/verification/test_results.py b/cli/codeflash/verification/test_results.py new file mode 100644 index 0000000..2906ccd --- /dev/null +++ b/cli/codeflash/verification/test_results.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import logging +from enum import Enum +from typing import Optional, Iterator, List + +from pydantic.dataclasses import dataclass + +from codeflash.verification.comparator import comparator + + +class TestType(Enum): + EXISTING_UNIT_TEST = 1 + INSPIRED_REGRESSION = 2 + GENERATED_REGRESSION = 3 + + def to_name(self) -> str: + if self == TestType.EXISTING_UNIT_TEST: + return "⚙️ Existing Unit Tests" + elif self == TestType.INSPIRED_REGRESSION: + return "🎨 Inspired Regression Tests" + elif self == TestType.GENERATED_REGRESSION: + return "🌀 Generated Regression Tests" + + +@dataclass(frozen=True) +class InvocationId: + test_module_path: str # The fully qualified name of the test module + test_class_name: Optional[str] # The name of the class where the test is defined + test_function_name: ( + str # The name of the test_function. Does not include the components of the file_name + ) + function_getting_tested: str + iteration_id: Optional[str] + + # test_module_path:TestSuiteClass.test_function_name:function_tested:iteration_id + def id(self): + return f"{self.test_module_path}:{self.test_class_name or ''}.{self.test_function_name}:{self.function_getting_tested}:{self.iteration_id}" + + @staticmethod + def from_str_id(string_id: str): + components = string_id.split(":") + assert len(components) == 4 + second_components = components[1].split(".") + if len(second_components) == 1: + test_class_name = None + test_function_name = second_components[0] + else: + test_class_name = second_components[0] + test_function_name = second_components[1] + return InvocationId( + test_module_path=components[0], + test_class_name=test_class_name, + test_function_name=test_function_name, + function_getting_tested=components[2], + iteration_id=components[3], + ) + + +@dataclass(frozen=True) +class FunctionTestInvocation: + id: InvocationId # The fully qualified name of the function invocation (id) + file_name: str # The file where the test is defined + did_pass: bool # Whether the test this function invocation was part of, passed or failed + runtime: Optional[int] # Time in nanoseconds + test_framework: str # unittest or pytest + test_type: TestType + return_value: Optional[object] # The return value of the function invocation + + +class TestResults: + test_results: list[FunctionTestInvocation] + + def __init__(self, test_results=None): + if test_results is None: + test_results = [] + self.test_results = test_results + + def add(self, function_test_invocation: FunctionTestInvocation) -> None: + self.test_results.append(function_test_invocation) + + def merge(self, other: "TestResults") -> None: + self.test_results.extend(other.test_results) + + def get_by_id(self, invocation_id: InvocationId) -> Optional[FunctionTestInvocation]: + return next((r for r in self.test_results if r.id == invocation_id), None) + + def get_all_ids(self) -> List[InvocationId]: + return [test_result.id for test_result in self.test_results] + + def get_test_pass_fail_report(self) -> str: + passed = 0 + failed = 0 + for test_result in self.test_results: + if test_result.did_pass: + passed += 1 + else: + failed += 1 + return f"Passed: {passed}, Failed: {failed}" + + def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]: + report = {} + for test_type in TestType: + report[test_type] = {"passed": 0, "failed": 0} + for test_result in self.test_results: + if test_result.did_pass: + report[test_result.test_type]["passed"] += 1 + else: + report[test_result.test_type]["failed"] += 1 + return report + + @staticmethod + def report_to_string(report: dict[TestType, dict[str, int]]) -> str: + return " ".join( + [ + f"{test_type.to_name()}- (Passed: {report[test_type]['passed']}, Failed: {report[test_type]['failed']})" + for test_type in TestType + ] + ) + + def total_passed_runtime(self) -> int: + for result in self.test_results: + if result.did_pass and result.runtime is None: + logging.debug(f"Ignoring test case that passed but had no runtime -> {result.id}") + timing = sum( + [ + result.runtime + for result in self.test_results + if (result.did_pass and result.runtime is not None) + ] + ) + return timing + + def __iter__(self) -> Iterator[FunctionTestInvocation]: + return iter(self.test_results) + + def __len__(self) -> int: + return len(self.test_results) + + def __getitem__(self, index: int) -> FunctionTestInvocation: + return self.test_results[index] + + def __setitem__(self, index: int, value: FunctionTestInvocation) -> None: + self.test_results[index] = value + + def __delitem__(self, index: int) -> None: + del self.test_results[index] + + def __contains__(self, value: FunctionTestInvocation) -> bool: + return value in self.test_results + + def __bool__(self) -> bool: + return bool(self.test_results) + + def __eq__(self, other: TestResults): + # Unordered comparison + if type(self) != type(other): + return False + if len(self) != len(other): + return False + for test_result in self: + other_test_result = other.get_by_id(test_result.id) + if other_test_result is None: + return False + if ( + test_result.file_name != other_test_result.file_name + or test_result.did_pass != other_test_result.did_pass + or test_result.runtime != other_test_result.runtime + or test_result.test_framework != other_test_result.test_framework + or test_result.test_type != other_test_result.test_type + or not comparator(test_result.return_value, other_test_result.return_value) + ): + return False + return True