Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot output a non-serialisable Artifact using Pydantic IO #1191

Open
elliotgunton opened this issue Sep 5, 2024 · 2 comments
Open

Cannot output a non-serialisable Artifact using Pydantic IO #1191

elliotgunton opened this issue Sep 5, 2024 · 2 comments
Labels
type:enhancement A general enhancement

Comments

@elliotgunton
Copy link
Collaborator

elliotgunton commented Sep 5, 2024

A complete blocker to using the new decorators - I have no way to output a bytes Artifact from a template -

Using

class ModelTrainingInput(Input):
    X_train: Annotated[list, Artifact(name="X_train", loader=ArtifactLoader.json)]
    y_train: Annotated[dict, Artifact(name="y_train", loader=ArtifactLoader.json)]
    model: Annotated[Path, Artifact(name="model", output=True)]  # Note the `output=True`

Gets the following error when building the workflow

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/elliot/projects/ds-blog/ds_blog/__main__.py", line 14, in <module>
    from ds_blog.workflow import w
  File "/Users/elliot/projects/ds-blog/ds_blog/workflow.py", line 138, in <module>
    @w.dag()
     ^^^^^^^
  File "/Users/elliot/Library/Caches/pypoetry/virtualenvs/ds-blog-IWZkzs9u-py3.12/lib/python3.12/site-packages/hera/workflows/_meta_mixins.py", line 826, in decorator
    func_return = func(input_obj)
                  ^^^^^^^^^^^^^^^
  File "/Users/elliot/projects/ds-blog/ds_blog/workflow.py", line 144, in run_training
    model_training(
  File "/Users/elliot/Library/Caches/pypoetry/virtualenvs/ds-blog-IWZkzs9u-py3.12/lib/python3.12/site-packages/hera/workflows/_meta_mixins.py", line 670, in script_call_wrapper
    return self._create_subnode(subnode_name, func, script_template, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/elliot/Library/Caches/pypoetry/virtualenvs/ds-blog-IWZkzs9u-py3.12/lib/python3.12/site-packages/hera/workflows/_meta_mixins.py", line 550, in _create_subnode
    subnode_args = args[0]._get_as_arguments()
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/elliot/Library/Caches/pypoetry/virtualenvs/ds-blog-IWZkzs9u-py3.12/lib/python3.12/site-packages/hera/workflows/io/_io_mixins.py", line 152, in _get_as_arguments
    templated_value = serialize(self_dict[field])
                                ~~~~~~~~~^^^^^^^
KeyError: 'model'

And using

class ModelTrainingOutput(Output):
    model: Annotated[bytes, Artifact(name="model", archive=NoneArchiveStrategy())]

@w.script()
def model_training(model_training_input: ModelTrainingInput) -> ModelTrainingOutput:
    X_train = np.array(model_training_input.X_train)
    y_train = pd.Series(model_training_input.y_train)
    model = LogisticRegression(random_state=42)
    model.fit(X_train, y_train)
    return ModelTrainingOutput(model=pickle.dumps(model))

gets the following error when running on the cluster

Traceback (most recent call last):
  File "/usr/local/lib/python3.12/site-packages/hera/workflows/_runner/util.py", line 222, in _runner
    output = _save_annotated_return_outputs(function(**kwargs), output_annotations)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/hera/workflows/_runner/script_annotations_util.py", line 250, in _save_annotated_return_outputs
    _write_to_path(path, value)
  File "/usr/local/lib/python3.12/site-packages/hera/workflows/_runner/script_annotations_util.py", line 326, in _write_to_path
    output_string = serialize(output_value)
                    ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/hera/shared/serialization.py", line 56, in serialize
    return json.dumps(value, cls=PydanticEncoder)  # None serialized as `null`
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/json/__init__.py", line 238, in dumps
    **kw).encode(obj)
          ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/json/encoder.py", line 200, in encode
    chunks = self.iterencode(o, _one_shot=True)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/json/encoder.py", line 258, in iterencode
    return _iterencode(o, 0)
           ^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/hera/shared/serialization.py", line 42, in default
    return super().default(o)
           ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/json/encoder.py", line 180, in default
    raise TypeError(f'Object of type {o.__class__.__name__} '
TypeError: Object of type bytes is not JSON serializable

Workaround is to use the old syntax with an "output" artifact in the function inputs i.e.

@script(constructor="runner")
def model_training(
    X_train: Annotated[list, Artifact(name="X_train", loader=ArtifactLoader.json)],
    y_train: Annotated[dict, Artifact(name="y_train", loader=ArtifactLoader.json)],
    model_path: Annotated[Path, Artifact(name="model", archive=NoneArchiveStrategy(), output=True)],
):

And doing

    model_path.write_bytes(pickle.dumps(model))

Originally posted by @elliotgunton in #1166 (comment)

@alicederyn

This comment was marked as resolved.

@elliotgunton
Copy link
Collaborator Author

Oops, I think I copied from the wrong workflow after trying a few ways to get around it. Updated the original post. The end result is still TypeError: Object of type bytes is not JSON serializable

@elliotgunton elliotgunton added the type:enhancement A general enhancement label Nov 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:enhancement A general enhancement
Projects
None yet
Development

No branches or pull requests

2 participants