Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple messenger interface support #332

Open
wants to merge 10 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions dff/messengers/common/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class MessengerInterface(abc.ABC):
"""

@abc.abstractmethod
async def connect(self, pipeline_runner: PipelineRunnerFunction):
async def connect(self, pipeline_runner: PipelineRunnerFunction, iface_id: str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why add this parameter when it is always self.name?

"""
Method invoked when message interface is instantiated and connection is established.
May be used for sending an introduction message or displaying general bot information.
Expand Down Expand Up @@ -91,6 +91,7 @@ async def _polling_loop(
async def connect(
self,
pipeline_runner: PipelineRunnerFunction,
iface_id: str,
loop: PollingInterfaceLoopFunction = lambda: True,
timeout: float = 0,
):
Expand All @@ -105,6 +106,7 @@ async def connect(
called in each cycle, should return `True` to continue polling or `False` to stop.
:param timeout: a time interval between polls (in seconds).
"""
self._interface_id = iface_id
while loop():
try:
await self._polling_loop(pipeline_runner, timeout)
Expand All @@ -122,8 +124,9 @@ class CallbackMessengerInterface(MessengerInterface):
def __init__(self):
self._pipeline_runner: Optional[PipelineRunnerFunction] = None

async def connect(self, pipeline_runner: PipelineRunnerFunction):
async def connect(self, pipeline_runner: PipelineRunnerFunction, iface_id: str):
self._pipeline_runner = pipeline_runner
self._interface_id = iface_id

async def on_request_async(
self, request: Message, ctx_id: Optional[Hashable] = None, update_ctx_misc: Optional[dict] = None
Expand Down Expand Up @@ -165,12 +168,12 @@ def __init__(
self._descriptor: Optional[TextIO] = out_descriptor

def _request(self) -> List[Tuple[Message, Any]]:
return [(Message(input(self._prompt_request)), self._ctx_id)]
return [(Message(input(self._prompt_request), interface=self._interface_id), self._ctx_id)]

def _respond(self, responses: List[Context]):
print(f"{self._prompt_response}{responses[0].last_response.text}", file=self._descriptor)
print(f"{self._prompt_response}{responses[0].last_response_to(self._interface_id).text}", file=self._descriptor)

async def connect(self, pipeline_runner: PipelineRunnerFunction, **kwargs):
async def connect(self, pipeline_runner: PipelineRunnerFunction, iface_id: str, **kwargs):
"""
The CLIProvider generates new dialog id used to user identification on each `connect` call.

Expand All @@ -181,4 +184,4 @@ async def connect(self, pipeline_runner: PipelineRunnerFunction, **kwargs):
self._ctx_id = uuid.uuid4()
if self._intro is not None:
print(self._intro)
await super().connect(pipeline_runner, **kwargs)
await super().connect(pipeline_runner, iface_id, **kwargs)
7 changes: 6 additions & 1 deletion dff/pipeline/pipeline/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import copy

from dff.utils.turn_caching import cache_clear
from dff.script.core.types import ActorStage, NodeLabel2Type, NodeLabel3Type, LabelType
from dff.script.core.types import DEFAULT_INTERFACE_ID, ActorStage, NodeLabel2Type, NodeLabel3Type, LabelType
from dff.script.core.message import Message

from dff.script.core.context import Context
Expand Down Expand Up @@ -149,10 +149,15 @@ async def __call__(self, pipeline: Pipeline, ctx: Context):
await self._run_pre_response_processing(ctx, pipeline)
await self._run_handlers(ctx, pipeline, ActorStage.RUN_PRE_RESPONSE_PROCESSING)

last_interface = DEFAULT_INTERFACE_ID
if ctx.last_request is not None:
pseusys marked this conversation as resolved.
Show resolved Hide resolved
last_interface = ctx.last_request.interface

# create response
ctx.framework_states["actor"]["response"] = await self.run_response(
ctx.framework_states["actor"]["pre_response_processed_node"].response, ctx, pipeline
)
ctx.framework_states["actor"]["response"].interface = last_interface
await self._run_handlers(ctx, pipeline, ActorStage.CREATE_RESPONSE)
ctx.add_response(ctx.framework_states["actor"]["response"])

Expand Down
32 changes: 22 additions & 10 deletions dff/pipeline/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@

import asyncio
import logging
from typing import Union, List, Dict, Optional, Hashable, Callable
from typing import Iterable, Union, List, Dict, Optional, Hashable, Callable
from uuid import uuid4

from dff.context_storages import DBContextStorage
from dff.script import Script, Context, ActorStage
from dff.script import NodeLabel2Type, Message
from dff.script import NodeLabel2Type, Message, DEFAULT_INTERFACE_ID
from dff.utils.turn_caching import cache_clear

from dff.messengers.common import MessengerInterface, CLIMessengerInterface
Expand Down Expand Up @@ -62,7 +63,7 @@ class Pipeline:
- key: :py:class:`~dff.script.ActorStage` - Stage in which the handler is called.
- value: List[Callable] - The list of called handlers for each stage. Defaults to an empty `dict`.

:param messenger_interface: An `AbsMessagingInterface` instance for this pipeline.
:param messenger_interfaces: An `AbsMessagingInterface` instance for this pipeline.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

todo: fix docs

:param context_storage: An :py:class:`~.DBContextStorage` instance for this pipeline or
a dict to store dialog :py:class:`~.Context`.
:param services: (required) A :py:data:`~.ServiceGroupBuilder` object,
Expand Down Expand Up @@ -92,7 +93,7 @@ def __init__(
condition_handler: Optional[Callable] = None,
verbose: bool = True,
handlers: Optional[Dict[ActorStage, List[Callable]]] = None,
messenger_interface: Optional[MessengerInterface] = None,
messenger_interfaces: Optional[Union[MessengerInterface, Iterable[MessengerInterface], Dict[str, MessengerInterface]]] = None,
context_storage: Optional[Union[DBContextStorage, Dict]] = None,
before_handler: Optional[ExtraHandlerBuilder] = None,
after_handler: Optional[ExtraHandlerBuilder] = None,
Expand All @@ -101,7 +102,6 @@ def __init__(
parallelize_processing: bool = False,
):
self.actor: Actor = None
self.messenger_interface = CLIMessengerInterface() if messenger_interface is None else messenger_interface
self.context_storage = {} if context_storage is None else context_storage
self._services_pipeline = ServiceGroup(
components,
Expand All @@ -110,6 +110,16 @@ def __init__(
timeout=timeout,
)

if messenger_interfaces is not None and not isinstance(messenger_interfaces, MessengerInterface):
if isinstance(messenger_interfaces, Iterable):
self.messenger_interfaces = {str(uuid4()): iface for iface in messenger_interfaces}
elif isinstance(messenger_interfaces, Iterable):
self.messenger_interfaces = messenger_interfaces
else:
raise RuntimeError(f"Unexpected type of 'messenger_interfaces': {type(messenger_interfaces)}")
pseusys marked this conversation as resolved.
Show resolved Hide resolved
else:
self.messenger_interfaces = {DEFAULT_INTERFACE_ID: CLIMessengerInterface()}

self._services_pipeline.name = "pipeline"
self._services_pipeline.path = ".pipeline"
actor_exists = finalize_service_group(self._services_pipeline, path=self._services_pipeline.path)
Expand Down Expand Up @@ -188,7 +198,9 @@ def info_dict(self) -> dict:
"""
return {
"type": type(self).__name__,
"messenger_interface": f"Instance of {type(self.messenger_interface).__name__}",
"messenger_interfaces": {
k: f"Instance of {type(v).__name__}" for k, v in self.messenger_interfaces.items()
},
"context_storage": f"Instance of {type(self.context_storage).__name__}",
"services": [self._services_pipeline.info_dict],
}
Expand Down Expand Up @@ -217,7 +229,7 @@ def from_script(
parallelize_processing: bool = False,
handlers: Optional[Dict[ActorStage, List[Callable]]] = None,
context_storage: Optional[Union[DBContextStorage, Dict]] = None,
messenger_interface: Optional[MessengerInterface] = None,
messenger_interfaces: Optional[Union[MessengerInterface, Iterable[MessengerInterface], Dict[str, MessengerInterface]]] = None,
RLKRo marked this conversation as resolved.
Show resolved Hide resolved
pre_services: Optional[List[Union[ServiceBuilder, ServiceGroupBuilder]]] = None,
post_services: Optional[List[Union[ServiceBuilder, ServiceGroupBuilder]]] = None,
) -> "Pipeline":
Expand Down Expand Up @@ -249,7 +261,7 @@ def from_script(

:param context_storage: An :py:class:`~.DBContextStorage` instance for this pipeline
or a dict to store dialog :py:class:`~.Context`.
:param messenger_interface: An instance for this pipeline.
:param messenger_interfaces: An instance for this pipeline.
:param pre_services: List of :py:data:`~.ServiceBuilder` or
:py:data:`~.ServiceGroupBuilder` that will be executed before Actor.
:type pre_services: Optional[List[Union[ServiceBuilder, ServiceGroupBuilder]]]
Expand All @@ -270,7 +282,7 @@ def from_script(
verbose=verbose,
parallelize_processing=parallelize_processing,
handlers=handlers,
messenger_interface=messenger_interface,
messenger_interfaces=messenger_interfaces,
context_storage=context_storage,
components=[*pre_services, ACTOR, *post_services],
)
Expand Down Expand Up @@ -369,7 +381,7 @@ def run(self):
This method can be both blocking and non-blocking. It depends on current `messenger_interface` nature.
Message interfaces that run in a loop block current thread.
"""
asyncio.run(self.messenger_interface.connect(self._run_pipeline))
asyncio.run(asyncio.gather(*[iface.connect(self._run_pipeline, id) for id, iface in self.messenger_interfaces.items()]))

def __call__(
self, request: Message, ctx_id: Optional[Hashable] = None, update_ctx_misc: Optional[dict] = None
Expand Down
2 changes: 1 addition & 1 deletion dff/pipeline/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ class ExtraHandlerRuntimeInfo(BaseModel):
PipelineBuilder: TypeAlias = TypedDict(
"PipelineBuilder",
{
"messenger_interface": NotRequired[Optional["MessengerInterface"]],
"messenger_interfaces": NotRequired[Optional[Union["MessengerInterface", Iterable["MessengerInterface"], Dict[str, "MessengerInterface"]]]],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update this.

"context_storage": NotRequired[Optional[Union[DBContextStorage, Dict]]],
"components": ServiceGroupBuilder,
"before_handler": NotRequired[Optional[ExtraHandlerBuilder]],
Expand Down
1 change: 1 addition & 0 deletions dff/script/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from .core.script import Node, Script
from .core.types import (
DEFAULT_INTERFACE_ID,
LabelType,
NodeLabel1Type,
NodeLabel2Type,
Expand Down
1 change: 1 addition & 0 deletions dff/script/conditions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
all,
negation,
has_last_labels,
from_interface,
true,
false,
agg,
Expand Down
16 changes: 15 additions & 1 deletion dff/script/conditions/std_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
or other factors that may affect the conversation flow.
"""

from typing import Callable, Pattern, Union, List, Optional
from typing import Callable, Pattern, Type, Union, List, Optional
import logging
import re

from pydantic import validate_call

from dff.messengers.common.interface import MessengerInterface
from dff.pipeline import Pipeline
from dff.script import NodeLabel2Type, Context, Message

Expand Down Expand Up @@ -193,6 +194,19 @@ def has_last_labels_condition_handler(ctx: Context, pipeline: Pipeline) -> bool:
return has_last_labels_condition_handler


def from_interface(iface: Type[MessengerInterface]) -> Callable[[Context, Pipeline], bool]:
pseusys marked this conversation as resolved.
Show resolved Hide resolved
def is_from_interface_type(ctx: Context, pipeline: Pipeline) -> bool:
if ctx.last_request is None:
return False
latest_interface = ctx.last_request.interface
for interface_name, interface_object in pipeline.messenger_interfaces.items():
if interface_name == latest_interface:
return isinstance(interface_object, iface)
return False

return is_from_interface_type


@validate_call
def true() -> Callable[[Context, Pipeline], bool]:
"""
Expand Down
7 changes: 7 additions & 0 deletions dff/script/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,13 @@ def last_label(self) -> Optional[NodeLabel2Type]:
last_index = get_last_index(self.labels)
return self.labels.get(last_index)

def last_response_to(self, interface_id: Optional[str]) -> Optional[Message]:
pseusys marked this conversation as resolved.
Show resolved Hide resolved
for index in list(self.responses)[::-1]:
response = self.responses.get(index)
if response is not None and response.interface == interface_id:
return response
return None

@property
def last_response(self) -> Optional[Message]:
"""
Expand Down
3 changes: 3 additions & 0 deletions dff/script/core/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from pydantic import field_validator, Field, FilePath, HttpUrl, BaseModel, model_validator

from .types import DEFAULT_INTERFACE_ID


class Session(Enum):
"""
Expand Down Expand Up @@ -196,6 +198,7 @@ class level variables to store message information.
attachments: Optional[Attachments] = None
annotations: Optional[dict] = None
misc: Optional[dict] = None
interface: str = DEFAULT_INTERFACE_ID
# commands and state options are required for integration with services
# that use an intermediate backend server, like Yandex's Alice
# state: Optional[Session] = Session.ACTIVE
Expand Down
3 changes: 3 additions & 0 deletions dff/script/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
# TODO: change example


DEFAULT_INTERFACE_ID = "default"


class ActorStage(Enum):
"""
The class which holds keys for the handlers. These keys are used
Expand Down
9 changes: 5 additions & 4 deletions tests/pipeline/test_messenger_interface.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should also add tests for launching pipeline with multiple http interfaces and running them at the same time on different ports.

Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
import pathlib

from dff.script import RESPONSE, TRANSITIONS, Message
from dff.script import DEFAULT_INTERFACE_ID, RESPONSE, TRANSITIONS, Message
from dff.messengers.common import CLIMessengerInterface, CallbackMessengerInterface
from dff.pipeline import Pipeline
import dff.script.conditions as cnd
Expand Down Expand Up @@ -41,7 +41,8 @@ def test_cli_messenger_interface(monkeypatch):
monkeypatch.setattr("builtins.input", lambda _: "Ping")
sys.path.append(str(pathlib.Path(__file__).parent.absolute()))

pipeline.messenger_interface = CLIMessengerInterface(intro="Hi, it's DFF powered bot, let's chat!")
interface = CLIMessengerInterface(intro="Hi, it's DFF powered bot, let's chat!")
pipeline.messenger_interfaces = {DEFAULT_INTERFACE_ID: interface}

def loop() -> bool:
loop.runs_left -= 1
Expand All @@ -50,12 +51,12 @@ def loop() -> bool:
loop.runs_left = 5

# Literally what happens in pipeline.run()
asyncio.run(pipeline.messenger_interface.connect(pipeline._run_pipeline, loop=loop))
asyncio.run(interface.connect(pipeline._run_pipeline, DEFAULT_INTERFACE_ID, loop=loop))


def test_callback_messenger_interface(monkeypatch):
interface = CallbackMessengerInterface()
pipeline.messenger_interface = interface
pipeline.messenger_interfaces = {DEFAULT_INTERFACE_ID: interface}

pipeline.run()

Expand Down
Loading
Loading