Skip to content

Commit

Permalink
Reformat the code
Browse files Browse the repository at this point in the history
  • Loading branch information
nik committed Nov 30, 2023
1 parent e73e317 commit d7c2e23
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions adala/runtimes/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def check_if_new_openai_version():

@retry(wait=wait_random(min=5, max=10), stop=stop_after_attempt(3))
def chat_completion_call(model, messages):
return openai.ChatCompletion.create(model=model, messages=messages, timeout=120, request_timeout=120)
return openai.ChatCompletion.create(
model=model, messages=messages, timeout=120, request_timeout=120
)


class OpenAIChatRuntime(Runtime):
Expand Down Expand Up @@ -161,9 +163,12 @@ def record_to_record(
completion_text = self.execute(messages)

field_schema = field_schema or {}
if output_field_name in field_schema and field_schema[output_field_name]["type"] == "array":
if (
output_field_name in field_schema
and field_schema[output_field_name]["type"] == "array"
):
# expected output is one item from the array
expected_items = field_schema[output_field_name]['items']['enum']
expected_items = field_schema[output_field_name]["items"]["enum"]
completion_text = self._match_items(completion_text, expected_items)

return {output_field_name: completion_text}
Expand All @@ -176,7 +181,12 @@ def _match_items(self, query: str, items: List[str]) -> str:
filtered_items = items

# soft constraint: find the most similar item to the query
scores = list(map(lambda item: difflib.SequenceMatcher(None, query, item).ratio(), filtered_items))
scores = list(
map(
lambda item: difflib.SequenceMatcher(None, query, item).ratio(),
filtered_items,
)
)
matched_item = filtered_items[scores.index(max(scores))]
return matched_item

Expand Down

0 comments on commit d7c2e23

Please sign in to comment.