Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PB-737: Fix CORS issues when working on localhost #69

Merged
merged 2 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .env.default
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ AWS_SECRET_ACCESS_KEY=dummy123
AWS_ENDPOINT_URL=http://localhost:8080
AWS_DEFAULT_REGION=eu-central-1
AWS_DYNAMODB_TABLE_NAME=test-db
ALLOWED_DOMAINS=.*localhost((:[0-9]*)?|\/)?,.*admin\.ch,.*bgdi\.ch
ALLOWED_DOMAINS=localhost,.*\.geo\.admin\.ch,.*\.bgdi\.ch
STAGING=local
2 changes: 1 addition & 1 deletion .env.testing
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ALLOWED_DOMAINS=.*\.geo\.admin\.ch,.*\.bgdi\.ch,http://localhost((:[0-9]*)?|\/)?
ALLOWED_DOMAINS=localhost,.*\.geo\.admin\.ch,.*\.bgdi\.ch
AWS_ACCESS_KEY_ID=testing
AWS_SECRET_ACCESS_KEY=testing
AWS_SECURITY_TOKEN=testing
Expand Down
6 changes: 1 addition & 5 deletions app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

from app.helpers.utils import get_redirect_param
from app.helpers.utils import get_registered_method
from app.helpers.utils import is_domain_allowed
from app.helpers.utils import make_error_msg
from app.settings import ALLOWED_DOMAINS_PATTERN
from app.settings import CACHE_CONTROL
from app.settings import CACHE_CONTROL_4XX

Expand All @@ -25,10 +25,6 @@
app.config.from_mapping({"TRAP_HTTP_EXCEPTIONS": True})


def is_domain_allowed(domain):
return re.fullmatch(ALLOWED_DOMAINS_PATTERN, domain) is not None


@app.before_request
# Add quick log of the routes used to all request.
# Important: this should be the first before_request method, to ensure
Expand Down
17 changes: 15 additions & 2 deletions app/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,12 @@ def get_url():
f"The url given as parameter was too long. (limit is 2046 "
f"characters, {len(url)} given)"
)
if not re.fullmatch(ALLOWED_DOMAINS_PATTERN, urlparse(url).netloc):
logger.error('URL(%s) given as a parameter is not allowed', url)
if not is_domain_allowed(url):
logger.error(
'URL(%s) given as a parameter is not allowed, test pattern %s',
url,
ALLOWED_DOMAINS_PATTERN
)
abort(400, 'URL given as a parameter is not allowed.')

return url
Expand All @@ -132,3 +136,12 @@ def strtobool(value) -> bool:
if value in ('n', 'no', 'f', 'false', 'off', '0'):
return False
raise ValueError(f"invalid truth value \'{value}\'")


def is_domain_allowed(url):
"""Check if the url contain a domain that is allowed
"""
domain = urlparse(url).hostname
if domain:
return re.fullmatch(ALLOWED_DOMAINS_PATTERN, domain) is not None
return False
24 changes: 13 additions & 11 deletions tests/unit_tests/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import re
import unittest
from urllib.parse import urlparse

import boto3

Expand Down Expand Up @@ -83,18 +84,19 @@ def setUp(self):
def tearDown(self):
self.table.delete()

def assertCors(
self,
response,
expected_allowed_methods,
origin_pattern=ALLOWED_DOMAINS_PATTERN
): # pylint: disable=invalid-name
def assertCors(self, response, expected_allowed_methods, all_origin=False): # pylint: disable=invalid-name
self.assertIn('Access-Control-Allow-Origin', response.headers)
self.assertIsNotNone(
re.fullmatch(origin_pattern, response.headers['Access-Control-Allow-Origin']),
msg=f"Access-Control-Allow-Origin={response.headers['Access-Control-Allow-Origin']}"
f" doesn't match {origin_pattern}"
)
if all_origin:
self.assertEqual(response.headers['Access-Control-Allow-Origin'], '*')
else:
allow_origin_domain = urlparse(response.headers['Access-Control-Allow-Origin']).hostname
self.assertIsNotNone(
re.fullmatch(
ALLOWED_DOMAINS_PATTERN, allow_origin_domain if allow_origin_domain else ''
),
msg=f"Access-Control-Allow-Origin={response.headers['Access-Control-Allow-Origin']}"
f" doesn't match {ALLOWED_DOMAINS_PATTERN}"
)
self.assertIn('Access-Control-Allow-Methods', response.headers)
self.assertListEqual(
sorted(expected_allowed_methods),
Expand Down
54 changes: 29 additions & 25 deletions tests/unit_tests/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class TestRoutes(BaseShortlinkTestCase):

def test_checker_ok(self):
# checker
response = self.app.get(url_for('checker'), headers={"Origin": "map.geo.admin.ch"})
response = self.app.get(url_for('checker'), headers={"Origin": "https://map.geo.admin.ch"})
self.assertEqual(response.status_code, 200)
self.assertNotIn('Cache-Control', response.headers)
self.assertEqual(response.content_type, "application/json; charset=utf-8")
Expand All @@ -27,7 +27,9 @@ def test_checker_ok(self):
def test_create_shortlink_ok(self):
url = "https://map.geo.admin.ch/#/map?lang=en&center=2647850.83,1120124.2&z=1.812&bgLayer=ch.swisstopo.pixelkarte-farbe&top" # pylint: disable=line-too-long
response = self.app.post(
url_for('create_shortlink'), json={"url": url}, headers={"Origin": "map.geo.admin.ch"}
url_for('create_shortlink'),
json={"url": url},
headers={"Origin": "https://map.geo.admin.ch"}
)
self.assertEqual(response.status_code, 201)
self.assertCors(response, ['POST', 'OPTIONS'])
Expand All @@ -49,7 +51,9 @@ def test_create_shortlink_ok(self):
)
# Check that second call returns 200 and the same short url
response = self.app.post(
url_for('create_shortlink'), json={"url": url}, headers={"Origin": "map.geo.admin.ch"}
url_for('create_shortlink'),
json={"url": url},
headers={"Origin": "https://map.geo.admin.ch"}
)
self.assertEqual(response.status_code, 200)
self.assertCors(response, ['POST', 'OPTIONS'])
Expand All @@ -59,7 +63,7 @@ def test_create_shortlink_ok(self):

def test_create_shortlink_no_json(self):
response = self.app.post(
url_for('create_shortlink'), headers={"Origin": "map.geo.admin.ch"}
url_for('create_shortlink'), headers={"Origin": "https://map.geo.admin.ch"}
)
self.assertEqual(415, response.status_code)
self.assertCors(response, ['POST', 'OPTIONS'])
Expand All @@ -77,7 +81,7 @@ def test_create_shortlink_no_json(self):

def test_create_shortlink_no_url(self):
response = self.app.post(
url_for('create_shortlink'), json={}, headers={"Origin": "map.geo.admin.ch"}
url_for('create_shortlink'), json={}, headers={"Origin": "https://map.geo.admin.ch"}
)
self.assertEqual(400, response.status_code)
self.assertCors(response, ['POST', 'OPTIONS'])
Expand All @@ -97,7 +101,7 @@ def test_create_shortlink_no_hostname(self):
response = self.app.post(
url_for('create_shortlink'),
json={"url": f"{wrong_url}"},
headers={"Origin": "map.geo.admin.ch"}
headers={"Origin": "https://map.geo.admin.ch"}
)
self.assertEqual(response.status_code, 400)
self.assertCors(response, ['POST', 'OPTIONS'])
Expand All @@ -116,7 +120,7 @@ def test_create_shortlink_non_allowed_hostname(self):
response = self.app.post(
url_for('create_shortlink'),
json={"url": "https://non-allowed.hostname.ch/test"},
headers={"Origin": "map.geo.admin.ch"}
headers={"Origin": "https://map.geo.admin.ch"}
)
self.assertEqual(response.status_code, 400)
self.assertCors(response, ['POST', 'OPTIONS'])
Expand All @@ -135,7 +139,7 @@ def test_create_shortlink_non_allowed_hostname_containing_admin_address(self):
response = self.app.post(
url_for('create_shortlink'),
json={"url": "https://map.geo.admin.ch.non-allowed.hostname.ch/test"},
headers={"Origin": "map.geo.admin.ch"}
headers={"Origin": "https://map.geo.admin.ch"}
)
self.assertEqual(response.status_code, 400)
self.assertCors(response, ['POST', 'OPTIONS'])
Expand All @@ -156,7 +160,7 @@ def test_create_shortlink_url_too_long(self):
url_for('create_shortlink'),
json={"url": url},
content_type="application/json",
headers={"Origin": "map.geo.admin.ch"}
headers={"Origin": "https://map.geo.admin.ch"}
)
self.assertEqual(response.status_code, 400)
self.assertCors(response, ['POST', 'OPTIONS'])
Expand All @@ -178,7 +182,7 @@ def test_redirect_shortlink_ok(self):
for short_id, url in self.uuid_to_url_dict.items():
response = self.app.get(url_for('get_shortlink', shortlink_id=short_id))
self.assertEqual(response.status_code, 301)
self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], origin_pattern=r"^\*$")
self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], all_origin=True)
self.assertIn('Cache-Control', response.headers)
self.assertIn('max-age=', response.headers['Cache-Control'])
self.assertEqual(response.content_type, "text/html; charset=utf-8")
Expand All @@ -192,7 +196,7 @@ def test_redirect_shortlink_ok_with_query(self):
headers={"Origin": "www.example.com"}
)
self.assertEqual(response.status_code, 301)
self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], origin_pattern=r"^\*$")
self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], all_origin=True)
self.assertIn('Cache-Control', response.headers)
self.assertIn('max-age=', response.headers['Cache-Control'])
self.assertEqual(response.content_type, "text/html; charset=utf-8")
Expand All @@ -204,7 +208,7 @@ def test_shortlink_fetch_nok_invalid_redirect_parameter(self):
url_for('get_shortlink', shortlink_id=short_id),
query_string={'redirect': 'banana'},
content_type="text/html",
headers={"Origin": "map.geo.admin.ch"}
headers={"Origin": "https://map.geo.admin.ch"}
)
expected_json = {
'success': False,
Expand All @@ -226,7 +230,7 @@ def test_shortlink_fetch_nok_invalid_redirect_parameter(self):
def test_redirect_shortlink_url_not_found(self):
response = self.app.get(
url_for('get_shortlink', shortlink_id='nonexistent'),
headers={"Origin": "map.geo.admin.ch"}
headers={"Origin": "https://map.geo.admin.ch"}
)
expected_json = {
'success': False,
Expand All @@ -235,7 +239,7 @@ def test_redirect_shortlink_url_not_found(self):
}
}
self.assertEqual(response.status_code, 404)
self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], origin_pattern=r"^\*$")
self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], all_origin=True)
self.assertIn('Cache-Control', response.headers)
self.assertIn('max-age=3600', response.headers['Cache-Control'])
self.assertIn('application/json', response.content_type)
Expand All @@ -246,7 +250,7 @@ def test_fetch_full_url_from_shortlink_ok(self):
response = self.app.get(
url_for('get_shortlink', shortlink_id=short_id),
query_string={'redirect': 'false'},
headers={"Origin": "map.geo.admin.ch"}
headers={"Origin": "https://map.geo.admin.ch"}
)
self.assertEqual(response.status_code, 200)
self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'])
Expand All @@ -262,7 +266,7 @@ def test_fetch_full_url_from_shortlink_ok_explicit_parameter(self):
response = self.app.get(
url_for('get_shortlink', shortlink_id=short_id),
query_string={'redirect': 'false'},
headers={"Origin": "map.geo.admin.ch"}
headers={"Origin": "https://map.geo.admin.ch"}
)
self.assertEqual(response.status_code, 200)
self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'])
Expand All @@ -277,7 +281,7 @@ def test_fetch_full_url_from_shortlink_url_not_found(self):
response = self.app.get(
url_for('get_shortlink', shortlink_id='nonexistent'),
query_string={'redirect': 'false'},
headers={"Origin": "map.geo.admin.ch"}
headers={"Origin": "https://map.geo.admin.ch"}
)
self.assertEqual(response.status_code, 404)
self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'])
Expand Down Expand Up @@ -325,12 +329,12 @@ def test_create_shortlink_origin_not_allowed(self, headers):
)

@params(
{'Origin': 'map.geo.admin.ch'},
{'Origin': 'https://map.geo.admin.ch'},
{
'Origin': 'map.geo.admin.ch', 'Sec-Fetch-Site': 'same-site'
'Origin': 'https://map.geo.admin.ch', 'Sec-Fetch-Site': 'same-site'
},
{
'Origin': 's.geo.admin.ch', 'Sec-Fetch-Site': 'same-origin'
'Origin': 'https://s.geo.admin.ch', 'Sec-Fetch-Site': 'same-origin'
},
{
'Origin': 'http://localhost', 'Sec-Fetch-Site': 'cross-site'
Expand Down Expand Up @@ -389,19 +393,19 @@ def test_get_shortlink_redirect_origin_allowed(self, headers):
headers=headers
)
self.assertEqual(response.status_code, 301)
self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], origin_pattern=r"^\*$")
self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], all_origin=True)

response = self.app.get(url_for('get_shortlink', shortlink_id=short_id), headers=headers)
self.assertEqual(response.status_code, 301)
self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], origin_pattern=r"^\*$")
self.assertCors(response, ['GET', 'HEAD', 'OPTIONS'], all_origin=True)

@params(
{'Origin': 'map.geo.admin.ch'},
{'Origin': 'https://map.geo.admin.ch'},
{
'Origin': 'map.geo.admin.ch', 'Sec-Fetch-Site': 'same-site'
'Origin': 'https://map.geo.admin.ch', 'Sec-Fetch-Site': 'same-site'
},
{
'Origin': 's.geo.admin.ch', 'Sec-Fetch-Site': 'same-origin'
'Origin': 'https://s.geo.admin.ch', 'Sec-Fetch-Site': 'same-origin'
},
{
'Origin': 'http://localhost', 'Sec-Fetch-Site': 'cross-site'
Expand Down