Skip to content

Commit

Permalink
implement some basic guards on our /search endpoint (#16812)
Browse files Browse the repository at this point in the history
* implement some guards on our /search endpoint

- Move timeout from 2000ms to 500ms
- Add a simple rate limiter to the endpoint by client ip

* no need to explicitly pass empty tags list to metrics.implement
  • Loading branch information
ewdurbin authored Sep 30, 2024
1 parent bed12d9 commit 126ac91
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 13 deletions.
11 changes: 10 additions & 1 deletion tests/unit/search/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import pretend

from warehouse import search
from warehouse.rate_limiting import IRateLimiter, RateLimit

from ...common.db.packaging import ProjectFactory, ReleaseFactory

Expand Down Expand Up @@ -118,11 +119,15 @@ def test_includeme(monkeypatch):
"aws.key_id": "AAAAAAAAAAAA",
"aws.secret_key": "deadbeefdeadbeefdeadbeef",
"opensearch.url": opensearch_url,
"warehouse.search.ratelimit_string": "10 per second",
},
__setitem__=registry.__setitem__,
),
add_request_method=pretend.call_recorder(lambda *a, **kw: None),
add_periodic_task=pretend.call_recorder(lambda *a, **kw: None),
register_service_factory=pretend.call_recorder(
lambda factory, iface, name=None: None
),
)

search.includeme(config)
Expand All @@ -132,7 +137,7 @@ def test_includeme(monkeypatch):
]
assert len(opensearch_client_init.calls) == 1
assert opensearch_client_init.calls[0].kwargs["hosts"] == ["https://some.url"]
assert opensearch_client_init.calls[0].kwargs["timeout"] == 2
assert opensearch_client_init.calls[0].kwargs["timeout"] == 0.5
assert opensearch_client_init.calls[0].kwargs["retry_on_timeout"] is False
assert (
opensearch_client_init.calls[0].kwargs["connection_class"]
Expand All @@ -147,3 +152,7 @@ def test_includeme(monkeypatch):
assert config.add_request_method.calls == [
pretend.call(search.opensearch, name="opensearch", reify=True)
]

assert config.register_service_factory.calls == [
pretend.call(RateLimit("10 per second"), IRateLimiter, name="search")
]
1 change: 1 addition & 0 deletions tests/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def __init__(self):
"warehouse.manage.oidc.ip_registration_ratelimit_string": "100 per day",
"warehouse.packaging.project_create_user_ratelimit_string": "20 per hour",
"warehouse.packaging.project_create_ip_ratelimit_string": "40 per hour",
"warehouse.search.ratelimit_string": "5 per second",
"oidc.backend": "warehouse.oidc.services.OIDCPublisherService",
"integrity.backend": "warehouse.attestations.services.IntegrityService",
"warehouse.organizations.max_undecided_organization_applications": 3,
Expand Down
110 changes: 102 additions & 8 deletions tests/unit/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
HTTPRequestEntityTooLarge,
HTTPSeeOther,
HTTPServiceUnavailable,
HTTPTooManyRequests,
)
from trove_classifiers import sorted_classifiers
from webob.multidict import MultiDict

from warehouse import views
from warehouse.errors import WarehouseDenied
from warehouse.packaging.models import ProjectFactory as DBProjectFactory
from warehouse.rate_limiting.interfaces import IRateLimiter
from warehouse.utils.row_counter import compute_row_counts
from warehouse.views import (
SecurityKeyGiveaway,
Expand Down Expand Up @@ -476,12 +478,21 @@ def test_csi_sidebar_sponsor_logo():

class TestSearch:
@pytest.mark.parametrize("page", [None, 1, 5])
def test_with_a_query(self, monkeypatch, db_request, metrics, page):
def test_with_a_query(
self, monkeypatch, pyramid_services, db_request, metrics, page
):
params = MultiDict({"q": "foo bar"})
if page is not None:
params["page"] = page
db_request.params = params

fake_rate_limiter = pretend.stub(
test=lambda *a: True, hit=lambda *a: True, resets_in=lambda *a: None
)
pyramid_services.register_service(
fake_rate_limiter, IRateLimiter, None, name="search"
)

db_request.opensearch = pretend.stub()
opensearch_query = pretend.stub()
get_opensearch_query = pretend.call_recorder(lambda *a, **kw: opensearch_query)
Expand Down Expand Up @@ -514,12 +525,21 @@ def test_with_a_query(self, monkeypatch, db_request, metrics, page):
]

@pytest.mark.parametrize("page", [None, 1, 5])
def test_with_classifiers(self, monkeypatch, db_request, metrics, page):
def test_with_classifiers(
self, monkeypatch, pyramid_services, db_request, metrics, page
):
params = MultiDict([("q", "foo bar"), ("c", "foo :: bar"), ("c", "fiz :: buz")])
if page is not None:
params["page"] = page
db_request.params = params

fake_rate_limiter = pretend.stub(
test=lambda *a: True, hit=lambda *a: True, resets_in=lambda *a: None
)
pyramid_services.register_service(
fake_rate_limiter, IRateLimiter, None, name="search"
)

opensearch_query = pretend.stub()
db_request.opensearch = pretend.stub()
get_opensearch_query = pretend.call_recorder(lambda *a, **kw: opensearch_query)
Expand Down Expand Up @@ -562,6 +582,7 @@ def test_with_classifiers(self, monkeypatch, db_request, metrics, page):
assert page_cls.calls == [
pretend.call(opensearch_query, url_maker=url_maker, page=page or 1)
]

assert url_maker_factory.calls == [pretend.call(db_request)]
assert get_opensearch_query.calls == [
pretend.call(db_request.opensearch, params.get("q"), "", params.getall("c"))
Expand All @@ -570,10 +591,19 @@ def test_with_classifiers(self, monkeypatch, db_request, metrics, page):
pretend.call("warehouse.views.search.results", 1000)
]

def test_returns_404_with_pagenum_too_high(self, monkeypatch, db_request, metrics):
def test_returns_404_with_pagenum_too_high(
self, monkeypatch, pyramid_services, db_request, metrics
):
params = MultiDict({"page": 15})
db_request.params = params

fake_rate_limiter = pretend.stub(
test=lambda *a: True, hit=lambda *a: True, resets_in=lambda *a: None
)
pyramid_services.register_service(
fake_rate_limiter, IRateLimiter, None, name="search"
)

opensearch_query = pretend.stub()
db_request.opensearch = pretend.stub(query=lambda *a, **kw: opensearch_query)

Expand All @@ -594,10 +624,19 @@ def test_returns_404_with_pagenum_too_high(self, monkeypatch, db_request, metric
assert url_maker_factory.calls == [pretend.call(db_request)]
assert metrics.histogram.calls == []

def test_raises_400_with_pagenum_type_str(self, monkeypatch, db_request, metrics):
def test_raises_400_with_pagenum_type_str(
self, monkeypatch, pyramid_services, db_request, metrics
):
params = MultiDict({"page": "abc"})
db_request.params = params

fake_rate_limiter = pretend.stub(
test=lambda *a: True, hit=lambda *a: True, resets_in=lambda *a: None
)
pyramid_services.register_service(
fake_rate_limiter, IRateLimiter, None, name="search"
)

opensearch_query = pretend.stub()
db_request.opensearch = pretend.stub(query=lambda *a, **kw: opensearch_query)

Expand All @@ -615,23 +654,40 @@ def test_raises_400_with_pagenum_type_str(self, monkeypatch, db_request, metrics
assert page_cls.calls == []
assert metrics.histogram.calls == []

def test_return_413_when_query_too_long(self, db_request, metrics):
def test_return_413_when_query_too_long(
self, pyramid_services, db_request, metrics
):
params = MultiDict({"q": "a" * 1001})
db_request.params = params

fake_rate_limiter = pretend.stub(
test=lambda *a: True, hit=lambda *a: True, resets_in=lambda *a: None
)
pyramid_services.register_service(
fake_rate_limiter, IRateLimiter, None, name="search"
)

with pytest.raises(HTTPRequestEntityTooLarge):
search(db_request)

assert metrics.increment.calls == [
pretend.call("warehouse.views.search.error", tags=["error:query_too_long"])
pretend.call("warehouse.search.ratelimiter.hit"),
pretend.call("warehouse.views.search.error", tags=["error:query_too_long"]),
]

def test_returns_503_when_opensearch_unavailable(
self, monkeypatch, db_request, metrics
self, monkeypatch, pyramid_services, db_request, metrics
):
params = MultiDict({"page": 15})
db_request.params = params

fake_rate_limiter = pretend.stub(
test=lambda *a: True, hit=lambda *a: True, resets_in=lambda *a: None
)
pyramid_services.register_service(
fake_rate_limiter, IRateLimiter, None, name="search"
)

opensearch_query = pretend.stub()
db_request.opensearch = pretend.stub(query=lambda *a, **kw: opensearch_query)

Expand All @@ -648,9 +704,47 @@ def raiser(*args, **kwargs):
search(db_request)

assert url_maker_factory.calls == [pretend.call(db_request)]
assert metrics.increment.calls == [pretend.call("warehouse.views.search.error")]
assert metrics.increment.calls == [
pretend.call("warehouse.search.ratelimiter.hit"),
pretend.call("warehouse.views.search.error"),
]
assert metrics.histogram.calls == []

@pytest.mark.parametrize("resets_in", [None, 1, 5])
def test_returns_429_when_ratelimited(
self, monkeypatch, pyramid_services, db_request, metrics, resets_in
):
params = MultiDict({"q": "foo bar"})
db_request.params = params

fake_rate_limiter = pretend.stub(
test=lambda *a: False,
hit=lambda *a: True,
resets_in=lambda *a: (
None
if resets_in is None
else pretend.stub(total_seconds=lambda *a: resets_in)
),
)
pyramid_services.register_service(
fake_rate_limiter, IRateLimiter, None, name="search"
)

with pytest.raises(HTTPTooManyRequests) as exc_info:
search(db_request)

message = (
"Your search query could not be performed because there were too "
"many requests by the client."
)
if resets_in is not None:
message += f" Limit may reset in {resets_in} seconds."

assert exc_info.value.args[0] == message
assert metrics.increment.calls == [
pretend.call("warehouse.search.ratelimiter.exceeded")
]


def test_classifiers(db_request):
assert list_classifiers(db_request) == {"classifiers": sorted_classifiers}
Expand Down
6 changes: 6 additions & 0 deletions warehouse/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,12 @@ def configure(settings=None):
"PROJECT_CREATE_IP_RATELIMIT_STRING",
default="40 per hour",
)
maybe_set(
settings,
"warehouse.search.ratelimit_string",
"SEARCH_RATELIMIT_STRING",
default="5 per second",
)

# OIDC feature flags and settings
maybe_set(settings, "warehouse.oidc.audience", "OIDC_AUDIENCE")
Expand Down
6 changes: 3 additions & 3 deletions warehouse/locale/messages.pot
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
#: warehouse/views.py:147
#: warehouse/views.py:149
msgid ""
"You must verify your **primary** email address before you can perform "
"this action."
msgstr ""

#: warehouse/views.py:163
#: warehouse/views.py:165
msgid ""
"Two-factor authentication must be enabled on your account to perform this"
" action."
msgstr ""

#: warehouse/views.py:299
#: warehouse/views.py:301
msgid "Locale updated"
msgstr ""

Expand Down
8 changes: 7 additions & 1 deletion warehouse/search/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from warehouse import db
from warehouse.packaging.models import Project, Release
from warehouse.rate_limiting import IRateLimiter, RateLimit
from warehouse.search.utils import get_index


Expand Down Expand Up @@ -79,13 +80,18 @@ def opensearch(request):


def includeme(config):
ratelimit_string = config.registry.settings.get("warehouse.search.ratelimit_string")
config.register_service_factory(
RateLimit(ratelimit_string), IRateLimiter, name="search"
)

p = parse_url(config.registry.settings["opensearch.url"])
qs = urllib.parse.parse_qs(p.query)
kwargs = {
"hosts": [urllib.parse.urlunparse((p.scheme, p.netloc) + ("",) * 4)],
"verify_certs": True,
"ca_certs": certifi.where(),
"timeout": 2,
"timeout": 0.5,
"retry_on_timeout": False,
"serializer": opensearchpy.serializer.serializer,
"max_retries": 1,
Expand Down
17 changes: 17 additions & 0 deletions warehouse/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
HTTPRequestEntityTooLarge,
HTTPSeeOther,
HTTPServiceUnavailable,
HTTPTooManyRequests,
exception_response,
)
from pyramid.i18n import make_localizer
Expand Down Expand Up @@ -60,6 +61,7 @@
Release,
ReleaseClassifiers,
)
from warehouse.rate_limiting import IRateLimiter
from warehouse.search.queries import SEARCH_FILTER_ORDER, get_opensearch_query
from warehouse.utils.cors import _CORS_HEADERS
from warehouse.utils.http import is_safe_url
Expand Down Expand Up @@ -322,8 +324,23 @@ def list_classifiers(request):
has_translations=True,
)
def search(request):
ratelimiter = request.find_service(IRateLimiter, name="search", context=None)
metrics = request.find_service(IMetricsService, context=None)

ratelimiter.hit(request.remote_addr)
if not ratelimiter.test(request.remote_addr):
metrics.increment("warehouse.search.ratelimiter.exceeded")
message = (
"Your search query could not be performed because there were too "
"many requests by the client."
)
_resets_in = ratelimiter.resets_in(request.remote_addr)
if _resets_in is not None:
_resets_in = max(1, int(_resets_in.total_seconds()))
message += f" Limit may reset in {_resets_in} seconds."
raise HTTPTooManyRequests(message)
metrics.increment("warehouse.search.ratelimiter.hit")

querystring = request.params.get("q", "").replace("'", '"')
# Bail early for really long queries before ES raises an error
if len(querystring) > 1000:
Expand Down

0 comments on commit 126ac91

Please sign in to comment.