diff --git a/app.py b/app.py index 7de746c..289c074 100644 --- a/app.py +++ b/app.py @@ -11,6 +11,7 @@ from gradio import ChatMessage import textwrap from tools import * +from db.connection import db load_dotenv() os.environ['LANGCHAIN_PROJECT'] = 'gradio-test' diff --git a/db/connection.py b/db/connection.py new file mode 100644 index 0000000..db250f0 --- /dev/null +++ b/db/connection.py @@ -0,0 +1,3 @@ +from langchain_community.utilities import SQLDatabase + +db = SQLDatabase.from_uri("sqlite:///db/Bahrain_2023_Q.db") diff --git a/tools/__init__.py b/tools/__init__.py index 120ec35..d1823c4 100644 --- a/tools/__init__.py +++ b/tools/__init__.py @@ -1,4 +1,3 @@ -from langchain_community.utilities import SQLDatabase from rich.console import Console from .driver_performance import GetDriverPerformance from .event_performance import GetEventPerformance @@ -8,8 +7,6 @@ console = Console(style="chartreuse1 on grey7") -db = SQLDatabase.from_uri("sqlite:///db/Bahrain_2023_Q.db") - __all__ = [ "GetDriverPerformance", "GetEventPerformance", diff --git a/tools/driver_performance.py b/tools/driver_performance.py index fba2ee4..baa18cb 100644 --- a/tools/driver_performance.py +++ b/tools/driver_performance.py @@ -1,6 +1,7 @@ from pydantic import BaseModel, Field from langchain_core.tools import BaseTool -from . import db, console +from db.connection import db +from . import console class GetDriverPerformanceOutput(BaseModel): diff --git a/tools/event_performance.py b/tools/event_performance.py index fbaa53a..de418e0 100644 --- a/tools/event_performance.py +++ b/tools/event_performance.py @@ -1,7 +1,8 @@ from pydantic import BaseModel, Field from typing import Type from langchain_core.tools import BaseTool -from . import db, console +from db.connection import db +from . import console class GetEventPerformanceOutput(BaseModel): diff --git a/tools/telemetry_analysis.py b/tools/telemetry_analysis.py index 5d0109c..9a54a9a 100644 --- a/tools/telemetry_analysis.py +++ b/tools/telemetry_analysis.py @@ -1,7 +1,8 @@ from pydantic import BaseModel, Field from typing import Type from langchain_core.tools import BaseTool -from . import db, console +from db.connection import db +from . import console class GetTelemetryAndWeatherInput(BaseModel): diff --git a/tools/tyre_performance.py b/tools/tyre_performance.py index 32e7eec..8fb1bc3 100644 --- a/tools/tyre_performance.py +++ b/tools/tyre_performance.py @@ -1,7 +1,8 @@ from pydantic import BaseModel, Field from typing import Type from langchain_core.tools import BaseTool -from . import db, console +from db.connection import db +from . import console class GetTyrePerformanceInput(BaseModel): diff --git a/tools/weather_impact.py b/tools/weather_impact.py index 2ed15f4..962d411 100644 --- a/tools/weather_impact.py +++ b/tools/weather_impact.py @@ -1,10 +1,8 @@ from pydantic import BaseModel, Field from typing import Type from langchain_core.tools import BaseTool -from rich.console import Console -from . import db - -console = Console(style="chartreuse1 on grey7") +from db.connection import db +from . import console class GetWeatherImpactInput(BaseModel):