diff --git a/src/SIL.Harmony.Tests/DataModelPerformanceTests.cs b/src/SIL.Harmony.Tests/DataModelPerformanceTests.cs index b5b4443..6aaf2af 100644 --- a/src/SIL.Harmony.Tests/DataModelPerformanceTests.cs +++ b/src/SIL.Harmony.Tests/DataModelPerformanceTests.cs @@ -125,8 +125,8 @@ internal static async Task BulkInsertChanges(DataModelTestBase dataModelTest, in }; commit.SetParentHash(parentHash); parentHash = commit.Hash; - dataModelTest.DbContext.Commits.Add(commit); - dataModelTest.DbContext.Snapshots.Add(new ObjectSnapshot(await change.NewEntity(commit, null!), commit, true)); + dataModelTest.DbContext.Add(commit); + dataModelTest.DbContext.Add(new ObjectSnapshot(await change.NewEntity(commit, null!), commit, true)); } await dataModelTest.DbContext.SaveChangesAsync(); diff --git a/src/SIL.Harmony.Tests/DbContextTests.cs b/src/SIL.Harmony.Tests/DbContextTests.cs index 33cc0a1..8341517 100644 --- a/src/SIL.Harmony.Tests/DbContextTests.cs +++ b/src/SIL.Harmony.Tests/DbContextTests.cs @@ -27,7 +27,7 @@ public async Task CanRoundTripDatesFromEf(int offset) ClientId = Guid.NewGuid(), HybridDateTime = new HybridDateTime(expectedDateTime, 0) }; - DbContext.Commits.Add(commit); + DbContext.Add(commit); await DbContext.SaveChangesAsync(); var actualCommit = await DbContext.Commits.AsNoTracking().SingleOrDefaultAsyncEF(c => c.Id == commitId); actualCommit!.HybridDateTime.DateTime.Should().Be(expectedDateTime, "EF"); @@ -46,7 +46,7 @@ public async Task CanRoundTripDatesFromLinq2Db(int offset) var commitId = Guid.NewGuid(); var expectedDateTime = new DateTimeOffset(2000, 1, 1, 1, 11, 11, TimeSpan.FromHours(offset)); - await DbContext.Commits.ToLinqToDBTable().AsValueInsertable() + await DbContext.Set().ToLinqToDBTable().AsValueInsertable() .Value(c => c.Id, commitId) .Value(c => c.ClientId, Guid.NewGuid()) .Value(c => c.HybridDateTime.DateTime, expectedDateTime) @@ -74,7 +74,7 @@ public async Task CanFilterCommitsByDateTime(double scale) for (int i = 0; i < 50; i++) { var offset = new TimeSpan((long)(i * scale)); - DbContext.Commits.Add(new Commit + DbContext.Add(new Commit { ClientId = Guid.NewGuid(), HybridDateTime = new HybridDateTime(baseDateTime.Add(offset), 0) diff --git a/src/SIL.Harmony/Db/CrdtRepository.cs b/src/SIL.Harmony/Db/CrdtRepository.cs index 85f57dc..82efc83 100644 --- a/src/SIL.Harmony/Db/CrdtRepository.cs +++ b/src/SIL.Harmony/Db/CrdtRepository.cs @@ -1,6 +1,7 @@ using SIL.Harmony.Core; using Microsoft.EntityFrameworkCore; using Microsoft.EntityFrameworkCore.ChangeTracking; +using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Storage; using Microsoft.Extensions.Options; using SIL.Harmony.Changes; @@ -9,34 +10,25 @@ namespace SIL.Harmony.Db; -internal class CrdtRepository(ICrdtDbContext _dbContext, IOptions crdtConfig, - Commit? ignoreChangesAfter = null -) +internal class CrdtRepository { - private IQueryable Snapshots => _dbContext.Snapshots.AsNoTracking(); + private readonly ICrdtDbContext _dbContext; + private readonly IOptions _crdtConfig; - private IQueryable SnapshotsFiltered + public CrdtRepository(ICrdtDbContext dbContext, IOptions crdtConfig, + Commit? ignoreChangesAfter = null) { - get - { - if (ignoreChangesAfter is not null) - { - return Snapshots.WhereBefore(ignoreChangesAfter, inclusive: true); - } - return Snapshots; - } - } - private IQueryable Commits - { - get - { - if (ignoreChangesAfter is not null) - { - return _dbContext.Commits.WhereBefore(ignoreChangesAfter, inclusive: true); - } - return _dbContext.Commits; - } + _crdtConfig = crdtConfig; + _dbContext = ignoreChangesAfter is not null ? new ScopedDbContext(dbContext, ignoreChangesAfter) : dbContext; + //we can't use the scoped db context is it prevents access to the DbSet for the Snapshots, + //but since we're using a custom query, we can use it directly and apply the scoped filters manually + _currentSnapshotsQueryable = MakeCurrentSnapshotsQuery(dbContext, ignoreChangesAfter); } + + + private IQueryable Snapshots => _dbContext.Snapshots.AsNoTracking(); + + private IQueryable Commits => _dbContext.Commits; public Task BeginTransactionAsync() { @@ -79,28 +71,34 @@ public IQueryable CurrentCommits() return Commits.DefaultOrder(); } - public IQueryable CurrentSnapshots() + private static IQueryable MakeCurrentSnapshotsQuery(ICrdtDbContext dbContext, Commit? ignoreChangesAfter) { var ignoreAfterDate = ignoreChangesAfter?.HybridDateTime.DateTime.UtcDateTime; var ignoreAfterCounter = ignoreChangesAfter?.HybridDateTime.Counter; var ignoreAfterCommitId = ignoreChangesAfter?.Id; - return _dbContext.Snapshots.FromSql( -$""" -WITH LatestSnapshots AS (SELECT first_value(s1.Id) - OVER ( - PARTITION BY "s1"."EntityId" - ORDER BY "c"."DateTime" DESC, "c"."Counter" DESC, "c"."Id" DESC - ) AS "LatestSnapshotId" - FROM "Snapshots" AS "s1" - INNER JOIN "Commits" AS "c" ON "s1"."CommitId" = "c"."Id" - WHERE {ignoreAfterDate} IS NULL - OR ("c"."DateTime" < {ignoreAfterDate} OR ("c"."DateTime" = {ignoreAfterDate} AND "c"."Counter" < {ignoreAfterCounter}) OR - ("c"."DateTime" = {ignoreAfterDate} AND "c"."Counter" = {ignoreAfterCounter} AND "c"."Id" < {ignoreAfterCommitId}) OR "c"."Id" = {ignoreAfterCommitId})) -SELECT * -FROM "Snapshots" AS "s" - INNER JOIN LatestSnapshots AS "ls" ON "s"."Id" = "ls"."LatestSnapshotId" -GROUP BY s.EntityId -""").AsNoTracking(); + return dbContext.Set().FromSql( + $""" + WITH LatestSnapshots AS (SELECT first_value(s1.Id) + OVER ( + PARTITION BY "s1"."EntityId" + ORDER BY "c"."DateTime" DESC, "c"."Counter" DESC, "c"."Id" DESC + ) AS "LatestSnapshotId" + FROM "Snapshots" AS "s1" + INNER JOIN "Commits" AS "c" ON "s1"."CommitId" = "c"."Id" + WHERE {ignoreAfterDate} IS NULL + OR ("c"."DateTime" < {ignoreAfterDate} OR ("c"."DateTime" = {ignoreAfterDate} AND "c"."Counter" < {ignoreAfterCounter}) OR + ("c"."DateTime" = {ignoreAfterDate} AND "c"."Counter" = {ignoreAfterCounter} AND "c"."Id" < {ignoreAfterCommitId}) OR "c"."Id" = {ignoreAfterCommitId})) + SELECT * + FROM "Snapshots" AS "s" + INNER JOIN LatestSnapshots AS "ls" ON "s"."Id" = "ls"."LatestSnapshotId" + GROUP BY s.EntityId + """).AsNoTracking(); + } + + private readonly IQueryable _currentSnapshotsQueryable; + public IQueryable CurrentSnapshots() + { + return _currentSnapshotsQueryable; } public IAsyncEnumerable CurrenSimpleSnapshots(bool includeDeleted = false) @@ -161,7 +159,7 @@ public async Task GetCommitsAfter(Commit? commit) public async Task FindSnapshot(Guid id, bool tracking = false) { - return await SnapshotsFiltered + return await Snapshots .AsTracking(tracking) .Include(s => s.Commit) .SingleOrDefaultAsync(s => s.Id == id); @@ -169,7 +167,7 @@ public async Task GetCommitsAfter(Commit? commit) public async Task GetCurrentSnapshotByObjectId(Guid objectId, bool tracking = false) { - return await SnapshotsFiltered + return await Snapshots .AsTracking(tracking) .Include(s => s.Commit) .DefaultOrder() @@ -178,7 +176,7 @@ public async Task GetCommitsAfter(Commit? commit) public async Task GetObjectBySnapshotId(Guid snapshotId) { - var entity = await SnapshotsFiltered + var entity = await Snapshots .Where(s => s.Id == snapshotId) .Select(s => s.Entity) .SingleOrDefaultAsync() @@ -194,7 +192,7 @@ public async Task GetObjectBySnapshotId(Guid snapshotId) public IQueryable GetCurrentObjects() where T : class { - if (crdtConfig.Value.EnableProjectedTables) + if (_crdtConfig.Value.EnableProjectedTables) { return _dbContext.Set(); } @@ -208,15 +206,14 @@ public async Task GetCurrentSyncState() public async Task> GetChanges(SyncState remoteState) { - var dbContextCommits = _dbContext.Commits; - return await dbContextCommits.GetChanges(remoteState); + return await _dbContext.Commits.GetChanges(remoteState); } public async Task AddSnapshots(IEnumerable snapshots) { foreach (var objectSnapshot in snapshots) { - _dbContext.Snapshots.Add(objectSnapshot); + _dbContext.Add(objectSnapshot); await SnapshotAdded(objectSnapshot); } @@ -228,7 +225,7 @@ public async ValueTask AddIfNew(IEnumerable snapshots) foreach (var snapshot in snapshots) { - if (_dbContext.Snapshots.Local.FindEntry(snapshot.Id) is not null) continue; + if (_dbContext.Set().Local.FindEntry(snapshot.Id) is not null) continue; _dbContext.Add(snapshot); await SnapshotAdded(snapshot); } @@ -238,7 +235,7 @@ public async ValueTask AddIfNew(IEnumerable snapshots) private async ValueTask SnapshotAdded(ObjectSnapshot objectSnapshot) { - if (!crdtConfig.Value.EnableProjectedTables) return; + if (!_crdtConfig.Value.EnableProjectedTables) return; if (objectSnapshot.IsRoot && objectSnapshot.EntityIsDeleted) return; //need to check if an entry exists already, even if this is the root commit it may have already been added to the db var existingEntry = await GetEntityEntry(objectSnapshot.Entity.DbObject.GetType(), objectSnapshot.EntityId); @@ -267,25 +264,25 @@ private async ValueTask SnapshotAdded(ObjectSnapshot objectSnapshot) private async ValueTask GetEntityEntry(Type entityType, Guid entityId) { - if (!crdtConfig.Value.EnableProjectedTables) return null; + if (!_crdtConfig.Value.EnableProjectedTables) return null; var entity = await _dbContext.FindAsync(entityType, entityId); return entity is not null ? _dbContext.Entry(entity) : null; } public CrdtRepository GetScopedRepository(Commit excludeChangesAfterCommit) { - return new CrdtRepository(_dbContext, crdtConfig, excludeChangesAfterCommit); + return new CrdtRepository(_dbContext, _crdtConfig, excludeChangesAfterCommit); } public async Task AddCommit(Commit commit) { - _dbContext.Commits.Add(commit); + _dbContext.Add(commit); await _dbContext.SaveChangesAsync(); } public async Task AddCommits(IEnumerable commits, bool save = true) { - _dbContext.Commits.AddRange(commits); + _dbContext.AddRange(commits); if (save) await _dbContext.SaveChangesAsync(); } @@ -298,3 +295,52 @@ public async Task AddCommits(IEnumerable commits, bool save = true) .FirstOrDefault(); } } + +internal class ScopedDbContext(ICrdtDbContext inner, Commit ignoreChangesAfter) : ICrdtDbContext +{ + public IQueryable Commits => inner.Commits.WhereBefore(ignoreChangesAfter, inclusive: true); + + public IQueryable Snapshots => inner.Snapshots.WhereBefore(ignoreChangesAfter, inclusive: true); + + public Task SaveChangesAsync(CancellationToken cancellationToken = default) + { + return inner.SaveChangesAsync(cancellationToken); + } + + public ValueTask FindAsync(Type entityType, params object?[]? keyValues) + { + throw new NotSupportedException("can not support FindAsync when using scoped db context"); + } + + public DbSet Set() where TEntity : class + { + throw new NotSupportedException("can not support Set when using scoped db context"); + } + + public DatabaseFacade Database => inner.Database; + + public EntityEntry Entry(TEntity entity) where TEntity : class + { + return inner.Entry(entity); + } + + public EntityEntry Entry(object entity) + { + return inner.Entry(entity); + } + + public EntityEntry Add(object entity) + { + return inner.Add(entity); + } + + public void AddRange(IEnumerable entities) + { + inner.AddRange(entities); + } + + public EntityEntry Remove(object entity) + { + return inner.Remove(entity); + } +} \ No newline at end of file diff --git a/src/SIL.Harmony/Db/ICrdtDbContext.cs b/src/SIL.Harmony/Db/ICrdtDbContext.cs index 9d40fe2..475e8e6 100644 --- a/src/SIL.Harmony/Db/ICrdtDbContext.cs +++ b/src/SIL.Harmony/Db/ICrdtDbContext.cs @@ -6,8 +6,8 @@ namespace SIL.Harmony.Db; public interface ICrdtDbContext { - DbSet Commits => Set(); - DbSet Snapshots => Set(); + IQueryable Commits => Set(); + IQueryable Snapshots => Set(); Task SaveChangesAsync(CancellationToken cancellationToken = default); ValueTask FindAsync(Type entityType, params object?[]? keyValues); DbSet Set() where TEntity : class; @@ -15,5 +15,6 @@ public interface ICrdtDbContext EntityEntry Entry(TEntity entity) where TEntity : class; EntityEntry Entry(object entity); EntityEntry Add(object entity); + void AddRange(IEnumerable entities); EntityEntry Remove(object entity); } \ No newline at end of file