Skip to content

Commit

Permalink
#40 fixed: refactoring cmd args parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Dec 10, 2024
1 parent 9e60a34 commit e7913ae
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 22 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ Just **three** simple steps:
!python -m bulk_chain.infer \
--schema "default.json" \
--adapter "dynamic:flan_t5.py:FlanT5" \
%% \
%%m \
--device "cpu" \
--temp 0.1
```
Expand Down
31 changes: 25 additions & 6 deletions bulk_chain/core/service_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,33 @@ def __release():
yield __release()

@staticmethod
def partition_list(lst, sep):
def __find_suffix_ind(lst, idx_from, end_prefix):
for i in range(idx_from, len(lst)):
if lst[i].startswith(end_prefix):
return i
return len(lst)

@staticmethod
def extract_native_args(lst, end_prefix):
return lst[:CmdArgsService.__find_suffix_ind(lst, idx_from=0, end_prefix=end_prefix)]

@staticmethod
def find_grouped_args(lst, starts_with, end_prefix):
"""Slices a list in two, cutting on index matching "sep"
"""
if sep in lst:
idx = lst.index(sep)
return (lst[:idx], lst[idx+1:])
else:
return (lst[:], None)

# Checking the presence of starts_with.
# We have to return empty content in the case of absence starts_with in the lst.
if starts_with not in lst:
return []

# Assigning start index.
idx_from = lst.index(starts_with) + 1

# Assigning end index.
idx_to = CmdArgsService.__find_suffix_ind(lst, idx_from=idx_from, end_prefix=end_prefix)

return lst[idx_from:idx_to]

@staticmethod
def args_to_dict(args):
Expand Down
15 changes: 10 additions & 5 deletions bulk_chain/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,16 @@ def optional_update_data_records(c, data):
parser.add_argument('--limit-prompt', dest="limit_prompt", type=int, default=None,
help="Optional trimming prompt by the specified amount of characters.")

native_args, model_args = CmdArgsService.partition_list(lst=sys.argv, sep="%%")

# Extract native arguments.
native_args = CmdArgsService.extract_native_args(sys.argv, end_prefix="%%")
args = parser.parse_args(args=native_args[1:])

# Initialize Large Language Model.
# Extract csv-related arguments.
csv_args = CmdArgsService.find_grouped_args(lst=sys.argv, starts_with="%%csv", end_prefix="%%")
csv_args_dict = CmdArgsService.args_to_dict(csv_args)

# Extract model-related arguments and Initialize Large Language Model.
model_args = CmdArgsService.find_grouped_args(lst=sys.argv, starts_with="%%m", end_prefix="%%")
model_args_dict = CmdArgsService.args_to_dict(model_args) | {"attempts": args.attempts}
llm, llm_model_name = init_llm(**model_args_dict)

Expand All @@ -128,8 +133,8 @@ def optional_update_data_records(c, data):
None: lambda _: chat_with_lm(llm, chain=schema.chain, model_name=llm_model_name),
"csv": lambda filepath: CsvService.read(src=filepath, row_id_key=args.id_col,
as_dict=True, skip_header=True,
delimiter=model_args_dict.get("delimiter", "\t"),
escapechar=model_args_dict.get("escapechar", None)),
delimiter=csv_args_dict.get("delimiter", "\t"),
escapechar=csv_args_dict.get("escapechar", None)),
"jsonl": lambda filepath: JsonlService.read(src=filepath, row_id_key=args.id_col)
}

Expand Down
3 changes: 0 additions & 3 deletions ext/openai_156.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ def __init__(self, api_key, model_name="gpt-4-1106-preview", temp=0.1, max_token
self.__freq_penalty = freq_penalty
self.__kwargs = {} if kwargs is None else kwargs

if "delimiter" in self.__kwargs:
del self.__kwargs["delimiter"]

if suppress_httpx_log:
httpx_logger = logging.getLogger("httpx")
httpx_logger.setLevel(logging.WARNING)
Expand Down
3 changes: 0 additions & 3 deletions ext/openai_o1.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ def __init__(self, api_key, model_name="o1-preview-2024-09-12", assistant_prompt
self.__freq_penalty = freq_penalty
self.__kwargs = {} if kwargs is None else kwargs

if "delimiter" in self.__kwargs:
del self.__kwargs["delimiter"]

if suppress_httpx_log:
httpx_logger = logging.getLogger("httpx")
httpx_logger.setLevel(logging.WARNING)
Expand Down
19 changes: 15 additions & 4 deletions test/test_cmdargs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
import sys

from bulk_chain.core.service_args import CmdArgsService


print(sys.argv)
string, args = CmdArgsService.partition_list(sys.argv[1:], "%%")
d = CmdArgsService.args_to_dict(args)
# Csv-related.
csv_args = CmdArgsService.find_grouped_args(sys.argv, starts_with="%%csv", end_prefix="%%")
print(csv_args)
csv_args = CmdArgsService.args_to_dict(csv_args)
print("csv\t", csv_args)

# Model-related.
m_args = CmdArgsService.find_grouped_args(sys.argv, starts_with="%%m", end_prefix="%%")
m_args = CmdArgsService.args_to_dict(m_args)
print("mod\t", m_args)

print(d)
# native.
n_args = CmdArgsService.extract_native_args(sys.argv, end_prefix="%%")
n_args = CmdArgsService.args_to_dict(n_args)
print("nat\t", n_args)

0 comments on commit e7913ae

Please sign in to comment.