Skip to content

Commit

Permalink
⚡️ Speed up function to_name by 117%
Browse files Browse the repository at this point in the history
This version doesn't create a dict each time function is called. This reduces memory consumption and increases speed. Each enum name is directly mapped to its corresponding string without looking up in dictionary, this is more efficient and quicker in execution.
  • Loading branch information
codeflash-ai-dev[bot] authored Dec 13, 2024
1 parent 97f70aa commit cc28949
Showing 1 changed file with 174 additions and 0 deletions.
174 changes: 174 additions & 0 deletions cli/codeflash/verification/test_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
from __future__ import annotations

import logging
from enum import Enum
from typing import Optional, Iterator, List

from pydantic.dataclasses import dataclass

from codeflash.verification.comparator import comparator


class TestType(Enum):
EXISTING_UNIT_TEST = 1
INSPIRED_REGRESSION = 2
GENERATED_REGRESSION = 3

def to_name(self) -> str:
if self == TestType.EXISTING_UNIT_TEST:
return "⚙️ Existing Unit Tests"
elif self == TestType.INSPIRED_REGRESSION:
return "🎨 Inspired Regression Tests"
elif self == TestType.GENERATED_REGRESSION:
return "🌀 Generated Regression Tests"


@dataclass(frozen=True)
class InvocationId:
test_module_path: str # The fully qualified name of the test module
test_class_name: Optional[str] # The name of the class where the test is defined
test_function_name: (
str # The name of the test_function. Does not include the components of the file_name
)
function_getting_tested: str
iteration_id: Optional[str]

# test_module_path:TestSuiteClass.test_function_name:function_tested:iteration_id
def id(self):
return f"{self.test_module_path}:{self.test_class_name or ''}.{self.test_function_name}:{self.function_getting_tested}:{self.iteration_id}"

@staticmethod
def from_str_id(string_id: str):
components = string_id.split(":")
assert len(components) == 4
second_components = components[1].split(".")
if len(second_components) == 1:
test_class_name = None
test_function_name = second_components[0]
else:
test_class_name = second_components[0]
test_function_name = second_components[1]
return InvocationId(
test_module_path=components[0],
test_class_name=test_class_name,
test_function_name=test_function_name,
function_getting_tested=components[2],
iteration_id=components[3],
)


@dataclass(frozen=True)
class FunctionTestInvocation:
id: InvocationId # The fully qualified name of the function invocation (id)
file_name: str # The file where the test is defined
did_pass: bool # Whether the test this function invocation was part of, passed or failed
runtime: Optional[int] # Time in nanoseconds
test_framework: str # unittest or pytest
test_type: TestType
return_value: Optional[object] # The return value of the function invocation


class TestResults:
test_results: list[FunctionTestInvocation]

def __init__(self, test_results=None):
if test_results is None:
test_results = []
self.test_results = test_results

def add(self, function_test_invocation: FunctionTestInvocation) -> None:
self.test_results.append(function_test_invocation)

def merge(self, other: "TestResults") -> None:
self.test_results.extend(other.test_results)

def get_by_id(self, invocation_id: InvocationId) -> Optional[FunctionTestInvocation]:
return next((r for r in self.test_results if r.id == invocation_id), None)

def get_all_ids(self) -> List[InvocationId]:
return [test_result.id for test_result in self.test_results]

def get_test_pass_fail_report(self) -> str:
passed = 0
failed = 0
for test_result in self.test_results:
if test_result.did_pass:
passed += 1
else:
failed += 1
return f"Passed: {passed}, Failed: {failed}"

def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]:
report = {}
for test_type in TestType:
report[test_type] = {"passed": 0, "failed": 0}
for test_result in self.test_results:
if test_result.did_pass:
report[test_result.test_type]["passed"] += 1
else:
report[test_result.test_type]["failed"] += 1
return report

@staticmethod
def report_to_string(report: dict[TestType, dict[str, int]]) -> str:
return " ".join(
[
f"{test_type.to_name()}- (Passed: {report[test_type]['passed']}, Failed: {report[test_type]['failed']})"
for test_type in TestType
]
)

def total_passed_runtime(self) -> int:
for result in self.test_results:
if result.did_pass and result.runtime is None:
logging.debug(f"Ignoring test case that passed but had no runtime -> {result.id}")
timing = sum(
[
result.runtime
for result in self.test_results
if (result.did_pass and result.runtime is not None)
]
)
return timing

def __iter__(self) -> Iterator[FunctionTestInvocation]:
return iter(self.test_results)

def __len__(self) -> int:
return len(self.test_results)

def __getitem__(self, index: int) -> FunctionTestInvocation:
return self.test_results[index]

def __setitem__(self, index: int, value: FunctionTestInvocation) -> None:
self.test_results[index] = value

def __delitem__(self, index: int) -> None:
del self.test_results[index]

def __contains__(self, value: FunctionTestInvocation) -> bool:
return value in self.test_results

def __bool__(self) -> bool:
return bool(self.test_results)

def __eq__(self, other: TestResults):
# Unordered comparison
if type(self) != type(other):
return False
if len(self) != len(other):
return False
for test_result in self:
other_test_result = other.get_by_id(test_result.id)
if other_test_result is None:
return False
if (
test_result.file_name != other_test_result.file_name
or test_result.did_pass != other_test_result.did_pass
or test_result.runtime != other_test_result.runtime
or test_result.test_framework != other_test_result.test_framework
or test_result.test_type != other_test_result.test_type
or not comparator(test_result.return_value, other_test_result.return_value)
):
return False
return True

0 comments on commit cc28949

Please sign in to comment.