Skip to content

Commit

Permalink
Ensure vespa_doc_id is unique in EmbeddingsReference
Browse files Browse the repository at this point in the history
- Add unique constraint to `vespa_doc_id` field
- Update `bulk_create` in `get_or_create_embedded_file` to handle `vespa_doc_id` conflicts
- Implement migration to deduplicate existing `vespa_doc_id` entries
- Modify `create_embeddings_in_search_db` to maintain unique `vespa_doc_id` references

This change enforces the uniqueness of document IDs in Vespa, preventing duplicate entries and ensuring data integrity.
  • Loading branch information
devxpy committed Dec 24, 2024
1 parent 6e72eed commit 6d0637d
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 5 deletions.
15 changes: 10 additions & 5 deletions daras_ai_v2/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,8 +440,13 @@ def get_or_create_embedded_file(
)[0]
for ref in refs:
ref.embedded_file = embedded_file
EmbeddingsReference.objects.bulk_create(refs)
return embedded_file
EmbeddingsReference.objects.bulk_create(
refs,
update_conflicts=True,
update_fields=["url", "title", "snippet", "updated_at"],
unique_fields=["vespa_doc_id"],
)
return embedded_file


def create_embeddings_in_search_db(
Expand All @@ -456,7 +461,7 @@ def create_embeddings_in_search_db(
embedding_model: EmbeddingModels,
is_user_url: bool,
) -> list[EmbeddingsReference]:
refs = []
refs = {}
vespa = get_vespa_app()
for ref, embedding in get_embeds_for_doc(
f_url=f_url,
Expand All @@ -475,7 +480,7 @@ def create_embeddings_in_search_db(
title=ref["title"],
snippet=ref["snippet"],
)
refs.append(db_ref)
refs[db_ref.vespa_doc_id] = db_ref
vespa.feed_data_point(
schema=settings.VESPA_SCHEMA,
data_id=doc_id,
Expand All @@ -488,7 +493,7 @@ def create_embeddings_in_search_db(
),
operation_type="feed",
)
return refs
return list(refs.values())


def get_embeds_for_doc(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Generated by Django 5.1.3 on 2024-12-24 07:06

from django.db import migrations, models


def deduplicate_vespa_doc_id(apps, schema_editor):
EmbeddingsReference = apps.get_model("embeddings", "EmbeddingsReference")
db_alias = schema_editor.connection.alias
objects = EmbeddingsReference.objects.using(db_alias)

# 1. Find all vespa_doc_ids that have duplicates
duplicate_doc_ids = (
objects.values_list("vespa_doc_id",flat=True)
.annotate(count=models.Count("id"))
.filter(count__gt=1)
)

# 2. For each duplicate group, keep only the latest (by updated_at)
for doc_id in duplicate_doc_ids:
# Order by -updated_at so the first item is the latest
duplicate_refs = objects.filter(vespa_doc_id=doc_id).order_by("-updated_at")
to_keep = duplicate_refs.first()

# Delete everything else in the group
duplicate_refs.exclude(pk=to_keep.pk).delete()


class Migration(migrations.Migration):

dependencies = [
(
"embeddings",
"0002_embeddedfile_created_by_embeddedfile_last_query_at_and_more",
),
]

operations = [
migrations.RunPython(deduplicate_vespa_doc_id, None),
migrations.AlterField(
model_name="embeddingsreference",
name="vespa_doc_id",
field=models.CharField(
help_text="The Document ID of this embedding in Vespa. A hash of the file metadata + the split snippet.",
max_length=256,
unique=True,
),
),
]
1 change: 1 addition & 0 deletions embeddings/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class EmbeddingsReference(models.Model):
vespa_doc_id = models.CharField(
max_length=256,
help_text="The Document ID of this embedding in Vespa. A hash of the file metadata + the split snippet.",
unique=True,
)
url = CustomURLField()
title = models.TextField()
Expand Down
18 changes: 18 additions & 0 deletions usage_costs/migrations/0023_alter_modelpricing_model_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 5.1.3 on 2024-12-24 07:06

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('usage_costs', '0022_alter_modelpricing_model_name'),
]

operations = [
migrations.AlterField(
model_name='modelpricing',
name='model_name',
field=models.CharField(choices=[('o1_preview', 'o1-preview (openai)'), ('o1_mini', 'o1-mini (openai)'), ('gpt_4_o', 'GPT-4o (openai)'), ('gpt_4_o_mini', 'GPT-4o-mini (openai)'), ('chatgpt_4_o', 'ChatGPT-4o (openai) 🧪'), ('gpt_4_turbo_vision', 'GPT-4 Turbo with Vision (openai)'), ('gpt_4_vision', 'GPT-4 Vision (openai) 🔻'), ('gpt_4_turbo', 'GPT-4 Turbo (openai)'), ('gpt_4', 'GPT-4 (openai)'), ('gpt_4_32k', 'GPT-4 32K (openai) 🔻'), ('gpt_3_5_turbo', 'ChatGPT (openai)'), ('gpt_3_5_turbo_16k', 'ChatGPT 16k (openai)'), ('gpt_3_5_turbo_instruct', 'GPT-3.5 Instruct (openai) 🔻'), ('llama3_3_70b', 'Llama 3.3 70B'), ('llama3_2_90b_vision', 'Llama 3.2 90B + Vision (Meta AI)'), ('llama3_2_11b_vision', 'Llama 3.2 11B + Vision (Meta AI)'), ('llama3_2_3b', 'Llama 3.2 3B (Meta AI)'), ('llama3_2_1b', 'Llama 3.2 1B (Meta AI)'), ('llama3_1_70b', 'Llama 3.1 70B (Meta AI)'), ('llama3_1_8b', 'Llama 3.1 8B (Meta AI)'), ('llama3_70b', 'Llama 3 70B (Meta AI)'), ('llama3_8b', 'Llama 3 8B (Meta AI)'), ('mixtral_8x7b_instruct_0_1', 'Mixtral 8x7b Instruct v0.1 (Mistral)'), ('gemma_2_9b_it', 'Gemma 2 9B (Google)'), ('gemma_7b_it', 'Gemma 7B (Google)'), ('gemini_1_5_flash', 'Gemini 1.5 Flash (Google)'), ('gemini_1_5_pro', 'Gemini 1.5 Pro (Google)'), ('gemini_1_pro_vision', 'Gemini 1.0 Pro Vision (Google)'), ('gemini_1_pro', 'Gemini 1.0 Pro (Google)'), ('palm2_chat', 'PaLM 2 Chat (Google)'), ('palm2_text', 'PaLM 2 Text (Google)'), ('claude_3_5_sonnet', 'Claude 3.5 Sonnet (Anthropic)'), ('claude_3_opus', 'Claude 3 Opus [L] (Anthropic)'), ('claude_3_sonnet', 'Claude 3 Sonnet [M] (Anthropic)'), ('claude_3_haiku', 'Claude 3 Haiku [S] (Anthropic)'), ('afrollama_v1', 'AfroLlama3 v1 (Jacaranda)'), ('llama3_8b_cpt_sea_lion_v2_1_instruct', 'Llama3 8B CPT SEA-LIONv2.1 Instruct (aisingapore)'), ('sarvam_2b', 'Sarvam 2B (sarvamai)'), ('llama_3_groq_70b_tool_use', 'Llama 3 Groq 70b Tool Use [Deprecated]'), ('llama_3_groq_8b_tool_use', 'Llama 3 Groq 8b Tool Use [Deprecated]'), ('llama2_70b_chat', 'Llama 2 70B Chat [Deprecated] (Meta AI)'), ('sea_lion_7b_instruct', 'SEA-LION-7B-Instruct [Deprecated] (aisingapore)'), ('llama3_8b_cpt_sea_lion_v2_instruct', 'Llama3 8B CPT SEA-LIONv2 Instruct [Deprecated] (aisingapore)'), ('text_davinci_003', 'GPT-3.5 Davinci-3 [Deprecated] (openai)'), ('text_davinci_002', 'GPT-3.5 Davinci-2 [Deprecated] (openai)'), ('code_davinci_002', 'Codex [Deprecated] (openai)'), ('text_curie_001', 'Curie [Deprecated] (openai)'), ('text_babbage_001', 'Babbage [Deprecated] (openai)'), ('text_ada_001', 'Ada [Deprecated] (openai)'), ('protogen_2_2', 'Protogen V2.2 (darkstorm2150)'), ('epicdream', 'epiCDream (epinikion)'), ('flux_1_dev', 'FLUX.1 [dev]'), ('dream_shaper', 'DreamShaper (Lykon)'), ('dreamlike_2', 'Dreamlike Photoreal 2.0 (dreamlike.art)'), ('sd_2', 'Stable Diffusion v2.1 (stability.ai)'), ('sd_1_5', 'Stable Diffusion v1.5 (RunwayML)'), ('dall_e', 'DALL·E 2 (OpenAI)'), ('dall_e_3', 'DALL·E 3 (OpenAI)'), ('openjourney_2', 'Open Journey v2 beta (PromptHero)'), ('openjourney', 'Open Journey (PromptHero)'), ('analog_diffusion', 'Analog Diffusion (wavymulder)'), ('protogen_5_3', 'Protogen v5.3 (darkstorm2150)'), ('jack_qiao', 'Stable Diffusion v1.4 [Deprecated] (Jack Qiao)'), ('rodent_diffusion_1_5', 'Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)'), ('deepfloyd_if', 'DeepFloyd IF [Deprecated] (stability.ai)'), ('dream_shaper', 'DreamShaper (Lykon)'), ('dreamlike_2', 'Dreamlike Photoreal 2.0 (dreamlike.art)'), ('sd_2', 'Stable Diffusion v2.1 (stability.ai)'), ('sd_1_5', 'Stable Diffusion v1.5 (RunwayML)'), ('dall_e', 'Dall-E (OpenAI)'), ('instruct_pix2pix', '✨ InstructPix2Pix (Tim Brooks)'), ('openjourney_2', 'Open Journey v2 beta (PromptHero) 🐢'), ('openjourney', 'Open Journey (PromptHero) 🐢'), ('analog_diffusion', 'Analog Diffusion (wavymulder) 🐢'), ('protogen_5_3', 'Protogen v5.3 (darkstorm2150) 🐢'), ('jack_qiao', 'Stable Diffusion v1.4 [Deprecated] (Jack Qiao)'), ('rodent_diffusion_1_5', 'Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)'), ('sd_2', 'Stable Diffusion v2.1 (stability.ai)'), ('runway_ml', 'Stable Diffusion v1.5 (RunwayML)'), ('dall_e', 'Dall-E (OpenAI)'), ('jack_qiao', 'Stable Diffusion v1.4 [Deprecated] (Jack Qiao)'), ('wav2lip', 'LipSync (wav2lip)'), ('sadtalker', 'LipSync (sadtalker)')], help_text='The name of the model. Only used for Display purposes.', max_length=255),
),
]

0 comments on commit 6d0637d

Please sign in to comment.