Skip to content

Commit

Permalink
feat(TextToImage): Completion of text to image thought/skill using St…
Browse files Browse the repository at this point in the history
…ableDiffusion XL and code interpreter.
  • Loading branch information
frostaura committed Jan 13, 2024
1 parent 025bc18 commit 05fd2c6
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
using FrostAura.Libraries.Semantic.Core.Thoughts.Chains.Cognitive;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Xunit;
using FrostAura.Libraries.Semantic.Core.Extensions.Configuration;

namespace Semantic.Core.Tests.Thoughts.Chains.Cognitive
{
public class TextToImageChainTests
{
[Fact]
public void Constructor_WithInvalidLogger_ShouldThrow()
{
ILogger<TextToImageChain> logger = null;
IServiceProvider serviceProvider = Substitute.For<IServiceProvider>();

var actual = Assert.Throws<ArgumentNullException>(() => new TextToImageChain(serviceProvider, logger));

Assert.Equal(nameof(logger), actual.ParamName);
}

[Fact]
public void Constructor_WithInvalidServiceProvider_ShouldThrow()
{
ILogger<TextToImageChain> logger = Substitute.For<ILogger<TextToImageChain>>();
IServiceProvider serviceProvider = null;

var actual = Assert.Throws<ArgumentNullException>(() => new TextToImageChain(serviceProvider, logger));

Assert.Equal(nameof(serviceProvider), actual.ParamName);
}

[Fact]
public void Constructor_WithValidParams_ShouldConstruct()
{
var serviceProvider = Substitute.For<IServiceProvider>();
var logger = Substitute.For<ILogger<TextToImageChain>>();

var actual = new TextToImageChain(serviceProvider, logger);

Assert.NotNull(actual);
Assert.NotEmpty(actual.QueryExample);
Assert.NotEmpty(actual.QueryInputExample);
Assert.NotEmpty(actual.Reasoning);
Assert.NotEmpty(actual.ChainOfThoughts);
}

[Fact]
public async Task GenerateImageAndGetFilePathAsync_WithInvalidInput_ShouldThrow()
{
var serviceCollection = new ServiceCollection()
.AddSemanticCore(Config.SEMANTIC_CONFIG);
var serviceProvider = serviceCollection.BuildServiceProvider();
var logger = Substitute.For<ILogger<TextToImageChain>>();
var instance = new TextToImageChain(serviceProvider, logger);
string prompt = default;

var actual = await Assert.ThrowsAsync<ArgumentNullException>(async () => await instance.GenerateImageAndGetFilePathAsync(prompt));

Assert.Equal(nameof(prompt), actual.ParamName);
}

[Fact]
public async Task GenerateImageAndGetFilePathAsync_WithValidInput_ShouldCallInvokeAsync()
{
var serviceCollection = new ServiceCollection()
.AddSemanticCore(Config.SEMANTIC_CONFIG);

var serviceProvider = serviceCollection.BuildServiceProvider();
var logger = Substitute.For<ILogger<TextToImageChain>>();
var instance = new TextToImageChain(serviceProvider, logger);
var prompt = "A surfer in a hurricane, fighting off sharks that are on fire, photo-realistic, motion blur.";

var actual = await instance.GenerateImageAndGetFilePathAsync(prompt);

Assert.NotNull(actual);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<package xmlns="http://schemas.microsoft.com/packaging/2012/06/nuspec.xsd">
<metadata>
<id>FrostAura.Libraries.Intelligence.Semantic.Core</id>
<version>0.1.1</version>
<version>1.0.1</version>
<title>FrostAura.Libraries.Intelligence.Semantic.Core</title>
<authors>Dean Martin</authors>
<owners>FrostAura</owners>
Expand Down
101 changes: 101 additions & 0 deletions src/Semantic.Core/Thoughts/Chains/Cognitive/TextToImageChain.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
using System.ComponentModel;
using FrostAura.Libraries.Core.Extensions.Validation;
using FrostAura.Libraries.Semantic.Core.Models.Thoughts;
using FrostAura.Libraries.Semantic.Core.Thoughts.Cognitive;
using FrostAura.Libraries.Semantic.Core.Thoughts.IO;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;

namespace FrostAura.Libraries.Semantic.Core.Thoughts.Chains.Cognitive
{
public class TextToImageChain : BaseExecutableChain
{
/// <summary>
/// An example query that this chain example can be used to solve for.
/// </summary>
public override string QueryExample => "Generate a picture of a surfer in a hurricane, fighting off sharks that are on fire.";
/// <summary>
/// An example query input that this chain example can be used to solve for.
/// </summary>
public override string QueryInputExample => "A surfer in a hurricane, fighting off sharks that are on fire, photo-realistic, motion blur.";
/// The reasoning for the solution of the chain.
/// </summary>
public override string Reasoning => "I can use my code interpreter to create a script to use the StableDiffusion XL model to generate an image from a prompt and sae it to a .png file.";
/// <summary>
/// A collection of thoughts.
/// </summary>
public override List<Thought> ChainOfThoughts => new List<Thought>
{
new Thought
{
Action = $"{nameof(CodeInterpreterThoughts)}.{nameof(CodeInterpreterThoughts.InvokeAsync)}",
Reasoning = "I will use my code Python code interpreter to construct a script that can use the StableDiffusion XL model to generate an image for the given prompt, and finally return the path of the file.",
Critisism = "I need to ensure that I use the correct package versions so that the Python environment has the required dependencies installed and ensure that the prompt used for the SDXL model should be optimized.",
Arguments = new Dictionary<string, string>
{
{ "pythonVersion", "3.11.3" },
{ "pipDependencies", "diffusers transformers accelerate omegaconf==2.3.0" },
{ "condaDependencies", "ffmpeg" },
{ "code", """
def generate(prompt: str) -> str:
try:
import torch
from diffusers import StableDiffusionXLPipeline
import uuid
output_file_path: str = f'{str(uuid.uuid4())}.png'
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
)
pipe = pipe.to("mps")
image = pipe(prompt).images[0]
image.save(output_file_path)
return output_file_path
except Exception as e:
print(e)
raise e
def main() -> str:
return generate('$input')
""" }
},
OutputKey = "1"
},
new Thought
{
Action = $"{nameof(OutputThoughts)}.{nameof(OutputThoughts.OutputTextAsync)}",
Reasoning = "I can simply proxy the response as a direct and response is appropriate for an exact transcription.",
Arguments = new Dictionary<string, string>
{
{ "output", "$1" }
},
OutputKey = "2"
}
};

/// <summary>
/// Overloaded constructor to provide dependencies.
/// </summary>
/// <param name="serviceProvider">The dependency service provider.</param>
/// <param name="logger">Instance logger.</param>
public TextToImageChain(IServiceProvider serviceProvider, ILogger<TextToImageChain> logger)
: base(serviceProvider, logger)
{ }

/// <summary>
/// Take input prompt and generate an image using the StableDiffusion XL model, save it to a .png file and return the path to the .png file.
/// </summary>
/// <param name="prompt">The prompt to use to generate an image.</param>
/// <param name="token">The token to use to request cancellation.</param>
/// <returns>The path to the .png file.</returns>
[KernelFunction, Description("Take input prompt and generate an image using the StableDiffusion XL model, save it to a .png file and return the path to the .png file.")]
public Task<string> GenerateImageAndGetFilePathAsync(
[Description("The prompt to use to generate an image.")] string prompt,
CancellationToken token = default)
{
return ExecuteChainAsync(prompt.ThrowIfNullOrWhitespace(nameof(prompt)), token: token);
}
}
}

0 comments on commit 05fd2c6

Please sign in to comment.