-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(TextToImage): Completion of text to image thought/skill using St…
…ableDiffusion XL and code interpreter.
- Loading branch information
Showing
3 changed files
with
182 additions
and
1 deletion.
There are no files selected for viewing
80 changes: 80 additions & 0 deletions
80
src/Semantic.Core.Tests/Thoughts/Chains/Cognitive/TextToImageChainTests.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
using FrostAura.Libraries.Semantic.Core.Thoughts.Chains.Cognitive; | ||
using Microsoft.Extensions.DependencyInjection; | ||
using Microsoft.Extensions.Logging; | ||
using NSubstitute; | ||
using Xunit; | ||
using FrostAura.Libraries.Semantic.Core.Extensions.Configuration; | ||
|
||
namespace Semantic.Core.Tests.Thoughts.Chains.Cognitive | ||
{ | ||
public class TextToImageChainTests | ||
{ | ||
[Fact] | ||
public void Constructor_WithInvalidLogger_ShouldThrow() | ||
{ | ||
ILogger<TextToImageChain> logger = null; | ||
IServiceProvider serviceProvider = Substitute.For<IServiceProvider>(); | ||
|
||
var actual = Assert.Throws<ArgumentNullException>(() => new TextToImageChain(serviceProvider, logger)); | ||
|
||
Assert.Equal(nameof(logger), actual.ParamName); | ||
} | ||
|
||
[Fact] | ||
public void Constructor_WithInvalidServiceProvider_ShouldThrow() | ||
{ | ||
ILogger<TextToImageChain> logger = Substitute.For<ILogger<TextToImageChain>>(); | ||
IServiceProvider serviceProvider = null; | ||
|
||
var actual = Assert.Throws<ArgumentNullException>(() => new TextToImageChain(serviceProvider, logger)); | ||
|
||
Assert.Equal(nameof(serviceProvider), actual.ParamName); | ||
} | ||
|
||
[Fact] | ||
public void Constructor_WithValidParams_ShouldConstruct() | ||
{ | ||
var serviceProvider = Substitute.For<IServiceProvider>(); | ||
var logger = Substitute.For<ILogger<TextToImageChain>>(); | ||
|
||
var actual = new TextToImageChain(serviceProvider, logger); | ||
|
||
Assert.NotNull(actual); | ||
Assert.NotEmpty(actual.QueryExample); | ||
Assert.NotEmpty(actual.QueryInputExample); | ||
Assert.NotEmpty(actual.Reasoning); | ||
Assert.NotEmpty(actual.ChainOfThoughts); | ||
} | ||
|
||
[Fact] | ||
public async Task GenerateImageAndGetFilePathAsync_WithInvalidInput_ShouldThrow() | ||
{ | ||
var serviceCollection = new ServiceCollection() | ||
.AddSemanticCore(Config.SEMANTIC_CONFIG); | ||
var serviceProvider = serviceCollection.BuildServiceProvider(); | ||
var logger = Substitute.For<ILogger<TextToImageChain>>(); | ||
var instance = new TextToImageChain(serviceProvider, logger); | ||
string prompt = default; | ||
|
||
var actual = await Assert.ThrowsAsync<ArgumentNullException>(async () => await instance.GenerateImageAndGetFilePathAsync(prompt)); | ||
|
||
Assert.Equal(nameof(prompt), actual.ParamName); | ||
} | ||
|
||
[Fact] | ||
public async Task GenerateImageAndGetFilePathAsync_WithValidInput_ShouldCallInvokeAsync() | ||
{ | ||
var serviceCollection = new ServiceCollection() | ||
.AddSemanticCore(Config.SEMANTIC_CONFIG); | ||
|
||
var serviceProvider = serviceCollection.BuildServiceProvider(); | ||
var logger = Substitute.For<ILogger<TextToImageChain>>(); | ||
var instance = new TextToImageChain(serviceProvider, logger); | ||
var prompt = "A surfer in a hurricane, fighting off sharks that are on fire, photo-realistic, motion blur."; | ||
|
||
var actual = await instance.GenerateImageAndGetFilePathAsync(prompt); | ||
|
||
Assert.NotNull(actual); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
101 changes: 101 additions & 0 deletions
101
src/Semantic.Core/Thoughts/Chains/Cognitive/TextToImageChain.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
using System.ComponentModel; | ||
using FrostAura.Libraries.Core.Extensions.Validation; | ||
using FrostAura.Libraries.Semantic.Core.Models.Thoughts; | ||
using FrostAura.Libraries.Semantic.Core.Thoughts.Cognitive; | ||
using FrostAura.Libraries.Semantic.Core.Thoughts.IO; | ||
using Microsoft.Extensions.Logging; | ||
using Microsoft.SemanticKernel; | ||
|
||
namespace FrostAura.Libraries.Semantic.Core.Thoughts.Chains.Cognitive | ||
{ | ||
public class TextToImageChain : BaseExecutableChain | ||
{ | ||
/// <summary> | ||
/// An example query that this chain example can be used to solve for. | ||
/// </summary> | ||
public override string QueryExample => "Generate a picture of a surfer in a hurricane, fighting off sharks that are on fire."; | ||
/// <summary> | ||
/// An example query input that this chain example can be used to solve for. | ||
/// </summary> | ||
public override string QueryInputExample => "A surfer in a hurricane, fighting off sharks that are on fire, photo-realistic, motion blur."; | ||
/// The reasoning for the solution of the chain. | ||
/// </summary> | ||
public override string Reasoning => "I can use my code interpreter to create a script to use the StableDiffusion XL model to generate an image from a prompt and sae it to a .png file."; | ||
/// <summary> | ||
/// A collection of thoughts. | ||
/// </summary> | ||
public override List<Thought> ChainOfThoughts => new List<Thought> | ||
{ | ||
new Thought | ||
{ | ||
Action = $"{nameof(CodeInterpreterThoughts)}.{nameof(CodeInterpreterThoughts.InvokeAsync)}", | ||
Reasoning = "I will use my code Python code interpreter to construct a script that can use the StableDiffusion XL model to generate an image for the given prompt, and finally return the path of the file.", | ||
Critisism = "I need to ensure that I use the correct package versions so that the Python environment has the required dependencies installed and ensure that the prompt used for the SDXL model should be optimized.", | ||
Arguments = new Dictionary<string, string> | ||
{ | ||
{ "pythonVersion", "3.11.3" }, | ||
{ "pipDependencies", "diffusers transformers accelerate omegaconf==2.3.0" }, | ||
{ "condaDependencies", "ffmpeg" }, | ||
{ "code", """ | ||
def generate(prompt: str) -> str: | ||
try: | ||
import torch | ||
from diffusers import StableDiffusionXLPipeline | ||
import uuid | ||
output_file_path: str = f'{str(uuid.uuid4())}.png' | ||
pipe = StableDiffusionXLPipeline.from_pretrained( | ||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 | ||
) | ||
pipe = pipe.to("mps") | ||
image = pipe(prompt).images[0] | ||
image.save(output_file_path) | ||
return output_file_path | ||
except Exception as e: | ||
print(e) | ||
raise e | ||
def main() -> str: | ||
return generate('$input') | ||
""" } | ||
}, | ||
OutputKey = "1" | ||
}, | ||
new Thought | ||
{ | ||
Action = $"{nameof(OutputThoughts)}.{nameof(OutputThoughts.OutputTextAsync)}", | ||
Reasoning = "I can simply proxy the response as a direct and response is appropriate for an exact transcription.", | ||
Arguments = new Dictionary<string, string> | ||
{ | ||
{ "output", "$1" } | ||
}, | ||
OutputKey = "2" | ||
} | ||
}; | ||
|
||
/// <summary> | ||
/// Overloaded constructor to provide dependencies. | ||
/// </summary> | ||
/// <param name="serviceProvider">The dependency service provider.</param> | ||
/// <param name="logger">Instance logger.</param> | ||
public TextToImageChain(IServiceProvider serviceProvider, ILogger<TextToImageChain> logger) | ||
: base(serviceProvider, logger) | ||
{ } | ||
|
||
/// <summary> | ||
/// Take input prompt and generate an image using the StableDiffusion XL model, save it to a .png file and return the path to the .png file. | ||
/// </summary> | ||
/// <param name="prompt">The prompt to use to generate an image.</param> | ||
/// <param name="token">The token to use to request cancellation.</param> | ||
/// <returns>The path to the .png file.</returns> | ||
[KernelFunction, Description("Take input prompt and generate an image using the StableDiffusion XL model, save it to a .png file and return the path to the .png file.")] | ||
public Task<string> GenerateImageAndGetFilePathAsync( | ||
[Description("The prompt to use to generate an image.")] string prompt, | ||
CancellationToken token = default) | ||
{ | ||
return ExecuteChainAsync(prompt.ThrowIfNullOrWhitespace(nameof(prompt)), token: token); | ||
} | ||
} | ||
} |