From 91850ae25961cb2b9a3e4d5cec300dfaa23e03af Mon Sep 17 00:00:00 2001 From: Evelynn Chen Date: Fri, 27 Dec 2024 22:00:58 -0800 Subject: [PATCH 1/2] added metadata to RawFile, added url_loader.py and tests --- src/harmony/schemas/errors/base.py | 72 +++++---- src/harmony/schemas/requests/text.py | 3 +- src/harmony/util/url_loader.py | 220 +++++++++++++++++++++++++ tests/test_url_loader.py | 230 +++++++++++++++++++++++++++ 4 files changed, 493 insertions(+), 32 deletions(-) create mode 100644 src/harmony/util/url_loader.py create mode 100644 tests/test_url_loader.py diff --git a/src/harmony/schemas/errors/base.py b/src/harmony/schemas/errors/base.py index a7fb38a..4fa8556 100644 --- a/src/harmony/schemas/errors/base.py +++ b/src/harmony/schemas/errors/base.py @@ -25,34 +25,44 @@ ''' -from pydantic import BaseModel - - -class BadRequestError(BaseModel): - status_code = 400 - detail = "Bad request data" - - -class SomethingWrongError(BaseModel): - status_code = 500 - detail = "Something went wrong" - - -class UnauthorizedError(BaseModel): - status_code = 401 - message = "Unauthorized" - - -class ForbiddenError(BaseModel): - status_code = 403 - message = "Forbidden" - - -class ConflictError(BaseModel): - status_code = 409 - message = "Conflict" - - -class ResourceNotFoundError(BaseModel): - status_code = 404 - message = "Resource not found" +class BaseHarmonyError(Exception): + def __init__(self, message: str = None): + self.status_code = 500 + self.detail = message or "Something went wrong" + super().__init__(self.detail) + +class BadRequestError(BaseHarmonyError): + def __init__(self, message: str = None): + self.status_code = 400 + self.detail = message or "Bad request data" + super(Exception, self).__init__(self.detail) + +class SomethingWrongError(BaseHarmonyError): + def __init__(self, message: str = None): + self.status_code = 500 + self.detail = message or "Something went wrong" + super(Exception, self).__init__(self.detail) + +class UnauthorizedError(BaseHarmonyError): + def __init__(self, message: str = None): + self.status_code = 401 + self.detail = message or "Unauthorized" + super(Exception, self).__init__(self.detail) + +class ForbiddenError(BaseHarmonyError): + def __init__(self, message: str = None): + self.status_code = 403 + self.detail = message or "Forbidden" + super(Exception, self).__init__(self.detail) + +class ConflictError(BaseHarmonyError): + def __init__(self, message: str = None): + self.status_code = 409 + self.detail = message or "Conflict" + super(Exception, self).__init__(self.detail) + +class ResourceNotFoundError(BaseHarmonyError): + def __init__(self, message: str = None): + self.status_code = 404 + self.detail = message or "Resource not found" + super(Exception, self).__init__(self.detail) \ No newline at end of file diff --git a/src/harmony/schemas/requests/text.py b/src/harmony/schemas/requests/text.py index c22b84d..83fe39f 100644 --- a/src/harmony/schemas/requests/text.py +++ b/src/harmony/schemas/requests/text.py @@ -25,7 +25,7 @@ ''' -from typing import List, Optional +from typing import Any, Dict, List, Optional from pydantic import ConfigDict, BaseModel, Field @@ -45,6 +45,7 @@ class RawFile(BaseModel): content: str = Field(description="The raw file contents") text_content: Optional[str] = Field(None, description="The plain text content") tables: list = Field([], description="The tables in the file") + metadata: Optional[Dict[str, Any]] = Field(default=None, description="Optional metadata about the file") model_config = ConfigDict( json_schema_extra={ "example": { diff --git a/src/harmony/util/url_loader.py b/src/harmony/util/url_loader.py new file mode 100644 index 0000000..5215d4a --- /dev/null +++ b/src/harmony/util/url_loader.py @@ -0,0 +1,220 @@ +''' +MIT License + +Copyright (c) 2023 Ulster University (https://www.ulster.ac.uk). +Project: Harmony (https://harmonydata.ac.uk) +Maintainer: Thomas Wood (https://fastdatascience.com) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' + +import urllib.parse +import base64 +import uuid +import ssl +import hashlib +from typing import List, Dict +from datetime import datetime, timedelta +from pathlib import Path + +import requests +from requests.adapters import HTTPAdapter + +from harmony.schemas.requests.text import RawFile, Instrument, FileType +from harmony.schemas.errors.base import BadRequestError, ForbiddenError, ConflictError, SomethingWrongError +from harmony.parsing.wrapper_all_parsers import convert_files_to_instruments + +MAX_FILE_SIZE = 50 * 1024 * 1024 #50MB +DOWNLOAD_TIMEOUT = 30 #seconds +MAX_REDIRECTS = 5 +ALLOWED_SCHEMES = {'https'} +RATE_LIMIT_REQUESTS = 60 #requests per min +RATE_LIMIT_WINDOW = 60 #seconds + +MIME_TO_FILE_TYPE = { + 'application/pdf': FileType.pdf, + 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': FileType.xlsx, + 'text/plain': FileType.txt, + 'text/csv': FileType.csv, + 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': FileType.docx +} + +EXT_TO_FILE_TYPE = { + '.pdf': FileType.pdf, + '.xlsx': FileType.xlsx, + '.txt': FileType.txt, + '.csv': FileType.csv, + '.docx': FileType.docx +} + +class URLDownloader: + def __init__(self): + self.rate_limit_storage: Dict[str, List[datetime]] = {} + self.session = requests.Session() + self.session.mount('https://', HTTPAdapter(max_retries=3)) + self.session.verify = True + + def _check_rate_limit(self, domain: str) -> None: + now = datetime.now() + if domain not in self.rate_limit_storage: + self.rate_limit_storage[domain] = [] + + self.rate_limit_storage[domain] = [ + ts for ts in self.rate_limit_storage[domain] + if ts > now - timedelta(seconds=RATE_LIMIT_WINDOW) + ] + + if len(self.rate_limit_storage[domain]) >= RATE_LIMIT_REQUESTS: + raise ConflictError("Rate limit exceeded") + + self.rate_limit_storage[domain].append(now) + + def _validate_url(self, url: str) -> None: + try: + parsed = urllib.parse.urlparse(url) + + if parsed.scheme not in ALLOWED_SCHEMES: + raise BadRequestError(f"URL must use HTTPS") + + if not parsed.netloc or '.' not in parsed.netloc: + raise BadRequestError("Invalid domain") + + if '..' in parsed.path or '//' in parsed.path: + raise ForbiddenError("Path traversal detected") + + if parsed.fragment: + raise BadRequestError("URL fragments not supported") + + blocked_domains = {'localhost', '127.0.0.1', '0.0.0.0'} + if parsed.netloc in blocked_domains: + raise ForbiddenError("Access to internal domains blocked") + + except Exception as e: + raise BadRequestError(f"Invalid URL: {str(e)}") + + def _validate_ssl(self, response: requests.Response) -> None: + cert = response.raw.connection.sock.getpeercert() + if not cert: + raise ForbiddenError("Invalid SSL certificate") + + not_after = ssl.cert_time_to_seconds(cert['notAfter']) + if datetime.fromtimestamp(not_after) < datetime.now(): + raise ForbiddenError("Expired SSL certificate") + + def _check_legal_headers(self, response: requests.Response) -> None: + if response.headers.get('X-Robots-Tag', '').lower() == 'noindex': + raise ForbiddenError("Access not allowed by robots directive") + + if 'X-Copyright' in response.headers: + raise ForbiddenError("Content is copyright protected") + + if 'X-Terms-Of-Service' in response.headers: + raise ForbiddenError("Terms of service acceptance required") + + def _validate_content_type(self, url: str, content_type: str) -> FileType: + try: + content_type = content_type.split(';')[0].lower() + + if content_type in MIME_TO_FILE_TYPE: + return MIME_TO_FILE_TYPE[content_type] + + ext = Path(urllib.parse.urlparse(url).path).suffix.lower() + if ext in EXT_TO_FILE_TYPE: + return EXT_TO_FILE_TYPE[ext] + + raise BadRequestError(f"Unsupported file type: {content_type}") + except BadRequestError: + raise + except Exception as e: + raise BadRequestError(f"Error validating content type: {str(e)}") + + def download(self, url: str) -> RawFile: + try: + self._validate_url(url) + domain = urllib.parse.urlparse(url).netloc + self._check_rate_limit(domain) + + response = self.session.get( + url, + timeout=DOWNLOAD_TIMEOUT, + stream=True, + verify=True, + allow_redirects=True, + headers={ + 'User-Agent': 'HarmonyBot/1.0 (+https://harmonydata.ac.uk)', + 'Accept': ', '.join(MIME_TO_FILE_TYPE.keys()) + } + ) + response.raise_for_status() + + self._validate_ssl(response) + self._check_legal_headers(response) + + content_length = response.headers.get('content-length') + if content_length and int(content_length) > MAX_FILE_SIZE: + raise ForbiddenError(f"File too large: {content_length} bytes (max {MAX_FILE_SIZE})") + + file_type = self._validate_content_type(url, response.headers.get('content-type', '')) + + hasher = hashlib.sha256() + content = b'' + for chunk in response.iter_content(chunk_size=8192): + hasher.update(chunk) + content += chunk + + if file_type in [FileType.pdf, FileType.xlsx, FileType.docx]: + content_str = f"data:{response.headers['content-type']};base64," + base64.b64encode(content).decode('ascii') + else: + content_str = content.decode('utf-8') + + return RawFile( + file_id=str(uuid.uuid4()), + file_name=Path(urllib.parse.urlparse(url).path).name or "downloaded_file", + file_type=file_type, + content=content_str, + metadata={ + 'content_hash': hasher.hexdigest(), + 'download_timestamp': datetime.now().isoformat(), + 'source_url': url + } + ) + + except (BadRequestError, ForbiddenError, ConflictError): + raise + except requests.Timeout: + raise SomethingWrongError("Download timeout") + except requests.TooManyRedirects: + raise ForbiddenError("Too many redirects") + except requests.RequestException as e: + if e.response is not None: + if e.response.status_code == 401: + raise ForbiddenError("Resource requires authentication") + elif e.response.status_code == 403: + raise ForbiddenError("Access forbidden") + elif e.response.status_code == 429: + raise ConflictError("Rate limit exceeded") + raise SomethingWrongError(f"Download error: {str(e)}") + except Exception as e: + raise SomethingWrongError(f"Unexpected error: {str(e)}") + +def load_instruments_from_url(url: str) -> List[Instrument]: + downloader = URLDownloader() + raw_file = downloader.download(url) + return convert_files_to_instruments([raw_file]) \ No newline at end of file diff --git a/tests/test_url_loader.py b/tests/test_url_loader.py new file mode 100644 index 0000000..c8de786 --- /dev/null +++ b/tests/test_url_loader.py @@ -0,0 +1,230 @@ +''' +MIT License + +Copyright (c) 2023 Ulster University (https://www.ulster.ac.uk). +Project: Harmony (https://harmonydata.ac.uk) +Maintainer: Thomas Wood (https://fastdatascience.com) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +''' + +import sys +import unittest +from unittest.mock import patch, MagicMock +from datetime import datetime +import requests + +sys.path.append("../src") + +from harmony.util.url_loader import ( + URLDownloader, + load_instruments_from_url, + MAX_FILE_SIZE, + RATE_LIMIT_REQUESTS +) +from harmony.schemas.errors.base import ( + BadRequestError, + ForbiddenError, + ConflictError, + SomethingWrongError +) +from harmony.schemas.requests.text import FileType + +class TestURLLoader(unittest.TestCase): + def setUp(self): + self.downloader = URLDownloader() + self.valid_url = "https://example.com/test.pdf" + + self.downloader.rate_limit_storage.clear() + + self.mock_response = MagicMock() + self.mock_response.headers = { + 'content-type': 'application/pdf', + 'content-length': '1000' + } + self.mock_response.content = b'test content' + self.mock_response.raw = MagicMock() + self.mock_response.raw.connection = MagicMock() + self.mock_response.raw.connection.sock = MagicMock() + self.mock_response.raw.connection.sock.getpeercert.return_value = { + 'notAfter': 'Dec 31 23:59:59 2125 GMT' + } + + def mock_iter_content(chunk_size=None): + yield b'test content' + self.mock_response.iter_content = mock_iter_content + + def test_content_integrity(self): + with patch('requests.Session.get', return_value=self.mock_response): + raw_file = self.downloader.download(self.valid_url) + self.assertIsNotNone(raw_file.metadata) + self.assertIn('content_hash', raw_file.metadata) + expected_hash = '6ae8a75555209fd6c44157c0aed8016e763ff435a19cf186f76863140143ff72' + self.assertEqual(raw_file.metadata['content_hash'], expected_hash) + + def test_content_type_validation(self): + invalid_types = [ + "application/javascript", + "application/x-executable", + "application/octet-stream" + ] + + for content_type in invalid_types: + with self.subTest(content_type=content_type): + mock_response = MagicMock() + mock_response.headers = { + 'content-type': content_type, + } + mock_response.raw = self.mock_response.raw + mock_response.iter_content = self.mock_response.iter_content + mock_response.raise_for_status = lambda: None + + with patch('requests.Session.get', return_value=mock_response): + with self.assertRaises(BadRequestError) as cm: + self.downloader.download("https://example.com/test.unknown") + self.assertIn("Unsupported file type", str(cm.exception)) + + def test_file_size_limit(self): + mock_response = MagicMock() + mock_response.headers = { + 'content-type': 'application/pdf', + 'content-length': str(MAX_FILE_SIZE + 1) + } + mock_response.raw = self.mock_response.raw + mock_response.iter_content = self.mock_response.iter_content + + with patch('requests.Session.get', return_value=mock_response): + with self.assertRaises(ForbiddenError): + self.downloader.download(self.valid_url) + + def test_file_types(self): + test_files = { + 'test.pdf': (FileType.pdf, 'application/pdf'), + 'test.xlsx': (FileType.xlsx, 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'), + 'test.txt': (FileType.txt, 'text/plain'), + 'test.csv': (FileType.csv, 'text/csv'), + 'test.docx': (FileType.docx, 'application/vnd.openxmlformats-officedocument.wordprocessingml.document') + } + + for filename, (file_type, mime_type) in test_files.items(): + with self.subTest(file_type=file_type): + url = f"https://example.com/{filename}" + mock_response = MagicMock() + mock_response.headers = { + 'content-type': mime_type, + 'content-length': '1000' + } + mock_response.raw = self.mock_response.raw + mock_response.content = b'test content' + mock_response.iter_content = lambda chunk_size: [b'test content'] + + with patch('requests.Session.get', return_value=mock_response): + raw_file = self.downloader.download(url) + self.assertEqual(raw_file.file_type, file_type) + + def test_rate_limiting(self): + self.downloader.rate_limit_storage.clear() + + with patch('requests.Session.get', return_value=self.mock_response): + # initial request + self.downloader.download(self.valid_url) + + # block after too many requests + self.downloader.rate_limit_storage['example.com'] = [ + datetime.now() for _ in range(RATE_LIMIT_REQUESTS) + ] + + with self.assertRaises(ConflictError): + self.downloader.download(self.valid_url) + + def test_successful_instrument_loading(self): + self.downloader.rate_limit_storage.clear() + + self.mock_response.iter_content = lambda chunk_size: [b'test content'] + + with patch('requests.Session.get', return_value=self.mock_response): + instruments = load_instruments_from_url(self.valid_url) + self.assertIsInstance(instruments, list) + + def test_error_handling(self): + error_conditions = { + requests.Timeout: SomethingWrongError, + requests.TooManyRedirects: ForbiddenError, + requests.ConnectionError: SomethingWrongError + } + + for exception, expected_error in error_conditions.items(): + with self.subTest(error=exception.__name__): + with patch('requests.Session.get', side_effect=exception()): + with self.assertRaises(expected_error): + self.downloader.download(self.valid_url) + + def test_http_error_handling(self): + error_codes = { + 401: ForbiddenError, #unauthorized + 403: ForbiddenError, #forbidden + 429: ConflictError, #rate limit + 500: SomethingWrongError, #server error + } + + for status_code, expected_error in error_codes.items(): + with self.subTest(status_code=status_code): + mock_response = MagicMock() + mock_response.raise_for_status.side_effect = requests.RequestException( + response=MagicMock(status_code=status_code) + ) + + with patch('requests.Session.get', return_value=mock_response): + with self.assertRaises(expected_error): + self.downloader.download(self.valid_url) + + def test_ssl_validation(self): + mock_response = MagicMock() + mock_response.headers = self.mock_response.headers + mock_response.content = self.mock_response.content + mock_response.iter_content = self.mock_response.iter_content + mock_response.raw = MagicMock() + mock_response.raw.connection = MagicMock() + mock_response.raw.connection.sock = MagicMock() + mock_response.raw.connection.sock.getpeercert.return_value = { + 'notAfter': 'Jan 1 00:00:00 2020 GMT' + } + + with patch('requests.Session.get', return_value=mock_response): + with self.assertRaises(ForbiddenError): + self.downloader.download(self.valid_url) + + def test_url_validation(self): + invalid_urls = [ + "not-a-url", + "http://example.com", #HTTP not allowed + "https://localhost", + "https://127.0.0.1", + "https://example.com/../test.pdf", #path traversing + "https://example.com/test.pdf#fragment" + ] + + for url in invalid_urls: + with self.subTest(url=url): + with self.assertRaises((BadRequestError, ForbiddenError)): + self.downloader.download(url) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 76647c093f1e6205d5ac1be3fe82789efb11692b Mon Sep 17 00:00:00 2001 From: Evelynn Chen Date: Fri, 27 Dec 2024 22:05:04 -0800 Subject: [PATCH 2/2] reformat with pycharm, addresses issue #38 --- src/harmony/schemas/errors/base.py | 9 +++- src/harmony/schemas/requests/text.py | 13 +++-- src/harmony/util/url_loader.py | 79 ++++++++++++++-------------- tests/test_url_loader.py | 61 +++++++++++---------- 4 files changed, 86 insertions(+), 76 deletions(-) diff --git a/src/harmony/schemas/errors/base.py b/src/harmony/schemas/errors/base.py index 4fa8556..9ff9c08 100644 --- a/src/harmony/schemas/errors/base.py +++ b/src/harmony/schemas/errors/base.py @@ -25,44 +25,51 @@ ''' + class BaseHarmonyError(Exception): def __init__(self, message: str = None): self.status_code = 500 self.detail = message or "Something went wrong" super().__init__(self.detail) + class BadRequestError(BaseHarmonyError): def __init__(self, message: str = None): self.status_code = 400 self.detail = message or "Bad request data" super(Exception, self).__init__(self.detail) + class SomethingWrongError(BaseHarmonyError): def __init__(self, message: str = None): self.status_code = 500 self.detail = message or "Something went wrong" super(Exception, self).__init__(self.detail) + class UnauthorizedError(BaseHarmonyError): def __init__(self, message: str = None): self.status_code = 401 self.detail = message or "Unauthorized" super(Exception, self).__init__(self.detail) + class ForbiddenError(BaseHarmonyError): def __init__(self, message: str = None): self.status_code = 403 self.detail = message or "Forbidden" super(Exception, self).__init__(self.detail) + class ConflictError(BaseHarmonyError): def __init__(self, message: str = None): self.status_code = 409 self.detail = message or "Conflict" super(Exception, self).__init__(self.detail) + class ResourceNotFoundError(BaseHarmonyError): def __init__(self, message: str = None): self.status_code = 404 self.detail = message or "Resource not found" - super(Exception, self).__init__(self.detail) \ No newline at end of file + super(Exception, self).__init__(self.detail) diff --git a/src/harmony/schemas/requests/text.py b/src/harmony/schemas/requests/text.py index 83fe39f..e6fee9c 100644 --- a/src/harmony/schemas/requests/text.py +++ b/src/harmony/schemas/requests/text.py @@ -25,14 +25,12 @@ ''' -from typing import Any, Dict, List, Optional - -from pydantic import ConfigDict, BaseModel, Field - from harmony.schemas.catalogue_instrument import CatalogueInstrument from harmony.schemas.catalogue_question import CatalogueQuestion from harmony.schemas.enums.file_types import FileType from harmony.schemas.enums.languages import Language +from pydantic import ConfigDict, BaseModel, Field +from typing import Any, Dict, List, Optional DEFAULT_FRAMEWORK = "huggingface" DEFAULT_MODEL = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2' @@ -66,7 +64,8 @@ class Question(BaseModel): instrument_id: Optional[str] = Field(None, description="Unique identifier for the instrument (UUID-4)") instrument_name: Optional[str] = Field(None, description="Human readable name for the instrument") topics_auto: Optional[list] = Field(None, description="Automated list of topics identified by model") - topics_strengths: Optional[dict] = Field(None, description="Automated list of topics identified by model with strength of topic") + topics_strengths: Optional[dict] = Field(None, + description="Automated list of topics identified by model with strength of topic") nearest_match_from_mhc_auto: Optional[dict] = Field(None, description="Automatically identified nearest MHC match") closest_catalogue_question_match: Optional[CatalogueQuestion] = Field( None, description="The closest question match in the catalogue for the question" @@ -96,7 +95,7 @@ class Instrument(BaseModel): study: Optional[str] = Field(None, description="The study") sweep: Optional[str] = Field(None, description="The sweep") metadata: Optional[dict] = Field(None, - description="Optional metadata about the instrument (URL, citation, DOI, copyright holder)") + description="Optional metadata about the instrument (URL, citation, DOI, copyright holder)") language: Language = Field(Language.English, description="The ISO 639-2 (alpha-2) encoding of the instrument language") questions: List[Question] = Field(description="The items inside the instrument") @@ -125,7 +124,7 @@ class Instrument(BaseModel): "source_page": 0 }] } - }) + }) class MatchParameters(BaseModel): diff --git a/src/harmony/util/url_loader.py b/src/harmony/util/url_loader.py index 5215d4a..0f3adc0 100644 --- a/src/harmony/util/url_loader.py +++ b/src/harmony/util/url_loader.py @@ -25,28 +25,26 @@ ''' -import urllib.parse import base64 -import uuid -import ssl import hashlib -from typing import List, Dict +import requests +import ssl +import urllib.parse +import uuid from datetime import datetime, timedelta +from harmony.parsing.wrapper_all_parsers import convert_files_to_instruments +from harmony.schemas.errors.base import BadRequestError, ForbiddenError, ConflictError, SomethingWrongError +from harmony.schemas.requests.text import RawFile, Instrument, FileType from pathlib import Path - -import requests from requests.adapters import HTTPAdapter +from typing import List, Dict -from harmony.schemas.requests.text import RawFile, Instrument, FileType -from harmony.schemas.errors.base import BadRequestError, ForbiddenError, ConflictError, SomethingWrongError -from harmony.parsing.wrapper_all_parsers import convert_files_to_instruments - -MAX_FILE_SIZE = 50 * 1024 * 1024 #50MB -DOWNLOAD_TIMEOUT = 30 #seconds +MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB +DOWNLOAD_TIMEOUT = 30 # seconds MAX_REDIRECTS = 5 ALLOWED_SCHEMES = {'https'} -RATE_LIMIT_REQUESTS = 60 #requests per min -RATE_LIMIT_WINDOW = 60 #seconds +RATE_LIMIT_REQUESTS = 60 # requests per min +RATE_LIMIT_WINDOW = 60 # seconds MIME_TO_FILE_TYPE = { 'application/pdf': FileType.pdf, @@ -64,6 +62,7 @@ '.docx': FileType.docx } + class URLDownloader: def __init__(self): self.rate_limit_storage: Dict[str, List[datetime]] = {} @@ -75,37 +74,37 @@ def _check_rate_limit(self, domain: str) -> None: now = datetime.now() if domain not in self.rate_limit_storage: self.rate_limit_storage[domain] = [] - + self.rate_limit_storage[domain] = [ ts for ts in self.rate_limit_storage[domain] if ts > now - timedelta(seconds=RATE_LIMIT_WINDOW) ] - + if len(self.rate_limit_storage[domain]) >= RATE_LIMIT_REQUESTS: raise ConflictError("Rate limit exceeded") - + self.rate_limit_storage[domain].append(now) def _validate_url(self, url: str) -> None: try: parsed = urllib.parse.urlparse(url) - + if parsed.scheme not in ALLOWED_SCHEMES: raise BadRequestError(f"URL must use HTTPS") - + if not parsed.netloc or '.' not in parsed.netloc: raise BadRequestError("Invalid domain") - + if '..' in parsed.path or '//' in parsed.path: raise ForbiddenError("Path traversal detected") - + if parsed.fragment: raise BadRequestError("URL fragments not supported") - + blocked_domains = {'localhost', '127.0.0.1', '0.0.0.0'} if parsed.netloc in blocked_domains: raise ForbiddenError("Access to internal domains blocked") - + except Exception as e: raise BadRequestError(f"Invalid URL: {str(e)}") @@ -113,7 +112,7 @@ def _validate_ssl(self, response: requests.Response) -> None: cert = response.raw.connection.sock.getpeercert() if not cert: raise ForbiddenError("Invalid SSL certificate") - + not_after = ssl.cert_time_to_seconds(cert['notAfter']) if datetime.fromtimestamp(not_after) < datetime.now(): raise ForbiddenError("Expired SSL certificate") @@ -121,24 +120,24 @@ def _validate_ssl(self, response: requests.Response) -> None: def _check_legal_headers(self, response: requests.Response) -> None: if response.headers.get('X-Robots-Tag', '').lower() == 'noindex': raise ForbiddenError("Access not allowed by robots directive") - + if 'X-Copyright' in response.headers: raise ForbiddenError("Content is copyright protected") - + if 'X-Terms-Of-Service' in response.headers: raise ForbiddenError("Terms of service acceptance required") def _validate_content_type(self, url: str, content_type: str) -> FileType: try: content_type = content_type.split(';')[0].lower() - + if content_type in MIME_TO_FILE_TYPE: return MIME_TO_FILE_TYPE[content_type] - + ext = Path(urllib.parse.urlparse(url).path).suffix.lower() if ext in EXT_TO_FILE_TYPE: return EXT_TO_FILE_TYPE[ext] - + raise BadRequestError(f"Unsupported file type: {content_type}") except BadRequestError: raise @@ -150,7 +149,7 @@ def download(self, url: str) -> RawFile: self._validate_url(url) domain = urllib.parse.urlparse(url).netloc self._check_rate_limit(domain) - + response = self.session.get( url, timeout=DOWNLOAD_TIMEOUT, @@ -163,27 +162,28 @@ def download(self, url: str) -> RawFile: } ) response.raise_for_status() - + self._validate_ssl(response) self._check_legal_headers(response) - + content_length = response.headers.get('content-length') if content_length and int(content_length) > MAX_FILE_SIZE: raise ForbiddenError(f"File too large: {content_length} bytes (max {MAX_FILE_SIZE})") - + file_type = self._validate_content_type(url, response.headers.get('content-type', '')) - + hasher = hashlib.sha256() content = b'' for chunk in response.iter_content(chunk_size=8192): hasher.update(chunk) content += chunk - + if file_type in [FileType.pdf, FileType.xlsx, FileType.docx]: - content_str = f"data:{response.headers['content-type']};base64," + base64.b64encode(content).decode('ascii') + content_str = f"data:{response.headers['content-type']};base64," + base64.b64encode(content).decode( + 'ascii') else: content_str = content.decode('utf-8') - + return RawFile( file_id=str(uuid.uuid4()), file_name=Path(urllib.parse.urlparse(url).path).name or "downloaded_file", @@ -195,7 +195,7 @@ def download(self, url: str) -> RawFile: 'source_url': url } ) - + except (BadRequestError, ForbiddenError, ConflictError): raise except requests.Timeout: @@ -214,7 +214,8 @@ def download(self, url: str) -> RawFile: except Exception as e: raise SomethingWrongError(f"Unexpected error: {str(e)}") + def load_instruments_from_url(url: str) -> List[Instrument]: downloader = URLDownloader() raw_file = downloader.download(url) - return convert_files_to_instruments([raw_file]) \ No newline at end of file + return convert_files_to_instruments([raw_file]) diff --git a/tests/test_url_loader.py b/tests/test_url_loader.py index c8de786..14b2fad 100644 --- a/tests/test_url_loader.py +++ b/tests/test_url_loader.py @@ -25,35 +25,36 @@ ''' +import requests import sys import unittest -from unittest.mock import patch, MagicMock from datetime import datetime -import requests +from unittest.mock import patch, MagicMock sys.path.append("../src") from harmony.util.url_loader import ( - URLDownloader, + URLDownloader, load_instruments_from_url, MAX_FILE_SIZE, RATE_LIMIT_REQUESTS ) from harmony.schemas.errors.base import ( BadRequestError, - ForbiddenError, + ForbiddenError, ConflictError, SomethingWrongError ) from harmony.schemas.requests.text import FileType + class TestURLLoader(unittest.TestCase): def setUp(self): self.downloader = URLDownloader() self.valid_url = "https://example.com/test.pdf" - + self.downloader.rate_limit_storage.clear() - + self.mock_response = MagicMock() self.mock_response.headers = { 'content-type': 'application/pdf', @@ -69,6 +70,7 @@ def setUp(self): def mock_iter_content(chunk_size=None): yield b'test content' + self.mock_response.iter_content = mock_iter_content def test_content_integrity(self): @@ -85,7 +87,7 @@ def test_content_type_validation(self): "application/x-executable", "application/octet-stream" ] - + for content_type in invalid_types: with self.subTest(content_type=content_type): mock_response = MagicMock() @@ -95,7 +97,7 @@ def test_content_type_validation(self): mock_response.raw = self.mock_response.raw mock_response.iter_content = self.mock_response.iter_content mock_response.raise_for_status = lambda: None - + with patch('requests.Session.get', return_value=mock_response): with self.assertRaises(BadRequestError) as cm: self.downloader.download("https://example.com/test.unknown") @@ -109,11 +111,11 @@ def test_file_size_limit(self): } mock_response.raw = self.mock_response.raw mock_response.iter_content = self.mock_response.iter_content - + with patch('requests.Session.get', return_value=mock_response): with self.assertRaises(ForbiddenError): self.downloader.download(self.valid_url) - + def test_file_types(self): test_files = { 'test.pdf': (FileType.pdf, 'application/pdf'), @@ -122,7 +124,7 @@ def test_file_types(self): 'test.csv': (FileType.csv, 'text/csv'), 'test.docx': (FileType.docx, 'application/vnd.openxmlformats-officedocument.wordprocessingml.document') } - + for filename, (file_type, mime_type) in test_files.items(): with self.subTest(file_type=file_type): url = f"https://example.com/{filename}" @@ -134,31 +136,31 @@ def test_file_types(self): mock_response.raw = self.mock_response.raw mock_response.content = b'test content' mock_response.iter_content = lambda chunk_size: [b'test content'] - + with patch('requests.Session.get', return_value=mock_response): raw_file = self.downloader.download(url) self.assertEqual(raw_file.file_type, file_type) def test_rate_limiting(self): self.downloader.rate_limit_storage.clear() - + with patch('requests.Session.get', return_value=self.mock_response): # initial request self.downloader.download(self.valid_url) - + # block after too many requests self.downloader.rate_limit_storage['example.com'] = [ datetime.now() for _ in range(RATE_LIMIT_REQUESTS) ] - + with self.assertRaises(ConflictError): self.downloader.download(self.valid_url) def test_successful_instrument_loading(self): self.downloader.rate_limit_storage.clear() - + self.mock_response.iter_content = lambda chunk_size: [b'test content'] - + with patch('requests.Session.get', return_value=self.mock_response): instruments = load_instruments_from_url(self.valid_url) self.assertIsInstance(instruments, list) @@ -169,7 +171,7 @@ def test_error_handling(self): requests.TooManyRedirects: ForbiddenError, requests.ConnectionError: SomethingWrongError } - + for exception, expected_error in error_conditions.items(): with self.subTest(error=exception.__name__): with patch('requests.Session.get', side_effect=exception()): @@ -178,19 +180,19 @@ def test_error_handling(self): def test_http_error_handling(self): error_codes = { - 401: ForbiddenError, #unauthorized - 403: ForbiddenError, #forbidden - 429: ConflictError, #rate limit - 500: SomethingWrongError, #server error + 401: ForbiddenError, # unauthorized + 403: ForbiddenError, # forbidden + 429: ConflictError, # rate limit + 500: SomethingWrongError, # server error } - + for status_code, expected_error in error_codes.items(): with self.subTest(status_code=status_code): mock_response = MagicMock() mock_response.raise_for_status.side_effect = requests.RequestException( response=MagicMock(status_code=status_code) ) - + with patch('requests.Session.get', return_value=mock_response): with self.assertRaises(expected_error): self.downloader.download(self.valid_url) @@ -206,7 +208,7 @@ def test_ssl_validation(self): mock_response.raw.connection.sock.getpeercert.return_value = { 'notAfter': 'Jan 1 00:00:00 2020 GMT' } - + with patch('requests.Session.get', return_value=mock_response): with self.assertRaises(ForbiddenError): self.downloader.download(self.valid_url) @@ -214,17 +216,18 @@ def test_ssl_validation(self): def test_url_validation(self): invalid_urls = [ "not-a-url", - "http://example.com", #HTTP not allowed + "http://example.com", # HTTP not allowed "https://localhost", "https://127.0.0.1", - "https://example.com/../test.pdf", #path traversing + "https://example.com/../test.pdf", # path traversing "https://example.com/test.pdf#fragment" ] - + for url in invalid_urls: with self.subTest(url=url): with self.assertRaises((BadRequestError, ForbiddenError)): self.downloader.download(url) + if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main()