diff --git a/main.py b/main.py index 2beaa76..0d5c6fc 100644 --- a/main.py +++ b/main.py @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) -def generate_image_from_prompt(user_prompt: str, show_on_screen: bool = False) -> str: +def generate_image_from_prompt(user_prompt: str, show_on_screen: bool = False, shall_have_text: bool = True) -> str: """ Generates an image from a prompt and saves it to file and returns the image""" logger.info('Starting MarketingAI') @@ -23,7 +23,7 @@ def generate_image_from_prompt(user_prompt: str, show_on_screen: bool = False) - classification = run_agent(user_prompt) logger.info(f'Classification: {classification}') - image_prompt = get_image_template(user_prompt, classification) + image_prompt = get_image_template(user_prompt, classification, shall_have_text) logger.info('Generating Text on prompt') logger.info(f'Starting image generation based on prompt: {image_prompt}') @@ -41,6 +41,12 @@ def generate_image_from_prompt(user_prompt: str, show_on_screen: bool = False) - template = f"This is a picture of {image_prompt}. Generate a short captivating and relevant caption for a poster. The response should not contain any other information than the caption." result = request_chat_completion(None, 'system', template) + result: str = '' + if shall_have_text: + logger.info('Generating image text') + template = f"This is a picture of {image_prompt}. Generate a short captivating and relevant caption for a poster. The response should not contain any other information than the caption." + result = request_chat_completion(None, 'system', template) + # Assemble image logger.info('Assembling image Generated by image prompt: {result}') assemble_image(user_prompt, result, "arial.ttf", 20, chose_color(user_prompt), (0, 0), show_on_screen) diff --git a/src/function_calling/image_classifier.py b/src/function_calling/image_classifier.py index 082b6a7..d21a39a 100644 --- a/src/function_calling/image_classifier.py +++ b/src/function_calling/image_classifier.py @@ -6,11 +6,8 @@ from langchain.agents import initialize_agent from src.gpt.text_generator import request_chat_completion from src.config import Config - logger = logging.getLogger(__name__) - - -def get_image_template(user_prompt: str, classification: str) -> str: +def get_image_template(user_prompt: str, classification: str, shall_have_text: bool = True) -> str: """ Generate image template based on classification. Args: @@ -25,19 +22,18 @@ def get_image_template(user_prompt: str, classification: str) -> str: image_prompt = "Meme: " + user_prompt else: image_prompt = "Poster: " + user_prompt + + if not shall_have_text: + image_prompt = "Do not add any text. " + image_prompt return image_prompt - def classify_text(text: str) -> str: """Classify text into one of three categories: meme, propaganda, marketing.""" if not isinstance(text, str): raise TypeError("Text must be a string.") - # Use gpt to classify gpt_str = "Classify this text into one of three categories: meme, propaganda, marketing. \"" + text + "\". Response should be one of the three categories." result = request_chat_completion(previous_message={}, message=gpt_str) - return "Classify this text into one of three categories: meme, propaganda, marketing. \"" + result + "\". Response should be one of the three categories." - tools: list[StructuredTool] = [ StructuredTool.from_function( name= "Classify Text", @@ -45,10 +41,8 @@ def classify_text(text: str) -> str: description="Classify text into one of three categories: meme, propaganda, marketing.", ), ] - # Make a memory for the agent to use memory = ConversationBufferMemory(memory_key="chat_history") - llm = OpenAI(temperature=0, openai_api_key=Config().API_KEY) agent_chain = initialize_agent( tools, @@ -58,15 +52,13 @@ def classify_text(text: str) -> str: memory=memory, max_iterations=10, ) - def run_agent(prompt: str) -> str: """Run the agent.""" if not isinstance(prompt, str): raise TypeError("Prompt must be a string.") - if (len(prompt) < 1) or (len(prompt) > 1000): raise ValueError("Prompt must be at least 1 character or less than 1000 characters.") result = agent_chain.run(prompt) logger.info(f"Finished running langchain_function_calling.py, result: {result}") - return result + return result \ No newline at end of file