Skip to content

Commit

Permalink
Implement OpenSearch auth via boto credentials
Browse files Browse the repository at this point in the history
  • Loading branch information
pudo committed Jul 17, 2024
1 parent 243eb7a commit c345a7a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
18 changes: 15 additions & 3 deletions yente/provider/opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from typing import Any, Dict, List, Optional, cast
from typing import AsyncIterator
from opensearchpy import AsyncOpenSearch
from opensearchpy import AsyncOpenSearch, AWSV4SignerAuth
from opensearchpy.helpers import async_bulk, BulkIndexError
from opensearchpy.exceptions import NotFoundError, TransportError

Expand Down Expand Up @@ -34,13 +34,25 @@ async def create(cls) -> "OpenSearchProvider":
kwargs["sniff_on_connection_fail"] = True
if settings.INDEX_USERNAME and settings.INDEX_PASSWORD:
auth = (settings.INDEX_USERNAME, settings.INDEX_PASSWORD)
kwargs["basic_auth"] = auth
kwargs["http_auth"] = auth
if settings.OPENSEARCH_REGION and settings.OPENSEARCH_SERVICE:
from boto3 import Session

service = settings.OPENSEARCH_SERVICE.lower().strip()
if service not in ["es", "aoss"]:
raise RuntimeError(f"Invalid OpenSearch service: {service}")
credentials = Session().get_credentials()
kwargs["http_auth"] = AWSV4SignerAuth(
credentials,
settings.OPENSEARCH_REGION,
settings.OPENSEARCH_SERVICE,
)
if settings.INDEX_CA_CERT:
kwargs["ca_certs"] = settings.INDEX_CA_CERT
for retry in range(2, 9):
try:
es = AsyncOpenSearch(**kwargs)
await es.cluster.health(wait_for_status="yellow")
await es.cluster.health(wait_for_status="yellow", timeout=5)
return OpenSearchProvider(es)
except (TransportError, ConnectionError) as exc:
log.error("Cannot connect to OpenSearch: %r" % exc)
Expand Down
10 changes: 7 additions & 3 deletions yente/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,6 @@ def random_cron() -> str:
_INDEX_URL = "http://localhost:9200"
INDEX_URL = env_legacy("YENTE_INDEX_URL", "YENTE_ELASTICSEARCH_URL", _INDEX_URL)

ES_CLOUD_ID = env_get("YENTE_ELASTICSEARCH_CLOUD_ID")
# TODO: https://docs.aws.amazon.com/opensearch-service/latest/developerguide/serverless-sdk.html

_INDEX_USERNAME = env_legacy("YENTE_INDEX_USERNAME", "YENTE_ELASTICSEARCH_USERNAME", "")
INDEX_USERNAME = None if _INDEX_USERNAME == "" else _INDEX_USERNAME
_INDEX_PASSWORD = env_legacy("YENTE_INDEX_PASSWORD", "YENTE_ELASTICSEARCH_PASSWORD", "")
Expand All @@ -187,6 +184,13 @@ def random_cron() -> str:
INDEX_VERSION = env_str("YENTE_INDEX_VERSION", "009")
assert len(INDEX_VERSION) == 3, "Index version must be 3 characters long."

# ElasticSearch-only options:
ES_CLOUD_ID = env_get("YENTE_ELASTICSEARCH_CLOUD_ID")

# OpenSearch-only options:
OPENSEARCH_REGION = env_get("YENTE_OPENSEARCH_REGION")
OPENSEARCH_SERVICE = env_get("YENTE_OPENSEARCH_SERVICE")

# Log output can be formatted as JSON:
LOG_JSON = as_bool(env_str("YENTE_LOG_JSON", "false"))
LOG_LEVEL = logging.DEBUG if DEBUG else logging.INFO
Expand Down

0 comments on commit c345a7a

Please sign in to comment.