diff --git a/docs/flux1.md b/docs/flux1.md index afb4a273..3d8f5ef9 100644 --- a/docs/flux1.md +++ b/docs/flux1.md @@ -2,19 +2,42 @@ This example demonstrates using the Flux1 model to perform tasks such as image generation and mask inpainting based on text prompts. -### Parameters - -The following parameters can be passed to the model: +### Parameters - task: The task to perform using the model - either image generation ("generation") or mask inpainting ("inpainting"). - prompt: The text prompt describing the desired modifications. -- height: The height in pixels of the generated image. Defaults to 1024 for best results. -- width: The width in pixels of the generated image. Defaults to 1024 for best results. +- config: The `Flux1Config` class allows you to configure the parameters for the Flux1 model. +- image (Image.Image): The original image to be modified - +used for the mask inpainting and image to image tasks. +- mask_image (Image.Image): The mask image indicating areas to be inpainted - used for the mask inpainting task + +#### Flux1Config + +Below is an example of how to create and use a `Flux1Config` object: + +```python +from vision_agent_tools.models.flux1 import Flux1Config + +config = Flux1Config( + height=512, + width=512, + num_inference_steps=28, + guidance_scale=3.5, + num_images_per_prompt=1, + max_sequence_length=512, + seed=42 +) +``` + +- height: The height in pixels of the generated image. Defaults to 512. +- width: The width in pixels of the generated image. Defaults to 512. - num_inference_steps: The number of inference steps to perform. Defaults to 28. - guidance_scale: Guidance scale as defined in Classifier-Free Diffusion Guidance. Defaults to 3.5. - num_images_per_prompt: The number of images to generate per prompt. Defaults to 1. - max_sequence_length: Maximum sequence length to use with the prompt. Defaults to 512. - seed: Seed for the random number generator. If not provided, a random seed is used. +- strength: Indicates extent to transform the reference `image`. +Must be between 0 and 1. A value of 1 essentially ignores `image`. ## Perform image generation @@ -29,20 +52,14 @@ flux1 = Flux1() generated_image = flux_model( task=Flux1Task.IMAGE_GENERATION, # Image Generation Task prompt="A purple car in a futuristic cityscape", - height=1024, - width=1024, - num_inference_steps=10, - guidance_scale=3.5, - num_images_per_prompt=1, - max_sequence_length=512, - seed=42 + config=config ) generated_image.save("generated_car.png") ``` -------------------------------------------------------------------- -## Alternatively, perform mask inpainting +## Perform mask inpainting To perform mask inpainting, both the original image and the mask image need to be provided. These images have the same dimensions. The mask should clearly delineate the areas that you want to modify in the original image. Additionally, the inpainting process includes a strength parameter, which controls the intensity of the modifications applied to the masked areas. @@ -68,18 +85,41 @@ inpainted_image = flux_model( prompt=inpainting_prompt, image=image_to_edit, mask_image=mask_image, - height=1024, - width=1024, - strength=0.6, - num_inference_steps=10, - guidance_scale=3.5, - num_images_per_prompt=1, - max_sequence_length=512, - seed=42 + config=config ) inpainted_image.save("inpainted_dog_over_cat.png") ``` +-------------------------------------------------------------------- + +## Perform image-to-image generation + +To perform image-to-image generation, you need to provide an original image along with a text prompt describing the desired modifications. The original image serves as the base, and the model will generate a new image based on the prompt. + +```python +import torch +from PIL import Image +from vision_agent_tools.models.flux1 import Flux1, Flux1Task + +# You have an original image named "original_image.jpg" that you want to use for image-to-image generation +original_image = Image.open("path/to/your/original_image.jpg").convert("RGB") # Original image + +# Set a new prompt for image-to-image generation +image_to_image_prompt = "A sunny beach with palm trees" + +# To perform image-to-image generation +flux1 = Flux1() + +generated_image = flux_model( + task=Flux1Task.IMAGE_TO_IMAGE, # Image-to-Image Generation Task + prompt=image_to_image_prompt, + image=original_image, + config=config +) + +generated_image.save("generated_beach.png") +``` + ::: vision_agent_tools.models.flux1 diff --git a/tests/models/test_flux1.py b/tests/models/test_flux1.py index 85910a09..2cd6120f 100644 --- a/tests/models/test_flux1.py +++ b/tests/models/test_flux1.py @@ -1,7 +1,8 @@ import pytest +from pydantic import ValidationError from PIL import Image -from vision_agent_tools.models.flux1 import Flux1, Flux1Task +from vision_agent_tools.models.flux1 import Flux1, Flux1Task, Flux1Config def test_image_mask_inpainting(model): @@ -9,17 +10,19 @@ def test_image_mask_inpainting(model): image = Image.open("tests/shared_data/images/chihuahua.png") mask_image = Image.open("tests/shared_data/images/chihuahua_mask.png") + config = Flux1Config( + height=32, + width=32, + num_inference_steps=1, + seed=42, + ) + result = model( task=Flux1Task.MASK_INPAINTING, prompt=prompt, image=image, mask_image=mask_image, - height=32, - width=32, - num_inference_steps=1, - guidance_scale=7, - strength=0.85, - seed=42, + config=config, ) assert result is not None @@ -32,16 +35,19 @@ def test_image_mask_inpainting(model): def test_image_generation(model): prompt = "cat wizard, Pixar style" - result = model( - task=Flux1Task.IMAGE_GENERATION, - prompt=prompt, + config = Flux1Config( height=32, width=32, - guidance_scale=0.5, num_inference_steps=1, seed=42, ) + result = model( + task=Flux1Task.IMAGE_GENERATION, + prompt=prompt, + config=config, + ) + assert result is not None assert len(result) == 1 image = result[0] @@ -52,22 +58,30 @@ def test_image_generation(model): def test_fail_image_generation_dimensions(model): prompt = "cat wizard, Pixar style" - height = 31 - width = 31 try: + config = Flux1Config( + height=31, + width=31, + num_inference_steps=1, + seed=42, + ) + model( task=Flux1Task.IMAGE_GENERATION, prompt=prompt, - height=height, - width=width, - num_inference_steps=1, - seed=42, + config=config, ) - except ValueError as e: + except ValidationError as e: assert ( - str(e) - == f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + repr(e.errors()[0]["msg"]) + == "'Assertion failed, height and width must be multiples of 8.'" ) + assert repr(e.errors()[0]["type"]) == "'assertion_error'" + assert ( + repr(e.errors()[1]["msg"]) + == "'Assertion failed, height and width must be multiples of 8.'" + ) + assert repr(e.errors()[1]["type"]) == "'assertion_error'" def test_fail_image_mask_size(model): @@ -76,16 +90,20 @@ def test_fail_image_mask_size(model): mask_image = Image.open("tests/shared_data/images/chihuahua_mask.png") mask_image = mask_image.resize((64, 64)) + config = Flux1Config( + height=32, + width=32, + num_inference_steps=1, + seed=42, + ) + try: model( task=Flux1Task.MASK_INPAINTING, prompt=prompt, image=image, mask_image=mask_image, - height=32, - width=32, - num_inference_steps=1, - seed=42, + config=config, ) except ValueError as e: assert str(e) == "The image and mask image should have the same size." @@ -97,18 +115,23 @@ def test_different_images_different_seeds(model): result_1 = model( task=Flux1Task.IMAGE_GENERATION, prompt=prompt, - height=32, - width=32, - num_inference_steps=1, - seed=42, + config=Flux1Config( + height=32, + width=32, + num_inference_steps=1, + seed=42, + ), ) result_2 = model( + task=Flux1Task.IMAGE_GENERATION, prompt=prompt, - height=32, - width=32, - num_inference_steps=1, - seed=0, + config=Flux1Config( + height=32, + width=32, + num_inference_steps=1, + seed=0, + ), ) assert result_1 is not None @@ -127,9 +150,7 @@ def test_different_images_different_seeds(model): def test_multiple_images_per_prompt(model): prompt = "cat wizard, Pixar style" - result = model( - task=Flux1Task.IMAGE_GENERATION, - prompt=prompt, + config = Flux1Config( height=32, width=32, num_inference_steps=1, @@ -137,6 +158,12 @@ def test_multiple_images_per_prompt(model): seed=42, ) + result = model( + task=Flux1Task.IMAGE_GENERATION, + prompt=prompt, + config=config, + ) + assert result is not None assert len(result) == 3 for image in result: @@ -144,6 +171,31 @@ def test_multiple_images_per_prompt(model): assert image.size == (32, 32) +def test_image_to_image(model): + prompt = "pixar style" + image = Image.open("tests/shared_data/images/chihuahua.png") + + config = Flux1Config( + height=32, + width=32, + num_inference_steps=1, + seed=42, + ) + + result = model( + task=Flux1Task.IMAGE_TO_IMAGE, + prompt=prompt, + image=image, + config=config, + ) + + assert result is not None + assert len(result) == 1 + image = result[0] + assert image.mode == "RGB" + assert image.size == (32, 32) + + @pytest.fixture(scope="session") def model(): return Flux1() diff --git a/vision_agent_tools/models/flux1.py b/vision_agent_tools/models/flux1.py index 4361a541..53a73eb9 100644 --- a/vision_agent_tools/models/flux1.py +++ b/vision_agent_tools/models/flux1.py @@ -1,39 +1,44 @@ import random -import logging from enum import Enum -from typing import List - +from typing import List, Annotated import torch from PIL import Image from pydantic import BaseModel, Field from pydantic import ConfigDict, validate_arguments -from diffusers import FluxPipeline, FluxInpaintPipeline +from pydantic.functional_validators import AfterValidator +from diffusers import FluxPipeline, FluxInpaintPipeline, FluxImg2ImgPipeline from vision_agent_tools.shared_types import BaseMLModel, Device -_LOGGER = logging.getLogger(__name__) + +class Flux1Task(str, Enum): + IMAGE_GENERATION = "generation" + MASK_INPAINTING = "inpainting" + IMAGE_TO_IMAGE = "img2img" + + +def _check_multiple_of_8(number: int) -> int: + assert number % 8 == 0, "height and width must be multiples of 8." + return number class Flux1Config(BaseModel): - hf_model: str = Field( - default="black-forest-labs/FLUX.1-schnell", - description="Name of the HuggingFace model", + """ + Configuration for the Flux1 model. + """ + + height: Annotated[int, AfterValidator(_check_multiple_of_8)] = Field( + ge=8, default=512 ) - device: Device = Field( - default=( - Device.GPU - if torch.cuda.is_available() - else Device.MPS - if torch.backends.mps.is_available() - else Device.CPU - ), - description="Device to run the model on. Options are 'cpu', 'gpu', and 'mps'. Default is the first available GPU.", + width: Annotated[int, AfterValidator(_check_multiple_of_8)] = Field( + ge=8, default=512 ) - - -class Flux1Task(str, Enum): - IMAGE_GENERATION = "generation" - MASK_INPAINTING = "inpainting" + num_inference_steps: int | None = Field(ge=1, default=10) + guidance_scale: float | None = Field(ge=0, default=3.5) + num_images_per_prompt: int | None = Field(ge=1, default=1) + max_sequence_length: int | None = Field(ge=0, le=512, default=512) + seed: int | None = None + strength: float | None = Field(ge=0, le=1, default=0.6) class Flux1(BaseMLModel): @@ -46,98 +51,92 @@ class Flux1(BaseMLModel): def __init__( self, - model_config: Flux1Config | None = None, + hf_model: str = "black-forest-labs/FLUX.1-schnell", + dtype: torch.dtype = torch.bfloat16, + enable_sequential_cpu_offload: bool = True, ): """ Initializes the Flux1 image generation tool. - Loads the pre-trained Flux1 model from HuggingFace and enables sequential CPU offload. + Loads the pre-trained Flux1 model from HuggingFace and sets model configurations. Args: - task (Flux1Task): The task to perform using the model: either image generation ("generation") or mask inpainting ("inpainting"). - model_config: The configuration for the model, hf_model, and device. + - dtype (torch.dtype): The data type to use for the model. + - enable_sequential_cpu_offload (bool): Whether to enable sequential CPU offload. """ - self.model_config = model_config or Flux1Config() - dtype = torch.bfloat16 self._pipeline_img_generation = FluxPipeline.from_pretrained( - self.model_config.hf_model, torch_dtype=dtype + hf_model, torch_dtype=dtype ) - self._pipeline_img_generation.enable_sequential_cpu_offload() + if enable_sequential_cpu_offload: + self._pipeline_img_generation.enable_sequential_cpu_offload() self._pipeline_mask_inpainting = FluxInpaintPipeline.from_pretrained( - self.model_config.hf_model, torch_dtype=dtype + hf_model, torch_dtype=dtype ) - self._pipeline_mask_inpainting.enable_sequential_cpu_offload() + if enable_sequential_cpu_offload: + self._pipeline_mask_inpainting.enable_sequential_cpu_offload() + + self._pipeline_img2img = FluxImg2ImgPipeline.from_pretrained( + hf_model, torch_dtype=dtype + ) + if enable_sequential_cpu_offload: + self._pipeline_img2img.enable_sequential_cpu_offload() @torch.inference_mode() @validate_arguments(config=config) def __call__( self, - prompt: str, + prompt: str = Field(max_length=512), task: Flux1Task = Flux1Task.IMAGE_GENERATION, + config: Flux1Config = Flux1Config(), image: Image.Image | None = None, mask_image: Image.Image | None = None, - height: int = 1024, - width: int = 1024, - strength: float | None = 0.6, - num_inference_steps: int | None = 28, - guidance_scale: float | None = 3.5, - num_images_per_prompt: int | None = 1, - max_sequence_length: int | None = 512, - seed: int | None = None, ) -> List[Image.Image] | None: """ Performs object detection on an image using the Flux1 model. Args: - prompt (str): The text prompt describing the desired modifications. + - task (Flux1Task): The task to perform using the model: + - image generation - "generation", + - mask inpainting - "inpainting", + - image-to-image generation - "img2img". + - config (Flux1Config): + - height (`int`, *optional*): + The height in pixels of the generated image. + This is set to 512 by default. + - width (`int`, *optional*): + The width in pixels of the generated image. + This is set to 512 by default. + - num_inference_steps (`int`, *optional*, defaults to 28): + - guidance_scale (`float`, *optional*, defaults to 3.5): + Guidance scale as defined in Classifier-Free Diffusion Guidance. + Higher guidance scale encourages to generate images + that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + - num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + - max_sequence_length (`int` defaults to 512): + Maximum sequence length to use with the `prompt`. + to make generation deterministic. + - seed (`int`, *optional*): The seed to use for the random number generator. + If not provided, a random seed is used. + - strength (`float`, *optional*, defaults to 0.6): + Indicates extent to transform the reference `image`. + Must be between 0 and 1. + A value of 1 essentially ignores `image`. - image (Image.Image): The original image to be modified. - mask_image (Image.Image): The mask image indicating areas to be inpainted. - - height (`int`, *optional*): - The height in pixels of the generated image. - This is set to 1024 by default for the best results. - - width (`int`, *optional*): - The width in pixels of the generated image. - This is set to 1024 by default for the best results. - - num_inference_steps (`int`, *optional*, defaults to 28): - - guidance_scale (`float`, *optional*, defaults to 3.5): - Guidance scale as defined in Classifier-Free Diffusion Guidance. - Higher guidance scale encourages to generate images - that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - - max_sequence_length (`int` defaults to 512): - Maximum sequence length to use with the `prompt`. - to make generation deterministic. - - strength (`float`, *optional*, defaults to 0.6): - Indicates extent to transform the reference `image`. - Must be between 0 and 1. - A value of 1 essentially ignores `image`. - - seed (`int`, *optional*): The seed to use for the random number generator. - If not provided, a random seed is used. Returns: - Image.Image | None: The output image if the Flux1 process is successful; - None if an error occurred. + List[Image.Image]: The list of generated image(s) if successful; None if an error occurred. """ - if height % 8 != 0 or width % 8 != 0: - raise ValueError( - f"`height` and `width` have to be divisible by 8 but are {height} and {width}." - ) - - if max_sequence_length is not None and max_sequence_length > 512: - raise ValueError( - f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}" - ) - - if strength < 0 or strength > 1: - raise ValueError( - f"The value of strength should in [0.0, 1.0] but is {strength}" - ) + seed = config.seed if seed is None: seed = random.randint(0, 2**32 - 1) @@ -148,42 +147,69 @@ def __call__( if task == Flux1Task.IMAGE_GENERATION: output = self._generate_image( prompt=prompt, - height=height, - width=width, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - num_images_per_prompt=num_images_per_prompt, + height=config.height, + width=config.width, + num_inference_steps=config.num_inference_steps, + guidance_scale=config.guidance_scale, + num_images_per_prompt=config.num_images_per_prompt, + max_sequence_length=config.max_sequence_length, generator=generator, - max_sequence_length=max_sequence_length, ) elif task == Flux1Task.MASK_INPAINTING: + if image is None or mask_image is None: + raise ValueError( + "Both image and mask image must be provided for inpainting." + ) + if image.size != mask_image.size: raise ValueError("The image and mask image should have the same size.") - if height is None: - height = image.height - if width is None: - width = image.width + height, width = config.height, config.width + + if height is None or width is None: + height, width = image.size output = self._inpaint_image( - prompt=prompt, image=image, mask_image=mask_image, + prompt=prompt, height=height, width=width, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - num_images_per_prompt=num_images_per_prompt, + num_inference_steps=config.num_inference_steps, + guidance_scale=config.guidance_scale, + num_images_per_prompt=config.num_images_per_prompt, + max_sequence_length=config.max_sequence_length, + strength=config.strength, + generator=generator, + ) + elif task == Flux1Task.IMAGE_TO_IMAGE: + if image is None: + raise ValueError( + "Image must be provided for image-to-image generation." + ) + + height, width = config.height, config.width + + if height is None or width is None: + height, width = image.size + + output = self._image_to_image( + prompt=prompt, + image=image, + height=config.height, + width=config.width, + num_inference_steps=config.num_inference_steps, + guidance_scale=config.guidance_scale, + num_images_per_prompt=config.num_images_per_prompt, + max_sequence_length=config.max_sequence_length, generator=generator, - max_sequence_length=max_sequence_length, - strength=strength, ) else: raise ValueError( - f"Unsupported task: {self.task}. Supported tasks are: {', '.join([task.value for task in Flux1Task])}." + f"Unsupported task: {task}. Supported tasks are: {', '.join([task.value for task in Flux1Task])}." ) - return output.images + return output def to(self, device: Device): raise NotImplementedError("This method is not supported for Flux1 model.") @@ -215,7 +241,7 @@ def _generate_image( max_sequence_length (`int`) Returns: - Optional[Image.Image]: The generated image if successful; None if an error occurred. + List[Image.Image]: The list of generated image(s) if successful; None if an error occurred. """ output = self._pipeline_img_generation( prompt=prompt, @@ -231,7 +257,7 @@ def _generate_image( if output is None: return None - return output + return output.images def _inpaint_image( self, @@ -267,7 +293,7 @@ def _inpaint_image( strength (`float`) Returns: - Optional[Image.Image]: The inpainted image if successful; None if an error occurred. + List[Image.Image]: The list of inpainted image(s) if successful; None if an error occurred. """ output = self._pipeline_mask_inpainting( prompt=prompt, @@ -286,4 +312,51 @@ def _inpaint_image( if output is None: return None - return output + return output.images + + def _image_to_image( + self, + prompt: str, + image: Image.Image, + height: int, + width: int, + num_inference_steps: int, + guidance_scale: float, + num_images_per_prompt: int, + generator: torch.Generator, + max_sequence_length: int, + ) -> List[Image.Image] | None: + """ + Generate an image from a given prompt + provided reference image. + + Image generation pipeline to create an image based on a provided textual prompt. + + Args: + prompt (`str`) + height (`int`) + width (`int`) + num_inference_steps (`int`) + guidance_scale (`float`) + num_images_per_prompt (`int`) + generator (`torch.Generator`) + max_sequence_length (`int`) + + Returns: + List[Image.Image]: The list of generated image(s) if successful; None if an error occurred. + """ + output = self._pipeline_img2img( + prompt=prompt, + image=image, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + max_sequence_length=max_sequence_length, + ) + + if output is None: + return None + + return output.images