diff --git a/apps/api/poetry.lock b/apps/api/poetry.lock index 65205dd0..5cf67919 100644 --- a/apps/api/poetry.lock +++ b/apps/api/poetry.lock @@ -3690,6 +3690,21 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] pymysql = ["pymysql"] sqlcipher = ["sqlcipher3_binary"] +[[package]] +name = "stackapi" +version = "0.3.0" +description = "Library for interacting with the Stack Exchange API" +optional = false +python-versions = "*" +files = [ + {file = "StackAPI-0.3.0-py3-none-any.whl", hash = "sha256:217f494aae3b4f267a0e4f8565e1761c4e55ec30f0c5a50a205632a52ca28481"}, + {file = "StackAPI-0.3.0.tar.gz", hash = "sha256:4147c9587f1c719d1ff9e01a70216290766821f9f7c1401e47b60ee89c329288"}, +] + +[package.dependencies] +requests = "*" +six = "*" + [[package]] name = "starlette" version = "0.37.2" @@ -4621,4 +4636,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "dfa697f06d9fd1971d71249073d3aa33a62d44eac25a4840697c375f6dc1bfef" +content-hash = "3be1bebffa9dce9a2cea13c7d5eea6ac68c50fc9f7b79a38b5d5011132cc15c5" diff --git a/apps/api/pyproject.toml b/apps/api/pyproject.toml index 35c753e4..09757254 100644 --- a/apps/api/pyproject.toml +++ b/apps/api/pyproject.toml @@ -39,6 +39,8 @@ selenium = "^4.19.0" mail = "^2.1.0" duckduckgo-search = "^5.2.2" langchain-exa = "^0.0.1" +stackapi = "^0.3.0" +mypy = "^1.9.0" [tool.poetry.group.dev.dependencies] mypy = "^1.7.0" diff --git a/apps/api/src/__init__.py b/apps/api/src/__init__.py index 13c7bf29..83122438 100644 --- a/apps/api/src/__init__.py +++ b/apps/api/src/__init__.py @@ -2,7 +2,7 @@ from uuid import UUID import autogen -from fastapi import Depends, FastAPI +from fastapi import Depends, FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse @@ -18,9 +18,20 @@ ) from .improver import PromptType, improve_prompt from .interfaces import db -from .models import CrewProcessed +from .models import Profile +from .routers import agents, api_key_types, api_keys from .routers import auth as auth_router -from .routers import agents, crews, messages, sessions, profiles, api_key_types, rest, api_keys +from .routers import ( + billing_information, + crews, + messages, + profiles, + rest, + sessions, + subscriptions, + tiers, + tools, +) logger = logging.getLogger("root") @@ -34,7 +45,11 @@ app.include_router(api_keys.router) app.include_router(auth_router.router) app.include_router(api_key_types.router) +app.include_router(tools.router) +app.include_router(subscriptions.router) app.include_router(rest.router) +app.include_router(tiers.router) +app.include_router(billing_information.router) app.add_middleware( CORSMiddleware, @@ -97,6 +112,5 @@ def auto_build_crew(general_task: str) -> str: @app.get("/me") -def get_profile_from_header(current_user=Depends(get_current_user)): +def get_profile_from_header(current_user=Depends(get_current_user)) -> Profile: return current_user - diff --git a/apps/api/src/auth.py b/apps/api/src/auth.py index 50b1fd83..6c76261e 100644 --- a/apps/api/src/auth.py +++ b/apps/api/src/auth.py @@ -5,13 +5,10 @@ import jwt from dotenv import load_dotenv from fastapi import Depends, FastAPI, HTTPException, status -from fastapi.security import ( - HTTPAuthorizationCredentials, - HTTPBearer, - -) +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from src.interfaces import db +from src.models import Profile load_dotenv() @@ -21,7 +18,7 @@ logger = logging.getLogger("root") -async def get_current_user(token: HTTPAuthorizationCredentials = Depends(HTTPBearer())): +async def get_current_user(token: HTTPAuthorizationCredentials = Depends(HTTPBearer())) -> Profile: credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", @@ -43,4 +40,4 @@ async def get_current_user(token: HTTPAuthorizationCredentials = Depends(HTTPBea if not user_id or not profile: raise credentials_exception - return profile \ No newline at end of file + return profile diff --git a/apps/api/src/crew.py b/apps/api/src/crew.py index 27ff23de..6cac7eb1 100644 --- a/apps/api/src/crew.py +++ b/apps/api/src/crew.py @@ -5,6 +5,7 @@ import autogen from autogen.cache import Cache +from langchain.tools import BaseTool from src.models.session import SessionStatus @@ -18,6 +19,7 @@ logger = logging.getLogger("root") + class AutogenCrew: def __init__( self, @@ -32,10 +34,8 @@ def __init__( self.profile_id = profile_id self.session = session self.on_reply = on_message - if not self._validate_crew_model(crew_model): - raise ValueError("composition is invalid") self.crew_model = crew_model - self.valid_tools = [] + self.valid_tools: list[BaseTool] = [] self.agents: list[autogen.ConversableAgent | autogen.Agent] = ( self._create_agents(crew_model) @@ -58,7 +58,7 @@ def __init__( ) if self.valid_tools else None - ) + ) self.user_proxy.register_reply([autogen.Agent, None], self._on_reply) self.base_config_list = autogen.config_list_from_json( @@ -131,26 +131,12 @@ async def _on_reply( ] ): logger.error( - f"on_reply: both ids are none, sender is not admin and recipient is not chat manager" + "on_reply: both ids are none, sender is not admin and recipient is not chat manager" ) await self.on_reply(recipient_id, sender_id, content, role) return False, None - def _validate_crew_model(self, crew_model: CrewProcessed) -> bool: - if len(crew_model.agents) == 0: - return False - - # Validate agents - for agent in crew_model.agents: - if agent.role == "": - return False - if agent.title == "": - return False - if agent.system_message == "": - return False - return True - def _extract_uuid(self, dictionary: dict[UUID, list[str]]) -> dict[UUID, list[str]]: new_dict = {} for key, value in dictionary.items(): @@ -190,7 +176,7 @@ def _create_agents( for agent in crew_model.agents: valid_agent_tools = [] - tool_schemas = {} + tool_schemas: list[dict] | None config_list = autogen.config_list_from_json( "OAI_CONFIG_LIST", filter_dict={ @@ -208,6 +194,7 @@ def _create_agents( tool, api_key_types, profile_api_keys ) except TypeError as e: + logger.error(f"tried to generate tool, got error: {e}") raise e ( ( diff --git a/apps/api/src/dependencies/__init__.py b/apps/api/src/dependencies/__init__.py index 5e163b3b..c551020e 100644 --- a/apps/api/src/dependencies/__init__.py +++ b/apps/api/src/dependencies/__init__.py @@ -20,6 +20,7 @@ if url is None or key is None: raise ValueError("SUPABASE_URL and SUPABASE_ANON_KEY must be set") + @dataclass class RateLimitResponse: limit: int diff --git a/apps/api/src/interfaces/db.py b/apps/api/src/interfaces/db.py index 451e34f0..4675324c 100644 --- a/apps/api/src/interfaces/db.py +++ b/apps/api/src/interfaces/db.py @@ -13,27 +13,38 @@ Agent, AgentInsertRequest, AgentUpdateModel, - CrewInsertRequest, + APIKey, + APIKeyInsertRequest, + APIKeyType, + APIKeyUpdateRequest, + Billing, + BillingInsertRequest, + BillingUpdateRequest, Crew, + CrewInsertRequest, CrewUpdateRequest, Message, + MessageInsertRequest, + MessageUpdateRequest, Profile, + ProfileInsertRequest, ProfileUpdateRequest, Session, SessionInsertRequest, - Session, SessionStatus, SessionUpdateRequest, - ProfileInsertRequest, - APIKeyInsertRequest, - APIKey, - APIKeyType, - APIKeyUpdateRequest, - APIKeyType, - MessageInsertRequest, - Message, - MessageUpdateRequest, + Subscription, + SubscriptionGetRequest, + SubscriptionInsertRequest, + SubscriptionUpdateRequest, + Tier, + TierInsertRequest, + TierUpdateRequest, + Tool, + ToolInsertRequest, + ToolUpdateRequest, ) +from src.models.tiers import TierGetRequest load_dotenv() url: str | None = os.environ.get("SUPABASE_URL") @@ -46,7 +57,7 @@ logger = logging.getLogger("root") -# keeping this function for now, since typing gets crazy with the sessions/run endpoint +# keeping this function for now, since typing gets crazy with the sessions/run endpoint # if it uses the "get_session_by_param" function def get_session(session_id: UUID) -> Session | None: """Get a session from the database.""" @@ -59,10 +70,10 @@ def get_session(session_id: UUID) -> Session | None: def get_sessions( - profile_id: UUID | None = None, + profile_id: UUID | None = None, crew_id: UUID | None = None, title: str | None = None, - status: str | None = None + status: str | None = None, ) -> list[Session]: """Gets session(s), filtered by what parameters are given""" supabase: Client = create_client(url, key) @@ -87,7 +98,7 @@ def get_sessions( def insert_session(content: SessionInsertRequest) -> Session: supabase: Client = create_client(url, key) - logger.info(f"inserting session") + logger.info("inserting session") response = ( supabase.table("sessions") .insert(json.loads(content.model_dump_json())) @@ -127,11 +138,11 @@ def get_messages( session_id: UUID | None = None, profile_id: UUID | None = None, recipient_id: UUID | None = None, - sender_id: UUID | None = None + sender_id: UUID | None = None, ) -> list[Message]: """Gets messages, filtered by what parameters are given""" supabase: Client = create_client(url, key) - logger.debug(f"Getting messages") + logger.debug("Getting messages") query = supabase.table("messages").select("*") if session_id: @@ -146,20 +157,20 @@ def get_messages( if sender_id: query = query.eq("sender_id", sender_id) - response = query.execute() return [Message(**data) for data in response.data] -def get_message(message_id: UUID) -> Message | None: + +def get_message(message_id: UUID) -> Message: """Get a message by its id""" supabase: Client = create_client(url, key) - response = supabase.table("messages").select("*").eq("id", message_id).execute() - if len(response.data) == 0: - return None + response = ( + supabase.table("messages").select("*").eq("id", message_id).single().execute() + ) + return Message(**response.data) + - return Message(**response.data[0]) - # TODO: combine this function with the insert_message one, or use this post_message for both the endpoint and internal operations def post_message(message: Message) -> None: """Post a message to the database.""" @@ -173,7 +184,11 @@ def post_message(message: Message) -> None: def insert_message(message: MessageInsertRequest) -> Message: """Posts a message like the post_message function, but uses a request model""" supabase: Client = create_client(url, key) - response = supabase.table("messages").insert(json.loads(message.model_dump_json(exclude_none=True))).execute() + response = ( + supabase.table("messages") + .insert(json.loads(message.model_dump_json(exclude_none=True))) + .execute() + ) return Message(**response.data[0]) @@ -202,6 +217,169 @@ def update_message(message_id: UUID, content: MessageUpdateRequest) -> Message | return Message(**response.data[0]) +def get_subscriptions( + profile_id: UUID | None = None, + stripe_subscription_id: str | None = None, +) -> list[Subscription]: + """Gets subscriptions, filtered by what parameters are given""" + supabase: Client = create_client(url, key) + logger.debug("Getting subscriptions") + query = supabase.table("subscriptions").select("*") + + if profile_id: + query = query.eq("profile_id", profile_id) + + if stripe_subscription_id: + query = query.eq("stripe_subscription_id", stripe_subscription_id) + + response = query.execute() + + return [Subscription(**data) for data in response.data] + + +def insert_subscription(subscription: SubscriptionInsertRequest) -> Subscription: + """Posts a subscription to the db""" + supabase: Client = create_client(url, key) + response = ( + supabase.table("subscriptions") + .insert(json.loads(subscription.model_dump_json(exclude_none=True))) + .execute() + ) + return Subscription(**response.data[0]) + + +def delete_subscription(profile_id: UUID) -> Subscription | None: + """Deletes a subscription by an id (the primary key)""" + supabase: Client = create_client(url, key) + response = ( + supabase.table("subscriptions").delete().eq("profile_id", profile_id).execute() + ) + if len(response.data) == 0: + return None + + return Subscription(**response.data[0]) + + +def update_subscription( + profile_id: UUID, content: SubscriptionUpdateRequest +) -> Subscription | None: + """Updates a subscription by an id""" + supabase: Client = create_client(url, key) + response = ( + supabase.table("subscriptions") + .update(json.loads(content.model_dump_json(exclude_none=True))) + .eq("profile_id", profile_id) + .execute() + ) + if len(response.data) == 0: + return None + + return Subscription(**response.data[0]) + + +def get_tier(id: UUID) -> Tier | None: + """Gets tiers, filtered by what parameters are given""" + supabase: Client = create_client(url, key) + response = supabase.table("tiers").select("*").eq("id", id).execute() + if len(response.data) == 0: + return None + + return Tier(**response.data[0]) + + +def insert_tier(tier: TierInsertRequest) -> Tier: + """Posts a tier to the db""" + supabase: Client = create_client(url, key) + response = ( + supabase.table("tiers") + .insert(json.loads(tier.model_dump_json(exclude_none=True))) + .execute() + ) + return Tier(**response.data[0]) + + +def delete_tier(id: UUID) -> Tier | None: + """Deletes a tier by an id (the primary key)""" + supabase: Client = create_client(url, key) + response = supabase.table("tiers").delete().eq("id", id).execute() + if len(response.data) == 0: + return None + + return Tier(**response.data[0]) + + +def update_tier(id: UUID, content: TierUpdateRequest) -> Tier | None: + """Updates a tier by an id""" + supabase: Client = create_client(url, key) + response = ( + supabase.table("tiers") + .update(json.loads(content.model_dump_json(exclude_none=True))) + .eq("id", id) + .execute() + ) + return Tier(**response.data[0]) + + +def get_billing( + profile_id: UUID, +) -> Billing | None: + """Gets billings, filtered by what parameters are given""" + supabase: Client = create_client(url, key) + logger.debug("Getting billings") + response = ( + supabase.table("billing_information") + .select("*") + .eq("profile_id", profile_id) + .execute() + ) + if len(response.data) == 0: + return None + + return Billing(**response.data[0]) + + +def insert_billing(billing: BillingInsertRequest) -> Billing: + """Posts a billing to the db""" + supabase: Client = create_client(url, key) + response = ( + supabase.table("billing_information") + .insert(json.loads(billing.model_dump_json(exclude_none=True))) + .execute() + ) + + return Billing(**response.data[0]) + + +def delete_billing(profile_id: UUID) -> Billing | None: + """Deletes a billing by an id (the primary key)""" + supabase: Client = create_client(url, key) + response = ( + supabase.table("billing_information") + .delete() + .eq("profile_id", profile_id) + .execute() + ) + if len(response.data) == 0: + return None + + return Billing(**response.data[0]) + + +def update_billing(profile_id: UUID, content: BillingUpdateRequest) -> Billing | None: + """Updates a billing by an id""" + supabase: Client = create_client(url, key) + response = ( + supabase.table("billing_information") + .update(json.loads(content.model_dump_json(exclude_none=True))) + .eq("profile_id", profile_id) + .execute() + ) + if len(response.data) == 0: + return None + + return Billing(**response.data[0]) + + def get_descriptions(agent_ids: list[UUID]) -> dict[UUID, list[str]] | None: """Get the description list for the given agent.""" supabase: Client = create_client(url, key) @@ -275,7 +453,7 @@ def get_crews( ) -> list[Crew]: """Gets crews, filtered by what parameters are given""" supabase: Client = create_client(url, key) - logger.debug(f"Getting crews") + logger.debug("Getting crews") query = supabase.table("crews").select("*") if profile_id: @@ -301,25 +479,6 @@ def delete_crew(crew_id: UUID) -> Crew: return Crew(**response.data[0]) -def get_tool_api_keys( - profile_id: UUID, api_key_type_ids: list[str] | None = None -) -> dict[str, str]: - """Gets all api keys for a profile id, if api_key_type_ids is given, only give api keys corresponding to those key types.""" - supabase: Client = create_client(url, key) - # casted_ids = [str(api_key_type_id) for api_key_type_id in api_key_type_ids] - query = ( - supabase.table("users_api_keys") - .select("api_key", "api_key_type_id") - .eq("profile_id", profile_id) - ) - - if api_key_type_ids: - query = query.in_("api_key_type_id", api_key_type_ids) - - response = query.execute() - return {data["api_key_type_id"]: data["api_key"] for data in response.data} - - def get_api_key(api_key_id: UUID) -> APIKey | None: supabase: Client = create_client(url, key) response = ( @@ -358,17 +517,26 @@ def get_api_keys( for data in response.data: api_key_type = APIKeyType(**data["api_key_types"]) api_keys.append(APIKey(**data, api_key_type=api_key_type)) - + return api_keys def insert_api_key(api_key: APIKeyInsertRequest) -> APIKey | None: supabase: Client = create_client(url, key) - type_response = supabase.table("api_key_types").select("*").eq("id", api_key.api_key_type_id).execute() + type_response = ( + supabase.table("api_key_types") + .select("*") + .eq("id", api_key.api_key_type_id) + .execute() + ) if len(type_response.data) == 0: return None - response = supabase.table("users_api_keys").insert(json.loads(api_key.model_dump_json())).execute() + response = ( + supabase.table("users_api_keys") + .insert(json.loads(api_key.model_dump_json())) + .execute() + ) api_key_type = APIKeyType(**type_response.data[0]) return APIKey(**response.data[0], api_key_type=api_key_type) @@ -380,15 +548,30 @@ def delete_api_key(api_key_id: UUID) -> APIKey | None: if not len(response.data): return None - type_response = supabase.table("api_key_types").select("*").eq("id", response.data[0]["api_key_type_id"]).execute() + type_response = ( + supabase.table("api_key_types") + .select("*") + .eq("id", response.data[0]["api_key_type_id"]) + .execute() + ) api_key_type = APIKeyType(**type_response.data[0]) return APIKey(**response.data[0], api_key_type=api_key_type) def update_api_key(api_key_id: UUID, api_key_update: APIKeyUpdateRequest) -> APIKey: supabase: Client = create_client(url, key) - response = supabase.table("users_api_keys").update(json.loads(api_key_update.model_dump_json())).eq("id", api_key_id).execute() - type_response = supabase.table("api_key_types").select("*").eq("id", response.data[0]["api_key_type_id"]).execute() + response = ( + supabase.table("users_api_keys") + .update(json.loads(api_key_update.model_dump_json())) + .eq("id", api_key_id) + .execute() + ) + type_response = ( + supabase.table("api_key_types") + .select("*") + .eq("id", response.data[0]["api_key_type_id"]) + .execute() + ) api_key_type = APIKeyType(**type_response.data[0]) return APIKey(**response.data[0], api_key_type=api_key_type) @@ -398,7 +581,7 @@ def get_api_key_types() -> list[APIKeyType]: supabase: Client = create_client(url, key) logger.debug("Getting all api key types") response = supabase.table("api_key_types").select("*").execute() - return [APIKeyType(**data) for data in response.data] + return [APIKeyType(**data) for data in response.data] def update_status(session_id: UUID, status: SessionStatus) -> None: @@ -407,10 +590,19 @@ def update_status(session_id: UUID, status: SessionStatus) -> None: supabase.table("sessions").update({"status": status}).eq("id", session_id).execute() +def get_agent(agent_id: UUID) -> Agent | None: + supabase: Client = create_client(url, key) + response = supabase.table("agents").select("*").eq("id", agent_id).execute() + if not response.data: + return None + + return Agent(**response.data[0]) + + def get_agents( profile_id: UUID | None = None, crew_id: UUID | None = None, - published: bool | None = None + published: bool | None = None, ) -> list[Agent] | None: """Gets agents, filtered by what parameters are given""" supabase: Client = create_client(url, key) @@ -427,9 +619,9 @@ def get_agents( response = get_agents_from_crew(crew_id) if not response: return None - + return response - + if published is not None: query = query.eq("published", published) @@ -438,15 +630,6 @@ def get_agents( return [Agent(**data) for data in response.data] -def get_agent(agent_id: UUID) -> Agent | None: - supabase: Client = create_client(url, key) - response = supabase.table("agents").select("*").eq("id", agent_id).execute() - if not response.data: - return None - - return Agent(**response.data[0]) - - def get_agents_from_crew(crew_id: UUID) -> list[Agent] | None: supabase: Client = create_client(url, key) nodes = supabase.table("crews").select("nodes").eq("id", crew_id).execute() @@ -484,18 +667,117 @@ def delete_agent(agent_id: UUID) -> Agent: return Agent(**response.data[0]) +def get_tool(tool_id: UUID) -> Tool | None: + supabase: Client = create_client(url, key) + response = supabase.table("tools").select("*").eq("id", tool_id).execute() + if len(response.data) == 0: + return None + + return Tool(**response.data[0]) + + +def get_tools( + name: str | None = None, + api_key_type_id: UUID | None = None, +) -> list[Tool]: + supabase: Client = create_client(url, key) + query = supabase.table("tools").select("*") + + if name: + query = query.eq("name", name) + + if api_key_type_id: + query = query.eq("api_key_type_id", api_key_type_id) + + response = query.execute() + + return [Tool(**data) for data in response.data] + + +def update_tool(tool_id: UUID, content: ToolUpdateRequest) -> Tool: + supabase: Client = create_client(url, key) + response = ( + supabase.table("tools") + .update(json.loads(content.model_dump_json(exclude_none=True))) + .eq("id", tool_id) + .execute() + ) + return Tool(**response.data[0]) + + +def insert_tool(tool: ToolInsertRequest) -> Tool: + supabase: Client = create_client(url, key) + response = ( + supabase.table("tools") + .insert(json.loads(tool.model_dump_json(exclude_none=True))) + .execute() + ) + return Tool(**response.data[0]) + + +def delete_tool(tool_id: UUID) -> Tool | None: + supabase: Client = create_client(url, key) + response = supabase.table("tools").delete().eq("id", tool_id).execute() + if len(response.data) == 0: + return None + + return Tool(**response.data[0]) + + +def update_agent_tool(agent_id: UUID, tool_id: UUID) -> Agent: + supabase: Client = create_client(url, key) + agent_tools = supabase.table("agents").select("tools").eq("id", agent_id).execute() + tool: dict = {"id": tool_id, "parameter": {}} + + agent_tools.data[0]["tools"].append(tool) + formatted_tools = agent_tools.data[0]["tools"] + response = ( + supabase.table("agents") + .update(json.loads(json.dumps(formatted_tools, default=str))) + .eq("id", agent_id) + .execute() + ) + return Agent(**response.data[0]) + + +def get_tool_api_keys( + profile_id: UUID, api_key_type_ids: list[str] | None = None +) -> dict[str, str]: + """Gets all api keys for a profile id, if api_key_type_ids is given, only give api keys corresponding to those key types.""" + supabase: Client = create_client(url, key) + query = ( + supabase.table("users_api_keys") + .select("api_key", "api_key_type_id") + .eq("profile_id", profile_id) + ) + + if api_key_type_ids: + query = query.in_("api_key_type_id", api_key_type_ids) + + response = query.execute() + return {data["api_key_type_id"]: data["api_key"] for data in response.data} + + +def get_profile(profile_id: UUID) -> Profile | None: + supabase: Client = create_client(url, key) + response = supabase.table("profiles").select("*").eq("id", profile_id).execute() + if len(response.data) == 0: + return None + return Profile(**response.data[0]) + + def get_profiles( tier_id: UUID | None = None, display_name: str | None = None, - stripe_customer_id: str | None = None + stripe_customer_id: str | None = None, ) -> list[Profile]: - """Gets profiles, filtered by what parameters are given""" + """Gets profiles, filtered by what parameters are given""" supabase: Client = create_client(url, key) query = supabase.table("profiles").select("*") if tier_id: query = query.eq("tier_id", tier_id) - + if display_name: query = query.eq("display_name", display_name) @@ -507,17 +789,7 @@ def get_profiles( return [Profile(**data) for data in response.data] -def get_profile(profile_id: UUID) -> Profile | None: - supabase: Client = create_client(url, key) - response = supabase.table("profiles").select("*").eq("id", profile_id).execute() - if len(response.data) == 0: - return None - return Profile(**response.data[0]) - - -def update_profile( - profile_id: UUID, content: ProfileUpdateRequest -) -> Profile: +def update_profile(profile_id: UUID, content: ProfileUpdateRequest) -> Profile: supabase: Client = create_client(url, key) response = ( supabase.table("profiles") @@ -530,7 +802,11 @@ def update_profile( def insert_profile(profile: ProfileInsertRequest) -> Profile: supabase: Client = create_client(url, key) - response = supabase.table("profiles").insert(json.loads(profile.model_dump_json(exclude_none=True))).execute() + response = ( + supabase.table("profiles") + .insert(json.loads(profile.model_dump_json(exclude_none=True))) + .execute() + ) return Profile(**response.data[0]) @@ -540,26 +816,7 @@ def delete_profile(profile_id: UUID) -> Profile: return Profile(**response.data[0]) -if __name__ == "__main__": +if __name__ == "__main__": from src.models import Session -# print( -# insert_session( -# SessionRequest( -# crew_id=UUID("1c11a9bf-748f-482b-9746-6196f136401a"), -# profile_id=UUID("070c1d2e-9d72-4854-a55e-52ade5a42071"), -# title="hello", -# ) -# ) -# ) -# - #print(get_crew(UUID("bf9f1cdc-fb63-45e1-b1ff-9a1989373ce3"))) - ##print(insert_message(MessageRequestModel( - # session_id=UUID("ec4a9ae1-f4de-46cf-946d-956b3081c432"), - # profile_id=UUID("070c1d2e-9d72-4854-a55e-52ade5a42071"), - # content="hello test message", - # recipient_id=UUID("7c707c30-2cfe-46a0-afa7-8bcc38f9687e"), - #))) - - #print(update_message(UUID("c3e4755b-141d-4f77-8ea8-924961ccf36d"), content=MessageUpdateRequest(content="wowzer"))) - #print(get_api_keys(api_key_type_id=UUID("3b64fe26-20b9-4064-907e-f2708b5f1656"))) - print(get_api_key_type_ids(["612ddae6-ecdd-4900-9314-1a2c9de6003d"])) \ No newline at end of file + + print(get_api_key_type_ids(["612ddae6-ecdd-4900-9314-1a2c9de6003d"])) diff --git a/apps/api/src/mock.py b/apps/api/src/mock.py index be09b92a..40600c41 100644 --- a/apps/api/src/mock.py +++ b/apps/api/src/mock.py @@ -1,5 +1,7 @@ from .tools import get_file_path_of_example +DATE = "2024-01-01T00:00:00.000Z" + fizz_buzz: dict = { "id": "00000000-0000-0000-0000-000000000000", "profile_id": "6fcde4e6-6592-471b-9d33-dbf7e2ecfab4", @@ -16,7 +18,7 @@ "title": "", "content": "Write a program that prints the numbers from 1 to 100. But for multiples of three print “Fizz” instead of the number and for the multiples of five print “Buzz”. For numbers which are multiples of both three and five print “FizzBuzz”.", }, - "created_at": "2024-01-01T00:00:00.000Z", + "created_at": DATE, } markdown_table: dict = { @@ -35,53 +37,53 @@ "title": "", "content": "Create a markdown table of the top 10 large language models comparing their abilities by researching on the internet.", }, - "created_at": "2024-01-01T00:00:00.000Z", + "created_at": DATE, } -read_file: dict = { - "id": "00000000-0000-0000-0000-000000000001", - "profile_id": "6fcde4e6-6592-471b-9d33-dbf7e2ecfab4", - "title": "Output file content", - "description": "Read content of file and output it in a nice format", - "receiver_id": "0c0f0b05-e4ff-4d9a-a103-96a9702248f4", - "published": False, - "nodes": [ - "8e26f947-a0e9-4e47-b86f-22930ea948fa", - # "0c0f0b05-e4ff-4d9a-a103-96a9702248f4", - "6e541720-b4ac-4c47-abf3-f17147c9a32a", - ], - "prompt": { - "id": "", - "title": "", - "content": f"Get the file content of the file '{get_file_path_of_example()}', the 'agent python software' can call what function it has been", - }, - "created_at": "2024-01-01T00:00:00.000Z", -} +# read_file: dict = { +# "id": "00000000-0000-0000-0000-000000000001", +# "profile_id": "6fcde4e6-6592-471b-9d33-dbf7e2ecfab4", +# "title": "Output file content", +# "description": "Read content of file and output it in a nice format", +# "receiver_id": "0c0f0b05-e4ff-4d9a-a103-96a9702248f4", +# "published": False, +# "nodes": [ +# "8e26f947-a0e9-4e47-b86f-22930ea948fa", +# # "0c0f0b05-e4ff-4d9a-a103-96a9702248f4", +# "6e541720-b4ac-4c47-abf3-f17147c9a32a", +# ], +# "prompt": { +# "id": "", +# "title": "", +# "content": f"Get the file content of the file '{get_file_path_of_example()}', the 'agent python software' can call what function it has been", +# }, +# "created_at": DATE, +# } # "6e541720-b4ac-4c47-abf3-f17147c9a32a", agent for code reviewing # "2ce0b7db-84f7-4d59-8c38-3fcc3fd7da98", agent for writing tables in markdown -move_file: dict = { - "id": "00000000-0000-0000-0000-000000000001", - "profile_id": "6fcde4e6-6592-471b-9d33-dbf7e2ecfab4", - "title": "Output file content", - "description": "Read content of file and output it in a nice format", - "receiver_id": "0c0f0b05-e4ff-4d9a-a103-96a9702248f4", - "published": False, - "nodes": [ - "8e26f947-a0e9-4e47-b86f-22930ea948fa", - # "0c0f0b05-e4ff-4d9a-a103-96a9702248f4", - "6e541720-b4ac-4c47-abf3-f17147c9a32a", - ], - "prompt": { - "id": "", - "title": "", - "content": f"Move the file: '{get_file_path_of_example()}' to the destination: {get_file_path_of_example().replace('.txt', '_2.txt')} the 'agent python software' can call what function it has been", - }, - "created_at": "2024-01-01T00:00:00.000Z", -} +# move_file: dict = { +# "id": "00000000-0000-0000-0000-000000000001", +# "profile_id": "6fcde4e6-6592-471b-9d33-dbf7e2ecfab4", +# "title": "Output file content", +# "description": "Read content of file and output it in a nice format", +# "receiver_id": "0c0f0b05-e4ff-4d9a-a103-96a9702248f4", +# "published": False, +# "nodes": [ +# "8e26f947-a0e9-4e47-b86f-22930ea948fa", +# # "0c0f0b05-e4ff-4d9a-a103-96a9702248f4", +# "6e541720-b4ac-4c47-abf3-f17147c9a32a", +# ], +# "prompt": { +# "id": "", +# "title": "", +# "content": f"Move the file: '{get_file_path_of_example()}' to the destination: {get_file_path_of_example().replace('.txt', '_2.txt')} the 'agent python software' can call what function it has been", +# }, +# "created_at": DATE, +# } -tool, prompt = "bing search tool", "what is openai? restrict the number of results to 3" +tool, prompt = "brave search tool", "what is openai?" test_tool: dict = { "id": "00000000-0000-0000-0000-000000000001", @@ -98,6 +100,8 @@ "title": "", "content": f"This is a tool testing environment, use the tool: {tool}, {prompt}. Suggest this function call", }, - "created_at": "2024-01-01T00:00:00.000Z", + "created_at": DATE, + "edges": [], + "updated_at": DATE, } crew_model = test_tool diff --git a/apps/api/src/models/__init__.py b/apps/api/src/models/__init__.py index 17425541..f9361e0c 100644 --- a/apps/api/src/models/__init__.py +++ b/apps/api/src/models/__init__.py @@ -1,48 +1,62 @@ from .agent_config import AgentConfig from .agent_model import ( Agent, + AgentGetRequest, AgentInsertRequest, AgentUpdateModel, - AgentGetRequest, +) +from .api_key import ( + APIKey, + APIKeyGetRequest, + APIKeyInsertRequest, + APIKeyType, + APIKeyUpdateRequest, +) +from .billing_information import ( + Billing, + BillingInsertRequest, + BillingUpdateRequest, ) from .code_execution_config import CodeExecutionConfig from .crew_model import ( - CrewProcessed, - CrewInsertRequest, Crew, - CrewUpdateRequest, CrewGetRequest, + CrewInsertRequest, + CrewProcessed, + CrewUpdateRequest, ) from .llm_config import LLMConfig from .message import ( - Message, - MessageInsertRequest, - MessageUpdateRequest, + Message, MessageGetRequest, + MessageInsertRequest, + MessageUpdateRequest, ) from .profile import ( - ProfileInsertRequest, Profile, - ProfileUpdateRequest, ProfileGetRequest, + ProfileInsertRequest, + ProfileUpdateRequest, ) from .session import ( - SessionRunRequest, - SessionRunResponse, Session, + SessionGetRequest, SessionInsertRequest, + SessionRunRequest, + SessionRunResponse, SessionStatus, SessionUpdateRequest, - SessionGetRequest, ) -from .api_key import( - APIKeyInsertRequest, - APIKey, - APIKeyType, - APIKeyUpdateRequest, - APIKeyGetRequest, +from .subscription import ( + Subscription, + SubscriptionGetRequest, + SubscriptionInsertRequest, + SubscriptionUpdateRequest, ) +from .tiers import Tier, TierGetRequest, TierInsertRequest, TierUpdateRequest +from .tool import Tool, ToolGetRequest, ToolInsertRequest, ToolUpdateRequest from .user import User + __all__ = [ "AgentConfig", "CodeExecutionConfig", @@ -77,4 +91,19 @@ "AgentGetRequest", "ProfileGetRequest", "APIKeyGetRequest", + "Subscription", + "SubscriptionInsertRequest", + "SubscriptionUpdateRequest", + "SubscriptionGetRequest", + "Tool", + "ToolInsertRequest", + "ToolUpdateRequest", + "ToolGetRequest", + "Tier", + "TierInsertRequest", + "TierUpdateRequest", + "TierGetRequest", + "Billing", + "BillingInsertRequest", + "BillingUpdateRequest", ] diff --git a/apps/api/src/models/agent_model.py b/apps/api/src/models/agent_model.py index 8a4c6a15..c4cc63f0 100644 --- a/apps/api/src/models/agent_model.py +++ b/apps/api/src/models/agent_model.py @@ -18,7 +18,7 @@ class Agent(BaseModel): model: Literal["gpt-3.5-turbo", "gpt-4-turbo-preview"] tools: list[dict] description: str | None = None - role: str + role: str version: str | None = None diff --git a/apps/api/src/models/api_key.py b/apps/api/src/models/api_key.py index 9ee5febe..b4c5a07d 100644 --- a/apps/api/src/models/api_key.py +++ b/apps/api/src/models/api_key.py @@ -1,6 +1,8 @@ from __future__ import annotations -from uuid import UUID, uuid4 + from datetime import datetime +from uuid import UUID, uuid4 + from pydantic import BaseModel @@ -32,4 +34,4 @@ class APIKeyType(BaseModel): id: UUID created_at: datetime name: str | None = None - description: str | None = None \ No newline at end of file + description: str | None = None diff --git a/apps/api/src/models/billing_information.py b/apps/api/src/models/billing_information.py new file mode 100644 index 00000000..95308c77 --- /dev/null +++ b/apps/api/src/models/billing_information.py @@ -0,0 +1,22 @@ +from datetime import UTC, datetime +from uuid import UUID, uuid4 + +from pydantic import BaseModel, Field + + +class Billing(BaseModel): + profile_id: UUID + stripe_payment_method: str | None = None + description: str | None = None + created_at: datetime + + +class BillingInsertRequest(BaseModel): + profile_id: UUID + stripe_payment_method: str | None = None + description: str | None = None + + +class BillingUpdateRequest(BaseModel): + stripe_payment_method: str | None = None + description: str | None = None diff --git a/apps/api/src/models/crew_model.py b/apps/api/src/models/crew_model.py index 65d274d1..08c99efe 100644 --- a/apps/api/src/models/crew_model.py +++ b/apps/api/src/models/crew_model.py @@ -11,12 +11,13 @@ class CrewProcessed(BaseModel): receiver_id: UUID - delegator_id: UUID | None = None + delegator_id: UUID | None = None # None means admin again, so its the original crew (has no parent crew) agents: list[Agent] - sub_crews: list[Crew] = [] + sub_crews: list[Crew] = [] # Must set delegator_id for each sub_crew in sub_crews + class Crew(BaseModel): id: UUID created_at: datetime @@ -58,4 +59,4 @@ class CrewGetRequest(BaseModel): profile_id: UUID | None = None receiver_id: UUID | None = None title: str | None = None - published: bool | None = None \ No newline at end of file + published: bool | None = None diff --git a/apps/api/src/models/edge.py b/apps/api/src/models/edge.py index 5ec54da4..e7a4c8cf 100644 --- a/apps/api/src/models/edge.py +++ b/apps/api/src/models/edge.py @@ -1,8 +1,10 @@ -from typing import Optional, Union, Generic, TypeVar +from typing import Generic, Optional, TypeVar, Union + from pydantic import BaseModel, Field T = TypeVar("T") + class Marker(BaseModel): type: str color: Optional[str] = None @@ -12,12 +14,14 @@ class Marker(BaseModel): orient: Optional[str] = None strokeWidth: Optional[float] = None + class PathOptions(BaseModel): offset: Optional[float] = None borderRadius: Optional[float] = None curvature: Optional[float] = None -class Edge(Generic[T], BaseModel): + +class Edge(BaseModel, Generic[T]): id: str type: Optional[str] = None source: str @@ -42,4 +46,4 @@ class Edge(Generic[T], BaseModel): pathOptions: Optional[PathOptions] = None class Config: - populate_by_name = True \ No newline at end of file + populate_by_name = True diff --git a/apps/api/src/models/message.py b/apps/api/src/models/message.py index deceb4c3..ba567b5a 100644 --- a/apps/api/src/models/message.py +++ b/apps/api/src/models/message.py @@ -5,14 +5,14 @@ class Message(BaseModel): - id: UUID = Field(default_factory=lambda: uuid4()) + id: UUID session_id: UUID profile_id: UUID sender_id: UUID | None = None # None means admin here recipient_id: UUID | None = None # None means admin here aswell content: str - role: str = "user" - created_at: datetime = Field(default_factory=lambda: datetime.now(tz=UTC)) + role: str + created_at: datetime class MessageInsertRequest(BaseModel): @@ -28,13 +28,13 @@ class MessageUpdateRequest(BaseModel): session_id: UUID | None = None content: str | None = None role: str | None = None - recipient_id: UUID | None = None - sender_id: UUID | None = None + recipient_id: UUID | None = None + sender_id: UUID | None = None profile_id: UUID | None = None class MessageGetRequest(BaseModel): session_id: UUID | None = None profile_id: UUID | None = None - recipient_id: UUID | None = None + recipient_id: UUID | None = None sender_id: UUID | None = None diff --git a/apps/api/src/models/profile.py b/apps/api/src/models/profile.py index a269234d..cad8ca11 100644 --- a/apps/api/src/models/profile.py +++ b/apps/api/src/models/profile.py @@ -1,9 +1,8 @@ +from datetime import datetime from uuid import UUID, uuid4 from pydantic import BaseModel, Field -from datetime import datetime -# = Field(default_factory=lambda: uuid4()) class Profile(BaseModel): id: UUID @@ -14,7 +13,7 @@ class Profile(BaseModel): class ProfileInsertRequest(BaseModel): - # user id needs to be passed since its created from some "auth" table in the db + # user id needs to be passed since its created from some "auth" table in the db user_id: UUID tier_id: UUID display_name: str diff --git a/apps/api/src/models/rest_comment.py b/apps/api/src/models/rest_comment.py deleted file mode 100644 index 195ac3f2..00000000 --- a/apps/api/src/models/rest_comment.py +++ /dev/null @@ -1,10 +0,0 @@ -from uuid import UUID -from pydantic import BaseModel - -# TODO: This is placed here since the openapi schema of this model cant be generated if its in the src/rest/models directory for some reason -# will move this later on but this works for now -class PublishCommentRequest(BaseModel): - lead_id: UUID - comment: str - reddit_username: str - reddit_password: str diff --git a/apps/api/src/models/session.py b/apps/api/src/models/session.py index 8b48eaf4..988ab3c6 100644 --- a/apps/api/src/models/session.py +++ b/apps/api/src/models/session.py @@ -12,16 +12,15 @@ class SessionStatus(StrEnum): IDLE = auto() - class Session(BaseModel): - id: UUID = Field(default_factory=lambda: uuid4()) - created_at: datetime = Field(default_factory=lambda: datetime.now(tz=UTC)) + id: UUID + created_at: datetime profile_id: UUID - reply: str = "" + reply: str crew_id: UUID - title: str = "Untitled" - last_opened_at: datetime = Field(default_factory=lambda: datetime.now(tz=UTC)) - status: SessionStatus = SessionStatus.RUNNING + title: str + last_opened_at: datetime + status: SessionStatus class SessionInsertRequest(BaseModel): @@ -55,4 +54,4 @@ class SessionGetRequest(BaseModel): profile_id: UUID | None = None crew_id: UUID | None = None title: str | None = None - status: SessionStatus | None = None \ No newline at end of file + status: SessionStatus | None = None diff --git a/apps/api/src/models/subscription.py b/apps/api/src/models/subscription.py new file mode 100644 index 00000000..3bbbd736 --- /dev/null +++ b/apps/api/src/models/subscription.py @@ -0,0 +1,24 @@ +from datetime import UTC, datetime +from uuid import UUID, uuid4 + +from pydantic import BaseModel, Field + + +class Subscription(BaseModel): + profile_id: UUID + stripe_subscription_id: str | None = None + created_at: datetime + + +class SubscriptionInsertRequest(BaseModel): + profile_id: UUID + stripe_subscription_id: str | None = None + + +class SubscriptionUpdateRequest(BaseModel): + stripe_subscription_id: str | None = None + + +class SubscriptionGetRequest(BaseModel): + profile_id: UUID | None = None + stripe_subscription_id: str | None = None diff --git a/apps/api/src/models/tiers.py b/apps/api/src/models/tiers.py new file mode 100644 index 00000000..12df0ae1 --- /dev/null +++ b/apps/api/src/models/tiers.py @@ -0,0 +1,42 @@ +from datetime import UTC, datetime +from uuid import UUID, uuid4 + +from pydantic import BaseModel, Field + + +class Tier(BaseModel): + id: UUID + created_at: datetime + period: int + limit: int + stripe_price_id: str | None = None + name: str | None = None + description: str | None = None + slug: str | None = None + image: str | None = None + + +class TierInsertRequest(BaseModel): + period: int | None = None + limit: int | None = None + stripe_price_id: str | None = None + name: str | None = None + description: str | None = None + slug: str | None = None + image: str | None = None + + +class TierUpdateRequest(BaseModel): + period: int | None = None + limit: int | None = None + stripe_price_id: str | None = None + name: str | None = None + description: str | None = None + slug: str | None = None + image: str | None = None + + +class TierGetRequest(BaseModel): + id: UUID + stripe_price_id: str | None = None + name: str | None = None diff --git a/apps/api/src/models/tool.py b/apps/api/src/models/tool.py new file mode 100644 index 00000000..2f13101a --- /dev/null +++ b/apps/api/src/models/tool.py @@ -0,0 +1,29 @@ +from datetime import UTC, datetime +from uuid import UUID, uuid4 + +from pydantic import BaseModel, Field + + +class Tool(BaseModel): + id: UUID + created_at: datetime + name: str + description: str + api_key_type_id: UUID | None = None + + +class ToolInsertRequest(BaseModel): + name: str + description: str + api_key_type_id: UUID | None = None + + +class ToolUpdateRequest(BaseModel): + name: str | None = None + description: str | None = None + api_key_type_id: UUID | None = None + + +class ToolGetRequest(BaseModel): + name: str | None = None + api_key_type_id: UUID | None = None diff --git a/apps/api/src/models/user.py b/apps/api/src/models/user.py index 50bdeb63..ad1649bd 100644 --- a/apps/api/src/models/user.py +++ b/apps/api/src/models/user.py @@ -4,4 +4,4 @@ class User: id: UUID name: str - email: str \ No newline at end of file + email: str diff --git a/apps/api/src/parser.py b/apps/api/src/parser.py index 70bb2c39..ba0e6aa7 100644 --- a/apps/api/src/parser.py +++ b/apps/api/src/parser.py @@ -46,12 +46,13 @@ def get_agents(agent_ids: list[UUID]) -> list[Agent]: response = supabase.table("agents").select("*").in_("id", agent_ids).execute() return [Agent(**agent) for agent in response.data] + def process_crew(crew: Crew) -> tuple[str, CrewProcessed]: logger.debug("Processing crew") agent_ids: list[UUID] = crew.nodes if not crew.receiver_id: raise HTTPException(400, "got no receiver id") - + receiver_id: UUID = crew.receiver_id crew_model = CrewProcessed( @@ -60,6 +61,16 @@ def process_crew(crew: Crew) -> tuple[str, CrewProcessed]: ) if not crew.prompt: raise HTTPException(400, "got no prompt") + if len(crew_model.agents) == 0: + raise ValueError("crew had no agents") + # Validate agents + for agent in crew_model.agents: + if agent.role == "": + raise ValueError(f"agent {agent.id} had no role") + if agent.title == "": + raise ValueError(f"agent {agent.id} had no title") + if agent.system_message == "": + raise ValueError(f"agent {agent.id} had no system message") message: str = crew.prompt["content"] return message, crew_model diff --git a/apps/api/src/rest/.gitignore b/apps/api/src/rest/.gitignore new file mode 100644 index 00000000..fc61eafa --- /dev/null +++ b/apps/api/src/rest/.gitignore @@ -0,0 +1 @@ +/cache/ \ No newline at end of file diff --git a/apps/api/src/rest/__init__.py b/apps/api/src/rest/__init__.py index 735ea8d6..364f2d9f 100644 --- a/apps/api/src/rest/__init__.py +++ b/apps/api/src/rest/__init__.py @@ -1,27 +1,21 @@ -import os -from dotenv import load_dotenv import logging -import diskcache as dc +import os import threading from uuid import uuid4 -from .saving import update_db_with_submission -from . import mail -from .reddit_utils import get_subreddits -from .relevance_bot import evaluate_relevance -from .interfaces import db -from . import comment_bot -from .models import ( - PublishCommentRequest, - GenerateCommentRequest, - FalseLead, -) -from .reddit_worker import RedditStreamWorker - +import diskcache as dc +from dotenv import load_dotenv from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse +from . import comment_bot, mail +from .interfaces import db +from .models import FalseLead, GenerateCommentRequest, PublishCommentRequest +from .reddit_utils import get_subreddits +from .reddit_worker import RedditStreamWorker +from .relevance_bot import evaluate_relevance +from .saving import update_db_with_submission # Relevant subreddits to Startino SUBREDDIT_NAMES = ( diff --git a/apps/api/src/rest/cache/cache.db b/apps/api/src/rest/cache/cache.db deleted file mode 100644 index 223eaf3a..00000000 Binary files a/apps/api/src/rest/cache/cache.db and /dev/null differ diff --git a/apps/api/src/rest/chat_bot.py b/apps/api/src/rest/chat_bot.py index b480b463..4c32a7e7 100644 --- a/apps/api/src/rest/chat_bot.py +++ b/apps/api/src/rest/chat_bot.py @@ -2,24 +2,24 @@ from selenium.webdriver.common.keys import Keys # Set the path to your Chrome driver executable -driver_path = '/path/to/chromedriver' +driver_path = "/path/to/chromedriver" # Create a new instance of the Chrome driver driver = webdriver.Chrome(driver_path) # Open chat.reddit.com -driver.get('https://chat.reddit.com') +driver.get("https://chat.reddit.com") # Find the login button and click it login_button = driver.find_element_by_xpath('//button[contains(text(), "Log in")]') login_button.click() # Find the username and password input fields and enter your credentials -username_input = driver.find_element_by_name('username') -username_input.send_keys('your_username') +username_input = driver.find_element_by_name("username") +username_input.send_keys("your_username") -password_input = driver.find_element_by_name('password') -password_input.send_keys('your_password') +password_input = driver.find_element_by_name("password") +password_input.send_keys("your_password") # Submit the login form -password_input.send_keys(Keys.RETURN) \ No newline at end of file +password_input.send_keys(Keys.RETURN) diff --git a/apps/api/src/rest/comment_bot.py b/apps/api/src/rest/comment_bot.py index 5ffd7395..121b9e36 100644 --- a/apps/api/src/rest/comment_bot.py +++ b/apps/api/src/rest/comment_bot.py @@ -1,16 +1,17 @@ import logging -from langchain_openai import ChatOpenAI -from langchain_core.prompts import PromptTemplate +import os + +from dotenv import load_dotenv from langchain_core.output_parsers import JsonOutputParser -from .models import EvaluatedSubmission, RedditComment, PublishCommentResponse -from .dummy_submissions import relevant_submissions, irrelevant_submissions -from .prompts import generate_comment_prompt +from langchain_core.prompts import PromptTemplate +from langchain_openai import ChatOpenAI + +from .dummy_submissions import irrelevant_submissions, relevant_submissions from .interfaces import db +from .models import EvaluatedSubmission, PublishCommentResponse, RedditComment +from .prompts import generate_comment_prompt from .reddit_utils import get_reddit_instance -from dotenv import load_dotenv -import os - # Load Enviornment variables load_dotenv() OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") diff --git a/apps/api/src/rest/dm.py b/apps/api/src/rest/dm.py index 77853a82..76561033 100644 --- a/apps/api/src/rest/dm.py +++ b/apps/api/src/rest/dm.py @@ -1,9 +1,15 @@ -from reddit_utils import REDDIT_PASSWORD, get_subreddits, reply -from Reddit_ChatBot_Python import ChatBot, RedditAuthentication -from Reddit_ChatBot_Python import CustomType, Snoo, Reaction -from dotenv import load_dotenv import os +from dotenv import load_dotenv +from Reddit_ChatBot_Python import ( + ChatBot, + CustomType, + Reaction, + RedditAuthentication, + Snoo, +) +from reddit_utils import REDDIT_PASSWORD, get_subreddits, reply + load_dotenv() REDDIT_PASSWORD = os.getenv("REDDIT_PASSWORD") diff --git a/apps/api/src/rest/interfaces/db.py b/apps/api/src/rest/interfaces/db.py index a6e1745a..c558590b 100644 --- a/apps/api/src/rest/interfaces/db.py +++ b/apps/api/src/rest/interfaces/db.py @@ -1,6 +1,7 @@ import json import logging import os +from datetime import datetime, timedelta from typing import Literal from uuid import UUID @@ -8,10 +9,7 @@ from pydantic import ValidationError from supabase import Client, create_client -from src.rest.models import Lead, PublishCommentResponse -from datetime import datetime, timedelta - -from src.rest.models import SavedSubmission +from src.rest.models import Lead, PublishCommentResponse, SavedSubmission load_dotenv() diff --git a/apps/api/src/rest/logging_utils.py b/apps/api/src/rest/logging_utils.py index cf2dfdd3..fad61c9e 100644 --- a/apps/api/src/rest/logging_utils.py +++ b/apps/api/src/rest/logging_utils.py @@ -1,6 +1,7 @@ -from praw.models import Submission from datetime import datetime +from praw.models import Submission + def log_relevance_calculation( model: str, submission: Submission, is_relevant: bool, cost: float, reason: str @@ -18,4 +19,3 @@ def log_relevance_calculation( print(f"Is Relevant: {'Yes' if is_relevant else 'No'}") print(f"Reason: {reason}") print("\n\n") - diff --git a/apps/api/src/rest/mail.py b/apps/api/src/rest/mail.py index 5013c3a8..69327966 100644 --- a/apps/api/src/rest/mail.py +++ b/apps/api/src/rest/mail.py @@ -1,12 +1,14 @@ -import smtplib -from email.mime.text import MIMEText -from email.mime.multipart import MIMEMultipart import os -from dotenv import load_dotenv +import smtplib from datetime import datetime +from email.mime.multipart import MIMEMultipart +from email.mime.text import MIMEText + import diskcache as dc -from .models import EvaluatedSubmission import markdown +from dotenv import load_dotenv + +from .models import EvaluatedSubmission load_dotenv() diff --git a/apps/api/src/rest/models/__init__.py b/apps/api/src/rest/models/__init__.py index fbac0500..7077061e 100644 --- a/apps/api/src/rest/models/__init__.py +++ b/apps/api/src/rest/models/__init__.py @@ -1,13 +1,13 @@ -from .relevance_result import RelevanceResult from .dummy_submission import DummySubmission +from .evaluated_submission import EvaluatedSubmission +from .false_lead import FalseLead from .filter_output import FilterOutput from .filter_question import FilterQuestion -from .evaluated_submission import EvaluatedSubmission from .lead import Lead -from .reddit_comment import RedditComment, GenerateCommentRequest from .publish_comment import PublishCommentRequest, PublishCommentResponse +from .reddit_comment import GenerateCommentRequest, RedditComment +from .relevance_result import RelevanceResult from .saved_submission import SavedSubmission -from .false_lead import FalseLead __all__ = [ "RelevanceResult", diff --git a/apps/api/src/rest/models/dummy_submission.py b/apps/api/src/rest/models/dummy_submission.py index 2b2526cf..ebc03630 100644 --- a/apps/api/src/rest/models/dummy_submission.py +++ b/apps/api/src/rest/models/dummy_submission.py @@ -1,8 +1,9 @@ from pydantic import BaseModel, Field + class DummySubmission(BaseModel): id: str url: str created_utc: int title: str - selftext: str \ No newline at end of file + selftext: str diff --git a/apps/api/src/rest/models/evaluated_submission.py b/apps/api/src/rest/models/evaluated_submission.py index 8ae024fc..2f71eeef 100644 --- a/apps/api/src/rest/models/evaluated_submission.py +++ b/apps/api/src/rest/models/evaluated_submission.py @@ -1,4 +1,5 @@ from typing import Optional + from praw.models import Submission from pydantic import BaseModel, ConfigDict diff --git a/apps/api/src/rest/models/false_lead.py b/apps/api/src/rest/models/false_lead.py index e8f7a249..d643aef4 100644 --- a/apps/api/src/rest/models/false_lead.py +++ b/apps/api/src/rest/models/false_lead.py @@ -1,6 +1,7 @@ -from pydantic import BaseModel from uuid import UUID +from pydantic import BaseModel + class FalseLead(BaseModel): lead_id: UUID diff --git a/apps/api/src/rest/models/filter_output.py b/apps/api/src/rest/models/filter_output.py index 22fe05ec..7e3145e2 100644 --- a/apps/api/src/rest/models/filter_output.py +++ b/apps/api/src/rest/models/filter_output.py @@ -5,7 +5,10 @@ from langchain_core.pydantic_v1 import BaseModel, Field, validator from langchain_openai import ChatOpenAI + # Define your desired data structure. class FilterOutput(BaseModel): answer: bool = Field(description="Answer to the yes-no question.") - source: str = Field(description="Either the piece of text you used to answer the question or the logical reason behind it. Should be brief and only have the relevant information") + source: str = Field( + description="Either the piece of text you used to answer the question or the logical reason behind it. Should be brief and only have the relevant information" + ) diff --git a/apps/api/src/rest/models/filter_question.py b/apps/api/src/rest/models/filter_question.py index 66cb8609..36c8bb47 100644 --- a/apps/api/src/rest/models/filter_question.py +++ b/apps/api/src/rest/models/filter_question.py @@ -1,6 +1,6 @@ from pydantic import BaseModel + class FilterQuestion(BaseModel): question: str reject_on: bool - diff --git a/apps/api/src/rest/models/lead.py b/apps/api/src/rest/models/lead.py index 8082c90f..ce9c05c4 100644 --- a/apps/api/src/rest/models/lead.py +++ b/apps/api/src/rest/models/lead.py @@ -1,8 +1,9 @@ -from re import S -from pydantic import BaseModel, Field from datetime import UTC, datetime +from re import S from uuid import UUID, uuid4 +from pydantic import BaseModel, Field + class Lead(BaseModel): """ @@ -10,10 +11,9 @@ class Lead(BaseModel): """ id: UUID = Field(default_factory=lambda: uuid4()) - submission_id : UUID + submission_id: UUID reddit_id: str - discovered_at: datetime = Field( - default_factory=lambda: datetime.now(tz=UTC)) + discovered_at: datetime = Field(default_factory=lambda: datetime.now(tz=UTC)) last_contacted_at: datetime | None = None prospect_username: str source: str diff --git a/apps/api/src/rest/models/publish_comment.py b/apps/api/src/rest/models/publish_comment.py index 5633bfad..f6fbe816 100644 --- a/apps/api/src/rest/models/publish_comment.py +++ b/apps/api/src/rest/models/publish_comment.py @@ -1,13 +1,16 @@ from datetime import datetime from uuid import UUID + from pydantic import BaseModel + class PublishCommentRequest(BaseModel): lead_id: UUID comment: str reddit_username: str reddit_password: str + class PublishCommentDataObject(BaseModel): url: str body: str @@ -25,4 +28,4 @@ class PublishCommentResponse(BaseModel): data: PublishCommentDataObject | None = None last_event: str status: str - comment: str | None = None \ No newline at end of file + comment: str | None = None diff --git a/apps/api/src/rest/models/reddit_comment.py b/apps/api/src/rest/models/reddit_comment.py index b99184d7..b8823d42 100644 --- a/apps/api/src/rest/models/reddit_comment.py +++ b/apps/api/src/rest/models/reddit_comment.py @@ -4,11 +4,11 @@ from pydantic import BaseModel as PydanticBaseModel - class RedditComment(BaseModel): comment: str = Field(description="the text of the reddit comment") # Not sure if this should be a model or simply a string. + class GenerateCommentRequest(PydanticBaseModel): title: str selftext: str diff --git a/apps/api/src/rest/models/relevance_result.py b/apps/api/src/rest/models/relevance_result.py index 0c3caa0c..fc797ca7 100644 --- a/apps/api/src/rest/models/relevance_result.py +++ b/apps/api/src/rest/models/relevance_result.py @@ -1,6 +1,11 @@ from langchain_core.pydantic_v1 import BaseModel, Field + class RelevanceResult(BaseModel): - certainty: float = Field(description="A value between 0-1 to determine how certain you are that the is_relevant answer is factually correct.") + certainty: float = Field( + description="A value between 0-1 to determine how certain you are that the is_relevant answer is factually correct." + ) is_relevant: bool = Field(description="Determines if the post is relevant.") - reason: str = Field(description="Explain why you determined this post is relevant or irrelevant. Format: Post is [answer] because [reason]. Hence, it is not a lead and not relevant") + reason: str = Field( + description="Explain why you determined this post is relevant or irrelevant. Format: Post is [answer] because [reason]. Hence, it is not a lead and not relevant" + ) diff --git a/apps/api/src/rest/models/saved_submission.py b/apps/api/src/rest/models/saved_submission.py index 9e40fe8e..35b94d28 100644 --- a/apps/api/src/rest/models/saved_submission.py +++ b/apps/api/src/rest/models/saved_submission.py @@ -1,7 +1,8 @@ -from pydantic import BaseModel, Field -from datetime import datetime, UTC +from datetime import UTC, datetime from typing import Optional -from uuid import uuid4, UUID +from uuid import UUID, uuid4 + +from pydantic import BaseModel, Field class SavedSubmission(BaseModel): diff --git a/apps/api/src/rest/prompts/__init__.py b/apps/api/src/rest/prompts/__init__.py index 91ad6927..60e1a535 100644 --- a/apps/api/src/rest/prompts/__init__.py +++ b/apps/api/src/rest/prompts/__init__.py @@ -1,5 +1,12 @@ -from .relevance import calculate_relevance_prompt, purpose, ideal_customer_profile, context, good_examples, bad_examples from .commenting import generate_comment_prompt +from .relevance import ( + bad_examples, + calculate_relevance_prompt, + context, + good_examples, + ideal_customer_profile, + purpose, +) __all__ = [ "calculate_relevance_prompt", diff --git a/apps/api/src/rest/prompts/commenting.py b/apps/api/src/rest/prompts/commenting.py index 4bea145f..0e76dd8f 100644 --- a/apps/api/src/rest/prompts/commenting.py +++ b/apps/api/src/rest/prompts/commenting.py @@ -1,16 +1,21 @@ -from gptrim import trim import os +from gptrim import trim + # Get the directory of the current script file script_dir = os.path.dirname(os.path.realpath(__file__)) -with open(os.path.join(script_dir, "startino_business_plan.md"), "r", encoding='utf-8') as file: +with open( + os.path.join(script_dir, "startino_business_plan.md"), "r", encoding="utf-8" +) as file: company_context = file.read() -with open(os.path.join(script_dir, "good_comment_examples.md"), "r", encoding='utf-8') as file: +with open( + os.path.join(script_dir, "good_comment_examples.md"), "r", encoding="utf-8" +) as file: good_examples = file.read() -with open(os.path.join(script_dir, "bad_examples.md"), "r", encoding='utf-8') as file: +with open(os.path.join(script_dir, "bad_examples.md"), "r", encoding="utf-8") as file: bad_examples = file.read() purpose = """ @@ -53,7 +58,8 @@ writing comments that fulfill the purpose. """ -generate_comment_prompt = trim(f""" +generate_comment_prompt = trim( + f""" # INSTRUCTIONS {roleplay} # PURPOSE @@ -64,5 +70,5 @@ {context} # EXAMPLES {examples} -""") - +""" +) diff --git a/apps/api/src/rest/prompts/relevance.py b/apps/api/src/rest/prompts/relevance.py index f7ebafdc..c8152baf 100644 --- a/apps/api/src/rest/prompts/relevance.py +++ b/apps/api/src/rest/prompts/relevance.py @@ -1,16 +1,19 @@ -from gptrim import trim import os +from gptrim import trim + # Get the directory of the current script file script_dir = os.path.dirname(os.path.realpath(__file__)) -with open(os.path.join(script_dir, "startino_business_plan.md"), "r", encoding='utf-8') as file: +with open( + os.path.join(script_dir, "startino_business_plan.md"), "r", encoding="utf-8" +) as file: company_context = file.read() -with open(os.path.join(script_dir, "good_examples.md"), "r", encoding='utf-8') as file: +with open(os.path.join(script_dir, "good_examples.md"), "r", encoding="utf-8") as file: good_examples = file.read() -with open(os.path.join(script_dir, "bad_examples.md"), "r", encoding='utf-8') as file: +with open(os.path.join(script_dir, "bad_examples.md"), "r", encoding="utf-8") as file: bad_examples = file.read() purpose = """ @@ -70,7 +73,8 @@ relevant to look into for your boss. """ -calculate_relevance_prompt = trim(f""" +calculate_relevance_prompt = trim( + f""" # INSTRUCTIONS {roleplay} # PURPOSE @@ -85,5 +89,5 @@ {examples} -""") - +""" +) diff --git a/apps/api/src/rest/reddit_utils.py b/apps/api/src/rest/reddit_utils.py index 5190e1dc..f3d6182a 100644 --- a/apps/api/src/rest/reddit_utils.py +++ b/apps/api/src/rest/reddit_utils.py @@ -1,7 +1,8 @@ +import os + +from dotenv import load_dotenv from praw import Reddit from praw.models import Subreddits -from dotenv import load_dotenv -import os load_dotenv() diff --git a/apps/api/src/rest/reddit_worker.py b/apps/api/src/rest/reddit_worker.py index 20282fc4..15b74678 100644 --- a/apps/api/src/rest/reddit_worker.py +++ b/apps/api/src/rest/reddit_worker.py @@ -1,15 +1,15 @@ -from praw import Reddit -from praw.models import Subreddits, Submission import os -from dotenv import load_dotenv +from pathlib import Path from urllib.parse import quote_plus -import diskcache as dc -from pathlib import Path +import diskcache as dc +from dotenv import load_dotenv +from praw import Reddit +from praw.models import Submission, Subreddits +from .reddit_utils import get_reddit_instance, get_subreddits from .relevance_bot import evaluate_relevance from .saving import update_db_with_submission -from .reddit_utils import get_subreddits, get_reddit_instance load_dotenv() REDDIT_CLIENT_ID = os.getenv("REDDIT_CLIENT_ID") diff --git a/apps/api/src/rest/relevance_bot.py b/apps/api/src/rest/relevance_bot.py index 80f378f3..c8367ad9 100644 --- a/apps/api/src/rest/relevance_bot.py +++ b/apps/api/src/rest/relevance_bot.py @@ -1,21 +1,27 @@ +import os import time from typing import List -import os -from dotenv import load_dotenv +from dotenv import load_dotenv from gptrim import trim -from praw.models import Submission -from langchain_openai import ChatOpenAI -from langchain_core.prompts import PromptTemplate -from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser from langchain_community.callbacks import get_openai_callback +from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser +from langchain_core.prompts import PromptTemplate +from langchain_openai import ChatOpenAI +from praw.models import Submission -from .models import EvaluatedSubmission, RelevanceResult, FilterOutput, FilterQuestion -from .prompts import calculate_relevance_prompt, context as company_context, purpose -from .dummy_submissions import relevant_submissions, irrelevant_submissions -from .utils import majority_vote, calculate_certainty_from_bools +from .dummy_submissions import irrelevant_submissions, relevant_submissions from .logging_utils import log_relevance_calculation - +from .models import ( + EvaluatedSubmission, + FilterOutput, + FilterQuestion, + RelevanceResult, +) +from .prompts import calculate_relevance_prompt +from .prompts import context as company_context +from .prompts import purpose +from .utils import calculate_certainty_from_bools, majority_vote # Load Enviornment variables load_dotenv() @@ -77,8 +83,8 @@ def invoke_chain(chain, submission: Submission) -> tuple[RelevanceResult, float] time.sleep(10) # Wait for 10 seconds before trying again raise RuntimeError( - "Failed to invoke chain after 3 attempts. Most likely no more credits left or usage limit has been reached." -) + "Failed to invoke chain after 3 attempts. Most likely no more credits left or usage limit has been reached." + ) def summarize_submission(submission: Submission) -> Submission: diff --git a/apps/api/src/rest/saving.py b/apps/api/src/rest/saving.py index 066c2360..4e7f5125 100644 --- a/apps/api/src/rest/saving.py +++ b/apps/api/src/rest/saving.py @@ -1,9 +1,8 @@ import os from . import comment_bot -from .models import Lead from .interfaces import db -from .models import EvaluatedSubmission, SavedSubmission +from .models import EvaluatedSubmission, Lead, SavedSubmission # Get the current file's directory current_dir = os.path.dirname(os.path.realpath(__file__)) diff --git a/apps/api/src/rest/utils.py b/apps/api/src/rest/utils.py index 833d776a..6a758d41 100644 --- a/apps/api/src/rest/utils.py +++ b/apps/api/src/rest/utils.py @@ -1,16 +1,15 @@ - - from typing import List def majority_vote(bool_list: List[bool]) -> bool: return sum(bool_list) > len(bool_list) / 2 + def calculate_certainty_from_bools(bool_list: List[bool]) -> float: length = len(bool_list) total = sum(bool_list) - + true_certainty = total / length - false_certainty = (length-total) / length + false_certainty = (length - total) / length - return max(true_certainty, false_certainty) \ No newline at end of file + return max(true_certainty, false_certainty) diff --git a/apps/api/src/routers/agents.py b/apps/api/src/routers/agents.py index 4401c44d..e1ccdc68 100644 --- a/apps/api/src/routers/agents.py +++ b/apps/api/src/routers/agents.py @@ -5,10 +5,10 @@ from src.interfaces import db from src.models import ( - AgentInsertRequest, - AgentUpdateModel, Agent, AgentGetRequest, + AgentInsertRequest, + AgentUpdateModel, ) router = APIRouter( @@ -24,7 +24,7 @@ def get_agents(q: AgentGetRequest = Depends()) -> list[Agent]: response = db.get_agents(q.profile_id, q.crew_id, q.published) if not response: raise HTTPException(404, "crew not found or crew has no agents") - + return response @@ -46,9 +46,7 @@ def insert_agent(agent_request: AgentInsertRequest) -> Agent: @router.patch("/{agent_id}") -def patch_agent( - agent_id: UUID, agent_update_request: AgentUpdateModel -) -> Agent: +def patch_agent(agent_id: UUID, agent_update_request: AgentUpdateModel) -> Agent: if not db.get_agent(agent_id): raise HTTPException(404, "agent not found") diff --git a/apps/api/src/routers/api_key_types.py b/apps/api/src/routers/api_key_types.py index 741bf75f..a7017e52 100644 --- a/apps/api/src/routers/api_key_types.py +++ b/apps/api/src/routers/api_key_types.py @@ -4,15 +4,13 @@ from fastapi import APIRouter, HTTPException from src.interfaces import db -from src.models import ( - APIKeyType, -) +from src.models import APIKeyType -router = APIRouter(prefix="/api_key_types", tags=["api key types"]) +router = APIRouter(prefix="/api-key-types", tags=["api key types"]) logger = logging.getLogger("root") @router.get("/") def get_all_api_key_types() -> list[APIKeyType]: - return db.get_api_key_types() \ No newline at end of file + return db.get_api_key_types() diff --git a/apps/api/src/routers/api_keys.py b/apps/api/src/routers/api_keys.py index 333b93ea..8c32f4e5 100644 --- a/apps/api/src/routers/api_keys.py +++ b/apps/api/src/routers/api_keys.py @@ -6,21 +6,20 @@ from src.interfaces import db from src.models import ( APIKey, + APIKeyGetRequest, APIKeyInsertRequest, APIKeyUpdateRequest, - APIKeyGetRequest, ) -router = APIRouter(prefix="/api_keys", tags=["api keys"]) +router = APIRouter(prefix="/api-keys", tags=["api keys"]) @router.get("/") def get_api_keys(q: APIKeyGetRequest = Depends()) -> list[APIKey]: """Returns api keys with the api key type as an object with the id, name, description etc.""" - if q.profile_id: - if not db.get_profile(q.profile_id): - raise HTTPException(404, "profile not found") - + if q.profile_id and not db.get_profile(q.profile_id): + raise HTTPException(404, "profile not found") + return db.get_api_keys(q.profile_id, q.api_key_type_id, q.api_key) @@ -41,6 +40,7 @@ def insert_api_key(api_key_request: APIKeyInsertRequest) -> APIKey: return response + @router.delete("/{api_key_id}") def delete_api_key(api_key_id: UUID) -> APIKey: deleted_key = db.delete_api_key(api_key_id) @@ -52,4 +52,4 @@ def delete_api_key(api_key_id: UUID) -> APIKey: @router.patch("/{api_key_id}") def update_api_key(api_key_id: UUID, api_key_update: APIKeyUpdateRequest) -> APIKey: - return db.update_api_key(api_key_id, api_key_update) \ No newline at end of file + return db.update_api_key(api_key_id, api_key_update) diff --git a/apps/api/src/routers/billing_information.py b/apps/api/src/routers/billing_information.py new file mode 100644 index 00000000..44f22b67 --- /dev/null +++ b/apps/api/src/routers/billing_information.py @@ -0,0 +1,49 @@ +import logging +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException + +from src.dependencies import ( + RateLimitResponse, + rate_limit, + rate_limit_profile, + rate_limit_tiered, +) +from src.interfaces import db +from src.models import Billing, BillingInsertRequest, BillingUpdateRequest + +router = APIRouter(prefix="/billing", tags=["billings"]) + +logger = logging.getLogger("root") + + +@router.get("/{id}") +def get_billings(id: UUID) -> Billing: + response = db.get_billing(id) + if not response: + raise HTTPException(404, "billing information not found") + + return response + + +@router.post("/") +def insert_billing(subscription: BillingInsertRequest) -> Billing: + return db.insert_billing(subscription) + + +@router.delete("/{profile_id}") +def delete_billing(profile_id: UUID) -> Billing: + response = db.delete_billing(profile_id) + if not response: + raise HTTPException(404, "stripe subscription id not found") + + return response + + +@router.patch("/{profile_id}") +def update_billing(profile_id: UUID, content: BillingUpdateRequest) -> Billing: + response = db.update_billing(profile_id, content) + if not response: + raise HTTPException(404, "message not found") + + return response diff --git a/apps/api/src/routers/crews.py b/apps/api/src/routers/crews.py index 3a18cc29..a837b48a 100644 --- a/apps/api/src/routers/crews.py +++ b/apps/api/src/routers/crews.py @@ -4,7 +4,12 @@ from fastapi import APIRouter, Depends, HTTPException from src.interfaces import db -from src.models import CrewInsertRequest, Crew, CrewUpdateRequest, CrewGetRequest +from src.models import ( + Crew, + CrewGetRequest, + CrewInsertRequest, + CrewUpdateRequest, +) router = APIRouter( prefix="/crews", diff --git a/apps/api/src/routers/messages.py b/apps/api/src/routers/messages.py index 2c2ab76e..62fa88ae 100644 --- a/apps/api/src/routers/messages.py +++ b/apps/api/src/routers/messages.py @@ -2,6 +2,7 @@ from uuid import UUID from fastapi import APIRouter, Depends, HTTPException +from postgrest.exceptions import APIError from src.dependencies import ( RateLimitResponse, @@ -10,8 +11,12 @@ rate_limit_tiered, ) from src.interfaces import db -from src.models import Message, MessageInsertRequest, Message, MessageUpdateRequest, MessageGetRequest -from postgrest.exceptions import APIError +from src.models import ( + Message, + MessageGetRequest, + MessageInsertRequest, + MessageUpdateRequest, +) router = APIRouter(prefix="/messages", tags=["messages"]) @@ -33,7 +38,7 @@ def delete_message(message_id: UUID) -> Message: response = db.delete_message(message_id) if not response: raise HTTPException(404, "message not found") - + return response @@ -51,5 +56,5 @@ def get_message(message_id: UUID) -> Message: response = db.get_message(message_id) if not response: raise HTTPException(404, "message not found") - - return response \ No newline at end of file + + return response diff --git a/apps/api/src/routers/profiles.py b/apps/api/src/routers/profiles.py index e59dcced..e2540dc3 100644 --- a/apps/api/src/routers/profiles.py +++ b/apps/api/src/routers/profiles.py @@ -6,9 +6,9 @@ from src.interfaces import db from src.models import ( Profile, - ProfileUpdateRequest, + ProfileGetRequest, ProfileInsertRequest, - ProfileGetRequest + ProfileUpdateRequest, ) router = APIRouter(prefix="/profiles", tags=["profiles"]) @@ -31,7 +31,7 @@ def get_profile_by_id(profile_id: UUID) -> Profile: raise HTTPException(404, "profile not found") return profile - + @router.delete("/{profile_id}") def delete_profile(profile_id: UUID) -> Profile: @@ -46,4 +46,3 @@ def update_profile( raise HTTPException(404, "profile not found") return db.update_profile(profile_id, profile_update_request) - diff --git a/apps/api/src/routers/rest.py b/apps/api/src/routers/rest.py index 48e5dcbf..fa928c59 100644 --- a/apps/api/src/routers/rest.py +++ b/apps/api/src/routers/rest.py @@ -1,57 +1,24 @@ -#import logging -#from uuid import UUID -# -#from fastapi import APIRouter, HTTPException -#from src.models import Crew, Message, Session -#from src.rest import comment_bot -#from src.rest.interfaces import db -#from src.rest.models import PublishCommentRequest, PublishCommentResponse -# -#router = APIRouter(prefix="/rest", tags=["rest"]) -# -#logger = logging.getLogger("root") -# -# -#@router.post("/") -#def publish_comment(publish_request: PublishCommentRequest): -# updated_content = comment_bot.publish_comment( -# publish_request.lead_id, -# publish_request.comment, -# publish_request.reddit_username, -# publish_request.reddit_password, -# ) -# if updated_content is None: -# raise HTTPException(404, "lead not found") -# -# return updated_content -# -#@router.get("/") -#def get_leads() -> list[PublishCommentResponse]: -# return db.get_all_leads() -# -import os -from dotenv import load_dotenv import logging -import diskcache as dc +import os import threading from uuid import uuid4 -from src.rest.saving import update_db_with_submission -from src.rest import mail -from src.rest.reddit_utils import get_subreddits -from src.rest.relevance_bot import evaluate_relevance +import diskcache as dc +from dotenv import load_dotenv +from fastapi import APIRouter, HTTPException +from fastapi.responses import RedirectResponse + +from src.rest import comment_bot, mail from src.rest.interfaces import db -from src.rest import comment_bot from src.rest.models import ( - PublishCommentRequest, - GenerateCommentRequest, FalseLead, + GenerateCommentRequest, + PublishCommentRequest, ) +from src.rest.reddit_utils import get_subreddits from src.rest.reddit_worker import RedditStreamWorker - -from fastapi import APIRouter, HTTPException -from fastapi.responses import RedirectResponse - +from src.rest.relevance_bot import evaluate_relevance +from src.rest.saving import update_db_with_submission # Relevant subreddits to Startino SUBREDDIT_NAMES = ( diff --git a/apps/api/src/routers/sessions.py b/apps/api/src/routers/sessions.py index 18ca18e4..b928e69b 100644 --- a/apps/api/src/routers/sessions.py +++ b/apps/api/src/routers/sessions.py @@ -1,6 +1,7 @@ import logging +from datetime import UTC, datetime from typing import cast -from uuid import UUID +from uuid import UUID, uuid4 from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException @@ -14,18 +15,18 @@ ) from src.interfaces import db from src.models import ( - CrewProcessed, Crew, + CrewProcessed, Message, + Session, + SessionGetRequest, SessionRunRequest, SessionRunResponse, - Session, - Session, + SessionStatus, SessionUpdateRequest, - SessionGetRequest, ) from src.models.session import SessionInsertRequest -from src.parser import process_crew, get_processed_crew_by_id +from src.parser import get_processed_crew_by_id, process_crew router = APIRouter( prefix="/sessions", @@ -36,9 +37,7 @@ @router.get("/") -def get_sessions( - q: SessionGetRequest = Depends() -) -> list[Session]: +def get_sessions(q: SessionGetRequest = Depends()) -> list[Session]: return db.get_sessions(q.profile_id, q.crew_id, q.title, q.status) @@ -47,13 +46,13 @@ def get_session(session_id: UUID) -> Session: response = db.get_session(session_id) if response is None: raise HTTPException(500, "failed validation") - # not sure if 500 is correct, but this is failed validation on the returned data, so + # not sure if 500 is correct, but this is failed validation on the returned data, so # it makes sense in my mind to raise a server error for that - + return response # pretty sure this response object will always be a session, so casting it to stop typing errors - - + + @router.patch("/{session_id}") def update_session(session_id: UUID, content: SessionUpdateRequest) -> Session: return db.update_session(session_id, content) @@ -93,6 +92,8 @@ async def run_crew( if mock: message, crew_model = process_crew(Crew(**mocks.crew_model)) + request.crew_id = UUID("1c11a9bf-748f-482b-9746-6196f136401a") + request.profile_id = UUID("070c1d2e-9d72-4854-a55e-52ade5a42071") else: message, crew_model = get_processed_crew_by_id(request.crew_id) @@ -118,12 +119,16 @@ async def run_crew( status_code=400, detail=f"Session with id {request.session_id} found, but has no messages", ) - if session is None: session = Session( + id=uuid4(), + created_at=datetime.now(tz=UTC), crew_id=request.crew_id, profile_id=request.profile_id, title=request.session_title, + reply="", + last_opened_at=datetime.now(tz=UTC), + status=SessionStatus.RUNNING, ) db.post_session(session) @@ -134,12 +139,14 @@ async def on_reply( role: str, ) -> None: message = Message( + id=uuid4(), session_id=session.id, profile_id=session.profile_id, recipient_id=recipient_id, sender_id=sender_id, content=content, role=role, + created_at=datetime.now(tz=UTC), ) logger.debug(f"on_reply: {message}") db.post_message(message) @@ -147,8 +154,9 @@ async def on_reply( try: crew = AutogenCrew(session.profile_id, session, crew_model, on_reply) except ValueError as e: + db.delete_session(session.id) logger.error(e) - raise HTTPException(400, "crew model bad input") + raise HTTPException(400, f"crew model bad input: {e}") background_tasks.add_task(crew.run, message, messages=cached_messages) diff --git a/apps/api/src/routers/subscriptions.py b/apps/api/src/routers/subscriptions.py new file mode 100644 index 00000000..d8e73b08 --- /dev/null +++ b/apps/api/src/routers/subscriptions.py @@ -0,0 +1,52 @@ +import logging +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException + +from src.dependencies import ( + RateLimitResponse, + rate_limit, + rate_limit_profile, + rate_limit_tiered, +) +from src.interfaces import db +from src.models import ( + Subscription, + SubscriptionGetRequest, + SubscriptionInsertRequest, + SubscriptionUpdateRequest, +) + +router = APIRouter(prefix="/subscriptions", tags=["subscriptions"]) + +logger = logging.getLogger("root") + + +@router.get("/") +def get_subscriptions(q: SubscriptionGetRequest = Depends()) -> list[Subscription]: + return db.get_subscriptions(q.profile_id, q.stripe_subscription_id) + + +@router.post("/", status_code=201) +def insert_subscription(subscription: SubscriptionInsertRequest) -> Subscription: + return db.insert_subscription(subscription) + + +@router.delete("/{profile_id}") +def delete_subscription(profile_id: UUID) -> Subscription: + response = db.delete_subscription(profile_id) + if not response: + raise HTTPException(404, "stripe subscription id not found") + + return response + + +@router.patch("/{profile_id}") +def update_subscription( + profile_id: UUID, content: SubscriptionUpdateRequest +) -> Subscription: + response = db.update_subscription(profile_id, content) + if not response: + raise HTTPException(404, "message not found") + + return response diff --git a/apps/api/src/routers/tiers.py b/apps/api/src/routers/tiers.py new file mode 100644 index 00000000..a6de01ac --- /dev/null +++ b/apps/api/src/routers/tiers.py @@ -0,0 +1,54 @@ +import logging +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException + +from src.dependencies import ( + RateLimitResponse, + rate_limit, + rate_limit_profile, + rate_limit_tiered, +) +from src.interfaces import db +from src.models import ( + Tier, + TierGetRequest, + TierInsertRequest, + TierUpdateRequest, +) + +router = APIRouter(prefix="/tiers", tags=["tiers"]) + +logger = logging.getLogger("root") + + +@router.get("/{id}") +def get_tier(id: UUID) -> Tier: + response = db.get_tier(id) + if not response: + raise HTTPException(404, "tiers information not found") + + return response + + +@router.post("/") +def insert_tier(tier: TierInsertRequest) -> Tier: + return db.insert_tier(tier) + + +@router.delete("/{id}") +def delete_tier(id: UUID) -> Tier: + response = db.delete_tier(id) + if not response: + raise HTTPException(404, "stripe tier id not found") + + return response + + +@router.patch("/{id}") +def update_tier(id: UUID, content: TierUpdateRequest) -> Tier: + response = db.update_tier(id, content) + if not response: + raise HTTPException(404, "message not found") + + return response diff --git a/apps/api/src/routers/tools.py b/apps/api/src/routers/tools.py new file mode 100644 index 00000000..eb430fd4 --- /dev/null +++ b/apps/api/src/routers/tools.py @@ -0,0 +1,62 @@ +import logging +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException + +from src.interfaces import db +from src.models import ( + Agent, + Tool, + ToolGetRequest, + ToolInsertRequest, + ToolUpdateRequest, +) + +router = APIRouter( + prefix="/tools", + tags=["tools"], +) + + +@router.get("/") +def get_tools(q: ToolGetRequest = Depends()) -> list[Tool]: + return db.get_tools(q.name, q.api_key_type_id) + + +@router.get("/{tool_id}") +def get_tool(tool_id: UUID) -> Tool: + response = db.get_tool(tool_id) + if not response: + raise HTTPException(404, "tool not found") + + return response + + +@router.post("/", status_code=201) +def insert_tool(tool: ToolInsertRequest) -> Tool: + return db.insert_tool(tool) + + +@router.delete("/{tool_id}") +def delete_tool(tool_id: UUID) -> Tool: + response = db.delete_tool(tool_id) + if not response: + raise HTTPException(404, "could not find tool") + + return response + + +@router.patch("/{tool_id}") +def update_profile(tool_id: UUID, tool_update_request: ToolUpdateRequest) -> Tool: + if not db.get_tool(tool_id): + raise HTTPException(404, "tool not found") + + return db.update_tool(tool_id, tool_update_request) + + +@router.patch("/{agent_id}") +def add_tool(agent_id: UUID, tool_id: UUID) -> Agent: + if not db.get_agent(agent_id): + raise HTTPException(404, "agent not found") + + return db.update_agent_tool(agent_id, tool_id) diff --git a/apps/api/src/tools/__init__.py b/apps/api/src/tools/__init__.py index b5eab4fb..3762d235 100644 --- a/apps/api/src/tools/__init__.py +++ b/apps/api/src/tools/__init__.py @@ -3,8 +3,8 @@ import os import random from typing import Any -from dotenv import load_dotenv +from dotenv import load_dotenv from langchain_core.tools import BaseTool from src.tools.alpha_vantage import ID as ALPHA_VANTAGE_TOOL_ID @@ -13,20 +13,23 @@ from src.tools.arxiv_tool import ArxivTool from src.tools.bing import ID as BING_SEARCH_TOOL_ID from src.tools.bing import BingTool +from src.tools.brave_search import ID as BRAVE_TOOL_ID +from src.tools.brave_search import BraveSearchTool +from src.tools.duckduckgo_tool import ID as DDGS_TOOL_ID +from src.tools.duckduckgo_tool import DuckDuckGoSearchTool +from src.tools.google_serper import RESULTS_ID as GOOGLE_SERPER_RESULTS_TOOL_ID +from src.tools.google_serper import RUN_ID as GOOGLE_SERPER_RUN_TOOL_ID +from src.tools.google_serper import GoogleSerperResultsTool, GoogleSerperRunTool from src.tools.move_file import ID as MOVE_TOOL_ID from src.tools.move_file import MoveFileTool from src.tools.read_file import ID as READ_TOOL_ID from src.tools.read_file import ReadFileTool from src.tools.scraper import ID as SCRAPER_TOOL_ID from src.tools.scraper import ScraperTool +from src.tools.stackapi_tool import ID as STACKAPI_ID +from src.tools.stackapi_tool import StackAPISearchTool from src.tools.wikipedia_tool import ID as WIKIPEDIA_TOOL_ID from src.tools.wikipedia_tool import WikipediaTool -from src.tools.duckduckgo_tool import ID as DDGS_TOOL_ID -from src.tools.duckduckgo_tool import DuckDuckGoSearchTool -from src.tools.google_serper import RUN_ID as GOOGLE_SERPER_RUN_TOOL_ID -from src.tools.google_serper import GoogleSerperRunTool -from src.tools.google_serper import RESULTS_ID as GOOGLE_SERPER_RESULTS_TOOL_ID -from src.tools.google_serper import GoogleSerperResultsTool tools: dict = { ARXIV_TOOL_ID: ArxivTool, @@ -39,12 +42,15 @@ DDGS_TOOL_ID: DuckDuckGoSearchTool, GOOGLE_SERPER_RUN_TOOL_ID: GoogleSerperRunTool, GOOGLE_SERPER_RESULTS_TOOL_ID: GoogleSerperResultsTool, + BRAVE_TOOL_ID: BraveSearchTool, + STACKAPI_ID: StackAPISearchTool, } logger = logging.getLogger("root") load_dotenv() -def get_file_path_of_example(): + +def get_file_path_of_example() -> str: current_dir = os.getcwd() target_folder = os.path.join(current_dir, "src/tools/test_files") @@ -76,14 +82,14 @@ def get_tool_ids_from_agent(tools: list[dict[str, Any]]) -> list[str]: return [tool["id"] for tool in tools] -def has_param(cls, param_name): +def has_param(cls, param_name) -> bool: init_signature = inspect.signature(cls.__init__) return param_name in init_signature.parameters def generate_tool_from_uuid( tool: str, api_key_types: dict[str, str], api_keys: dict[str, str] -): +) -> BaseTool | None: for tool_id in tools: if tool_id == tool: tool_key_type = "" @@ -92,19 +98,18 @@ def generate_tool_from_uuid( if tool in api_key_types.keys(): # set the api_key_type to the current tools api_key_type (the api_key_types dict has key "tool_id" and value "api_key_type_id") tool_key_type = api_key_types[tool] - - if tool_key_type in api_keys.keys(): - # set current api key that will be given to current tool (the api_keys dict has key "api_key_type_íd" and value "api_key") - api_key = api_keys[tool_key_type] + if tool_key_type in api_keys.keys(): + # set current api key that will be given to current tool (the api_keys dict has key "api_key_type_íd" and value "api_key") + api_key = api_keys[tool_key_type] if has_param(tool_cls, "api_key"): - logger.info(f"has parameter 'api_key'") + logger.info("has parameter 'api_key'") if not api_key: raise TypeError( "api key should not be none when passed to tool that needs api key" ) tool_object = tools[tool_id](api_key=api_key) - logger.info(f"creating tool") + logger.info("creating tool") return tool_object logger.info("making tool without api_key") @@ -118,15 +123,17 @@ def generate_tool_from_uuid( bing_key = os.environ.get("BING_SUBSCRIPTION_KEY") alphavantage_key = os.environ.get("ALPHAVANTAGE_API_KEY") google_search_key = os.environ.get("GOOGLE_SEARCH_API_KEY") + brave_search_key = os.environ.get("BRAVE_API_KEY") print(serpapi_key, bing_key, alphavantage_key, google_search_key) if not all([serpapi_key, bing_key, alphavantage_key, google_search_key]): raise TypeError("a key was not found in env variables") api_keys = { - '3b64fe26-20b9-4064-907e-f2708b5f1656': serpapi_key, - "5281bbc4-45ea-4f4b-b790-e92c62bbc019": bing_key, - "8a29840f-4748-4ce4-88e6-44e1ef5b7637": alphavantage_key, - "4d950712-8b4c-4cc0-a24d-7599638119f2": google_search_key, + "3b64fe26-20b9-4064-907e-f2708b5f1656": serpapi_key, + "5281bbc4-45ea-4f4b-b790-e92c62bbc019": bing_key, + "8a29840f-4748-4ce4-88e6-44e1ef5b7637": alphavantage_key, + "4d950712-8b4c-4cc0-a24d-7599638119f2": google_search_key, + "58dc6249-3a0c-496b-91f3-27cf0054bfb0": brave_search_key, } api_key_types = { "fa4c2568-00d9-4e3c-9ab7-44f76f3a0e3f": "8a29840f-4748-4ce4-88e6-44e1ef5b7637", # alpha vantage @@ -134,6 +141,7 @@ def generate_tool_from_uuid( "71e4ddcc-4475-46f2-9816-894173b1292e": "5281bbc4-45ea-4f4b-b790-e92c62bbc019", # bing search "3e2665a8-6d73-42ee-a64f-50ddcc0621c6": "4d950712-8b4c-4cc0-a24d-7599638119f2", # google search (run) "1046fefb-a540-498f-8b96-7292523559e0": "4d950712-8b4c-4cc0-a24d-7599638119f2", # google search (results) + "3c0d3635-80f4-4286-aab6-c359795e1ac4": "58dc6249-3a0c-496b-91f3-27cf0054bfb0", # brave search } agents_tools = [ "f57d47fd-5783-4aac-be34-17ba36bb6242", # Move File Tool @@ -145,10 +153,12 @@ def generate_tool_from_uuid( "7dc53d81-cdac-4320-8077-1a7ab9497551", # DuckDuckGoSearch Tool "3e2665a8-6d73-42ee-a64f-50ddcc0621c6", # Google Serper Run "1046fefb-a540-498f-8b96-7292523559e0", # Google Serper Results + "3c0d3635-80f4-4286-aab6-c359795e1ac4", # Brave search + "612ddae6-ecdd-4900-9314-1a2c9de6003d", # StackAPI ] generated_tools = [] for tool in agents_tools: - tool = generate_tool_from_uuid(tool, api_key_types, api_keys) # type: ignore + tool = generate_tool_from_uuid(tool, api_key_types, api_keys) # type: ignore if tool is None: print("fail") else: diff --git a/apps/api/src/tools/alpha_vantage.py b/apps/api/src/tools/alpha_vantage.py index 137cd34d..69a7a061 100644 --- a/apps/api/src/tools/alpha_vantage.py +++ b/apps/api/src/tools/alpha_vantage.py @@ -21,7 +21,7 @@ class AlphaVantageToolInput(BaseModel): class AlphaVantageTool(Tool, BaseTool): args_schema: Type[BaseModel] = AlphaVantageToolInput - def __init__(self, api_key): + def __init__(self, api_key: str) -> None: alpha_vantage = AlphaVantageAPIWrapper(alphavantage_api_key=api_key) super().__init__( name="alpha_vantage_tool", diff --git a/apps/api/src/tools/arxiv_tool.py b/apps/api/src/tools/arxiv_tool.py index 1caf2ff3..54bae5c1 100644 --- a/apps/api/src/tools/arxiv_tool.py +++ b/apps/api/src/tools/arxiv_tool.py @@ -18,7 +18,7 @@ class ArxivToolInput(BaseModel): class ArxivTool(Tool, BaseTool): args_schema: Type[BaseModel] = ArxivToolInput - def __init__(self): + def __init__(self) -> None: super().__init__( name="arxiv_tool", func=arxiv.run, @@ -27,6 +27,3 @@ def __init__(self): __all__ = ["ArxivTool"] - - - diff --git a/apps/api/src/tools/bing.py b/apps/api/src/tools/bing.py index a6a80912..97ffe3ae 100644 --- a/apps/api/src/tools/bing.py +++ b/apps/api/src/tools/bing.py @@ -2,25 +2,31 @@ from typing import Type from dotenv import load_dotenv -from langchain_community.tools import BingSearchRun -from langchain_community.utilities import BingSearchAPIWrapper from langchain.agents import Tool from langchain.pydantic_v1 import BaseModel, Field from langchain.tools import BaseTool +from langchain_community.tools import BingSearchRun from langchain_community.tools.bing_search.tool import BingSearchRun +from langchain_community.utilities import BingSearchAPIWrapper from langchain_community.utilities.bing_search import BingSearchAPIWrapper # TODO: Split this tool into 2 different tools, like I did with the Google Serper tool, so a BingSearchRun and a BingSearchResults -BING_SEARCH_URL="https://api.bing.microsoft.com/v7.0/search" +BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search" ID = "71e4ddcc-4475-46f2-9816-894173b1292e" class BingToolInput(BaseModel): - tool_input: str = Field(title="Query", description="Search query input to search bing") + tool_input: str = Field( + title="Query", description="Search query input to search bing" + ) - nr_of_results: int = Field(title="Number of results", description="The amount of returned results from the search", default=5) + nr_of_results: int = Field( + title="Number of results", + description="The amount of returned results from the search", + default=5, + ) class BingTool(Tool, BaseTool): @@ -28,7 +34,7 @@ class BingTool(Tool, BaseTool): api_key: str = "" # needs to be empty string or it throws validation errors - def __init__(self, api_key): + def __init__(self, api_key: str) -> None: super().__init__( name="bing_search", func=self._run, @@ -37,8 +43,8 @@ def __init__(self, api_key): Input should be a search query.""", ) self.api_key = api_key - - def _run(self, tool_input: str, nr_of_results: int = 5): + + def _run(self, tool_input: str, nr_of_results: int = 5) -> str: wrapper = BingSearchAPIWrapper( bing_subscription_key=self.api_key, bing_search_url=BING_SEARCH_URL, diff --git a/apps/api/src/tools/brave_search.py b/apps/api/src/tools/brave_search.py new file mode 100644 index 00000000..0bca14b0 --- /dev/null +++ b/apps/api/src/tools/brave_search.py @@ -0,0 +1,29 @@ +import logging +from typing import Callable, Optional, Type + +from langchain.agents import Tool +from langchain.pydantic_v1 import BaseModel, Field +from langchain.tools import BaseTool +from langchain_community.tools import BraveSearch + +ID = "3c0d3635-80f4-4286-aab6-c359795e1ac4" + +logger = logging.getLogger("root") + + +class BraveSearchToolInput(BaseModel): + tool_input: str = Field( + title="query", description="Search query input to look up on brave" + ) + + +class BraveSearchTool(Tool, BaseTool): + args_schema: Type[BaseModel] = BraveSearchToolInput + + def __init__(self, api_key: str) -> None: + tool = BraveSearch.from_api_key(api_key=api_key, search_kwargs={"count": 3}) + super().__init__( + name="brave_search", + func=tool.run, + description="""search the internet through the search engine brave""", + ) diff --git a/apps/api/src/tools/duckduckgo_tool.py b/apps/api/src/tools/duckduckgo_tool.py index 232bf0cf..dc794abe 100644 --- a/apps/api/src/tools/duckduckgo_tool.py +++ b/apps/api/src/tools/duckduckgo_tool.py @@ -4,13 +4,16 @@ from langchain.agents import Tool from langchain.pydantic_v1 import BaseModel, Field from langchain.tools import BaseTool -from langchain_community.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper from langchain_community.tools import DuckDuckGoSearchRun +from langchain_community.utilities.duckduckgo_search import ( + DuckDuckGoSearchAPIWrapper, +) -ID="7dc53d81-cdac-4320-8077-1a7ab9497551" +ID = "7dc53d81-cdac-4320-8077-1a7ab9497551" logger = logging.getLogger("root") + class DuckDuckGoSearchToolInput(BaseModel): tool_input: str = Field( title="query", description="Search query input to look up on duck duck go" @@ -19,25 +22,29 @@ class DuckDuckGoSearchToolInput(BaseModel): title="region", description="Region to use for the search", default="wt-wt" ) source: str = Field( - title="source", description="Source of information, ex 'text' or 'news'", default="text" + title="source", + description="Source of information, ex 'text' or 'news'", + default="text", ) class DuckDuckGoSearchTool(Tool, BaseTool): args_schema: Type[BaseModel] = DuckDuckGoSearchToolInput - def __init__(self): + def __init__(self) -> None: super().__init__( name="duck_duck_go_search", func=self._run, description="""search the internet through the search engine duck duck go""", ) - - def _run(self, tool_input: str, region: str = "wt-wt", source: str = "text") -> Callable: + + def _run( + self, tool_input: str, region: str = "wt-wt", source: str = "text" + ) -> str: """Method passed to agent so the agent can initialize the wrapper with additional args""" logger.debug("Creating DuckDuckGo wrapper") ddgs_tool = DuckDuckGoSearchRun( wrapper=DuckDuckGoSearchAPIWrapper(region=region, source=source) ) - return ddgs_tool.run(tool_input=tool_input) \ No newline at end of file + return ddgs_tool.run(tool_input=tool_input) diff --git a/apps/api/src/tools/google_serper.py b/apps/api/src/tools/google_serper.py index 37980c71..3319aec4 100644 --- a/apps/api/src/tools/google_serper.py +++ b/apps/api/src/tools/google_serper.py @@ -6,19 +6,30 @@ from langchain.tools import BaseTool from langchain_community.utilities.google_serper import GoogleSerperAPIWrapper -RUN_ID="3e2665a8-6d73-42ee-a64f-50ddcc0621c6" +RUN_ID = "3e2665a8-6d73-42ee-a64f-50ddcc0621c6" -RESULTS_ID="1046fefb-a540-498f-8b96-7292523559e0" +RESULTS_ID = "1046fefb-a540-498f-8b96-7292523559e0" logger = logging.getLogger("root") + class GoogleSerperRunToolInput(BaseModel): - query: str = Field(title="query", description="search query input, looks up on google search") + query: str = Field( + title="query", description="search query input, looks up on google search" + ) + class GoogleSerperResultsToolInput(BaseModel): - query: str = Field(title="query", description="search query input, looks up on google search and returns metadata") + query: str = Field( + title="query", + description="search query input, looks up on google search and returns metadata", + ) - nr_of_results: int = Field(title="number of results", description="number of results shown per page", default=10) + nr_of_results: int = Field( + title="number of results", + description="number of results shown per page", + default=10, + ) region: str = Field( title="region", @@ -27,7 +38,7 @@ class GoogleSerperResultsToolInput(BaseModel): ) language: str = Field( title="language", - description="sets the interface language of the search, given as a two letter code, for example English is 'en' and french is 'fr'", + description="sets the interface language of the search, given as a two letter code, for example English is 'en' and french is 'fr'", default="en", ) search_type: Literal["news", "search", "places", "images"] = Field( @@ -44,8 +55,8 @@ class GoogleSerperResultsToolInput(BaseModel): class GoogleSerperRunTool(Tool, BaseTool): args_schema: Type[BaseModel] = GoogleSerperRunToolInput - - def __init__(self, api_key): + + def __init__(self, api_key: str) -> None: search = GoogleSerperAPIWrapper(serper_api_key=api_key) super().__init__( name="google_serper_run_tool", @@ -53,12 +64,12 @@ def __init__(self, api_key): description="""search the web with serper's google search api""", ) - + class GoogleSerperResultsTool(Tool, BaseTool): args_schema: Type[BaseModel] = GoogleSerperResultsToolInput api_key: str = "" - def __init__(self, api_key): + def __init__(self, api_key: str) -> None: super().__init__( name="google_serper_results_tool", func=self._run, @@ -69,22 +80,21 @@ def __init__(self, api_key): def _run( self, query: str, - nr_of_results: int = 10, - region: str = "us", + nr_of_results: int = 10, + region: str = "us", language: str = "en", - search_type: Literal["news", "search", "places", "images"] = "search", + search_type: Literal["news", "search", "places", "images"] = "search", time_based_search: str | None = None, - ): + ) -> dict: """Method passed to the agent to allow it to pass additional optional parameters, similar to the DDG search tool""" search = GoogleSerperAPIWrapper( serper_api_key=self.api_key, - k=nr_of_results, - gl=region, - hl=language, - type=search_type, - tbs=time_based_search + k=nr_of_results, + gl=region, + hl=language, + type=search_type, + tbs=time_based_search, ) - return search.results(query) - \ No newline at end of file + return search.results(query) diff --git a/apps/api/src/tools/scraper.py b/apps/api/src/tools/scraper.py index eb2bb65a..b78991fa 100644 --- a/apps/api/src/tools/scraper.py +++ b/apps/api/src/tools/scraper.py @@ -12,8 +12,6 @@ ID = "4ac25953-dc41-42d5-b9f2-bcae3b2c1d9f" API_KEY_TYPE = "3b64fe26-20b9-4064-907e-f2708b5f1656" -# key = os.environ.get("SERPAPI_API_KEY") - class ScraperToolInput(BaseModel): query: str = Field( @@ -24,7 +22,7 @@ class ScraperToolInput(BaseModel): class ScraperTool(Tool, BaseTool): args_schema: Type[BaseModel] = ScraperToolInput - def __init__(self, api_key): + def __init__(self, api_key: str) -> None: web_scrape = SerpAPIWrapper(serpapi_api_key=api_key) super().__init__( name="scraper_tool", diff --git a/apps/api/src/tools/stackapi_tool.py b/apps/api/src/tools/stackapi_tool.py new file mode 100644 index 00000000..c5249281 --- /dev/null +++ b/apps/api/src/tools/stackapi_tool.py @@ -0,0 +1,30 @@ +import logging +from typing import Callable, Optional, Type + +from langchain.agents import Tool +from langchain.pydantic_v1 import BaseModel, Field +from langchain.tools import BaseTool +from langchain_community.tools.stackexchange.tool import StackExchangeTool +from langchain_community.utilities import StackExchangeAPIWrapper + +ID = "612ddae6-ecdd-4900-9314-1a2c9de6003d" + +logger = logging.getLogger("root") + + +class StackAPIToolInput(BaseModel): + query: str = Field( + title="query", description="Search query input to look up on Stack Exchange" + ) + + +class StackAPISearchTool(Tool, BaseTool): + args_schema: Type[BaseModel] = StackAPIToolInput + + def __init__(self) -> None: + tool = StackExchangeTool(api_wrapper=StackExchangeAPIWrapper()) + super().__init__( + name="stack_api_tool", + func=tool._run, + description="""StackAPI searches through a network of question-and-answer (Q&A) websites""", + ) diff --git a/apps/api/src/tools/wikipedia_tool.py b/apps/api/src/tools/wikipedia_tool.py index 2b402b10..a88e1014 100644 --- a/apps/api/src/tools/wikipedia_tool.py +++ b/apps/api/src/tools/wikipedia_tool.py @@ -18,7 +18,7 @@ class WikipediaToolInput(BaseModel): class WikipediaTool(Tool, BaseTool): args_schema: Type[BaseModel] = WikipediaToolInput - def __init__(self): + def __init__(self) -> None: wiki_tool = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()) super().__init__( name="wikipedia",