Skip to content

Commit

Permalink
feat(core): Upgrade pydantic to 2.x (eosphoros-ai#1428)
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc authored Apr 20, 2024
1 parent baa1e3f commit 57be1ec
Show file tree
Hide file tree
Showing 103 changed files with 1,145 additions and 533 deletions.
4 changes: 3 additions & 1 deletion dbgpt/_private/llm_metadata.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field

DEFAULT_CONTEXT_WINDOW = 3900
DEFAULT_NUM_OUTPUTS = 256


class LLMMetadata(BaseModel):
model_config = ConfigDict(protected_namespaces=())

context_window: int = Field(
default=DEFAULT_CONTEXT_WINDOW,
description=(
Expand Down
80 changes: 63 additions & 17 deletions dbgpt/_private/pydantic.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,15 @@
from typing import get_origin

import pydantic

if pydantic.VERSION.startswith("1."):
PYDANTIC_VERSION = 1
from pydantic import (
BaseModel,
ConfigDict,
Extra,
Field,
NonNegativeFloat,
NonNegativeInt,
PositiveFloat,
PositiveInt,
PrivateAttr,
ValidationError,
root_validator,
validator,
)
raise NotImplementedError("pydantic 1.x is not supported, please upgrade to 2.x.")
else:
PYDANTIC_VERSION = 2
# pydantic 2.x
from pydantic.v1 import (
# Now we upgrade to pydantic 2.x
from pydantic import (
BaseModel,
ConfigDict,
Extra,
Expand All @@ -30,16 +20,72 @@
PositiveInt,
PrivateAttr,
ValidationError,
field_validator,
model_validator,
root_validator,
validator,
)

EXTRA_FORBID = "forbid"

def model_to_json(model, **kwargs):
"""Convert a pydantic model to json"""

def model_to_json(model, **kwargs) -> str:
"""Convert a pydantic model to json."""
if PYDANTIC_VERSION == 1:
return model.json(**kwargs)
else:
if "ensure_ascii" in kwargs:
del kwargs["ensure_ascii"]
return model.model_dump_json(**kwargs)


def model_to_dict(model, **kwargs) -> dict:
"""Convert a pydantic model to dict."""
if PYDANTIC_VERSION == 1:
return model.dict(**kwargs)
else:
return model.model_dump(**kwargs)


def model_fields(model):
"""Return the fields of a pydantic model."""
if PYDANTIC_VERSION == 1:
return model.__fields__
else:
return model.model_fields


def field_is_required(field) -> bool:
"""Return if a field is required."""
if PYDANTIC_VERSION == 1:
return field.required
else:
return field.is_required()


def field_outer_type(field):
"""Return the outer type of a field."""
if PYDANTIC_VERSION == 1:
return field.outer_type_
else:
# https://github.com/pydantic/pydantic/discussions/7217
origin = get_origin(field.annotation)
if origin is None:
return field.annotation
return origin


def field_description(field):
"""Return the description of a field."""
if PYDANTIC_VERSION == 1:
return field.field_info.description
else:
return field.description


def field_default(field):
"""Return the default value of a field."""
if PYDANTIC_VERSION == 1:
return field.field_info.default
else:
return field.default
25 changes: 18 additions & 7 deletions dbgpt/agent/actions/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
get_origin,
)

from dbgpt._private.pydantic import BaseModel
from dbgpt._private.pydantic import (
BaseModel,
field_default,
field_description,
model_fields,
model_to_dict,
)
from dbgpt.util.json_utils import find_json_objects

from ...vis.base import Vis
Expand Down Expand Up @@ -45,6 +51,10 @@ def from_dict(
return None
return cls.parse_obj(param)

def to_dict(self) -> Dict[str, Any]:
"""Convert the object to a dictionary."""
return model_to_dict(self)


class Action(ABC, Generic[T]):
"""Base Action class for defining agent actions."""
Expand Down Expand Up @@ -85,12 +95,13 @@ def _create_example(
if origin is None:
example = {}
single_model_type = cast(Type[BaseModel], model_type)
for field_name, field in single_model_type.__fields__.items():
field_info = field.field_info
if field_info.description:
example[field_name] = field_info.description
elif field_info.default:
example[field_name] = field_info.default
for field_name, field in model_fields(single_model_type).items():
description = field_description(field)
default_value = field_default(field)
if description:
example[field_name] = description
elif default_value:
example[field_name] = default_value
else:
example[field_name] = ""
return example
Expand Down
6 changes: 3 additions & 3 deletions dbgpt/agent/actions/chart_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from typing import Optional

from dbgpt._private.pydantic import BaseModel, Field
from dbgpt._private.pydantic import BaseModel, Field, model_to_json
from dbgpt.vis.tags.vis_chart import Vis, VisChart

from ..resource.resource_api import AgentResource, ResourceType
Expand Down Expand Up @@ -86,13 +86,13 @@ async def run(
if not self.render_protocol:
raise ValueError("The rendering protocol is not initialized!")
view = await self.render_protocol.display(
chart=json.loads(param.json()), data_df=data_df
chart=json.loads(model_to_json(param)), data_df=data_df
)
if not self.resource_need:
raise ValueError("The resource type is not found!")
return ActionOutput(
is_exe_success=True,
content=param.json(),
content=model_to_json(param),
view=view,
resource_type=self.resource_need.value,
resource_value=resource.value,
Expand Down
12 changes: 9 additions & 3 deletions dbgpt/agent/actions/dashboard_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from typing import List, Optional

from dbgpt._private.pydantic import BaseModel, Field
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
from dbgpt.vis.tags.vis_dashboard import Vis, VisDashboard

from ..resource.resource_api import AgentResource, ResourceType
Expand All @@ -30,6 +30,10 @@ class ChartItem(BaseModel):
)
thought: str = Field(..., description="Summary of thoughts to the user")

def to_dict(self):
"""Convert to dict."""
return model_to_dict(self)


class DashboardAction(Action[List[ChartItem]]):
"""Dashboard action class."""
Expand Down Expand Up @@ -101,7 +105,7 @@ async def run(
sql_df = await resource_db_client.query_to_df(
resource.value, chart_item.sql
)
chart_dict = chart_item.dict()
chart_dict = chart_item.to_dict()

chart_dict["data"] = sql_df
except Exception as e:
Expand All @@ -113,7 +117,9 @@ async def run(
view = await self.render_protocol.display(charts=chart_params)
return ActionOutput(
is_exe_success=True,
content=json.dumps([chart_item.dict() for chart_item in chart_items]),
content=json.dumps(
[chart_item.to_dict() for chart_item in chart_items]
),
view=view,
)
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/agent/core/agent_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def get_describe_by_name(self, name: str) -> str:
"""Return the description of an agent by name."""
return self._agents[name][1].desc

def all_agents(self):
def all_agents(self) -> Dict[str, str]:
"""Return a dictionary of all registered agents and their descriptions."""
result = {}
for name, value in self._agents.items():
Expand Down
19 changes: 9 additions & 10 deletions dbgpt/agent/core/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast

from dbgpt._private.pydantic import Field
from dbgpt._private.pydantic import ConfigDict, Field
from dbgpt.core import LLMClient, ModelMessageRoleType
from dbgpt.util.error_types import LLMChatError
from dbgpt.util.tracer import SpanType, root_tracer
Expand All @@ -27,6 +27,8 @@
class ConversableAgent(Role, Agent):
"""ConversableAgent is an agent that can communicate with other agents."""

model_config = ConfigDict(arbitrary_types_allowed=True)

agent_context: Optional[AgentContext] = Field(None, description="Agent context")
actions: List[Action] = Field(default_factory=list)
resources: List[AgentResource] = Field(default_factory=list)
Expand All @@ -38,11 +40,6 @@ class ConversableAgent(Role, Agent):
llm_client: Optional[AIWrapper] = None
oai_system_message: List[Dict] = Field(default_factory=list)

class Config:
"""Pydantic configuration."""

arbitrary_types_allowed = True

def __init__(self, **kwargs):
"""Create a new agent."""
Role.__init__(self, **kwargs)
Expand Down Expand Up @@ -377,8 +374,10 @@ async def generate_reply(
**act_extent_param,
)
if act_out:
reply_message.action_report = act_out.dict()
span.metadata["action_report"] = act_out.dict() if act_out else None
reply_message.action_report = act_out.to_dict()
span.metadata["action_report"] = (
act_out.to_dict() if act_out else None
)

with root_tracer.start_span(
"agent.generate_reply.verify",
Expand Down Expand Up @@ -496,7 +495,7 @@ async def act(
"recipient": self.get_name(),
"reviewer": reviewer.get_name() if reviewer else None,
"need_resource": need_resource.to_dict() if need_resource else None,
"rely_action_out": last_out.dict() if last_out else None,
"rely_action_out": last_out.to_dict() if last_out else None,
"conv_uid": self.not_null_agent_context.conv_id,
"action_index": i,
"total_action": len(self.actions),
Expand All @@ -508,7 +507,7 @@ async def act(
rely_action_out=last_out,
**kwargs,
)
span.metadata["action_out"] = last_out.dict() if last_out else None
span.metadata["action_out"] = last_out.to_dict() if last_out else None
return last_out

async def correctness_check(
Expand Down
16 changes: 5 additions & 11 deletions dbgpt/agent/core/base_team.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from typing import Dict, List, Optional, Tuple, Union

from dbgpt._private.pydantic import BaseModel, Field
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field

from ..actions.action import ActionOutput
from .agent import Agent, AgentMessage
Expand Down Expand Up @@ -68,16 +68,13 @@ def _content_str(content: Union[str, List, None]) -> str:
class Team(BaseModel):
"""Team class for managing a group of agents in a team chat."""

model_config = ConfigDict(arbitrary_types_allowed=True)

agents: List[Agent] = Field(default_factory=list)
messages: List[Dict] = Field(default_factory=list)
max_round: int = 100
is_team: bool = True

class Config:
"""Pydantic model configuration."""

arbitrary_types_allowed = True

def __init__(self, **kwargs):
"""Create a new Team instance."""
super().__init__(**kwargs)
Expand Down Expand Up @@ -122,6 +119,8 @@ def append(self, message: Dict):
class ManagerAgent(ConversableAgent, Team):
"""Manager Agent class."""

model_config = ConfigDict(arbitrary_types_allowed=True)

profile: str = "TeamManager"
goal: str = "manage all hired intelligent agents to complete mission objectives"
constraints: List[str] = []
Expand All @@ -132,11 +131,6 @@ class ManagerAgent(ConversableAgent, Team):
# of the agent has already been retried.
max_retry_count: int = 1

class Config:
"""Pydantic model configuration."""

arbitrary_types_allowed = True

def __init__(self, **kwargs):
"""Create a new ManagerAgent instance."""
ConversableAgent.__init__(self, **kwargs)
Expand Down
9 changes: 3 additions & 6 deletions dbgpt/agent/core/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Type

from dbgpt._private.pydantic import BaseModel, Field
from dbgpt._private.pydantic import BaseModel, ConfigDict, Field
from dbgpt.core import LLMClient, ModelMetadata, ModelRequest

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -106,11 +106,8 @@ def register_llm_strategy(
class LLMConfig(BaseModel):
"""LLM configuration."""

model_config = ConfigDict(arbitrary_types_allowed=True)

llm_client: Optional[LLMClient] = Field(default_factory=LLMClient)
llm_strategy: LLMStrategyType = Field(default=LLMStrategyType.Default)
strategy_context: Optional[Any] = None

class Config:
"""Pydantic model config."""

arbitrary_types_allowed = True
Loading

0 comments on commit 57be1ec

Please sign in to comment.