Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(api): add support for GraphQL filter attributeExists #3838

Merged
merged 3 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions Amplify/Categories/DataStore/Model/Internal/Persistable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ struct PersistableHelper {
return lhs == rhs
case let (lhs, rhs) as (String, String):
return lhs == rhs
case let (lhs, rhs) as (any EnumPersistable, String):
return lhs.rawValue == rhs
case let (lhs, rhs) as (String, any EnumPersistable):
return lhs == rhs.rawValue
case let (lhs, rhs) as (any EnumPersistable, any EnumPersistable):
return lhs.rawValue == rhs.rawValue
default:
return false
}
Expand Down Expand Up @@ -94,6 +100,12 @@ struct PersistableHelper {
return lhs == Double(rhs)
case let (lhs, rhs) as (String, String):
return lhs == rhs
case let (lhs, rhs) as (any EnumPersistable, String):
return lhs.rawValue == rhs
case let (lhs, rhs) as (String, any EnumPersistable):
return lhs == rhs.rawValue
case let (lhs, rhs) as (any EnumPersistable, any EnumPersistable):
return lhs.rawValue == rhs.rawValue
default:
return false
}
Expand Down Expand Up @@ -122,6 +134,12 @@ struct PersistableHelper {
return lhs <= Double(rhs)
case let (lhs, rhs) as (String, String):
return lhs <= rhs
case let (lhs, rhs) as (any EnumPersistable, String):
return lhs.rawValue <= rhs
case let (lhs, rhs) as (String, any EnumPersistable):
return lhs <= rhs.rawValue
case let (lhs, rhs) as (any EnumPersistable, any EnumPersistable):
return lhs.rawValue <= rhs.rawValue
default:
return false
}
Expand Down Expand Up @@ -150,6 +168,12 @@ struct PersistableHelper {
return lhs < Double(rhs)
case let (lhs, rhs) as (String, String):
return lhs < rhs
case let (lhs, rhs) as (any EnumPersistable, String):
return lhs.rawValue < rhs
case let (lhs, rhs) as (String, any EnumPersistable):
return lhs < rhs.rawValue
case let (lhs, rhs) as (any EnumPersistable, any EnumPersistable):
return lhs.rawValue < rhs.rawValue
default:
return false
}
Expand Down Expand Up @@ -178,6 +202,12 @@ struct PersistableHelper {
return lhs >= Double(rhs)
case let (lhs, rhs) as (String, String):
return lhs >= rhs
case let (lhs, rhs) as (any EnumPersistable, String):
return lhs.rawValue >= rhs
case let (lhs, rhs) as (String, any EnumPersistable):
return lhs >= rhs.rawValue
case let (lhs, rhs) as (any EnumPersistable, any EnumPersistable):
return lhs.rawValue >= rhs.rawValue
default:
return false
}
Expand Down Expand Up @@ -206,6 +236,12 @@ struct PersistableHelper {
return Double(lhs) > rhs
case let (lhs, rhs) as (String, String):
return lhs > rhs
case let (lhs, rhs) as (any EnumPersistable, String):
return lhs.rawValue > rhs
case let (lhs, rhs) as (String, any EnumPersistable):
return lhs > rhs.rawValue
case let (lhs, rhs) as (any EnumPersistable, any EnumPersistable):
return lhs.rawValue > rhs.rawValue
default:
return false
}
Expand Down
5 changes: 5 additions & 0 deletions Amplify/Categories/DataStore/Query/ModelKey.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ public protocol ModelKey: CodingKey, CaseIterable, QueryFieldOperation {}

extension CodingKey where Self: ModelKey {

// MARK: - attributeExists
public func attributeExists(_ value: Bool) -> QueryPredicateOperation {
return field(stringValue).attributeExists(value)
}

// MARK: - beginsWith
public func beginsWith(_ value: String) -> QueryPredicateOperation {
return field(stringValue).beginsWith(value)
Expand Down
7 changes: 6 additions & 1 deletion Amplify/Categories/DataStore/Query/QueryField.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public func field(_ name: String) -> QueryField {
/// - seealso: `ModelKey`
public protocol QueryFieldOperation {
// MARK: - Functions

func attributeExists(_ value: Bool) -> QueryPredicateOperation
func beginsWith(_ value: String) -> QueryPredicateOperation
func between(start: Persistable, end: Persistable) -> QueryPredicateOperation
func contains(_ value: String) -> QueryPredicateOperation
Expand Down Expand Up @@ -61,6 +61,11 @@ public struct QueryField: QueryFieldOperation {

public let name: String

// MARK: - attributeExists
public func attributeExists(_ value: Bool) -> QueryPredicateOperation {
return QueryPredicateOperation(field: name, operator: .attributeExists(value))
}

// MARK: - beginsWith
public func beginsWith(_ value: String) -> QueryPredicateOperation {
return QueryPredicateOperation(field: name, operator: .beginsWith(value))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ extension QueryOperator: Equatable {
case let (.between(oneStart, oneEnd), .between(otherStart, otherEnd)):
return PersistableHelper.isEqual(oneStart, otherStart)
&& PersistableHelper.isEqual(oneEnd, otherEnd)
case let (.attributeExists(one), .attributeExists(other)):
return one == other
default:
return false
}
Expand Down
19 changes: 15 additions & 4 deletions Amplify/Categories/DataStore/Query/QueryOperator.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ public enum QueryOperator: Encodable {
case notContains(_ value: String)
case between(start: Persistable, end: Persistable)
case beginsWith(_ value: String)
case attributeExists(_ value: Bool)

public func evaluate(target: Any) -> Bool {
public func evaluate(target: Any?) -> Bool {
switch self {
case .notEqual(let predicateValue):
return !PersistableHelper.isEqual(target, predicateValue)
Expand All @@ -34,20 +35,26 @@ public enum QueryOperator: Encodable {
case .greaterThan(let predicateValue):
return PersistableHelper.isGreaterThan(target, predicateValue)
case .contains(let predicateString):
if let targetString = target as? String {
if let targetString = target.flatMap({ $0 as? String }) {
return targetString.contains(predicateString)
}
return false
case .notContains(let predicateString):
if let targetString = target as? String {
if let targetString = target.flatMap({ $0 as? String }) {
return !targetString.contains(predicateString)
}
case .between(let start, let end):
return PersistableHelper.isBetween(start, end, target)
case .beginsWith(let predicateValue):
if let targetString = target as? String {
if let targetString = target.flatMap({ $0 as? String }) {
return targetString.starts(with: predicateValue)
}
case .attributeExists(let predicateValue):
if case .some = target {
return predicateValue == true
} else {
return predicateValue == false
}
}
return false
}
Expand Down Expand Up @@ -105,6 +112,10 @@ public enum QueryOperator: Encodable {
case .beginsWith(let value):
try container.encode("beginsWith", forKey: .type)
try container.encode(value, forKey: .value)

case .attributeExists(let value):
try container.encode("attributeExists", forKey: .type)
try container.encode(value, forKey: .value)
}
}
}
30 changes: 1 addition & 29 deletions Amplify/Categories/DataStore/Query/QueryPredicate.swift
Original file line number Diff line number Diff line change
Expand Up @@ -155,34 +155,6 @@ public class QueryPredicateOperation: QueryPredicate, Encodable {
}

public func evaluate(target: Model) -> Bool {
guard let fieldValue = target[field] else {
return false
}

guard let value = fieldValue else {
return false
}

if let booleanValue = value as? Bool {
return self.operator.evaluate(target: booleanValue)
}

if let doubleValue = value as? Double {
return self.operator.evaluate(target: doubleValue)
}

if let intValue = value as? Int {
return self.operator.evaluate(target: intValue)
}

if let timeValue = value as? Temporal.Time {
return self.operator.evaluate(target: timeValue)
}

if let enumValue = value as? EnumPersistable {
return self.operator.evaluate(target: enumValue.rawValue)
}

return self.operator.evaluate(target: value)
return self.operator.evaluate(target: target[field]?.flatMap { $0 })
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,56 @@ extension GraphQLModelBasedTests {
XCTAssertNotNil(error)
}
}

/**
- Given: API with Post schema and optional field 'draft'
- When:
- create a new post with optional field 'draft' value .none
- Then:
- query Posts with filter {eq : null} shouldn't include the post
*/
func test_listModelsWithNilOptionalField_failedWithEqFilter() async throws {
let post = Post(title: UUID().uuidString, content: UUID().uuidString, createdAt: .now())
_ = try await Amplify.API.mutate(request: .create(post))
let posts = try await list(.list(
Post.self,
where: Post.keys.draft == nil && Post.keys.createdAt >= post.createdAt
))

XCTAssertFalse(posts.map(\.id).contains(post.id))
}

/**
- Given: DataStore with Post schema and optional field 'draft'
- When:
- create a new post with optional field 'draft' value .none
- Then:
- query Posts with filter {attributeExists : false} should include the post
*/
func test_listModelsWithNilOptionalField_successWithAttributeExistsFilter() async throws {
let post = Post(title: UUID().uuidString, content: UUID().uuidString, createdAt: .now())
_ = try await Amplify.API.mutate(request: .create(post))
let listPosts = try await list(
.list(
Post.self,
where: Post.keys.draft.attributeExists(false)
&& Post.keys.createdAt >= post.createdAt
)
)

XCTAssertTrue(listPosts.map(\.id).contains(post.id))
}

func list<M: Model>(_ request: GraphQLRequest<List<M>>) async throws -> [M] {
func getAllPages(_ list: List<M>) async throws -> [M] {
if list.hasNextPage() {
return list.elements + (try await getAllPages(list.getNextPage()))
} else {
return list.elements
}
}

return try await getAllPages(try await Amplify.API.query(request: request).get())
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ extension QueryOperator {
return "beginsWith"
case .notContains:
return "notContains"
case .attributeExists:
return "attributeExists"
}
}

Expand All @@ -212,6 +214,8 @@ extension QueryOperator {
return value
case .notContains(let value):
return value
case .attributeExists(let value):
return value
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,4 +218,88 @@ class GraphQLListQueryTests: XCTestCase {
XCTAssertEqual(variables["limit"] as? Int, 1_000)
XCTAssertNotNil(variables["filter"])
}

/**
- Given:
- A Post schema with optional field 'draft'
- When:
- Using list query to filter records that either don't have 'draft' field or have 'null' value
- Then:
- the query document as expected
- the filter is encoded correctly
*/
func test_listQuery_withAttributeExistsFilter_correctlyBuildGraphQLQueryStatement() {
let post = Post.keys
let predicate = post.id.eq("id")
&& (post.draft.attributeExists(false) || post.draft.eq(nil))

var documentBuilder = ModelBasedGraphQLDocumentBuilder(modelSchema: Post.schema, operationType: .query)
documentBuilder.add(decorator: DirectiveNameDecorator(type: .list))
documentBuilder.add(decorator: PaginationDecorator())
documentBuilder.add(decorator: FilterDecorator(filter: predicate.graphQLFilter(for: Post.schema)))
let document = documentBuilder.build()
let expectedQueryDocument = """
query ListPosts($filter: ModelPostFilterInput, $limit: Int) {
listPosts(filter: $filter, limit: $limit) {
items {
id
content
createdAt
draft
rating
status
title
updatedAt
__typename
}
nextToken
}
}
"""
XCTAssertEqual(document.name, "listPosts")
XCTAssertEqual(document.stringValue, expectedQueryDocument)
guard let variables = document.variables else {
XCTFail("The document doesn't contain variables")
return
}
XCTAssertNotNil(variables["limit"])
XCTAssertEqual(variables["limit"] as? Int, 1_000)

guard let filter = variables["filter"] as? GraphQLFilter else {
XCTFail("variables should contain a valid filter")
return
}

// Test filter for a valid JSON format
let filterJSON = try? JSONSerialization.data(withJSONObject: filter,
options: .prettyPrinted)
XCTAssertNotNil(filterJSON)

let expectedFilterJSON = """
{
"and" : [
{
"id" : {
"eq" : "id"
}
},
{
"or" : [
{
"draft" : {
"attributeExists" : false
}
},
{
"draft" : {
"eq" : null
}
}
]
}
]
}
"""
XCTAssertEqual(String(data: filterJSON!, encoding: .utf8), expectedFilterJSON)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class QueryPredicateEvaluateGeneratedBoolTests: XCTestCase {

let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance)

XCTAssertFalse(evaluation)
XCTAssertTrue(evaluation)
}

func testBoolfalsenotEqualBooltrue() throws {
Expand Down Expand Up @@ -70,7 +70,7 @@ class QueryPredicateEvaluateGeneratedBoolTests: XCTestCase {

let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance)

XCTAssertFalse(evaluation)
XCTAssertTrue(evaluation)
}

func testBooltrueequalsBooltrue() throws {
Expand Down
Loading
Loading