Skip to content

Commit

Permalink
feat: regenerate in Chat, agent and Chatflow app (#7661)
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzuodong authored Sep 21, 2024
1 parent b32a771 commit 8c51d06
Show file tree
Hide file tree
Showing 51 changed files with 604 additions and 179 deletions.
1 change: 1 addition & 0 deletions api/constants/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
HIDDEN_VALUE = "[__HIDDEN__]"
UUID_NIL = "00000000-0000-0000-0000-000000000000"
1 change: 1 addition & 0 deletions api/controllers/console/app/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def post(self, app_model):
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("model_config", type=dict, required=True, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
args = parser.parse_args()
Expand Down
2 changes: 0 additions & 2 deletions api/controllers/console/app/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,6 @@ def get(self, app_model):
if rest_count > 0:
has_more = True

history_messages = list(reversed(history_messages))

return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)


Expand Down
2 changes: 2 additions & 0 deletions api/controllers/console/app/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def post(self, app_model: App):
parser.add_argument("query", type=str, required=True, location="json", default="")
parser.add_argument("files", type=list, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")

args = parser.parse_args()

try:
Expand Down
1 change: 1 addition & 0 deletions api/controllers/console/explore/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def post(self, installed_app):
parser.add_argument("query", type=str, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()

Expand Down
2 changes: 1 addition & 1 deletion api/controllers/console/explore/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get(self, installed_app):

try:
return MessageService.pagination_by_first_id(
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"], "desc"
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
Expand Down
1 change: 1 addition & 0 deletions api/controllers/service_api/app/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class MessageListApi(Resource):
message_fields = {
"id": fields.String,
"conversation_id": fields.String,
"parent_message_id": fields.String,
"inputs": fields.Raw,
"query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"),
Expand Down
1 change: 1 addition & 0 deletions api/controllers/web/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def post(self, app_model, end_user):
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")

args = parser.parse_args()
Expand Down
3 changes: 2 additions & 1 deletion api/controllers/web/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class MessageListApi(WebApiResource):
message_fields = {
"id": fields.String,
"conversation_id": fields.String,
"parent_message_id": fields.String,
"inputs": fields.Raw,
"query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"),
Expand Down Expand Up @@ -89,7 +90,7 @@ def get(self, app_model, end_user):

try:
return MessageService.pagination_by_first_id(
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"], "desc"
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
Expand Down
5 changes: 4 additions & 1 deletion api/core/agent/base_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from core.tools.entities.tool_entities import (
ToolParameter,
ToolRuntimeVariablePool,
Expand Down Expand Up @@ -441,10 +442,12 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P
.filter(
Message.conversation_id == self.message.conversation_id,
)
.order_by(Message.created_at.asc())
.order_by(Message.created_at.desc())
.all()
)

messages = list(reversed(extract_thread_messages(messages)))

for message in messages:
if message.id == self.message.id:
continue
Expand Down
1 change: 1 addition & 0 deletions api/core/app/apps/advanced_chat/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def generate(
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
query=query,
files=file_objs,
parent_message_id=args.get("parent_message_id"),
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
Expand Down
1 change: 1 addition & 0 deletions api/core/app/apps/agent_chat/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def generate(
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
query=query,
files=file_objs,
parent_message_id=args.get("parent_message_id"),
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
Expand Down
1 change: 1 addition & 0 deletions api/core/app/apps/chat/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def generate(
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
query=query,
files=file_objs,
parent_message_id=args.get("parent_message_id"),
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
Expand Down
1 change: 1 addition & 0 deletions api/core/app/apps/message_based_app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def _init_generate_records(
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0,
parent_message_id=getattr(application_generate_entity, "parent_message_id", None),
provider_response_latency=0,
total_price=0,
currency="USD",
Expand Down
3 changes: 3 additions & 0 deletions api/core/app/entities/app_invoke_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
"""

conversation_id: Optional[str] = None
parent_message_id: Optional[str] = None


class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
Expand All @@ -138,6 +139,7 @@ class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
"""

conversation_id: Optional[str] = None
parent_message_id: Optional[str] = None


class AdvancedChatAppGenerateEntity(AppGenerateEntity):
Expand All @@ -149,6 +151,7 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
app_config: WorkflowUIBasedAppConfig

conversation_id: Optional[str] = None
parent_message_id: Optional[str] = None
query: str

class SingleIterationRunEntity(BaseModel):
Expand Down
21 changes: 18 additions & 3 deletions api/core/memory/token_buffer_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
TextPromptMessageContent,
UserPromptMessage,
)
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db
from models.model import AppMode, Conversation, Message, MessageFile
from models.workflow import WorkflowRun
Expand All @@ -33,8 +34,17 @@ def get_history_prompt_messages(

# fetch limited messages, and return reversed
query = (
db.session.query(Message.id, Message.query, Message.answer, Message.created_at, Message.workflow_run_id)
.filter(Message.conversation_id == self.conversation.id, Message.answer != "")
db.session.query(
Message.id,
Message.query,
Message.answer,
Message.created_at,
Message.workflow_run_id,
Message.parent_message_id,
)
.filter(
Message.conversation_id == self.conversation.id,
)
.order_by(Message.created_at.desc())
)

Expand All @@ -45,7 +55,12 @@ def get_history_prompt_messages(

messages = query.limit(message_limit).all()

messages = list(reversed(messages))
# instead of all messages from the conversation, we only need to extract messages
# that belong to the thread of last message
thread_messages = extract_thread_messages(messages)
thread_messages.pop(0)
messages = list(reversed(thread_messages))

message_file_parser = MessageFileParser(tenant_id=app_record.tenant_id, app_id=app_record.id)
prompt_messages = []
for message in messages:
Expand Down
22 changes: 22 additions & 0 deletions api/core/prompt/utils/extract_thread_messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from constants import UUID_NIL


def extract_thread_messages(messages: list[dict]) -> list[dict]:
thread_messages = []
next_message = None

for message in messages:
if not message.parent_message_id:
# If the message is regenerated and does not have a parent message, it is the start of a new thread
thread_messages.append(message)
break

if not next_message:
thread_messages.append(message)
next_message = message.parent_message_id
else:
if next_message in {message.id, UUID_NIL}:
thread_messages.append(message)
next_message = message.parent_message_id

return thread_messages
1 change: 1 addition & 0 deletions api/fields/conversation_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def format(self, value):
"metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String,
"error": fields.String,
"parent_message_id": fields.String,
}

feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer}
Expand Down
1 change: 1 addition & 0 deletions api/fields/message_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
message_fields = {
"id": fields.String,
"conversation_id": fields.String,
"parent_message_id": fields.String,
"inputs": fields.Raw,
"query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""add parent_message_id to messages
Revision ID: d57ba9ebb251
Revises: 675b5321501b
Create Date: 2024-09-11 10:12:45.826265
"""
import sqlalchemy as sa
from alembic import op

import models as models

# revision identifiers, used by Alembic.
revision = 'd57ba9ebb251'
down_revision = '675b5321501b'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('messages', schema=None) as batch_op:
batch_op.add_column(sa.Column('parent_message_id', models.types.StringUUID(), nullable=True))

# Set parent_message_id for existing messages to uuid_nil() to distinguish them from new messages with actual parent IDs or NULLs
op.execute('UPDATE messages SET parent_message_id = uuid_nil() WHERE parent_message_id IS NULL')

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('messages', schema=None) as batch_op:
batch_op.drop_column('parent_message_id')

# ### end Alembic commands ###
1 change: 1 addition & 0 deletions api/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,7 @@ class Message(db.Model):
answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
parent_message_id = db.Column(StringUUID, nullable=True)
provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0"))
total_price = db.Column(db.Numeric(10, 7))
currency = db.Column(db.String(255), nullable=False)
Expand Down
4 changes: 3 additions & 1 deletion api/services/message_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def pagination_by_first_id(
conversation_id: str,
first_id: Optional[str],
limit: int,
order: str = "asc",
) -> InfiniteScrollPagination:
if not user:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
Expand Down Expand Up @@ -91,7 +92,8 @@ def pagination_by_first_id(
if rest_count > 0:
has_more = True

history_messages = list(reversed(history_messages))
if order == "asc":
history_messages = list(reversed(history_messages))

return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more)

Expand Down
91 changes: 91 additions & 0 deletions api/tests/unit_tests/core/prompt/test_extract_thread_messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from uuid import uuid4

from constants import UUID_NIL
from core.prompt.utils.extract_thread_messages import extract_thread_messages


class TestMessage:
def __init__(self, id, parent_message_id):
self.id = id
self.parent_message_id = parent_message_id

def __getitem__(self, item):
return getattr(self, item)


def test_extract_thread_messages_single_message():
messages = [TestMessage(str(uuid4()), UUID_NIL)]
result = extract_thread_messages(messages)
assert len(result) == 1
assert result[0] == messages[0]


def test_extract_thread_messages_linear_thread():
id1, id2, id3, id4, id5 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
messages = [
TestMessage(id5, id4),
TestMessage(id4, id3),
TestMessage(id3, id2),
TestMessage(id2, id1),
TestMessage(id1, UUID_NIL),
]
result = extract_thread_messages(messages)
assert len(result) == 5
assert [msg["id"] for msg in result] == [id5, id4, id3, id2, id1]


def test_extract_thread_messages_branched_thread():
id1, id2, id3, id4 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
messages = [
TestMessage(id4, id2),
TestMessage(id3, id2),
TestMessage(id2, id1),
TestMessage(id1, UUID_NIL),
]
result = extract_thread_messages(messages)
assert len(result) == 3
assert [msg["id"] for msg in result] == [id4, id2, id1]


def test_extract_thread_messages_empty_list():
messages = []
result = extract_thread_messages(messages)
assert len(result) == 0


def test_extract_thread_messages_partially_loaded():
id0, id1, id2, id3 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
messages = [
TestMessage(id3, id2),
TestMessage(id2, id1),
TestMessage(id1, id0),
]
result = extract_thread_messages(messages)
assert len(result) == 3
assert [msg["id"] for msg in result] == [id3, id2, id1]


def test_extract_thread_messages_legacy_messages():
id1, id2, id3 = str(uuid4()), str(uuid4()), str(uuid4())
messages = [
TestMessage(id3, UUID_NIL),
TestMessage(id2, UUID_NIL),
TestMessage(id1, UUID_NIL),
]
result = extract_thread_messages(messages)
assert len(result) == 3
assert [msg["id"] for msg in result] == [id3, id2, id1]


def test_extract_thread_messages_mixed_with_legacy_messages():
id1, id2, id3, id4, id5 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
messages = [
TestMessage(id5, id4),
TestMessage(id4, id2),
TestMessage(id3, id2),
TestMessage(id2, UUID_NIL),
TestMessage(id1, UUID_NIL),
]
result = extract_thread_messages(messages)
assert len(result) == 4
assert [msg["id"] for msg in result] == [id5, id4, id2, id1]
Loading

0 comments on commit 8c51d06

Please sign in to comment.