Skip to content

Commit

Permalink
Auto-moderation & remaining v1 fixes (#1089)
Browse files Browse the repository at this point in the history
* add expiry date for tasks and periodic removal, fix purge user messages sibling ranking counts

* add auto-moderation feature

* fix doc strings

* fix bad message query

* add debug log on insert message

* fix >= comparison
  • Loading branch information
andreaskoepf authored Feb 3, 2023
1 parent 323432c commit ad6c39b
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 6 deletions.
10 changes: 9 additions & 1 deletion backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from oasst_backend.config import settings
from oasst_backend.database import engine
from oasst_backend.models import message_tree_state
from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository
from oasst_backend.prompt_repository import PromptRepository, UserRepository
from oasst_backend.task_repository import TaskRepository, delete_expired_tasks
from oasst_backend.tree_manager import TreeManager
from oasst_backend.user_repository import User
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
Expand Down Expand Up @@ -318,6 +319,13 @@ def update_user_streak(session: Session) -> None:
return


@app.on_event("startup")
@repeat_every(seconds=60 * 60) # 1 hour
@managed_tx_function(auto_commit=CommitMode.COMMIT)
def cronjob_delete_expired_tasks(session: Session) -> None:
delete_expired_tasks(session)


app.include_router(api_router, prefix=settings.API_V1_STR)


Expand Down
19 changes: 19 additions & 0 deletions backend/oasst_backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,29 @@ class TreeManagerConfiguration(BaseModel):
goal_tree_size: int = 12
"""Total number of messages to gather per tree."""

random_goal_tree_size: bool = False
"""If set to true goal tree sizes will be generated randomly within range [min_goal_tree_size, goal_tree_size]."""

min_goal_tree_size: int = 5
"""Minimum tree size for random goal sizes."""

num_reviews_initial_prompt: int = 3
"""Number of peer review checks to collect in INITIAL_PROMPT_REVIEW state."""

num_reviews_reply: int = 3
"""Number of peer review checks to collect per reply (other than initial_prompt)."""

auto_mod_enabled: bool = True
"""Flag to enable/disable auto moderation."""

auto_mod_max_skip_reply: int = 25
"""Automatically set tree state to `halted_by_moderator` when more than the specified number
of users skip replying to a message. (auto moderation)"""

auto_mod_red_flags: int = 3
"""Delete messages that receive more than this number of red flags if it is a reply or
set the tree to `aborted_low_grade` when a prompt is flagged. (auto moderation)"""

p_full_labeling_review_prompt: float = 1.0
"""Probability of full text-labeling (instead of mandatory only) for initial prompts."""

Expand Down Expand Up @@ -222,6 +239,8 @@ def validate_user_stats_intervals(cls, v: int):
RATE_LIMIT_TASK_API_TIMES: int = 10_000
RATE_LIMIT_TASK_API_MINUTES: int = 1

TASK_VALIDITY_MINUTES: int = 60 * 24 * 2 # tasks expire after 2 days

class Config:
env_file = ".env"
env_file_encoding = "utf-8"
Expand Down
6 changes: 4 additions & 2 deletions backend/oasst_backend/prompt_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,6 @@ def insert_message(
review_result=review_result,
)
self.db.add(message)

# self.db.refresh(message)
return message

def _validate_task(
Expand Down Expand Up @@ -288,6 +286,10 @@ def store_text_reply(
task.done = True
self.db.add(task)
self.journal.log_text_reply(task=task, message_id=new_message_id, role=role, length=len(text))
logger.debug(
f"Inserted message id={user_message.id}, tree={user_message.message_tree_id}, user_id={user_message.user_id}, "
f"text[:100]='{user_message.text[:100]}', role='{user_message.role}', lang='{user_message.lang}'"
)
return user_message

@managed_tx_method(CommitMode.FLUSH)
Expand Down
24 changes: 22 additions & 2 deletions backend/oasst_backend/task_repository.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from datetime import timedelta
from datetime import datetime, timedelta
from typing import Optional
from uuid import UUID

import oasst_backend.models.db_payload as db_payload
from loguru import logger
from oasst_backend.config import settings
from oasst_backend.models import ApiClient, Task
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_backend.user_repository import UserRepository
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session, func, or_
from oasst_shared.utils import utcnow
from sqlmodel import Session, delete, func, or_
from starlette.status import HTTP_404_NOT_FOUND


Expand All @@ -24,6 +26,13 @@ def validate_frontend_message_id(message_id: str) -> None:
raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)


def delete_expired_tasks(session: Session) -> int:
stm = delete(Task).where(Task.expiry_date < utcnow())
result = session.exec(stm)
logger.info(f"Deleted {result.rowcount} expired tasks.")
return result.rowcount


class TaskRepository:
def __init__(
self,
Expand Down Expand Up @@ -118,12 +127,18 @@ def store_task(
case _:
raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)

if not collective and settings.TASK_VALIDITY_MINUTES > 0:
expiry_date = utcnow() + timedelta(minutes=settings.TASK_VALIDITY_MINUTES)
else:
expiry_date = None

task_model = self.insert_task(
payload=payload,
id=task.id,
message_tree_id=message_tree_id,
parent_message_id=parent_message_id,
collective=collective,
expiry_date=expiry_date,
)
assert task_model.id == task.id
return task_model
Expand Down Expand Up @@ -175,6 +190,7 @@ def insert_task(
message_tree_id: UUID = None,
parent_message_id: UUID = None,
collective: bool = False,
expiry_date: datetime = None,
) -> Task:
c = PayloadContainer(payload=payload)
task = Task(
Expand All @@ -186,6 +202,7 @@ def insert_task(
message_tree_id=message_tree_id,
parent_message_id=parent_message_id,
collective=collective,
expiry_date=expiry_date,
)
logger.debug(f"inserting {task=}")
self.db.add(task)
Expand Down Expand Up @@ -218,3 +235,6 @@ def fetch_recent_reply_tasks(
if limit:
qry = qry.limit(limit)
return qry.all()

def delete_expired_tasks(self) -> int:
return delete_expired_tasks(self.db)
92 changes: 91 additions & 1 deletion backend/oasst_backend/tree_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
import pydantic
import sqlalchemy as sa
from fastapi.encoders import jsonable_encoder
from loguru import logger
from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversation_message_list
Expand All @@ -31,6 +32,7 @@
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.utils import utcnow
from sqlalchemy.sql.functions import coalesce
from sqlmodel import Session, and_, func, not_, or_, text, update


Expand Down Expand Up @@ -269,13 +271,39 @@ def _prompt_lottery(self, lang: str) -> int:
self._enter_state(mts, message_tree_state.State.GROWING)
self.db.flush()

def _auto_moderation(self, lang: str) -> None:
if not self.cfg.auto_mod_enabled:
return

bad_messages = self.query_moderation_bad_messages(lang=lang)
for m in bad_messages:
num_red_flag = m.emojis.get(protocol_schema.EmojiCode.red_flag)

if num_red_flag is not None and num_red_flag >= self.cfg.auto_mod_red_flags:
if m.parent_id is None:
logger.warning(
f"[AUTO MOD] Halting tree {m.message_tree_id}, inital prompt got too many red flags ({m.emojis})."
)
self.enter_low_grade_state(m.message_tree_id)
else:
logger.warning(f"[AUTO MOD] Deleting message {m.id=}, it received too many red flags ({m.emojis}).")
self.pr.mark_messages_deleted(m.id, recursive=True)

num_skip_reply = m.emojis.get(protocol_schema.EmojiCode.skip_reply)
if num_skip_reply is not None and num_skip_reply >= self.cfg.auto_mod_max_skip_reply:
logger.warning(
f"[AUTO MOD] Halting tree {m.message_tree_id} due to high skip-reply count of message {m.id=} ({m.emojis})."
)
self.halt_tree(m.id, halt=True)

def determine_task_availability(self, lang: str) -> dict[protocol_schema.TaskRequestType, int]:
self.pr.ensure_user_is_enabled()

if not lang:
lang = "en"
logger.warning("Task availability request without lang tag received, assuming lang='en'.")

self._auto_moderation(lang=lang)
num_missing_prompts = self._prompt_lottery(lang=lang)
extendible_parents, _ = self.query_extendible_parents(lang=lang)
prompts_need_review = self.query_prompts_need_review(lang=lang)
Expand Down Expand Up @@ -313,6 +341,7 @@ def next_task(
lang = "en"
logger.warning("Task request without lang tag received, assuming 'en'.")

self._auto_moderation(lang=lang)
num_missing_prompts = self._prompt_lottery(lang=lang)

prompts_need_review = self.query_prompts_need_review(lang=lang)
Expand Down Expand Up @@ -1254,6 +1283,37 @@ def query_reviews_for_message(self, message_id: UUID) -> list[TextLabels]:
)
return qry.all()

def query_moderation_bad_messages(self, lang: str) -> list[Message]:
qry = (
self.db.query(Message)
.select_from(MessageTreeState)
.join(Message, MessageTreeState.message_tree_id == Message.message_tree_id)
.filter(
MessageTreeState.active,
or_(
MessageTreeState.state == message_tree_state.State.INITIAL_PROMPT_REVIEW,
MessageTreeState.state == message_tree_state.State.GROWING,
),
or_(
Message.parent_id.is_(None),
Message.review_result,
and_(Message.parent_id.is_not(None), Message.review_count < self.cfg.num_reviews_reply),
),
not_(Message.deleted),
or_(
coalesce(Message.emojis[protocol_schema.EmojiCode.red_flag].cast(sa.Integer), 0)
>= self.cfg.auto_mod_red_flags,
coalesce(Message.emojis[protocol_schema.EmojiCode.skip_reply].cast(sa.Integer), 0)
>= self.cfg.auto_mod_max_skip_reply,
),
)
)

if lang is not None:
qry = qry.filter(Message.lang == lang)

return qry.all()

@managed_tx_method(CommitMode.FLUSH)
def _insert_tree_state(
self,
Expand Down Expand Up @@ -1281,10 +1341,17 @@ def _insert_default_state(
self,
root_message_id: UUID,
state: message_tree_state.State = message_tree_state.State.INITIAL_PROMPT_REVIEW,
*,
goal_tree_size: int = None,
) -> MessageTreeState:
if goal_tree_size is None:
if self.cfg.random_goal_tree_size and self.cfg.min_goal_tree_size < self.cfg.goal_tree_size:
goal_tree_size = random.randint(self.cfg.min_goal_tree_size, self.cfg.goal_tree_size)
else:
goal_tree_size = self.cfg.goal_tree_size
return self._insert_tree_state(
root_message_id=root_message_id,
goal_tree_size=self.cfg.goal_tree_size,
goal_tree_size=goal_tree_size,
max_depth=self.cfg.max_tree_depth,
max_children_count=self.cfg.max_children_count,
state=state,
Expand Down Expand Up @@ -1379,9 +1446,32 @@ def _purge_message_internal(self, message_id: UUID) -> None:
DELETE FROM task t WHERE t.parent_message_id = :message_id;
DELETE FROM message WHERE id = :message_id;
"""
parent_id = self.pr.fetch_message(message_id=message_id).parent_id
r = self.db.execute(text(sql_purge_message), {"message_id": message_id})
logger.debug(f"purge_message({message_id=}): {r.rowcount} rows.")

sql_update_ranking_counts = """
WITH r AS (
-- find ranking results and count per child
SELECT c.id,
count(*) FILTER (
WHERE mr.payload#>'{payload, ranked_message_ids}' ? CAST(c.id AS varchar)
) AS ranking_count
FROM message c
LEFT JOIN message_reaction mr ON mr.payload_type = 'RankingReactionPayload'
AND mr.message_id = c.parent_id
WHERE c.parent_id = :parent_id
GROUP BY c.id
)
UPDATE message m SET ranking_count = r.ranking_count
FROM r WHERE m.id = r.id AND m.ranking_count != r.ranking_count;
"""

if parent_id is not None:
# update ranking counts of remaining children
r = self.db.execute(text(sql_update_ranking_counts), {"parent_id": parent_id})
logger.debug(f"ranking_count updated for {r.rowcount} rows.")

def purge_message_tree(self, message_tree_id: UUID) -> None:
sql_purge_message_tree = """
DELETE FROM journal j USING message m WHERE j.message_id = m.Id AND m.message_tree_id = :message_tree_id;
Expand Down

0 comments on commit ad6c39b

Please sign in to comment.