Skip to content

Commit

Permalink
Fix state changes within a gr.render (#10095)
Browse files Browse the repository at this point in the history
* changes

* add changeset

---------

Co-authored-by: Ali Abid <[email protected]>
Co-authored-by: gradio-pr-bot <[email protected]>
Co-authored-by: Abubakar Abid <[email protected]>
  • Loading branch information
4 people authored Dec 4, 2024
1 parent de42c85 commit 97d647e
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 2 deletions.
5 changes: 5 additions & 0 deletions .changeset/evil-streets-hunt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": patch
---

fix:Fix state changes within a gr.render
2 changes: 1 addition & 1 deletion demo/state_change/run.ipynb
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: state_change"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "with gr.Blocks() as demo:\n", "\n", " with gr.Row():\n", " state_a = gr.State(0)\n", " btn_a = gr.Button(\"Increment A\")\n", " value_a = gr.Number(label=\"Number A\")\n", " btn_a.click(lambda x: x + 1, state_a, state_a)\n", " state_a.change(lambda x: x, state_a, value_a)\n", " with gr.Row():\n", " state_b = gr.State(0)\n", " btn_b = gr.Button(\"Increment B\")\n", " value_b = gr.Number(label=\"Number B\")\n", " btn_b.click(lambda x: x + 1, state_b, state_b)\n", "\n", " @gr.on(inputs=state_b, outputs=value_b)\n", " def identity(x):\n", " return x\n", "\n", " @gr.render(inputs=[state_a, state_b])\n", " def render(a, b):\n", " for x in range(a):\n", " with gr.Row():\n", " for y in range(b):\n", " gr.Button(f\"Button {x}, {y}\")\n", "\n", " list_state = gr.State([])\n", " dict_state = gr.State(dict())\n", " nested_list_state = gr.State([])\n", " set_state = gr.State(set())\n", "\n", " def transform_list(x):\n", " return {n: n for n in x}, [x[:] for _ in range(len(x))], set(x)\n", "\n", " list_state.change(\n", " transform_list,\n", " inputs=list_state,\n", " outputs=[dict_state, nested_list_state, set_state],\n", " )\n", "\n", " all_textbox = gr.Textbox(label=\"Output\")\n", " click_count = gr.Number(label=\"Clicks\")\n", " change_count = gr.Number(label=\"Changes\")\n", " gr.on(\n", " inputs=[change_count, dict_state, nested_list_state, set_state],\n", " triggers=[dict_state.change, nested_list_state.change, set_state.change],\n", " fn=lambda x, *args: (x + 1, \"\\n\".join(str(arg) for arg in args)),\n", " outputs=[change_count, all_textbox],\n", " )\n", "\n", " count_to_3_btn = gr.Button(\"Count to 3\")\n", " count_to_3_btn.click(lambda: [1, 2, 3], outputs=list_state)\n", " zero_all_btn = gr.Button(\"Zero All\")\n", " zero_all_btn.click(lambda x: [0] * len(x), inputs=list_state, outputs=list_state)\n", "\n", " gr.on(\n", " [count_to_3_btn.click, zero_all_btn.click],\n", " lambda x: x + 1,\n", " click_count,\n", " click_count,\n", " )\n", "\n", " async def increment(x):\n", " yield x + 1\n", "\n", " n_text = gr.State(0)\n", " add_btn = gr.Button(\"Iterator State Change\")\n", " add_btn.click(increment, n_text, n_text)\n", "\n", " @gr.render(inputs=n_text)\n", " def render_count(count):\n", " for i in range(int(count)):\n", " gr.Markdown(value = f\"Success Box {i} added\", key=i)\n", " \n", " class CustomState():\n", " def __init__(self, val):\n", " self.val = val\n", "\n", " def __hash__(self) -> int:\n", " return self.val\n", "\n", " custom_state = gr.State(CustomState(5))\n", " with gr.Row():\n", " btn_10 = gr.Button(\"Set State to 10\")\n", " custom_changes = gr.Number(0, label=\"Custom State Changes\")\n", " custom_clicks = gr.Number(0, label=\"Custom State Clicks\")\n", "\n", " custom_state.change(increment, custom_changes, custom_changes)\n", " def set_to_10(cs: CustomState):\n", " cs.val = 10\n", " return cs\n", "\n", " btn_10.click(set_to_10, custom_state, custom_state).then(\n", " increment, custom_clicks, custom_clicks\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: state_change"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "with gr.Blocks() as demo:\n", "\n", " with gr.Row():\n", " state_a = gr.State(0)\n", " btn_a = gr.Button(\"Increment A\")\n", " value_a = gr.Number(label=\"Number A\")\n", " btn_a.click(lambda x: x + 1, state_a, state_a)\n", " state_a.change(lambda x: x, state_a, value_a)\n", " with gr.Row():\n", " state_b = gr.State(0)\n", " btn_b = gr.Button(\"Increment B\")\n", " value_b = gr.Number(label=\"Number B\")\n", " btn_b.click(lambda x: x + 1, state_b, state_b)\n", "\n", " @gr.on(inputs=state_b, outputs=value_b)\n", " def identity(x):\n", " return x\n", "\n", " @gr.render(inputs=[state_a, state_b])\n", " def render(a, b):\n", " for x in range(a):\n", " with gr.Row():\n", " for y in range(b):\n", " gr.Button(f\"Button {x}, {y}\")\n", "\n", " list_state = gr.State([])\n", " dict_state = gr.State(dict())\n", " nested_list_state = gr.State([])\n", " set_state = gr.State(set())\n", "\n", " def transform_list(x):\n", " return {n: n for n in x}, [x[:] for _ in range(len(x))], set(x)\n", "\n", " list_state.change(\n", " transform_list,\n", " inputs=list_state,\n", " outputs=[dict_state, nested_list_state, set_state],\n", " )\n", "\n", " all_textbox = gr.Textbox(label=\"Output\")\n", " click_count = gr.Number(label=\"Clicks\")\n", " change_count = gr.Number(label=\"Changes\")\n", " gr.on(\n", " inputs=[change_count, dict_state, nested_list_state, set_state],\n", " triggers=[dict_state.change, nested_list_state.change, set_state.change],\n", " fn=lambda x, *args: (x + 1, \"\\n\".join(str(arg) for arg in args)),\n", " outputs=[change_count, all_textbox],\n", " )\n", "\n", " count_to_3_btn = gr.Button(\"Count to 3\")\n", " count_to_3_btn.click(lambda: [1, 2, 3], outputs=list_state)\n", " zero_all_btn = gr.Button(\"Zero All\")\n", " zero_all_btn.click(lambda x: [0] * len(x), inputs=list_state, outputs=list_state)\n", "\n", " gr.on(\n", " [count_to_3_btn.click, zero_all_btn.click],\n", " lambda x: x + 1,\n", " click_count,\n", " click_count,\n", " )\n", "\n", " async def increment(x):\n", " yield x + 1\n", "\n", " n_text = gr.State(0)\n", " add_btn = gr.Button(\"Iterator State Change\")\n", " add_btn.click(increment, n_text, n_text)\n", "\n", " @gr.render(inputs=n_text)\n", " def render_count(count):\n", " for i in range(int(count)):\n", " gr.Markdown(value = f\"Success Box {i} added\", key=i)\n", " \n", " class CustomState():\n", " def __init__(self, val):\n", " self.val = val\n", "\n", " def __hash__(self) -> int:\n", " return self.val\n", "\n", " custom_state = gr.State(CustomState(5))\n", " with gr.Row():\n", " btn_10 = gr.Button(\"Set State to 10\")\n", " custom_changes = gr.Number(0, label=\"Custom State Changes\")\n", " custom_clicks = gr.Number(0, label=\"Custom State Clicks\")\n", "\n", " custom_state.change(increment, custom_changes, custom_changes)\n", " def set_to_10(cs: CustomState):\n", " cs.val = 10\n", " return cs\n", "\n", " btn_10.click(set_to_10, custom_state, custom_state).then(\n", " increment, custom_clicks, custom_clicks\n", " )\n", "\n", " @gr.render()\n", " def render_state_changes():\n", " with gr.Row():\n", " box1 = gr.Textbox(label=\"Start State\")\n", " state1 = gr.State()\n", " box2 = gr.Textbox()\n", " state2 = gr.State()\n", " box3 = gr.Textbox(label=\"End State\")\n", "\n", " iden = lambda x: x\n", " box1.change(iden, box1, state1)\n", " state1.change(iden, state1, box2)\n", " box2.change(iden, box2, state2)\n", " state2.change(iden, state2, box3)\n", " \n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
15 changes: 15 additions & 0 deletions demo/state_change/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,20 @@ def set_to_10(cs: CustomState):
increment, custom_clicks, custom_clicks
)

@gr.render()
def render_state_changes():
with gr.Row():
box1 = gr.Textbox(label="Start State")
state1 = gr.State()
box2 = gr.Textbox()
state2 = gr.State()
box3 = gr.Textbox(label="End State")

iden = lambda x: x
box1.change(iden, box1, state1)
state1.change(iden, state1, box2)
box2.change(iden, box2, state2)
state2.change(iden, state2, box3)

if __name__ == "__main__":
demo.launch()
3 changes: 2 additions & 1 deletion gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2114,7 +2114,8 @@ def get_state_ids_to_track(
hashed_values = []
for block in block_fn.outputs:
if block.stateful and any(
(block._id, "change") in fn.targets for fn in self.fns.values()
(block._id, "change") in fn.targets
for fn in state.blocks_config.fns.values()
):
value = state[block._id]
state_ids_to_track.append(block._id)
Expand Down
6 changes: 6 additions & 0 deletions js/spa/test/state_change.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,9 @@ test("test state change for custom hashes", async ({ page }) => {
"1"
);
});

test("test state changes work within gr.render", async ({ page }) => {
const textbox = await page.getByLabel("Start State");
await textbox.fill("test");
await expect(page.getByLabel("End State").first()).toHaveValue("test");
});

0 comments on commit 97d647e

Please sign in to comment.