Skip to content

Commit

Permalink
Refactor: Make use get_default_text_generator instead of llm
Browse files Browse the repository at this point in the history
  • Loading branch information
SverreNystad committed Sep 19, 2023
1 parent 0b906be commit 560f0f8
Showing 1 changed file with 7 additions and 12 deletions.
19 changes: 7 additions & 12 deletions src/npc_generation.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
from dataclasses import dataclass
import logging
import random
from langchain.llms import OpenAI
from enum import Enum
from src.text_generation.config import GPTConfig
from langchain.schema import HumanMessage


api_key = GPTConfig.API_KEY
llm: OpenAI = OpenAI(openai_api_key=api_key) if api_key is not None else None
from src.text_generation.text_generator import get_default_text_generator


# Create a logger instance for this script
Expand Down Expand Up @@ -135,7 +130,7 @@ def generate_npc() -> NPC:

def generate_name(race: Race, role: str) -> str:
name_template = f"What would be a good name for a {race.value} that has the role of {role} for a RPG?"
raw_name: str = llm.predict(name_template)
raw_name: str = get_default_text_generator().predict(name_template)
# Clean the name
name = raw_name.replace("\n", "")
return name
Expand All @@ -154,7 +149,7 @@ def generate_general_background(name: str, age: int, race: Race, role: str) -> s
"""

text = f"""Generate a backstory for a NPC of race: {race.value}, with the name: {name}, and age: {age}. The character should have the role: {role} in the story."""
background = llm.predict(text)
background = get_default_text_generator().predict(text)
return background

def generate_alignment(info:str=None, alignment_list:list[Alignment]=None) -> Alignment:
Expand All @@ -164,7 +159,7 @@ def generate_alignment(info:str=None, alignment_list:list[Alignment]=None) -> Al
alignment_list = [alignment.value for alignment in Alignment]

alignment_template = get_alignment_template(info, alignment_list)
raw_alignment = llm.predict(alignment_template)
raw_alignment = get_default_text_generator().predict(alignment_template)

# Clean the alignment
alignment = raw_alignment.replace("\n", "")
Expand All @@ -191,7 +186,7 @@ def generate_npc_relations(background:str) -> NPCRelations:
"""Generate the relations for an NPC."""
# Generate relations
relations_template = get_npc_relation_template(background)
raw_relations = llm.predict(relations_template)
raw_relations = get_default_text_generator().predict(relations_template)

# Clean the relations
raw_relations = raw_relations.strip()
Expand Down Expand Up @@ -231,7 +226,7 @@ def generate_npc_psychology(profile: NPCProfile, background: str, relations: NPC
"""Generate the psychology for an NPC."""
# Generate psychology
psychology_template = get_psychology_template(profile, background, relations)
raw_psychology = llm.predict(psychology_template)
raw_psychology = get_default_text_generator().predict(psychology_template)

# Clean the psychology
raw_psychology = raw_psychology.strip()
Expand Down Expand Up @@ -328,7 +323,7 @@ def generate_appearance(race: Race, age: int, role: str, backstory: str) -> str:

# Generate appearance
appearance_template = get_appearance_template(race, age, role, backstory)
raw_appearance = llm.predict(appearance_template)
raw_appearance = get_default_text_generator().predict(appearance_template)

# Parse the appearance
appearance = raw_appearance.strip()
Expand Down

0 comments on commit 560f0f8

Please sign in to comment.