Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow serialize to gather implicit nodes #24

Merged
merged 6 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions examples/intro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os
import sys
import json
from pathlib import Path

# add parent dir to sys.path to make 'substrate' importable
parent_dir = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(parent_dir))

api_key = os.environ.get("SUBSTRATE_API_KEY")
if api_key is None:
raise EnvironmentError("No SUBSTRATE_API_KEY set")

from substrate import Substrate, GenerateText, GenerateImage, sb

substrate = Substrate(api_key=api_key, timeout=60 * 5)

scene = GenerateText(prompt="description of a mythical forest creature: ")

styles = ["woodblock printed", "art nouveau poster"]
images = [GenerateImage(store="hosted", prompt=sb.concat(style, ": ", scene.future.text)) for style in styles]

result = substrate.run(*images)

print(json.dumps(result.json, indent=2))

viz = Substrate.visualize(*images)
os.system(f"open {viz}")
3 changes: 3 additions & 0 deletions scripts/sync_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ def exec_sync(command):
src_dir = "../substrate/sb_models/substratecore"
dest_dir = "substrate/core"
excluded_files = [
"client/future.py",
"corenode.py",
"future_directive.py",
"versions.py",
"jina_versions.py",
"mistral_versions.py",
Expand Down
4 changes: 3 additions & 1 deletion substrate/core/client/future.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""
CORE ꩜ SUBSTRATE

NOTE: this file is not copied from the main repo
"""
from typing import Union, TypeVar, Optional

Expand Down Expand Up @@ -60,7 +62,7 @@ def _on_access(self, key: Union[str, int, "Future"], accessor: TraceType):
operation = TraceOperation(key=key, accessor=accessor, future_id=None)
next_f = TracedFuture(
directive=TraceDirective(
origin_node_id=self.directive.origin_node_id,
origin_node=self.directive.origin_node,
op_stack=self.directive.op_stack + [operation],
),
FG=self.FutureG,
Expand Down
5 changes: 4 additions & 1 deletion substrate/core/corenode.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""
CORE ꩜ SUBSTRATE

NOTE: this file is not copied from the main repo
"""
from typing import Any, List, Type, Generic, TypeVar, Optional

Expand Down Expand Up @@ -30,6 +32,7 @@ def __init__(self, out_type: Type[OT] = Type[Any], hide: bool = True, **attr):
self.args = {}
self.SG.add_node(self, **self.args)
self.futures_from_args: List[BaseFuture] = find_futures_client(attr)
self.referenced_nodes = [future.directive.origin_node for future in self.futures_from_args if isinstance(future, TracedFuture)]

@property
def out_type(self) -> Type[OT]:
Expand Down Expand Up @@ -58,7 +61,7 @@ def future(self) -> TracedFuture:
"""
Reference to future output of this node.
"""
return TracedFuture(directive=TraceDirective(op_stack=[], origin_node_id=self.id))
return TracedFuture(directive=TraceDirective(op_stack=[], origin_node=self))

def __repr__(self):
return f"{self.__class__.__name__}({self.id})"
Expand Down
13 changes: 12 additions & 1 deletion substrate/core/future_directive.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""
CORE ꩜ SUBSTRATE

NOTE: this file is not copied from the main repo
"""
from abc import ABC
from typing import (
Any,
Dict,
List,
Union,
Expand Down Expand Up @@ -82,5 +85,13 @@ class TraceOperation:
@dataclass
class TraceDirective(BaseDirective):
op_stack: List[TraceOperation]
origin_node_id: Optional[str]
origin_node: Any # Should be CoreNode, but am running into circular import
type: Literal["trace"] = "trace"

def to_dict(self) -> Dict:
# noinspection PyDataclass
return {
"op_stack": [asdict(item) for item in self.op_stack],
"origin_node_id": self.origin_node.id,
"type": self.type,
}
12 changes: 11 additions & 1 deletion substrate/substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,18 @@ def serialize(*nodes):
"""
Serializes the given nodes.
"""
graph = Graph()

all_nodes = set()
def collect_nodes(node):
all_nodes.add(node)
for referenced_node in node.referenced_nodes:
collect_nodes(referenced_node)

for node in nodes:
collect_nodes(node)

graph = Graph()
for node in all_nodes:
graph.add_node(node)
graph_serialized = graph.to_dict()
return graph_serialized
38 changes: 38 additions & 0 deletions tests/python-3-9/tests/test_substrate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import sys
from pathlib import Path

import pytest
from pydantic import BaseModel

# add parent dir to sys.path to make 'substrate' importable
parent_dir = Path(__file__).resolve().parent.parent.parent.parent
sys.path.insert(0, str(parent_dir))

from substrate import Substrate
from substrate.core.corenode import CoreNode


class MockOutput(BaseModel):
y: str


@pytest.mark.unit
class TestSubstrate:
def test_serialize(self):
a = CoreNode(x="y", out_type=MockOutput)
b = CoreNode(x=a.future.y, out_type=MockOutput)
c = CoreNode(x=b.future.y, out_type=MockOutput)

# when the nodes are explicitly passed in
result = Substrate.serialize(a, b, c)

node_ids = sorted([d["id"] for d in result["nodes"]])
assert node_ids == [a.id, b.id, c.id]
assert len(result["futures"]) == 2

# when the terminal node are passed in
result = Substrate.serialize(c)

node_ids = sorted([d["id"] for d in result["nodes"]])
assert node_ids == [a.id, b.id, c.id]
assert len(result["futures"]) == 2
Loading