Skip to content

Commit

Permalink
Merge pull request #86 from CogitoNTNU/85-image-generation-with-no-text
Browse files Browse the repository at this point in the history
85 image generation with no text
  • Loading branch information
SverreNystad authored Jan 7, 2024
2 parents 25a452d + 74c3109 commit 1d94227
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 15 deletions.
10 changes: 8 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

logger = logging.getLogger(__name__)

def generate_image_from_prompt(user_prompt: str, show_on_screen: bool = False) -> str:
def generate_image_from_prompt(user_prompt: str, show_on_screen: bool = False, shall_have_text: bool = True) -> str:
""" Generates an image from a prompt and saves it to file and returns the image"""
logger.info('Starting MarketingAI')

Expand All @@ -23,7 +23,7 @@ def generate_image_from_prompt(user_prompt: str, show_on_screen: bool = False) -
classification = run_agent(user_prompt)
logger.info(f'Classification: {classification}')

image_prompt = get_image_template(user_prompt, classification)
image_prompt = get_image_template(user_prompt, classification, shall_have_text)

logger.info('Generating Text on prompt')
logger.info(f'Starting image generation based on prompt: {image_prompt}')
Expand All @@ -41,6 +41,12 @@ def generate_image_from_prompt(user_prompt: str, show_on_screen: bool = False) -
template = f"This is a picture of {image_prompt}. Generate a short captivating and relevant caption for a poster. The response should not contain any other information than the caption."
result = request_chat_completion(None, 'system', template)

result: str = ''
if shall_have_text:
logger.info('Generating image text')
template = f"This is a picture of {image_prompt}. Generate a short captivating and relevant caption for a poster. The response should not contain any other information than the caption."
result = request_chat_completion(None, 'system', template)

# Assemble image
logger.info('Assembling image Generated by image prompt: {result}')
assemble_image(user_prompt, result, "arial.ttf", 20, chose_color(user_prompt), (0, 0), show_on_screen)
Expand Down
18 changes: 5 additions & 13 deletions src/function_calling/image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@
from langchain.agents import initialize_agent
from src.gpt.text_generator import request_chat_completion
from src.config import Config

logger = logging.getLogger(__name__)


def get_image_template(user_prompt: str, classification: str) -> str:
def get_image_template(user_prompt: str, classification: str, shall_have_text: bool = True) -> str:
"""
Generate image template based on classification.
Args:
Expand All @@ -25,30 +22,27 @@ def get_image_template(user_prompt: str, classification: str) -> str:
image_prompt = "Meme: " + user_prompt
else:
image_prompt = "Poster: " + user_prompt

if not shall_have_text:
image_prompt = "Do not add any text. " + image_prompt
return image_prompt

def classify_text(text: str) -> str:
"""Classify text into one of three categories: meme, propaganda, marketing."""
if not isinstance(text, str):
raise TypeError("Text must be a string.")

# Use gpt to classify
gpt_str = "Classify this text into one of three categories: meme, propaganda, marketing. \"" + text + "\". Response should be one of the three categories."
result = request_chat_completion(previous_message={}, message=gpt_str)

return "Classify this text into one of three categories: meme, propaganda, marketing. \"" + result + "\". Response should be one of the three categories."

tools: list[StructuredTool] = [
StructuredTool.from_function(
name= "Classify Text",
func=classify_text,
description="Classify text into one of three categories: meme, propaganda, marketing.",
),
]

# Make a memory for the agent to use
memory = ConversationBufferMemory(memory_key="chat_history")

llm = OpenAI(temperature=0, openai_api_key=Config().API_KEY)
agent_chain = initialize_agent(
tools,
Expand All @@ -58,15 +52,13 @@ def classify_text(text: str) -> str:
memory=memory,
max_iterations=10,
)

def run_agent(prompt: str) -> str:
"""Run the agent."""
if not isinstance(prompt, str):
raise TypeError("Prompt must be a string.")

if (len(prompt) < 1) or (len(prompt) > 1000):
raise ValueError("Prompt must be at least 1 character or less than 1000 characters.")

result = agent_chain.run(prompt)
logger.info(f"Finished running langchain_function_calling.py, result: {result}")
return result
return result

0 comments on commit 1d94227

Please sign in to comment.