Skip to content

Commit

Permalink
Add SetAll operator to IUpdateBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
ddaspit committed Dec 5, 2024
1 parent 2fa330d commit 997ccdc
Show file tree
Hide file tree
Showing 9 changed files with 220 additions and 40 deletions.
1 change: 1 addition & 0 deletions src/DataAccess/src/SIL.DataAccess/ArrayPosition.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ public static class ArrayPosition
{
public const int FirstMatching = int.MaxValue;
public const int All = int.MaxValue - 1;
internal const int ArrayFilter = int.MaxValue - 2;
}
11 changes: 9 additions & 2 deletions src/DataAccess/src/SIL.DataAccess/DataAccessFieldDefinition.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
namespace SIL.DataAccess;

public class DataAccessFieldDefinition<TDocument, TField>(Expression<Func<TDocument, TField>> expression)
: FieldDefinition<TDocument, TField>
public class DataAccessFieldDefinition<TDocument, TField>(
Expression<Func<TDocument, TField>> expression,
string arrayFilterId = ""
) : FieldDefinition<TDocument, TField>
{
private readonly ExpressionFieldDefinition<TDocument, TField> _internalDef = new(expression);
private readonly string _arrayFilterId = arrayFilterId;

public override RenderedFieldDefinition<TField> Render(
IBsonSerializer<TDocument> documentSerializer,
Expand All @@ -18,6 +21,10 @@ LinqProvider linqProvider
);
string fieldName = rendered.FieldName.Replace(ArrayPosition.All.ToString(CultureInfo.InvariantCulture), "$[]");
fieldName = fieldName.Replace(ArrayPosition.FirstMatching.ToString(CultureInfo.InvariantCulture), "$");
fieldName = fieldName.Replace(
ArrayPosition.ArrayFilter.ToString(CultureInfo.InvariantCulture),
$"$[{_arrayFilterId}]"
);
if (fieldName != rendered.FieldName)
{
return new RenderedFieldDefinition<TField>(
Expand Down
10 changes: 10 additions & 0 deletions src/DataAccess/src/SIL.DataAccess/ExpressionHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,14 @@ Expression expression
finder.Visit(expression);
return finder.Value;
}

public static Expression<Func<TIn, TOut>> Concatenate<TIn, TInter, TOut>(
Expression<Func<TIn, TInter>> left,
Expression<Func<TInter, TOut>> right
)
{
ParameterReplacer replacer = new(right.Parameters[0], left.Body);
Expression merged = replacer.Visit(right.Body);
return Expression.Lambda<Func<TIn, TOut>>(merged, left.Parameters[0]);
}
}
9 changes: 8 additions & 1 deletion src/DataAccess/src/SIL.DataAccess/IUpdateBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,17 @@ public interface IUpdateBuilder<T>

IUpdateBuilder<T> RemoveAll<TItem>(
Expression<Func<T, IEnumerable<TItem>?>> field,
Expression<Func<TItem, bool>> predicate
Expression<Func<TItem, bool>>? predicate = null
);

IUpdateBuilder<T> Remove<TItem>(Expression<Func<T, IEnumerable<TItem>?>> field, TItem value);

IUpdateBuilder<T> Add<TItem>(Expression<Func<T, IEnumerable<TItem>?>> field, TItem value);

IUpdateBuilder<T> SetAll<TItem, TField>(
Expression<Func<T, IEnumerable<TItem>?>> collectionField,
Expression<Func<TItem, TField>> itemField,
TField value,
Expression<Func<TItem, bool>>? predicate = null
);
}
92 changes: 66 additions & 26 deletions src/DataAccess/src/SIL.DataAccess/MemoryUpdateBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,20 @@ public class MemoryUpdateBuilder<T>(Expression<Func<T, bool>> filter, T entity,

public IUpdateBuilder<T> Set<TField>(Expression<Func<T, TField>> field, TField value)
{
(IEnumerable<object> owners, PropertyInfo? prop, object? index) = GetFieldOwners(field);
object[]? indices = index == null ? null : [index];
foreach (object owner in owners)
prop.SetValue(owner, value, indices);
Set(_entity, _filter, field, value);
return this;
}

public IUpdateBuilder<T> SetOnInsert<TField>(Expression<Func<T, TField>> field, TField value)
{
if (_isInsert)
Set(field, value);
Set(_entity, _filter, field, value);
return this;
}

public IUpdateBuilder<T> Unset<TField>(Expression<Func<T, TField>> field)
{
(IEnumerable<object> owners, PropertyInfo prop, object? index) = GetFieldOwners(field);
(IEnumerable<object> owners, PropertyInfo prop, object? index) = GetFieldOwners(_entity, _filter, field);
if (index != null)
{
// remove value from a dictionary
Expand All @@ -49,7 +46,7 @@ public IUpdateBuilder<T> Unset<TField>(Expression<Func<T, TField>> field)

public IUpdateBuilder<T> Inc(Expression<Func<T, int>> field, int value = 1)
{
(IEnumerable<object> owners, PropertyInfo prop, object? index) = GetFieldOwners(field);
(IEnumerable<object> owners, PropertyInfo prop, object? index) = GetFieldOwners(_entity, _filter, field);
object[]? indices = index == null ? null : [index];
foreach (object owner in owners)
{
Expand All @@ -62,20 +59,20 @@ public IUpdateBuilder<T> Inc(Expression<Func<T, int>> field, int value = 1)

public IUpdateBuilder<T> RemoveAll<TItem>(
Expression<Func<T, IEnumerable<TItem>?>> field,
Expression<Func<TItem, bool>> predicate
Expression<Func<TItem, bool>>? predicate = null
)
{
(IEnumerable<object> owners, PropertyInfo? prop, object? index) = GetFieldOwners(field);
(IEnumerable<object> owners, PropertyInfo? prop, object? index) = GetFieldOwners(_entity, _filter, field);
object[]? indices = index == null ? null : [index];
Func<TItem, bool> predicateFunc = predicate.Compile();
Func<TItem, bool>? predicateFunc = predicate?.Compile();
foreach (object owner in owners)
{
var collection = (IEnumerable<TItem>?)prop.GetValue(owner, indices);
MethodInfo? removeMethod = collection?.GetType().GetMethod("Remove");
if (collection is not null && removeMethod is not null)
{
// the collection is mutable, so use Remove method to remove item
TItem[] toRemove = collection.Where(predicateFunc).ToArray();
TItem[] toRemove = collection.Where(i => predicateFunc?.Invoke(i) ?? true).ToArray();
foreach (TItem item in toRemove)
removeMethod.Invoke(collection, [item]);
}
Expand All @@ -84,14 +81,17 @@ Expression<Func<TItem, bool>> predicate
if (prop.PropertyType.IsArray || prop.PropertyType.IsInterface)
{
// the collection type is an array or interface, so construct a new array and set property
TItem[] newValue = collection.Where(i => !predicateFunc(i)).ToArray();
TItem[] newValue = collection.Where(i => !(predicateFunc?.Invoke(i) ?? false)).ToArray();
prop.SetValue(owner, newValue, indices);
}
else
{
// the collection type is a collection class, so construct a new collection and set property
var newValue = (IEnumerable<TItem>?)
Activator.CreateInstance(prop.PropertyType, collection.Where(i => !predicateFunc(i)).ToArray());
Activator.CreateInstance(
prop.PropertyType,
collection.Where(i => !(predicateFunc?.Invoke(i) ?? false)).ToArray()
);
prop.SetValue(owner, newValue, indices);
}
}
Expand All @@ -101,7 +101,7 @@ Expression<Func<TItem, bool>> predicate

public IUpdateBuilder<T> Remove<TItem>(Expression<Func<T, IEnumerable<TItem>?>> field, TItem value)
{
(IEnumerable<object> owners, PropertyInfo? prop, object? index) = GetFieldOwners(field);
(IEnumerable<object> owners, PropertyInfo? prop, object? index) = GetFieldOwners(_entity, _filter, field);
object[]? indices = index == null ? null : [index];
foreach (object owner in owners)
{
Expand Down Expand Up @@ -134,7 +134,7 @@ public IUpdateBuilder<T> Remove<TItem>(Expression<Func<T, IEnumerable<TItem>?>>

public IUpdateBuilder<T> Add<TItem>(Expression<Func<T, IEnumerable<TItem>?>> field, TItem value)
{
(IEnumerable<object> owners, PropertyInfo? prop, object? index) = GetFieldOwners(field);
(IEnumerable<object> owners, PropertyInfo? prop, object? index) = GetFieldOwners(_entity, _filter, field);
object[]? indices = index == null ? null : [index];
foreach (object owner in owners)
{
Expand All @@ -147,7 +147,7 @@ public IUpdateBuilder<T> Add<TItem>(Expression<Func<T, IEnumerable<TItem>?>> fie
}
else
{
collection ??= Array.Empty<TItem>();
collection ??= [];
if (prop.PropertyType.IsArray || prop.PropertyType.IsInterface)
{
// the collection type is an array or interface, so construct a new array and set property
Expand All @@ -166,6 +166,47 @@ public IUpdateBuilder<T> Add<TItem>(Expression<Func<T, IEnumerable<TItem>?>> fie
return this;
}

public IUpdateBuilder<T> SetAll<TItem, TField>(
Expression<Func<T, IEnumerable<TItem>?>> collectionField,
Expression<Func<TItem, TField>> itemField,
TField value,
Expression<Func<TItem, bool>>? predicate = null
)
{
(IEnumerable<object> owners, PropertyInfo? prop, object? index) = GetFieldOwners(
_entity,
_filter,
collectionField
);
object[]? indices = index == null ? null : [index];
Func<TItem, bool>? predicateFunc = predicate?.Compile();
foreach (object owner in owners)
{
var collection = (IEnumerable<TItem>?)prop.GetValue(owner, indices);
if (collection is null)
continue;
foreach (TItem item in collection)
{
if (predicateFunc == null || predicateFunc(item))
Set(item, i => true, itemField, value);
}
}
return this;
}

private static void Set<TEntity, TField>(
TEntity entity,
Expression<Func<TEntity, bool>> filter,
Expression<Func<TEntity, TField>> field,
TField value
)
{
(IEnumerable<object> owners, PropertyInfo? prop, object? index) = GetFieldOwners(entity, filter, field);
object[]? indices = index == null ? null : [index];
foreach (object owner in owners)
prop.SetValue(owner, value, indices);
}

private static bool IsAnyMethod(MethodInfo mi)
{
return mi.DeclaringType == typeof(Enumerable) && mi.Name == "Any";
Expand All @@ -180,8 +221,10 @@ private static MethodInfo GetFirstOrDefaultMethod(Type type)
.MakeGenericMethod(type);
}

private (IEnumerable<object> Owners, PropertyInfo Property, object? Index) GetFieldOwners<TField>(
Expression<Func<T, TField>> field
private static (IEnumerable<object> Owners, PropertyInfo Property, object? Index) GetFieldOwners<TEntity, TField>(
TEntity entity,
Expression<Func<TEntity, bool>> filter,
Expression<Func<TEntity, TField>> field
)
{
List<object>? owners = null;
Expand All @@ -192,8 +235,8 @@ Expression<Func<T, TField>> field
var newOwners = new List<object>();
if (owners == null)
{
if (_entity != null)
newOwners.Add(_entity);
if (entity != null)
newOwners.Add(entity);
}
else
{
Expand All @@ -206,17 +249,14 @@ Expression<Func<T, TField>> field
switch (index)
{
case ArrayPosition.FirstMatching:
foreach (Expression expression in ExpressionHelper.Flatten(_filter))
foreach (Expression expression in ExpressionHelper.Flatten(filter))
{
if (expression is MethodCallExpression callExpr && IsAnyMethod(callExpr.Method))
{
var predicate = (LambdaExpression)callExpr.Arguments[1];
Type itemType = predicate.Parameters[0].Type;
MethodInfo firstOrDefault = GetFirstOrDefaultMethod(itemType);
newOwner = firstOrDefault.Invoke(
null,
new object[] { owner, predicate.Compile() }
);
newOwner = firstOrDefault.Invoke(null, [owner, predicate.Compile()]);
if (newOwner != null)
newOwners.Add(newOwner);
break;
Expand Down Expand Up @@ -245,7 +285,7 @@ Expression<Func<T, TField>> field
}
else
{
newOwner = method.Invoke(owner, new object[] { index });
newOwner = method.Invoke(owner, [index]);
if (newOwner != null)
newOwners.Add(newOwner);
}
Expand Down
13 changes: 9 additions & 4 deletions src/DataAccess/src/SIL.DataAccess/MongoRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ await _collection
var updateBuilder = new MongoUpdateBuilder<T>();
update(updateBuilder);
updateBuilder.Inc(e => e.Revision, 1);
UpdateDefinition<T> updateDef = updateBuilder.Build();
(UpdateDefinition<T> updateDef, IReadOnlyList<ArrayFilterDefinition> arrayFilters) = updateBuilder.Build();
var options = new FindOneAndUpdateOptions<T>
{
IsUpsert = upsert,
Expand Down Expand Up @@ -160,20 +160,25 @@ public async Task<int> UpdateAllAsync(
var updateBuilder = new MongoUpdateBuilder<T>();
update(updateBuilder);
updateBuilder.Inc(e => e.Revision, 1);
UpdateDefinition<T> updateDef = updateBuilder.Build();
(UpdateDefinition<T> updateDef, IReadOnlyList<ArrayFilterDefinition> arrayFilters) = updateBuilder.Build();
UpdateOptions? updateOptions = null;
if (arrayFilters.Count > 0)
{
updateOptions = new UpdateOptions { ArrayFilters = arrayFilters };
}
UpdateResult result;
try
{
if (_context.Session is not null)
{
result = await _collection
.UpdateManyAsync(_context.Session, filter, updateDef, cancellationToken: cancellationToken)
.UpdateManyAsync(_context.Session, filter, updateDef, updateOptions, cancellationToken)
.ConfigureAwait(false);
}
else
{
result = await _collection
.UpdateManyAsync(filter, updateDef, cancellationToken: cancellationToken)
.UpdateManyAsync(filter, updateDef, updateOptions, cancellationToken)
.ConfigureAwait(false);
}
}
Expand Down
50 changes: 44 additions & 6 deletions src/DataAccess/src/SIL.DataAccess/MongoUpdateBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ public class MongoUpdateBuilder<T> : IUpdateBuilder<T>
{
private readonly UpdateDefinitionBuilder<T> _builder;
private readonly List<UpdateDefinition<T>> _defs;
private readonly List<ArrayFilterDefinition<BsonValue>> _arrayFilters;

public MongoUpdateBuilder()
{
_builder = Builders<T>.Update;
_defs = new List<UpdateDefinition<T>>();
_arrayFilters = new List<ArrayFilterDefinition<BsonValue>>();
}

public IUpdateBuilder<T> Set<TField>(Expression<Func<T, TField>> field, TField value)
Expand Down Expand Up @@ -38,7 +40,7 @@ public IUpdateBuilder<T> Inc(Expression<Func<T, int>> field, int value = 1)

public IUpdateBuilder<T> RemoveAll<TItem>(
Expression<Func<T, IEnumerable<TItem>?>> field,
Expression<Func<TItem, bool>> predicate
Expression<Func<TItem, bool>>? predicate = null
)
{
_defs.Add(_builder.PullFilter(ToFieldDefinition(field), Builders<TItem>.Filter.Where(predicate)));
Expand All @@ -57,15 +59,51 @@ public IUpdateBuilder<T> Add<TItem>(Expression<Func<T, IEnumerable<TItem>?>> fie
return this;
}

public UpdateDefinition<T> Build()
public IUpdateBuilder<T> SetAll<TItem, TField>(
Expression<Func<T, IEnumerable<TItem>?>> collectionField,
Expression<Func<TItem, TField>> itemField,
TField value,
Expression<Func<TItem, bool>>? predicate = null
)
{
Expression<Func<T, TItem>> itemExpr = ExpressionHelper.Concatenate(
collectionField,
(collection) => ((IReadOnlyList<TItem>?)collection)![ArrayPosition.ArrayFilter]
);
Expression<Func<T, TField>> fieldExpr = ExpressionHelper.Concatenate(itemExpr, itemField);
if (predicate != null)
{
string filterId = "f" + ObjectId.GenerateNewId().ToString();
_defs.Add(_builder.Set(ToFieldDefinition(fieldExpr, filterId), value));
ExpressionFilterDefinition<TItem> filter = new(predicate);
BsonDocument bsonDoc = filter.Render(
BsonSerializer.SerializerRegistry.GetSerializer<TItem>(),
BsonSerializer.SerializerRegistry,
LinqProvider.V2
);
_arrayFilters.Add(
new BsonDocument($"{filterId}.{bsonDoc.Elements.Single().Name}", bsonDoc.Elements.Single().Value)
);
}
else
{
_defs.Add(_builder.Set(ToFieldDefinition(fieldExpr), value));
}
return this;
}

public (UpdateDefinition<T>, IReadOnlyList<ArrayFilterDefinition>) Build()
{
if (_defs.Count == 1)
return _defs.Single();
return _builder.Combine(_defs);
return (_defs.Single(), _arrayFilters);
return (_builder.Combine(_defs), _arrayFilters);
}

private static FieldDefinition<T, TField> ToFieldDefinition<TField>(Expression<Func<T, TField>> field)
private static FieldDefinition<T, TField> ToFieldDefinition<TField>(
Expression<Func<T, TField>> field,
string arrayFilterId = ""
)
{
return new DataAccessFieldDefinition<T, TField>(field);
return new DataAccessFieldDefinition<T, TField>(field, arrayFilterId);
}
}
Loading

0 comments on commit 997ccdc

Please sign in to comment.