Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempting to change the descriptions to use policy #1

Open
wants to merge 1 commit into
base: product-descriptions
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,4 @@ policies:
epochs: 120
learning_rate: 0.02
constrain_similarities: true
- name: pipeline.product_description_policy.ProductDescriptionPolicy
69 changes: 69 additions & 0 deletions pipeline/product_description_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import Dict, Text, Any, Optional

from rasa.core.policies.policy import Policy
from rasa.core.policies.policy import PolicyPrediction
from rasa.shared.core.domain import Domain
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.core.featurizers.precomputation import MessageContainerForCoreFeaturization
from rasa.engine.storage.resource import Resource
from rasa.engine.graph import ExecutionContext, GraphComponent
from rasa.engine.storage.storage import ModelStorage
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
from rasa.core.constants import (
POLICY_PRIORITY,
POLICY_MAX_HISTORY,
)

@DefaultV1Recipe.register(
[DefaultV1Recipe.ComponentType.POLICY_WITHOUT_END_TO_END_SUPPORT], is_trainable=False
)
class ProductDescriptionPolicy(Policy):
def __init__(
self,
config: Dict[Text, Any],
name: Text,
model_storage: ModelStorage,
resource: Resource,
) -> None:
super().__init__(config, name, model_storage, resource)
self.config = config
self.name = name
self.model_storage = model_storage
self.resource = resource

@classmethod
def create(
cls,
config: Dict[Text, Any],
model_storage: ModelStorage,
resource: Resource,
execution_context: ExecutionContext,
) -> GraphComponent:
return cls(config, execution_context.node_name, model_storage, resource)

@staticmethod
def get_default_config() -> Dict[Text, Any]:
return {
# Determines the importance of policies, higher values take precedence
POLICY_PRIORITY: 0.0,
POLICY_MAX_HISTORY: 10,
}

def predict_action_probabilities(
self,
tracker: DialogueStateTracker,
domain: Domain,
rule_only_data: Optional[Dict[Text, Any]] = None,
precomputations: Optional[MessageContainerForCoreFeaturization] = None,
**kwargs: Any,
) -> PolicyPrediction:
# check if we are in an active loop
if tracker.active_loop:
print("ProductDescriptionPolicy: Active loop detected")
# print the latest intent detected
print("ProductDescriptionPolicy: Latest intent detected: ", tracker.latest_message.get("intent", {}).get("name"))
entites = tracker.latest_message.get("entities", [])
if any(entity["entity"] == "product" for entity in entites):
print("ProductDescriptionPolicy: Product entity detected")
print("ProductDescriptionPolicy: Product entity value: ", next(entity["value"] for entity in entites if entity["entity"] == "product"))
return PolicyPrediction([0.0 for _ in range(domain.num_actions)], "ProductDescriptionPolicy")