diff --git a/tests/unit/search/test_init.py b/tests/unit/search/test_init.py index 712122b0be6e..286e865a45fc 100644 --- a/tests/unit/search/test_init.py +++ b/tests/unit/search/test_init.py @@ -14,6 +14,7 @@ import pretend from warehouse import search +from warehouse.rate_limiting import IRateLimiter, RateLimit from ...common.db.packaging import ProjectFactory, ReleaseFactory @@ -118,11 +119,15 @@ def test_includeme(monkeypatch): "aws.key_id": "AAAAAAAAAAAA", "aws.secret_key": "deadbeefdeadbeefdeadbeef", "opensearch.url": opensearch_url, + "warehouse.search.ratelimit_string": "10 per second", }, __setitem__=registry.__setitem__, ), add_request_method=pretend.call_recorder(lambda *a, **kw: None), add_periodic_task=pretend.call_recorder(lambda *a, **kw: None), + register_service_factory=pretend.call_recorder( + lambda factory, iface, name=None: None + ), ) search.includeme(config) @@ -132,7 +137,7 @@ def test_includeme(monkeypatch): ] assert len(opensearch_client_init.calls) == 1 assert opensearch_client_init.calls[0].kwargs["hosts"] == ["https://some.url"] - assert opensearch_client_init.calls[0].kwargs["timeout"] == 2 + assert opensearch_client_init.calls[0].kwargs["timeout"] == 0.5 assert opensearch_client_init.calls[0].kwargs["retry_on_timeout"] is False assert ( opensearch_client_init.calls[0].kwargs["connection_class"] @@ -147,3 +152,7 @@ def test_includeme(monkeypatch): assert config.add_request_method.calls == [ pretend.call(search.opensearch, name="opensearch", reify=True) ] + + assert config.register_service_factory.calls == [ + pretend.call(RateLimit("10 per second"), IRateLimiter, name="search") + ] diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index ba312034fefb..4c263fa5742e 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -333,6 +333,7 @@ def __init__(self): "warehouse.manage.oidc.ip_registration_ratelimit_string": "100 per day", "warehouse.packaging.project_create_user_ratelimit_string": "20 per hour", "warehouse.packaging.project_create_ip_ratelimit_string": "40 per hour", + "warehouse.search.ratelimit_string": "5 per second", "oidc.backend": "warehouse.oidc.services.OIDCPublisherService", "integrity.backend": "warehouse.attestations.services.IntegrityService", "warehouse.organizations.max_undecided_organization_applications": 3, diff --git a/tests/unit/test_views.py b/tests/unit/test_views.py index 1ac64553e988..0e5e460e03ad 100644 --- a/tests/unit/test_views.py +++ b/tests/unit/test_views.py @@ -23,6 +23,7 @@ HTTPRequestEntityTooLarge, HTTPSeeOther, HTTPServiceUnavailable, + HTTPTooManyRequests, ) from trove_classifiers import sorted_classifiers from webob.multidict import MultiDict @@ -30,6 +31,7 @@ from warehouse import views from warehouse.errors import WarehouseDenied from warehouse.packaging.models import ProjectFactory as DBProjectFactory +from warehouse.rate_limiting.interfaces import IRateLimiter from warehouse.utils.row_counter import compute_row_counts from warehouse.views import ( SecurityKeyGiveaway, @@ -476,12 +478,21 @@ def test_csi_sidebar_sponsor_logo(): class TestSearch: @pytest.mark.parametrize("page", [None, 1, 5]) - def test_with_a_query(self, monkeypatch, db_request, metrics, page): + def test_with_a_query( + self, monkeypatch, pyramid_services, db_request, metrics, page + ): params = MultiDict({"q": "foo bar"}) if page is not None: params["page"] = page db_request.params = params + fake_rate_limiter = pretend.stub( + test=lambda *a: True, hit=lambda *a: True, resets_in=lambda *a: None + ) + pyramid_services.register_service( + fake_rate_limiter, IRateLimiter, None, name="search" + ) + db_request.opensearch = pretend.stub() opensearch_query = pretend.stub() get_opensearch_query = pretend.call_recorder(lambda *a, **kw: opensearch_query) @@ -514,12 +525,21 @@ def test_with_a_query(self, monkeypatch, db_request, metrics, page): ] @pytest.mark.parametrize("page", [None, 1, 5]) - def test_with_classifiers(self, monkeypatch, db_request, metrics, page): + def test_with_classifiers( + self, monkeypatch, pyramid_services, db_request, metrics, page + ): params = MultiDict([("q", "foo bar"), ("c", "foo :: bar"), ("c", "fiz :: buz")]) if page is not None: params["page"] = page db_request.params = params + fake_rate_limiter = pretend.stub( + test=lambda *a: True, hit=lambda *a: True, resets_in=lambda *a: None + ) + pyramid_services.register_service( + fake_rate_limiter, IRateLimiter, None, name="search" + ) + opensearch_query = pretend.stub() db_request.opensearch = pretend.stub() get_opensearch_query = pretend.call_recorder(lambda *a, **kw: opensearch_query) @@ -562,6 +582,7 @@ def test_with_classifiers(self, monkeypatch, db_request, metrics, page): assert page_cls.calls == [ pretend.call(opensearch_query, url_maker=url_maker, page=page or 1) ] + assert url_maker_factory.calls == [pretend.call(db_request)] assert get_opensearch_query.calls == [ pretend.call(db_request.opensearch, params.get("q"), "", params.getall("c")) @@ -570,10 +591,19 @@ def test_with_classifiers(self, monkeypatch, db_request, metrics, page): pretend.call("warehouse.views.search.results", 1000) ] - def test_returns_404_with_pagenum_too_high(self, monkeypatch, db_request, metrics): + def test_returns_404_with_pagenum_too_high( + self, monkeypatch, pyramid_services, db_request, metrics + ): params = MultiDict({"page": 15}) db_request.params = params + fake_rate_limiter = pretend.stub( + test=lambda *a: True, hit=lambda *a: True, resets_in=lambda *a: None + ) + pyramid_services.register_service( + fake_rate_limiter, IRateLimiter, None, name="search" + ) + opensearch_query = pretend.stub() db_request.opensearch = pretend.stub(query=lambda *a, **kw: opensearch_query) @@ -594,10 +624,19 @@ def test_returns_404_with_pagenum_too_high(self, monkeypatch, db_request, metric assert url_maker_factory.calls == [pretend.call(db_request)] assert metrics.histogram.calls == [] - def test_raises_400_with_pagenum_type_str(self, monkeypatch, db_request, metrics): + def test_raises_400_with_pagenum_type_str( + self, monkeypatch, pyramid_services, db_request, metrics + ): params = MultiDict({"page": "abc"}) db_request.params = params + fake_rate_limiter = pretend.stub( + test=lambda *a: True, hit=lambda *a: True, resets_in=lambda *a: None + ) + pyramid_services.register_service( + fake_rate_limiter, IRateLimiter, None, name="search" + ) + opensearch_query = pretend.stub() db_request.opensearch = pretend.stub(query=lambda *a, **kw: opensearch_query) @@ -615,23 +654,40 @@ def test_raises_400_with_pagenum_type_str(self, monkeypatch, db_request, metrics assert page_cls.calls == [] assert metrics.histogram.calls == [] - def test_return_413_when_query_too_long(self, db_request, metrics): + def test_return_413_when_query_too_long( + self, pyramid_services, db_request, metrics + ): params = MultiDict({"q": "a" * 1001}) db_request.params = params + fake_rate_limiter = pretend.stub( + test=lambda *a: True, hit=lambda *a: True, resets_in=lambda *a: None + ) + pyramid_services.register_service( + fake_rate_limiter, IRateLimiter, None, name="search" + ) + with pytest.raises(HTTPRequestEntityTooLarge): search(db_request) assert metrics.increment.calls == [ - pretend.call("warehouse.views.search.error", tags=["error:query_too_long"]) + pretend.call("warehouse.search.ratelimiter.hit"), + pretend.call("warehouse.views.search.error", tags=["error:query_too_long"]), ] def test_returns_503_when_opensearch_unavailable( - self, monkeypatch, db_request, metrics + self, monkeypatch, pyramid_services, db_request, metrics ): params = MultiDict({"page": 15}) db_request.params = params + fake_rate_limiter = pretend.stub( + test=lambda *a: True, hit=lambda *a: True, resets_in=lambda *a: None + ) + pyramid_services.register_service( + fake_rate_limiter, IRateLimiter, None, name="search" + ) + opensearch_query = pretend.stub() db_request.opensearch = pretend.stub(query=lambda *a, **kw: opensearch_query) @@ -648,9 +704,47 @@ def raiser(*args, **kwargs): search(db_request) assert url_maker_factory.calls == [pretend.call(db_request)] - assert metrics.increment.calls == [pretend.call("warehouse.views.search.error")] + assert metrics.increment.calls == [ + pretend.call("warehouse.search.ratelimiter.hit"), + pretend.call("warehouse.views.search.error"), + ] assert metrics.histogram.calls == [] + @pytest.mark.parametrize("resets_in", [None, 1, 5]) + def test_returns_429_when_ratelimited( + self, monkeypatch, pyramid_services, db_request, metrics, resets_in + ): + params = MultiDict({"q": "foo bar"}) + db_request.params = params + + fake_rate_limiter = pretend.stub( + test=lambda *a: False, + hit=lambda *a: True, + resets_in=lambda *a: ( + None + if resets_in is None + else pretend.stub(total_seconds=lambda *a: resets_in) + ), + ) + pyramid_services.register_service( + fake_rate_limiter, IRateLimiter, None, name="search" + ) + + with pytest.raises(HTTPTooManyRequests) as exc_info: + search(db_request) + + message = ( + "Your search query could not be performed because there were too " + "many requests by the client." + ) + if resets_in is not None: + message += f" Limit may reset in {resets_in} seconds." + + assert exc_info.value.args[0] == message + assert metrics.increment.calls == [ + pretend.call("warehouse.search.ratelimiter.exceeded") + ] + def test_classifiers(db_request): assert list_classifiers(db_request) == {"classifiers": sorted_classifiers} diff --git a/warehouse/config.py b/warehouse/config.py index 78dc28eb859f..046d41164529 100644 --- a/warehouse/config.py +++ b/warehouse/config.py @@ -556,6 +556,12 @@ def configure(settings=None): "PROJECT_CREATE_IP_RATELIMIT_STRING", default="40 per hour", ) + maybe_set( + settings, + "warehouse.search.ratelimit_string", + "SEARCH_RATELIMIT_STRING", + default="5 per second", + ) # OIDC feature flags and settings maybe_set(settings, "warehouse.oidc.audience", "OIDC_AUDIENCE") diff --git a/warehouse/locale/messages.pot b/warehouse/locale/messages.pot index d074cfa7a6b2..c6a25e337279 100644 --- a/warehouse/locale/messages.pot +++ b/warehouse/locale/messages.pot @@ -1,16 +1,16 @@ -#: warehouse/views.py:147 +#: warehouse/views.py:149 msgid "" "You must verify your **primary** email address before you can perform " "this action." msgstr "" -#: warehouse/views.py:163 +#: warehouse/views.py:165 msgid "" "Two-factor authentication must be enabled on your account to perform this" " action." msgstr "" -#: warehouse/views.py:299 +#: warehouse/views.py:301 msgid "Locale updated" msgstr "" diff --git a/warehouse/search/__init__.py b/warehouse/search/__init__.py index a27b7179ec73..2207ed6c5b4a 100644 --- a/warehouse/search/__init__.py +++ b/warehouse/search/__init__.py @@ -21,6 +21,7 @@ from warehouse import db from warehouse.packaging.models import Project, Release +from warehouse.rate_limiting import IRateLimiter, RateLimit from warehouse.search.utils import get_index @@ -79,13 +80,18 @@ def opensearch(request): def includeme(config): + ratelimit_string = config.registry.settings.get("warehouse.search.ratelimit_string") + config.register_service_factory( + RateLimit(ratelimit_string), IRateLimiter, name="search" + ) + p = parse_url(config.registry.settings["opensearch.url"]) qs = urllib.parse.parse_qs(p.query) kwargs = { "hosts": [urllib.parse.urlunparse((p.scheme, p.netloc) + ("",) * 4)], "verify_certs": True, "ca_certs": certifi.where(), - "timeout": 2, + "timeout": 0.5, "retry_on_timeout": False, "serializer": opensearchpy.serializer.serializer, "max_retries": 1, diff --git a/warehouse/views.py b/warehouse/views.py index 8984957ee1ea..b539d0ef5e02 100644 --- a/warehouse/views.py +++ b/warehouse/views.py @@ -26,6 +26,7 @@ HTTPRequestEntityTooLarge, HTTPSeeOther, HTTPServiceUnavailable, + HTTPTooManyRequests, exception_response, ) from pyramid.i18n import make_localizer @@ -60,6 +61,7 @@ Release, ReleaseClassifiers, ) +from warehouse.rate_limiting import IRateLimiter from warehouse.search.queries import SEARCH_FILTER_ORDER, get_opensearch_query from warehouse.utils.cors import _CORS_HEADERS from warehouse.utils.http import is_safe_url @@ -322,8 +324,23 @@ def list_classifiers(request): has_translations=True, ) def search(request): + ratelimiter = request.find_service(IRateLimiter, name="search", context=None) metrics = request.find_service(IMetricsService, context=None) + ratelimiter.hit(request.remote_addr) + if not ratelimiter.test(request.remote_addr): + metrics.increment("warehouse.search.ratelimiter.exceeded") + message = ( + "Your search query could not be performed because there were too " + "many requests by the client." + ) + _resets_in = ratelimiter.resets_in(request.remote_addr) + if _resets_in is not None: + _resets_in = max(1, int(_resets_in.total_seconds())) + message += f" Limit may reset in {_resets_in} seconds." + raise HTTPTooManyRequests(message) + metrics.increment("warehouse.search.ratelimiter.hit") + querystring = request.params.get("q", "").replace("'", '"') # Bail early for really long queries before ES raises an error if len(querystring) > 1000: