From a564fd27e4c7f8fc3f7947009974fbe5e7be0c21 Mon Sep 17 00:00:00 2001 From: Kye Gomez <98760976+kyegomez@users.noreply.github.com> Date: Thu, 12 Dec 2024 11:46:44 -0800 Subject: [PATCH] Update main.py --- api/main.py | 288 +++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 240 insertions(+), 48 deletions(-) diff --git a/api/main.py b/api/main.py index 768e8d962..cfc5e1b2f 100644 --- a/api/main.py +++ b/api/main.py @@ -1,41 +1,34 @@ import os +import secrets +import traceback +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime, timedelta +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional +from uuid import UUID, uuid4 + +import uvicorn +from dotenv import load_dotenv from fastapi import ( + BackgroundTasks, + Depends, FastAPI, + Header, HTTPException, - status, Query, - BackgroundTasks, + Request, + status, ) from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel, Field -from typing import Optional, Dict, Any, List from loguru import logger -import uvicorn -from datetime import datetime, timedelta -from uuid import UUID, uuid4 -from enum import Enum -from pathlib import Path -from concurrent.futures import ThreadPoolExecutor -import traceback +from pydantic import BaseModel, Field from swarms import Agent -from dotenv import load_dotenv -print ("starting") # Load environment variables load_dotenv() -# Configure Loguru -logger.add( - "logs/api_{time}.log", - rotation="500 MB", - retention="10 days", - level="INFO", - format="{time} {level} {message}", - backtrace=True, - diagnose=True, -) - class AgentStatus(str, Enum): """Enum for agent status.""" @@ -44,6 +37,28 @@ class AgentStatus(str, Enum): PROCESSING = "processing" ERROR = "error" MAINTENANCE = "maintenance" + + +# Security configurations +API_KEY_LENGTH = 32 # Length of generated API keys + +class APIKey(BaseModel): + key: str + name: str + created_at: datetime + last_used: datetime + is_active: bool = True + +class APIKeyCreate(BaseModel): + name: str # A friendly name for the API key + +class User(BaseModel): + id: UUID + username: str + is_active: bool = True + is_admin: bool = False + api_keys: Dict[str, APIKey] = {} # key -> APIKey object + class AgentConfig(BaseModel): @@ -105,6 +120,7 @@ class AgentConfig(BaseModel): ) + class AgentUpdate(BaseModel): """Model for updating agent configuration.""" @@ -173,6 +189,9 @@ class AgentStore: def __init__(self): self.agents: Dict[UUID, Agent] = {} self.agent_metadata: Dict[UUID, Dict[str, Any]] = {} + self.users: Dict[UUID, User] = {} # user_id -> User + self.api_keys: Dict[str, UUID] = {} # api_key -> user_id + self.user_agents: Dict[UUID, List[UUID]] = {} # user_id -> [agent_ids] self.executor = ThreadPoolExecutor(max_workers=4) self._ensure_directories() @@ -180,8 +199,56 @@ def _ensure_directories(self): """Ensure required directories exist.""" Path("logs").mkdir(exist_ok=True) Path("states").mkdir(exist_ok=True) + + def create_api_key(self, user_id: UUID, key_name: str) -> APIKey: + """Create a new API key for a user.""" + if user_id not in self.users: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found" + ) - async def create_agent(self, config: AgentConfig) -> UUID: + # Generate a secure random API key + api_key = secrets.token_urlsafe(API_KEY_LENGTH) + + # Create the API key object + key_object = APIKey( + key=api_key, + name=key_name, + created_at=datetime.utcnow(), + last_used=datetime.utcnow() + ) + + # Store the API key + self.users[user_id].api_keys[api_key] = key_object + self.api_keys[api_key] = user_id + + return key_object + + async def verify_agent_access(self, agent_id: UUID, user_id: UUID) -> bool: + """Verify if a user has access to an agent.""" + if agent_id not in self.agents: + return False + return ( + self.agent_metadata[agent_id]["owner_id"] == user_id + or self.users[user_id].is_admin + ) + + def validate_api_key(self, api_key: str) -> Optional[UUID]: + """Validate an API key and return the associated user ID.""" + user_id = self.api_keys.get(api_key) + if not user_id or api_key not in self.users[user_id].api_keys: + return None + + key_object = self.users[user_id].api_keys[api_key] + if not key_object.is_active: + return None + + # Update last used timestamp + key_object.last_used = datetime.utcnow() + return user_id + + async def create_agent(self, config: AgentConfig, user_id: UUID) -> UUID: """Create a new agent with the given configuration.""" try: @@ -220,7 +287,11 @@ async def create_agent(self, config: AgentConfig) -> UUID: "successful_completions": 0, } - logger.info(f"Created agent with ID: {agent_id}") + # Add to user's agents list + if user_id not in self.user_agents: + self.user_agents[user_id] = [] + self.user_agents[user_id].append(agent_id) + return agent_id except Exception as e: @@ -465,6 +536,35 @@ async def process_completion( finally: metadata["status"] = AgentStatus.IDLE +class StoreManager: + _instance = None + + @classmethod + def get_instance(cls) -> 'AgentStore': + if cls._instance is None: + cls._instance = AgentStore() + return cls._instance + +# Modify the dependency function +def get_store() -> AgentStore: + """Dependency to get the AgentStore instance.""" + return StoreManager.get_instance() + +# Security utility function using the new dependency +async def get_current_user( + api_key: str = Header(..., description="API key for authentication"), + store: AgentStore = Depends(get_store) +) -> User: + """Validate API key and return current user.""" + user_id = store.validate_api_key(api_key) + if not user_id: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired API key", + headers={"WWW-Authenticate": "ApiKey"}, + ) + return store.users[user_id] + class SwarmsAPI: """Enhanced API class for Swarms agent integration.""" @@ -477,7 +577,9 @@ def __init__(self): docs_url="/v1/docs", redoc_url="/v1/redoc", ) - self.store = AgentStore() + # Initialize the store using the singleton manager + self.store = StoreManager.get_instance() + # Configure CORS self.app.add_middleware( CORSMiddleware, @@ -493,11 +595,102 @@ def __init__(self): def _setup_routes(self): """Set up API routes.""" + + # In your API code + @self.app.post("/v1/users", response_model=Dict[str, Any]) + async def create_user(request: Request): + """Create a new user and initial API key.""" + try: + body = await request.json() + username = body.get("username") + if not username or len(username) < 3: + raise HTTPException(status_code=400, detail="Invalid username") + + user_id = uuid4() + user = User(id=user_id, username=username) + self.store.users[user_id] = user + initial_key = self.store.create_api_key(user_id, "Initial Key") + return {"user_id": user_id, "api_key": initial_key.key} + except Exception as e: + logger.error(f"Error creating user: {str(e)}") + raise HTTPException(status_code=400, detail=str(e)) + + + + @self.app.post("/v1/users/{user_id}/api-keys", response_model=APIKey) + async def create_api_key( + user_id: UUID, + key_create: APIKeyCreate, + current_user: User = Depends(get_current_user) + ): + """Create a new API key for a user.""" + if current_user.id != user_id and not current_user.is_admin: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to create API keys for this user" + ) + + return self.store.create_api_key(user_id, key_create.name) + @self.app.get("/v1/users/{user_id}/api-keys", response_model=List[APIKey]) + async def list_api_keys( + user_id: UUID, + current_user: User = Depends(get_current_user) + ): + """List all API keys for a user.""" + if current_user.id != user_id and not current_user.is_admin: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to view API keys for this user" + ) + + return list(self.store.users[user_id].api_keys.values()) + + @self.app.delete("/v1/users/{user_id}/api-keys/{key}") + async def revoke_api_key( + user_id: UUID, + key: str, + current_user: User = Depends(get_current_user) + ): + """Revoke an API key.""" + if current_user.id != user_id and not current_user.is_admin: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to revoke API keys for this user" + ) + + if key in self.store.users[user_id].api_keys: + self.store.users[user_id].api_keys[key].is_active = False + del self.store.api_keys[key] + return {"status": "API key revoked"} + + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="API key not found" + ) + + @self.app.get("/v1/users/me/agents", response_model=List[AgentSummary]) + async def list_user_agents( + current_user: User = Depends(get_current_user), + tags: Optional[List[str]] = Query(None), + status: Optional[AgentStatus] = None, + ): + """List all agents owned by the current user.""" + user_agents = self.store.user_agents.get(current_user.id, []) + return [ + agent for agent in await self.store.list_agents(tags, status) + if agent.agent_id in user_agents + ] + + + # Modify existing routes to use API key authentication @self.app.post("/v1/agent", response_model=Dict[str, UUID]) - async def create_agent(config: AgentConfig): + async def create_agent( + config: AgentConfig, + current_user: User = Depends(get_current_user) + ): """Create a new agent with the specified configuration.""" - agent_id = await self.store.create_agent(config) + agent_id = await self.store.create_agent(config, current_user.id) return {"agent_id": agent_id} @self.app.get("/v1/agents", response_model=List[AgentSummary]) @@ -611,28 +804,27 @@ async def _cleanup_old_metrics(self, agent_id: UUID): if k > cutoff } - def create_app() -> FastAPI: """Create and configure the FastAPI application.""" - print("create app") + logger.info("Creating FastAPI application") api = SwarmsAPI() - return api.app + app = api.app + logger.info("FastAPI application created successfully") + return app +app = create_app() -#if __name__ == "__main__": if __name__ == '__main__': - #freeze_support() - print("yes in main") - # Configure uvicorn logging - logger.info("API Starting") - - uvicorn.run( - "main:create_app", - host="0.0.0.0", - port=8000, - # reload=True, - # workers=4, - ) -else: - print("not in main") - + try: + logger.info("Starting API server...") + print("Starting API server on http://0.0.0.0:8000") + + uvicorn.run( + app, # Pass the app instance directly + host="0.0.0.0", + port=8000, + log_level="info" + ) + except Exception as e: + logger.error(f"Failed to start API: {str(e)}") + print(f"Error starting server: {str(e)}")