Skip to content

Commit

Permalink
change how CrdtRepository restricts its queries to avoid writing code…
Browse files Browse the repository at this point in the history
… that won't work in scoped contexts
  • Loading branch information
hahn-kev committed Nov 6, 2024
1 parent b9d7d7a commit 476f128
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 62 deletions.
4 changes: 2 additions & 2 deletions src/SIL.Harmony.Tests/DataModelPerformanceTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ internal static async Task BulkInsertChanges(DataModelTestBase dataModelTest, in
};
commit.SetParentHash(parentHash);
parentHash = commit.Hash;
dataModelTest.DbContext.Commits.Add(commit);
dataModelTest.DbContext.Snapshots.Add(new ObjectSnapshot(await change.NewEntity(commit, null!), commit, true));
dataModelTest.DbContext.Add(commit);
dataModelTest.DbContext.Add(new ObjectSnapshot(await change.NewEntity(commit, null!), commit, true));
}

await dataModelTest.DbContext.SaveChangesAsync();
Expand Down
6 changes: 3 additions & 3 deletions src/SIL.Harmony.Tests/DbContextTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public async Task CanRoundTripDatesFromEf(int offset)
ClientId = Guid.NewGuid(),
HybridDateTime = new HybridDateTime(expectedDateTime, 0)
};
DbContext.Commits.Add(commit);
DbContext.Add(commit);
await DbContext.SaveChangesAsync();
var actualCommit = await DbContext.Commits.AsNoTracking().SingleOrDefaultAsyncEF(c => c.Id == commitId);
actualCommit!.HybridDateTime.DateTime.Should().Be(expectedDateTime, "EF");
Expand All @@ -46,7 +46,7 @@ public async Task CanRoundTripDatesFromLinq2Db(int offset)
var commitId = Guid.NewGuid();
var expectedDateTime = new DateTimeOffset(2000, 1, 1, 1, 11, 11, TimeSpan.FromHours(offset));

await DbContext.Commits.ToLinqToDBTable().AsValueInsertable()
await DbContext.Set<Commit>().ToLinqToDBTable().AsValueInsertable()
.Value(c => c.Id, commitId)
.Value(c => c.ClientId, Guid.NewGuid())
.Value(c => c.HybridDateTime.DateTime, expectedDateTime)
Expand Down Expand Up @@ -74,7 +74,7 @@ public async Task CanFilterCommitsByDateTime(double scale)
for (int i = 0; i < 50; i++)
{
var offset = new TimeSpan((long)(i * scale));
DbContext.Commits.Add(new Commit
DbContext.Add(new Commit
{
ClientId = Guid.NewGuid(),
HybridDateTime = new HybridDateTime(baseDateTime.Add(offset), 0)
Expand Down
156 changes: 101 additions & 55 deletions src/SIL.Harmony/Db/CrdtRepository.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using SIL.Harmony.Core;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.ChangeTracking;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.Extensions.Options;
using SIL.Harmony.Changes;
Expand All @@ -9,34 +10,25 @@

namespace SIL.Harmony.Db;

internal class CrdtRepository(ICrdtDbContext _dbContext, IOptions<CrdtConfig> crdtConfig,
Commit? ignoreChangesAfter = null
)
internal class CrdtRepository
{
private IQueryable<ObjectSnapshot> Snapshots => _dbContext.Snapshots.AsNoTracking();
private readonly ICrdtDbContext _dbContext;
private readonly IOptions<CrdtConfig> _crdtConfig;

private IQueryable<ObjectSnapshot> SnapshotsFiltered
public CrdtRepository(ICrdtDbContext dbContext, IOptions<CrdtConfig> crdtConfig,
Commit? ignoreChangesAfter = null)
{
get
{
if (ignoreChangesAfter is not null)
{
return Snapshots.WhereBefore(ignoreChangesAfter, inclusive: true);
}
return Snapshots;
}
}
private IQueryable<Commit> Commits
{
get
{
if (ignoreChangesAfter is not null)
{
return _dbContext.Commits.WhereBefore(ignoreChangesAfter, inclusive: true);
}
return _dbContext.Commits;
}
_crdtConfig = crdtConfig;
_dbContext = ignoreChangesAfter is not null ? new ScopedDbContext(dbContext, ignoreChangesAfter) : dbContext;
//we can't use the scoped db context is it prevents access to the DbSet for the Snapshots,
//but since we're using a custom query, we can use it directly and apply the scoped filters manually
_currentSnapshotsQueryable = MakeCurrentSnapshotsQuery(dbContext, ignoreChangesAfter);
}


private IQueryable<ObjectSnapshot> Snapshots => _dbContext.Snapshots.AsNoTracking();

private IQueryable<Commit> Commits => _dbContext.Commits;

public Task<IDbContextTransaction> BeginTransactionAsync()
{
Expand Down Expand Up @@ -79,28 +71,34 @@ public IQueryable<Commit> CurrentCommits()
return Commits.DefaultOrder();
}

public IQueryable<ObjectSnapshot> CurrentSnapshots()
private static IQueryable<ObjectSnapshot> MakeCurrentSnapshotsQuery(ICrdtDbContext dbContext, Commit? ignoreChangesAfter)
{
var ignoreAfterDate = ignoreChangesAfter?.HybridDateTime.DateTime.UtcDateTime;
var ignoreAfterCounter = ignoreChangesAfter?.HybridDateTime.Counter;
var ignoreAfterCommitId = ignoreChangesAfter?.Id;
return _dbContext.Snapshots.FromSql(
$"""
WITH LatestSnapshots AS (SELECT first_value(s1.Id)
OVER (
PARTITION BY "s1"."EntityId"
ORDER BY "c"."DateTime" DESC, "c"."Counter" DESC, "c"."Id" DESC
) AS "LatestSnapshotId"
FROM "Snapshots" AS "s1"
INNER JOIN "Commits" AS "c" ON "s1"."CommitId" = "c"."Id"
WHERE {ignoreAfterDate} IS NULL
OR ("c"."DateTime" < {ignoreAfterDate} OR ("c"."DateTime" = {ignoreAfterDate} AND "c"."Counter" < {ignoreAfterCounter}) OR
("c"."DateTime" = {ignoreAfterDate} AND "c"."Counter" = {ignoreAfterCounter} AND "c"."Id" < {ignoreAfterCommitId}) OR "c"."Id" = {ignoreAfterCommitId}))
SELECT *
FROM "Snapshots" AS "s"
INNER JOIN LatestSnapshots AS "ls" ON "s"."Id" = "ls"."LatestSnapshotId"
GROUP BY s.EntityId
""").AsNoTracking();
return dbContext.Set<ObjectSnapshot>().FromSql(
$"""
WITH LatestSnapshots AS (SELECT first_value(s1.Id)
OVER (
PARTITION BY "s1"."EntityId"
ORDER BY "c"."DateTime" DESC, "c"."Counter" DESC, "c"."Id" DESC
) AS "LatestSnapshotId"
FROM "Snapshots" AS "s1"
INNER JOIN "Commits" AS "c" ON "s1"."CommitId" = "c"."Id"
WHERE {ignoreAfterDate} IS NULL
OR ("c"."DateTime" < {ignoreAfterDate} OR ("c"."DateTime" = {ignoreAfterDate} AND "c"."Counter" < {ignoreAfterCounter}) OR
("c"."DateTime" = {ignoreAfterDate} AND "c"."Counter" = {ignoreAfterCounter} AND "c"."Id" < {ignoreAfterCommitId}) OR "c"."Id" = {ignoreAfterCommitId}))
SELECT *
FROM "Snapshots" AS "s"
INNER JOIN LatestSnapshots AS "ls" ON "s"."Id" = "ls"."LatestSnapshotId"
GROUP BY s.EntityId
""").AsNoTracking();
}

private readonly IQueryable<ObjectSnapshot> _currentSnapshotsQueryable;
public IQueryable<ObjectSnapshot> CurrentSnapshots()
{
return _currentSnapshotsQueryable;
}

public IAsyncEnumerable<SimpleSnapshot> CurrenSimpleSnapshots(bool includeDeleted = false)
Expand Down Expand Up @@ -161,15 +159,15 @@ public async Task<Commit[]> GetCommitsAfter(Commit? commit)

public async Task<ObjectSnapshot?> FindSnapshot(Guid id, bool tracking = false)
{
return await SnapshotsFiltered
return await Snapshots
.AsTracking(tracking)
.Include(s => s.Commit)
.SingleOrDefaultAsync(s => s.Id == id);
}

public async Task<ObjectSnapshot?> GetCurrentSnapshotByObjectId(Guid objectId, bool tracking = false)
{
return await SnapshotsFiltered
return await Snapshots
.AsTracking(tracking)
.Include(s => s.Commit)
.DefaultOrder()
Expand All @@ -178,7 +176,7 @@ public async Task<Commit[]> GetCommitsAfter(Commit? commit)

public async Task<T> GetObjectBySnapshotId<T>(Guid snapshotId)
{
var entity = await SnapshotsFiltered
var entity = await Snapshots
.Where(s => s.Id == snapshotId)
.Select(s => s.Entity)
.SingleOrDefaultAsync()
Expand All @@ -194,7 +192,7 @@ public async Task<T> GetObjectBySnapshotId<T>(Guid snapshotId)

public IQueryable<T> GetCurrentObjects<T>() where T : class
{
if (crdtConfig.Value.EnableProjectedTables)
if (_crdtConfig.Value.EnableProjectedTables)
{
return _dbContext.Set<T>();
}
Expand All @@ -208,15 +206,14 @@ public async Task<SyncState> GetCurrentSyncState()

public async Task<ChangesResult<Commit>> GetChanges(SyncState remoteState)
{
var dbContextCommits = _dbContext.Commits;
return await dbContextCommits.GetChanges<Commit, IChange>(remoteState);
return await _dbContext.Commits.GetChanges<Commit, IChange>(remoteState);
}

public async Task AddSnapshots(IEnumerable<ObjectSnapshot> snapshots)
{
foreach (var objectSnapshot in snapshots)
{
_dbContext.Snapshots.Add(objectSnapshot);
_dbContext.Add(objectSnapshot);
await SnapshotAdded(objectSnapshot);
}

Expand All @@ -228,7 +225,7 @@ public async ValueTask AddIfNew(IEnumerable<ObjectSnapshot> snapshots)
foreach (var snapshot in snapshots)
{

if (_dbContext.Snapshots.Local.FindEntry(snapshot.Id) is not null) continue;
if (_dbContext.Set<ObjectSnapshot>().Local.FindEntry(snapshot.Id) is not null) continue;
_dbContext.Add(snapshot);
await SnapshotAdded(snapshot);
}
Expand All @@ -238,7 +235,7 @@ public async ValueTask AddIfNew(IEnumerable<ObjectSnapshot> snapshots)

private async ValueTask SnapshotAdded(ObjectSnapshot objectSnapshot)
{
if (!crdtConfig.Value.EnableProjectedTables) return;
if (!_crdtConfig.Value.EnableProjectedTables) return;
if (objectSnapshot.IsRoot && objectSnapshot.EntityIsDeleted) return;
//need to check if an entry exists already, even if this is the root commit it may have already been added to the db
var existingEntry = await GetEntityEntry(objectSnapshot.Entity.DbObject.GetType(), objectSnapshot.EntityId);
Expand Down Expand Up @@ -267,25 +264,25 @@ private async ValueTask SnapshotAdded(ObjectSnapshot objectSnapshot)

private async ValueTask<EntityEntry?> GetEntityEntry(Type entityType, Guid entityId)
{
if (!crdtConfig.Value.EnableProjectedTables) return null;
if (!_crdtConfig.Value.EnableProjectedTables) return null;
var entity = await _dbContext.FindAsync(entityType, entityId);
return entity is not null ? _dbContext.Entry(entity) : null;
}

public CrdtRepository GetScopedRepository(Commit excludeChangesAfterCommit)
{
return new CrdtRepository(_dbContext, crdtConfig, excludeChangesAfterCommit);
return new CrdtRepository(_dbContext, _crdtConfig, excludeChangesAfterCommit);
}

public async Task AddCommit(Commit commit)
{
_dbContext.Commits.Add(commit);
_dbContext.Add(commit);
await _dbContext.SaveChangesAsync();
}

public async Task AddCommits(IEnumerable<Commit> commits, bool save = true)
{
_dbContext.Commits.AddRange(commits);
_dbContext.AddRange(commits);
if (save) await _dbContext.SaveChangesAsync();
}

Expand All @@ -298,3 +295,52 @@ public async Task AddCommits(IEnumerable<Commit> commits, bool save = true)
.FirstOrDefault();
}
}

internal class ScopedDbContext(ICrdtDbContext inner, Commit ignoreChangesAfter) : ICrdtDbContext
{
public IQueryable<Commit> Commits => inner.Commits.WhereBefore(ignoreChangesAfter, inclusive: true);

public IQueryable<ObjectSnapshot> Snapshots => inner.Snapshots.WhereBefore(ignoreChangesAfter, inclusive: true);

public Task<int> SaveChangesAsync(CancellationToken cancellationToken = default)
{
return inner.SaveChangesAsync(cancellationToken);
}

public ValueTask<object?> FindAsync(Type entityType, params object?[]? keyValues)
{
throw new NotSupportedException("can not support FindAsync when using scoped db context");
}

public DbSet<TEntity> Set<TEntity>() where TEntity : class
{
throw new NotSupportedException("can not support Set<T> when using scoped db context");
}

public DatabaseFacade Database => inner.Database;

public EntityEntry<TEntity> Entry<TEntity>(TEntity entity) where TEntity : class
{
return inner.Entry(entity);
}

public EntityEntry Entry(object entity)
{
return inner.Entry(entity);
}

public EntityEntry Add(object entity)
{
return inner.Add(entity);
}

public void AddRange(IEnumerable<object> entities)
{
inner.AddRange(entities);
}

public EntityEntry Remove(object entity)
{
return inner.Remove(entity);
}
}
5 changes: 3 additions & 2 deletions src/SIL.Harmony/Db/ICrdtDbContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ namespace SIL.Harmony.Db;

public interface ICrdtDbContext
{
DbSet<Commit> Commits => Set<Commit>();
DbSet<ObjectSnapshot> Snapshots => Set<ObjectSnapshot>();
IQueryable<Commit> Commits => Set<Commit>();
IQueryable<ObjectSnapshot> Snapshots => Set<ObjectSnapshot>();
Task<int> SaveChangesAsync(CancellationToken cancellationToken = default);
ValueTask<object?> FindAsync(Type entityType, params object?[]? keyValues);
DbSet<TEntity> Set<TEntity>() where TEntity : class;
DatabaseFacade Database { get; }
EntityEntry<TEntity> Entry<TEntity>(TEntity entity) where TEntity : class;
EntityEntry Entry(object entity);
EntityEntry Add(object entity);
void AddRange(IEnumerable<object> entities);
EntityEntry Remove(object entity);
}

0 comments on commit 476f128

Please sign in to comment.