Skip to content

Commit

Permalink
Adds transform support for runnables (#8762)
Browse files Browse the repository at this point in the history
<!-- Thank you for contributing to LangChain!

Replace this comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

Please make sure you're PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
  2. an example notebook showing its use.

Maintainer responsibilities:
  - General / Misc / if you don't know who to tag: @baskaryan
  - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev
  - Models / Prompts: @hwchase17, @baskaryan
  - Memory: @hwchase17
  - Agents / Tools / Toolkits: @hinthornw
  - Tracing / Callbacks: @agola11
  - Async: @agola11

If no one reviews your PR within a few days, feel free to @-mention the
same people again.

See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
 -->

---------

Co-authored-by: jacoblee93 <[email protected]>
Co-authored-by: Eugene Yurtsev <[email protected]>
Co-authored-by: Bagatur <[email protected]>
  • Loading branch information
4 people authored Aug 9, 2023
1 parent 4d72288 commit b8df15c
Show file tree
Hide file tree
Showing 11 changed files with 717 additions and 59 deletions.
15 changes: 14 additions & 1 deletion libs/langchain/langchain/callbacks/tracers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def _start_trace(self, run: Run) -> None:
parent_run = self.run_map[str(run.parent_run_id)]
if parent_run:
self._add_child_run(parent_run, run)
parent_run.child_execution_order = max(
parent_run.child_execution_order, run.child_execution_order
)
else:
logger.debug(f"Parent run with UUID {run.parent_run_id} not found.")
self.run_map[str(run.id)] = run
Expand Down Expand Up @@ -254,7 +257,12 @@ def on_chain_start(
self._on_chain_start(chain_run)

def on_chain_end(
self, outputs: Dict[str, Any], *, run_id: UUID, **kwargs: Any
self,
outputs: Dict[str, Any],
*,
run_id: UUID,
inputs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""End a trace for a chain run."""
if not run_id:
Expand All @@ -266,13 +274,16 @@ def on_chain_end(
chain_run.outputs = outputs
chain_run.end_time = datetime.utcnow()
chain_run.events.append({"name": "end", "time": chain_run.end_time})
if inputs is not None:
chain_run.inputs = inputs
self._end_trace(chain_run)
self._on_chain_end(chain_run)

def on_chain_error(
self,
error: Union[Exception, KeyboardInterrupt],
*,
inputs: Optional[Dict[str, Any]] = None,
run_id: UUID,
**kwargs: Any,
) -> None:
Expand All @@ -286,6 +297,8 @@ def on_chain_error(
chain_run.error = repr(error)
chain_run.end_time = datetime.utcnow()
chain_run.events.append({"name": "error", "time": chain_run.end_time})
if inputs is not None:
chain_run.inputs = inputs
self._end_trace(chain_run)
self._on_chain_error(chain_run)

Expand Down
20 changes: 15 additions & 5 deletions libs/langchain/langchain/chains/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Base interface that all chains should implement."""
import asyncio
import inspect
import json
import logging
import warnings
from abc import ABC, abstractmethod
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

Expand Down Expand Up @@ -55,18 +57,26 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
"""

def invoke(
self, input: Dict[str, Any], config: Optional[RunnableConfig] = None
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
return self(input, **(config or {}))
return self(input, **(config or {}), **kwargs)

async def ainvoke(
self, input: Dict[str, Any], config: Optional[RunnableConfig] = None
self,
input: Dict[str, Any],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
if type(self)._acall == Chain._acall:
# If the chain does not implement async, fall back to default implementation
return await super().ainvoke(input, config)
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, **kwargs)
)

return await self.acall(input, **(config or {}))
return await self.acall(input, **(config or {}), **kwargs)

memory: Optional[BaseMemory] = None
"""Optional memory object. Defaults to None.
Expand Down
16 changes: 10 additions & 6 deletions libs/langchain/langchain/chains/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import logging
from typing import Any, Awaitable, Callable, Dict, List, Optional

from pydantic import Field

from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
Expand All @@ -27,9 +29,11 @@ class TransformChain(Chain):
"""The keys expected by the transform's input dictionary."""
output_variables: List[str]
"""The keys returned by the transform's output dictionary."""
transform: Callable[[Dict[str, str]], Dict[str, str]]
transform_cb: Callable[[Dict[str, str]], Dict[str, str]] = Field(alias="transform")
"""The transform function."""
atransform: Optional[Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]] = None
atransform_cb: Optional[
Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]
] = Field(None, alias="atransform")
"""The async coroutine transform function."""

@staticmethod
Expand Down Expand Up @@ -62,18 +66,18 @@ def _call(
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
return self.transform(inputs)
return self.transform_cb(inputs)

async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
if self.atransform is not None:
return await self.atransform(inputs)
if self.atransform_cb is not None:
return await self.atransform_cb(inputs)
else:
self._log_once(
"TransformChain's atransform is not provided, falling"
" back to synchronous transform"
)
return self.transform(inputs)
return self.transform_cb(inputs)
40 changes: 37 additions & 3 deletions libs/langchain/langchain/chat_models/fake.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""Fake ChatModel for testing purposes."""
from typing import Any, Dict, List, Optional
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import SimpleChatModel
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import AIMessageChunk, BaseMessage
from langchain.schema.output import ChatGenerationChunk


class FakeListChatModel(SimpleChatModel):
Expand Down Expand Up @@ -31,6 +35,36 @@ def _call(
self.i = 0
return response

def _stream(
self,
messages: List[BaseMessage],
stop: Union[List[str], None] = None,
run_manager: Union[CallbackManagerForLLMRun, None] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
response = self.responses[self.i]
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0
for c in response:
yield ChatGenerationChunk(message=AIMessageChunk(content=c))

async def _astream(
self,
messages: List[BaseMessage],
stop: Union[List[str], None] = None,
run_manager: Union[AsyncCallbackManagerForLLMRun, None] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
response = self.responses[self.i]
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0
for c in response:
yield ChatGenerationChunk(message=AIMessageChunk(content=c))

@property
def _identifying_params(self) -> Dict[str, Any]:
return {"responses": self.responses}
30 changes: 29 additions & 1 deletion libs/langchain/langchain/llms/fake.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Any, List, Mapping, Optional
from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional

from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.schema.language_model import LanguageModelInput
from langchain.schema.runnable import RunnableConfig


class FakeListLLM(LLM):
Expand Down Expand Up @@ -51,3 +53,29 @@ async def _acall(
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {"responses": self.responses}


class FakeStreamingListLLM(FakeListLLM):
def stream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[str]:
result = self.invoke(input, config)
for c in result:
yield c

async def astream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[str]:
result = await self.ainvoke(input, config)
for c in result:
yield c
59 changes: 54 additions & 5 deletions libs/langchain/langchain/schema/output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@

import asyncio
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
from typing import (
Any,
AsyncIterator,
Dict,
Generic,
Iterator,
List,
Optional,
TypeVar,
Union,
)

from langchain.load.serializable import Serializable
from langchain.schema.messages import BaseMessage
Expand Down Expand Up @@ -47,7 +57,7 @@ class BaseGenerationOutputParser(
BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
):
def invoke(
self, input: str | BaseMessage, config: RunnableConfig | None = None
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
) -> T:
if isinstance(input, BaseMessage):
return self._call_with_config(
Expand Down Expand Up @@ -115,7 +125,7 @@ def _type(self) -> str:
""" # noqa: E501

def invoke(
self, input: str | BaseMessage, config: RunnableConfig | None = None
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
) -> T:
if isinstance(input, BaseMessage):
return self._call_with_config(
Expand Down Expand Up @@ -242,8 +252,47 @@ def dict(self, **kwargs: Any) -> Dict:
return output_parser_dict


class StrOutputParser(BaseOutputParser[str]):
"""OutputParser that parses LLMResult into the top likely string.."""
class BaseTransformOutputParser(BaseOutputParser[T]):
"""Base class for an output parser that can handle streaming input."""

def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[T]:
for chunk in input:
if isinstance(chunk, BaseMessage):
yield self.parse_result([ChatGeneration(message=chunk)])
else:
yield self.parse_result([Generation(text=chunk)])

async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[T]:
async for chunk in input:
if isinstance(chunk, BaseMessage):
yield self.parse_result([ChatGeneration(message=chunk)])
else:
yield self.parse_result([Generation(text=chunk)])

def transform(
self,
input: Iterator[Union[str, BaseMessage]],
config: Optional[RunnableConfig] = None,
) -> Iterator[T]:
yield from self._transform_stream_with_config(
input, self._transform, config, run_type="parser"
)

async def atransform(
self,
input: AsyncIterator[Union[str, BaseMessage]],
config: Optional[RunnableConfig] = None,
) -> AsyncIterator[T]:
async for chunk in self._atransform_stream_with_config(
input, self._atransform, config, run_type="parser"
):
yield chunk


class StrOutputParser(BaseTransformOutputParser[str]):
"""OutputParser that parses LLMResult into the top likely string."""

@property
def lc_serializable(self) -> bool:
Expand Down
Loading

0 comments on commit b8df15c

Please sign in to comment.