Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial notebooks #7

Merged
merged 5 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,6 @@ print(summary_out.text)
To run the above example as a notebook, navigate to the `examples/notebooks` directory and run:

```sh
make ensure # install dependencies
make basic # run the notebook
make ensure # install dependencies
poetry run marimo edit basic.py # run the notebook
```
5 changes: 0 additions & 5 deletions examples/notebooks/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ help:
@echo "Substrate notebooks"
@echo ""
@echo " ensure Install dependencies"
@echo " basic Basic notebook with the example from the README"
@echo " update Update substrate-python to the latest version"

poetry.lock: pyproject.toml
Expand All @@ -20,7 +19,3 @@ ensure: poetry.lock
.PHONY: update
update:
poetry cache clear pypi --all && poetry update substrate

.PHONY: basic
basic:
poetry run marimo edit basic.py
100 changes: 100 additions & 0 deletions examples/notebooks/gen-image-chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import marimo

__generated_with = "0.3.12"
app = marimo.App(width="medium")


@app.cell
def __():
import os
import json
import base64
import marimo as mo
import substrate as sb

api_key = os.environ.get("SUBSTRATE_API_KEY")
api_key = api_key or "YOUR_API_KEY"
mo.md(f"`{api_key}`")
return api_key, base64, json, mo, os, sb


@app.cell
def __(api_key, sb):
substrate = sb.Substrate(
api_key=api_key,
backend="v1",
)
return substrate,


@app.cell
def __(mo):
prompt = mo.ui.text(
placeholder="prompt",
value="A bowl of fruit",
full_width=True,
).form()
prompt
return prompt,


@app.cell
def __(prompt, sb):
image = sb.GenerateImage(
{
"prompt": prompt.value,
}
)
return image,


@app.cell
def __(image, sb):
rmbg = sb.RemoveBackground({
"image_uri": image.future.image_uri
})
return rmbg,


@app.cell
def __(image, sb):
upscale = sb.UpscaleImage({
"image_uri": image.future.image_uri
})
return upscale,


@app.cell
def __(image, mo, rmbg, substrate, upscale):
res = substrate.run(image, rmbg, upscale)
viz = substrate.visualize(image, rmbg, upscale)
mo.md(f"[visualize]({viz})")
return res, viz


@app.cell
def __(image, mo, res):
image_out = res.get(image)
mo.image(src=image_out.image_uri)
return image_out,


@app.cell
def __(mo, res, rmbg):
rmbg_out = res.get(rmbg)
mo.image(src=rmbg_out.image_uri)
return rmbg_out,


@app.cell
def __(mo, res, upscale):
upscale_out = res.get(upscale)
mo.download(
data=upscale_out.image_uri,
filename="upscaled.jpeg",
)
return upscale_out,


if __name__ == "__main__":
app.run()
223 changes: 223 additions & 0 deletions examples/notebooks/gen-image-embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import marimo

__generated_with = "0.3.12"
app = marimo.App(width="medium")


@app.cell
def __():
import os
import json
import string
import random
import base64
import marimo as mo
import substrate as sb

api_key = os.environ.get("SUBSTRATE_API_KEY")
api_key = api_key or "YOUR_API_KEY"
mo.md(f"`{api_key}`")
characters = string.ascii_letters
random_string = "".join(random.choice(characters) for i in range(3))
collection_name = f"image_test_a3"
collection_name
return (
api_key,
base64,
characters,
collection_name,
json,
mo,
os,
random,
random_string,
sb,
string,
)


@app.cell
def __(api_key, sb):
substrate = sb.Substrate(
api_key=api_key,
backend="v1",
)
return substrate,


@app.cell
def __(collection_name, mo, sb, substrate):
# create the vector store
create_vstore = sb.CreateVectorStore(
{"model": "clip", "collection_name": collection_name}
)
create_res = substrate.run(create_vstore)
mo.tree(create_res.json)
return create_res, create_vstore


@app.cell
def __(mo):
prompt = mo.ui.text(
placeholder="prompt",
value="A bowl of fruit",
full_width=True,
).form()
prompt
return prompt,


@app.cell
def __(collection_name, prompt, sb):
image = sb.GenerateImage(
{
"prompt": prompt.value,
}
)
embed_prompt = sb.EmbedText(
{
"text": prompt.value,
"collection_name": collection_name,
}
)
embed = sb.EmbedImage(
{
"image_uri": image.future.image_uri,
"collection_name": collection_name,
}
)
return embed, embed_prompt, image


@app.cell
def __(embed, embed_prompt, image, mo, substrate):
res = substrate.run(image, embed, embed_prompt)
mo.tree(res.json)
return res,


@app.cell
def __(embed_prompt, res):
prompt_doc_id = res.get(embed_prompt).embedding.doc_id
print(prompt_doc_id)
return prompt_doc_id,


@app.cell
def __(embed, res):
image1_doc_id = res.get(embed).embedding.doc_id
print(image1_doc_id)
return image1_doc_id,


@app.cell
def __(mo):
prompt2 = mo.ui.text(
placeholder="prompt",
value="A bowl of chocolate",
full_width=True,
).form()
prompt2
return prompt2,


@app.cell
def __(collection_name, prompt2, sb):
image2 = sb.GenerateImage(
{
"prompt": prompt2.value,
}
)
embed2 = sb.EmbedImage(
{
"image_uri": image2.future.image_uri,
"collection_name": collection_name,
}
)
return embed2, image2


@app.cell
def __(embed2, image2, mo, substrate):
res2 = substrate.run(image2, embed2)
mo.tree(res2.json)
return res2,


@app.cell
def __(embed2, res2):
image2_doc_id = res2.get(embed2).embedding.doc_id
print(image2_doc_id)
return image2_doc_id,


@app.cell
def __(collection_name, image, res, sb):
query = sb.QueryVectorStore(
{
"model": "clip",
"collection_name": collection_name,
"query_image_uris": [res.get(image).image_uri],
"top_k": 100,
"ef_search": 64,
# "query_strings": [prompt.value],
}
)
return query,


@app.cell
def __(query, substrate):
query_res = substrate.run(query)
return query_res,


@app.cell
def __(query, query_res):
results = query_res.get(query).results
return results,


@app.cell
def __(mo, results):
mo.tree(results)
return


@app.cell
def __(image2_doc_id, mo, results):
image2_distance = None
for r in results[0]:
if r.id == image2_doc_id:
image2_distance = r.distance

mo.tree(
{
"image2_distance": image2_distance,
}
)
return image2_distance, r


@app.cell
def __(image, image2, image2_distance, mo, res, res2):
mo.hstack(
[
mo.vstack(
[
mo.image(src=res.get(image).image_uri),
]
),
mo.vstack(
[
mo.image(src=res2.get(image2).image_uri),
mo.md(f"distance: {image2_distance}"),
]
),
]
)
return


if __name__ == "__main__":
app.run()
Loading
Loading