Skip to content

Commit

Permalink
Merge pull request #5 from Ceilidh-Team/parallel
Browse files Browse the repository at this point in the history
Parallelization
  • Loading branch information
OrionNebula authored Sep 2, 2018
2 parents b60538e + a45a6be commit 7bc3cea
Show file tree
Hide file tree
Showing 8 changed files with 335 additions and 42 deletions.
156 changes: 156 additions & 0 deletions ProjectCeilidh.Cobble.Tests/AsyncCobbleContextTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Threading.Tasks;
using ProjectCeilidh.Cobble.Generator;
using Xunit;

namespace ProjectCeilidh.Cobble.Tests
{
public class AsyncCobbleContextTests
{
private readonly CobbleContext _context;

public AsyncCobbleContextTests()
{
_context = new CobbleContext();
}

[Fact]
public async Task BasicLoad()
{
_context.AddManaged<TestUnit>();
await _context.ExecuteAsync();

Assert.True(_context.TryGetSingleton<TestUnit>(out _));
Assert.True(_context.TryGetSingleton<ITestUnit>(out _));
}

[Fact]
public async Task BasicDeps()
{
_context.AddManaged<TestUnit>();
_context.AddManaged<BasicDependUnit>();
await _context.ExecuteAsync();

Assert.True(_context.TryGetSingleton<TestUnit>(out _));
Assert.True(_context.TryGetSingleton<ITestUnit>(out _));
Assert.True(_context.TryGetSingleton<BasicDependUnit>(out _));
}

[Fact]
public async Task MediumDeps()
{
_context.AddManaged<TestUnit>();
_context.AddManaged<MediumDependUnit>();
await _context.ExecuteAsync();

Assert.True(_context.TryGetSingleton<TestUnit>(out _));
Assert.True(_context.TryGetSingleton<ITestUnit>(out _));
Assert.True(_context.TryGetSingleton<MediumDependUnit>(out _));
}

[Fact]
public async Task AdvancedDeps()
{
_context.AddManaged<TestUnit>();
_context.AddManaged<AdvancedDependUnit>();
await _context.ExecuteAsync();

_context.AddManaged<TestUnit>();

Assert.True(_context.TryGetImplementations<TestUnit>(out var testSet) && testSet.Count() == 2);
Assert.True(_context.TryGetSingleton<AdvancedDependUnit>(out var adv) && adv.TestUnits.Count == 2);
}

[Fact]
public async Task DuplicateResolver()
{
var exec = false;

_context.DuplicateResolver = (dependencyType, instances) => {
exec = true;
return instances[0];
};

_context.AddManaged<TestUnit>();
_context.AddManaged<TestUnit>();
_context.AddManaged<BasicDependUnit>();

await _context.ExecuteAsync();

Assert.True(exec);
}

[Fact]
public async Task DuplicateException()
{
_context.AddManaged<TestUnit>();
_context.AddManaged<TestUnit>();
_context.AddManaged<BasicDependUnit>();

await Assert.ThrowsAsync<AmbiguousDependencyException>(() => _context.ExecuteAsync());
}

[Fact]
public async Task DictInstanceGenerator()
{
_context.AddManaged<TestUnit>();
_context.AddManaged(new DictionaryInstanceGenerator(typeof(ITestUnit), new Func<TestUnit, object>(x => x), new Dictionary<MethodInfo, Delegate>
{
[typeof(ITestUnit).GetMethod("get_TestValue")] = new Func<TestUnit, string>(x => x.TestValue)
}));

await _context.ExecuteAsync();
Assert.True(_context.TryGetImplementations<ITestUnit>(out var units));
foreach (var testUnit in units)
Assert.Equal("Hi", testUnit.TestValue);
}

private interface ITestUnit
{
string TestValue { get; }
}

private class TestUnit : ITestUnit
{
public string TestValue => "Hi";
}

private class BasicDependUnit
{
public BasicDependUnit(TestUnit unit)
{
Assert.NotNull(unit);
Assert.Equal("Hi", unit.TestValue);
}
}

private class MediumDependUnit
{
public MediumDependUnit(IEnumerable<ITestUnit> units)
{
Assert.NotNull(units);
Assert.NotEmpty(units);
}
}

private class AdvancedDependUnit : ILateInject<ITestUnit>
{
public readonly List<ITestUnit> TestUnits;

public AdvancedDependUnit(IEnumerable<ITestUnit> units)
{
TestUnits = new List<ITestUnit>(units);

Assert.NotEmpty(TestUnits);
}

public void UnitLoaded(ITestUnit unit)
{
TestUnits.Add(unit);
}
}
}
}
35 changes: 34 additions & 1 deletion ProjectCeilidh.Cobble.Tests/DirectedGraphTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Linq;
using System;
using System.Linq;
using Xunit;
using ProjectCeilidh.Cobble.Data;

Expand All @@ -23,6 +24,25 @@ void InitialInspector(int value)
}
}

[Fact]
public void ParallelTopologicalSort()
{
var graph = new DirectedGraph<int>(Enumerable.Range(0, 5));
graph.Link(0, 1);
graph.Link(2, 1);
graph.Link(1, 3);
graph.Link(4, 3);

Assert.Collection(graph.ParallelTopologicalSort(), InitialInspector, x => Assert.Equal(new []{ 1 }, x), x => Assert.Equal(new []{ 3 }, x));

void InitialInspector(int[] value)
{
Array.Sort(value);

Assert.Equal(new []{ 0, 2, 4 }, value);
}
}

[Fact]
public void CircularDependency()
{
Expand All @@ -35,5 +55,18 @@ public void CircularDependency()

Assert.Throws<DirectedGraph<int>.CyclicGraphException>(() => graph.TopologicalSort().ToList());
}

[Fact]
public void ParallelCircularDependency()
{
var graph = new DirectedGraph<int>(Enumerable.Range(0, 5));
graph.Link(0, 1);
graph.Link(2, 1);
graph.Link(1, 3);
graph.Link(4, 3);
graph.Link(3, 0);

Assert.Throws<DirectedGraph<int>.CyclicGraphException>(() => graph.ParallelTopologicalSort().ToList());
}
}
}
101 changes: 79 additions & 22 deletions ProjectCeilidh.Cobble/CobbleContext.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using System.Collections.Generic;
using System;
using System.Collections.Concurrent;
using System.Linq;
using System.Threading.Tasks;
using ProjectCeilidh.Cobble.Data;
using ProjectCeilidh.Cobble.Generator;

Expand All @@ -18,17 +20,17 @@ public sealed class CobbleContext
private bool _firstStage;

private readonly List<IInstanceGenerator> _instanceGenerators;
private readonly Dictionary<Type, HashSet<object>> _lateInjectInstances;
private readonly Dictionary<Type, HashSet<object>> _implementations;
private readonly ConcurrentDictionary<Type, HashSet<object>> _lateInjectInstances;
private readonly ConcurrentDictionary<Type, HashSet<object>> _implementations;

/// <summary>
/// Construct a new CobbleContext.
/// </summary>
public CobbleContext()
{
_instanceGenerators = new List<IInstanceGenerator>();
_lateInjectInstances = new Dictionary<Type, HashSet<object>>();
_implementations = new Dictionary<Type, HashSet<object>>();
_lateInjectInstances = new ConcurrentDictionary<Type, HashSet<object>>();
_implementations = new ConcurrentDictionary<Type, HashSet<object>>();

AddUnmanaged(this);
}
Expand Down Expand Up @@ -138,9 +140,9 @@ public void Execute()

var graph = new DirectedGraph<IInstanceGenerator>(_instanceGenerators);

foreach(var gen in _instanceGenerators) // Create links in the DirectedGraph between dependencies and the generators which provide them
foreach (var gen in _instanceGenerators) // Create links in the DirectedGraph between dependencies and the generators which provide them
{
foreach(var dep in gen.Dependencies)
foreach (var dep in gen.Dependencies)
{
var depType = dep;

Expand All @@ -156,20 +158,74 @@ public void Execute()
{
foreach (var gen in graph.TopologicalSort()) // Sort the dependency graph topologically - all dependencies should be satisfied by the time we get to each unit
{
var inst = CreateInstance(gen, _implementations);
var obj = CreateInstance(gen, _implementations);
PushInstanceProvides(gen, obj, _implementations);

if (!(gen is ILateInstanceGenerator late)) continue;

// If the generator supports late injection, we need to add it to our list
foreach (var lateDep in late.LateDependencies)
_lateInjectInstances.AddOrUpdate(lateDep, x => new HashSet<object>(new[] {obj}),
(a, b) =>
{
b.Add(a);
return b;
});
}
}
catch (DirectedGraph<IInstanceGenerator>.CyclicGraphException)
{
throw new CircularDependencyException();
}
}

PushInstanceProvides(gen, inst, _implementations);
public async Task ExecuteAsync()
{
if (_firstStage) throw new Exception("You cannot execute a CobbleContext twice.");

if (gen is ILateInstanceGenerator late) // If the generator supports late injection, we need to add it to our list
_firstStage = true;

// Create a lookup which maps provided type to the set off all generators that provide it.
var implMap = _instanceGenerators
.SelectMany(x => x.Provides.Select(y => (Type: y, Generator: x)))
.ToLookup(x => x.Type, x => x.Generator);

var graph = new DirectedGraph<IInstanceGenerator>(_instanceGenerators);

foreach(var gen in _instanceGenerators) // Create links in the DirectedGraph between dependencies and the generators which provide them
{
foreach(var dep in gen.Dependencies)
{
var depType = dep;

if (dep.IsConstructedGenericType && dep.GetGenericTypeDefinition() == typeof(IEnumerable<>))
depType = dep.GetGenericArguments()[0];

foreach (var impl in implMap[depType])
graph.Link(impl, gen);
}
}

try
{
foreach (var level in graph.ParallelTopologicalSort()) // Sort the dependency graph topologically - all dependencies should be satisfied by the time we get to each unit
{
await Task.WhenAll(level.Select(gen => Task.Run(() =>
{
var obj = CreateInstance(gen, _implementations);
PushInstanceProvides(gen, obj, _implementations);
if (!(gen is ILateInstanceGenerator late)) return;
// If the generator supports late injection, we need to add it to our list
foreach (var lateDep in late.LateDependencies)
{
if (_lateInjectInstances.TryGetValue(lateDep, out var lateSet))
lateSet.Add(inst);
else
_lateInjectInstances[lateDep] = new HashSet<object>(new[] { inst });
}
}
_lateInjectInstances.AddOrUpdate(lateDep, x => new HashSet<object>(new[] { obj }),
(a, b) =>
{
b.Add(obj);
return b;
});
})));
}
}
catch (DirectedGraph<IInstanceGenerator>.CyclicGraphException) {
Expand All @@ -183,13 +239,14 @@ public void Execute()
/// <param name="gen">The instance generator that produced the instance.</param>
/// <param name="instance">The instance that was produced.</param>
/// <param name="instances">A dictionary mapping provided types to a set of instances.</param>
private static void PushInstanceProvides(IInstanceGenerator gen, object instance, Dictionary<Type, HashSet<object>> instances)
private static void PushInstanceProvides(IInstanceGenerator gen, object instance, ConcurrentDictionary<Type, HashSet<object>> instances)
{
foreach (var prov in gen.Provides)
if (instances.TryGetValue(prov, out var set))
set.Add(instance);
else
instances[prov] = new HashSet<object>(new[] { instance });
instances.AddOrUpdate(prov, x => new HashSet<object>(new[] {instance}), (a, b) =>
{
b.Add(instance);
return b;
});
}

/// <summary>
Expand All @@ -198,7 +255,7 @@ private static void PushInstanceProvides(IInstanceGenerator gen, object instance
/// <returns>The created object.</returns>
/// <param name="gen">The generator instance.</param>
/// <param name="instances">A dictionary mapping provided types to a set of instances.</param>
private object CreateInstance(IInstanceGenerator gen, Dictionary<Type, HashSet<object>> instances)
private object CreateInstance(IInstanceGenerator gen, IDictionary<Type, HashSet<object>> instances)
{
var args = new object[gen.Dependencies.Count()];
var i = 0;
Expand Down
Loading

0 comments on commit 7bc3cea

Please sign in to comment.