Skip to content

Commit

Permalink
reformat with pycharm, addresses issue harmonydata#38
Browse files Browse the repository at this point in the history
  • Loading branch information
evelynnchen-cmu committed Dec 28, 2024
1 parent 91850ae commit 76647c0
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 76 deletions.
9 changes: 8 additions & 1 deletion src/harmony/schemas/errors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,44 +25,51 @@
'''


class BaseHarmonyError(Exception):
def __init__(self, message: str = None):
self.status_code = 500
self.detail = message or "Something went wrong"
super().__init__(self.detail)


class BadRequestError(BaseHarmonyError):
def __init__(self, message: str = None):
self.status_code = 400
self.detail = message or "Bad request data"
super(Exception, self).__init__(self.detail)


class SomethingWrongError(BaseHarmonyError):
def __init__(self, message: str = None):
self.status_code = 500
self.detail = message or "Something went wrong"
super(Exception, self).__init__(self.detail)


class UnauthorizedError(BaseHarmonyError):
def __init__(self, message: str = None):
self.status_code = 401
self.detail = message or "Unauthorized"
super(Exception, self).__init__(self.detail)


class ForbiddenError(BaseHarmonyError):
def __init__(self, message: str = None):
self.status_code = 403
self.detail = message or "Forbidden"
super(Exception, self).__init__(self.detail)


class ConflictError(BaseHarmonyError):
def __init__(self, message: str = None):
self.status_code = 409
self.detail = message or "Conflict"
super(Exception, self).__init__(self.detail)


class ResourceNotFoundError(BaseHarmonyError):
def __init__(self, message: str = None):
self.status_code = 404
self.detail = message or "Resource not found"
super(Exception, self).__init__(self.detail)
super(Exception, self).__init__(self.detail)
13 changes: 6 additions & 7 deletions src/harmony/schemas/requests/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,12 @@
'''

from typing import Any, Dict, List, Optional

from pydantic import ConfigDict, BaseModel, Field

from harmony.schemas.catalogue_instrument import CatalogueInstrument
from harmony.schemas.catalogue_question import CatalogueQuestion
from harmony.schemas.enums.file_types import FileType
from harmony.schemas.enums.languages import Language
from pydantic import ConfigDict, BaseModel, Field
from typing import Any, Dict, List, Optional

DEFAULT_FRAMEWORK = "huggingface"
DEFAULT_MODEL = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
Expand Down Expand Up @@ -66,7 +64,8 @@ class Question(BaseModel):
instrument_id: Optional[str] = Field(None, description="Unique identifier for the instrument (UUID-4)")
instrument_name: Optional[str] = Field(None, description="Human readable name for the instrument")
topics_auto: Optional[list] = Field(None, description="Automated list of topics identified by model")
topics_strengths: Optional[dict] = Field(None, description="Automated list of topics identified by model with strength of topic")
topics_strengths: Optional[dict] = Field(None,
description="Automated list of topics identified by model with strength of topic")
nearest_match_from_mhc_auto: Optional[dict] = Field(None, description="Automatically identified nearest MHC match")
closest_catalogue_question_match: Optional[CatalogueQuestion] = Field(
None, description="The closest question match in the catalogue for the question"
Expand Down Expand Up @@ -96,7 +95,7 @@ class Instrument(BaseModel):
study: Optional[str] = Field(None, description="The study")
sweep: Optional[str] = Field(None, description="The sweep")
metadata: Optional[dict] = Field(None,
description="Optional metadata about the instrument (URL, citation, DOI, copyright holder)")
description="Optional metadata about the instrument (URL, citation, DOI, copyright holder)")
language: Language = Field(Language.English,
description="The ISO 639-2 (alpha-2) encoding of the instrument language")
questions: List[Question] = Field(description="The items inside the instrument")
Expand Down Expand Up @@ -125,7 +124,7 @@ class Instrument(BaseModel):
"source_page": 0
}]
}
})
})


class MatchParameters(BaseModel):
Expand Down
79 changes: 40 additions & 39 deletions src/harmony/util/url_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,26 @@
'''

import urllib.parse
import base64
import uuid
import ssl
import hashlib
from typing import List, Dict
import requests
import ssl
import urllib.parse
import uuid
from datetime import datetime, timedelta
from harmony.parsing.wrapper_all_parsers import convert_files_to_instruments
from harmony.schemas.errors.base import BadRequestError, ForbiddenError, ConflictError, SomethingWrongError
from harmony.schemas.requests.text import RawFile, Instrument, FileType
from pathlib import Path

import requests
from requests.adapters import HTTPAdapter
from typing import List, Dict

from harmony.schemas.requests.text import RawFile, Instrument, FileType
from harmony.schemas.errors.base import BadRequestError, ForbiddenError, ConflictError, SomethingWrongError
from harmony.parsing.wrapper_all_parsers import convert_files_to_instruments

MAX_FILE_SIZE = 50 * 1024 * 1024 #50MB
DOWNLOAD_TIMEOUT = 30 #seconds
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
DOWNLOAD_TIMEOUT = 30 # seconds
MAX_REDIRECTS = 5
ALLOWED_SCHEMES = {'https'}
RATE_LIMIT_REQUESTS = 60 #requests per min
RATE_LIMIT_WINDOW = 60 #seconds
RATE_LIMIT_REQUESTS = 60 # requests per min
RATE_LIMIT_WINDOW = 60 # seconds

MIME_TO_FILE_TYPE = {
'application/pdf': FileType.pdf,
Expand All @@ -64,6 +62,7 @@
'.docx': FileType.docx
}


class URLDownloader:
def __init__(self):
self.rate_limit_storage: Dict[str, List[datetime]] = {}
Expand All @@ -75,70 +74,70 @@ def _check_rate_limit(self, domain: str) -> None:
now = datetime.now()
if domain not in self.rate_limit_storage:
self.rate_limit_storage[domain] = []

self.rate_limit_storage[domain] = [
ts for ts in self.rate_limit_storage[domain]
if ts > now - timedelta(seconds=RATE_LIMIT_WINDOW)
]

if len(self.rate_limit_storage[domain]) >= RATE_LIMIT_REQUESTS:
raise ConflictError("Rate limit exceeded")

self.rate_limit_storage[domain].append(now)

def _validate_url(self, url: str) -> None:
try:
parsed = urllib.parse.urlparse(url)

if parsed.scheme not in ALLOWED_SCHEMES:
raise BadRequestError(f"URL must use HTTPS")

if not parsed.netloc or '.' not in parsed.netloc:
raise BadRequestError("Invalid domain")

if '..' in parsed.path or '//' in parsed.path:
raise ForbiddenError("Path traversal detected")

if parsed.fragment:
raise BadRequestError("URL fragments not supported")

blocked_domains = {'localhost', '127.0.0.1', '0.0.0.0'}
if parsed.netloc in blocked_domains:
raise ForbiddenError("Access to internal domains blocked")

except Exception as e:
raise BadRequestError(f"Invalid URL: {str(e)}")

def _validate_ssl(self, response: requests.Response) -> None:
cert = response.raw.connection.sock.getpeercert()
if not cert:
raise ForbiddenError("Invalid SSL certificate")

not_after = ssl.cert_time_to_seconds(cert['notAfter'])
if datetime.fromtimestamp(not_after) < datetime.now():
raise ForbiddenError("Expired SSL certificate")

def _check_legal_headers(self, response: requests.Response) -> None:
if response.headers.get('X-Robots-Tag', '').lower() == 'noindex':
raise ForbiddenError("Access not allowed by robots directive")

if 'X-Copyright' in response.headers:
raise ForbiddenError("Content is copyright protected")

if 'X-Terms-Of-Service' in response.headers:
raise ForbiddenError("Terms of service acceptance required")

def _validate_content_type(self, url: str, content_type: str) -> FileType:
try:
content_type = content_type.split(';')[0].lower()

if content_type in MIME_TO_FILE_TYPE:
return MIME_TO_FILE_TYPE[content_type]

ext = Path(urllib.parse.urlparse(url).path).suffix.lower()
if ext in EXT_TO_FILE_TYPE:
return EXT_TO_FILE_TYPE[ext]

raise BadRequestError(f"Unsupported file type: {content_type}")
except BadRequestError:
raise
Expand All @@ -150,7 +149,7 @@ def download(self, url: str) -> RawFile:
self._validate_url(url)
domain = urllib.parse.urlparse(url).netloc
self._check_rate_limit(domain)

response = self.session.get(
url,
timeout=DOWNLOAD_TIMEOUT,
Expand All @@ -163,27 +162,28 @@ def download(self, url: str) -> RawFile:
}
)
response.raise_for_status()

self._validate_ssl(response)
self._check_legal_headers(response)

content_length = response.headers.get('content-length')
if content_length and int(content_length) > MAX_FILE_SIZE:
raise ForbiddenError(f"File too large: {content_length} bytes (max {MAX_FILE_SIZE})")

file_type = self._validate_content_type(url, response.headers.get('content-type', ''))

hasher = hashlib.sha256()
content = b''
for chunk in response.iter_content(chunk_size=8192):
hasher.update(chunk)
content += chunk

if file_type in [FileType.pdf, FileType.xlsx, FileType.docx]:
content_str = f"data:{response.headers['content-type']};base64," + base64.b64encode(content).decode('ascii')
content_str = f"data:{response.headers['content-type']};base64," + base64.b64encode(content).decode(
'ascii')
else:
content_str = content.decode('utf-8')

return RawFile(
file_id=str(uuid.uuid4()),
file_name=Path(urllib.parse.urlparse(url).path).name or "downloaded_file",
Expand All @@ -195,7 +195,7 @@ def download(self, url: str) -> RawFile:
'source_url': url
}
)

except (BadRequestError, ForbiddenError, ConflictError):
raise
except requests.Timeout:
Expand All @@ -214,7 +214,8 @@ def download(self, url: str) -> RawFile:
except Exception as e:
raise SomethingWrongError(f"Unexpected error: {str(e)}")


def load_instruments_from_url(url: str) -> List[Instrument]:
downloader = URLDownloader()
raw_file = downloader.download(url)
return convert_files_to_instruments([raw_file])
return convert_files_to_instruments([raw_file])
Loading

0 comments on commit 76647c0

Please sign in to comment.