Skip to content

Commit

Permalink
Update main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kyegomez authored Dec 12, 2024
1 parent 770b4a1 commit a564fd2
Showing 1 changed file with 240 additions and 48 deletions.
288 changes: 240 additions & 48 deletions api/main.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,34 @@
import os
import secrets
import traceback
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional
from uuid import UUID, uuid4

import uvicorn
from dotenv import load_dotenv
from fastapi import (
BackgroundTasks,
Depends,
FastAPI,
Header,
HTTPException,
status,
Query,
BackgroundTasks,
Request,
status,
)
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import Optional, Dict, Any, List
from loguru import logger
import uvicorn
from datetime import datetime, timedelta
from uuid import UUID, uuid4
from enum import Enum
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor
import traceback
from pydantic import BaseModel, Field

from swarms import Agent
from dotenv import load_dotenv

print ("starting")
# Load environment variables
load_dotenv()

# Configure Loguru
logger.add(
"logs/api_{time}.log",
rotation="500 MB",
retention="10 days",
level="INFO",
format="{time} {level} {message}",
backtrace=True,
diagnose=True,
)


class AgentStatus(str, Enum):
"""Enum for agent status."""
Expand All @@ -44,6 +37,28 @@ class AgentStatus(str, Enum):
PROCESSING = "processing"
ERROR = "error"
MAINTENANCE = "maintenance"


# Security configurations
API_KEY_LENGTH = 32 # Length of generated API keys

class APIKey(BaseModel):
key: str
name: str
created_at: datetime
last_used: datetime
is_active: bool = True

class APIKeyCreate(BaseModel):
name: str # A friendly name for the API key

class User(BaseModel):
id: UUID
username: str
is_active: bool = True
is_admin: bool = False
api_keys: Dict[str, APIKey] = {} # key -> APIKey object



class AgentConfig(BaseModel):
Expand Down Expand Up @@ -105,6 +120,7 @@ class AgentConfig(BaseModel):
)



class AgentUpdate(BaseModel):
"""Model for updating agent configuration."""

Expand Down Expand Up @@ -173,15 +189,66 @@ class AgentStore:
def __init__(self):
self.agents: Dict[UUID, Agent] = {}
self.agent_metadata: Dict[UUID, Dict[str, Any]] = {}
self.users: Dict[UUID, User] = {} # user_id -> User
self.api_keys: Dict[str, UUID] = {} # api_key -> user_id
self.user_agents: Dict[UUID, List[UUID]] = {} # user_id -> [agent_ids]
self.executor = ThreadPoolExecutor(max_workers=4)
self._ensure_directories()

def _ensure_directories(self):
"""Ensure required directories exist."""
Path("logs").mkdir(exist_ok=True)
Path("states").mkdir(exist_ok=True)

def create_api_key(self, user_id: UUID, key_name: str) -> APIKey:
"""Create a new API key for a user."""
if user_id not in self.users:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found"
)

async def create_agent(self, config: AgentConfig) -> UUID:
# Generate a secure random API key
api_key = secrets.token_urlsafe(API_KEY_LENGTH)

# Create the API key object
key_object = APIKey(
key=api_key,
name=key_name,
created_at=datetime.utcnow(),
last_used=datetime.utcnow()
)

# Store the API key
self.users[user_id].api_keys[api_key] = key_object
self.api_keys[api_key] = user_id

return key_object

async def verify_agent_access(self, agent_id: UUID, user_id: UUID) -> bool:
"""Verify if a user has access to an agent."""
if agent_id not in self.agents:
return False
return (
self.agent_metadata[agent_id]["owner_id"] == user_id
or self.users[user_id].is_admin
)

def validate_api_key(self, api_key: str) -> Optional[UUID]:
"""Validate an API key and return the associated user ID."""
user_id = self.api_keys.get(api_key)
if not user_id or api_key not in self.users[user_id].api_keys:
return None

key_object = self.users[user_id].api_keys[api_key]
if not key_object.is_active:
return None

# Update last used timestamp
key_object.last_used = datetime.utcnow()
return user_id

async def create_agent(self, config: AgentConfig, user_id: UUID) -> UUID:
"""Create a new agent with the given configuration."""
try:

Expand Down Expand Up @@ -220,7 +287,11 @@ async def create_agent(self, config: AgentConfig) -> UUID:
"successful_completions": 0,
}

logger.info(f"Created agent with ID: {agent_id}")
# Add to user's agents list
if user_id not in self.user_agents:
self.user_agents[user_id] = []
self.user_agents[user_id].append(agent_id)

return agent_id

except Exception as e:
Expand Down Expand Up @@ -465,6 +536,35 @@ async def process_completion(
finally:
metadata["status"] = AgentStatus.IDLE

class StoreManager:
_instance = None

@classmethod
def get_instance(cls) -> 'AgentStore':
if cls._instance is None:
cls._instance = AgentStore()
return cls._instance

# Modify the dependency function
def get_store() -> AgentStore:
"""Dependency to get the AgentStore instance."""
return StoreManager.get_instance()

# Security utility function using the new dependency
async def get_current_user(
api_key: str = Header(..., description="API key for authentication"),
store: AgentStore = Depends(get_store)
) -> User:
"""Validate API key and return current user."""
user_id = store.validate_api_key(api_key)
if not user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired API key",
headers={"WWW-Authenticate": "ApiKey"},
)
return store.users[user_id]


class SwarmsAPI:
"""Enhanced API class for Swarms agent integration."""
Expand All @@ -477,7 +577,9 @@ def __init__(self):
docs_url="/v1/docs",
redoc_url="/v1/redoc",
)
self.store = AgentStore()
# Initialize the store using the singleton manager
self.store = StoreManager.get_instance()

# Configure CORS
self.app.add_middleware(
CORSMiddleware,
Expand All @@ -493,11 +595,102 @@ def __init__(self):

def _setup_routes(self):
"""Set up API routes."""

# In your API code
@self.app.post("/v1/users", response_model=Dict[str, Any])
async def create_user(request: Request):
"""Create a new user and initial API key."""
try:
body = await request.json()
username = body.get("username")
if not username or len(username) < 3:
raise HTTPException(status_code=400, detail="Invalid username")

user_id = uuid4()
user = User(id=user_id, username=username)
self.store.users[user_id] = user
initial_key = self.store.create_api_key(user_id, "Initial Key")
return {"user_id": user_id, "api_key": initial_key.key}
except Exception as e:
logger.error(f"Error creating user: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))



@self.app.post("/v1/users/{user_id}/api-keys", response_model=APIKey)
async def create_api_key(
user_id: UUID,
key_create: APIKeyCreate,
current_user: User = Depends(get_current_user)
):
"""Create a new API key for a user."""
if current_user.id != user_id and not current_user.is_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not authorized to create API keys for this user"
)

return self.store.create_api_key(user_id, key_create.name)

@self.app.get("/v1/users/{user_id}/api-keys", response_model=List[APIKey])
async def list_api_keys(
user_id: UUID,
current_user: User = Depends(get_current_user)
):
"""List all API keys for a user."""
if current_user.id != user_id and not current_user.is_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not authorized to view API keys for this user"
)

return list(self.store.users[user_id].api_keys.values())

@self.app.delete("/v1/users/{user_id}/api-keys/{key}")
async def revoke_api_key(
user_id: UUID,
key: str,
current_user: User = Depends(get_current_user)
):
"""Revoke an API key."""
if current_user.id != user_id and not current_user.is_admin:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not authorized to revoke API keys for this user"
)

if key in self.store.users[user_id].api_keys:
self.store.users[user_id].api_keys[key].is_active = False
del self.store.api_keys[key]
return {"status": "API key revoked"}

raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="API key not found"
)

@self.app.get("/v1/users/me/agents", response_model=List[AgentSummary])
async def list_user_agents(
current_user: User = Depends(get_current_user),
tags: Optional[List[str]] = Query(None),
status: Optional[AgentStatus] = None,
):
"""List all agents owned by the current user."""
user_agents = self.store.user_agents.get(current_user.id, [])
return [
agent for agent in await self.store.list_agents(tags, status)
if agent.agent_id in user_agents
]


# Modify existing routes to use API key authentication
@self.app.post("/v1/agent", response_model=Dict[str, UUID])
async def create_agent(config: AgentConfig):
async def create_agent(
config: AgentConfig,
current_user: User = Depends(get_current_user)
):
"""Create a new agent with the specified configuration."""
agent_id = await self.store.create_agent(config)
agent_id = await self.store.create_agent(config, current_user.id)
return {"agent_id": agent_id}

@self.app.get("/v1/agents", response_model=List[AgentSummary])
Expand Down Expand Up @@ -611,28 +804,27 @@ async def _cleanup_old_metrics(self, agent_id: UUID):
if k > cutoff
}


def create_app() -> FastAPI:
"""Create and configure the FastAPI application."""
print("create app")
logger.info("Creating FastAPI application")
api = SwarmsAPI()
return api.app
app = api.app
logger.info("FastAPI application created successfully")
return app

app = create_app()

#if __name__ == "__main__":
if __name__ == '__main__':
#freeze_support()
print("yes in main")
# Configure uvicorn logging
logger.info("API Starting")

uvicorn.run(
"main:create_app",
host="0.0.0.0",
port=8000,
# reload=True,
# workers=4,
)
else:
print("not in main")

try:
logger.info("Starting API server...")
print("Starting API server on http://0.0.0.0:8000")

uvicorn.run(
app, # Pass the app instance directly
host="0.0.0.0",
port=8000,
log_level="info"
)
except Exception as e:
logger.error(f"Failed to start API: {str(e)}")
print(f"Error starting server: {str(e)}")

0 comments on commit a564fd2

Please sign in to comment.