diff --git a/.gitignore b/.gitignore index 68bc17f..2dc53ca 100644 --- a/.gitignore +++ b/.gitignore @@ -157,4 +157,4 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ diff --git a/README.md b/README.md index 31011f6..1da1f40 100644 --- a/README.md +++ b/README.md @@ -123,7 +123,7 @@ print(valid) Creating a repeatable jq query for extracitng data from identically formatted input JSONs ```python -jq_query = jaiqu.translate_schema(input_json, schema, key_hints, max_retries=30) +jq_query = translate_schema(input_json, schema, key_hints, max_retries=30) >>>'{"id": .attributes["call.id"], "date": .datetime}' ``` @@ -137,22 +137,8 @@ pip install jaiqu ## Architecture -Unraveling the Jaiqu agentic workflow pattern -```mermaid -flowchart TD - A[Start translate_schema] --> B{Validate input schema} - B -- Valid --> C[For each key, create a jq filter query] - B -- Invalid --> D[Throw RuntimeError] - C --> E[Compile and Test jq Filter] - E -- Success --> F[Validate JSON] - E -- Fail --> G[Retry Create jq Filter] - G -- Success --> E - G -- Fail n times--> H[Throw RuntimeError] - F -- Success --> I[Return jq query string] - F -- Fail --> J[Retry Validate JSON] - J -- Success --> I - J -- Fail n times --> K[Throw RuntimeError] -``` +Unraveling the Jaiqu agentic workflow pattern: +![jaiqu](jaiqu_app.png) ## Running tests diff --git a/jaiqu/__init__.py b/jaiqu/__init__.py index 5741c0e..f2b881e 100644 --- a/jaiqu/__init__.py +++ b/jaiqu/__init__.py @@ -1 +1 @@ -from .jaiqu import validate_schema, translate_schema +from .jaiqu import translate_schema, validate_schema diff --git a/jaiqu/cli.py b/jaiqu/cli.py index 44bf3b5..05facd2 100644 --- a/jaiqu/cli.py +++ b/jaiqu/cli.py @@ -30,6 +30,12 @@ def jaiqu( "--max-retries", help="Max number of retries for the ai to complete the task", ), + save_trace: bool = Option( + False, + "-s", + "--save-trace", + help="Saves a trace for introspection using the Burr UI.", + ), ): """ Validate and translate a json schema to jq filter @@ -53,7 +59,8 @@ def jaiqu( input_json=input_json, key_hints=key_hints, max_retries=max_retries, - quiet=quiet + quiet=quiet, + save_trace=save_trace, ) print(query) diff --git a/jaiqu/helpers.py b/jaiqu/helpers.py index a577283..f3de5d9 100644 --- a/jaiqu/helpers.py +++ b/jaiqu/helpers.py @@ -67,7 +67,7 @@ def identify_key(key, value, input_schema, openai_api_key=None, key_hints=None) return to_key(completion), completion -def create_jq_string(input_schema, key, value, openai_api_key) -> str: +def create_jq_string(input_schema, key, value, openai_api_key=None) -> str: messages: list[ChatCompletionMessageParam] = [{ "role": "system", "content": f"""You are a perfect jq engineer designed to validate and extract data from JSON files using jq. Only reply with code. Do NOT use any natural language. Do NOT use markdown i.e. ```. diff --git a/jaiqu/jaiqu.py b/jaiqu/jaiqu.py index 57c2210..273bd7f 100644 --- a/jaiqu/jaiqu.py +++ b/jaiqu/jaiqu.py @@ -1,12 +1,20 @@ +import logging +import os + +import burr.core +from burr.core import ApplicationBuilder, State, default, expr, when +from burr.core.action import action import jq -import json from jsonschema import validate from tqdm.auto import tqdm # Use the auto submodule for notebook-friendly output if necessary from .helpers import identify_key, create_jq_string, repair_query, dict_to_jq_filter +logger = logging.getLogger(__name__) + -def validate_schema(input_json: dict, output_schema: dict, openai_api_key: str | None = None, key_hints=None, quiet=False) -> tuple[dict, bool]: +def validate_schema(input_json: dict, output_schema: dict, openai_api_key: str | None = None, key_hints=None, + quiet=False) -> tuple[dict, bool]: """Validates the schema of the input JSON against the output schema. Args: input_json (dict): The input JSON parsed into a dictionary. @@ -15,9 +23,9 @@ def validate_schema(input_json: dict, output_schema: dict, openai_api_key: str | key_hints (any, optional): Key hints to assist in identifying keys. Defaults to None. Returns: - tuple[dict, bool]: A tuple containing the results of the validation and a boolean indicating if the validation was successful. + tuple[dict, bool]: A tuple containing the results of the validation and a boolean indicating if the + validation was successful. """ - results = {} valid = True with tqdm(total=len(output_schema['properties']), desc="Validating schema", disable=quiet) as pbar: @@ -41,42 +49,45 @@ def validate_schema(input_json: dict, output_schema: dict, openai_api_key: str | results[key]['required'] = False pbar.update(1) - return results, valid - - -def translate_schema(input_json, output_schema, openai_api_key: str | None = None, key_hints=None, max_retries=10, quiet=False) -> str: - """ - Translate the input JSON schema into a filtering query using jq. - - Args: - input_json (dict): The input JSON to be reformatted. - output_schema (dict): The desired output schema using standard schema formatting. - openai_api_key (str, optional): OpenAI API key. Defaults to None. - key_hints (None, optional): Hints for translating keys. Defaults to None. - max_retries (int, optional): Maximum number of retries for creating a valid jq filter. Defaults to 10. - - Returns: - str: The filtering query in jq syntax. - - Raises: - RuntimeError: If the input JSON does not contain the required data to satisfy the output schema. - RuntimeError: If failed to create a valid jq filter after maximum retries. - RuntimeError: If failed to validate the jq filter after maximum retries. - """ - - schema_properties, is_valid = validate_schema(input_json, output_schema, key_hints=key_hints, openai_api_key=openai_api_key, quiet=quiet) - if not is_valid: - raise RuntimeError( - f"The input JSON does not contain the required data to satisfy the output schema: \n\n{json.dumps(schema_properties, indent=2)}") - + return results, valid + + +@action( + reads=["input_json", "output_schema", "key_hints", "quiet"], + writes=["valid_schema", "schema_properties"] +) +def validate_schema_action(state: State) -> tuple[dict, State]: + """Action to validate the provided input schema.""" + output_schema = state["output_schema"] + input_json = state["input_json"] + key_hints = state["key_hints"] + quiet = state.get("quiet", False) + results, valid = validate_schema(input_json, output_schema, key_hints=key_hints, openai_api_key=None, quiet=quiet) + state = state.update(valid_schema=valid, schema_properties=results) + return results, state + + +@action( + reads=["input_json", "schema_properties", "max_retries", "quiet"], + writes=["max_retries_hit", "jq_filter"] +) +def create_jq_filter_query(state: State) -> tuple[dict, State]: + """Creates the JQ filter query.""" + schema_properties = state["schema_properties"] + input_json = state["input_json"] + max_retries = state["max_retries"] filtered_schema = {k: v for k, v in schema_properties.items() if v['identified'] == True} - + quiet = state.get("quiet", False) filter_query = {} - with tqdm(total=len(filtered_schema), desc="Translating schema", disable=quiet) as pbar, tqdm(total=max_retries, desc="Retry attempts", disable=quiet) as pbar_retries: + with tqdm(total=len(filtered_schema), + desc="Translating schema", + disable=quiet) as pbar, tqdm(total=max_retries, + desc="Retry attempts", + disable=quiet) as pbar_retries: for key, value in filtered_schema.items(): pbar.set_postfix_str(f"Key: {key}", refresh=True) - jq_string = create_jq_string(input_json, key, value, openai_api_key) + jq_string = create_jq_string(input_json, key, value) if jq_string == "None": # If the response is empty, skip the key pbar.update(1) @@ -85,20 +96,36 @@ def translate_schema(input_json, output_schema, openai_api_key: str | None = Non tries = 0 while True: try: - key_query = jq.compile(jq_string).input(input_json).all() + jq.compile(jq_string).input(input_json).all() break except Exception as e: tries += 1 pbar_retries.update(1) - jq_string = repair_query(jq_string, str(e), input_json, openai_api_key) + jq_string = repair_query(jq_string, str(e), input_json, None) if tries >= max_retries: - raise RuntimeError( - f"Failed to create a valid jq filter for key '{key}' after {max_retries} retries.") + state = state.update(max_retries_hit=True, jq_filter=None) + return { + "error": f"Failed to create a valid jq filter for key '{key}' after {max_retries} retries."}, state pbar.update(1) filter_query[key] = jq_string pbar.close() pbar_retries.close() complete_filter = dict_to_jq_filter(filter_query) + state = state.update(jq_filter=complete_filter, max_retries_hit=False) + return {"filter": complete_filter}, state + + +@action( + reads=["input_json", "jq_filter", "output_schema", "max_retries", "quiet"], + writes=["max_retries_hit", "valid_json", "complete_filter"] +) +def validate_json(state: State) -> tuple[dict, State]: + """Validates the filter JSON.""" + output_schema = state["output_schema"] + complete_filter = state["jq_filter"] + input_json = state["input_json"] + max_retries = state["max_retries"] + quiet = state.get("quiet", False) # Validate JSON tries = 0 with tqdm(total=max_retries, desc="Validation attempts", disable=quiet) as pbar_validation: @@ -112,7 +139,114 @@ def translate_schema(input_json, output_schema, openai_api_key: str | None = Non tries += 1 pbar_validation.update(1) if tries >= max_retries: - raise RuntimeError(f"Failed to validate the jq filter after {max_retries} retries.") - complete_filter = repair_query(complete_filter, str(e), input_json, openai_api_key) - pbar.close() - return complete_filter + state = state.update(max_retries_hit=True, valid_json=False, complete_filter=complete_filter) + return { + "error": f"Failed to validate the jq filter after {max_retries} retries."}, state + + complete_filter = repair_query(complete_filter, str(e), input_json, None) + state = state.update(complete_filter=complete_filter, max_retries_hit=False, valid_json=True) + return {"complete_filter": complete_filter}, state + + +def translate_schema(input_json: dict, output_schema: dict, openai_api_key: str | None = None, key_hints: str = None, + max_retries: int = 10, quiet: bool = False, save_trace: bool = False) -> str: + """ + Translate the input JSON schema into a filtering query using jq. + + Args: + input_json (dict): The input JSON to be reformatted. + output_schema (dict): The desired output schema using standard schema formatting. + openai_api_key (str, optional): OpenAI API key. Defaults to None. + key_hints (str, optional): Hints for translating keys. Defaults to None. + max_retries (int, optional): Maximum number of retries for creating a valid jq filter. Defaults to 10. + quiet (bool, optional): Quiet mode to turn off TQDM progress bars. Defaults to False. + save_trace (bool, optional): turn on Burr tracking to debug jaiqu runs. Defaults to False. + + Returns: + str: The filtering query in jq syntax. + + Raises: + RuntimeError: If the input JSON does not contain the required data to satisfy the output schema. + RuntimeError: If failed to create a valid jq filter after maximum retries. + RuntimeError: If failed to validate the jq filter after maximum retries. + """ + if openai_api_key is not None: + os.environ["OPENAI_API_KEY"] = openai_api_key + app = build_application(input_json, output_schema, + key_hints=key_hints, max_retries=max_retries, quiet=quiet, save_trace=save_trace) + last_action, result, state = app.run(halt_after=["error_state", "good_result"]) + if last_action == "error_state": + raise RuntimeError(result) + return result["complete_filter"] + + +def build_application(input_json, + output_schema, + key_hints: str = None, + max_retries: int = 10, + quiet: bool = False, + save_trace: bool = False, + visualize: bool = False): + """ + Builds the application for translating the input JSON schema into a filtering query using jq. + + Args: + input_json (dict): The input JSON to be reformatted. + output_schema (dict): The desired output schema using standard schema formatting. + key_hints (str, optional): Hints for translating keys. Defaults to None. + max_retries (int, optional): Maximum number of retries for creating a valid jq filter. Defaults to 10. + quiet (bool, optional): Quiet mode to turn off TQDM progress bars. Defaults to False. + save_trace (bool, optional): Turn on Burr tracking to debug jaiqu runs. Defaults to False. + visualize (bool, optional): If set to True, visualizes the application flow. Defaults to False. + + Returns: + Application: The built application with the specified state, actions, and transitions. + """ + _app = ( + ApplicationBuilder() + .with_state( + **{ + "input_json": input_json, + "output_schema": output_schema, + "key_hints": key_hints, + "max_retries": max_retries, + "quiet": quiet, + } + ) + .with_actions( + # bind the vector store to the AI conversational step + validate_schema=validate_schema_action, + create_jq_filter_query=create_jq_filter_query, + validate_json=validate_json, + error_state=burr.core.Result("complete_filter"), + good_result=burr.core.Result("complete_filter"), + ) + .with_transitions( + ("validate_schema", "create_jq_filter_query", default), + ("create_jq_filter_query", "validate_json", when(max_retries_hit=False)), + ("create_jq_filter_query", "error_state", when(max_retries_hit=True)), + ("validate_json", "good_result", when(valid_json=True)), + ("validate_json", "error_state", when(valid_json=False)), + ("validate_schema", "good_result", when(valid_schema=True)), + ("validate_schema", "error_state", when(valid_schema=False)), + ) + .with_entrypoint("validate_schema") + ) + if save_trace: + logger.warning("To see trace information, start `burr` (pip install \"burr[start]\") in a separate terminal" + " and go to http://localhost:7241") + _app = _app.with_tracker(project="jaiqu") + _app = _app.build() + if visualize: + _app.visualize( + output_file_path="jaiqu_app", include_conditions=True, view=False, format="png" + ) + return _app + + +if __name__ == '__main__': + """Recreate the image easily""" + app = build_application("", "") + app.visualize( + output_file_path="jaiqu_app", include_conditions=True, view=False, format="png" + ) diff --git a/jaiqu_app.png b/jaiqu_app.png new file mode 100644 index 0000000..655ce66 Binary files /dev/null and b/jaiqu_app.png differ diff --git a/pyproject.toml b/pyproject.toml index dba106d..a56b409 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "jaiqu" -version = "0.0.5" +version = "0.0.6" authors = [ { name = "Alex Reibman", email = "areibman@gmail.com" }, { name = "Howard Gil", email = "howardbgil@gmail.com" }, @@ -23,6 +23,7 @@ dependencies = [ "openai~=1.12.0", "jsonschema==4.21.1", "typer==0.9.0", + "burr~=0.*", ] [project.optional-dependencies] @@ -30,6 +31,7 @@ dev = [ "pytest>=7.4.4", "flake8>=3.1.0", "coverage[toml]>=7.4.0", + "burr[start]~=0.*", ] [project.urls] diff --git a/requirements.txt b/requirements.txt index 5f10958..e06be91 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ jq==1.6.0 openai>=1.12.0,<2.0.0 jsonschema==4.21.1 -typer==0.9.0 \ No newline at end of file +typer==0.9.0 +burr~=0.* # install everything until 1.0.0 \ No newline at end of file diff --git a/jaiqu/tests/__init__.py b/tests/__init__.py similarity index 100% rename from jaiqu/tests/__init__.py rename to tests/__init__.py diff --git a/jaiqu/tests/calendar/event.schema.json b/tests/calendar/event.schema.json similarity index 100% rename from jaiqu/tests/calendar/event.schema.json rename to tests/calendar/event.schema.json diff --git a/jaiqu/tests/calendar/gcal/input.json b/tests/calendar/gcal/input.json similarity index 100% rename from jaiqu/tests/calendar/gcal/input.json rename to tests/calendar/gcal/input.json diff --git a/jaiqu/tests/calendar/outlook/input.json b/tests/calendar/outlook/input.json similarity index 100% rename from jaiqu/tests/calendar/outlook/input.json rename to tests/calendar/outlook/input.json diff --git a/jaiqu/tests/llms/anthropic/input.json b/tests/llms/anthropic/input.json similarity index 100% rename from jaiqu/tests/llms/anthropic/input.json rename to tests/llms/anthropic/input.json diff --git a/jaiqu/tests/llms/arize_openetelemetry/input.json b/tests/llms/arize_openetelemetry/input.json similarity index 100% rename from jaiqu/tests/llms/arize_openetelemetry/input.json rename to tests/llms/arize_openetelemetry/input.json diff --git a/jaiqu/tests/llms/errors.schema.json b/tests/llms/errors.schema.json similarity index 100% rename from jaiqu/tests/llms/errors.schema.json rename to tests/llms/errors.schema.json diff --git a/jaiqu/tests/llms/llms.schema.json b/tests/llms/llms.schema.json similarity index 100% rename from jaiqu/tests/llms/llms.schema.json rename to tests/llms/llms.schema.json diff --git a/jaiqu/tests/llms/openai/schema.json b/tests/llms/openai/schema.json similarity index 100% rename from jaiqu/tests/llms/openai/schema.json rename to tests/llms/openai/schema.json diff --git a/tests/test_jaiqu.py b/tests/test_jaiqu.py new file mode 100644 index 0000000..1e3da40 --- /dev/null +++ b/tests/test_jaiqu.py @@ -0,0 +1,56 @@ +import jq +from jaiqu import translate_schema + + +def test_translate_schema(): + schema = { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "id": { + "type": ["string", "null"], + "description": "A unique identifier for the record." + }, + "date": { + "type": "string", + "description": "A string describing the date." + }, + "model": { + "type": "string", + "description": "A text field representing the model used." + } + }, + "required": [ + "id", + "date" + ] + } + + # Provided data + input_json = { + "call.id": "123", + "datetime": "2022-01-01", + "timestamp": 1640995200, + "Address": "123 Main St", + "user": { + "name": "John Doe", + "age": 30, + "contact": "john@email.com" + } + } + + # (Optional) Create hints so the agent knows what to look for in the input + key_hints = "We are processing outputs of an containing an id, a date, and a model. All the required fields should be present in this input, but the names might be different." + + new_schema = translate_schema(input_json, schema, key_hints=key_hints) + # there are many permutations possible so we check the result rather than the schema... + # some_possible_schemas = [ + # '{ "id": (.["call.id"] // "None"), "date": (.datetime // "None") }', + # '{ "id": .["call.id"], "date": (.datetime // "None") }', + # '{ "id": .["call.id"] // "None", "date": .datetime // null }', + # '{ "id": (."call.id"? // "None"), "date": (.datetime? // "None") }', + # '{ "id": .["call.id"] // "None", "date": .datetime // "None" }', + # '{ "id": (.["call.id"] // "None"), "date": (.datetime // null) }' + # ] + actual = jq.compile(new_schema).input(input_json).all() + assert actual == [{'id': '123', 'date': '2022-01-01'}]