Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple messenger interface support #332

Open
wants to merge 10 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions dff/messengers/common/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@ class MessengerInterface(abc.ABC):
It is responsible for connection between user and pipeline, as well as for request-response transactions.
"""

def __init__(self, name: Optional[str] = None):
self.name = name if name is not None else str(type(self))


@abc.abstractmethod
async def connect(self, pipeline_runner: PipelineRunnerFunction):
async def connect(self, pipeline_runner: PipelineRunnerFunction, iface_id: str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why add this parameter when it is always self.name?

"""
Method invoked when message interface is instantiated and connection is established.
May be used for sending an introduction message or displaying general bot information.
Expand All @@ -44,6 +48,9 @@ class PollingMessengerInterface(MessengerInterface):
Polling message interface runs in a loop, constantly asking users for a new input.
"""

def __init__(self, name: Optional[str] = None):
MessengerInterface.__init__(self, name)

@abc.abstractmethod
def _request(self) -> List[Tuple[Message, Hashable]]:
"""
Expand Down Expand Up @@ -91,6 +98,7 @@ async def _polling_loop(
async def connect(
self,
pipeline_runner: PipelineRunnerFunction,
iface_id: str,
loop: PollingInterfaceLoopFunction = lambda: True,
timeout: float = 0,
):
Expand All @@ -105,6 +113,7 @@ async def connect(
called in each cycle, should return `True` to continue polling or `False` to stop.
:param timeout: a time interval between polls (in seconds).
"""
self._interface_id = iface_id
while loop():
try:
await self._polling_loop(pipeline_runner, timeout)
Expand All @@ -119,11 +128,13 @@ class CallbackMessengerInterface(MessengerInterface):
Callback message interface is waiting for user input and answers once it gets one.
"""

def __init__(self):
def __init__(self, name: Optional[str] = None):
self._pipeline_runner: Optional[PipelineRunnerFunction] = None
MessengerInterface.__init__(self, name)

async def connect(self, pipeline_runner: PipelineRunnerFunction):
async def connect(self, pipeline_runner: PipelineRunnerFunction, iface_id: str):
self._pipeline_runner = pipeline_runner
self._interface_id = iface_id

async def on_request_async(
self, request: Message, ctx_id: Optional[Hashable] = None, update_ctx_misc: Optional[dict] = None
Expand Down Expand Up @@ -156,21 +167,22 @@ def __init__(
prompt_request: str = "request: ",
prompt_response: str = "response: ",
out_descriptor: Optional[TextIO] = None,
name: Optional[str] = None
):
super().__init__()
PollingMessengerInterface.__init__(self, name)
self._ctx_id: Optional[Hashable] = None
self._intro: Optional[str] = intro
self._prompt_request: str = prompt_request
self._prompt_response: str = prompt_response
self._descriptor: Optional[TextIO] = out_descriptor

def _request(self) -> List[Tuple[Message, Any]]:
return [(Message(input(self._prompt_request)), self._ctx_id)]
return [(Message(input(self._prompt_request), interface=self._interface_id), self._ctx_id)]

def _respond(self, responses: List[Context]):
print(f"{self._prompt_response}{responses[0].last_response.text}", file=self._descriptor)
print(f"{self._prompt_response}{responses[0].last_response_to(self._interface_id).text}", file=self._descriptor)

async def connect(self, pipeline_runner: PipelineRunnerFunction, **kwargs):
async def connect(self, pipeline_runner: PipelineRunnerFunction, iface_id: str, **kwargs):
"""
The CLIProvider generates new dialog id used to user identification on each `connect` call.

Expand All @@ -181,4 +193,4 @@ async def connect(self, pipeline_runner: PipelineRunnerFunction, **kwargs):
self._ctx_id = uuid.uuid4()
if self._intro is not None:
print(self._intro)
await super().connect(pipeline_runner, **kwargs)
await super().connect(pipeline_runner, iface_id, **kwargs)
2 changes: 2 additions & 0 deletions dff/messengers/telegram/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ def __init__(
timeout: int = 20,
long_polling_timeout: int = 20,
messenger: Optional[TelegramMessenger] = None,
name: Optional[str] = None
):
super().__init__(name)
self.messenger = (
messenger if messenger is not None else TelegramMessenger(token, suppress_middleware_excepions=True)
)
Expand Down
3 changes: 3 additions & 0 deletions dff/pipeline/pipeline/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,13 @@ async def __call__(self, pipeline: Pipeline, ctx: Context):
await self._run_pre_response_processing(ctx, pipeline)
await self._run_handlers(ctx, pipeline, ActorStage.RUN_PRE_RESPONSE_PROCESSING)

last_interface = ctx.last_request.interface

# create response
ctx.framework_states["actor"]["response"] = await self.run_response(
ctx.framework_states["actor"]["pre_response_processed_node"].response, ctx, pipeline
)
ctx.framework_states["actor"]["response"].interface = last_interface
await self._run_handlers(ctx, pipeline, ActorStage.CREATE_RESPONSE)
ctx.add_response(ctx.framework_states["actor"]["response"])

Expand Down
26 changes: 17 additions & 9 deletions dff/pipeline/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

import asyncio
import logging
from typing import Union, List, Dict, Optional, Hashable, Callable
from typing import Iterable, Union, List, Dict, Optional, Hashable, Callable
from uuid import uuid4

from dff.context_storages import DBContextStorage
from dff.script import Script, Context, ActorStage
Expand Down Expand Up @@ -62,7 +63,7 @@ class Pipeline:
- key: :py:class:`~dff.script.ActorStage` - Stage in which the handler is called.
- value: List[Callable] - The list of called handlers for each stage. Defaults to an empty `dict`.

:param messenger_interface: An `AbsMessagingInterface` instance for this pipeline.
:param messenger_interfaces: An `AbsMessagingInterface` instance for this pipeline.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

todo: fix docs

:param context_storage: An :py:class:`~.DBContextStorage` instance for this pipeline or
a dict to store dialog :py:class:`~.Context`.
:param services: (required) A :py:data:`~.ServiceGroupBuilder` object,
Expand Down Expand Up @@ -92,7 +93,7 @@ def __init__(
condition_handler: Optional[Callable] = None,
verbose: bool = True,
handlers: Optional[Dict[ActorStage, List[Callable]]] = None,
messenger_interface: Optional[MessengerInterface] = None,
messenger_interfaces: Optional[Iterable[MessengerInterface]] = None,
context_storage: Optional[Union[DBContextStorage, Dict]] = None,
before_handler: Optional[ExtraHandlerBuilder] = None,
after_handler: Optional[ExtraHandlerBuilder] = None,
Expand All @@ -101,7 +102,6 @@ def __init__(
parallelize_processing: bool = False,
):
self.actor: Actor = None
self.messenger_interface = CLIMessengerInterface() if messenger_interface is None else messenger_interface
self.context_storage = {} if context_storage is None else context_storage
self._services_pipeline = ServiceGroup(
components,
Expand All @@ -110,6 +110,12 @@ def __init__(
timeout=timeout,
)

if messenger_interfaces is None:
interface = CLIMessengerInterface()
self.messenger_interfaces = {interface.name: interface}
else:
self.messenger_interfaces = {iface.name: iface for iface in messenger_interfaces}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check for duplicates.


self._services_pipeline.name = "pipeline"
self._services_pipeline.path = ".pipeline"
actor_exists = finalize_service_group(self._services_pipeline, path=self._services_pipeline.path)
Expand Down Expand Up @@ -188,7 +194,9 @@ def info_dict(self) -> dict:
"""
return {
"type": type(self).__name__,
"messenger_interface": f"Instance of {type(self.messenger_interface).__name__}",
"messenger_interfaces": {
k: f"Instance of {type(v).__name__}" for k, v in self.messenger_interfaces.items()
},
"context_storage": f"Instance of {type(self.context_storage).__name__}",
"services": [self._services_pipeline.info_dict],
}
Expand Down Expand Up @@ -217,7 +225,7 @@ def from_script(
parallelize_processing: bool = False,
handlers: Optional[Dict[ActorStage, List[Callable]]] = None,
context_storage: Optional[Union[DBContextStorage, Dict]] = None,
messenger_interface: Optional[MessengerInterface] = None,
messenger_interfaces: Optional[Iterable[MessengerInterface]] = None,
pre_services: Optional[List[Union[ServiceBuilder, ServiceGroupBuilder]]] = None,
post_services: Optional[List[Union[ServiceBuilder, ServiceGroupBuilder]]] = None,
) -> "Pipeline":
Expand Down Expand Up @@ -249,7 +257,7 @@ def from_script(

:param context_storage: An :py:class:`~.DBContextStorage` instance for this pipeline
or a dict to store dialog :py:class:`~.Context`.
:param messenger_interface: An instance for this pipeline.
:param messenger_interfaces: An instance for this pipeline.
:param pre_services: List of :py:data:`~.ServiceBuilder` or
:py:data:`~.ServiceGroupBuilder` that will be executed before Actor.
:type pre_services: Optional[List[Union[ServiceBuilder, ServiceGroupBuilder]]]
Expand All @@ -270,7 +278,7 @@ def from_script(
verbose=verbose,
parallelize_processing=parallelize_processing,
handlers=handlers,
messenger_interface=messenger_interface,
messenger_interfaces=messenger_interfaces,
context_storage=context_storage,
components=[*pre_services, ACTOR, *post_services],
)
Expand Down Expand Up @@ -369,7 +377,7 @@ def run(self):
This method can be both blocking and non-blocking. It depends on current `messenger_interface` nature.
Message interfaces that run in a loop block current thread.
"""
asyncio.run(self.messenger_interface.connect(self._run_pipeline))
asyncio.run(asyncio.gather(*[iface.connect(self._run_pipeline, id) for id, iface in self.messenger_interfaces.items()]))

def __call__(
self, request: Message, ctx_id: Optional[Hashable] = None, update_ctx_misc: Optional[dict] = None
Expand Down
2 changes: 1 addition & 1 deletion dff/pipeline/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ class ExtraHandlerRuntimeInfo(BaseModel):
PipelineBuilder: TypeAlias = TypedDict(
"PipelineBuilder",
{
"messenger_interface": NotRequired[Optional["MessengerInterface"]],
"messenger_interfaces": NotRequired[Optional[Union["MessengerInterface", Iterable["MessengerInterface"], Dict[str, "MessengerInterface"]]]],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update this.

"context_storage": NotRequired[Optional[Union[DBContextStorage, Dict]]],
"components": ServiceGroupBuilder,
"before_handler": NotRequired[Optional[ExtraHandlerBuilder]],
Expand Down
1 change: 1 addition & 0 deletions dff/script/conditions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
all,
negation,
has_last_labels,
from_interface,
true,
false,
agg,
Expand Down
18 changes: 17 additions & 1 deletion dff/script/conditions/std_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
or other factors that may affect the conversation flow.
"""

from typing import Callable, Pattern, Union, List, Optional
from typing import Callable, Pattern, Type, Union, List, Optional
import logging
import re

from pydantic import validate_call

from dff.messengers.common.interface import MessengerInterface
from dff.pipeline import Pipeline
from dff.script import NodeLabel2Type, Context, Message

Expand Down Expand Up @@ -193,6 +194,21 @@ def has_last_labels_condition_handler(ctx: Context, pipeline: Pipeline) -> bool:
return has_last_labels_condition_handler


def from_interface(iface: Optional[Type[MessengerInterface]] = None, name: Optional[str] = None) -> Callable[[Context, Pipeline], bool]:
def is_from_interface_type(ctx: Context, pipeline: Pipeline) -> bool:
if ctx.last_request is None:
return False
latest_interface = ctx.last_request.interface
for interface_name, interface_object in pipeline.messenger_interfaces.items():
if interface_name == latest_interface:
name_match = name is None or interface_name == name
type_match = iface is None or isinstance(interface_object, iface)
return name_match and type_match
return False

return is_from_interface_type


@validate_call
def true() -> Callable[[Context, Pipeline], bool]:
"""
Expand Down
14 changes: 14 additions & 0 deletions dff/script/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,13 @@ def last_label(self) -> Optional[NodeLabel2Type]:
last_index = get_last_index(self.labels)
return self.labels.get(last_index)

def last_response_to(self, interface_id: Optional[str]) -> Optional[Message]:
pseusys marked this conversation as resolved.
Show resolved Hide resolved
for index in list(self.responses)[::-1]:
response = self.responses.get(index)
if response is not None and response.interface == interface_id:
return response
return None

@property
def last_response(self) -> Optional[Message]:
"""
Expand All @@ -232,6 +239,13 @@ def last_response(self) -> Optional[Message]:
last_index = get_last_index(self.responses)
return self.responses.get(last_index)

def last_request_from(self, interface_id: Optional[str]) -> Optional[Message]:
for index in list(self.requests)[::-1]:
request = self.requests.get(index)
if request is not None and request.interface == interface_id:
return request
return None

@last_response.setter
def last_response(self, response: Optional[Message]):
"""
Expand Down
4 changes: 3 additions & 1 deletion dff/script/core/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ class level variables to store message information.
attachments: Optional[Attachments] = None
annotations: Optional[dict] = None
misc: Optional[dict] = None
interface: Optional[str] = None
# commands and state options are required for integration with services
# that use an intermediate backend server, like Yandex's Alice
# state: Optional[Session] = Session.ACTIVE
Expand All @@ -208,10 +209,11 @@ def __init__(
attachments: Optional[Attachments] = None,
annotations: Optional[dict] = None,
misc: Optional[dict] = None,
interface: Optional[str] = None,
**kwargs,
):
super().__init__(
text=text, commands=commands, attachments=attachments, annotations=annotations, misc=misc, **kwargs
text=text, commands=commands, attachments=attachments, annotations=annotations, misc=misc, interface=interface, **kwargs
)

def __eq__(self, other):
Expand Down
7 changes: 4 additions & 3 deletions tests/pipeline/test_messenger_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def test_cli_messenger_interface(monkeypatch):
monkeypatch.setattr("builtins.input", lambda _: "Ping")
sys.path.append(str(pathlib.Path(__file__).parent.absolute()))

pipeline.messenger_interface = CLIMessengerInterface(intro="Hi, it's DFF powered bot, let's chat!")
interface = CLIMessengerInterface(intro="Hi, it's DFF powered bot, let's chat!")
pipeline.messenger_interfaces = {interface.name: interface}

def loop() -> bool:
loop.runs_left -= 1
Expand All @@ -50,12 +51,12 @@ def loop() -> bool:
loop.runs_left = 5

# Literally what happens in pipeline.run()
asyncio.run(pipeline.messenger_interface.connect(pipeline._run_pipeline, loop=loop))
asyncio.run(interface.connect(pipeline._run_pipeline, interface.name, loop=loop))


def test_callback_messenger_interface(monkeypatch):
interface = CallbackMessengerInterface()
pipeline.messenger_interface = interface
pipeline.messenger_interfaces = {interface.name: interface}

pipeline.run()

Expand Down
Loading
Loading