Skip to content

Commit

Permalink
Fix: Validate model name and improve path security for file download …
Browse files Browse the repository at this point in the history
…endpoint
  • Loading branch information
yihong1120 committed Nov 16, 2024
1 parent 5ed7d3a commit 3b6817d
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions examples/YOLO_server_api/model_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from fastapi import HTTPException
from fastapi.responses import FileResponse
from fastapi_limiter.depends import RateLimiter
from werkzeug.utils import secure_filename

models_router = APIRouter()
MODELS_DIRECTORY = Path('models/pt/')
Expand Down Expand Up @@ -38,10 +39,17 @@ async def download_model(model_name: str):
if model_name not in ALLOWED_MODELS:
raise HTTPException(status_code=404, detail='Model not found')

# Ensure the model name is sanitized
sanitized_model_name = secure_filename(model_name)
if sanitized_model_name != model_name:
raise HTTPException(status_code=400, detail='Invalid model name')

try:
MODEL_URL = (
f"http://changdar-server.mooo.com:28000/models/{model_name}"
f"http://changdar-server.mooo.com:28000/"
f"models/{sanitized_model_name}"
)

# Asynchronously request headers information
async with httpx.AsyncClient() as client:
response = await client.head(MODEL_URL)
Expand All @@ -51,13 +59,15 @@ async def download_model(model_name: str):
response.headers['Last-Modified'],
'%a, %d %b %Y %H:%M:%S GMT',
)
local_file_path = MODELS_DIRECTORY / model_name
local_file_path = MODELS_DIRECTORY / sanitized_model_name
try:
# Resolve the local file path and ensure it is within the
# models directory
local_file_path = local_file_path.resolve().relative_to(
MODELS_DIRECTORY.resolve(),
)
resolved_path = local_file_path.resolve()
if not resolved_path.parent == MODELS_DIRECTORY.resolve():
raise HTTPException(
status_code=400, detail='Invalid model name',
)
except ValueError:
raise HTTPException(
status_code=400, detail='Invalid model name',
Expand All @@ -74,9 +84,11 @@ async def download_model(model_name: str):
# Return file response
return FileResponse(
local_file_path,
filename=model_name,
filename=sanitized_model_name,
headers={
'Content-Disposition': f"attachment; filename={model_name}",
'Content-Disposition': (
f"attachment; filename={sanitized_model_name}"
),
},
)

Expand Down

0 comments on commit 3b6817d

Please sign in to comment.