Skip to content

Commit

Permalink
[ENH] better use of relative directories for paths
Browse files Browse the repository at this point in the history
  • Loading branch information
YannDubs committed Aug 26, 2024
1 parent 66e0772 commit 37393dd
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 19 deletions.
2 changes: 1 addition & 1 deletion src/alpaca_eval/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def get_completions(configs, df: pd.DataFrame, old_output_path: Optional[Path] =
if len(curr_outputs) > 0:
prompts, _ = utils.make_prompts(
curr_outputs,
template=utils.read_or_return(base_dir / configs["prompt_template"]),
template=utils.read_or_return(configs["prompt_template"], relative_to=base_dir),
)
fn_completions = decoders.get_fn_completions(configs["fn_completions"])
completions = fn_completions(prompts=prompts, **configs["completions_kwargs"])["completions"]
Expand Down
52 changes: 34 additions & 18 deletions src/alpaca_eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,18 @@
DUMMY_EXAMPLE = dict(instruction="1+1=", output_1="2", input="", output_2="3")


def read_or_return(to_read: Union[AnyPath, str], **kwargs):
def read_or_return(to_read: Union[AnyPath, str], relative_to: Optional[str] = None, **kwargs):
"""Read a file or return the input if it is already a string."""
is_str_input = isinstance(to_read, str)

if relative_to is not None:
to_read = Path(relative_to) / to_read

try:
with open(Path(to_read), **kwargs) as f:
out = f.read()
except FileNotFoundError as e:
if Path(to_read).is_absolute():
if is_str_input and Path(to_read).is_absolute():
# The path is not absolute, so it's not just a string
raise e

Expand Down Expand Up @@ -235,7 +240,7 @@ def check_imports(modules: Sequence[str], to_use: str = "this fnction"):


def check_pkg_atleast_version(package, atleast_version):
curr_version = pkg_resources.get_distribution(package).version
curr_version = get_package_version(package)
return pkg_resources.parse_version(curr_version) > pkg_resources.parse_version(atleast_version)


Expand Down Expand Up @@ -310,33 +315,39 @@ def __exit__(self, a, b, c):
logging.disable(logging.NOTSET)


def contains_list(text):
"""Check if the text contains a list / bullet points...."""

# Bullet points or '*' list items
bullet_point_pattern = r"(\s*•\s*|\s*\*\s*)(\w+)"
def contains_list(text: str, is_return_count: bool = False) -> Union[bool, int]:
"""
Check if the text contains list items or count the number of list items.
# Numbered lists with '.' or ')'
number_list_pattern = r"(\s*\d+\.|\s*\d+\))\s*(\w+)"
Parameters:
- text (str): The input text to check.
- is_return_count (bool): If True, return the count of list items. If False, return a boolean indicating presence of list items.
# Alphabetic lists with '.' or ')'
alpha_list_pattern = r"(\s*[a-zA-Z]\.|\s*[a-zA-Z]\))\s*(\w+)"
Returns:
- int or bool: Returns the count of list items if is_return_count is True, otherwise returns True/False.
"""

# List items starting with a dash '-'
dash_list_pattern = r"(\s*-\s*)(\w+)"
# Patterns to match different types of list items
bullet_point_pattern = r"(?m)^\s*[\*\•\-]\s*[^\w]*\w+" # Matches bullets like "* item" or "• item" or "- item"
number_list_pattern = r"(?m)^\s*\d+[\.\)]\s*[^\w]*\w+" # Matches numbered lists like "1. item" or "1) item"
alpha_list_pattern = r"(?m)^\s*[a-zA-Z][\.\)]\s+[^\w]*\w+" # Matches lettered lists like "a. item" or "A) item"

patterns = [
bullet_point_pattern,
number_list_pattern,
alpha_list_pattern,
dash_list_pattern,
]

count = 0

for pattern in patterns:
if re.search(pattern, text):
return True
matches = re.findall(pattern, text)
count += len(matches)

return False
if is_return_count:
return count
else:
return count > 0


def prioritize_elements(lst: list, elements: Sequence) -> list:
Expand Down Expand Up @@ -653,3 +664,8 @@ def _string_to_dict(to_convert):
def get_package_version(package_name: str) -> str:
"""Get the version of a package."""
return pkg_resources.get_distribution(package_name).version


def get_multi_package_version(package_names: Sequence[str]) -> str:
"""Get the version of multiple packages."""
return " ".join([f"{p}=={get_package_version(p)}" for p in package_names])

0 comments on commit 37393dd

Please sign in to comment.