diff --git a/Amplify/Categories/DataStore/Model/Internal/Persistable.swift b/Amplify/Categories/DataStore/Model/Internal/Persistable.swift index b7a53acf5a..92fd149d8d 100644 --- a/Amplify/Categories/DataStore/Model/Internal/Persistable.swift +++ b/Amplify/Categories/DataStore/Model/Internal/Persistable.swift @@ -65,6 +65,12 @@ struct PersistableHelper { return lhs == rhs case let (lhs, rhs) as (String, String): return lhs == rhs + case let (lhs, rhs) as (any EnumPersistable, String): + return lhs.rawValue == rhs + case let (lhs, rhs) as (String, any EnumPersistable): + return lhs == rhs.rawValue + case let (lhs, rhs) as (any EnumPersistable, any EnumPersistable): + return lhs.rawValue == rhs.rawValue default: return false } @@ -94,6 +100,12 @@ struct PersistableHelper { return lhs == Double(rhs) case let (lhs, rhs) as (String, String): return lhs == rhs + case let (lhs, rhs) as (any EnumPersistable, String): + return lhs.rawValue == rhs + case let (lhs, rhs) as (String, any EnumPersistable): + return lhs == rhs.rawValue + case let (lhs, rhs) as (any EnumPersistable, any EnumPersistable): + return lhs.rawValue == rhs.rawValue default: return false } @@ -122,6 +134,12 @@ struct PersistableHelper { return lhs <= Double(rhs) case let (lhs, rhs) as (String, String): return lhs <= rhs + case let (lhs, rhs) as (any EnumPersistable, String): + return lhs.rawValue <= rhs + case let (lhs, rhs) as (String, any EnumPersistable): + return lhs <= rhs.rawValue + case let (lhs, rhs) as (any EnumPersistable, any EnumPersistable): + return lhs.rawValue <= rhs.rawValue default: return false } @@ -150,6 +168,12 @@ struct PersistableHelper { return lhs < Double(rhs) case let (lhs, rhs) as (String, String): return lhs < rhs + case let (lhs, rhs) as (any EnumPersistable, String): + return lhs.rawValue < rhs + case let (lhs, rhs) as (String, any EnumPersistable): + return lhs < rhs.rawValue + case let (lhs, rhs) as (any EnumPersistable, any EnumPersistable): + return lhs.rawValue < rhs.rawValue default: return false } @@ -178,6 +202,12 @@ struct PersistableHelper { return lhs >= Double(rhs) case let (lhs, rhs) as (String, String): return lhs >= rhs + case let (lhs, rhs) as (any EnumPersistable, String): + return lhs.rawValue >= rhs + case let (lhs, rhs) as (String, any EnumPersistable): + return lhs >= rhs.rawValue + case let (lhs, rhs) as (any EnumPersistable, any EnumPersistable): + return lhs.rawValue >= rhs.rawValue default: return false } @@ -206,6 +236,12 @@ struct PersistableHelper { return Double(lhs) > rhs case let (lhs, rhs) as (String, String): return lhs > rhs + case let (lhs, rhs) as (any EnumPersistable, String): + return lhs.rawValue > rhs + case let (lhs, rhs) as (String, any EnumPersistable): + return lhs > rhs.rawValue + case let (lhs, rhs) as (any EnumPersistable, any EnumPersistable): + return lhs.rawValue > rhs.rawValue default: return false } diff --git a/Amplify/Categories/DataStore/Query/ModelKey.swift b/Amplify/Categories/DataStore/Query/ModelKey.swift index 8096bfc9c7..860cb17061 100644 --- a/Amplify/Categories/DataStore/Query/ModelKey.swift +++ b/Amplify/Categories/DataStore/Query/ModelKey.swift @@ -36,6 +36,11 @@ public protocol ModelKey: CodingKey, CaseIterable, QueryFieldOperation {} extension CodingKey where Self: ModelKey { + // MARK: - attributeExists + public func attributeExists(_ value: Bool) -> QueryPredicateOperation { + return field(stringValue).attributeExists(value) + } + // MARK: - beginsWith public func beginsWith(_ value: String) -> QueryPredicateOperation { return field(stringValue).beginsWith(value) diff --git a/Amplify/Categories/DataStore/Query/QueryField.swift b/Amplify/Categories/DataStore/Query/QueryField.swift index 9d29967569..afde18a94d 100644 --- a/Amplify/Categories/DataStore/Query/QueryField.swift +++ b/Amplify/Categories/DataStore/Query/QueryField.swift @@ -30,7 +30,7 @@ public func field(_ name: String) -> QueryField { /// - seealso: `ModelKey` public protocol QueryFieldOperation { // MARK: - Functions - + func attributeExists(_ value: Bool) -> QueryPredicateOperation func beginsWith(_ value: String) -> QueryPredicateOperation func between(start: Persistable, end: Persistable) -> QueryPredicateOperation func contains(_ value: String) -> QueryPredicateOperation @@ -61,6 +61,11 @@ public struct QueryField: QueryFieldOperation { public let name: String + // MARK: - attributeExists + public func attributeExists(_ value: Bool) -> QueryPredicateOperation { + return QueryPredicateOperation(field: name, operator: .attributeExists(value)) + } + // MARK: - beginsWith public func beginsWith(_ value: String) -> QueryPredicateOperation { return QueryPredicateOperation(field: name, operator: .beginsWith(value)) diff --git a/Amplify/Categories/DataStore/Query/QueryOperator+Equatable.swift b/Amplify/Categories/DataStore/Query/QueryOperator+Equatable.swift index 41ee77159b..c1907802b6 100644 --- a/Amplify/Categories/DataStore/Query/QueryOperator+Equatable.swift +++ b/Amplify/Categories/DataStore/Query/QueryOperator+Equatable.swift @@ -24,6 +24,8 @@ extension QueryOperator: Equatable { case let (.between(oneStart, oneEnd), .between(otherStart, otherEnd)): return PersistableHelper.isEqual(oneStart, otherStart) && PersistableHelper.isEqual(oneEnd, otherEnd) + case let (.attributeExists(one), .attributeExists(other)): + return one == other default: return false } diff --git a/Amplify/Categories/DataStore/Query/QueryOperator.swift b/Amplify/Categories/DataStore/Query/QueryOperator.swift index 2fcb50ccd2..e4897e5f0d 100644 --- a/Amplify/Categories/DataStore/Query/QueryOperator.swift +++ b/Amplify/Categories/DataStore/Query/QueryOperator.swift @@ -18,8 +18,9 @@ public enum QueryOperator: Encodable { case notContains(_ value: String) case between(start: Persistable, end: Persistable) case beginsWith(_ value: String) + case attributeExists(_ value: Bool) - public func evaluate(target: Any) -> Bool { + public func evaluate(target: Any?) -> Bool { switch self { case .notEqual(let predicateValue): return !PersistableHelper.isEqual(target, predicateValue) @@ -34,20 +35,26 @@ public enum QueryOperator: Encodable { case .greaterThan(let predicateValue): return PersistableHelper.isGreaterThan(target, predicateValue) case .contains(let predicateString): - if let targetString = target as? String { + if let targetString = target.flatMap({ $0 as? String }) { return targetString.contains(predicateString) } return false case .notContains(let predicateString): - if let targetString = target as? String { + if let targetString = target.flatMap({ $0 as? String }) { return !targetString.contains(predicateString) } case .between(let start, let end): return PersistableHelper.isBetween(start, end, target) case .beginsWith(let predicateValue): - if let targetString = target as? String { + if let targetString = target.flatMap({ $0 as? String }) { return targetString.starts(with: predicateValue) } + case .attributeExists(let predicateValue): + if case .some = target { + return predicateValue == true + } else { + return predicateValue == false + } } return false } @@ -105,6 +112,10 @@ public enum QueryOperator: Encodable { case .beginsWith(let value): try container.encode("beginsWith", forKey: .type) try container.encode(value, forKey: .value) + + case .attributeExists(let value): + try container.encode("attributeExists", forKey: .type) + try container.encode(value, forKey: .value) } } } diff --git a/Amplify/Categories/DataStore/Query/QueryPredicate.swift b/Amplify/Categories/DataStore/Query/QueryPredicate.swift index 78bdf9f051..222bd11c6e 100644 --- a/Amplify/Categories/DataStore/Query/QueryPredicate.swift +++ b/Amplify/Categories/DataStore/Query/QueryPredicate.swift @@ -155,34 +155,6 @@ public class QueryPredicateOperation: QueryPredicate, Encodable { } public func evaluate(target: Model) -> Bool { - guard let fieldValue = target[field] else { - return false - } - - guard let value = fieldValue else { - return false - } - - if let booleanValue = value as? Bool { - return self.operator.evaluate(target: booleanValue) - } - - if let doubleValue = value as? Double { - return self.operator.evaluate(target: doubleValue) - } - - if let intValue = value as? Int { - return self.operator.evaluate(target: intValue) - } - - if let timeValue = value as? Temporal.Time { - return self.operator.evaluate(target: timeValue) - } - - if let enumValue = value as? EnumPersistable { - return self.operator.evaluate(target: enumValue.rawValue) - } - - return self.operator.evaluate(target: value) + return self.operator.evaluate(target: target[field]?.flatMap { $0 }) } } diff --git a/AmplifyPlugins/API/Tests/APIHostApp/AWSAPIPluginFunctionalTests/GraphQLModelBasedTests+List.swift b/AmplifyPlugins/API/Tests/APIHostApp/AWSAPIPluginFunctionalTests/GraphQLModelBasedTests+List.swift index c42916e6e6..032a9d9c67 100644 --- a/AmplifyPlugins/API/Tests/APIHostApp/AWSAPIPluginFunctionalTests/GraphQLModelBasedTests+List.swift +++ b/AmplifyPlugins/API/Tests/APIHostApp/AWSAPIPluginFunctionalTests/GraphQLModelBasedTests+List.swift @@ -145,4 +145,56 @@ extension GraphQLModelBasedTests { XCTAssertNotNil(error) } } + + /** + - Given: API with Post schema and optional field 'draft' + - When: + - create a new post with optional field 'draft' value .none + - Then: + - query Posts with filter {eq : null} shouldn't include the post + */ + func test_listModelsWithNilOptionalField_failedWithEqFilter() async throws { + let post = Post(title: UUID().uuidString, content: UUID().uuidString, createdAt: .now()) + _ = try await Amplify.API.mutate(request: .create(post)) + let posts = try await list(.list( + Post.self, + where: Post.keys.draft == nil && Post.keys.createdAt >= post.createdAt + )) + + XCTAssertFalse(posts.map(\.id).contains(post.id)) + } + + /** + - Given: DataStore with Post schema and optional field 'draft' + - When: + - create a new post with optional field 'draft' value .none + - Then: + - query Posts with filter {attributeExists : false} should include the post + */ + func test_listModelsWithNilOptionalField_successWithAttributeExistsFilter() async throws { + let post = Post(title: UUID().uuidString, content: UUID().uuidString, createdAt: .now()) + _ = try await Amplify.API.mutate(request: .create(post)) + let listPosts = try await list( + .list( + Post.self, + where: Post.keys.draft.attributeExists(false) + && Post.keys.createdAt >= post.createdAt + ) + ) + + XCTAssertTrue(listPosts.map(\.id).contains(post.id)) + } + + func list(_ request: GraphQLRequest>) async throws -> [M] { + func getAllPages(_ list: List) async throws -> [M] { + if list.hasNextPage() { + return list.elements + (try await getAllPages(list.getNextPage())) + } else { + return list.elements + } + } + + return try await getAllPages(try await Amplify.API.query(request: request).get()) + } + } diff --git a/AmplifyPlugins/Core/AWSPluginsCore/Model/Support/QueryPredicate+GraphQL.swift b/AmplifyPlugins/Core/AWSPluginsCore/Model/Support/QueryPredicate+GraphQL.swift index f2e3a6f816..c7e6735920 100644 --- a/AmplifyPlugins/Core/AWSPluginsCore/Model/Support/QueryPredicate+GraphQL.swift +++ b/AmplifyPlugins/Core/AWSPluginsCore/Model/Support/QueryPredicate+GraphQL.swift @@ -187,6 +187,8 @@ extension QueryOperator { return "beginsWith" case .notContains: return "notContains" + case .attributeExists: + return "attributeExists" } } @@ -212,6 +214,8 @@ extension QueryOperator { return value case .notContains(let value): return value + case .attributeExists(let value): + return value } } } diff --git a/AmplifyPlugins/Core/AWSPluginsCoreTests/Model/GraphQLDocument/GraphQLListQueryTests.swift b/AmplifyPlugins/Core/AWSPluginsCoreTests/Model/GraphQLDocument/GraphQLListQueryTests.swift index 7d26e52108..fc2cc77bbe 100644 --- a/AmplifyPlugins/Core/AWSPluginsCoreTests/Model/GraphQLDocument/GraphQLListQueryTests.swift +++ b/AmplifyPlugins/Core/AWSPluginsCoreTests/Model/GraphQLDocument/GraphQLListQueryTests.swift @@ -218,4 +218,88 @@ class GraphQLListQueryTests: XCTestCase { XCTAssertEqual(variables["limit"] as? Int, 1_000) XCTAssertNotNil(variables["filter"]) } + + /** + - Given: + - A Post schema with optional field 'draft' + - When: + - Using list query to filter records that either don't have 'draft' field or have 'null' value + - Then: + - the query document as expected + - the filter is encoded correctly + */ + func test_listQuery_withAttributeExistsFilter_correctlyBuildGraphQLQueryStatement() { + let post = Post.keys + let predicate = post.id.eq("id") + && (post.draft.attributeExists(false) || post.draft.eq(nil)) + + var documentBuilder = ModelBasedGraphQLDocumentBuilder(modelSchema: Post.schema, operationType: .query) + documentBuilder.add(decorator: DirectiveNameDecorator(type: .list)) + documentBuilder.add(decorator: PaginationDecorator()) + documentBuilder.add(decorator: FilterDecorator(filter: predicate.graphQLFilter(for: Post.schema))) + let document = documentBuilder.build() + let expectedQueryDocument = """ + query ListPosts($filter: ModelPostFilterInput, $limit: Int) { + listPosts(filter: $filter, limit: $limit) { + items { + id + content + createdAt + draft + rating + status + title + updatedAt + __typename + } + nextToken + } + } + """ + XCTAssertEqual(document.name, "listPosts") + XCTAssertEqual(document.stringValue, expectedQueryDocument) + guard let variables = document.variables else { + XCTFail("The document doesn't contain variables") + return + } + XCTAssertNotNil(variables["limit"]) + XCTAssertEqual(variables["limit"] as? Int, 1_000) + + guard let filter = variables["filter"] as? GraphQLFilter else { + XCTFail("variables should contain a valid filter") + return + } + + // Test filter for a valid JSON format + let filterJSON = try? JSONSerialization.data(withJSONObject: filter, + options: .prettyPrinted) + XCTAssertNotNil(filterJSON) + + let expectedFilterJSON = """ + { + "and" : [ + { + "id" : { + "eq" : "id" + } + }, + { + "or" : [ + { + "draft" : { + "attributeExists" : false + } + }, + { + "draft" : { + "eq" : null + } + } + ] + } + ] + } + """ + XCTAssertEqual(String(data: filterJSON!, encoding: .utf8), expectedFilterJSON) + } } diff --git a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedBoolTests.swift b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedBoolTests.swift index e3e013d248..465ccfe256 100644 --- a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedBoolTests.swift +++ b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedBoolTests.swift @@ -41,7 +41,7 @@ class QueryPredicateEvaluateGeneratedBoolTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testBoolfalsenotEqualBooltrue() throws { @@ -70,7 +70,7 @@ class QueryPredicateEvaluateGeneratedBoolTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testBooltrueequalsBooltrue() throws { diff --git a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDateTests.swift b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDateTests.swift index ae7c9c8f34..5055bf2230 100644 --- a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDateTests.swift +++ b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDateTests.swift @@ -60,7 +60,7 @@ class QueryPredicateEvaluateGeneratedDateTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testTemporalDateTemporal_Date_now_addvalue1to_daynotEqualTemporalDateTemporal_Date_now() throws { @@ -109,7 +109,7 @@ class QueryPredicateEvaluateGeneratedDateTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testTemporalDateTemporal_Date_now_addvalue2to_daynotEqualTemporalDateTemporal_Date_now() throws { @@ -158,7 +158,7 @@ class QueryPredicateEvaluateGeneratedDateTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testTemporalDateTemporal_Date_now_addvalue3to_daynotEqualTemporalDateTemporal_Date_now() throws { @@ -207,7 +207,7 @@ class QueryPredicateEvaluateGeneratedDateTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testTemporalDateTemporal_Date_nowequalsTemporalDateTemporal_Date_now() throws { diff --git a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDateTimeTests.swift b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDateTimeTests.swift index 226b3d7908..14728e550a 100644 --- a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDateTimeTests.swift +++ b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDateTimeTests.swift @@ -66,7 +66,7 @@ class QueryPredicateEvaluateGeneratedDateTimeTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testTemporalDateTimeTemporal_DateTime_now_addvalue1to_hournotEqualTemporalDateTimeTemporal_DateTime_now() throws { @@ -120,7 +120,7 @@ class QueryPredicateEvaluateGeneratedDateTimeTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testTemporalDateTimeTemporal_DateTime_now_addvalue2to_hournotEqualTemporalDateTimeTemporal_DateTime_now() throws { @@ -174,7 +174,7 @@ class QueryPredicateEvaluateGeneratedDateTimeTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testTemporalDateTimeTemporal_DateTime_now_addvalue3to_hournotEqualTemporalDateTimeTemporal_DateTime_now() throws { @@ -228,7 +228,7 @@ class QueryPredicateEvaluateGeneratedDateTimeTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testTemporalDateTimeTemporal_DateTime_nowequalsTemporalDateTimeTemporal_DateTime_now() throws { diff --git a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDoubleIntTests.swift b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDoubleIntTests.swift index 31fe364268..8866439b17 100644 --- a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDoubleIntTests.swift +++ b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDoubleIntTests.swift @@ -50,7 +50,7 @@ class QueryPredicateEvaluateGeneratedDoubleIntTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testDouble2_1notEqualInt1() throws { @@ -89,7 +89,7 @@ class QueryPredicateEvaluateGeneratedDoubleIntTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testDouble3_1notEqualInt1() throws { @@ -128,7 +128,7 @@ class QueryPredicateEvaluateGeneratedDoubleIntTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testDouble1notEqualInt1() throws { @@ -167,7 +167,7 @@ class QueryPredicateEvaluateGeneratedDoubleIntTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testDouble2notEqualInt1() throws { @@ -206,7 +206,7 @@ class QueryPredicateEvaluateGeneratedDoubleIntTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testDouble3notEqualInt1() throws { @@ -245,7 +245,7 @@ class QueryPredicateEvaluateGeneratedDoubleIntTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testDouble1_1equalsInt1() throws { diff --git a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDoubleTests.swift b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDoubleTests.swift index fbc4c6566d..12f54572b6 100644 --- a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDoubleTests.swift +++ b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedDoubleTests.swift @@ -80,7 +80,7 @@ class QueryPredicateEvaluateGeneratedDoubleTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testDouble2_1notEqualDouble1_1() throws { @@ -149,7 +149,7 @@ class QueryPredicateEvaluateGeneratedDoubleTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testDouble3_1notEqualDouble1_1() throws { @@ -218,7 +218,7 @@ class QueryPredicateEvaluateGeneratedDoubleTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testDouble1notEqualDouble1_1() throws { @@ -287,7 +287,7 @@ class QueryPredicateEvaluateGeneratedDoubleTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testDouble2notEqualDouble1_1() throws { @@ -356,7 +356,7 @@ class QueryPredicateEvaluateGeneratedDoubleTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testDouble3notEqualDouble1_1() throws { @@ -425,7 +425,7 @@ class QueryPredicateEvaluateGeneratedDoubleTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testDouble1_1equalsDouble1_1() throws { diff --git a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedIntTests.swift b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedIntTests.swift index 1e5c4ec370..315c648ae1 100644 --- a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedIntTests.swift +++ b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedIntTests.swift @@ -54,7 +54,7 @@ class QueryPredicateEvaluateGeneratedIntBetweenTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testInt2notEqualInt1() throws { @@ -93,7 +93,7 @@ class QueryPredicateEvaluateGeneratedIntBetweenTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testInt3notEqualInt1() throws { @@ -132,7 +132,7 @@ class QueryPredicateEvaluateGeneratedIntBetweenTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testInt1equalsInt1() throws { diff --git a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedStringTests.swift b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedStringTests.swift index 1ce8eb039c..9408557689 100644 --- a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedStringTests.swift +++ b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedStringTests.swift @@ -64,7 +64,7 @@ class QueryPredicateEvaluateGeneratedStringTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testStringbbnotEqualStringa() throws { @@ -113,7 +113,7 @@ class QueryPredicateEvaluateGeneratedStringTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testStringaanotEqualStringa() throws { @@ -162,7 +162,7 @@ class QueryPredicateEvaluateGeneratedStringTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testStringcnotEqualStringa() throws { @@ -211,7 +211,7 @@ class QueryPredicateEvaluateGeneratedStringTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testStringaequalsStringa() throws { diff --git a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedTimeTests.swift b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedTimeTests.swift index 0caa2f6566..9ef76a5e68 100644 --- a/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedTimeTests.swift +++ b/AmplifyPlugins/Core/AWSPluginsCoreTests/Query/QueryPredicateEvaluateGeneratedTimeTests.swift @@ -69,7 +69,7 @@ class QueryPredicateEvaluateGeneratedTimeTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testTemporalTimeTemporal_Time_now_addvalue1to_hournotEqualTemporalTimeTemporal_Time_now() throws { @@ -123,7 +123,7 @@ class QueryPredicateEvaluateGeneratedTimeTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testTemporalTimeTemporal_Time_now_addvalue2to_hournotEqualTemporalTimeTemporal_Time_now() throws { @@ -177,7 +177,7 @@ class QueryPredicateEvaluateGeneratedTimeTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testTemporalTimeTemporal_Time_now_addvalue3to_hournotEqualTemporalTimeTemporal_Time_now() throws { @@ -231,7 +231,7 @@ class QueryPredicateEvaluateGeneratedTimeTests: XCTestCase { let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance) - XCTAssertFalse(evaluation) + XCTAssertTrue(evaluation) } func testTemporalTimeTemporal_Time_nowequalsTemporalTimeTemporal_Time_now() throws { diff --git a/AmplifyPlugins/DataStore/Sources/AWSDataStorePlugin/Storage/SQLite/QueryPredicate+SQLite.swift b/AmplifyPlugins/DataStore/Sources/AWSDataStorePlugin/Storage/SQLite/QueryPredicate+SQLite.swift index dcc528ae63..ff358da05d 100644 --- a/AmplifyPlugins/DataStore/Sources/AWSDataStorePlugin/Storage/SQLite/QueryPredicate+SQLite.swift +++ b/AmplifyPlugins/DataStore/Sources/AWSDataStorePlugin/Storage/SQLite/QueryPredicate+SQLite.swift @@ -33,6 +33,8 @@ extension QueryOperator { return "instr(\(column), ?) > 0" case .notContains: return "instr(\(column), ?) = 0" + case .attributeExists(let value): + return "\(column) is\(value ? " not" : "") null" } } @@ -51,6 +53,8 @@ extension QueryOperator { .beginsWith(let value), .notContains(let value): return [value.asBinding()] + case .attributeExists(let value): + return [value.asBinding()] } } } diff --git a/AmplifyPlugins/DataStore/Sources/AWSDataStorePlugin/Storage/SQLite/SQLStatement+Condition.swift b/AmplifyPlugins/DataStore/Sources/AWSDataStorePlugin/Storage/SQLite/SQLStatement+Condition.swift index 6d8d55136c..2a774992d3 100644 --- a/AmplifyPlugins/DataStore/Sources/AWSDataStorePlugin/Storage/SQLite/SQLStatement+Condition.swift +++ b/AmplifyPlugins/DataStore/Sources/AWSDataStorePlugin/Storage/SQLite/SQLStatement+Condition.swift @@ -76,9 +76,66 @@ private func translateQueryPredicate(from modelSchema: ModelSchema, return operation.field.quoted() } + func optimizeQueryPredicateGroup(_ predicate: QueryPredicate) -> QueryPredicate { + func rewritePredicate(_ predicate: QueryPredicate) -> QueryPredicate { + if let operation = predicate as? QueryPredicateOperation { + switch operation.operator { + case .attributeExists(let bool): + return QueryPredicateOperation( + field: operation.field, + operator: bool ? .notEqual(nil) : .equals(nil) + ) + default: + return operation + } + } else if let group = predicate as? QueryPredicateGroup { + return optimizeQueryPredicateGroup(group) + } + + return predicate + } + + func removeDuplicatePredicate(_ predicates: [QueryPredicate]) -> [QueryPredicate] { + var result = [QueryPredicate]() + for predicate in predicates { + let hasSameExpression = result.reduce(false) { + if $0 { return $0 } + switch ($1, predicate) { + case let (lhs as QueryPredicateOperation, rhs as QueryPredicateOperation): + return lhs == rhs + case let (lhs as QueryPredicateGroup, rhs as QueryPredicateGroup): + return lhs == rhs + default: + return false + } + } + + if !hasSameExpression { + result.append(predicate) + } + } + return result + } + + switch predicate { + case let predicate as QueryPredicateGroup: + let optimizedPredicates = removeDuplicatePredicate(predicate.predicates.reduce([]) { + $0 + [rewritePredicate($1)] + }) + + if optimizedPredicates.count == 1 { + return optimizedPredicates.first! + } else { + return QueryPredicateGroup(type: predicate.type, predicates: optimizedPredicates) + } + default: + return predicate + } + } + // the very first `and` is always prepended, using -1 for if statement checking // the very first `and` is to connect `where` clause with translated QueryPredicate - translate(predicate, predicateIndex: -1, groupType: .and) + translate(optimizeQueryPredicateGroup(predicate), predicateIndex: -1, groupType: .and) return (sql.joined(separator: "\n"), bindings) } diff --git a/AmplifyPlugins/DataStore/Tests/AWSDataStorePluginTests/Core/SQLStatementTests.swift b/AmplifyPlugins/DataStore/Tests/AWSDataStorePluginTests/Core/SQLStatementTests.swift index 19806595c8..149edd2edb 100644 --- a/AmplifyPlugins/DataStore/Tests/AWSDataStorePluginTests/Core/SQLStatementTests.swift +++ b/AmplifyPlugins/DataStore/Tests/AWSDataStorePluginTests/Core/SQLStatementTests.swift @@ -1350,4 +1350,101 @@ class SQLStatementTests: XCTestCase { XCTAssertEqual(statement.stringValue, expectStatement) XCTAssertEqual(variables[0] as? String, expectedVariable) } + + + /// Given: a query predicate of attributeExists + /// When: the bind value is false + /// Then: generate the correct SQL query statement + func test_translateAttributeExistsFalseQueryPredicate() { + let post = Post.keys + + let predicate = post.id.attributeExists(false) + let statement = ConditionStatement(modelSchema: Post.schema, predicate: predicate, namespace: "root") + let expectedStatement = + """ + and "root"."id" is null + """ + XCTAssertEqual(statement.stringValue, expectedStatement) + } + + /// Given: a query predicate of attributeExists + /// When: the bind value is true + /// Then: generate the correct SQL query statement + func test_translateAttributeExistsTrueQueryPredicate() { + let post = Post.keys + + let predicate = post.id.attributeExists(true) + let statement = ConditionStatement(modelSchema: Post.schema, predicate: predicate, namespace: "root") + let expectedStatement = + """ + and "root"."id" is not null + """ + XCTAssertEqual(statement.stringValue, expectedStatement) + } + + /// Given: a combined query predicate of attributeExists and ne + /// When: attributeExists(true) && ne(nil) + /// Then: generate the correct SQL query statement + func test_translateCombinedQueryPredicateOfAttributeExistsTrueAndNeNil() { + let post = Post.keys + + let predicate = post.id.attributeExists(true) && post.id.ne(nil) + let statement = ConditionStatement(modelSchema: Post.schema, predicate: predicate, namespace: "root") + let expectedStatement = + """ + and "root"."id" is not null + """ + XCTAssertEqual(statement.stringValue, expectedStatement) + } + + /// Given: a combined query predicate of attributeExists and ne + /// When: attributeExists(false) && ne(nil) + /// Then: generate the correct SQL query statement + func test_translateCombinedQueryPredicateOfAttributeExistsFalseAndNeNil() { + let post = Post.keys + + let predicate = post.id.attributeExists(false) && post.id.ne(nil) + let statement = ConditionStatement(modelSchema: Post.schema, predicate: predicate, namespace: "root") + let expectedStatement = + """ + and ( + "root"."id" is null + and "root"."id" is not null + ) + """ + XCTAssertEqual(statement.stringValue, expectedStatement) + } + + /// Given: a combined query predicate of attributeExists and eq + /// When: attributeExists(false) || eq(nil) + /// Then: generate the correct SQL query statement + func test_translateCombinedQueryPredicateOfAttributeExistsFalseOrEqNil() { + let post = Post.keys + + let predicate = post.id.attributeExists(false) || post.id.eq(nil) + let statement = ConditionStatement(modelSchema: Post.schema, predicate: predicate, namespace: "root") + let expectedStatement = + """ + and "root"."id" is null + """ + XCTAssertEqual(statement.stringValue, expectedStatement) + } + + /// Given: a combined query predicate of attributeExists and eq + /// When: attributeExists(true) || eq(nil) + /// Then: generate the correct SQL query statement + func test_translateCombinedQueryPredicateOfAttributeExistsTrueOrEqNil() { + let post = Post.keys + + let predicate = post.id.attributeExists(true) || post.id.eq(nil) + let statement = ConditionStatement(modelSchema: Post.schema, predicate: predicate, namespace: "root") + let expectedStatement = + """ + and ( + "root"."id" is not null + or "root"."id" is null + ) + """ + XCTAssertEqual(statement.stringValue, expectedStatement) + } } diff --git a/AmplifyPlugins/DataStore/Tests/AWSDataStorePluginTests/Sync/Support/MutationEventQueryTests.swift b/AmplifyPlugins/DataStore/Tests/AWSDataStorePluginTests/Sync/Support/MutationEventQueryTests.swift index 9e1eb5df65..e982872609 100644 --- a/AmplifyPlugins/DataStore/Tests/AWSDataStorePluginTests/Sync/Support/MutationEventQueryTests.swift +++ b/AmplifyPlugins/DataStore/Tests/AWSDataStorePluginTests/Sync/Support/MutationEventQueryTests.swift @@ -58,7 +58,7 @@ class MutationEventQueryTests: BaseDataStoreTests { wait(for: [querySuccess], timeout: 1) } - func testQueryPendingMutationEventsForModelIds() { + func testQueryPendingMutationEventsForModelIds() async { let mutationEvent1 = generateRandomMutationEvent() let mutationEvent2 = generateRandomMutationEvent() @@ -70,7 +70,7 @@ class MutationEventQueryTests: BaseDataStoreTests { } saveMutationEvent1.fulfill() } - wait(for: [saveMutationEvent1], timeout: 1) + await fulfillment(of: [saveMutationEvent1], timeout: 1) let saveMutationEvent2 = expectation(description: "save mutationEvent1 success") storageAdapter.save(mutationEvent2) { result in @@ -80,7 +80,7 @@ class MutationEventQueryTests: BaseDataStoreTests { } saveMutationEvent2.fulfill() } - wait(for: [saveMutationEvent2], timeout: 1) + await fulfillment(of: [saveMutationEvent2], timeout: 1) let querySuccess = expectation(description: "query for metadata success") var mutationEvents = [mutationEvent1] @@ -98,7 +98,7 @@ class MutationEventQueryTests: BaseDataStoreTests { } } - wait(for: [querySuccess], timeout: 1) + await fulfillment(of: [querySuccess], timeout: 5) } private func generateRandomMutationEvent() -> MutationEvent {