Skip to content

Commit

Permalink
feat: use new sem index, gen ctx extractors
Browse files Browse the repository at this point in the history
Signed-off-by: Panos Vagenas <[email protected]>
  • Loading branch information
vagenas committed May 7, 2024
1 parent 52ed664 commit 7b60438
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 243 deletions.
26 changes: 18 additions & 8 deletions deepsearch/cps/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from deepsearch.core.client.settings import ProfileSettings
from deepsearch.core.client.settings_manager import SettingsManager
from deepsearch.cps.apis import public as sw_client
from deepsearch.cps.apis import public_v2 as sw_client_v2
from deepsearch.cps.client.components import (
CpsApiDataCatalogs,
CpsApiDataIndices,
Expand Down Expand Up @@ -44,28 +45,37 @@ def __init__(self, config: DeepSearchConfig) -> None:

auth = f"Bearer {self.bearer_token_auth.bearer_token}"

################################
# configure v1 public API client
################################
sw_config = sw_client.Configuration(
host=f"{self.config.host}/api/cps/public/v1",
api_key={"Authorization": auth},
)
sw_config.verify_ssl = self.config.verify_ssl

# Disable client-side validation, because our APIs lie.
sw_config.client_side_validation = False

# print(sw_config, sw_config.client_side_validation)

self.swagger_client = sw_client.ApiClient(sw_config)

################################
# configure v2 public API client
################################
sw_config_v2 = sw_client_v2.Configuration(
host=f"{self.config.host}/api/cps/public/v2",
)
sw_config_v2.api_key["Bearer"] = auth
sw_config_v2.verify_ssl = self.config.verify_ssl
sw_config_v2.client_side_validation = False
self.swagger_client_v2 = sw_client_v2.ApiClient(sw_config_v2)

##############################
# configure v1 user API client
##############################
sw_user_conf = deepsearch.cps.apis.user.Configuration(
host=f"{self.config.host}/api/cps/user/v1",
api_key={"Authorization": auth},
)
sw_user_conf.verify_ssl = self.config.verify_ssl

# Disable client-side validation, because our APIs lie.
sw_user_conf.client_side_validation = False

self.user_swagger_client = deepsearch.cps.apis.user.ApiClient(sw_user_conf)

self.session = requests.Session()
Expand Down
116 changes: 47 additions & 69 deletions deepsearch/cps/client/components/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,28 @@
import base64
import json
import urllib.parse
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

from pydantic.v1 import BaseModel, Field
from typing_extensions import Annotated
from pydantic.v1 import BaseModel

from deepsearch.cps.apis import public as sw_client
from deepsearch.cps.apis.public.models.semantic_ingest_req_params import (
from deepsearch.cps.apis.public_v2 import SemanticApi
from deepsearch.cps.apis.public_v2.models.cps_task import CpsTask
from deepsearch.cps.apis.public_v2.models.semantic_ingest_req_params import (
SemanticIngestReqParams,
)
from deepsearch.cps.apis.public.models.semantic_ingest_request import (
from deepsearch.cps.apis.public_v2.models.semantic_ingest_request import (
SemanticIngestRequest,
)
from deepsearch.cps.apis.public.models.task import Task
from deepsearch.cps.apis.public_v2.models.semantic_ingest_source_private_data_collection import (
SemanticIngestSourcePrivateDataCollection,
)
from deepsearch.cps.apis.public_v2.models.semantic_ingest_source_private_data_document import (
SemanticIngestSourcePrivateDataDocument,
)
from deepsearch.cps.apis.public_v2.models.semantic_ingest_source_public_data_document import (
SemanticIngestSourcePublicDataDocument,
)
from deepsearch.cps.apis.public_v2.models.source1 import Source1
from deepsearch.cps.client.components.data_indices import (
ElasticProjectDataCollectionSource,
)
Expand All @@ -26,107 +35,76 @@
from deepsearch.cps.client import CpsApi


class SemIngestPublicDataDocumentSource(BaseModel):
class PublicDataDocumentSource(BaseModel):
source: ElasticDataCollectionSource
document_hash: str


class SemIngestPrivateDataDocumentSource(BaseModel):
class PrivateDataDocumentSource(BaseModel):
source: ElasticProjectDataCollectionSource
document_hash: str


class SemIngestPrivateDataCollectionSource(BaseModel):
class PrivateDataCollectionSource(BaseModel):
source: ElasticProjectDataCollectionSource


SemIngestSource = Union[
SemIngestPublicDataDocumentSource,
SemIngestPrivateDataDocumentSource,
SemIngestPrivateDataCollectionSource,
]


class _APISemanticIngestSourceUrl(BaseModel):
type: Literal["url"] = "url"
url: str


class _APISemanticIngestSourcePublicDataDocument(BaseModel):
type: Literal["public_data_document"] = "public_data_document"
elastic_id: str
index_key: str
document_hash: str


class _APISemanticIngestSourcePrivateDataDocument(BaseModel):
type: Literal["private_data_document"] = "private_data_document"
proj_key: str
index_key: str
document_hash: str


class _APISemanticIngestSourcePrivateDataCollection(BaseModel):
type: Literal["private_data_collection"] = "private_data_collection"
proj_key: str
index_key: str


_APISemanticIngestSourceType = Annotated[
Union[
_APISemanticIngestSourceUrl,
_APISemanticIngestSourcePublicDataDocument,
_APISemanticIngestSourcePrivateDataDocument,
_APISemanticIngestSourcePrivateDataCollection,
],
Field(discriminator="type"),
DataSource = Union[
PublicDataDocumentSource,
PrivateDataDocumentSource,
PrivateDataCollectionSource,
]


class DSApiDocuments:
def __init__(self, api: CpsApi) -> None:
self.api = api
self.semantic_api = sw_client.SemanticApi(self.api.client.swagger_client)
self.semantic_api = SemanticApi(self.api.client.swagger_client_v2)

def semantic_ingest(
self,
project: Union[Project, str],
data_source: SemIngestSource,
data_source: DataSource,
skip_ingested_docs: bool = True,
) -> Task:
) -> CpsTask:

proj_key = project.key if isinstance(project, Project) else project
api_src_data: _APISemanticIngestSourceType
if isinstance(data_source, SemIngestPublicDataDocumentSource):
api_src_data = _APISemanticIngestSourcePublicDataDocument(
api_src_data: Any
if isinstance(data_source, PublicDataDocumentSource):
api_src_data = SemanticIngestSourcePublicDataDocument(
type="public_data_document",
elastic_id=data_source.source.elastic_id,
index_key=data_source.source.index_key,
document_hash=data_source.document_hash,
)
elif isinstance(data_source, SemIngestPrivateDataDocumentSource):
api_src_data = _APISemanticIngestSourcePrivateDataDocument(
elif isinstance(data_source, PrivateDataDocumentSource):
api_src_data = SemanticIngestSourcePrivateDataDocument(
type="private_data_document",
proj_key=data_source.source.proj_key,
index_key=data_source.source.index_key,
document_hash=data_source.document_hash,
)
elif isinstance(data_source, SemIngestPrivateDataCollectionSource):
api_src_data = _APISemanticIngestSourcePrivateDataCollection(
elif isinstance(data_source, PrivateDataCollectionSource):
api_src_data = SemanticIngestSourcePrivateDataCollection(
type="private_data_collection",
proj_key=data_source.source.proj_key,
index_key=data_source.source.index_key,
)
else:
raise RuntimeError("Unknown data source format for ingest_for_qa")
req_params = SemanticIngestReqParams(
skip_ingested_docs=skip_ingested_docs,
raise RuntimeError("Unknown data source format for semantic_ingest")

semantic_ingest_request = SemanticIngestRequest(
source=Source1(
actual_instance=api_src_data,
),
parameters=SemanticIngestReqParams(
skip_ingested_docs=skip_ingested_docs,
),
)
task: Task = self.semantic_api.ingest(
task = self.semantic_api.ingest(
proj_key=proj_key,
body=SemanticIngestRequest(
source=api_src_data.dict(),
parameters=req_params.to_dict(),
),
semantic_ingest_request=semantic_ingest_request,
)

return task

def generate_url(
Expand Down
13 changes: 7 additions & 6 deletions deepsearch/cps/client/components/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, List, Literal, Optional, Union

from pydantic.v1 import BaseModel

Expand Down Expand Up @@ -89,12 +89,13 @@ class Project:


class SemanticBackendResource(BaseModel):
type: Literal["semantic_backend_genai_runner"] = "semantic_backend_genai_runner"
proj_key: str
index_key: str

def to_resource(self):
return {
"type": "semantic_backend_genai_runner",
"proj_key": self.proj_key,
"index_key": self.index_key,
}
return self.dict()


class SemanticBackendPublicResource(SemanticBackendResource):
elastic_id: str
Loading

0 comments on commit 7b60438

Please sign in to comment.