Skip to content

Commit

Permalink
Feat:remove estimation of embedding cost (#7950)
Browse files Browse the repository at this point in the history
Co-authored-by: jyong <[email protected]>
  • Loading branch information
JzoNgKVO and JohnJyong authored Sep 4, 2024
1 parent 83e8486 commit 14af875
Show file tree
Hide file tree
Showing 14 changed files with 122 additions and 162 deletions.
41 changes: 2 additions & 39 deletions api/core/indexing_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
from core.errors.error import ProviderTokenNotInitError
from core.llm_generator.llm_generator import LLMGenerator
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType, PriceType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
Expand Down Expand Up @@ -255,11 +253,8 @@ def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSettin
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
tokens = 0
preview_texts = []
total_segments = 0
total_price = 0
currency = 'USD'
index_type = doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
all_text_docs = []
Expand All @@ -286,54 +281,22 @@ def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSettin
for document in documents:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
if indexing_technique == 'high_quality' or embedding_model_instance:
tokens += embedding_model_instance.get_text_embedding_num_tokens(
texts=[self.filter_string(document.page_content)]
)

if doc_form and doc_form == 'qa_model':
model_instance = self.model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM
)

model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)

if len(preview_texts) > 0:
# qa model document
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
doc_language)
document_qa_list = self.format_split_text(response)
price_info = model_type_instance.get_price(
model=model_instance.model,
credentials=model_instance.credentials,
price_type=PriceType.INPUT,
tokens=total_segments * 2000,
)

return {
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(price_info.total_amount),
"currency": price_info.currency,
"qa_preview": document_qa_list,
"preview": preview_texts
}
if embedding_model_instance:
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance)
embedding_price_info = embedding_model_type_instance.get_price(
model=embedding_model_instance.model,
credentials=embedding_model_instance.credentials,
price_type=PriceType.INPUT,
tokens=tokens
)
total_price = '{:f}'.format(embedding_price_info.total_amount)
currency = embedding_price_info.currency
return {
"total_segments": total_segments,
"tokens": tokens,
"total_price": total_price,
"currency": currency,
"preview": preview_texts
}

Expand Down
23 changes: 15 additions & 8 deletions api/core/rag/splitter/text_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,17 @@ def _join_docs(self, docs: list[str], separator: str) -> Optional[str]:
else:
return text

def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]:
def _merge_splits(self, splits: Iterable[str], separator: str, lengths: list[int]) -> list[str]:
# We now want to combine these smaller pieces into medium size
# chunks to send to the LLM.
separator_len = self._length_function(separator)

docs = []
current_doc: list[str] = []
total = 0
index = 0
for d in splits:
_len = self._length_function(d)
_len = lengths[index]
if (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> self._chunk_size
Expand Down Expand Up @@ -145,6 +146,7 @@ def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]:
current_doc = current_doc[1:]
current_doc.append(d)
total += _len + (separator_len if len(current_doc) > 1 else 0)
index += 1
doc = self._join_docs(current_doc, separator)
if doc is not None:
docs.append(doc)
Expand Down Expand Up @@ -493,11 +495,10 @@ def __init__(
self._separators = separators or ["\n\n", "\n", " ", ""]

def _split_text(self, text: str, separators: list[str]) -> list[str]:
"""Split incoming text and return chunks."""
final_chunks = []
# Get appropriate separator to use
separator = separators[-1]
new_separators = []

for i, _s in enumerate(separators):
if _s == "":
separator = _s
Expand All @@ -508,25 +509,31 @@ def _split_text(self, text: str, separators: list[str]) -> list[str]:
break

splits = _split_text_with_regex(text, separator, self._keep_separator)
# Now go merging things, recursively splitting longer texts.
_good_splits = []
_good_splits_lengths = [] # cache the lengths of the splits
_separator = "" if self._keep_separator else separator

for s in splits:
if self._length_function(s) < self._chunk_size:
s_len = self._length_function(s)
if s_len < self._chunk_size:
_good_splits.append(s)
_good_splits_lengths.append(s_len)
else:
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
final_chunks.extend(merged_text)
_good_splits = []
_good_splits_lengths = []
if not new_separators:
final_chunks.append(s)
else:
other_info = self._split_text(s, new_separators)
final_chunks.extend(other_info)

if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
final_chunks.extend(merged_text)

return final_chunks

def split_text(self, text: str) -> list[str]:
Expand Down
7 changes: 3 additions & 4 deletions api/services/dataset_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,7 +1054,6 @@ def save_document_without_dataset_id(tenant_id: str, document_data: dict, accoun

DocumentService.check_documents_upload_quota(count, features)

embedding_model = None
dataset_collection_binding_id = None
retrieval_model = None
if document_data["indexing_technique"] == "high_quality":
Expand Down Expand Up @@ -1082,10 +1081,10 @@ def save_document_without_dataset_id(tenant_id: str, document_data: dict, accoun
tenant_id=tenant_id,
name="",
data_source_type=document_data["data_source"]["type"],
indexing_technique=document_data["indexing_technique"],
indexing_technique=document_data.get("indexing_technique", "high_quality"),
created_by=account.id,
embedding_model=embedding_model.model if embedding_model else None,
embedding_model_provider=embedding_model.provider if embedding_model else None,
embedding_model=document_data.get("embedding_model"),
embedding_model_provider=document_data.get("embedding_model_provider"),
collection_binding_id=dataset_collection_binding_id,
retrieval_model=retrieval_model,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ import { FileSearch02 } from '@/app/components/base/icons/src/vender/solid/files
import { useProviderContext } from '@/context/provider-context'
import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import {
DEFAULT_WEIGHTED_SCORE,
RerankingModeEnum,
WeightedScoreEnum,
} from '@/models/datasets'

type Props = {
value: RetrievalConfig
Expand All @@ -32,6 +37,18 @@ const RetrievalMethodConfig: FC<Props> = ({
reranking_provider_name: rerankDefaultModel?.provider.provider || '',
reranking_model_name: rerankDefaultModel?.model || '',
},
reranking_mode: passValue.reranking_mode || (rerankDefaultModel ? RerankingModeEnum.RerankingModel : RerankingModeEnum.WeightedScore),
weights: passValue.weights || {
weight_type: WeightedScoreEnum.Customized,
vector_setting: {
vector_weight: DEFAULT_WEIGHTED_SCORE.other.semantic,
embedding_provider_name: '',
embedding_model_name: '',
},
keyword_setting: {
keyword_weight: DEFAULT_WEIGHTED_SCORE.other.keyword,
},
},
}
}
return passValue
Expand Down
30 changes: 2 additions & 28 deletions web/app/components/datasets/create/embedding-process/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ import cn from '@/utils/classnames'
import { FieldInfo } from '@/app/components/datasets/documents/detail/metadata'
import Button from '@/app/components/base/button'
import type { FullDocumentDetail, IndexingStatusResponse, ProcessRuleResponse } from '@/models/datasets'
import { formatNumber } from '@/utils/format'
import { fetchIndexingStatusBatch as doFetchIndexingStatus, fetchIndexingEstimateBatch, fetchProcessRule } from '@/service/datasets'
import { fetchIndexingStatusBatch as doFetchIndexingStatus, fetchProcessRule } from '@/service/datasets'
import { DataSourceType } from '@/models/datasets'
import NotionIcon from '@/app/components/base/notion-icon'
import PriorityLabel from '@/app/components/billing/priority-label'
Expand Down Expand Up @@ -142,14 +141,6 @@ const EmbeddingProcess: FC<Props> = ({ datasetId, batchId, documents = [], index
}, apiParams => fetchProcessRule(omit(apiParams, 'action')), {
revalidateOnFocus: false,
})
// get cost
const { data: indexingEstimateDetail } = useSWR({
action: 'fetchIndexingEstimateBatch',
datasetId,
batchId,
}, apiParams => fetchIndexingEstimateBatch(omit(apiParams, 'action')), {
revalidateOnFocus: false,
})

const router = useRouter()
const navToDocumentList = () => {
Expand Down Expand Up @@ -190,28 +181,11 @@ const EmbeddingProcess: FC<Props> = ({ datasetId, batchId, documents = [], index

return (
<>
<div className='h-5 flex justify-between items-center mb-5'>
<div className='h-5 flex items-center mb-5'>
<div className={s.embeddingStatus}>
{isEmbedding && t('datasetDocuments.embedding.processing')}
{isEmbeddingCompleted && t('datasetDocuments.embedding.completed')}
</div>
<div className={s.cost}>
{indexingType === 'high_quality' && (
<div className='flex items-center'>
<div className={cn(s.commonIcon, s.highIcon)} />
{t('datasetDocuments.embedding.highQuality')} · {t('datasetDocuments.embedding.estimate')}
<span className={s.tokens}>{formatNumber(indexingEstimateDetail?.tokens || 0)}</span>tokens
(<span className={s.price}>${formatNumber(indexingEstimateDetail?.total_price || 0)}</span>)
</div>
)}
{indexingType === 'economy' && (
<div className='flex items-center'>
<div className={cn(s.commonIcon, s.economyIcon)} />
{t('datasetDocuments.embedding.economy')} · {t('datasetDocuments.embedding.estimate')}
<span className={s.tokens}>0</span>tokens
</div>
)}
</div>
</div>
{
enableBilling && plan.type !== Plan.team && (
Expand Down
8 changes: 2 additions & 6 deletions web/app/components/datasets/create/step-two/index.module.css
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
}

.indexItem {
min-height: 146px;
min-height: 126px;
}

.indexItem .disableMask {
Expand Down Expand Up @@ -121,10 +121,6 @@
@apply pb-1;
}

.radioItem.indexItem .typeHeader .tip {
@apply pb-3;
}

.radioItem .typeIcon {
position: absolute;
top: 18px;
Expand Down Expand Up @@ -264,7 +260,7 @@
}

.input {
@apply inline-flex h-9 w-full py-1 px-2 rounded-lg text-xs leading-normal;
@apply inline-flex h-9 w-full py-1 px-2 pr-14 rounded-lg text-xs leading-normal;
@apply bg-gray-100 caret-primary-600 hover:bg-gray-100 focus:ring-1 focus:ring-inset focus:ring-gray-200 focus-visible:outline-none focus:bg-white placeholder:text-gray-400;
}

Expand Down
Loading

0 comments on commit 14af875

Please sign in to comment.