Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[omm] Seed data fixups #1506

Merged
merged 2 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 5 additions & 32 deletions open-media-match/src/OpenMediaMatch/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from OpenMediaMatch.persistence import get_storage
from OpenMediaMatch.blueprints import development, hashing, matching, curation, ui
from OpenMediaMatch.storage.interface import BankConfig
from OpenMediaMatch.utils import dev_utils


def _is_debug_mode():
Expand Down Expand Up @@ -177,18 +178,9 @@ def site_map():
return routes

@app.cli.command("seed")
def seed_data():
"""Insert plausible-looking data into the database layer"""
from threatexchange.signal_type.pdq.signal import PdqSignal

bank_name = "SEED_BANK"

storage = get_storage()
storage.bank_update(BankConfig(name=bank_name, matching_enabled_ratio=1.0))

for st in (PdqSignal, VideoMD5Signal):
for example in st.get_examples():
storage.bank_add_content(bank_name, {st.get_name(): example})
def seed_data() -> None:
"""Add sample data API connection"""
dev_utils.seed_sample()

@app.cli.command("big-seed")
@click.option("-b", "--banks", default=100, show_default=True)
Expand All @@ -198,26 +190,7 @@ def seed_enourmous(banks: int, seeds: int) -> None:
Seed the database with a large number of banks and hashes
It will generate n banks and put n/m hashes on each bank
"""
storage = get_storage()

types: list[t.Type[CanGenerateRandomSignal]] = [PdqSignal, VideoMD5Signal]

for i in range(banks):
# create bank
bank = BankConfig(name=f"SEED_BANK_{i}", matching_enabled_ratio=1.0)
storage.bank_update(bank, create=True)

# Add hashes
for _ in range(seeds // banks):
# grab randomly either PDQ or MD5 signal
signal_type = random.choice(types)
random_hash = signal_type.get_random_signal()

storage.bank_add_content(
bank.name, {t.cast(t.Type[SignalType], signal_type): random_hash}
)

print("Finished adding hashes to", bank.name)
dev_utils.seed_banks_random(banks, seeds)

@app.cli.command("fetch")
def fetch():
Expand Down
13 changes: 13 additions & 0 deletions open-media-match/src/OpenMediaMatch/blueprints/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from OpenMediaMatch.blueprints import matching, curation, hashing
from OpenMediaMatch.persistence import get_storage
from OpenMediaMatch.utils import dev_utils
from OpenMediaMatch.storage.postgres.flask_utils import reset_tables
from OpenMediaMatch.storage.postgres.database import db
from OpenMediaMatch.utils.time_utils import duration_to_human_str
Expand Down Expand Up @@ -124,6 +125,18 @@ def upload():
return {"hashes": signals, "banks": sorted(banks)}


@bp.route("/seed_sample", methods=["POST"])
def seed_sample():
dev_utils.seed_sample()
return redirect("./")


@bp.route("/seed_banks", methods=["POST"])
def seed_banks():
dev_utils.seed_banks_random()
return redirect("./")


@bp.route("/factory_reset", methods=["POST"])
def factory_reset():
reset_tables()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@
<div class="alert alert-primary" role="alert">
<div class="d-flex flex-row align-items-center">
<div class="me-2">Running server in development mode!</div>
<div class="me-2">
<form action="/ui/factory_reset" method="post" enctype="multipart/form-data">
<button type="submit" class="btn btn-danger">Factory Reset</button>
<div class="hstack gap-2">
<form action="/ui/seed_sample" method="post">
<button type="submit" class="btn btn-primary">Seed Sample API</button>
</form>
<form action="/ui/seed_banks" method="post">
<button type="submit" class="btn btn-primary">Seed Banks</button>
</form>
<form action="/ui/factory_reset" method="post">
<button type="submit" class="btn btn-outline-danger">Factory Reset</button>
</form>
</div>
</div>
</div>
</div>
</div>
48 changes: 48 additions & 0 deletions open-media-match/src/OpenMediaMatch/utils/dev_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import typing as t

from threatexchange.signal_type.pdq.signal import PdqSignal
from threatexchange.signal_type.md5 import VideoMD5Signal
from threatexchange.exchanges.collab_config import CollaborationConfigBase
from threatexchange.exchanges.impl.static_sample import StaticSampleSignalExchangeAPI
from threatexchange.signal_type.signal_base import SignalType, CanGenerateRandomSignal

from OpenMediaMatch import persistence
from OpenMediaMatch.storage.interface import BankConfig


def seed_sample() -> None:
storage = persistence.get_storage()
storage.exchange_update(
CollaborationConfigBase(
name="SEED_SAMPLE",
api=StaticSampleSignalExchangeAPI.get_name(),
enabled=True,
),
create=True,
)


def seed_banks_random(banks: int = 2, seeds: int = 10000) -> None:
"""
Seed the database with a large number of banks and hashes
It will generate n banks and put n/m hashes on each bank
"""
storage = persistence.get_storage()

types: list[t.Type[CanGenerateRandomSignal]] = [PdqSignal, VideoMD5Signal]

for i in range(banks):
# create bank
bank = BankConfig(name=f"SEED_BANK_{i}", matching_enabled_ratio=1.0)
storage.bank_update(bank, create=True)

# Add hashes
for i in range(seeds // banks):
signal_type = types[i % len(types)]
random_hash = signal_type.get_random_signal()

storage.bank_add_content(
bank.name, {t.cast(t.Type[SignalType], signal_type): random_hash}
)
Loading