From 1f347ed48ecaf99636495bbcd7f0b4587aef17e0 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Sun, 22 Sep 2024 12:45:47 -0700 Subject: [PATCH 1/2] perf: Intern string constants --- libs/langgraph/langgraph/constants.py | 58 +++++----- .../langgraph/managed/shared_value.py | 6 +- libs/langgraph/langgraph/pregel/__init__.py | 100 ++++++++-------- libs/langgraph/langgraph/pregel/algo.py | 15 ++- libs/langgraph/langgraph/pregel/loop.py | 108 +++++++++--------- libs/langgraph/langgraph/pregel/read.py | 4 +- libs/langgraph/langgraph/pregel/retry.py | 10 +- libs/langgraph/langgraph/pregel/write.py | 4 +- libs/langgraph/langgraph/utils/config.py | 28 +++-- libs/langgraph/langgraph/utils/runnable.py | 6 +- 10 files changed, 180 insertions(+), 159 deletions(-) diff --git a/libs/langgraph/langgraph/constants.py b/libs/langgraph/langgraph/constants.py index 13d136d48..80713b839 100644 --- a/libs/langgraph/langgraph/constants.py +++ b/libs/langgraph/langgraph/constants.py @@ -1,3 +1,4 @@ +import sys from types import MappingProxyType from typing import Any, Mapping @@ -11,67 +12,69 @@ EMPTY_SEQ: tuple[str, ...] = tuple() # --- Public constants --- -TAG_HIDDEN = "langsmith:hidden" +TAG_HIDDEN = sys.intern("langsmith:hidden") # tag to hide a node/edge from certain tracing/streaming environments -START = "__start__" +START = sys.intern("__start__") # the first (maybe virtual) node in graph-style Pregel -END = "__end__" +END = sys.intern("__end__") # the last (maybe virtual) node in graph-style Pregel # --- Reserved write keys --- -INPUT = "__input__" +INPUT = sys.intern("__input__") # for values passed as input to the graph -INTERRUPT = "__interrupt__" +INTERRUPT = sys.intern("__interrupt__") # for dynamic interrupts raised by nodes -ERROR = "__error__" +ERROR = sys.intern("__error__") # for errors raised by nodes -NO_WRITES = "__no_writes__" +NO_WRITES = sys.intern("__no_writes__") # marker to signal node didn't write anything -SCHEDULED = "__scheduled__" +SCHEDULED = sys.intern("__scheduled__") # marker to signal node was scheduled (in distributed mode) -TASKS = "__pregel_tasks" +TASKS = sys.intern("__pregel_tasks") # for Send objects returned by nodes/edges, corresponds to PUSH below # --- Reserved config.configurable keys --- -CONFIG_KEY_SEND = "__pregel_send" +CONFIG_KEY_SEND = sys.intern("__pregel_send") # holds the `write` function that accepts writes to state/edges/reserved keys -CONFIG_KEY_READ = "__pregel_read" +CONFIG_KEY_READ = sys.intern("__pregel_read") # holds the `read` function that returns a copy of the current state -CONFIG_KEY_CHECKPOINTER = "__pregel_checkpointer" +CONFIG_KEY_CHECKPOINTER = sys.intern("__pregel_checkpointer") # holds a `BaseCheckpointSaver` passed from parent graph to child graphs -CONFIG_KEY_STREAM = "__pregel_stream" +CONFIG_KEY_STREAM = sys.intern("__pregel_stream") # holds a `StreamProtocol` passed from parent graph to child graphs -CONFIG_KEY_STREAM_WRITER = "__pregel_stream_writer" +CONFIG_KEY_STREAM_WRITER = sys.intern("__pregel_stream_writer") # holds a `StreamWriter` for stream_mode=custom -CONFIG_KEY_STORE = "__pregel_store" +CONFIG_KEY_STORE = sys.intern("__pregel_store") # holds a `BaseStore` made available to managed values -CONFIG_KEY_RESUMING = "__pregel_resuming" +CONFIG_KEY_RESUMING = sys.intern("__pregel_resuming") # holds a boolean indicating if subgraphs should resume from a previous checkpoint -CONFIG_KEY_TASK_ID = "__pregel_task_id" +CONFIG_KEY_TASK_ID = sys.intern("__pregel_task_id") # holds the task ID for the current task -CONFIG_KEY_DEDUPE_TASKS = "__pregel_dedupe_tasks" +CONFIG_KEY_DEDUPE_TASKS = sys.intern("__pregel_dedupe_tasks") # holds a boolean indicating if tasks should be deduplicated (for distributed mode) -CONFIG_KEY_ENSURE_LATEST = "__pregel_ensure_latest" +CONFIG_KEY_ENSURE_LATEST = sys.intern("__pregel_ensure_latest") # holds a boolean indicating whether to assert the requested checkpoint is the latest # (for distributed mode) -CONFIG_KEY_DELEGATE = "__pregel_delegate" +CONFIG_KEY_DELEGATE = sys.intern("__pregel_delegate") # holds a boolean indicating whether to delegate subgraphs (for distributed mode) -CONFIG_KEY_CHECKPOINT_MAP = "checkpoint_map" +CONFIG_KEY_CHECKPOINT_MAP = sys.intern("checkpoint_map") # holds a mapping of checkpoint_ns -> checkpoint_id for parent graphs -CONFIG_KEY_CHECKPOINT_ID = "checkpoint_id" +CONFIG_KEY_CHECKPOINT_ID = sys.intern("checkpoint_id") # holds the current checkpoint_id, if any -CONFIG_KEY_CHECKPOINT_NS = "checkpoint_ns" +CONFIG_KEY_CHECKPOINT_NS = sys.intern("checkpoint_ns") # holds the current checkpoint_ns, "" for root graph # --- Other constants --- -PUSH = "__pregel_push" +PUSH = sys.intern("__pregel_push") # denotes push-style tasks, ie. those created by Send objects -PULL = "__pregel_pull" +PULL = sys.intern("__pregel_pull") # denotes pull-style tasks, ie. those triggered by edges -NS_SEP = "|" +NS_SEP = sys.intern("|") # for checkpoint_ns, separates each level (ie. graph|subgraph|subsubgraph) -NS_END = ":" +NS_END = sys.intern(":") # for checkpoint_ns, for each level, separates the namespace from the task_id +CONF = sys.intern("configurable") +# key for the configurable dict in RunnableConfig RESERVED = { TAG_HIDDEN, @@ -103,4 +106,5 @@ PULL, NS_SEP, NS_END, + CONF, } diff --git a/libs/langgraph/langgraph/managed/shared_value.py b/libs/langgraph/langgraph/managed/shared_value.py index 9a624c685..0f94dd1a1 100644 --- a/libs/langgraph/langgraph/managed/shared_value.py +++ b/libs/langgraph/langgraph/managed/shared_value.py @@ -13,7 +13,7 @@ from langchain_core.runnables import RunnableConfig from typing_extensions import NotRequired, Required, Self -from langgraph.constants import CONFIG_KEY_STORE +from langgraph.constants import CONF, CONFIG_KEY_STORE from langgraph.errors import InvalidUpdateError from langgraph.managed.base import ( ChannelKeyPlaceholder, @@ -83,10 +83,10 @@ def __init__( raise ValueError("SharedValue must be a dict") self.scope = scope self.value: Value = {} - self.store = cast(BaseStore, config["configurable"].get(CONFIG_KEY_STORE)) + self.store = cast(BaseStore, config[CONF].get(CONFIG_KEY_STORE)) if self.store is None: pass - elif scope_value := config["configurable"].get(self.scope): + elif scope_value := config[CONF].get(self.scope): self.ns = f"scoped:{scope}:{key}:{scope_value}" else: raise ValueError( diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index 687b3ebb3..b642f8007 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -55,6 +55,8 @@ empty_checkpoint, ) from langgraph.constants import ( + CONF, + CONFIG_KEY_CHECKPOINT_NS, CONFIG_KEY_CHECKPOINTER, CONFIG_KEY_READ, CONFIG_KEY_RESUMING, @@ -451,7 +453,7 @@ def _prepare_state_snapshot( ) # get the subgraphs subgraphs = dict(self.get_subgraphs()) - parent_ns = saved.config["configurable"].get("checkpoint_ns", "") + parent_ns = saved.config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "") task_states: dict[str, Union[RunnableConfig, StateSnapshot]] = {} for task in next_tasks.values(): if task.name not in subgraphs: @@ -463,19 +465,19 @@ def _prepare_state_snapshot( if not recurse: # set config as signal that subgraph checkpoints exist config = { - "configurable": { - "thread_id": saved.config["configurable"]["thread_id"], - "checkpoint_ns": task_ns, + CONF: { + "thread_id": saved.config[CONF]["thread_id"], + CONFIG_KEY_CHECKPOINT_NS: task_ns, } } task_states[task.id] = config else: # get the state of the subgraph config = { - "configurable": { + CONF: { CONFIG_KEY_CHECKPOINTER: recurse, - "thread_id": saved.config["configurable"]["thread_id"], - "checkpoint_ns": task_ns, + "thread_id": saved.config[CONF]["thread_id"], + CONFIG_KEY_CHECKPOINT_NS: task_ns, } } task_states[task.id] = subgraphs[task.name].get_state( @@ -527,7 +529,7 @@ async def _aprepare_state_snapshot( ) # get the subgraphs subgraphs = {n: g async for n, g in self.aget_subgraphs()} - parent_ns = saved.config["configurable"].get("checkpoint_ns", "") + parent_ns = saved.config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "") task_states: dict[str, Union[RunnableConfig, StateSnapshot]] = {} for task in next_tasks.values(): if task.name not in subgraphs: @@ -539,19 +541,19 @@ async def _aprepare_state_snapshot( if not recurse: # set config as signal that subgraph checkpoints exist config = { - "configurable": { - "thread_id": saved.config["configurable"]["thread_id"], - "checkpoint_ns": task_ns, + CONF: { + "thread_id": saved.config[CONF]["thread_id"], + CONFIG_KEY_CHECKPOINT_NS: task_ns, } } task_states[task.id] = config else: # get the state of the subgraph config = { - "configurable": { + CONF: { CONFIG_KEY_CHECKPOINTER: recurse, - "thread_id": saved.config["configurable"]["thread_id"], - "checkpoint_ns": task_ns, + "thread_id": saved.config[CONF]["thread_id"], + CONFIG_KEY_CHECKPOINT_NS: task_ns, } } task_states[task.id] = await subgraphs[task.name].aget_state( @@ -572,15 +574,15 @@ def get_state( self, config: RunnableConfig, *, subgraphs: bool = False ) -> StateSnapshot: """Get the current state of the graph.""" - checkpointer: Optional[BaseCheckpointSaver] = config["configurable"].get( + checkpointer: Optional[BaseCheckpointSaver] = config[CONF].get( CONFIG_KEY_CHECKPOINTER, self.checkpointer ) if not checkpointer: raise ValueError("No checkpointer set") if ( - checkpoint_ns := config["configurable"].get("checkpoint_ns", "") - ) and CONFIG_KEY_CHECKPOINTER not in config["configurable"]: + checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "") + ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]: # remove task_ids from checkpoint_ns recast_checkpoint_ns = NS_SEP.join( part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP) @@ -607,15 +609,15 @@ async def aget_state( self, config: RunnableConfig, *, subgraphs: bool = False ) -> StateSnapshot: """Get the current state of the graph.""" - checkpointer: Optional[BaseCheckpointSaver] = config["configurable"].get( + checkpointer: Optional[BaseCheckpointSaver] = config[CONF].get( CONFIG_KEY_CHECKPOINTER, self.checkpointer ) if not checkpointer: raise ValueError("No checkpointer set") if ( - checkpoint_ns := config["configurable"].get("checkpoint_ns", "") - ) and CONFIG_KEY_CHECKPOINTER not in config["configurable"]: + checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "") + ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]: # remove task_ids from checkpoint_ns recast_checkpoint_ns = NS_SEP.join( part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP) @@ -647,15 +649,15 @@ def get_state_history( limit: Optional[int] = None, ) -> Iterator[StateSnapshot]: """Get the history of the state of the graph.""" - checkpointer: Optional[BaseCheckpointSaver] = config["configurable"].get( + checkpointer: Optional[BaseCheckpointSaver] = config[CONF].get( CONFIG_KEY_CHECKPOINTER, self.checkpointer ) if not checkpointer: raise ValueError("No checkpointer set") if ( - checkpoint_ns := config["configurable"].get("checkpoint_ns", "") - ) and CONFIG_KEY_CHECKPOINTER not in config["configurable"]: + checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "") + ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]: # remove task_ids from checkpoint_ns recast_checkpoint_ns = NS_SEP.join( part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP) @@ -676,7 +678,9 @@ def get_state_history( raise ValueError(f"Subgraph {recast_checkpoint_ns} not found") config = merge_configs( - self.config, config, {"configurable": {"checkpoint_ns": checkpoint_ns}} + self.config, + config, + {CONF: {CONFIG_KEY_CHECKPOINT_NS: checkpoint_ns}}, ) # eagerly consume list() to avoid holding up the db cursor for checkpoint_tuple in list( @@ -695,15 +699,15 @@ async def aget_state_history( limit: Optional[int] = None, ) -> AsyncIterator[StateSnapshot]: """Get the history of the state of the graph.""" - checkpointer: Optional[BaseCheckpointSaver] = config["configurable"].get( + checkpointer: Optional[BaseCheckpointSaver] = config[CONF].get( CONFIG_KEY_CHECKPOINTER, self.checkpointer ) if not checkpointer: raise ValueError("No checkpointer set") if ( - checkpoint_ns := config["configurable"].get("checkpoint_ns", "") - ) and CONFIG_KEY_CHECKPOINTER not in config["configurable"]: + checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "") + ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]: # remove task_ids from checkpoint_ns recast_checkpoint_ns = NS_SEP.join( part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP) @@ -725,7 +729,9 @@ async def aget_state_history( raise ValueError(f"Subgraph {recast_checkpoint_ns} not found") config = merge_configs( - self.config, config, {"configurable": {"checkpoint_ns": checkpoint_ns}} + self.config, + config, + {CONF: {CONFIG_KEY_CHECKPOINT_NS: checkpoint_ns}}, ) # eagerly consume list() to avoid holding up the db cursor for checkpoint_tuple in [ @@ -748,7 +754,7 @@ def update_state( node `as_node`. If `as_node` is not provided, it will be set to the last node that updated the state, if not ambiguous. """ - checkpointer: Optional[BaseCheckpointSaver] = config["configurable"].get( + checkpointer: Optional[BaseCheckpointSaver] = config[CONF].get( CONFIG_KEY_CHECKPOINTER, self.checkpointer ) if not checkpointer: @@ -756,8 +762,8 @@ def update_state( # delegate to subgraph if ( - checkpoint_ns := config["configurable"].get("checkpoint_ns", "") - ) and CONFIG_KEY_CHECKPOINTER not in config["configurable"]: + checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "") + ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]: # remove task_ids from checkpoint_ns recast_checkpoint_ns = NS_SEP.join( part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP) @@ -786,10 +792,10 @@ def update_state( # merge configurable fields with previous checkpoint config checkpoint_config = patch_configurable( config, - {"checkpoint_ns": config["configurable"].get("checkpoint_ns", "")}, + {CONFIG_KEY_CHECKPOINT_NS: config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")}, ) if saved: - checkpoint_config = patch_configurable(config, saved.config["configurable"]) + checkpoint_config = patch_configurable(config, saved.config[CONF]) # find last node that updated the state, if not provided if values is None and as_node is None: next_config = checkpointer.put( @@ -896,7 +902,7 @@ async def aupdate_state( values: dict[str, Any] | Any, as_node: Optional[str] = None, ) -> RunnableConfig: - checkpointer: Optional[BaseCheckpointSaver] = config["configurable"].get( + checkpointer: Optional[BaseCheckpointSaver] = config[CONF].get( CONFIG_KEY_CHECKPOINTER, self.checkpointer ) if not checkpointer: @@ -904,8 +910,8 @@ async def aupdate_state( # delegate to subgraph if ( - checkpoint_ns := config["configurable"].get("checkpoint_ns", "") - ) and CONFIG_KEY_CHECKPOINTER not in config["configurable"]: + checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "") + ) and CONFIG_KEY_CHECKPOINTER not in config[CONF]: # remove task_ids from checkpoint_ns recast_checkpoint_ns = NS_SEP.join( part.split(NS_END)[0] for part in checkpoint_ns.split(NS_SEP) @@ -934,10 +940,10 @@ async def aupdate_state( # merge configurable fields with previous checkpoint config checkpoint_config = patch_configurable( config, - {"checkpoint_ns": config["configurable"].get("checkpoint_ns", "")}, + {CONFIG_KEY_CHECKPOINT_NS: config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, "")}, ) if saved: - checkpoint_config = patch_configurable(config, saved.config["configurable"]) + checkpoint_config = patch_configurable(config, saved.config[CONF]) # find last node that updated the state, if not provided if values is None and as_node is None: next_config = await checkpointer.aput( @@ -1065,16 +1071,16 @@ def _defaults( stream_mode = stream_mode if stream_mode is not None else self.stream_mode if not isinstance(stream_mode, list): stream_mode = [stream_mode] - if CONFIG_KEY_TASK_ID in config.get("configurable", {}): + if CONFIG_KEY_TASK_ID in config.get(CONF, {}): # if being called as a node in another graph, always use values mode stream_mode = ["values"] if self.checkpointer is False: checkpointer: Optional[BaseCheckpointSaver] = None - elif CONFIG_KEY_CHECKPOINTER in config.get("configurable", {}): - checkpointer = config["configurable"][CONFIG_KEY_CHECKPOINTER] + elif CONFIG_KEY_CHECKPOINTER in config.get(CONF, {}): + checkpointer = config[CONF][CONFIG_KEY_CHECKPOINTER] else: checkpointer = self.checkpointer - if checkpointer and not config.get("configurable"): + if checkpointer and not config.get(CONF): raise ValueError( f"Checkpointer requires one or more of the following 'configurable' keys: {[s.id for s in checkpointer.config_specs]}" ) @@ -1216,7 +1222,7 @@ def output() -> Iterator: ) # set up custom stream mode if "custom" in stream_modes: - config["configurable"][CONFIG_KEY_STREAM_WRITER] = lambda c: stream.put( + config[CONF][CONFIG_KEY_STREAM_WRITER] = lambda c: stream.put( ((), "custom", c) ) with SyncPregelLoop( @@ -1238,7 +1244,7 @@ def output() -> Iterator: ) # enable subgraph streaming if subgraphs: - loop.config["configurable"][CONFIG_KEY_STREAM] = loop.stream + loop.config[CONF][CONFIG_KEY_STREAM] = loop.stream # enable concurrent streaming if subgraphs or "messages" in stream_modes or "custom" in stream_modes: # we are careful to have a single waiter live at any one time @@ -1431,8 +1437,8 @@ def output() -> Iterator: ) # set up custom stream mode if "custom" in stream_modes: - config["configurable"][CONFIG_KEY_STREAM_WRITER] = ( - lambda c: stream.put_nowait(((), "custom", c)) + config[CONF][CONFIG_KEY_STREAM_WRITER] = lambda c: stream.put_nowait( + ((), "custom", c) ) async with AsyncPregelLoop( input, @@ -1453,7 +1459,7 @@ def output() -> Iterator: ) # enable subgraph streaming if subgraphs: - loop.config["configurable"][CONFIG_KEY_STREAM] = loop.stream + loop.config[CONF][CONFIG_KEY_STREAM] = loop.stream # enable concurrent streaming if subgraphs or "messages" in stream_modes or "custom" in stream_modes: diff --git a/libs/langgraph/langgraph/pregel/algo.py b/libs/langgraph/langgraph/pregel/algo.py index 9da421b95..99de1f130 100644 --- a/libs/langgraph/langgraph/pregel/algo.py +++ b/libs/langgraph/langgraph/pregel/algo.py @@ -28,7 +28,10 @@ copy_checkpoint, ) from langgraph.constants import ( + CONF, + CONFIG_KEY_CHECKPOINT_ID, CONFIG_KEY_CHECKPOINT_MAP, + CONFIG_KEY_CHECKPOINT_NS, CONFIG_KEY_CHECKPOINTER, CONFIG_KEY_READ, CONFIG_KEY_SEND, @@ -360,8 +363,8 @@ def prepare_single_task( """Prepares a single task for the next Pregel step, given a task path, which uniquely identifies a PUSH or PULL task within the graph.""" checkpoint_id = UUID(checkpoint["id"]).bytes - configurable = config.get("configurable", {}) - parent_ns = configurable.get("checkpoint_ns", "") + configurable = config.get(CONF, {}) + parent_ns = configurable.get(CONFIG_KEY_CHECKPOINT_NS, "") if task_path[0] == PUSH: idx = int(task_path[1]) @@ -443,8 +446,8 @@ def prepare_single_task( **configurable.get(CONFIG_KEY_CHECKPOINT_MAP, {}), parent_ns: checkpoint["id"], }, - "checkpoint_id": None, - "checkpoint_ns": task_checkpoint_ns, + CONFIG_KEY_CHECKPOINT_ID: None, + CONFIG_KEY_CHECKPOINT_NS: task_checkpoint_ns, }, ), triggers, @@ -550,8 +553,8 @@ def prepare_single_task( **configurable.get(CONFIG_KEY_CHECKPOINT_MAP, {}), parent_ns: checkpoint["id"], }, - "checkpoint_id": None, - "checkpoint_ns": task_checkpoint_ns, + CONFIG_KEY_CHECKPOINT_ID: None, + CONFIG_KEY_CHECKPOINT_NS: task_checkpoint_ns, }, ), triggers, diff --git a/libs/langgraph/langgraph/pregel/loop.py b/libs/langgraph/langgraph/pregel/loop.py index 45b8798af..071f798d7 100644 --- a/libs/langgraph/langgraph/pregel/loop.py +++ b/libs/langgraph/langgraph/pregel/loop.py @@ -37,7 +37,10 @@ empty_checkpoint, ) from langgraph.constants import ( + CONF, + CONFIG_KEY_CHECKPOINT_ID, CONFIG_KEY_CHECKPOINT_MAP, + CONFIG_KEY_CHECKPOINT_NS, CONFIG_KEY_DEDUPE_TASKS, CONFIG_KEY_DELEGATE, CONFIG_KEY_ENSURE_LATEST, @@ -209,43 +212,42 @@ def __init__( self.specs = specs self.output_keys = output_keys self.stream_keys = stream_keys - self.is_nested = CONFIG_KEY_TASK_ID in self.config.get("configurable", {}) + self.is_nested = CONFIG_KEY_TASK_ID in self.config.get(CONF, {}) self.skip_done_tasks = ( - "checkpoint_id" not in config["configurable"] - or CONFIG_KEY_DEDUPE_TASKS in config["configurable"] + CONFIG_KEY_CHECKPOINT_ID not in config[CONF] + or CONFIG_KEY_DEDUPE_TASKS in config[CONF] ) self.debug = debug - if self.stream is not None and CONFIG_KEY_STREAM in config["configurable"]: - self.stream = DuplexStream( - self.stream, config["configurable"][CONFIG_KEY_STREAM] - ) - if not self.is_nested and config["configurable"].get("checkpoint_ns"): + if self.stream is not None and CONFIG_KEY_STREAM in config[CONF]: + self.stream = DuplexStream(self.stream, config[CONF][CONFIG_KEY_STREAM]) + if not self.is_nested and config[CONF].get(CONFIG_KEY_CHECKPOINT_NS): self.config = patch_configurable( - self.config, {"checkpoint_ns": "", "checkpoint_id": None} + self.config, + {CONFIG_KEY_CHECKPOINT_NS: "", CONFIG_KEY_CHECKPOINT_ID: None}, ) if check_subgraphs and self.is_nested and self.checkpointer is not None: - if self.config["configurable"]["checkpoint_ns"] in _SEEN_CHECKPOINT_NS: + if self.config[CONF][CONFIG_KEY_CHECKPOINT_NS] in _SEEN_CHECKPOINT_NS: raise MultipleSubgraphsError else: - _SEEN_CHECKPOINT_NS.add(self.config["configurable"]["checkpoint_ns"]) + _SEEN_CHECKPOINT_NS.add(self.config[CONF][CONFIG_KEY_CHECKPOINT_NS]) if ( - CONFIG_KEY_CHECKPOINT_MAP in self.config["configurable"] - and self.config["configurable"].get("checkpoint_ns") - in self.config["configurable"][CONFIG_KEY_CHECKPOINT_MAP] + CONFIG_KEY_CHECKPOINT_MAP in self.config[CONF] + and self.config[CONF].get(CONFIG_KEY_CHECKPOINT_NS) + in self.config[CONF][CONFIG_KEY_CHECKPOINT_MAP] ): self.checkpoint_config = patch_configurable( self.config, { - "checkpoint_id": config["configurable"][CONFIG_KEY_CHECKPOINT_MAP][ - self.config["configurable"]["checkpoint_ns"] + CONFIG_KEY_CHECKPOINT_ID: config[CONF][CONFIG_KEY_CHECKPOINT_MAP][ + self.config[CONF][CONFIG_KEY_CHECKPOINT_NS] ] }, ) else: self.checkpoint_config = config self.checkpoint_ns = ( - tuple(cast(str, self.config["configurable"]["checkpoint_ns"]).split(NS_SEP)) - if self.config["configurable"].get("checkpoint_ns") + tuple(cast(str, self.config[CONF][CONFIG_KEY_CHECKPOINT_NS]).split(NS_SEP)) + if self.config[CONF].get(CONFIG_KEY_CHECKPOINT_NS) else () ) @@ -272,12 +274,12 @@ def put_writes(self, task_id: str, writes: Sequence[tuple[str, Any]]) -> None: self.checkpointer_put_writes, { **self.checkpoint_config, - "configurable": { - **self.checkpoint_config["configurable"], - "checkpoint_ns": self.config["configurable"].get( - "checkpoint_ns", "" + CONF: { + **self.checkpoint_config[CONF], + CONFIG_KEY_CHECKPOINT_NS: self.config[CONF].get( + CONFIG_KEY_CHECKPOINT_NS, "" ), - "checkpoint_id": self.checkpoint["id"], + CONFIG_KEY_CHECKPOINT_ID: self.checkpoint["id"], }, }, writes, @@ -392,7 +394,7 @@ def tick( return False # check if we should delegate (used by subgraphs in distributed mode) - if self.config["configurable"].get(CONFIG_KEY_DELEGATE): + if self.config[CONF].get(CONFIG_KEY_DELEGATE): assert self.input is INPUT_RESUMING raise GraphDelegate( { @@ -456,7 +458,7 @@ def _first(self, *, input_keys: Union[str, Sequence[str]]) -> None: # resuming from previous checkpoint requires # - finding a previous checkpoint # - receiving None input (outer graph) or RESUMING flag (subgraph) - configurable = self.config.get("configurable", {}) + configurable = self.config.get(CONF, {}) is_resuming = bool(self.checkpoint["channel_versions"]) and bool( configurable.get(CONFIG_KEY_RESUMING, self.input is None) ) @@ -475,7 +477,7 @@ def _first(self, *, input_keys: Union[str, Sequence[str]]) -> None: # map inputs to channel updates elif input_writes := deque(map_input(input_keys, self.input)): # check if we should delegate (used by subgraphs in distributed mode) - if self.config["configurable"].get(CONFIG_KEY_DELEGATE): + if self.config[CONF].get(CONFIG_KEY_DELEGATE): raise GraphDelegate( { "config": patch_configurable( @@ -518,9 +520,7 @@ def _first(self, *, input_keys: Union[str, Sequence[str]]) -> None: def _put_checkpoint(self, metadata: CheckpointMetadata) -> None: # assign step metadata["step"] = self.step - metadata["parents"] = self.config["configurable"].get( - CONFIG_KEY_CHECKPOINT_MAP, {} - ) + metadata["parents"] = self.config[CONF].get(CONFIG_KEY_CHECKPOINT_MAP, {}) # debug flag if self.debug: print_step_checkpoint( @@ -537,10 +537,10 @@ def _put_checkpoint(self, metadata: CheckpointMetadata) -> None: self.checkpoint_metadata = metadata self.checkpoint_config = { **self.checkpoint_config, - "configurable": { - **self.checkpoint_config["configurable"], - "checkpoint_ns": self.config["configurable"].get( - "checkpoint_ns", "" + CONF: { + **self.checkpoint_config[CONF], + CONFIG_KEY_CHECKPOINT_NS: self.config[CONF].get( + CONFIG_KEY_CHECKPOINT_NS, "" ), }, } @@ -564,9 +564,9 @@ def _put_checkpoint(self, metadata: CheckpointMetadata) -> None: ) self.checkpoint_config = { **self.checkpoint_config, - "configurable": { - **self.checkpoint_config["configurable"], - "checkpoint_id": self.checkpoint["id"], + CONF: { + **self.checkpoint_config[CONF], + CONFIG_KEY_CHECKPOINT_ID: self.checkpoint["id"], }, } # increment step @@ -689,20 +689,22 @@ def _update_mv(self, key: str, values: Sequence[Any]) -> None: # context manager def __enter__(self) -> Self: - if self.config.get("configurable", {}).get( + if self.config.get(CONF, {}).get( CONFIG_KEY_ENSURE_LATEST - ) and self.checkpoint_config["configurable"].get("checkpoint_id"): + ) and self.checkpoint_config[CONF].get(CONFIG_KEY_CHECKPOINT_ID): if self.checkpointer is None: raise RuntimeError( "Cannot ensure latest checkpoint without checkpointer" ) saved = self.checkpointer.get_tuple( - patch_configurable(self.checkpoint_config, {"checkpoint_id": None}) + patch_configurable( + self.checkpoint_config, {CONFIG_KEY_CHECKPOINT_ID: None} + ) ) if ( saved is None or saved.checkpoint["id"] - != self.checkpoint_config["configurable"]["checkpoint_id"] + != self.checkpoint_config[CONF][CONFIG_KEY_CHECKPOINT_ID] ): raise CheckpointNotLatest elif self.checkpointer: @@ -716,10 +718,10 @@ def __enter__(self) -> Self: self.checkpoint_config = { **self.config, **saved.config, - "configurable": { - "checkpoint_ns": "", - **self.config.get("configurable", {}), - **saved.config.get("configurable", {}), + CONF: { + CONFIG_KEY_CHECKPOINT_NS: "", + **self.config.get(CONF, {}), + **saved.config.get(CONF, {}), }, } self.checkpoint = saved.checkpoint @@ -815,20 +817,22 @@ def _update_mv(self, key: str, values: Sequence[Any]) -> None: # context manager async def __aenter__(self) -> Self: - if self.config.get("configurable", {}).get( + if self.config.get(CONF, {}).get( CONFIG_KEY_ENSURE_LATEST - ) and self.checkpoint_config["configurable"].get("checkpoint_id"): + ) and self.checkpoint_config[CONF].get(CONFIG_KEY_CHECKPOINT_ID): if self.checkpointer is None: raise RuntimeError( "Cannot ensure latest checkpoint without checkpointer" ) saved = await self.checkpointer.aget_tuple( - patch_configurable(self.checkpoint_config, {"checkpoint_id": None}) + patch_configurable( + self.checkpoint_config, {CONFIG_KEY_CHECKPOINT_ID: None} + ) ) if ( saved is None or saved.checkpoint["id"] - != self.checkpoint_config["configurable"]["checkpoint_id"] + != self.checkpoint_config[CONF][CONFIG_KEY_CHECKPOINT_ID] ): raise CheckpointNotLatest elif self.checkpointer: @@ -842,10 +846,10 @@ async def __aenter__(self) -> Self: self.checkpoint_config = { **self.config, **saved.config, - "configurable": { - "checkpoint_ns": "", - **self.config.get("configurable", {}), - **saved.config.get("configurable", {}), + CONF: { + CONFIG_KEY_CHECKPOINT_NS: "", + **self.config.get(CONF, {}), + **saved.config.get(CONF, {}), }, } self.checkpoint = saved.checkpoint diff --git a/libs/langgraph/langgraph/pregel/read.py b/libs/langgraph/langgraph/pregel/read.py index 097e76fa2..6733258dc 100644 --- a/libs/langgraph/langgraph/pregel/read.py +++ b/libs/langgraph/langgraph/pregel/read.py @@ -21,7 +21,7 @@ from langchain_core.runnables.base import Input, Other, coerce_to_runnable from langchain_core.runnables.utils import ConfigurableFieldSpec -from langgraph.constants import CONFIG_KEY_READ +from langgraph.constants import CONF, CONFIG_KEY_READ from langgraph.pregel.retry import RetryPolicy from langgraph.pregel.write import ChannelWrite from langgraph.utils.config import merge_configs @@ -95,7 +95,7 @@ def do_read( mapper: Optional[Callable[[Any], Any]] = None, ) -> Any: try: - read: READ_TYPE = config["configurable"][CONFIG_KEY_READ] + read: READ_TYPE = config[CONF][CONFIG_KEY_READ] except KeyError: raise RuntimeError( "Not configured with a read function" diff --git a/libs/langgraph/langgraph/pregel/retry.py b/libs/langgraph/langgraph/pregel/retry.py index 33c60d875..60057493d 100644 --- a/libs/langgraph/langgraph/pregel/retry.py +++ b/libs/langgraph/langgraph/pregel/retry.py @@ -4,7 +4,7 @@ import time from typing import Optional, Sequence -from langgraph.constants import CONFIG_KEY_RESUMING +from langgraph.constants import CONF, CONFIG_KEY_CHECKPOINT_NS, CONFIG_KEY_RESUMING from langgraph.errors import _SEEN_CHECKPOINT_NS, GraphInterrupt from langgraph.types import PregelExecutableTask, RetryPolicy from langgraph.utils.config import patch_configurable @@ -72,11 +72,11 @@ def run_with_retry( # signal subgraphs to resume (if available) config = patch_configurable(config, {CONFIG_KEY_RESUMING: True}) # clear checkpoint_ns seen (for subgraph detection) - if checkpoint_ns := config["configurable"].get("checkpoint_ns"): + if checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS): _SEEN_CHECKPOINT_NS.discard(checkpoint_ns) finally: # clear checkpoint_ns seen (for subgraph detection) - if checkpoint_ns := config["configurable"].get("checkpoint_ns"): + if checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS): _SEEN_CHECKPOINT_NS.discard(checkpoint_ns) @@ -145,9 +145,9 @@ async def arun_with_retry( # signal subgraphs to resume (if available) config = patch_configurable(config, {CONFIG_KEY_RESUMING: True}) # clear checkpoint_ns seen (for subgraph detection) - if checkpoint_ns := config["configurable"].get("checkpoint_ns"): + if checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS): _SEEN_CHECKPOINT_NS.discard(checkpoint_ns) finally: # clear checkpoint_ns seen (for subgraph detection) - if checkpoint_ns := config["configurable"].get("checkpoint_ns"): + if checkpoint_ns := config[CONF].get(CONFIG_KEY_CHECKPOINT_NS): _SEEN_CHECKPOINT_NS.discard(checkpoint_ns) diff --git a/libs/langgraph/langgraph/pregel/write.py b/libs/langgraph/langgraph/pregel/write.py index c2795c67c..9975c7e5b 100644 --- a/libs/langgraph/langgraph/pregel/write.py +++ b/libs/langgraph/langgraph/pregel/write.py @@ -14,7 +14,7 @@ from langchain_core.runnables import Runnable, RunnableConfig from langchain_core.runnables.utils import ConfigurableFieldSpec -from langgraph.constants import CONFIG_KEY_SEND, TASKS, Send +from langgraph.constants import CONF, CONFIG_KEY_SEND, TASKS, Send from langgraph.errors import InvalidUpdateError from langgraph.utils.runnable import RunnableCallable @@ -138,7 +138,7 @@ def do_write( raise InvalidUpdateError( f"Must write to at least one of {require_at_least_one_of}" ) - write: TYPE_SEND = config["configurable"][CONFIG_KEY_SEND] + write: TYPE_SEND = config[CONF][CONFIG_KEY_SEND] write(sends + filtered) @staticmethod diff --git a/libs/langgraph/langgraph/utils/config.py b/libs/langgraph/langgraph/utils/config.py index cc352cf07..4261fc47c 100644 --- a/libs/langgraph/langgraph/utils/config.py +++ b/libs/langgraph/langgraph/utils/config.py @@ -16,32 +16,36 @@ ) from langgraph.checkpoint.base import CheckpointMetadata -from langgraph.constants import CONFIG_KEY_CHECKPOINT_MAP +from langgraph.constants import ( + CONF, + CONFIG_KEY_CHECKPOINT_ID, + CONFIG_KEY_CHECKPOINT_MAP, + CONFIG_KEY_CHECKPOINT_NS, +) def patch_configurable( config: Optional[RunnableConfig], patch: dict[str, Any] ) -> RunnableConfig: if config is None: - return {"configurable": patch} - elif "configurable" not in config: - return {**config, "configurable": patch} + return {CONF: patch} + elif CONF not in config: + return {**config, CONF: patch} else: - return {**config, "configurable": {**config["configurable"], **patch}} + return {**config, CONF: {**config[CONF], **patch}} def patch_checkpoint_map( config: RunnableConfig, metadata: Optional[CheckpointMetadata] ) -> RunnableConfig: if parents := (metadata.get("parents") if metadata else None): + conf = config[CONF] return patch_configurable( config, { CONFIG_KEY_CHECKPOINT_MAP: { **parents, - config["configurable"]["checkpoint_ns"]: config["configurable"][ - "checkpoint_id" - ], + conf[CONFIG_KEY_CHECKPOINT_NS]: conf[CONFIG_KEY_CHECKPOINT_ID], }, }, ) @@ -77,7 +81,7 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig: base[key] = [*base_value, *value] # type: ignore else: base[key] = value # type: ignore[literal-required] - elif key == "configurable": + elif key == CONF: if base_value := base.get(key): base[key] = {**base_value, **value} # type: ignore else: @@ -161,7 +165,7 @@ def patch_config( if run_name is not None: config["run_name"] = run_name if configurable is not None: - config["configurable"] = {**config.get("configurable", {}), **configurable} + config[CONF] = {**config.get(CONF, {}), **configurable} return config @@ -273,8 +277,8 @@ def ensure_config(*configs: Optional[RunnableConfig]) -> RunnableConfig: empty[k] = v # type: ignore[literal-required] for k, v in config.items(): if v is not None and k not in CONFIG_KEYS: - empty["configurable"][k] = v - for key, value in empty["configurable"].items(): + empty[CONF][k] = v + for key, value in empty[CONF].items(): if ( not key.startswith("__") and isinstance(value, (str, int, float, bool)) diff --git a/libs/langgraph/langgraph/utils/runnable.py b/libs/langgraph/langgraph/utils/runnable.py index d81f86e34..0a90217a0 100644 --- a/libs/langgraph/langgraph/utils/runnable.py +++ b/libs/langgraph/langgraph/utils/runnable.py @@ -34,7 +34,7 @@ from langchain_core.tracers._streaming import _StreamingCallbackHandler from typing_extensions import TypeGuard -from langgraph.constants import CONFIG_KEY_STREAM_WRITER +from langgraph.constants import CONF, CONFIG_KEY_STREAM_WRITER from langgraph.types import StreamWriter from langgraph.utils.config import ( ensure_config, @@ -142,7 +142,7 @@ def invoke( kwargs["config"] = config for kw, _, ck, defv in KWARGS_CONFIG_KEYS: if self.func_accepts[kw]: - kwargs[kw] = config["configurable"].get(ck, defv) + kwargs[kw] = config[CONF].get(ck, defv) context = copy_context() if self.trace: callback_manager = get_callback_manager_for_config(config, self.tags) @@ -181,7 +181,7 @@ async def ainvoke( kwargs["config"] = config for kw, _, ck, defv in KWARGS_CONFIG_KEYS: if self.func_accepts[kw]: - kwargs[kw] = config["configurable"].get(ck, defv) + kwargs[kw] = config[CONF].get(ck, defv) context = copy_context() if self.trace: callback_manager = get_async_callback_manager_for_config(config, self.tags) From 47fe321e31180a9dbab7e1ca7928597eff8eaf6a Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Sun, 22 Sep 2024 12:51:06 -0700 Subject: [PATCH 2/2] Lint --- libs/langgraph/langgraph/constants.py | 4 ++-- libs/langgraph/langgraph/utils/config.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/libs/langgraph/langgraph/constants.py b/libs/langgraph/langgraph/constants.py index 80713b839..4c6733ac5 100644 --- a/libs/langgraph/langgraph/constants.py +++ b/libs/langgraph/langgraph/constants.py @@ -1,6 +1,6 @@ import sys from types import MappingProxyType -from typing import Any, Mapping +from typing import Any, Literal, Mapping, cast from langgraph.types import Interrupt, Send # noqa: F401 @@ -73,7 +73,7 @@ # for checkpoint_ns, separates each level (ie. graph|subgraph|subsubgraph) NS_END = sys.intern(":") # for checkpoint_ns, for each level, separates the namespace from the task_id -CONF = sys.intern("configurable") +CONF = cast(Literal["configurable"], sys.intern("configurable")) # key for the configurable dict in RunnableConfig RESERVED = { diff --git a/libs/langgraph/langgraph/utils/config.py b/libs/langgraph/langgraph/utils/config.py index 4261fc47c..6f75c64ba 100644 --- a/libs/langgraph/langgraph/utils/config.py +++ b/libs/langgraph/langgraph/utils/config.py @@ -83,9 +83,9 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig: base[key] = value # type: ignore[literal-required] elif key == CONF: if base_value := base.get(key): - base[key] = {**base_value, **value} # type: ignore + base[key] = {**base_value, **value} # type: ignore[dict-item] else: - base[key] = value # type: ignore[literal-required] + base[key] = value elif key == "callbacks": base_callbacks = base.get("callbacks") # callbacks can be either None, list[handler] or manager