Skip to content

Commit

Permalink
Use typed methods and typed API Responses to catch all the APIs retur…
Browse files Browse the repository at this point in the history
…ning invalid response types
  • Loading branch information
mythz committed Nov 25, 2024
1 parent 97466c8 commit 4dce797
Show file tree
Hide file tree
Showing 11 changed files with 159 additions and 142 deletions.
8 changes: 4 additions & 4 deletions AiServer.ServiceInterface/AudioServices.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace AiServer.ServiceInterface;

public class AudioServices(IBackgroundJobs jobs) : Service
{
public async Task<object> Any(ConvertAudio request)
public async Task<ArtifactGenerationResponse> Any(ConvertAudio request)
{
if (Request?.Files == null || Request.Files.Length == 0)
{
Expand All @@ -27,9 +27,9 @@ public async Task<object> Any(ConvertAudio request)
};

var transformService = base.ResolveService<MediaTransformProviderServices>();
return await transformRequest.ProcessTransform(jobs, transformService, sync: true);
return await transformRequest.ProcessSyncTransformAsync(jobs, transformService);
}
public async Task<object> Any(QueueConvertAudio request)
public async Task<QueueMediaTransformResponse> Any(QueueConvertAudio request)
{
if (Request?.Files == null || Request.Files.Length == 0)
{
Expand All @@ -52,7 +52,7 @@ public async Task<object> Any(QueueConvertAudio request)
};

var transformService = base.ResolveService<MediaTransformProviderServices>();
return await transformRequest.ProcessTransform(jobs, transformService);
return await transformRequest.ProcessQueuedTransformAsync(jobs, transformService);
}

private bool IsAudioFormat(MediaOutputFormat outputformat)
Expand Down
17 changes: 17 additions & 0 deletions AiServer.ServiceInterface/DtoExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using AiServer.ServiceModel;

namespace AiServer.ServiceInterface;

public static class DtoExtensions
{
public static TextGenerationResponse ToTextGenerationResponse(this GenerationResponse response) => response.TextOutputs?.Count > 0 ? new() {
Results = response.TextOutputs,
ResponseStatus = response.ResponseStatus
} : throw new Exception("Failed to generate any text outputs");

public static ArtifactGenerationResponse ToArtifactGenerationResponse(this GenerationResponse response) => response.Outputs?.Count > 0 ? new() {
Results = response.Outputs,
ResponseStatus = response.ResponseStatus
} : throw new Exception("Failed to generate any outputs");

}
66 changes: 45 additions & 21 deletions AiServer.ServiceInterface/GenerationServices.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public object Any(ActiveMediaModels request)
};
}

public async Task<object> Any(TextToImage request)
public async Task<ArtifactGenerationResponse> Any(TextToImage request)
{
var diffRequest = new CreateGeneration
{
Expand All @@ -88,11 +88,11 @@ public async Task<object> Any(TextToImage request)
};

await using var diffServices = ResolveService<MediaProviderServices>();
var result = await diffRequest.ProcessGeneration(jobs, diffServices, sync: true) as GenerationResponse;
return result.ConvertTo<ArtifactGenerationResponse>();
var result = await diffRequest.ProcessSyncGenerationAsync(jobs, diffServices);
return result.ToArtifactGenerationResponse();
}

public async Task<object> Any(ImageToImage request)
public async Task<ArtifactGenerationResponse> Any(ImageToImage request)
{
var diffRequest = new CreateGeneration
{
Expand All @@ -109,11 +109,11 @@ public async Task<object> Any(ImageToImage request)
};

await using var diffServices = ResolveService<MediaProviderServices>();
var result = await diffRequest.ProcessGeneration(jobs, diffServices, sync: true) as GenerationResponse;
return result.ConvertTo<ArtifactGenerationResponse>();
var result = await diffRequest.ProcessSyncGenerationAsync(jobs, diffServices);
return result.ToArtifactGenerationResponse();
}

public async Task<object> Any(ImageUpscale request)
public async Task<ArtifactGenerationResponse> Any(ImageUpscale request)
{
var diffRequest = new CreateGeneration
{
Expand All @@ -126,11 +126,11 @@ public async Task<object> Any(ImageUpscale request)
};

await using var diffServices = ResolveService<MediaProviderServices>();
var result = await diffRequest.ProcessGeneration(jobs, diffServices, sync: true) as GenerationResponse;
return result.ConvertTo<ArtifactGenerationResponse>();
var result = await diffRequest.ProcessSyncGenerationAsync(jobs, diffServices);
return result.ToArtifactGenerationResponse();
}

public async Task<object> Any(ImageWithMask request)
public async Task<ArtifactGenerationResponse> Any(ImageWithMask request)
{
var diffRequest = new CreateGeneration
{
Expand All @@ -146,11 +146,11 @@ public async Task<object> Any(ImageWithMask request)
};

await using var diffServices = ResolveService<MediaProviderServices>();
var result = await diffRequest.ProcessGeneration(jobs, diffServices, sync: true) as GenerationResponse;
return result.ConvertTo<ArtifactGenerationResponse>();
var result = await diffRequest.ProcessSyncGenerationAsync(jobs, diffServices);
return result.ToArtifactGenerationResponse();
}

public async Task<object> Any(ImageToText request)
public async Task<TextGenerationResponse> Any(ImageToText request)
{
var diffRequest = new CreateGeneration
{
Expand All @@ -162,15 +162,14 @@ public async Task<object> Any(ImageToText request)
};

await using var diffServices = ResolveService<MediaProviderServices>();
var result = await diffRequest.ProcessGeneration(jobs, diffServices, sync: true) as GenerationResponse;
return result.ConvertTo<TextGenerationResponse>();
var result = await diffRequest.ProcessSyncGenerationAsync(jobs, diffServices);
return result.ToTextGenerationResponse();
}
}

public static class GenerationServiceExtensions
{
public static async Task<object> ProcessGeneration(this CreateGeneration diffRequest, IBackgroundJobs jobs,
MediaProviderServices genProviderServices, bool sync = false)
public static async Task<QueueGenerationResponse> ProcessQueuedGenerationAsync(this CreateGeneration diffRequest, IBackgroundJobs jobs, MediaProviderServices genProviderServices)
{
CreateGenerationResponse? diffResponse = null;
try
Expand All @@ -184,7 +183,7 @@ public static async Task<object> ProcessGeneration(this CreateGeneration diffReq
throw;
}

if(diffResponse == null)
if (diffResponse == null)
throw new Exception("Failed to start generation");

var job = jobs.GetJob(diffResponse.Id);
Expand Down Expand Up @@ -217,12 +216,37 @@ public static async Task<object> ProcessGeneration(this CreateGeneration diffReq
throw new Exception($"Job failed: {job.Failed.Error}");
}

// If not a synchronous request, return immediately with job details
if (sync != true)
return queueResponse;
}

public static async Task<GenerationResponse> ProcessSyncGenerationAsync(this CreateGeneration diffRequest, IBackgroundJobs jobs, MediaProviderServices genProviderServices)
{
CreateGenerationResponse? diffResponse = null;
try
{
return queueResponse;
var response = genProviderServices.Any(diffRequest);
diffResponse = response as CreateGenerationResponse;
}
catch (Exception e)
{
Console.WriteLine(e);
throw;
}

if (diffResponse == null)
throw new Exception("Failed to start generation");

var job = jobs.GetJob(diffResponse.Id);
// For synchronous requests, wait for the job to be created
while (job == null)
{
await Task.Delay(1000);
job = jobs.GetJob(diffResponse.Id);
}

// We know at this point, we definitely have a job
JobResult? queuedJob = job;

var completedResponse = new GenerationResponse { };

// Wait for the job to complete max 1 minute
Expand Down
12 changes: 6 additions & 6 deletions AiServer.ServiceInterface/ImageServices.Generation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public partial class ImageServices(IBackgroundJobs jobs,
ILogger<ImageServices> log,
AppData appData) : Service
{
public async Task<object> Any(QueueTextToImage request)
public async Task<QueueGenerationResponse> Any(QueueTextToImage request)
{
if (!string.IsNullOrEmpty(request.Model) && !appData.ModelSupportsTask(request.Model, AiTaskType.TextToImage))
{
Expand Down Expand Up @@ -44,7 +44,7 @@ public async Task<object> Any(QueueTextToImage request)
};

await using var diffServices = ResolveService<MediaProviderServices>();
return await diffRequest.ProcessGeneration(jobs, diffServices);
return await diffRequest.ProcessQueuedGenerationAsync(jobs, diffServices);
}

public async Task<object> Any(QueueImageUpscale request)
Expand All @@ -70,7 +70,7 @@ public async Task<object> Any(QueueImageUpscale request)
};

await using var diffServices = ResolveService<MediaProviderServices>();
return await diffRequest.ProcessGeneration(jobs, diffServices);
return await diffRequest.ProcessSyncGenerationAsync(jobs, diffServices);
}

public async Task<object> Any(QueueImageToImage request)
Expand Down Expand Up @@ -100,7 +100,7 @@ public async Task<object> Any(QueueImageToImage request)
};

await using var diffServices = ResolveService<MediaProviderServices>();
return await diffRequest.ProcessGeneration(jobs, diffServices);
return await diffRequest.ProcessSyncGenerationAsync(jobs, diffServices);
}

public async Task<object> Any(QueueImageWithMask request)
Expand Down Expand Up @@ -129,7 +129,7 @@ public async Task<object> Any(QueueImageWithMask request)
};

await using var diffServices = ResolveService<MediaProviderServices>();
return await diffRequest.ProcessGeneration(jobs, diffServices);
return await diffRequest.ProcessSyncGenerationAsync(jobs, diffServices);
}

public async Task<object> Any(QueueImageToText request)
Expand All @@ -147,7 +147,7 @@ public async Task<object> Any(QueueImageToText request)
};

await using var genServices = ResolveService<MediaProviderServices>();
return await diffRequest.ProcessGeneration(jobs, genServices);
return await diffRequest.ProcessSyncGenerationAsync(jobs, genServices);
}
}

Expand Down
Loading

0 comments on commit 4dce797

Please sign in to comment.