Skip to content

Commit

Permalink
Fix serialization issue
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Aug 9, 2023
1 parent ea0cb4b commit 9403ba6
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 8 deletions.
14 changes: 10 additions & 4 deletions libs/langchain/langchain/load/dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,16 @@ def default(obj: Any) -> Any:

def dumps(obj: Any, *, pretty: bool = False) -> str:
"""Return a json string representation of an object."""
if pretty:
return json.dumps(obj, default=default, indent=2)
else:
return json.dumps(obj, default=default)
try:
if pretty:
return json.dumps(obj, default=default, indent=2)
else:
return json.dumps(obj, default=default)
except TypeError:
if pretty:
return json.dumps(to_json_not_implemented(obj), indent=2)
else:
return json.dumps(to_json_not_implemented(obj))


def dumpd(obj: Any) -> Dict[str, Any]:
Expand Down
10 changes: 7 additions & 3 deletions libs/langchain/langchain/schema/prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _format_value(formatters: FormattersType, value: Any) -> Any:
class BasePromptTemplate(Serializable, Runnable[Dict, PromptValue], ABC):
"""Base class for all prompt templates, returning a prompt."""

formatters: FormattersType = PROMPT_DEFAULT_FORMATTERS
formatters: Optional[FormattersType] = None
"""A mapping of types to functions that format them into a string.
The functions should take a single argument, the value to format, and
return a string. If the function takes two arguments, the second argument
Expand Down Expand Up @@ -147,7 +147,12 @@ def _prepare_variables(self, **kwargs: Any) -> Dict[str, Any]:
for k, v in self.partial_variables.items()
}
all_variables = {**partial_kwargs, **kwargs}
return {k: _format_value(self.formatters, v) for k, v in all_variables.items()}
formatters = (
self.formatters
if self.formatters is not None
else PROMPT_DEFAULT_FORMATTERS
)
return {k: _format_value(formatters, v) for k, v in all_variables.items()}

@abstractmethod
def format(self, **kwargs: Any) -> str:
Expand All @@ -174,7 +179,6 @@ def _prompt_type(self) -> str:
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of prompt."""
prompt_dict = super().dict(**kwargs)
del prompt_dict["formatters"]
prompt_dict["_type"] = self._prompt_type
return prompt_dict

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# serializer version: 1
# name: test_default_formatters
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"kwargs": {
"input_variables": [
"foo"
],
"template": "This is a {foo} test.",
"template_format": "f-string",
"partial_variables": {}
}
}
'''
# ---
# name: test_default_formatters.1
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"kwargs": {
"input_variables": [
"foo"
],
"template": "This is a {foo} test.",
"template_format": "f-string",
"partial_variables": {},
"formatters": {}
}
}
'''
# ---
# name: test_default_formatters.2
'''
{
"lc": 1,
"type": "not_implemented",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
]
}
'''
# ---
16 changes: 15 additions & 1 deletion libs/langchain/tests/unit_tests/prompts/test_prompt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Test functionality related to prompts."""
from langchain.load.dump import dumps
import pytest

from syrupy import SnapshotAssertion

from langchain.prompts.prompt import PromptTemplate
from langchain.schema.document import Document

Expand Down Expand Up @@ -123,7 +126,7 @@ def test_partial_init_string() -> None:
assert result == "This is a 1 test."


def test_default_formatters() -> None:
def test_default_formatters(snapshot: SnapshotAssertion) -> None:
"""Test prompt can be initialized with partial variables."""
template = "This is a {foo} test."
prompt = PromptTemplate.from_template(template)
Expand All @@ -132,12 +135,23 @@ def test_default_formatters() -> None:

foo = [Document(page_content="Hello there", metadata={"some": "key"})]
assert prompt.format(foo=foo) == "This is a ['Hello there'] test."
assert dumps(prompt, pretty=True) == snapshot

prompt_no_formatters = PromptTemplate.from_template(template, formatters={})
assert (
prompt_no_formatters.format(foo=foo)
== "This is a [Document(page_content='Hello there', metadata={'some': 'key'})] test." # noqa: E501
)
assert dumps(prompt_no_formatters, pretty=True) == snapshot

prompt_custom_formatters = PromptTemplate.from_template(
template, formatters={Document: lambda x: x.metadata}
)
assert (
prompt_custom_formatters.format(foo=foo[0])
== "This is a {'some': 'key'} test." # noqa: E501
)
assert dumps(prompt_custom_formatters, pretty=True) == snapshot


def test_partial_init_func() -> None:
Expand Down

0 comments on commit 9403ba6

Please sign in to comment.