From dabaadd1f8cda3f3f55a9f7cc683ffbe6aa6081b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Alleyne?= Date: Wed, 16 Oct 2024 16:07:14 -0400 Subject: [PATCH] IO Implementation using Go CDK MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Loïc Alleyne --- .asf.yaml | 92 +- .gitattributes | 44 +- .github/ISSUE_TEMPLATE/iceberg_bug_report.yml | 80 +- .../ISSUE_TEMPLATE/iceberg_improvement.yml | 54 +- .github/ISSUE_TEMPLATE/iceberg_question.yml | 62 +- .github/dependabot.yml | 48 +- .github/labeler.yml | 50 +- .github/workflows/go-ci.yml | 122 +- .github/workflows/go-integration.yml | 142 +- .github/workflows/labeler.yml | 64 +- .github/workflows/license_check.yml | 54 +- .github/workflows/rc.yml | 254 +- .gitignore | 116 +- .pre-commit-config.yaml | 52 +- LICENSE | 630 +-- NOTICE | 16 +- README.md | 170 +- catalog/README.md | 244 +- catalog/catalog.go | 374 +- catalog/glue.go | 510 +-- catalog/glue_test.go | 398 +- catalog/rest.go | 1424 +++---- catalog/rest_internal_test.go | 246 +- catalog/rest_test.go | 1634 ++++---- cmd/iceberg/main.go | 686 ++-- cmd/iceberg/output.go | 440 +- cmd/iceberg/output_test.go | 390 +- dev/Dockerfile | 48 +- dev/check-license | 166 +- dev/docker-compose.yml | 182 +- dev/provision.py | 756 ++-- dev/release/README.md | 210 +- dev/release/check_rat_report.py | 118 +- dev/release/rat_exclude_files.txt | 46 +- dev/release/release.sh | 166 +- dev/release/release_rc.sh | 280 +- dev/release/run_rat.sh | 108 +- dev/release/verify_rc.sh | 412 +- errors.go | 64 +- exprs.go | 2058 +++++----- exprs_test.go | 1484 +++---- go.mod | 43 +- go.sum | 202 +- internal/avro_schemas.go | 1136 +++--- internal/mock_fs.go | 170 +- io/blob.go | 333 ++ io/gcs_cdk.go | 63 + io/io.go | 514 +-- io/local.go | 70 +- io/s3.go | 256 +- io/s3_cdk.go | 45 + literals.go | 2340 +++++------ literals_test.go | 2008 ++++----- manifest.go | 1936 ++++----- manifest_test.go | 1544 +++---- operation_string.go | 82 +- partitions.go | 464 +-- partitions_test.go | 286 +- predicates.go | 276 +- schema.go | 2334 +++++------ schema_test.go | 1516 +++---- table/arrow_utils.go | 824 ++-- table/arrow_utils_test.go | 742 ++-- table/evaluators.go | 2250 +++++------ table/evaluators_test.go | 3598 ++++++++--------- table/metadata.go | 934 ++--- table/metadata_test.go | 988 ++--- table/name_mapping.go | 592 +-- table/name_mapping_test.go | 290 +- table/refs.go | 136 +- table/refs_test.go | 144 +- table/scanner.go | 790 ++-- table/scanner_test.go | 242 +- table/snapshots.go | 392 +- table/snapshots_test.go | 230 +- table/sorting.go | 354 +- table/sorting_test.go | 220 +- table/table.go | 224 +- table/table_test.go | 260 +- transforms.go | 1742 ++++---- transforms_test.go | 178 +- types.go | 1278 +++--- types_test.go | 472 +-- utils.go | 396 +- visitors.go | 794 ++-- visitors_test.go | 1214 +++--- 86 files changed, 25049 insertions(+), 24347 deletions(-) create mode 100644 io/blob.go create mode 100644 io/gcs_cdk.go create mode 100644 io/s3_cdk.go diff --git a/.asf.yaml b/.asf.yaml index e735371..f210e47 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -1,46 +1,46 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# The format of this file is documented at -# https://cwiki.apache.org/confluence/display/INFRA/Git+-+.asf.yaml+features - -github: - description: "Apache Iceberg - Go" - homepage: https://iceberg.apache.org/ - labels: - - iceberg - - apache - - golang - features: - issues: true - - enabled_merge_buttons: - squash: true - merge: false - rebase: false - - protected_branches: - main: - required_pull_request_reviews: - required_approving_review_count: 1 - - required_linear_history: true - -notifications: - commits: commits@iceberg.apache.org - issues: issues@iceberg.apache.org - pullrequests: issues@iceberg.apache.org +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# The format of this file is documented at +# https://cwiki.apache.org/confluence/display/INFRA/Git+-+.asf.yaml+features + +github: + description: "Apache Iceberg - Go" + homepage: https://iceberg.apache.org/ + labels: + - iceberg + - apache + - golang + features: + issues: true + + enabled_merge_buttons: + squash: true + merge: false + rebase: false + + protected_branches: + main: + required_pull_request_reviews: + required_approving_review_count: 1 + + required_linear_history: true + +notifications: + commits: commits@iceberg.apache.org + issues: issues@iceberg.apache.org + pullrequests: issues@iceberg.apache.org diff --git a/.gitattributes b/.gitattributes index 2ca0612..0cfe36d 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,23 +1,23 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# Files marked as export-ignore will be ignored from the release's -# built via `git archive`. This simplifies the release script and -# uses an industry standard as opposed to possibly hard to read -# shell scripts with many flags. Unfortunately, directories themselves -# won't recursively ignore, so we need the top level directories +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Files marked as export-ignore will be ignored from the release's +# built via `git archive`. This simplifies the release script and +# uses an industry standard as opposed to possibly hard to read +# shell scripts with many flags. Unfortunately, directories themselves +# won't recursively ignore, so we need the top level directories # as well as their files. \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/iceberg_bug_report.yml b/.github/ISSUE_TEMPLATE/iceberg_bug_report.yml index 512bebc..e7cff76 100644 --- a/.github/ISSUE_TEMPLATE/iceberg_bug_report.yml +++ b/.github/ISSUE_TEMPLATE/iceberg_bug_report.yml @@ -1,40 +1,40 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - ---- -name: Iceberg Bug report 🐞 -description: Problems, bugs and issues with Apache Iceberg -labels: ["kind:bug"] -body: - - type: dropdown - attributes: - label: Apache Iceberg version - description: What Apache Iceberg version are you using? - multiple: false - options: - - "main (development)" - validations: - required: false - - type: textarea - attributes: - label: Please describe the bug 🐞 - description: > - Please describe the problem, what to expect, and how to reproduce. - Feel free to include stacktraces and the Iceberg catalog configuration. - You can include files by dragging and dropping them here. - validations: - required: true +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +name: Iceberg Bug report 🐞 +description: Problems, bugs and issues with Apache Iceberg +labels: ["kind:bug"] +body: + - type: dropdown + attributes: + label: Apache Iceberg version + description: What Apache Iceberg version are you using? + multiple: false + options: + - "main (development)" + validations: + required: false + - type: textarea + attributes: + label: Please describe the bug 🐞 + description: > + Please describe the problem, what to expect, and how to reproduce. + Feel free to include stacktraces and the Iceberg catalog configuration. + You can include files by dragging and dropping them here. + validations: + required: true diff --git a/.github/ISSUE_TEMPLATE/iceberg_improvement.yml b/.github/ISSUE_TEMPLATE/iceberg_improvement.yml index 60eddb5..8e68940 100644 --- a/.github/ISSUE_TEMPLATE/iceberg_improvement.yml +++ b/.github/ISSUE_TEMPLATE/iceberg_improvement.yml @@ -1,28 +1,28 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - ---- -name: Iceberg Improvement / Feature Request -description: New features with Apache Iceberg -labels: ["kind:feature request"] -body: - - type: textarea - attributes: - label: Feature Request / Improvement - description: Please describe the feature and elaborate on the use case and motivation behind it - validations: +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +name: Iceberg Improvement / Feature Request +description: New features with Apache Iceberg +labels: ["kind:feature request"] +body: + - type: textarea + attributes: + label: Feature Request / Improvement + description: Please describe the feature and elaborate on the use case and motivation behind it + validations: required: true \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/iceberg_question.yml b/.github/ISSUE_TEMPLATE/iceberg_question.yml index a6111bb..b8d441d 100644 --- a/.github/ISSUE_TEMPLATE/iceberg_question.yml +++ b/.github/ISSUE_TEMPLATE/iceberg_question.yml @@ -1,31 +1,31 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - ---- -name: Iceberg Question -description: Questions around Apache Iceberg -labels: ["kind:question"] -body: - - type: markdown - attributes: - value: "Feel free to ask your question on [Slack](https://join.slack.com/t/apache-iceberg/shared_invite/zt-1uva9gyp1-TrLQl7o~nZ5PsTVgl6uoEQ) as well." - - type: textarea - attributes: - label: Question - description: What is your question? - validations: - required: true +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +name: Iceberg Question +description: Questions around Apache Iceberg +labels: ["kind:question"] +body: + - type: markdown + attributes: + value: "Feel free to ask your question on [Slack](https://join.slack.com/t/apache-iceberg/shared_invite/zt-1uva9gyp1-TrLQl7o~nZ5PsTVgl6uoEQ) as well." + - type: textarea + attributes: + label: Question + description: What is your question? + validations: + required: true diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 14f8e04..4d80f74 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,24 +1,24 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -version: 2 -updates: - - package-ecosystem: gomod - directory: / - schedule: - interval: "weekly" - day: "sunday" +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +version: 2 +updates: + - package-ecosystem: gomod + directory: / + schedule: + interval: "weekly" + day: "sunday" diff --git a/.github/labeler.yml b/.github/labeler.yml index 4cd7825..fd89e16 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -1,25 +1,25 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# Pull request labeler Github Action Configuration: https://github.com/marketplace/actions/labeler -INFRA: - - .asf.yaml - - .gitattributes - - .gitignore - - .github/**/* - - .pre-commit-config.yaml - - dev/**/* +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Pull request labeler Github Action Configuration: https://github.com/marketplace/actions/labeler +INFRA: + - .asf.yaml + - .gitattributes + - .gitignore + - .github/**/* + - .pre-commit-config.yaml + - dev/**/* diff --git a/.github/workflows/go-ci.yml b/.github/workflows/go-ci.yml index 2b158b5..f400467 100644 --- a/.github/workflows/go-ci.yml +++ b/.github/workflows/go-ci.yml @@ -1,61 +1,61 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -name: Go - -on: - push: - branches: - - 'main' - tags: - - 'v**' - pull_request: - -concurrency: - group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} - cancel-in-progress: ${{ github.event_name == 'pull_request' }} - -permissions: - contents: read - -jobs: - lint-and-test: - name: ${{ matrix.os }} go${{ matrix.go }} - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - go: [ '1.22', '1.23' ] - os: [ 'ubuntu-latest', 'windows-latest', 'macos-latest' ] - steps: - - uses: actions/checkout@v4 - - name: Install Go - uses: actions/setup-go@v4 - with: - go-version: ${{ matrix.go }} - cache: true - cache-dependency-path: go.sum - - name: Install staticcheck - if: matrix.go == '1.22' - run: go install honnef.co/go/tools/cmd/staticcheck@v0.4.7 - - name: Install staticcheck - if: matrix.go == '1.23' - run: go install honnef.co/go/tools/cmd/staticcheck@v0.5.1 - - name: Lint - run: staticcheck ./... - - name: Run tests - run: go test -v ./... +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: Go + +on: + push: + branches: + - 'main' + tags: + - 'v**' + pull_request: + +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: ${{ github.event_name == 'pull_request' }} + +permissions: + contents: read + +jobs: + lint-and-test: + name: ${{ matrix.os }} go${{ matrix.go }} + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + go: [ '1.22', '1.23' ] + os: [ 'ubuntu-latest', 'windows-latest', 'macos-latest' ] + steps: + - uses: actions/checkout@v4 + - name: Install Go + uses: actions/setup-go@v4 + with: + go-version: ${{ matrix.go }} + cache: true + cache-dependency-path: go.sum + - name: Install staticcheck + if: matrix.go == '1.22' + run: go install honnef.co/go/tools/cmd/staticcheck@v0.4.7 + - name: Install staticcheck + if: matrix.go == '1.23' + run: go install honnef.co/go/tools/cmd/staticcheck@v0.5.1 + - name: Lint + run: staticcheck ./... + - name: Run tests + run: go test -v ./... diff --git a/.github/workflows/go-integration.yml b/.github/workflows/go-integration.yml index 7dd5e75..ac9faa8 100644 --- a/.github/workflows/go-integration.yml +++ b/.github/workflows/go-integration.yml @@ -1,71 +1,71 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -name: "Go Integration" - -on: - push: - branches: - - 'main' - tags: - - 'v**' - pull_request: - -concurrency: - group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} - cancel-in-progress: ${{ github.event_name == 'pull_request' }} - -permissions: - contents: read - -jobs: - integration-test: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - with: - fetch-depth: 2 - - name: Install Go - uses: actions/setup-go@v4 - with: - go-version: 1.23 - cache: true - cache-dependency-path: go.sum - - - name: Start docker - run: | - docker compose -f dev/docker-compose.yml up -d - sleep 10 - - name: Provision Tables - run: | - docker compose -f dev/docker-compose.yml exec -T spark-iceberg ipython ./provision.py - sleep 10 - - - name: Get minio container IP - run: | - echo "AWS_S3_ENDPOINT=http://$(docker inspect -f '{{range.NetworkSettings.Networks}}{{.IPAddress}}{{end}}' minio):9000" >> $GITHUB_ENV - - - name: Run integration tests - env: - AWS_S3_ENDPOINT: "${{ env.AWS_S3_ENDPOINT }}" - AWS_REGION: "us-east-1" - run: | - go test -tags integration -v -run="^TestScanner" ./table - - - name: Show debug logs - if: ${{ failure() }} - run: docker compose -f dev/docker-compose.yml logs +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: "Go Integration" + +on: + push: + branches: + - 'main' + tags: + - 'v**' + pull_request: + +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: ${{ github.event_name == 'pull_request' }} + +permissions: + contents: read + +jobs: + integration-test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 2 + - name: Install Go + uses: actions/setup-go@v4 + with: + go-version: 1.23 + cache: true + cache-dependency-path: go.sum + + - name: Start docker + run: | + docker compose -f dev/docker-compose.yml up -d + sleep 10 + - name: Provision Tables + run: | + docker compose -f dev/docker-compose.yml exec -T spark-iceberg ipython ./provision.py + sleep 10 + + - name: Get minio container IP + run: | + echo "AWS_S3_ENDPOINT=http://$(docker inspect -f '{{range.NetworkSettings.Networks}}{{.IPAddress}}{{end}}' minio):9000" >> $GITHUB_ENV + + - name: Run integration tests + env: + AWS_S3_ENDPOINT: "${{ env.AWS_S3_ENDPOINT }}" + AWS_REGION: "us-east-1" + run: | + go test -tags integration -v -run="^TestScanner" ./table + + - name: Show debug logs + if: ${{ failure() }} + run: docker compose -f dev/docker-compose.yml logs diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 4a5b455..97daa64 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -1,32 +1,32 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -name: "Pull Request Labeler" -on: pull_request_target - -permissions: - contents: read - pull-requests: write - -jobs: - triage: - runs-on: ubuntu-22.04 - steps: - - uses: actions/labeler@v4 - with: - repo-token: "${{ secrets.GITHUB_TOKEN }}" - sync-labels: true +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: "Pull Request Labeler" +on: pull_request_target + +permissions: + contents: read + pull-requests: write + +jobs: + triage: + runs-on: ubuntu-22.04 + steps: + - uses: actions/labeler@v4 + with: + repo-token: "${{ secrets.GITHUB_TOKEN }}" + sync-labels: true diff --git a/.github/workflows/license_check.yml b/.github/workflows/license_check.yml index d727084..86ae21f 100644 --- a/.github/workflows/license_check.yml +++ b/.github/workflows/license_check.yml @@ -1,27 +1,27 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -name: "Run License Check" -on: pull_request - -jobs: - rat: - runs-on: ubuntu-22.04 - steps: - - uses: actions/checkout@v3 - - run: | - dev/check-license +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: "Run License Check" +on: pull_request + +jobs: + rat: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v3 + - run: | + dev/check-license diff --git a/.github/workflows/rc.yml b/.github/workflows/rc.yml index 5a407d8..df350db 100644 --- a/.github/workflows/rc.yml +++ b/.github/workflows/rc.yml @@ -1,127 +1,127 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -name: RC -on: - push: - branches: - - '**' - - '!dependabot/**' - tags: - - '*-rc*' - pull_request: - -concurrency: - group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} - cancel-in-progress: true -permissions: - contents: read - -jobs: - archive: - name: Archive - runs-on: ubuntu-latest - timeout-minutes: 5 - steps: - - name: Checkout - uses: actions/checkout@v4 - - name: Prepare for tag - if: github.ref_type == 'tag' - run: | - version=${GITHUB_REF_NAME%-rc} - version=${version#v} - rc=${GITHUB_REF_NAME#*-rc} - echo "VERSION=${version}" >> ${GITHUB_ENV} - echo "RC=${rc}" >> ${GITHUB_ENV} - - name: Prepare for branch - if: github.ref_type == 'branch' - run: | - rc=100 - echo "VERSION=${version}" >> ${GITHUB_ENV} - echo "RC=${rc}" >> ${GITHUB_ENV} - - name: Archive - run: | - id="apache-iceberg-go-${VERSION}" - tar_gz="${id}.tar.gz" - echo "TAR_GZ=${tar_gz}" >> ${GITHUB_ENV} - git archive HEAD --prefix "${id}/" --output "${tar_gz}" - sha256sum "${tar_gz}" > "${tar_gz}.sha256" - sha512sum "${tar_gz}" > "${tar_gz}.sha512" - - name: Audit - run: | - dev/release/run_rat.sh "${TAR_GZ}" - - uses: actions/upload-artifact@v4 - with: - name: archive - path: | - apache-iceberg-go-* - - verify: - name: Verify - needs: - - archive - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: - - macos-latest - - ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@v4 - - uses: actions/download-artifact@v4 - with: - name: archive - - name: Verify - run: | - tar_gz=$(echo apache-iceberg-go-*.tar.gz) - version=${tar_gz#apache-iceberg-go-} - version=${version%.tar.gz} - if [ "${GITHUB_REF_TYPE}" = "tag" ]; then - rc="${GITHUB_REF_NAME#*-rc}" - else - rc=100 - fi - VERIFY_DEFAULT=0 dev/release/verify_rc.sh "${version}" "${rc}" - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - upload: - name: upload - if: github.ref_type == 'tag' - needs: - - verify - runs-on: ubuntu-latest - permissions: - contents: write - steps: - - name: Checkout - uses: actions/checkout@v4 - - uses: actions/download-artifact@v4 - with: - name: archive - - name: Upload - run: | - # TODO: Add support for release notes - gh release create ${GITHUB_REF_NAME} \ - --prerelease \ - --title "Apache Iceberg Go ${GITHUB_REF_NAME}" \ - --verify-tag \ - apache-iceberg-go-*.tar.gz \ - apache-iceberg-go-*.tar.gz.sha* - env: - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: RC +on: + push: + branches: + - '**' + - '!dependabot/**' + tags: + - '*-rc*' + pull_request: + +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true +permissions: + contents: read + +jobs: + archive: + name: Archive + runs-on: ubuntu-latest + timeout-minutes: 5 + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Prepare for tag + if: github.ref_type == 'tag' + run: | + version=${GITHUB_REF_NAME%-rc} + version=${version#v} + rc=${GITHUB_REF_NAME#*-rc} + echo "VERSION=${version}" >> ${GITHUB_ENV} + echo "RC=${rc}" >> ${GITHUB_ENV} + - name: Prepare for branch + if: github.ref_type == 'branch' + run: | + rc=100 + echo "VERSION=${version}" >> ${GITHUB_ENV} + echo "RC=${rc}" >> ${GITHUB_ENV} + - name: Archive + run: | + id="apache-iceberg-go-${VERSION}" + tar_gz="${id}.tar.gz" + echo "TAR_GZ=${tar_gz}" >> ${GITHUB_ENV} + git archive HEAD --prefix "${id}/" --output "${tar_gz}" + sha256sum "${tar_gz}" > "${tar_gz}.sha256" + sha512sum "${tar_gz}" > "${tar_gz}.sha512" + - name: Audit + run: | + dev/release/run_rat.sh "${TAR_GZ}" + - uses: actions/upload-artifact@v4 + with: + name: archive + path: | + apache-iceberg-go-* + + verify: + name: Verify + needs: + - archive + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: + - macos-latest + - ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + - uses: actions/download-artifact@v4 + with: + name: archive + - name: Verify + run: | + tar_gz=$(echo apache-iceberg-go-*.tar.gz) + version=${tar_gz#apache-iceberg-go-} + version=${version%.tar.gz} + if [ "${GITHUB_REF_TYPE}" = "tag" ]; then + rc="${GITHUB_REF_NAME#*-rc}" + else + rc=100 + fi + VERIFY_DEFAULT=0 dev/release/verify_rc.sh "${version}" "${rc}" + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + upload: + name: upload + if: github.ref_type == 'tag' + needs: + - verify + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - name: Checkout + uses: actions/checkout@v4 + - uses: actions/download-artifact@v4 + with: + name: archive + - name: Upload + run: | + # TODO: Add support for release notes + gh release create ${GITHUB_REF_NAME} \ + --prerelease \ + --title "Apache Iceberg Go ${GITHUB_REF_NAME}" \ + --verify-tag \ + apache-iceberg-go-*.tar.gz \ + apache-iceberg-go-*.tar.gz.sha* + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitignore b/.gitignore index 99c3976..d7da271 100644 --- a/.gitignore +++ b/.gitignore @@ -1,58 +1,58 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -.DS_Store -.cache -tmp/ - -.vscode - -# Binaries from programs and plugins -*.exe -*.exe~ -*.dll -*.so -*.dylib - -# rat check -build/ -lib/ - -# Test binary, build with `go test -c` -*.test - -# Output of the go coverage tool -*.out - -# intellij files -.idea/ -.idea_modules/ -*.ipr -*.iws -*.iml - -.envrc* - -# local catalog environment via docker -dev/notebooks -dev/warehouse - -/apache-iceberg-go-*.tar.gz -/apache-iceberg-go-*.tar.gz.asc -/dev/release/apache-rat-*.jar -/dev/release/filtered_rat.txt -/dev/release/rat.xml +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +.DS_Store +.cache +tmp/ + +.vscode + +# Binaries from programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# rat check +build/ +lib/ + +# Test binary, build with `go test -c` +*.test + +# Output of the go coverage tool +*.out + +# intellij files +.idea/ +.idea_modules/ +*.ipr +*.iws +*.iml + +.envrc* + +# local catalog environment via docker +dev/notebooks +dev/warehouse + +/apache-iceberg-go-*.tar.gz +/apache-iceberg-go-*.tar.gz.asc +/dev/release/apache-rat-*.jar +/dev/release/filtered_rat.txt +/dev/release/rat.xml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d1bf3ab..ae6a8b0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,26 +1,26 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. ---- -files: ^go/ - -repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 - hooks: - - id: golangci-lint - entry: bash -c 'cd iceberg && golangci-lint run --fix --timeout 5m' - types_or: [go] +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +--- +files: ^go/ + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: golangci-lint + entry: bash -c 'cd iceberg && golangci-lint run --fix --timeout 5m' + types_or: [go] diff --git a/LICENSE b/LICENSE index 515fd54..5727b72 100644 --- a/LICENSE +++ b/LICENSE @@ -1,315 +1,315 @@ - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - --------------------------------------------------------------------------------- - -This product includes a gradle wrapper. - -* gradlew and gradle/wrapper/gradle-wrapper.properties - -Copyright: 2010-2019 Gradle Authors. -Home page: https://github.com/gradle/gradle -License: https://www.apache.org/licenses/LICENSE-2.0 - --------------------------------------------------------------------------------- - -This product includes code from Apache Avro. - -* Conversion in DecimalWriter is based on Avro's Conversions.DecimalConversion. - -Copyright: 2014-2017 The Apache Software Foundation. -Home page: https://avro.apache.org/ -License: https://www.apache.org/licenses/LICENSE-2.0 - --------------------------------------------------------------------------------- - -This product includes code from Apache Parquet. - -* DynMethods.java -* DynConstructors.java -* AssertHelpers.java -* IOUtil.java readFully and tests -* ByteBufferInputStream implementations and tests - -Copyright: 2014-2017 The Apache Software Foundation. -Home page: https://parquet.apache.org/ -License: https://www.apache.org/licenses/LICENSE-2.0 - --------------------------------------------------------------------------------- - -This product includes code from Cloudera Kite. - -* SchemaVisitor and visit methods - -Copyright: 2013-2017 Cloudera Inc. -Home page: https://kitesdk.org/ -License: https://www.apache.org/licenses/LICENSE-2.0 - --------------------------------------------------------------------------------- - -This product includes code from Presto. - -* Retry wait and jitter logic in Tasks.java -* S3FileIO logic derived from PrestoS3FileSystem.java in S3InputStream.java - and S3OutputStream.java -* SQL grammar rules for parsing CALL statements in IcebergSqlExtensions.g4 -* some aspects of handling stored procedures - -Copyright: 2016 Facebook and contributors -Home page: https://prestodb.io/ -License: https://www.apache.org/licenses/LICENSE-2.0 - --------------------------------------------------------------------------------- - -This product includes code from Apache iBATIS. - -* Hive ScriptRunner.java - -Copyright: 2004 Clinton Begin -Home page: https://ibatis.apache.org/ -License: https://www.apache.org/licenses/LICENSE-2.0 - --------------------------------------------------------------------------------- - -This product includes code from Apache Hive. - -* Hive metastore derby schema in hive-schema-3.1.0.derby.sql - -Copyright: 2011-2018 The Apache Software Foundation -Home page: https://hive.apache.org/ -License: https://www.apache.org/licenses/LICENSE-2.0 - --------------------------------------------------------------------------------- - -This product includes code from Apache Spark. - -* dev/check-license script -* vectorized reading of definition levels in BaseVectorizedParquetValuesReader.java -* portions of the extensions parser -* casting logic in AssignmentAlignmentSupport -* implementation of SetAccumulator. -* Connector expressions. - -Copyright: 2011-2018 The Apache Software Foundation -Home page: https://spark.apache.org/ -License: https://www.apache.org/licenses/LICENSE-2.0 - --------------------------------------------------------------------------------- - -This product includes code from Delta Lake. - -* AssignmentAlignmentSupport is an independent development but UpdateExpressionsSupport in Delta was used as a reference. - -Copyright: 2020 The Delta Lake Project Authors. -Home page: https://delta.io/ -License: https://www.apache.org/licenses/LICENSE-2.0 - --------------------------------------------------------------------------------- - -This product includes code from Apache Commons. - -* Core ArrayUtil. - -Copyright: 2020 The Apache Software Foundation -Home page: https://commons.apache.org/ -License: https://www.apache.org/licenses/LICENSE-2.0 + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +-------------------------------------------------------------------------------- + +This product includes a gradle wrapper. + +* gradlew and gradle/wrapper/gradle-wrapper.properties + +Copyright: 2010-2019 Gradle Authors. +Home page: https://github.com/gradle/gradle +License: https://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This product includes code from Apache Avro. + +* Conversion in DecimalWriter is based on Avro's Conversions.DecimalConversion. + +Copyright: 2014-2017 The Apache Software Foundation. +Home page: https://avro.apache.org/ +License: https://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This product includes code from Apache Parquet. + +* DynMethods.java +* DynConstructors.java +* AssertHelpers.java +* IOUtil.java readFully and tests +* ByteBufferInputStream implementations and tests + +Copyright: 2014-2017 The Apache Software Foundation. +Home page: https://parquet.apache.org/ +License: https://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This product includes code from Cloudera Kite. + +* SchemaVisitor and visit methods + +Copyright: 2013-2017 Cloudera Inc. +Home page: https://kitesdk.org/ +License: https://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This product includes code from Presto. + +* Retry wait and jitter logic in Tasks.java +* S3FileIO logic derived from PrestoS3FileSystem.java in S3InputStream.java + and S3OutputStream.java +* SQL grammar rules for parsing CALL statements in IcebergSqlExtensions.g4 +* some aspects of handling stored procedures + +Copyright: 2016 Facebook and contributors +Home page: https://prestodb.io/ +License: https://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This product includes code from Apache iBATIS. + +* Hive ScriptRunner.java + +Copyright: 2004 Clinton Begin +Home page: https://ibatis.apache.org/ +License: https://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This product includes code from Apache Hive. + +* Hive metastore derby schema in hive-schema-3.1.0.derby.sql + +Copyright: 2011-2018 The Apache Software Foundation +Home page: https://hive.apache.org/ +License: https://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This product includes code from Apache Spark. + +* dev/check-license script +* vectorized reading of definition levels in BaseVectorizedParquetValuesReader.java +* portions of the extensions parser +* casting logic in AssignmentAlignmentSupport +* implementation of SetAccumulator. +* Connector expressions. + +Copyright: 2011-2018 The Apache Software Foundation +Home page: https://spark.apache.org/ +License: https://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This product includes code from Delta Lake. + +* AssignmentAlignmentSupport is an independent development but UpdateExpressionsSupport in Delta was used as a reference. + +Copyright: 2020 The Delta Lake Project Authors. +Home page: https://delta.io/ +License: https://www.apache.org/licenses/LICENSE-2.0 + +-------------------------------------------------------------------------------- + +This product includes code from Apache Commons. + +* Core ArrayUtil. + +Copyright: 2020 The Apache Software Foundation +Home page: https://commons.apache.org/ +License: https://www.apache.org/licenses/LICENSE-2.0 diff --git a/NOTICE b/NOTICE index 48df82e..9e8943f 100644 --- a/NOTICE +++ b/NOTICE @@ -1,8 +1,8 @@ - -Apache Iceberg -Copyright 2023 The Apache Software Foundation - -This product includes software developed at -The Apache Software Foundation (http://www.apache.org/). - --------------------------------------------------------------------------------- + +Apache Iceberg +Copyright 2023 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + +-------------------------------------------------------------------------------- diff --git a/README.md b/README.md index aeb249d..f2983d0 100644 --- a/README.md +++ b/README.md @@ -1,86 +1,86 @@ - - -# Iceberg Golang - -[![Go Reference](https://pkg.go.dev/badge/github.com/apache/iceberg-go.svg)](https://pkg.go.dev/github.com/apache/iceberg-go) - -`iceberg` is a Golang implementation of the [Iceberg table spec](https://iceberg.apache.org/spec/). - -## Build From Source - -### Prerequisites - -* Go 1.21 or later - -### Build - -```shell -$ git clone https://github.com/apache/iceberg-go.git -$ cd iceberg-go/cmd/iceberg && go build . -``` - -## Feature Support / Roadmap - -### FileSystem Support - -| Filesystem Type | Supported | -| :------------------: | :-------: | -| S3 | X | -| Google Cloud Storage | | -| Azure Blob Storage | | -| Local Filesystem | X | - -### Metadata - -| Operation | Supported | -| :----------------------- | :-------: | -| Get Schema | X | -| Get Snapshots | X | -| Get Sort Orders | X | -| Get Partition Specs | X | -| Get Manifests | X | -| Create New Manifests | X | -| Plan Scan | | -| Plan Scan for Snapshot | | - -### Catalog Support - -| Operation | REST | Hive | DynamoDB | Glue | -| :----------------------- | :--: | :--: | :------: | :--: | -| Load Table | | | | X | -| List Tables | | | | X | -| Create Table | | | | | -| Update Current Snapshot | | | | | -| Create New Snapshot | | | | | -| Rename Table | | | | | -| Drop Table | | | | | -| Alter Table | | | | | -| Set Table Properties | | | | | -| Create Namespace | | | | | -| Drop Namespace | | | | | -| Set Namespace Properties | | | | | - -### Read/Write Data Support - -* No intrinsic support for reading/writing data yet -* Data can be manually read currently by retrieving data files via Manifests. -* Plan to add [Apache Arrow](https://pkg.go.dev/github.com/apache/arrow/go/v14@v14.0.0) support eventually. - -# Get in Touch - + + +# Iceberg Golang + +[![Go Reference](https://pkg.go.dev/badge/github.com/apache/iceberg-go.svg)](https://pkg.go.dev/github.com/apache/iceberg-go) + +`iceberg` is a Golang implementation of the [Iceberg table spec](https://iceberg.apache.org/spec/). + +## Build From Source + +### Prerequisites + +* Go 1.21 or later + +### Build + +```shell +$ git clone https://github.com/apache/iceberg-go.git +$ cd iceberg-go/cmd/iceberg && go build . +``` + +## Feature Support / Roadmap + +### FileSystem Support + +| Filesystem Type | Supported | +| :------------------: | :-------: | +| S3 | X | +| Google Cloud Storage | | +| Azure Blob Storage | | +| Local Filesystem | X | + +### Metadata + +| Operation | Supported | +| :----------------------- | :-------: | +| Get Schema | X | +| Get Snapshots | X | +| Get Sort Orders | X | +| Get Partition Specs | X | +| Get Manifests | X | +| Create New Manifests | X | +| Plan Scan | | +| Plan Scan for Snapshot | | + +### Catalog Support + +| Operation | REST | Hive | DynamoDB | Glue | +| :----------------------- | :--: | :--: | :------: | :--: | +| Load Table | | | | X | +| List Tables | | | | X | +| Create Table | | | | | +| Update Current Snapshot | | | | | +| Create New Snapshot | | | | | +| Rename Table | | | | | +| Drop Table | | | | | +| Alter Table | | | | | +| Set Table Properties | | | | | +| Create Namespace | | | | | +| Drop Namespace | | | | | +| Set Namespace Properties | | | | | + +### Read/Write Data Support + +* No intrinsic support for reading/writing data yet +* Data can be manually read currently by retrieving data files via Manifests. +* Plan to add [Apache Arrow](https://pkg.go.dev/github.com/apache/arrow/go/v14@v14.0.0) support eventually. + +# Get in Touch + - [Iceberg community](https://iceberg.apache.org/community/) \ No newline at end of file diff --git a/catalog/README.md b/catalog/README.md index 561cccb..216b15a 100644 --- a/catalog/README.md +++ b/catalog/README.md @@ -1,123 +1,123 @@ - - -# Catalog Implementations - -## Integration Testing - -The Catalog implementations can be manually tested using the CLI implemented -in the `cmd/iceberg` folder. - -### REST Catalog - -To test the REST catalog implementation, we have a docker configuration -for a Minio container and tabluario/iceberg-rest container. - -You can spin up the local catalog by going to the `dev/` folder and running -`docker-compose up`. You can then follow the steps of the Iceberg [Quickstart](https://iceberg.apache.org/spark-quickstart/#creating-a-table) -tutorial, which we've summarized below. - -#### Setup your Iceberg catalog - -First launch a pyspark console by running: - -```bash -docker exec -it spark-iceberg pyspark -``` - -Once in the pyspark shell, we create a simple table with a namespace of -"demo.nyc" called "taxis": - -```python -from pyspark.sql.types import DoubleType, FloatType, LongType, StructType,StructField, StringType -schema = StructType([ - StructField("vendor_id", LongType(), True), - StructField("trip_id", LongType(), True), - StructField("trip_distance", FloatType(), True), - StructField("fare_amount", DoubleType(), True), - StructField("store_and_fwd_flag", StringType(), True) -]) - -df = spark.createDataFrame([], schema) -df.writeTo("demo.nyc.taxis").create() -``` - -Finally, we write another data-frame to the table to add new files: - -```python -schema = spark.table("demo.nyc.taxis").schema -data = [ - (1, 1000371, 1.8, 15.32, "N"), - (2, 1000372, 2.5, 22.15, "N"), - (2, 1000373, 0.9, 9.01, "N"), - (1, 1000374, 8.4, 42.13, "Y") - ] -df = spark.createDataFrame(data, schema) -df.writeTo("demo.nyc.taxis").append() -``` - -#### Testing with the CLI - -Now that we have a table in the catalog which is running. You can use the -CLI which is implemented in the `cmd/iceberg` folder. You will need to set -the following environment variables (which can also be found in the -docker-compose.yml): - -``` -AWS_S3_ENDPOINT=http://localhost:9000 -AWS_REGION=us-east-1 -AWS_ACCESS_KEY_ID=admin -AWS_SECRET_ACCESS_KEY=password -``` - -With those environment variables set you can now run the CLI: - -```bash -$ go run ./cmd/iceberg list --catalog rest --uri http://localhost:8181 -┌──────┐ -| IDs | -| ---- | -| demo | -└──────┘ -``` - -You can retrieve the schema of the table: - -```bash -$ go run ./cmd/iceberg schema --catalog rest --uri http://localhost:8181 demo.nyc.taxis -Current Schema, id=0 -├──1: vendor_id: optional long -├──2: trip_id: optional long -├──3: trip_distance: optional float -├──4: fare_amount: optional double -└──5: store_and_fwd_flag: optional string -``` - -You can get the file list: - -```bash -$ go run ./cmd/iceberg files --catalog rest --uri http://localhost:8181 demo.nyc.taxis -Snapshots: rest.demo.nyc.taxis -└─┬Snapshot 7004656639550124164, schema 0: s3://warehouse/demo/nyc/taxis/metadata/snap-7004656639550124164-1-0d533cd4-f0c1-45a6-a691-f2be3abe5491.avro - └─┬Manifest: s3://warehouse/demo/nyc/taxis/metadata/0d533cd4-f0c1-45a6-a691-f2be3abe5491-m0.avro - ├──Datafile: s3://warehouse/demo/nyc/taxis/data/00004-24-244255d4-8bf6-41bd-8885-bf7d2136fddf-00001.parquet - ├──Datafile: s3://warehouse/demo/nyc/taxis/data/00009-29-244255d4-8bf6-41bd-8885-bf7d2136fddf-00001.parquet - ├──Datafile: s3://warehouse/demo/nyc/taxis/data/00014-34-244255d4-8bf6-41bd-8885-bf7d2136fddf-00001.parquet - └──Datafile: s3://warehouse/demo/nyc/taxis/data/00019-39-244255d4-8bf6-41bd-8885-bf7d2136fddf-00001.parquet -``` - + + +# Catalog Implementations + +## Integration Testing + +The Catalog implementations can be manually tested using the CLI implemented +in the `cmd/iceberg` folder. + +### REST Catalog + +To test the REST catalog implementation, we have a docker configuration +for a Minio container and tabluario/iceberg-rest container. + +You can spin up the local catalog by going to the `dev/` folder and running +`docker-compose up`. You can then follow the steps of the Iceberg [Quickstart](https://iceberg.apache.org/spark-quickstart/#creating-a-table) +tutorial, which we've summarized below. + +#### Setup your Iceberg catalog + +First launch a pyspark console by running: + +```bash +docker exec -it spark-iceberg pyspark +``` + +Once in the pyspark shell, we create a simple table with a namespace of +"demo.nyc" called "taxis": + +```python +from pyspark.sql.types import DoubleType, FloatType, LongType, StructType,StructField, StringType +schema = StructType([ + StructField("vendor_id", LongType(), True), + StructField("trip_id", LongType(), True), + StructField("trip_distance", FloatType(), True), + StructField("fare_amount", DoubleType(), True), + StructField("store_and_fwd_flag", StringType(), True) +]) + +df = spark.createDataFrame([], schema) +df.writeTo("demo.nyc.taxis").create() +``` + +Finally, we write another data-frame to the table to add new files: + +```python +schema = spark.table("demo.nyc.taxis").schema +data = [ + (1, 1000371, 1.8, 15.32, "N"), + (2, 1000372, 2.5, 22.15, "N"), + (2, 1000373, 0.9, 9.01, "N"), + (1, 1000374, 8.4, 42.13, "Y") + ] +df = spark.createDataFrame(data, schema) +df.writeTo("demo.nyc.taxis").append() +``` + +#### Testing with the CLI + +Now that we have a table in the catalog which is running. You can use the +CLI which is implemented in the `cmd/iceberg` folder. You will need to set +the following environment variables (which can also be found in the +docker-compose.yml): + +``` +AWS_S3_ENDPOINT=http://localhost:9000 +AWS_REGION=us-east-1 +AWS_ACCESS_KEY_ID=admin +AWS_SECRET_ACCESS_KEY=password +``` + +With those environment variables set you can now run the CLI: + +```bash +$ go run ./cmd/iceberg list --catalog rest --uri http://localhost:8181 +┌──────┐ +| IDs | +| ---- | +| demo | +└──────┘ +``` + +You can retrieve the schema of the table: + +```bash +$ go run ./cmd/iceberg schema --catalog rest --uri http://localhost:8181 demo.nyc.taxis +Current Schema, id=0 +├──1: vendor_id: optional long +├──2: trip_id: optional long +├──3: trip_distance: optional float +├──4: fare_amount: optional double +└──5: store_and_fwd_flag: optional string +``` + +You can get the file list: + +```bash +$ go run ./cmd/iceberg files --catalog rest --uri http://localhost:8181 demo.nyc.taxis +Snapshots: rest.demo.nyc.taxis +└─┬Snapshot 7004656639550124164, schema 0: s3://warehouse/demo/nyc/taxis/metadata/snap-7004656639550124164-1-0d533cd4-f0c1-45a6-a691-f2be3abe5491.avro + └─┬Manifest: s3://warehouse/demo/nyc/taxis/metadata/0d533cd4-f0c1-45a6-a691-f2be3abe5491-m0.avro + ├──Datafile: s3://warehouse/demo/nyc/taxis/data/00004-24-244255d4-8bf6-41bd-8885-bf7d2136fddf-00001.parquet + ├──Datafile: s3://warehouse/demo/nyc/taxis/data/00009-29-244255d4-8bf6-41bd-8885-bf7d2136fddf-00001.parquet + ├──Datafile: s3://warehouse/demo/nyc/taxis/data/00014-34-244255d4-8bf6-41bd-8885-bf7d2136fddf-00001.parquet + └──Datafile: s3://warehouse/demo/nyc/taxis/data/00019-39-244255d4-8bf6-41bd-8885-bf7d2136fddf-00001.parquet +``` + and so on, for the various options available in the CLI. \ No newline at end of file diff --git a/catalog/catalog.go b/catalog/catalog.go index d6d7f1e..72d31be 100644 --- a/catalog/catalog.go +++ b/catalog/catalog.go @@ -1,187 +1,187 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package catalog - -import ( - "context" - "crypto/tls" - "errors" - "net/url" - - "github.com/apache/iceberg-go" - "github.com/apache/iceberg-go/table" - "github.com/aws/aws-sdk-go-v2/aws" -) - -type CatalogType string - -const ( - REST CatalogType = "rest" - Hive CatalogType = "hive" - Glue CatalogType = "glue" - DynamoDB CatalogType = "dynamodb" - SQL CatalogType = "sql" -) - -var ( - // ErrNoSuchTable is returned when a table does not exist in the catalog. - ErrNoSuchTable = errors.New("table does not exist") - ErrNoSuchNamespace = errors.New("namespace does not exist") - ErrNamespaceAlreadyExists = errors.New("namespace already exists") -) - -// WithAwsConfig sets the AWS configuration for the catalog. -func WithAwsConfig(cfg aws.Config) Option[GlueCatalog] { - return func(o *options) { - o.awsConfig = cfg - } -} - -func WithCredential(cred string) Option[RestCatalog] { - return func(o *options) { - o.credential = cred - } -} - -func WithOAuthToken(token string) Option[RestCatalog] { - return func(o *options) { - o.oauthToken = token - } -} - -func WithTLSConfig(config *tls.Config) Option[RestCatalog] { - return func(o *options) { - o.tlsConfig = config - } -} - -func WithWarehouseLocation(loc string) Option[RestCatalog] { - return func(o *options) { - o.warehouseLocation = loc - } -} - -func WithMetadataLocation(loc string) Option[RestCatalog] { - return func(o *options) { - o.metadataLocation = loc - } -} - -func WithSigV4() Option[RestCatalog] { - return func(o *options) { - o.enableSigv4 = true - o.sigv4Service = "execute-api" - } -} - -func WithSigV4RegionSvc(region, service string) Option[RestCatalog] { - return func(o *options) { - o.enableSigv4 = true - o.sigv4Region = region - - if service == "" { - o.sigv4Service = "execute-api" - } else { - o.sigv4Service = service - } - } -} - -func WithAuthURI(uri *url.URL) Option[RestCatalog] { - return func(o *options) { - o.authUri = uri - } -} - -func WithPrefix(prefix string) Option[RestCatalog] { - return func(o *options) { - o.prefix = prefix - } -} - -type Option[T GlueCatalog | RestCatalog] func(*options) - -type options struct { - awsConfig aws.Config - - tlsConfig *tls.Config - credential string - oauthToken string - warehouseLocation string - metadataLocation string - enableSigv4 bool - sigv4Region string - sigv4Service string - prefix string - authUri *url.URL -} - -type PropertiesUpdateSummary struct { - Removed []string `json:"removed"` - Updated []string `json:"updated"` - Missing []string `json:"missing"` -} - -// Catalog for iceberg table operations like create, drop, load, list and others. -type Catalog interface { - // CatalogType returns the type of the catalog. - CatalogType() CatalogType - - // ListTables returns a list of table identifiers in the catalog, with the returned - // identifiers containing the information required to load the table via that catalog. - ListTables(ctx context.Context, namespace table.Identifier) ([]table.Identifier, error) - // LoadTable loads a table from the catalog and returns a Table with the metadata. - LoadTable(ctx context.Context, identifier table.Identifier, props iceberg.Properties) (*table.Table, error) - // DropTable tells the catalog to drop the table entirely - DropTable(ctx context.Context, identifier table.Identifier) error - // RenameTable tells the catalog to rename a given table by the identifiers - // provided, and then loads and returns the destination table - RenameTable(ctx context.Context, from, to table.Identifier) (*table.Table, error) - // ListNamespaces returns the list of available namespaces, optionally filtering by a - // parent namespace - ListNamespaces(ctx context.Context, parent table.Identifier) ([]table.Identifier, error) - // CreateNamespace tells the catalog to create a new namespace with the given properties - CreateNamespace(ctx context.Context, namespace table.Identifier, props iceberg.Properties) error - // DropNamespace tells the catalog to drop the namespace and all tables in that namespace - DropNamespace(ctx context.Context, namespace table.Identifier) error - // LoadNamespaceProperties returns the current properties in the catalog for - // a given namespace - LoadNamespaceProperties(ctx context.Context, namespace table.Identifier) (iceberg.Properties, error) - // UpdateNamespaceProperties allows removing, adding, and/or updating properties of a namespace - UpdateNamespaceProperties(ctx context.Context, namespace table.Identifier, - removals []string, updates iceberg.Properties) (PropertiesUpdateSummary, error) -} - -const ( - keyOauthToken = "token" - keyWarehouseLocation = "warehouse" - keyMetadataLocation = "metadata_location" - keyOauthCredential = "credential" -) - -func TableNameFromIdent(ident table.Identifier) string { - if len(ident) == 0 { - return "" - } - - return ident[len(ident)-1] -} - -func NamespaceFromIdent(ident table.Identifier) table.Identifier { - return ident[:len(ident)-1] -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package catalog + +import ( + "context" + "crypto/tls" + "errors" + "net/url" + + "github.com/apache/iceberg-go" + "github.com/apache/iceberg-go/table" + "github.com/aws/aws-sdk-go-v2/aws" +) + +type CatalogType string + +const ( + REST CatalogType = "rest" + Hive CatalogType = "hive" + Glue CatalogType = "glue" + DynamoDB CatalogType = "dynamodb" + SQL CatalogType = "sql" +) + +var ( + // ErrNoSuchTable is returned when a table does not exist in the catalog. + ErrNoSuchTable = errors.New("table does not exist") + ErrNoSuchNamespace = errors.New("namespace does not exist") + ErrNamespaceAlreadyExists = errors.New("namespace already exists") +) + +// WithAwsConfig sets the AWS configuration for the catalog. +func WithAwsConfig(cfg aws.Config) Option[GlueCatalog] { + return func(o *options) { + o.awsConfig = cfg + } +} + +func WithCredential(cred string) Option[RestCatalog] { + return func(o *options) { + o.credential = cred + } +} + +func WithOAuthToken(token string) Option[RestCatalog] { + return func(o *options) { + o.oauthToken = token + } +} + +func WithTLSConfig(config *tls.Config) Option[RestCatalog] { + return func(o *options) { + o.tlsConfig = config + } +} + +func WithWarehouseLocation(loc string) Option[RestCatalog] { + return func(o *options) { + o.warehouseLocation = loc + } +} + +func WithMetadataLocation(loc string) Option[RestCatalog] { + return func(o *options) { + o.metadataLocation = loc + } +} + +func WithSigV4() Option[RestCatalog] { + return func(o *options) { + o.enableSigv4 = true + o.sigv4Service = "execute-api" + } +} + +func WithSigV4RegionSvc(region, service string) Option[RestCatalog] { + return func(o *options) { + o.enableSigv4 = true + o.sigv4Region = region + + if service == "" { + o.sigv4Service = "execute-api" + } else { + o.sigv4Service = service + } + } +} + +func WithAuthURI(uri *url.URL) Option[RestCatalog] { + return func(o *options) { + o.authUri = uri + } +} + +func WithPrefix(prefix string) Option[RestCatalog] { + return func(o *options) { + o.prefix = prefix + } +} + +type Option[T GlueCatalog | RestCatalog] func(*options) + +type options struct { + awsConfig aws.Config + + tlsConfig *tls.Config + credential string + oauthToken string + warehouseLocation string + metadataLocation string + enableSigv4 bool + sigv4Region string + sigv4Service string + prefix string + authUri *url.URL +} + +type PropertiesUpdateSummary struct { + Removed []string `json:"removed"` + Updated []string `json:"updated"` + Missing []string `json:"missing"` +} + +// Catalog for iceberg table operations like create, drop, load, list and others. +type Catalog interface { + // CatalogType returns the type of the catalog. + CatalogType() CatalogType + + // ListTables returns a list of table identifiers in the catalog, with the returned + // identifiers containing the information required to load the table via that catalog. + ListTables(ctx context.Context, namespace table.Identifier) ([]table.Identifier, error) + // LoadTable loads a table from the catalog and returns a Table with the metadata. + LoadTable(ctx context.Context, identifier table.Identifier, props iceberg.Properties) (*table.Table, error) + // DropTable tells the catalog to drop the table entirely + DropTable(ctx context.Context, identifier table.Identifier) error + // RenameTable tells the catalog to rename a given table by the identifiers + // provided, and then loads and returns the destination table + RenameTable(ctx context.Context, from, to table.Identifier) (*table.Table, error) + // ListNamespaces returns the list of available namespaces, optionally filtering by a + // parent namespace + ListNamespaces(ctx context.Context, parent table.Identifier) ([]table.Identifier, error) + // CreateNamespace tells the catalog to create a new namespace with the given properties + CreateNamespace(ctx context.Context, namespace table.Identifier, props iceberg.Properties) error + // DropNamespace tells the catalog to drop the namespace and all tables in that namespace + DropNamespace(ctx context.Context, namespace table.Identifier) error + // LoadNamespaceProperties returns the current properties in the catalog for + // a given namespace + LoadNamespaceProperties(ctx context.Context, namespace table.Identifier) (iceberg.Properties, error) + // UpdateNamespaceProperties allows removing, adding, and/or updating properties of a namespace + UpdateNamespaceProperties(ctx context.Context, namespace table.Identifier, + removals []string, updates iceberg.Properties) (PropertiesUpdateSummary, error) +} + +const ( + keyOauthToken = "token" + keyWarehouseLocation = "warehouse" + keyMetadataLocation = "metadata_location" + keyOauthCredential = "credential" +) + +func TableNameFromIdent(ident table.Identifier) string { + if len(ident) == 0 { + return "" + } + + return ident[len(ident)-1] +} + +func NamespaceFromIdent(ident table.Identifier) table.Identifier { + return ident[:len(ident)-1] +} diff --git a/catalog/glue.go b/catalog/glue.go index 91b21ff..f6b0cec 100644 --- a/catalog/glue.go +++ b/catalog/glue.go @@ -1,255 +1,255 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package catalog - -import ( - "context" - "errors" - "fmt" - - "github.com/apache/iceberg-go" - "github.com/apache/iceberg-go/io" - "github.com/apache/iceberg-go/table" - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/glue" - "github.com/aws/aws-sdk-go-v2/service/glue/types" -) - -const glueTypeIceberg = "ICEBERG" - -var ( - _ Catalog = (*GlueCatalog)(nil) -) - -type glueAPI interface { - GetTable(ctx context.Context, params *glue.GetTableInput, optFns ...func(*glue.Options)) (*glue.GetTableOutput, error) - GetTables(ctx context.Context, params *glue.GetTablesInput, optFns ...func(*glue.Options)) (*glue.GetTablesOutput, error) - GetDatabases(ctx context.Context, params *glue.GetDatabasesInput, optFns ...func(*glue.Options)) (*glue.GetDatabasesOutput, error) -} - -type GlueCatalog struct { - glueSvc glueAPI -} - -func NewGlueCatalog(opts ...Option[GlueCatalog]) *GlueCatalog { - glueOps := &options{} - - for _, o := range opts { - o(glueOps) - } - - return &GlueCatalog{ - glueSvc: glue.NewFromConfig(glueOps.awsConfig), - } -} - -// ListTables returns a list of iceberg tables in the given Glue database. -// -// The namespace should just contain the Glue database name. -func (c *GlueCatalog) ListTables(ctx context.Context, namespace table.Identifier) ([]table.Identifier, error) { - database, err := identifierToGlueDatabase(namespace) - if err != nil { - return nil, err - } - - params := &glue.GetTablesInput{DatabaseName: aws.String(database)} - - var icebergTables []table.Identifier - - for { - tblsRes, err := c.glueSvc.GetTables(ctx, params) - if err != nil { - return nil, fmt.Errorf("failed to list tables in namespace %s: %w", database, err) - } - - icebergTables = append(icebergTables, - filterTableListByType(database, tblsRes.TableList, glueTypeIceberg)...) - - if tblsRes.NextToken == nil { - break - } - - params.NextToken = tblsRes.NextToken - } - - return icebergTables, nil -} - -// LoadTable loads a table from the catalog table details. -// -// The identifier should contain the Glue database name, then glue table name. -func (c *GlueCatalog) LoadTable(ctx context.Context, identifier table.Identifier, props iceberg.Properties) (*table.Table, error) { - database, tableName, err := identifierToGlueTable(identifier) - if err != nil { - return nil, err - } - - if props == nil { - props = map[string]string{} - } - - location, err := c.getTable(ctx, database, tableName) - if err != nil { - return nil, err - } - - // TODO: consider providing a way to directly access the S3 iofs to enable testing of the catalog. - iofs, err := io.LoadFS(props, location) - if err != nil { - return nil, fmt.Errorf("failed to load table %s.%s: %w", database, tableName, err) - } - - icebergTable, err := table.NewFromLocation([]string{tableName}, location, iofs) - if err != nil { - return nil, fmt.Errorf("failed to create table from location %s.%s: %w", database, tableName, err) - } - - return icebergTable, nil -} - -func (c *GlueCatalog) CatalogType() CatalogType { - return Glue -} - -func (c *GlueCatalog) DropTable(ctx context.Context, identifier table.Identifier) error { - return fmt.Errorf("%w: [Glue Catalog] drop table", iceberg.ErrNotImplemented) -} - -func (c *GlueCatalog) RenameTable(ctx context.Context, from, to table.Identifier) (*table.Table, error) { - return nil, fmt.Errorf("%w: [Glue Catalog] rename table", iceberg.ErrNotImplemented) -} - -func (c *GlueCatalog) CreateNamespace(ctx context.Context, namespace table.Identifier, props iceberg.Properties) error { - return fmt.Errorf("%w: [Glue Catalog] create namespace", iceberg.ErrNotImplemented) -} - -func (c *GlueCatalog) DropNamespace(ctx context.Context, namespace table.Identifier) error { - return fmt.Errorf("%w: [Glue Catalog] drop namespace", iceberg.ErrNotImplemented) -} - -func (c *GlueCatalog) LoadNamespaceProperties(ctx context.Context, namespace table.Identifier) (iceberg.Properties, error) { - return nil, fmt.Errorf("%w: [Glue Catalog] load namespace properties", iceberg.ErrNotImplemented) -} - -func (c *GlueCatalog) UpdateNamespaceProperties(ctx context.Context, namespace table.Identifier, - removals []string, updates iceberg.Properties) (PropertiesUpdateSummary, error) { - return PropertiesUpdateSummary{}, fmt.Errorf("%w: [Glue Catalog] update namespace properties", iceberg.ErrNotImplemented) -} - -// ListNamespaces returns a list of Iceberg namespaces from the given Glue catalog. -func (c *GlueCatalog) ListNamespaces(ctx context.Context, parent table.Identifier) ([]table.Identifier, error) { - params := &glue.GetDatabasesInput{} - - if parent != nil { - return nil, fmt.Errorf("hierarchical namespace is not supported") - } - - var icebergNamespaces []table.Identifier - - for { - databasesResp, err := c.glueSvc.GetDatabases(ctx, params) - if err != nil { - return nil, fmt.Errorf("failed to list databases: %w", err) - } - - icebergNamespaces = append(icebergNamespaces, - filterDatabaseListByType(databasesResp.DatabaseList, glueTypeIceberg)...) - - if databasesResp.NextToken == nil { - break - } - - params.NextToken = databasesResp.NextToken - } - - return icebergNamespaces, nil -} - -// GetTable loads a table from the Glue Catalog using the given database and table name. -func (c *GlueCatalog) getTable(ctx context.Context, database, tableName string) (string, error) { - tblRes, err := c.glueSvc.GetTable(ctx, - &glue.GetTableInput{ - DatabaseName: aws.String(database), - Name: aws.String(tableName), - }, - ) - if err != nil { - if errors.Is(err, &types.EntityNotFoundException{}) { - return "", fmt.Errorf("failed to get table %s.%s: %w", database, tableName, ErrNoSuchTable) - } - return "", fmt.Errorf("failed to get table %s.%s: %w", database, tableName, err) - } - - if tblRes.Table.Parameters["table_type"] != "ICEBERG" { - return "", errors.New("table is not an iceberg table") - } - - return tblRes.Table.Parameters["metadata_location"], nil -} - -func identifierToGlueTable(identifier table.Identifier) (string, string, error) { - if len(identifier) != 2 { - return "", "", fmt.Errorf("invalid identifier, missing database name: %v", identifier) - } - - return identifier[0], identifier[1], nil -} - -func identifierToGlueDatabase(identifier table.Identifier) (string, error) { - if len(identifier) != 1 { - return "", fmt.Errorf("invalid identifier, missing database name: %v", identifier) - } - - return identifier[0], nil -} - -// GlueTableIdentifier returns a glue table identifier for an iceberg table in the format [database, table]. -func GlueTableIdentifier(database string, tableName string) table.Identifier { - return []string{database, tableName} -} - -// GlueDatabaseIdentifier returns a database identifier for a Glue database in the format [database]. -func GlueDatabaseIdentifier(database string) table.Identifier { - return []string{database} -} - -func filterTableListByType(database string, tableList []types.Table, tableType string) []table.Identifier { - var filtered []table.Identifier - - for _, tbl := range tableList { - if tbl.Parameters["table_type"] != tableType { - continue - } - filtered = append(filtered, GlueTableIdentifier(database, aws.ToString(tbl.Name))) - } - - return filtered -} - -func filterDatabaseListByType(databases []types.Database, databaseType string) []table.Identifier { - var filtered []table.Identifier - - for _, database := range databases { - if database.Parameters["database_type"] != databaseType { - continue - } - filtered = append(filtered, GlueDatabaseIdentifier(aws.ToString(database.Name))) - } - - return filtered -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package catalog + +import ( + "context" + "errors" + "fmt" + + "github.com/apache/iceberg-go" + "github.com/apache/iceberg-go/io" + "github.com/apache/iceberg-go/table" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/glue" + "github.com/aws/aws-sdk-go-v2/service/glue/types" +) + +const glueTypeIceberg = "ICEBERG" + +var ( + _ Catalog = (*GlueCatalog)(nil) +) + +type glueAPI interface { + GetTable(ctx context.Context, params *glue.GetTableInput, optFns ...func(*glue.Options)) (*glue.GetTableOutput, error) + GetTables(ctx context.Context, params *glue.GetTablesInput, optFns ...func(*glue.Options)) (*glue.GetTablesOutput, error) + GetDatabases(ctx context.Context, params *glue.GetDatabasesInput, optFns ...func(*glue.Options)) (*glue.GetDatabasesOutput, error) +} + +type GlueCatalog struct { + glueSvc glueAPI +} + +func NewGlueCatalog(opts ...Option[GlueCatalog]) *GlueCatalog { + glueOps := &options{} + + for _, o := range opts { + o(glueOps) + } + + return &GlueCatalog{ + glueSvc: glue.NewFromConfig(glueOps.awsConfig), + } +} + +// ListTables returns a list of iceberg tables in the given Glue database. +// +// The namespace should just contain the Glue database name. +func (c *GlueCatalog) ListTables(ctx context.Context, namespace table.Identifier) ([]table.Identifier, error) { + database, err := identifierToGlueDatabase(namespace) + if err != nil { + return nil, err + } + + params := &glue.GetTablesInput{DatabaseName: aws.String(database)} + + var icebergTables []table.Identifier + + for { + tblsRes, err := c.glueSvc.GetTables(ctx, params) + if err != nil { + return nil, fmt.Errorf("failed to list tables in namespace %s: %w", database, err) + } + + icebergTables = append(icebergTables, + filterTableListByType(database, tblsRes.TableList, glueTypeIceberg)...) + + if tblsRes.NextToken == nil { + break + } + + params.NextToken = tblsRes.NextToken + } + + return icebergTables, nil +} + +// LoadTable loads a table from the catalog table details. +// +// The identifier should contain the Glue database name, then glue table name. +func (c *GlueCatalog) LoadTable(ctx context.Context, identifier table.Identifier, props iceberg.Properties) (*table.Table, error) { + database, tableName, err := identifierToGlueTable(identifier) + if err != nil { + return nil, err + } + + if props == nil { + props = map[string]string{} + } + + location, err := c.getTable(ctx, database, tableName) + if err != nil { + return nil, err + } + + // TODO: consider providing a way to directly access the S3 iofs to enable testing of the catalog. + iofs, err := io.LoadFS(props, location) + if err != nil { + return nil, fmt.Errorf("failed to load table %s.%s: %w", database, tableName, err) + } + + icebergTable, err := table.NewFromLocation([]string{tableName}, location, iofs) + if err != nil { + return nil, fmt.Errorf("failed to create table from location %s.%s: %w", database, tableName, err) + } + + return icebergTable, nil +} + +func (c *GlueCatalog) CatalogType() CatalogType { + return Glue +} + +func (c *GlueCatalog) DropTable(ctx context.Context, identifier table.Identifier) error { + return fmt.Errorf("%w: [Glue Catalog] drop table", iceberg.ErrNotImplemented) +} + +func (c *GlueCatalog) RenameTable(ctx context.Context, from, to table.Identifier) (*table.Table, error) { + return nil, fmt.Errorf("%w: [Glue Catalog] rename table", iceberg.ErrNotImplemented) +} + +func (c *GlueCatalog) CreateNamespace(ctx context.Context, namespace table.Identifier, props iceberg.Properties) error { + return fmt.Errorf("%w: [Glue Catalog] create namespace", iceberg.ErrNotImplemented) +} + +func (c *GlueCatalog) DropNamespace(ctx context.Context, namespace table.Identifier) error { + return fmt.Errorf("%w: [Glue Catalog] drop namespace", iceberg.ErrNotImplemented) +} + +func (c *GlueCatalog) LoadNamespaceProperties(ctx context.Context, namespace table.Identifier) (iceberg.Properties, error) { + return nil, fmt.Errorf("%w: [Glue Catalog] load namespace properties", iceberg.ErrNotImplemented) +} + +func (c *GlueCatalog) UpdateNamespaceProperties(ctx context.Context, namespace table.Identifier, + removals []string, updates iceberg.Properties) (PropertiesUpdateSummary, error) { + return PropertiesUpdateSummary{}, fmt.Errorf("%w: [Glue Catalog] update namespace properties", iceberg.ErrNotImplemented) +} + +// ListNamespaces returns a list of Iceberg namespaces from the given Glue catalog. +func (c *GlueCatalog) ListNamespaces(ctx context.Context, parent table.Identifier) ([]table.Identifier, error) { + params := &glue.GetDatabasesInput{} + + if parent != nil { + return nil, fmt.Errorf("hierarchical namespace is not supported") + } + + var icebergNamespaces []table.Identifier + + for { + databasesResp, err := c.glueSvc.GetDatabases(ctx, params) + if err != nil { + return nil, fmt.Errorf("failed to list databases: %w", err) + } + + icebergNamespaces = append(icebergNamespaces, + filterDatabaseListByType(databasesResp.DatabaseList, glueTypeIceberg)...) + + if databasesResp.NextToken == nil { + break + } + + params.NextToken = databasesResp.NextToken + } + + return icebergNamespaces, nil +} + +// GetTable loads a table from the Glue Catalog using the given database and table name. +func (c *GlueCatalog) getTable(ctx context.Context, database, tableName string) (string, error) { + tblRes, err := c.glueSvc.GetTable(ctx, + &glue.GetTableInput{ + DatabaseName: aws.String(database), + Name: aws.String(tableName), + }, + ) + if err != nil { + if errors.Is(err, &types.EntityNotFoundException{}) { + return "", fmt.Errorf("failed to get table %s.%s: %w", database, tableName, ErrNoSuchTable) + } + return "", fmt.Errorf("failed to get table %s.%s: %w", database, tableName, err) + } + + if tblRes.Table.Parameters["table_type"] != "ICEBERG" { + return "", errors.New("table is not an iceberg table") + } + + return tblRes.Table.Parameters["metadata_location"], nil +} + +func identifierToGlueTable(identifier table.Identifier) (string, string, error) { + if len(identifier) != 2 { + return "", "", fmt.Errorf("invalid identifier, missing database name: %v", identifier) + } + + return identifier[0], identifier[1], nil +} + +func identifierToGlueDatabase(identifier table.Identifier) (string, error) { + if len(identifier) != 1 { + return "", fmt.Errorf("invalid identifier, missing database name: %v", identifier) + } + + return identifier[0], nil +} + +// GlueTableIdentifier returns a glue table identifier for an iceberg table in the format [database, table]. +func GlueTableIdentifier(database string, tableName string) table.Identifier { + return []string{database, tableName} +} + +// GlueDatabaseIdentifier returns a database identifier for a Glue database in the format [database]. +func GlueDatabaseIdentifier(database string) table.Identifier { + return []string{database} +} + +func filterTableListByType(database string, tableList []types.Table, tableType string) []table.Identifier { + var filtered []table.Identifier + + for _, tbl := range tableList { + if tbl.Parameters["table_type"] != tableType { + continue + } + filtered = append(filtered, GlueTableIdentifier(database, aws.ToString(tbl.Name))) + } + + return filtered +} + +func filterDatabaseListByType(databases []types.Database, databaseType string) []table.Identifier { + var filtered []table.Identifier + + for _, database := range databases { + if database.Parameters["database_type"] != databaseType { + continue + } + filtered = append(filtered, GlueDatabaseIdentifier(aws.ToString(database.Name))) + } + + return filtered +} diff --git a/catalog/glue_test.go b/catalog/glue_test.go index 5889537..49410a2 100644 --- a/catalog/glue_test.go +++ b/catalog/glue_test.go @@ -1,199 +1,199 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package catalog - -import ( - "context" - "os" - "testing" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/service/glue" - "github.com/aws/aws-sdk-go-v2/service/glue/types" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" -) - -type mockGlueClient struct { - mock.Mock -} - -func (m *mockGlueClient) GetTable(ctx context.Context, params *glue.GetTableInput, optFns ...func(*glue.Options)) (*glue.GetTableOutput, error) { - args := m.Called(ctx, params, optFns) - return args.Get(0).(*glue.GetTableOutput), args.Error(1) -} - -func (m *mockGlueClient) GetTables(ctx context.Context, params *glue.GetTablesInput, optFns ...func(*glue.Options)) (*glue.GetTablesOutput, error) { - args := m.Called(ctx, params, optFns) - return args.Get(0).(*glue.GetTablesOutput), args.Error(1) -} - -func (m *mockGlueClient) GetDatabases(ctx context.Context, params *glue.GetDatabasesInput, optFns ...func(*glue.Options)) (*glue.GetDatabasesOutput, error) { - args := m.Called(ctx, params, optFns) - return args.Get(0).(*glue.GetDatabasesOutput), args.Error(1) -} - -func TestGlueGetTable(t *testing.T) { - assert := require.New(t) - - mockGlueSvc := &mockGlueClient{} - - mockGlueSvc.On("GetTable", mock.Anything, &glue.GetTableInput{ - DatabaseName: aws.String("test_database"), - Name: aws.String("test_table"), - }, mock.Anything).Return(&glue.GetTableOutput{ - Table: &types.Table{ - Parameters: map[string]string{ - "table_type": "ICEBERG", - "metadata_location": "s3://test-bucket/test_table/metadata/abc123-123.metadata.json", - }, - }, - }, nil) - - glueCatalog := &GlueCatalog{ - glueSvc: mockGlueSvc, - } - - location, err := glueCatalog.getTable(context.TODO(), "test_database", "test_table") - assert.NoError(err) - assert.Equal("s3://test-bucket/test_table/metadata/abc123-123.metadata.json", location) -} - -func TestGlueListTables(t *testing.T) { - assert := require.New(t) - - mockGlueSvc := &mockGlueClient{} - - mockGlueSvc.On("GetTables", mock.Anything, &glue.GetTablesInput{ - DatabaseName: aws.String("test_database"), - }, mock.Anything).Return(&glue.GetTablesOutput{ - TableList: []types.Table{ - { - Name: aws.String("test_table"), - Parameters: map[string]string{ - "table_type": "ICEBERG", - "metadata_location": "s3://test-bucket/test_table/metadata/abc123-123.metadata.json", - }, - }, - { - Name: aws.String("other_table"), - Parameters: map[string]string{ - "metadata_location": "s3://test-bucket/other_table/", - }, - }, - }, - }, nil).Once() - - glueCatalog := &GlueCatalog{ - glueSvc: mockGlueSvc, - } - - tables, err := glueCatalog.ListTables(context.TODO(), GlueDatabaseIdentifier("test_database")) - assert.NoError(err) - assert.Len(tables, 1) - assert.Equal([]string{"test_database", "test_table"}, tables[0]) -} - -func TestGlueListNamespaces(t *testing.T) { - assert := require.New(t) - - mockGlueSvc := &mockGlueClient{} - - mockGlueSvc.On("GetDatabases", mock.Anything, &glue.GetDatabasesInput{}, mock.Anything).Return(&glue.GetDatabasesOutput{ - DatabaseList: []types.Database{ - { - Name: aws.String("test_database"), - Parameters: map[string]string{ - "database_type": "ICEBERG", - }, - }, - { - Name: aws.String("other_database"), - Parameters: map[string]string{}, - }, - }, - }, nil).Once() - - glueCatalog := &GlueCatalog{ - glueSvc: mockGlueSvc, - } - - databases, err := glueCatalog.ListNamespaces(context.TODO(), nil) - assert.NoError(err) - assert.Len(databases, 1) - assert.Equal([]string{"test_database"}, databases[0]) -} - -func TestGlueListTablesIntegration(t *testing.T) { - if os.Getenv("TEST_DATABASE_NAME") == "" { - t.Skip() - } - if os.Getenv("TEST_TABLE_NAME") == "" { - t.Skip() - } - assert := require.New(t) - - awscfg, err := config.LoadDefaultConfig(context.TODO(), config.WithClientLogMode(aws.LogRequest|aws.LogResponse)) - assert.NoError(err) - - catalog := NewGlueCatalog(WithAwsConfig(awscfg)) - - tables, err := catalog.ListTables(context.TODO(), GlueDatabaseIdentifier(os.Getenv("TEST_DATABASE_NAME"))) - assert.NoError(err) - assert.Equal([]string{os.Getenv("TEST_DATABASE_NAME"), os.Getenv("TEST_TABLE_NAME")}, tables[1]) -} - -func TestGlueLoadTableIntegration(t *testing.T) { - if os.Getenv("TEST_DATABASE_NAME") == "" { - t.Skip() - } - if os.Getenv("TEST_TABLE_NAME") == "" { - t.Skip() - } - if os.Getenv("TEST_TABLE_LOCATION") == "" { - t.Skip() - } - - assert := require.New(t) - - awscfg, err := config.LoadDefaultConfig(context.TODO(), config.WithClientLogMode(aws.LogRequest|aws.LogResponse)) - assert.NoError(err) - - catalog := NewGlueCatalog(WithAwsConfig(awscfg)) - - table, err := catalog.LoadTable(context.TODO(), []string{os.Getenv("TEST_DATABASE_NAME"), os.Getenv("TEST_TABLE_NAME")}, nil) - assert.NoError(err) - assert.Equal([]string{os.Getenv("TEST_TABLE_NAME")}, table.Identifier()) -} - -func TestGlueListNamespacesIntegration(t *testing.T) { - if os.Getenv("TEST_DATABASE_NAME") == "" { - t.Skip() - } - assert := require.New(t) - - awscfg, err := config.LoadDefaultConfig(context.TODO(), config.WithClientLogMode(aws.LogRequest|aws.LogResponse)) - assert.NoError(err) - - catalog := NewGlueCatalog(WithAwsConfig(awscfg)) - - namespaces, err := catalog.ListNamespaces(context.TODO(), nil) - assert.NoError(err) - assert.Contains(namespaces, []string{os.Getenv("TEST_DATABASE_NAME")}) -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package catalog + +import ( + "context" + "os" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/glue" + "github.com/aws/aws-sdk-go-v2/service/glue/types" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type mockGlueClient struct { + mock.Mock +} + +func (m *mockGlueClient) GetTable(ctx context.Context, params *glue.GetTableInput, optFns ...func(*glue.Options)) (*glue.GetTableOutput, error) { + args := m.Called(ctx, params, optFns) + return args.Get(0).(*glue.GetTableOutput), args.Error(1) +} + +func (m *mockGlueClient) GetTables(ctx context.Context, params *glue.GetTablesInput, optFns ...func(*glue.Options)) (*glue.GetTablesOutput, error) { + args := m.Called(ctx, params, optFns) + return args.Get(0).(*glue.GetTablesOutput), args.Error(1) +} + +func (m *mockGlueClient) GetDatabases(ctx context.Context, params *glue.GetDatabasesInput, optFns ...func(*glue.Options)) (*glue.GetDatabasesOutput, error) { + args := m.Called(ctx, params, optFns) + return args.Get(0).(*glue.GetDatabasesOutput), args.Error(1) +} + +func TestGlueGetTable(t *testing.T) { + assert := require.New(t) + + mockGlueSvc := &mockGlueClient{} + + mockGlueSvc.On("GetTable", mock.Anything, &glue.GetTableInput{ + DatabaseName: aws.String("test_database"), + Name: aws.String("test_table"), + }, mock.Anything).Return(&glue.GetTableOutput{ + Table: &types.Table{ + Parameters: map[string]string{ + "table_type": "ICEBERG", + "metadata_location": "s3://test-bucket/test_table/metadata/abc123-123.metadata.json", + }, + }, + }, nil) + + glueCatalog := &GlueCatalog{ + glueSvc: mockGlueSvc, + } + + location, err := glueCatalog.getTable(context.TODO(), "test_database", "test_table") + assert.NoError(err) + assert.Equal("s3://test-bucket/test_table/metadata/abc123-123.metadata.json", location) +} + +func TestGlueListTables(t *testing.T) { + assert := require.New(t) + + mockGlueSvc := &mockGlueClient{} + + mockGlueSvc.On("GetTables", mock.Anything, &glue.GetTablesInput{ + DatabaseName: aws.String("test_database"), + }, mock.Anything).Return(&glue.GetTablesOutput{ + TableList: []types.Table{ + { + Name: aws.String("test_table"), + Parameters: map[string]string{ + "table_type": "ICEBERG", + "metadata_location": "s3://test-bucket/test_table/metadata/abc123-123.metadata.json", + }, + }, + { + Name: aws.String("other_table"), + Parameters: map[string]string{ + "metadata_location": "s3://test-bucket/other_table/", + }, + }, + }, + }, nil).Once() + + glueCatalog := &GlueCatalog{ + glueSvc: mockGlueSvc, + } + + tables, err := glueCatalog.ListTables(context.TODO(), GlueDatabaseIdentifier("test_database")) + assert.NoError(err) + assert.Len(tables, 1) + assert.Equal([]string{"test_database", "test_table"}, tables[0]) +} + +func TestGlueListNamespaces(t *testing.T) { + assert := require.New(t) + + mockGlueSvc := &mockGlueClient{} + + mockGlueSvc.On("GetDatabases", mock.Anything, &glue.GetDatabasesInput{}, mock.Anything).Return(&glue.GetDatabasesOutput{ + DatabaseList: []types.Database{ + { + Name: aws.String("test_database"), + Parameters: map[string]string{ + "database_type": "ICEBERG", + }, + }, + { + Name: aws.String("other_database"), + Parameters: map[string]string{}, + }, + }, + }, nil).Once() + + glueCatalog := &GlueCatalog{ + glueSvc: mockGlueSvc, + } + + databases, err := glueCatalog.ListNamespaces(context.TODO(), nil) + assert.NoError(err) + assert.Len(databases, 1) + assert.Equal([]string{"test_database"}, databases[0]) +} + +func TestGlueListTablesIntegration(t *testing.T) { + if os.Getenv("TEST_DATABASE_NAME") == "" { + t.Skip() + } + if os.Getenv("TEST_TABLE_NAME") == "" { + t.Skip() + } + assert := require.New(t) + + awscfg, err := config.LoadDefaultConfig(context.TODO(), config.WithClientLogMode(aws.LogRequest|aws.LogResponse)) + assert.NoError(err) + + catalog := NewGlueCatalog(WithAwsConfig(awscfg)) + + tables, err := catalog.ListTables(context.TODO(), GlueDatabaseIdentifier(os.Getenv("TEST_DATABASE_NAME"))) + assert.NoError(err) + assert.Equal([]string{os.Getenv("TEST_DATABASE_NAME"), os.Getenv("TEST_TABLE_NAME")}, tables[1]) +} + +func TestGlueLoadTableIntegration(t *testing.T) { + if os.Getenv("TEST_DATABASE_NAME") == "" { + t.Skip() + } + if os.Getenv("TEST_TABLE_NAME") == "" { + t.Skip() + } + if os.Getenv("TEST_TABLE_LOCATION") == "" { + t.Skip() + } + + assert := require.New(t) + + awscfg, err := config.LoadDefaultConfig(context.TODO(), config.WithClientLogMode(aws.LogRequest|aws.LogResponse)) + assert.NoError(err) + + catalog := NewGlueCatalog(WithAwsConfig(awscfg)) + + table, err := catalog.LoadTable(context.TODO(), []string{os.Getenv("TEST_DATABASE_NAME"), os.Getenv("TEST_TABLE_NAME")}, nil) + assert.NoError(err) + assert.Equal([]string{os.Getenv("TEST_TABLE_NAME")}, table.Identifier()) +} + +func TestGlueListNamespacesIntegration(t *testing.T) { + if os.Getenv("TEST_DATABASE_NAME") == "" { + t.Skip() + } + assert := require.New(t) + + awscfg, err := config.LoadDefaultConfig(context.TODO(), config.WithClientLogMode(aws.LogRequest|aws.LogResponse)) + assert.NoError(err) + + catalog := NewGlueCatalog(WithAwsConfig(awscfg)) + + namespaces, err := catalog.ListNamespaces(context.TODO(), nil) + assert.NoError(err) + assert.Contains(namespaces, []string{os.Getenv("TEST_DATABASE_NAME")}) +} diff --git a/catalog/rest.go b/catalog/rest.go index ef9c332..6642aa4 100644 --- a/catalog/rest.go +++ b/catalog/rest.go @@ -1,712 +1,712 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package catalog - -import ( - "bytes" - "context" - "crypto/sha256" - "encoding/json" - "errors" - "fmt" - "hash" - "io" - "maps" - "net/http" - "net/url" - "strings" - "time" - - "github.com/apache/iceberg-go" - iceio "github.com/apache/iceberg-go/io" - "github.com/apache/iceberg-go/table" - "github.com/aws/aws-sdk-go-v2/aws" - v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" - "github.com/aws/aws-sdk-go-v2/config" -) - -var ( - _ Catalog = (*RestCatalog)(nil) -) - -const ( - authorizationHeader = "Authorization" - bearerPrefix = "Bearer" - namespaceSeparator = "\x1F" - keyPrefix = "prefix" - - icebergRestSpecVersion = "0.14.1" - - keyRestSigV4 = "rest.sigv4-enabled" - keyRestSigV4Region = "rest.signing-region" - keyRestSigV4Service = "rest.signing-name" - keyAuthUrl = "rest.authorization-url" -) - -var ( - ErrRESTError = errors.New("REST error") - ErrBadRequest = fmt.Errorf("%w: bad request", ErrRESTError) - ErrForbidden = fmt.Errorf("%w: forbidden", ErrRESTError) - ErrUnauthorized = fmt.Errorf("%w: unauthorized", ErrRESTError) - ErrAuthorizationExpired = fmt.Errorf("%w: authorization expired", ErrRESTError) - ErrServiceUnavailable = fmt.Errorf("%w: service unavailable", ErrRESTError) - ErrServerError = fmt.Errorf("%w: server error", ErrRESTError) - ErrCommitFailed = fmt.Errorf("%w: commit failed, refresh and try again", ErrRESTError) - ErrCommitStateUnknown = fmt.Errorf("%w: commit failed due to unknown reason", ErrRESTError) - ErrOAuthError = fmt.Errorf("%w: oauth error", ErrRESTError) -) - -type errorResponse struct { - Message string `json:"message"` - Type string `json:"type"` - Code int `json:"code"` - - wrapping error -} - -func (e errorResponse) Unwrap() error { return e.wrapping } -func (e errorResponse) Error() string { - return e.Type + ": " + e.Message -} - -type oauthTokenResponse struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - ExpiresIn int `json:"expires_in"` - Scope string `json:"scope"` - RefreshToken string `json:"refresh_token"` -} - -type oauthErrorResponse struct { - Err string `json:"error"` - ErrDesc string `json:"error_description"` - ErrURI string `json:"error_uri"` -} - -func (o oauthErrorResponse) Unwrap() error { return ErrOAuthError } -func (o oauthErrorResponse) Error() string { - msg := o.Err - if o.ErrDesc != "" { - msg += ": " + o.ErrDesc - } - - if o.ErrURI != "" { - msg += " (" + o.ErrURI + ")" - } - return msg -} - -type configResponse struct { - Defaults iceberg.Properties `json:"defaults"` - Overrides iceberg.Properties `json:"overrides"` -} - -type sessionTransport struct { - http.Transport - - defaultHeaders http.Header - signer v4.HTTPSigner - cfg aws.Config - service string - h hash.Hash -} - -// from https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/aws/signer/v4#Signer.SignHTTP -const emptyStringHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" - -func (s *sessionTransport) RoundTrip(r *http.Request) (*http.Response, error) { - for k, v := range s.defaultHeaders { - for _, hdr := range v { - r.Header.Add(k, hdr) - } - } - - if s.signer != nil { - var payloadHash string - if r.Body == nil { - payloadHash = emptyStringHash - } else { - rdr, err := r.GetBody() - if err != nil { - return nil, err - } - - if _, err = io.Copy(s.h, rdr); err != nil { - return nil, err - } - - payloadHash = string(s.h.Sum(nil)) - s.h.Reset() - } - - creds, err := s.cfg.Credentials.Retrieve(r.Context()) - if err != nil { - return nil, err - } - - // modifies the request in place - err = s.signer.SignHTTP(r.Context(), creds, r, payloadHash, - s.service, s.cfg.Region, time.Now()) - if err != nil { - return nil, err - } - } - - return s.Transport.RoundTrip(r) -} - -func do[T any](ctx context.Context, method string, baseURI *url.URL, path []string, cl *http.Client, override map[int]error, allowNoContent bool) (ret T, err error) { - var ( - req *http.Request - rsp *http.Response - ) - - uri := baseURI.JoinPath(path...).String() - if req, err = http.NewRequestWithContext(ctx, method, uri, nil); err != nil { - return - } - - if rsp, err = cl.Do(req); err != nil { - return - } - - if allowNoContent && rsp.StatusCode == http.StatusNoContent { - return - } - - if rsp.StatusCode != http.StatusOK { - return ret, handleNon200(rsp, override) - } - - defer rsp.Body.Close() - if err = json.NewDecoder(rsp.Body).Decode(&ret); err != nil { - return ret, fmt.Errorf("%w: error decoding json payload: `%s`", ErrRESTError, err.Error()) - } - - return -} - -func doGet[T any](ctx context.Context, baseURI *url.URL, path []string, cl *http.Client, override map[int]error) (ret T, err error) { - return do[T](ctx, http.MethodGet, baseURI, path, cl, override, false) -} - -func doDelete[T any](ctx context.Context, baseURI *url.URL, path []string, cl *http.Client, override map[int]error) (ret T, err error) { - return do[T](ctx, http.MethodDelete, baseURI, path, cl, override, true) -} - -func doPost[Payload, Result any](ctx context.Context, baseURI *url.URL, path []string, payload Payload, cl *http.Client, override map[int]error) (ret Result, err error) { - var ( - req *http.Request - rsp *http.Response - data []byte - ) - - uri := baseURI.JoinPath(path...).String() - data, err = json.Marshal(payload) - if err != nil { - return - } - - req, err = http.NewRequestWithContext(ctx, http.MethodPost, uri, bytes.NewReader(data)) - if err != nil { - return - } - - req.Header.Add("Content-Type", "application/json") - - rsp, err = cl.Do(req) - if err != nil { - return - } - - if rsp.StatusCode != http.StatusOK { - return ret, handleNon200(rsp, override) - } - - if rsp.ContentLength == 0 { - return - } - - defer rsp.Body.Close() - if err = json.NewDecoder(rsp.Body).Decode(&ret); err != nil { - return ret, fmt.Errorf("%w: error decoding json payload: `%s`", ErrRESTError, err.Error()) - } - - return -} - -func handleNon200(rsp *http.Response, override map[int]error) error { - var e errorResponse - - dec := json.NewDecoder(rsp.Body) - dec.Decode(&struct { - Error *errorResponse `json:"error"` - }{Error: &e}) - - if override != nil { - if err, ok := override[rsp.StatusCode]; ok { - e.wrapping = err - return e - } - } - - switch rsp.StatusCode { - case http.StatusBadRequest: - e.wrapping = ErrBadRequest - case http.StatusUnauthorized: - e.wrapping = ErrUnauthorized - case http.StatusForbidden: - e.wrapping = ErrForbidden - case http.StatusUnprocessableEntity: - e.wrapping = ErrRESTError - case 419: - e.wrapping = ErrAuthorizationExpired - case http.StatusNotImplemented: - e.wrapping = iceberg.ErrNotImplemented - case http.StatusServiceUnavailable: - e.wrapping = ErrServiceUnavailable - default: - if 500 <= rsp.StatusCode && rsp.StatusCode < 600 { - e.wrapping = ErrServerError - } else { - e.wrapping = ErrRESTError - } - } - - return e -} - -func ToRestIdentifier(ident ...string) table.Identifier { - if len(ident) == 1 { - if ident[0] == "" { - return nil - } - return table.Identifier(strings.Split(ident[0], ".")) - } - - return table.Identifier(ident) -} - -func fromProps(props iceberg.Properties) *options { - o := &options{} - for k, v := range props { - switch k { - case keyOauthToken: - o.oauthToken = v - case keyWarehouseLocation: - o.warehouseLocation = v - case keyMetadataLocation: - o.metadataLocation = v - case keyRestSigV4: - o.enableSigv4 = strings.ToLower(v) == "true" - case keyRestSigV4Region: - o.sigv4Region = v - case keyRestSigV4Service: - o.sigv4Service = v - case keyAuthUrl: - u, err := url.Parse(v) - if err != nil { - continue - } - o.authUri = u - case keyOauthCredential: - o.credential = v - case keyPrefix: - o.prefix = v - } - } - return o -} - -func toProps(o *options) iceberg.Properties { - props := iceberg.Properties{} - - setIf := func(key, v string) { - if v != "" { - props[key] = v - } - } - - setIf(keyOauthCredential, o.credential) - setIf(keyOauthToken, o.oauthToken) - setIf(keyWarehouseLocation, o.warehouseLocation) - setIf(keyMetadataLocation, o.metadataLocation) - if o.enableSigv4 { - props[keyRestSigV4] = "true" - setIf(keyRestSigV4Region, o.sigv4Region) - setIf(keyRestSigV4Service, o.sigv4Service) - } - - setIf(keyPrefix, o.prefix) - if o.authUri != nil { - setIf(keyAuthUrl, o.authUri.String()) - } - return props -} - -type RestCatalog struct { - baseURI *url.URL - cl *http.Client - - name string - props iceberg.Properties -} - -func NewRestCatalog(name, uri string, opts ...Option[RestCatalog]) (*RestCatalog, error) { - ops := &options{} - for _, o := range opts { - o(ops) - } - - baseuri, err := url.Parse(uri) - if err != nil { - return nil, err - } - - r := &RestCatalog{ - name: name, - baseURI: baseuri.JoinPath("v1"), - } - - if ops, err = r.fetchConfig(ops); err != nil { - return nil, err - } - - cl, err := r.createSession(ops) - if err != nil { - return nil, err - } - - r.cl = cl - if ops.prefix != "" { - r.baseURI = r.baseURI.JoinPath(ops.prefix) - } - r.props = toProps(ops) - return r, nil -} - -func (r *RestCatalog) fetchAccessToken(cl *http.Client, creds string, opts *options) (string, error) { - clientID, clientSecret, hasID := strings.Cut(creds, ":") - if !hasID { - clientID, clientSecret = "", clientID - } - - data := url.Values{ - "grant_type": {"client_credentials"}, - "client_id": {clientID}, - "client_secret": {clientSecret}, - "scope": {"catalog"}, - } - - uri := opts.authUri - if uri == nil { - uri = r.baseURI.JoinPath("oauth/tokens") - } - - rsp, err := cl.PostForm(uri.String(), data) - if err != nil { - return "", err - } - - if rsp.StatusCode == http.StatusOK { - defer rsp.Body.Close() - dec := json.NewDecoder(rsp.Body) - var tok oauthTokenResponse - if err := dec.Decode(&tok); err != nil { - return "", fmt.Errorf("failed to decode oauth token response: %w", err) - } - - return tok.AccessToken, nil - } - - switch rsp.StatusCode { - case http.StatusUnauthorized, http.StatusBadRequest: - defer rsp.Request.GetBody() - dec := json.NewDecoder(rsp.Body) - var oauthErr oauthErrorResponse - if err := dec.Decode(&oauthErr); err != nil { - return "", fmt.Errorf("failed to decode oauth error: %w", err) - } - - return "", oauthErr - default: - return "", handleNon200(rsp, nil) - } -} - -func (r *RestCatalog) createSession(opts *options) (*http.Client, error) { - session := &sessionTransport{ - Transport: http.Transport{TLSClientConfig: opts.tlsConfig}, - defaultHeaders: http.Header{}, - } - cl := &http.Client{Transport: session} - - token := opts.oauthToken - if token == "" && opts.credential != "" { - var err error - if token, err = r.fetchAccessToken(cl, opts.credential, opts); err != nil { - return nil, fmt.Errorf("auth error: %w", err) - } - } - - if token != "" { - session.defaultHeaders.Set(authorizationHeader, bearerPrefix+" "+token) - } - - session.defaultHeaders.Set("X-Client-Version", icebergRestSpecVersion) - session.defaultHeaders.Set("Content-Type", "application/json") - session.defaultHeaders.Set("User-Agent", "GoIceberg/"+iceberg.Version()) - session.defaultHeaders.Set("X-Iceberg-Access-Delegation", "vended-credentials") - - if opts.enableSigv4 { - cfg, err := config.LoadDefaultConfig(context.Background()) - if err != nil { - return nil, err - } - - if opts.sigv4Region != "" { - cfg.Region = opts.sigv4Region - } - - session.cfg, session.service = cfg, opts.sigv4Service - session.signer, session.h = v4.NewSigner(), sha256.New() - } - - return cl, nil -} - -func (r *RestCatalog) fetchConfig(opts *options) (*options, error) { - params := url.Values{} - if opts.warehouseLocation != "" { - params.Set(keyWarehouseLocation, opts.warehouseLocation) - } - - route := r.baseURI.JoinPath("config") - route.RawQuery = params.Encode() - - sess, err := r.createSession(opts) - if err != nil { - return nil, err - } - - rsp, err := doGet[configResponse](context.Background(), route, []string{}, sess, nil) - if err != nil { - return nil, err - } - - cfg := rsp.Defaults - maps.Copy(cfg, toProps(opts)) - maps.Copy(cfg, rsp.Overrides) - - o := fromProps(cfg) - o.awsConfig = opts.awsConfig - o.tlsConfig = opts.tlsConfig - - if uri, ok := cfg["uri"]; ok { - r.baseURI, err = url.Parse(uri) - if err != nil { - return nil, err - } - r.baseURI = r.baseURI.JoinPath("v1") - } - - return o, nil -} - -func (r *RestCatalog) CatalogType() CatalogType { return REST } - -func checkValidNamespace(ident table.Identifier) error { - if len(ident) < 1 { - return fmt.Errorf("%w: empty namespace identifier", ErrNoSuchNamespace) - } - return nil -} - -func (r *RestCatalog) ListTables(ctx context.Context, namespace table.Identifier) ([]table.Identifier, error) { - if err := checkValidNamespace(namespace); err != nil { - return nil, err - } - - ns := strings.Join(namespace, namespaceSeparator) - path := []string{"namespaces", ns, "tables"} - - type resp struct { - Identifiers []struct { - Namespace []string `json:"namespace"` - Name string `json:"name"` - } `json:"identifiers"` - } - - rsp, err := doGet[resp](ctx, r.baseURI, path, r.cl, map[int]error{http.StatusNotFound: ErrNoSuchNamespace}) - if err != nil { - return nil, err - } - - out := make([]table.Identifier, len(rsp.Identifiers)) - for i, id := range rsp.Identifiers { - out[i] = append(id.Namespace, id.Name) - } - return out, nil -} - -func splitIdentForPath(ident table.Identifier) (string, string, error) { - if len(ident) < 1 { - return "", "", fmt.Errorf("%w: missing namespace or invalid identifier %v", - ErrNoSuchTable, strings.Join(ident, ".")) - } - - return strings.Join(NamespaceFromIdent(ident), namespaceSeparator), TableNameFromIdent(ident), nil -} - -type tblResponse struct { - MetadataLoc string `json:"metadata-location"` - RawMetadata json.RawMessage `json:"metadata"` - Config iceberg.Properties `json:"config"` - Metadata table.Metadata `json:"-"` -} - -func (t *tblResponse) UnmarshalJSON(b []byte) (err error) { - type Alias tblResponse - if err = json.Unmarshal(b, (*Alias)(t)); err != nil { - return err - } - - t.Metadata, err = table.ParseMetadataBytes(t.RawMetadata) - return -} - -func (r *RestCatalog) LoadTable(ctx context.Context, identifier table.Identifier, props iceberg.Properties) (*table.Table, error) { - ns, tbl, err := splitIdentForPath(identifier) - if err != nil { - return nil, err - } - - if props == nil { - props = iceberg.Properties{} - } - - ret, err := doGet[tblResponse](ctx, r.baseURI, []string{"namespaces", ns, "tables", tbl}, - r.cl, map[int]error{http.StatusNotFound: ErrNoSuchTable}) - if err != nil { - return nil, err - } - - id := identifier - if r.name != "" { - id = append([]string{r.name}, identifier...) - } - - tblProps := maps.Clone(r.props) - maps.Copy(tblProps, props) - maps.Copy(tblProps, ret.Metadata.Properties()) - for k, v := range ret.Config { - tblProps[k] = v - } - - iofs, err := iceio.LoadFS(tblProps, ret.MetadataLoc) - if err != nil { - return nil, err - } - return table.New(id, ret.Metadata, ret.MetadataLoc, iofs), nil -} - -func (r *RestCatalog) DropTable(ctx context.Context, identifier table.Identifier) error { - return fmt.Errorf("%w: [Rest Catalog] drop table", iceberg.ErrNotImplemented) -} - -func (r *RestCatalog) RenameTable(ctx context.Context, from, to table.Identifier) (*table.Table, error) { - return nil, fmt.Errorf("%w: [Rest Catalog] rename table", iceberg.ErrNotImplemented) -} - -func (r *RestCatalog) CreateNamespace(ctx context.Context, namespace table.Identifier, props iceberg.Properties) error { - if err := checkValidNamespace(namespace); err != nil { - return err - } - - _, err := doPost[map[string]any, struct{}](ctx, r.baseURI, []string{"namespaces"}, - map[string]any{"namespace": namespace, "properties": props}, r.cl, map[int]error{ - http.StatusNotFound: ErrNoSuchNamespace, http.StatusConflict: ErrNamespaceAlreadyExists}) - return err -} - -func (r *RestCatalog) DropNamespace(ctx context.Context, namespace table.Identifier) error { - if err := checkValidNamespace(namespace); err != nil { - return err - } - - _, err := doDelete[struct{}](ctx, r.baseURI, []string{"namespaces", strings.Join(namespace, namespaceSeparator)}, - r.cl, map[int]error{http.StatusNotFound: ErrNoSuchNamespace}) - - return err -} - -func (r *RestCatalog) ListNamespaces(ctx context.Context, parent table.Identifier) ([]table.Identifier, error) { - uri := r.baseURI.JoinPath("namespaces") - if len(parent) != 0 { - v := url.Values{} - v.Set("parent", strings.Join(parent, namespaceSeparator)) - uri.RawQuery = v.Encode() - } - - type rsptype struct { - Namespaces []table.Identifier `json:"namespaces"` - } - - rsp, err := doGet[rsptype](ctx, uri, []string{}, r.cl, map[int]error{http.StatusNotFound: ErrNoSuchNamespace}) - if err != nil { - return nil, err - } - - return rsp.Namespaces, nil -} - -func (r *RestCatalog) LoadNamespaceProperties(ctx context.Context, namespace table.Identifier) (iceberg.Properties, error) { - if err := checkValidNamespace(namespace); err != nil { - return nil, err - } - - type nsresponse struct { - Namespace table.Identifier `json:"namespace"` - Props iceberg.Properties `json:"properties"` - } - - rsp, err := doGet[nsresponse](ctx, r.baseURI, []string{"namespaces", strings.Join(namespace, namespaceSeparator)}, - r.cl, map[int]error{http.StatusNotFound: ErrNoSuchNamespace}) - if err != nil { - return nil, err - } - - return rsp.Props, nil -} - -func (r *RestCatalog) UpdateNamespaceProperties(ctx context.Context, namespace table.Identifier, - removals []string, updates iceberg.Properties) (PropertiesUpdateSummary, error) { - - if err := checkValidNamespace(namespace); err != nil { - return PropertiesUpdateSummary{}, err - } - - type payload struct { - Remove []string `json:"removals"` - Updates iceberg.Properties `json:"updates"` - } - - ns := strings.Join(namespace, namespaceSeparator) - return doPost[payload, PropertiesUpdateSummary](ctx, r.baseURI, []string{"namespaces", ns, "properties"}, - payload{Remove: removals, Updates: updates}, r.cl, map[int]error{http.StatusNotFound: ErrNoSuchNamespace}) -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package catalog + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/json" + "errors" + "fmt" + "hash" + "io" + "maps" + "net/http" + "net/url" + "strings" + "time" + + "github.com/apache/iceberg-go" + iceio "github.com/apache/iceberg-go/io" + "github.com/apache/iceberg-go/table" + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/aws-sdk-go-v2/config" +) + +var ( + _ Catalog = (*RestCatalog)(nil) +) + +const ( + authorizationHeader = "Authorization" + bearerPrefix = "Bearer" + namespaceSeparator = "\x1F" + keyPrefix = "prefix" + + icebergRestSpecVersion = "0.14.1" + + keyRestSigV4 = "rest.sigv4-enabled" + keyRestSigV4Region = "rest.signing-region" + keyRestSigV4Service = "rest.signing-name" + keyAuthUrl = "rest.authorization-url" +) + +var ( + ErrRESTError = errors.New("REST error") + ErrBadRequest = fmt.Errorf("%w: bad request", ErrRESTError) + ErrForbidden = fmt.Errorf("%w: forbidden", ErrRESTError) + ErrUnauthorized = fmt.Errorf("%w: unauthorized", ErrRESTError) + ErrAuthorizationExpired = fmt.Errorf("%w: authorization expired", ErrRESTError) + ErrServiceUnavailable = fmt.Errorf("%w: service unavailable", ErrRESTError) + ErrServerError = fmt.Errorf("%w: server error", ErrRESTError) + ErrCommitFailed = fmt.Errorf("%w: commit failed, refresh and try again", ErrRESTError) + ErrCommitStateUnknown = fmt.Errorf("%w: commit failed due to unknown reason", ErrRESTError) + ErrOAuthError = fmt.Errorf("%w: oauth error", ErrRESTError) +) + +type errorResponse struct { + Message string `json:"message"` + Type string `json:"type"` + Code int `json:"code"` + + wrapping error +} + +func (e errorResponse) Unwrap() error { return e.wrapping } +func (e errorResponse) Error() string { + return e.Type + ": " + e.Message +} + +type oauthTokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` + RefreshToken string `json:"refresh_token"` +} + +type oauthErrorResponse struct { + Err string `json:"error"` + ErrDesc string `json:"error_description"` + ErrURI string `json:"error_uri"` +} + +func (o oauthErrorResponse) Unwrap() error { return ErrOAuthError } +func (o oauthErrorResponse) Error() string { + msg := o.Err + if o.ErrDesc != "" { + msg += ": " + o.ErrDesc + } + + if o.ErrURI != "" { + msg += " (" + o.ErrURI + ")" + } + return msg +} + +type configResponse struct { + Defaults iceberg.Properties `json:"defaults"` + Overrides iceberg.Properties `json:"overrides"` +} + +type sessionTransport struct { + http.Transport + + defaultHeaders http.Header + signer v4.HTTPSigner + cfg aws.Config + service string + h hash.Hash +} + +// from https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/aws/signer/v4#Signer.SignHTTP +const emptyStringHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + +func (s *sessionTransport) RoundTrip(r *http.Request) (*http.Response, error) { + for k, v := range s.defaultHeaders { + for _, hdr := range v { + r.Header.Add(k, hdr) + } + } + + if s.signer != nil { + var payloadHash string + if r.Body == nil { + payloadHash = emptyStringHash + } else { + rdr, err := r.GetBody() + if err != nil { + return nil, err + } + + if _, err = io.Copy(s.h, rdr); err != nil { + return nil, err + } + + payloadHash = string(s.h.Sum(nil)) + s.h.Reset() + } + + creds, err := s.cfg.Credentials.Retrieve(r.Context()) + if err != nil { + return nil, err + } + + // modifies the request in place + err = s.signer.SignHTTP(r.Context(), creds, r, payloadHash, + s.service, s.cfg.Region, time.Now()) + if err != nil { + return nil, err + } + } + + return s.Transport.RoundTrip(r) +} + +func do[T any](ctx context.Context, method string, baseURI *url.URL, path []string, cl *http.Client, override map[int]error, allowNoContent bool) (ret T, err error) { + var ( + req *http.Request + rsp *http.Response + ) + + uri := baseURI.JoinPath(path...).String() + if req, err = http.NewRequestWithContext(ctx, method, uri, nil); err != nil { + return + } + + if rsp, err = cl.Do(req); err != nil { + return + } + + if allowNoContent && rsp.StatusCode == http.StatusNoContent { + return + } + + if rsp.StatusCode != http.StatusOK { + return ret, handleNon200(rsp, override) + } + + defer rsp.Body.Close() + if err = json.NewDecoder(rsp.Body).Decode(&ret); err != nil { + return ret, fmt.Errorf("%w: error decoding json payload: `%s`", ErrRESTError, err.Error()) + } + + return +} + +func doGet[T any](ctx context.Context, baseURI *url.URL, path []string, cl *http.Client, override map[int]error) (ret T, err error) { + return do[T](ctx, http.MethodGet, baseURI, path, cl, override, false) +} + +func doDelete[T any](ctx context.Context, baseURI *url.URL, path []string, cl *http.Client, override map[int]error) (ret T, err error) { + return do[T](ctx, http.MethodDelete, baseURI, path, cl, override, true) +} + +func doPost[Payload, Result any](ctx context.Context, baseURI *url.URL, path []string, payload Payload, cl *http.Client, override map[int]error) (ret Result, err error) { + var ( + req *http.Request + rsp *http.Response + data []byte + ) + + uri := baseURI.JoinPath(path...).String() + data, err = json.Marshal(payload) + if err != nil { + return + } + + req, err = http.NewRequestWithContext(ctx, http.MethodPost, uri, bytes.NewReader(data)) + if err != nil { + return + } + + req.Header.Add("Content-Type", "application/json") + + rsp, err = cl.Do(req) + if err != nil { + return + } + + if rsp.StatusCode != http.StatusOK { + return ret, handleNon200(rsp, override) + } + + if rsp.ContentLength == 0 { + return + } + + defer rsp.Body.Close() + if err = json.NewDecoder(rsp.Body).Decode(&ret); err != nil { + return ret, fmt.Errorf("%w: error decoding json payload: `%s`", ErrRESTError, err.Error()) + } + + return +} + +func handleNon200(rsp *http.Response, override map[int]error) error { + var e errorResponse + + dec := json.NewDecoder(rsp.Body) + dec.Decode(&struct { + Error *errorResponse `json:"error"` + }{Error: &e}) + + if override != nil { + if err, ok := override[rsp.StatusCode]; ok { + e.wrapping = err + return e + } + } + + switch rsp.StatusCode { + case http.StatusBadRequest: + e.wrapping = ErrBadRequest + case http.StatusUnauthorized: + e.wrapping = ErrUnauthorized + case http.StatusForbidden: + e.wrapping = ErrForbidden + case http.StatusUnprocessableEntity: + e.wrapping = ErrRESTError + case 419: + e.wrapping = ErrAuthorizationExpired + case http.StatusNotImplemented: + e.wrapping = iceberg.ErrNotImplemented + case http.StatusServiceUnavailable: + e.wrapping = ErrServiceUnavailable + default: + if 500 <= rsp.StatusCode && rsp.StatusCode < 600 { + e.wrapping = ErrServerError + } else { + e.wrapping = ErrRESTError + } + } + + return e +} + +func ToRestIdentifier(ident ...string) table.Identifier { + if len(ident) == 1 { + if ident[0] == "" { + return nil + } + return table.Identifier(strings.Split(ident[0], ".")) + } + + return table.Identifier(ident) +} + +func fromProps(props iceberg.Properties) *options { + o := &options{} + for k, v := range props { + switch k { + case keyOauthToken: + o.oauthToken = v + case keyWarehouseLocation: + o.warehouseLocation = v + case keyMetadataLocation: + o.metadataLocation = v + case keyRestSigV4: + o.enableSigv4 = strings.ToLower(v) == "true" + case keyRestSigV4Region: + o.sigv4Region = v + case keyRestSigV4Service: + o.sigv4Service = v + case keyAuthUrl: + u, err := url.Parse(v) + if err != nil { + continue + } + o.authUri = u + case keyOauthCredential: + o.credential = v + case keyPrefix: + o.prefix = v + } + } + return o +} + +func toProps(o *options) iceberg.Properties { + props := iceberg.Properties{} + + setIf := func(key, v string) { + if v != "" { + props[key] = v + } + } + + setIf(keyOauthCredential, o.credential) + setIf(keyOauthToken, o.oauthToken) + setIf(keyWarehouseLocation, o.warehouseLocation) + setIf(keyMetadataLocation, o.metadataLocation) + if o.enableSigv4 { + props[keyRestSigV4] = "true" + setIf(keyRestSigV4Region, o.sigv4Region) + setIf(keyRestSigV4Service, o.sigv4Service) + } + + setIf(keyPrefix, o.prefix) + if o.authUri != nil { + setIf(keyAuthUrl, o.authUri.String()) + } + return props +} + +type RestCatalog struct { + baseURI *url.URL + cl *http.Client + + name string + props iceberg.Properties +} + +func NewRestCatalog(name, uri string, opts ...Option[RestCatalog]) (*RestCatalog, error) { + ops := &options{} + for _, o := range opts { + o(ops) + } + + baseuri, err := url.Parse(uri) + if err != nil { + return nil, err + } + + r := &RestCatalog{ + name: name, + baseURI: baseuri.JoinPath("v1"), + } + + if ops, err = r.fetchConfig(ops); err != nil { + return nil, err + } + + cl, err := r.createSession(ops) + if err != nil { + return nil, err + } + + r.cl = cl + if ops.prefix != "" { + r.baseURI = r.baseURI.JoinPath(ops.prefix) + } + r.props = toProps(ops) + return r, nil +} + +func (r *RestCatalog) fetchAccessToken(cl *http.Client, creds string, opts *options) (string, error) { + clientID, clientSecret, hasID := strings.Cut(creds, ":") + if !hasID { + clientID, clientSecret = "", clientID + } + + data := url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {clientID}, + "client_secret": {clientSecret}, + "scope": {"catalog"}, + } + + uri := opts.authUri + if uri == nil { + uri = r.baseURI.JoinPath("oauth/tokens") + } + + rsp, err := cl.PostForm(uri.String(), data) + if err != nil { + return "", err + } + + if rsp.StatusCode == http.StatusOK { + defer rsp.Body.Close() + dec := json.NewDecoder(rsp.Body) + var tok oauthTokenResponse + if err := dec.Decode(&tok); err != nil { + return "", fmt.Errorf("failed to decode oauth token response: %w", err) + } + + return tok.AccessToken, nil + } + + switch rsp.StatusCode { + case http.StatusUnauthorized, http.StatusBadRequest: + defer rsp.Request.GetBody() + dec := json.NewDecoder(rsp.Body) + var oauthErr oauthErrorResponse + if err := dec.Decode(&oauthErr); err != nil { + return "", fmt.Errorf("failed to decode oauth error: %w", err) + } + + return "", oauthErr + default: + return "", handleNon200(rsp, nil) + } +} + +func (r *RestCatalog) createSession(opts *options) (*http.Client, error) { + session := &sessionTransport{ + Transport: http.Transport{TLSClientConfig: opts.tlsConfig}, + defaultHeaders: http.Header{}, + } + cl := &http.Client{Transport: session} + + token := opts.oauthToken + if token == "" && opts.credential != "" { + var err error + if token, err = r.fetchAccessToken(cl, opts.credential, opts); err != nil { + return nil, fmt.Errorf("auth error: %w", err) + } + } + + if token != "" { + session.defaultHeaders.Set(authorizationHeader, bearerPrefix+" "+token) + } + + session.defaultHeaders.Set("X-Client-Version", icebergRestSpecVersion) + session.defaultHeaders.Set("Content-Type", "application/json") + session.defaultHeaders.Set("User-Agent", "GoIceberg/"+iceberg.Version()) + session.defaultHeaders.Set("X-Iceberg-Access-Delegation", "vended-credentials") + + if opts.enableSigv4 { + cfg, err := config.LoadDefaultConfig(context.Background()) + if err != nil { + return nil, err + } + + if opts.sigv4Region != "" { + cfg.Region = opts.sigv4Region + } + + session.cfg, session.service = cfg, opts.sigv4Service + session.signer, session.h = v4.NewSigner(), sha256.New() + } + + return cl, nil +} + +func (r *RestCatalog) fetchConfig(opts *options) (*options, error) { + params := url.Values{} + if opts.warehouseLocation != "" { + params.Set(keyWarehouseLocation, opts.warehouseLocation) + } + + route := r.baseURI.JoinPath("config") + route.RawQuery = params.Encode() + + sess, err := r.createSession(opts) + if err != nil { + return nil, err + } + + rsp, err := doGet[configResponse](context.Background(), route, []string{}, sess, nil) + if err != nil { + return nil, err + } + + cfg := rsp.Defaults + maps.Copy(cfg, toProps(opts)) + maps.Copy(cfg, rsp.Overrides) + + o := fromProps(cfg) + o.awsConfig = opts.awsConfig + o.tlsConfig = opts.tlsConfig + + if uri, ok := cfg["uri"]; ok { + r.baseURI, err = url.Parse(uri) + if err != nil { + return nil, err + } + r.baseURI = r.baseURI.JoinPath("v1") + } + + return o, nil +} + +func (r *RestCatalog) CatalogType() CatalogType { return REST } + +func checkValidNamespace(ident table.Identifier) error { + if len(ident) < 1 { + return fmt.Errorf("%w: empty namespace identifier", ErrNoSuchNamespace) + } + return nil +} + +func (r *RestCatalog) ListTables(ctx context.Context, namespace table.Identifier) ([]table.Identifier, error) { + if err := checkValidNamespace(namespace); err != nil { + return nil, err + } + + ns := strings.Join(namespace, namespaceSeparator) + path := []string{"namespaces", ns, "tables"} + + type resp struct { + Identifiers []struct { + Namespace []string `json:"namespace"` + Name string `json:"name"` + } `json:"identifiers"` + } + + rsp, err := doGet[resp](ctx, r.baseURI, path, r.cl, map[int]error{http.StatusNotFound: ErrNoSuchNamespace}) + if err != nil { + return nil, err + } + + out := make([]table.Identifier, len(rsp.Identifiers)) + for i, id := range rsp.Identifiers { + out[i] = append(id.Namespace, id.Name) + } + return out, nil +} + +func splitIdentForPath(ident table.Identifier) (string, string, error) { + if len(ident) < 1 { + return "", "", fmt.Errorf("%w: missing namespace or invalid identifier %v", + ErrNoSuchTable, strings.Join(ident, ".")) + } + + return strings.Join(NamespaceFromIdent(ident), namespaceSeparator), TableNameFromIdent(ident), nil +} + +type tblResponse struct { + MetadataLoc string `json:"metadata-location"` + RawMetadata json.RawMessage `json:"metadata"` + Config iceberg.Properties `json:"config"` + Metadata table.Metadata `json:"-"` +} + +func (t *tblResponse) UnmarshalJSON(b []byte) (err error) { + type Alias tblResponse + if err = json.Unmarshal(b, (*Alias)(t)); err != nil { + return err + } + + t.Metadata, err = table.ParseMetadataBytes(t.RawMetadata) + return +} + +func (r *RestCatalog) LoadTable(ctx context.Context, identifier table.Identifier, props iceberg.Properties) (*table.Table, error) { + ns, tbl, err := splitIdentForPath(identifier) + if err != nil { + return nil, err + } + + if props == nil { + props = iceberg.Properties{} + } + + ret, err := doGet[tblResponse](ctx, r.baseURI, []string{"namespaces", ns, "tables", tbl}, + r.cl, map[int]error{http.StatusNotFound: ErrNoSuchTable}) + if err != nil { + return nil, err + } + + id := identifier + if r.name != "" { + id = append([]string{r.name}, identifier...) + } + + tblProps := maps.Clone(r.props) + maps.Copy(tblProps, props) + maps.Copy(tblProps, ret.Metadata.Properties()) + for k, v := range ret.Config { + tblProps[k] = v + } + + iofs, err := iceio.LoadFS(tblProps, ret.MetadataLoc) + if err != nil { + return nil, err + } + return table.New(id, ret.Metadata, ret.MetadataLoc, iofs), nil +} + +func (r *RestCatalog) DropTable(ctx context.Context, identifier table.Identifier) error { + return fmt.Errorf("%w: [Rest Catalog] drop table", iceberg.ErrNotImplemented) +} + +func (r *RestCatalog) RenameTable(ctx context.Context, from, to table.Identifier) (*table.Table, error) { + return nil, fmt.Errorf("%w: [Rest Catalog] rename table", iceberg.ErrNotImplemented) +} + +func (r *RestCatalog) CreateNamespace(ctx context.Context, namespace table.Identifier, props iceberg.Properties) error { + if err := checkValidNamespace(namespace); err != nil { + return err + } + + _, err := doPost[map[string]any, struct{}](ctx, r.baseURI, []string{"namespaces"}, + map[string]any{"namespace": namespace, "properties": props}, r.cl, map[int]error{ + http.StatusNotFound: ErrNoSuchNamespace, http.StatusConflict: ErrNamespaceAlreadyExists}) + return err +} + +func (r *RestCatalog) DropNamespace(ctx context.Context, namespace table.Identifier) error { + if err := checkValidNamespace(namespace); err != nil { + return err + } + + _, err := doDelete[struct{}](ctx, r.baseURI, []string{"namespaces", strings.Join(namespace, namespaceSeparator)}, + r.cl, map[int]error{http.StatusNotFound: ErrNoSuchNamespace}) + + return err +} + +func (r *RestCatalog) ListNamespaces(ctx context.Context, parent table.Identifier) ([]table.Identifier, error) { + uri := r.baseURI.JoinPath("namespaces") + if len(parent) != 0 { + v := url.Values{} + v.Set("parent", strings.Join(parent, namespaceSeparator)) + uri.RawQuery = v.Encode() + } + + type rsptype struct { + Namespaces []table.Identifier `json:"namespaces"` + } + + rsp, err := doGet[rsptype](ctx, uri, []string{}, r.cl, map[int]error{http.StatusNotFound: ErrNoSuchNamespace}) + if err != nil { + return nil, err + } + + return rsp.Namespaces, nil +} + +func (r *RestCatalog) LoadNamespaceProperties(ctx context.Context, namespace table.Identifier) (iceberg.Properties, error) { + if err := checkValidNamespace(namespace); err != nil { + return nil, err + } + + type nsresponse struct { + Namespace table.Identifier `json:"namespace"` + Props iceberg.Properties `json:"properties"` + } + + rsp, err := doGet[nsresponse](ctx, r.baseURI, []string{"namespaces", strings.Join(namespace, namespaceSeparator)}, + r.cl, map[int]error{http.StatusNotFound: ErrNoSuchNamespace}) + if err != nil { + return nil, err + } + + return rsp.Props, nil +} + +func (r *RestCatalog) UpdateNamespaceProperties(ctx context.Context, namespace table.Identifier, + removals []string, updates iceberg.Properties) (PropertiesUpdateSummary, error) { + + if err := checkValidNamespace(namespace); err != nil { + return PropertiesUpdateSummary{}, err + } + + type payload struct { + Remove []string `json:"removals"` + Updates iceberg.Properties `json:"updates"` + } + + ns := strings.Join(namespace, namespaceSeparator) + return doPost[payload, PropertiesUpdateSummary](ctx, r.baseURI, []string{"namespaces", ns, "properties"}, + payload{Remove: removals, Updates: updates}, r.cl, map[int]error{http.StatusNotFound: ErrNoSuchNamespace}) +} diff --git a/catalog/rest_internal_test.go b/catalog/rest_internal_test.go index a03e2a5..a8b204c 100644 --- a/catalog/rest_internal_test.go +++ b/catalog/rest_internal_test.go @@ -1,123 +1,123 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package catalog - -import ( - "encoding/json" - "net/http" - "net/http/httptest" - "net/url" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestAuthHeader(t *testing.T) { - mux := http.NewServeMux() - srv := httptest.NewServer(mux) - - mux.HandleFunc("/v1/config", func(w http.ResponseWriter, r *http.Request) { - json.NewEncoder(w).Encode(map[string]any{ - "defaults": map[string]any{}, "overrides": map[string]any{}}) - }) - - mux.HandleFunc("/v1/oauth/tokens", func(w http.ResponseWriter, req *http.Request) { - assert.Equal(t, http.MethodPost, req.Method) - - assert.Equal(t, req.Header.Get("Content-Type"), "application/x-www-form-urlencoded") - - require.NoError(t, req.ParseForm()) - values := req.PostForm - assert.Equal(t, values.Get("grant_type"), "client_credentials") - assert.Equal(t, values.Get("client_id"), "client") - assert.Equal(t, values.Get("client_secret"), "secret") - assert.Equal(t, values.Get("scope"), "catalog") - - w.WriteHeader(http.StatusOK) - - json.NewEncoder(w).Encode(map[string]any{ - "access_token": "some_jwt_token", - "token_type": "Bearer", - "expires_in": 86400, - "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", - }) - }) - - cat, err := NewRestCatalog("rest", srv.URL, - WithCredential("client:secret")) - require.NoError(t, err) - assert.NotNil(t, cat) - - require.IsType(t, (*sessionTransport)(nil), cat.cl.Transport) - assert.Equal(t, http.Header{ - "Authorization": {"Bearer some_jwt_token"}, - "Content-Type": {"application/json"}, - "User-Agent": {"GoIceberg/(unknown version)"}, - "X-Client-Version": {icebergRestSpecVersion}, - "X-Iceberg-Access-Delegation": {"vended-credentials"}, - }, cat.cl.Transport.(*sessionTransport).defaultHeaders) -} - -func TestAuthUriHeader(t *testing.T) { - mux := http.NewServeMux() - srv := httptest.NewServer(mux) - - mux.HandleFunc("/v1/config", func(w http.ResponseWriter, r *http.Request) { - json.NewEncoder(w).Encode(map[string]any{ - "defaults": map[string]any{}, "overrides": map[string]any{}}) - }) - - mux.HandleFunc("/auth-token-url", func(w http.ResponseWriter, req *http.Request) { - assert.Equal(t, http.MethodPost, req.Method) - - assert.Equal(t, req.Header.Get("Content-Type"), "application/x-www-form-urlencoded") - - require.NoError(t, req.ParseForm()) - values := req.PostForm - assert.Equal(t, values.Get("grant_type"), "client_credentials") - assert.Equal(t, values.Get("client_id"), "client") - assert.Equal(t, values.Get("client_secret"), "secret") - assert.Equal(t, values.Get("scope"), "catalog") - - w.WriteHeader(http.StatusOK) - - json.NewEncoder(w).Encode(map[string]any{ - "access_token": "some_jwt_token", - "token_type": "Bearer", - "expires_in": 86400, - "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", - }) - }) - - authUri, err := url.Parse(srv.URL) - require.NoError(t, err) - cat, err := NewRestCatalog("rest", srv.URL, - WithCredential("client:secret"), WithAuthURI(authUri.JoinPath("auth-token-url"))) - require.NoError(t, err) - assert.NotNil(t, cat) - - require.IsType(t, (*sessionTransport)(nil), cat.cl.Transport) - assert.Equal(t, http.Header{ - "Authorization": {"Bearer some_jwt_token"}, - "Content-Type": {"application/json"}, - "User-Agent": {"GoIceberg/(unknown version)"}, - "X-Client-Version": {icebergRestSpecVersion}, - "X-Iceberg-Access-Delegation": {"vended-credentials"}, - }, cat.cl.Transport.(*sessionTransport).defaultHeaders) -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package catalog + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAuthHeader(t *testing.T) { + mux := http.NewServeMux() + srv := httptest.NewServer(mux) + + mux.HandleFunc("/v1/config", func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]any{ + "defaults": map[string]any{}, "overrides": map[string]any{}}) + }) + + mux.HandleFunc("/v1/oauth/tokens", func(w http.ResponseWriter, req *http.Request) { + assert.Equal(t, http.MethodPost, req.Method) + + assert.Equal(t, req.Header.Get("Content-Type"), "application/x-www-form-urlencoded") + + require.NoError(t, req.ParseForm()) + values := req.PostForm + assert.Equal(t, values.Get("grant_type"), "client_credentials") + assert.Equal(t, values.Get("client_id"), "client") + assert.Equal(t, values.Get("client_secret"), "secret") + assert.Equal(t, values.Get("scope"), "catalog") + + w.WriteHeader(http.StatusOK) + + json.NewEncoder(w).Encode(map[string]any{ + "access_token": "some_jwt_token", + "token_type": "Bearer", + "expires_in": 86400, + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + }) + }) + + cat, err := NewRestCatalog("rest", srv.URL, + WithCredential("client:secret")) + require.NoError(t, err) + assert.NotNil(t, cat) + + require.IsType(t, (*sessionTransport)(nil), cat.cl.Transport) + assert.Equal(t, http.Header{ + "Authorization": {"Bearer some_jwt_token"}, + "Content-Type": {"application/json"}, + "User-Agent": {"GoIceberg/(unknown version)"}, + "X-Client-Version": {icebergRestSpecVersion}, + "X-Iceberg-Access-Delegation": {"vended-credentials"}, + }, cat.cl.Transport.(*sessionTransport).defaultHeaders) +} + +func TestAuthUriHeader(t *testing.T) { + mux := http.NewServeMux() + srv := httptest.NewServer(mux) + + mux.HandleFunc("/v1/config", func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]any{ + "defaults": map[string]any{}, "overrides": map[string]any{}}) + }) + + mux.HandleFunc("/auth-token-url", func(w http.ResponseWriter, req *http.Request) { + assert.Equal(t, http.MethodPost, req.Method) + + assert.Equal(t, req.Header.Get("Content-Type"), "application/x-www-form-urlencoded") + + require.NoError(t, req.ParseForm()) + values := req.PostForm + assert.Equal(t, values.Get("grant_type"), "client_credentials") + assert.Equal(t, values.Get("client_id"), "client") + assert.Equal(t, values.Get("client_secret"), "secret") + assert.Equal(t, values.Get("scope"), "catalog") + + w.WriteHeader(http.StatusOK) + + json.NewEncoder(w).Encode(map[string]any{ + "access_token": "some_jwt_token", + "token_type": "Bearer", + "expires_in": 86400, + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + }) + }) + + authUri, err := url.Parse(srv.URL) + require.NoError(t, err) + cat, err := NewRestCatalog("rest", srv.URL, + WithCredential("client:secret"), WithAuthURI(authUri.JoinPath("auth-token-url"))) + require.NoError(t, err) + assert.NotNil(t, cat) + + require.IsType(t, (*sessionTransport)(nil), cat.cl.Transport) + assert.Equal(t, http.Header{ + "Authorization": {"Bearer some_jwt_token"}, + "Content-Type": {"application/json"}, + "User-Agent": {"GoIceberg/(unknown version)"}, + "X-Client-Version": {icebergRestSpecVersion}, + "X-Iceberg-Access-Delegation": {"vended-credentials"}, + }, cat.cl.Transport.(*sessionTransport).defaultHeaders) +} diff --git a/catalog/rest_test.go b/catalog/rest_test.go index 618c5e0..a7747c2 100644 --- a/catalog/rest_test.go +++ b/catalog/rest_test.go @@ -1,817 +1,817 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package catalog_test - -import ( - "context" - "crypto/tls" - "crypto/x509" - "encoding/json" - "net/http" - "net/http/httptest" - "net/url" - "testing" - - "github.com/apache/iceberg-go" - "github.com/apache/iceberg-go/catalog" - "github.com/apache/iceberg-go/table" - "github.com/stretchr/testify/suite" -) - -const ( - TestCreds = "client:secret" - TestToken = "some_jwt_token" -) - -var ( - TestHeaders = http.Header{ - "X-Client-Version": {"0.14.1"}, - "User-Agent": {"GoIceberg/(unknown version)"}, - "Authorization": {"Bearer " + TestToken}, - } - OAuthTestHeaders = http.Header{ - "Content-Type": {"application/x-www-form-urlencoded"}, - } -) - -type RestCatalogSuite struct { - suite.Suite - - srv *httptest.Server - mux *http.ServeMux - - configVals url.Values -} - -func (r *RestCatalogSuite) SetupTest() { - r.mux = http.NewServeMux() - - r.mux.HandleFunc("/v1/config", func(w http.ResponseWriter, req *http.Request) { - r.Require().Equal(http.MethodGet, req.Method) - r.configVals = req.URL.Query() - - json.NewEncoder(w).Encode(map[string]any{ - "defaults": map[string]any{}, - "overrides": map[string]any{}, - }) - }) - - r.srv = httptest.NewServer(r.mux) -} - -func (r *RestCatalogSuite) TearDownTest() { - r.srv.Close() - r.srv = nil - r.mux = nil - r.configVals = nil -} - -func (r *RestCatalogSuite) TestToken200() { - r.mux.HandleFunc("/v1/oauth/tokens", func(w http.ResponseWriter, req *http.Request) { - r.Equal(http.MethodPost, req.Method) - - r.Equal(req.Header.Get("Content-Type"), "application/x-www-form-urlencoded") - - r.Require().NoError(req.ParseForm()) - values := req.PostForm - r.Equal(values.Get("grant_type"), "client_credentials") - r.Equal(values.Get("client_id"), "client") - r.Equal(values.Get("client_secret"), "secret") - r.Equal(values.Get("scope"), "catalog") - - w.WriteHeader(http.StatusOK) - - json.NewEncoder(w).Encode(map[string]any{ - "access_token": TestToken, - "token_type": "Bearer", - "expires_in": 86400, - "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", - }) - }) - - cat, err := catalog.NewRestCatalog("rest", r.srv.URL, - catalog.WithWarehouseLocation("s3://some-bucket"), - catalog.WithCredential(TestCreds)) - r.Require().NoError(err) - - r.NotNil(cat) - r.Equal(r.configVals.Get("warehouse"), "s3://some-bucket") -} - -func (r *RestCatalogSuite) TestToken400() { - r.mux.HandleFunc("/v1/oauth/tokens", func(w http.ResponseWriter, req *http.Request) { - r.Equal(http.MethodPost, req.Method) - - r.Equal(req.Header.Get("Content-Type"), "application/x-www-form-urlencoded") - - w.WriteHeader(http.StatusBadRequest) - - json.NewEncoder(w).Encode(map[string]any{ - "error": "invalid_client", - "error_description": "credentials for key invalid_key do not match", - }) - }) - - cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithCredential(TestCreds)) - r.Nil(cat) - - r.ErrorIs(err, catalog.ErrRESTError) - r.ErrorIs(err, catalog.ErrOAuthError) - r.ErrorContains(err, "invalid_client: credentials for key invalid_key do not match") -} - -func (r *RestCatalogSuite) TestToken200AuthUrl() { - r.mux.HandleFunc("/auth-token-url", func(w http.ResponseWriter, req *http.Request) { - r.Equal(http.MethodPost, req.Method) - - r.Equal(req.Header.Get("Content-Type"), "application/x-www-form-urlencoded") - - r.Require().NoError(req.ParseForm()) - values := req.PostForm - r.Equal(values.Get("grant_type"), "client_credentials") - r.Equal(values.Get("client_id"), "client") - r.Equal(values.Get("client_secret"), "secret") - r.Equal(values.Get("scope"), "catalog") - - w.WriteHeader(http.StatusOK) - - json.NewEncoder(w).Encode(map[string]any{ - "access_token": TestToken, - "token_type": "Bearer", - "expires_in": 86400, - "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", - }) - }) - - authUri, err := url.Parse(r.srv.URL) - r.Require().NoError(err) - cat, err := catalog.NewRestCatalog("rest", r.srv.URL, - catalog.WithWarehouseLocation("s3://some-bucket"), - catalog.WithCredential(TestCreds), catalog.WithAuthURI(authUri.JoinPath("auth-token-url"))) - - r.Require().NoError(err) - - r.NotNil(cat) - r.Equal(r.configVals.Get("warehouse"), "s3://some-bucket") -} - -func (r *RestCatalogSuite) TestToken401() { - r.mux.HandleFunc("/v1/oauth/tokens", func(w http.ResponseWriter, req *http.Request) { - r.Equal(http.MethodPost, req.Method) - - r.Equal(req.Header.Get("Content-Type"), "application/x-www-form-urlencoded") - - w.WriteHeader(http.StatusUnauthorized) - - json.NewEncoder(w).Encode(map[string]any{ - "error": "invalid_client", - "error_description": "credentials for key invalid_key do not match", - }) - }) - - cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithCredential(TestCreds)) - r.Nil(cat) - - r.ErrorIs(err, catalog.ErrRESTError) - r.ErrorIs(err, catalog.ErrOAuthError) - r.ErrorContains(err, "invalid_client: credentials for key invalid_key do not match") -} - -func (r *RestCatalogSuite) TestListTables200() { - namespace := "examples" - r.mux.HandleFunc("/v1/namespaces/"+namespace+"/tables", func(w http.ResponseWriter, req *http.Request) { - r.Require().Equal(http.MethodGet, req.Method) - - for k, v := range TestHeaders { - r.Equal(v, req.Header.Values(k)) - } - - json.NewEncoder(w).Encode(map[string]any{ - "identifiers": []any{ - map[string]any{ - "namespace": []string{namespace}, - "name": "fooshare", - }, - }, - }) - }) - - cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) - r.Require().NoError(err) - - tables, err := cat.ListTables(context.Background(), catalog.ToRestIdentifier(namespace)) - r.Require().NoError(err) - r.Equal([]table.Identifier{{"examples", "fooshare"}}, tables) -} - -func (r *RestCatalogSuite) TestListTablesPrefixed200() { - r.mux.HandleFunc("/v1/oauth/tokens", func(w http.ResponseWriter, req *http.Request) { - r.Equal(http.MethodPost, req.Method) - - r.Equal(req.Header.Get("Content-Type"), "application/x-www-form-urlencoded") - - r.Require().NoError(req.ParseForm()) - values := req.PostForm - r.Equal(values.Get("grant_type"), "client_credentials") - r.Equal(values.Get("client_id"), "client") - r.Equal(values.Get("client_secret"), "secret") - r.Equal(values.Get("scope"), "catalog") - - w.WriteHeader(http.StatusOK) - - json.NewEncoder(w).Encode(map[string]any{ - "access_token": TestToken, - "token_type": "Bearer", - "expires_in": 86400, - "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", - }) - }) - - namespace := "examples" - r.mux.HandleFunc("/v1/prefix/namespaces/"+namespace+"/tables", func(w http.ResponseWriter, req *http.Request) { - r.Require().Equal(http.MethodGet, req.Method) - - for k, v := range TestHeaders { - r.Equal(v, req.Header.Values(k)) - } - - json.NewEncoder(w).Encode(map[string]any{ - "identifiers": []any{ - map[string]any{ - "namespace": []string{namespace}, - "name": "fooshare", - }, - }, - }) - }) - - cat, err := catalog.NewRestCatalog("rest", r.srv.URL, - catalog.WithPrefix("prefix"), - catalog.WithWarehouseLocation("s3://some-bucket"), - catalog.WithCredential(TestCreds)) - r.Require().NoError(err) - - r.NotNil(cat) - r.Equal(r.configVals.Get("warehouse"), "s3://some-bucket") - - tables, err := cat.ListTables(context.Background(), catalog.ToRestIdentifier(namespace)) - r.Require().NoError(err) - r.Equal([]table.Identifier{{"examples", "fooshare"}}, tables) -} - -func (r *RestCatalogSuite) TestListTables404() { - namespace := "examples" - r.mux.HandleFunc("/v1/namespaces/"+namespace+"/tables", func(w http.ResponseWriter, req *http.Request) { - r.Require().Equal(http.MethodGet, req.Method) - - for k, v := range TestHeaders { - r.Equal(v, req.Header.Values(k)) - } - - w.WriteHeader(http.StatusNotFound) - json.NewEncoder(w).Encode(map[string]any{ - "error": map[string]any{ - "message": "Namespace does not exist: personal in warehouse 8bcb0838-50fc-472d-9ddb-8feb89ef5f1e", - "type": "NoSuchNamespaceException", - "code": 404, - }, - }) - }) - - cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) - r.Require().NoError(err) - - _, err = cat.ListTables(context.Background(), catalog.ToRestIdentifier(namespace)) - r.ErrorIs(err, catalog.ErrNoSuchNamespace) - r.ErrorContains(err, "Namespace does not exist: personal in warehouse 8bcb0838-50fc-472d-9ddb-8feb89ef5f1e") -} - -func (r *RestCatalogSuite) TestListNamespaces200() { - r.mux.HandleFunc("/v1/namespaces", func(w http.ResponseWriter, req *http.Request) { - r.Require().Equal(http.MethodGet, req.Method) - - for k, v := range TestHeaders { - r.Equal(v, req.Header.Values(k)) - } - - json.NewEncoder(w).Encode(map[string]any{ - "namespaces": []table.Identifier{ - {"default"}, {"examples"}, {"fokko"}, {"system"}, - }, - }) - }) - - cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) - r.Require().NoError(err) - - results, err := cat.ListNamespaces(context.Background(), nil) - r.Require().NoError(err) - - r.Equal([]table.Identifier{{"default"}, {"examples"}, {"fokko"}, {"system"}}, results) -} - -func (r *RestCatalogSuite) TestListNamespaceWithParent200() { - r.mux.HandleFunc("/v1/namespaces", func(w http.ResponseWriter, req *http.Request) { - r.Require().Equal(http.MethodGet, req.Method) - r.Require().Equal("accounting", req.URL.Query().Get("parent")) - - for k, v := range TestHeaders { - r.Equal(v, req.Header.Values(k)) - } - - json.NewEncoder(w).Encode(map[string]any{ - "namespaces": []table.Identifier{ - {"accounting", "tax"}, - }, - }) - }) - - cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) - r.Require().NoError(err) - - results, err := cat.ListNamespaces(context.Background(), catalog.ToRestIdentifier("accounting")) - r.Require().NoError(err) - - r.Equal([]table.Identifier{{"accounting", "tax"}}, results) -} - -func (r *RestCatalogSuite) TestListNamespaces400() { - r.mux.HandleFunc("/v1/namespaces", func(w http.ResponseWriter, req *http.Request) { - r.Require().Equal(http.MethodGet, req.Method) - - for k, v := range TestHeaders { - r.Equal(v, req.Header.Values(k)) - } - - w.WriteHeader(http.StatusNotFound) - json.NewEncoder(w).Encode(map[string]any{ - "error": map[string]any{ - "message": "Namespace does not exist: personal in warehouse 8bcb0838-50fc-472d-9ddb-8feb89ef5f1e", - "type": "NoSuchNamespaceException", - "code": 404, - }, - }) - }) - - cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) - r.Require().NoError(err) - - _, err = cat.ListNamespaces(context.Background(), catalog.ToRestIdentifier("accounting")) - r.ErrorIs(err, catalog.ErrNoSuchNamespace) - r.ErrorContains(err, "Namespace does not exist: personal in warehouse 8bcb0838-50fc-472d-9ddb-8feb89ef5f1e") -} - -func (r *RestCatalogSuite) TestCreateNamespace200() { - r.mux.HandleFunc("/v1/namespaces", func(w http.ResponseWriter, req *http.Request) { - r.Require().Equal(http.MethodPost, req.Method) - r.Require().Equal("application/json", req.Header.Get("Content-Type")) - - for k, v := range TestHeaders { - r.Equal(v, req.Header.Values(k)) - } - - defer req.Body.Close() - dec := json.NewDecoder(req.Body) - body := struct { - Namespace table.Identifier `json:"namespace"` - Props iceberg.Properties `json:"properties"` - }{} - - r.Require().NoError(dec.Decode(&body)) - r.Equal(table.Identifier{"leden"}, body.Namespace) - r.Empty(body.Props) - - json.NewEncoder(w).Encode(map[string]any{ - "namespace": []string{"leden"}, "properties": map[string]any{}, - }) - }) - - cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) - r.Require().NoError(err) - - r.Require().NoError(cat.CreateNamespace(context.Background(), catalog.ToRestIdentifier("leden"), nil)) -} - -func (r *RestCatalogSuite) TestCreateNamespaceWithProps200() { - r.mux.HandleFunc("/v1/namespaces", func(w http.ResponseWriter, req *http.Request) { - r.Require().Equal(http.MethodPost, req.Method) - r.Require().Equal("application/json", req.Header.Get("Content-Type")) - - for k, v := range TestHeaders { - r.Equal(v, req.Header.Values(k)) - } - - defer req.Body.Close() - dec := json.NewDecoder(req.Body) - body := struct { - Namespace table.Identifier `json:"namespace"` - Props iceberg.Properties `json:"properties"` - }{} - - r.Require().NoError(dec.Decode(&body)) - r.Equal(table.Identifier{"leden"}, body.Namespace) - r.Equal(iceberg.Properties{"foo": "bar", "super": "duper"}, body.Props) - - json.NewEncoder(w).Encode(map[string]any{ - "namespace": []string{"leden"}, "properties": body.Props, - }) - }) - - cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) - r.Require().NoError(err) - - r.Require().NoError(cat.CreateNamespace(context.Background(), catalog.ToRestIdentifier("leden"), iceberg.Properties{"foo": "bar", "super": "duper"})) -} - -func (r *RestCatalogSuite) TestCreateNamespace409() { - r.mux.HandleFunc("/v1/namespaces", func(w http.ResponseWriter, req *http.Request) { - r.Require().Equal(http.MethodPost, req.Method) - r.Require().Equal("application/json", req.Header.Get("Content-Type")) - - for k, v := range TestHeaders { - r.Equal(v, req.Header.Values(k)) - } - - defer req.Body.Close() - dec := json.NewDecoder(req.Body) - body := struct { - Namespace table.Identifier `json:"namespace"` - Props iceberg.Properties `json:"properties"` - }{} - - r.Require().NoError(dec.Decode(&body)) - r.Equal(table.Identifier{"fokko"}, body.Namespace) - r.Empty(body.Props) - - w.WriteHeader(http.StatusConflict) - json.NewEncoder(w).Encode(map[string]any{ - "error": map[string]any{ - "message": "Namespace already exists: fokko in warehouse 8bcb0838-50fc-472d-9ddb-8feb89ef5f1e", - "type": "AlreadyExistsException", - "code": 409, - }, - }) - }) - - cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) - r.Require().NoError(err) - - err = cat.CreateNamespace(context.Background(), catalog.ToRestIdentifier("fokko"), nil) - r.ErrorIs(err, catalog.ErrNamespaceAlreadyExists) - r.ErrorContains(err, "fokko in warehouse") -} - -func (r *RestCatalogSuite) TestDropNamespace204() { - r.mux.HandleFunc("/v1/namespaces/examples", func(w http.ResponseWriter, req *http.Request) { - r.Require().Equal(http.MethodDelete, req.Method) - - for k, v := range TestHeaders { - r.Equal(v, req.Header.Values(k)) - } - - w.WriteHeader(http.StatusNoContent) - }) - - cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) - r.Require().NoError(err) - - r.NoError(cat.DropNamespace(context.Background(), catalog.ToRestIdentifier("examples"))) -} - -func (r *RestCatalogSuite) TestDropNamespace404() { - r.mux.HandleFunc("/v1/namespaces/examples", func(w http.ResponseWriter, req *http.Request) { - r.Require().Equal(http.MethodDelete, req.Method) - - for k, v := range TestHeaders { - r.Equal(v, req.Header.Values(k)) - } - - w.WriteHeader(http.StatusNotFound) - json.NewEncoder(w).Encode(map[string]any{ - "error": map[string]any{ - "message": "Namespace does not exist: examples in warehouse", - "type": "NoSuchNamespaceException", - "code": 404, - }, - }) - }) - - cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) - r.Require().NoError(err) - - err = cat.DropNamespace(context.Background(), catalog.ToRestIdentifier("examples")) - r.ErrorIs(err, catalog.ErrNoSuchNamespace) - r.ErrorContains(err, "examples in warehouse") -} - -func (r *RestCatalogSuite) TestLoadNamespaceProps200() { - r.mux.HandleFunc("/v1/namespaces/leden", func(w http.ResponseWriter, req *http.Request) { - r.Require().Equal(http.MethodGet, req.Method) - - for k, v := range TestHeaders { - r.Equal(v, req.Header.Values(k)) - } - - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]any{ - "namespace": []string{"fokko"}, - "properties": map[string]any{"prop": "yes"}, - }) - }) - - cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) - r.Require().NoError(err) - - props, err := cat.LoadNamespaceProperties(context.Background(), catalog.ToRestIdentifier("leden")) - r.Require().NoError(err) - r.Equal(iceberg.Properties{"prop": "yes"}, props) -} - -func (r *RestCatalogSuite) TestLoadNamespaceProps404() { - r.mux.HandleFunc("/v1/namespaces/leden", func(w http.ResponseWriter, req *http.Request) { - r.Require().Equal(http.MethodGet, req.Method) - - for k, v := range TestHeaders { - r.Equal(v, req.Header.Values(k)) - } - - w.WriteHeader(http.StatusNotFound) - json.NewEncoder(w).Encode(map[string]any{ - "error": map[string]any{ - "message": "Namespace does not exist: fokko22 in warehouse", - "type": "NoSuchNamespaceException", - "code": 404, - }, - }) - }) - - cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) - r.Require().NoError(err) - - _, err = cat.LoadNamespaceProperties(context.Background(), catalog.ToRestIdentifier("leden")) - r.ErrorIs(err, catalog.ErrNoSuchNamespace) - r.ErrorContains(err, "Namespace does not exist: fokko22 in warehouse") -} - -func (r *RestCatalogSuite) TestUpdateNamespaceProps200() { - r.mux.HandleFunc("/v1/namespaces/fokko/properties", func(w http.ResponseWriter, req *http.Request) { - r.Require().Equal(http.MethodPost, req.Method) - - for k, v := range TestHeaders { - r.Equal(v, req.Header.Values(k)) - } - - json.NewEncoder(w).Encode(map[string]any{ - "removed": []string{}, - "updated": []string{"prop"}, - "missing": []string{"abc"}, - }) - }) - - cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) - r.Require().NoError(err) - - summary, err := cat.UpdateNamespaceProperties(context.Background(), table.Identifier([]string{"fokko"}), - []string{"abc"}, iceberg.Properties{"prop": "yes"}) - r.Require().NoError(err) - - r.Equal(catalog.PropertiesUpdateSummary{ - Removed: []string{}, - Updated: []string{"prop"}, - Missing: []string{"abc"}, - }, summary) -} - -func (r *RestCatalogSuite) TestUpdateNamespaceProps404() { - r.mux.HandleFunc("/v1/namespaces/fokko/properties", func(w http.ResponseWriter, req *http.Request) { - r.Require().Equal(http.MethodPost, req.Method) - - for k, v := range TestHeaders { - r.Equal(v, req.Header.Values(k)) - } - - w.WriteHeader(http.StatusNotFound) - json.NewEncoder(w).Encode(map[string]any{ - "error": map[string]any{ - "message": "Namespace does not exist: does_not_exist in warehouse", - "type": "NoSuchNamespaceException", - "code": 404, - }, - }) - }) - - cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) - r.Require().NoError(err) - - _, err = cat.UpdateNamespaceProperties(context.Background(), - table.Identifier{"fokko"}, []string{"abc"}, iceberg.Properties{"prop": "yes"}) - r.ErrorIs(err, catalog.ErrNoSuchNamespace) - r.ErrorContains(err, "Namespace does not exist: does_not_exist in warehouse") -} - -func (r *RestCatalogSuite) TestLoadTable200() { - r.mux.HandleFunc("/v1/namespaces/fokko/tables/table", func(w http.ResponseWriter, req *http.Request) { - r.Require().Equal(http.MethodGet, req.Method) - - for k, v := range TestHeaders { - r.Equal(v, req.Header.Values(k)) - } - - w.Write([]byte(`{ - "metadata-location": "s3://warehouse/database/table/metadata/00001-5f2f8166-244c-4eae-ac36-384ecdec81fc.gz.metadata.json", - "metadata": { - "format-version": 1, - "table-uuid": "b55d9dda-6561-423a-8bfc-787980ce421f", - "location": "s3://warehouse/database/table", - "last-updated-ms": 1646787054459, - "last-column-id": 2, - "schema": { - "type": "struct", - "schema-id": 0, - "fields": [ - {"id": 1, "name": "id", "required": false, "type": "int"}, - {"id": 2, "name": "data", "required": false, "type": "string"} - ] - }, - "current-schema-id": 0, - "schemas": [ - { - "type": "struct", - "schema-id": 0, - "fields": [ - {"id": 1, "name": "id", "required": false, "type": "int"}, - {"id": 2, "name": "data", "required": false, "type": "string"} - ] - } - ], - "partition-spec": [], - "default-spec-id": 0, - "partition-specs": [{"spec-id": 0, "fields": []}], - "last-partition-id": 999, - "default-sort-order-id": 0, - "sort-orders": [{"order-id": 0, "fields": []}], - "properties": {"owner": "bryan", "write.metadata.compression-codec": "gzip"}, - "current-snapshot-id": 3497810964824022504, - "refs": {"main": {"snapshot-id": 3497810964824022504, "type": "branch"}}, - "snapshots": [ - { - "snapshot-id": 3497810964824022504, - "timestamp-ms": 1646787054459, - "summary": { - "operation": "append", - "spark.app.id": "local-1646787004168", - "added-data-files": "1", - "added-records": "1", - "added-files-size": "697", - "changed-partition-count": "1", - "total-records": "1", - "total-files-size": "697", - "total-data-files": "1", - "total-delete-files": "0", - "total-position-deletes": "0", - "total-equality-deletes": "0" - }, - "manifest-list": "s3://warehouse/database/table/metadata/snap-3497810964824022504-1-c4f68204-666b-4e50-a9df-b10c34bf6b82.avro", - "schema-id": 0 - } - ], - "snapshot-log": [{"timestamp-ms": 1646787054459, "snapshot-id": 3497810964824022504}], - "metadata-log": [ - { - "timestamp-ms": 1646787031514, - "metadata-file": "s3://warehouse/database/table/metadata/00000-88484a1c-00e5-4a07-a787-c0e7aeffa805.gz.metadata.json" - } - ] - } - }`)) - }) - - cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) - r.Require().NoError(err) - - tbl, err := cat.LoadTable(context.Background(), catalog.ToRestIdentifier("fokko", "table"), nil) - r.Require().NoError(err) - - r.Equal(catalog.ToRestIdentifier("rest", "fokko", "table"), tbl.Identifier()) - r.Equal("s3://warehouse/database/table/metadata/00001-5f2f8166-244c-4eae-ac36-384ecdec81fc.gz.metadata.json", tbl.MetadataLocation()) - r.EqualValues(1, tbl.Metadata().Version()) - r.Equal("b55d9dda-6561-423a-8bfc-787980ce421f", tbl.Metadata().TableUUID().String()) - r.EqualValues(1646787054459, tbl.Metadata().LastUpdatedMillis()) - r.Equal(2, tbl.Metadata().LastColumnID()) - r.Zero(tbl.Schema().ID) - r.Zero(tbl.Metadata().DefaultPartitionSpec()) - r.Equal(999, *tbl.Metadata().LastPartitionSpecID()) - r.Equal(table.UnsortedSortOrder, tbl.SortOrder()) - r.EqualValues(3497810964824022504, tbl.CurrentSnapshot().SnapshotID) - zero := 0 - r.True(tbl.SnapshotByName("main").Equals(table.Snapshot{ - SnapshotID: 3497810964824022504, - TimestampMs: 1646787054459, - SchemaID: &zero, - ManifestList: "s3://warehouse/database/table/metadata/snap-3497810964824022504-1-c4f68204-666b-4e50-a9df-b10c34bf6b82.avro", - Summary: &table.Summary{ - Operation: table.OpAppend, - Properties: map[string]string{ - "spark.app.id": "local-1646787004168", - "added-data-files": "1", - "added-records": "1", - "added-files-size": "697", - "changed-partition-count": "1", - "total-records": "1", - "total-files-size": "697", - "total-data-files": "1", - "total-delete-files": "0", - "total-position-deletes": "0", - "total-equality-deletes": "0", - }, - }, - })) -} - -type RestTLSCatalogSuite struct { - suite.Suite - - srv *httptest.Server - mux *http.ServeMux - - configVals url.Values -} - -func (r *RestTLSCatalogSuite) SetupTest() { - r.mux = http.NewServeMux() - - r.mux.HandleFunc("/v1/config", func(w http.ResponseWriter, req *http.Request) { - r.Require().Equal(http.MethodGet, req.Method) - r.configVals = req.URL.Query() - - json.NewEncoder(w).Encode(map[string]any{ - "defaults": map[string]any{}, - "overrides": map[string]any{}, - }) - }) - - r.srv = httptest.NewTLSServer(r.mux) -} - -func (r *RestTLSCatalogSuite) TearDownTest() { - r.srv.Close() - r.srv = nil - r.mux = nil - r.configVals = nil -} - -func (r *RestTLSCatalogSuite) TestSSLFail() { - cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) - r.Nil(cat) - - r.ErrorContains(err, "tls: failed to verify certificate") -} - -func (r *RestTLSCatalogSuite) TestSSLConfig() { - cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken), - catalog.WithWarehouseLocation("s3://some-bucket"), - catalog.WithTLSConfig(&tls.Config{InsecureSkipVerify: true})) - r.NoError(err) - - r.NotNil(cat) - r.Equal(r.configVals.Get("warehouse"), "s3://some-bucket") -} - -func (r *RestTLSCatalogSuite) TestSSLCerts() { - certs := x509.NewCertPool() - for _, c := range r.srv.TLS.Certificates { - roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1]) - r.Require().NoError(err) - for _, root := range roots { - certs.AddCert(root) - } - } - - cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken), - catalog.WithWarehouseLocation("s3://some-bucket"), - catalog.WithTLSConfig(&tls.Config{RootCAs: certs})) - r.NoError(err) - - r.NotNil(cat) - r.Equal(r.configVals.Get("warehouse"), "s3://some-bucket") -} - -func TestRestCatalog(t *testing.T) { - suite.Run(t, new(RestCatalogSuite)) - suite.Run(t, new(RestTLSCatalogSuite)) -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package catalog_test + +import ( + "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/apache/iceberg-go" + "github.com/apache/iceberg-go/catalog" + "github.com/apache/iceberg-go/table" + "github.com/stretchr/testify/suite" +) + +const ( + TestCreds = "client:secret" + TestToken = "some_jwt_token" +) + +var ( + TestHeaders = http.Header{ + "X-Client-Version": {"0.14.1"}, + "User-Agent": {"GoIceberg/(unknown version)"}, + "Authorization": {"Bearer " + TestToken}, + } + OAuthTestHeaders = http.Header{ + "Content-Type": {"application/x-www-form-urlencoded"}, + } +) + +type RestCatalogSuite struct { + suite.Suite + + srv *httptest.Server + mux *http.ServeMux + + configVals url.Values +} + +func (r *RestCatalogSuite) SetupTest() { + r.mux = http.NewServeMux() + + r.mux.HandleFunc("/v1/config", func(w http.ResponseWriter, req *http.Request) { + r.Require().Equal(http.MethodGet, req.Method) + r.configVals = req.URL.Query() + + json.NewEncoder(w).Encode(map[string]any{ + "defaults": map[string]any{}, + "overrides": map[string]any{}, + }) + }) + + r.srv = httptest.NewServer(r.mux) +} + +func (r *RestCatalogSuite) TearDownTest() { + r.srv.Close() + r.srv = nil + r.mux = nil + r.configVals = nil +} + +func (r *RestCatalogSuite) TestToken200() { + r.mux.HandleFunc("/v1/oauth/tokens", func(w http.ResponseWriter, req *http.Request) { + r.Equal(http.MethodPost, req.Method) + + r.Equal(req.Header.Get("Content-Type"), "application/x-www-form-urlencoded") + + r.Require().NoError(req.ParseForm()) + values := req.PostForm + r.Equal(values.Get("grant_type"), "client_credentials") + r.Equal(values.Get("client_id"), "client") + r.Equal(values.Get("client_secret"), "secret") + r.Equal(values.Get("scope"), "catalog") + + w.WriteHeader(http.StatusOK) + + json.NewEncoder(w).Encode(map[string]any{ + "access_token": TestToken, + "token_type": "Bearer", + "expires_in": 86400, + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + }) + }) + + cat, err := catalog.NewRestCatalog("rest", r.srv.URL, + catalog.WithWarehouseLocation("s3://some-bucket"), + catalog.WithCredential(TestCreds)) + r.Require().NoError(err) + + r.NotNil(cat) + r.Equal(r.configVals.Get("warehouse"), "s3://some-bucket") +} + +func (r *RestCatalogSuite) TestToken400() { + r.mux.HandleFunc("/v1/oauth/tokens", func(w http.ResponseWriter, req *http.Request) { + r.Equal(http.MethodPost, req.Method) + + r.Equal(req.Header.Get("Content-Type"), "application/x-www-form-urlencoded") + + w.WriteHeader(http.StatusBadRequest) + + json.NewEncoder(w).Encode(map[string]any{ + "error": "invalid_client", + "error_description": "credentials for key invalid_key do not match", + }) + }) + + cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithCredential(TestCreds)) + r.Nil(cat) + + r.ErrorIs(err, catalog.ErrRESTError) + r.ErrorIs(err, catalog.ErrOAuthError) + r.ErrorContains(err, "invalid_client: credentials for key invalid_key do not match") +} + +func (r *RestCatalogSuite) TestToken200AuthUrl() { + r.mux.HandleFunc("/auth-token-url", func(w http.ResponseWriter, req *http.Request) { + r.Equal(http.MethodPost, req.Method) + + r.Equal(req.Header.Get("Content-Type"), "application/x-www-form-urlencoded") + + r.Require().NoError(req.ParseForm()) + values := req.PostForm + r.Equal(values.Get("grant_type"), "client_credentials") + r.Equal(values.Get("client_id"), "client") + r.Equal(values.Get("client_secret"), "secret") + r.Equal(values.Get("scope"), "catalog") + + w.WriteHeader(http.StatusOK) + + json.NewEncoder(w).Encode(map[string]any{ + "access_token": TestToken, + "token_type": "Bearer", + "expires_in": 86400, + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + }) + }) + + authUri, err := url.Parse(r.srv.URL) + r.Require().NoError(err) + cat, err := catalog.NewRestCatalog("rest", r.srv.URL, + catalog.WithWarehouseLocation("s3://some-bucket"), + catalog.WithCredential(TestCreds), catalog.WithAuthURI(authUri.JoinPath("auth-token-url"))) + + r.Require().NoError(err) + + r.NotNil(cat) + r.Equal(r.configVals.Get("warehouse"), "s3://some-bucket") +} + +func (r *RestCatalogSuite) TestToken401() { + r.mux.HandleFunc("/v1/oauth/tokens", func(w http.ResponseWriter, req *http.Request) { + r.Equal(http.MethodPost, req.Method) + + r.Equal(req.Header.Get("Content-Type"), "application/x-www-form-urlencoded") + + w.WriteHeader(http.StatusUnauthorized) + + json.NewEncoder(w).Encode(map[string]any{ + "error": "invalid_client", + "error_description": "credentials for key invalid_key do not match", + }) + }) + + cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithCredential(TestCreds)) + r.Nil(cat) + + r.ErrorIs(err, catalog.ErrRESTError) + r.ErrorIs(err, catalog.ErrOAuthError) + r.ErrorContains(err, "invalid_client: credentials for key invalid_key do not match") +} + +func (r *RestCatalogSuite) TestListTables200() { + namespace := "examples" + r.mux.HandleFunc("/v1/namespaces/"+namespace+"/tables", func(w http.ResponseWriter, req *http.Request) { + r.Require().Equal(http.MethodGet, req.Method) + + for k, v := range TestHeaders { + r.Equal(v, req.Header.Values(k)) + } + + json.NewEncoder(w).Encode(map[string]any{ + "identifiers": []any{ + map[string]any{ + "namespace": []string{namespace}, + "name": "fooshare", + }, + }, + }) + }) + + cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) + r.Require().NoError(err) + + tables, err := cat.ListTables(context.Background(), catalog.ToRestIdentifier(namespace)) + r.Require().NoError(err) + r.Equal([]table.Identifier{{"examples", "fooshare"}}, tables) +} + +func (r *RestCatalogSuite) TestListTablesPrefixed200() { + r.mux.HandleFunc("/v1/oauth/tokens", func(w http.ResponseWriter, req *http.Request) { + r.Equal(http.MethodPost, req.Method) + + r.Equal(req.Header.Get("Content-Type"), "application/x-www-form-urlencoded") + + r.Require().NoError(req.ParseForm()) + values := req.PostForm + r.Equal(values.Get("grant_type"), "client_credentials") + r.Equal(values.Get("client_id"), "client") + r.Equal(values.Get("client_secret"), "secret") + r.Equal(values.Get("scope"), "catalog") + + w.WriteHeader(http.StatusOK) + + json.NewEncoder(w).Encode(map[string]any{ + "access_token": TestToken, + "token_type": "Bearer", + "expires_in": 86400, + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + }) + }) + + namespace := "examples" + r.mux.HandleFunc("/v1/prefix/namespaces/"+namespace+"/tables", func(w http.ResponseWriter, req *http.Request) { + r.Require().Equal(http.MethodGet, req.Method) + + for k, v := range TestHeaders { + r.Equal(v, req.Header.Values(k)) + } + + json.NewEncoder(w).Encode(map[string]any{ + "identifiers": []any{ + map[string]any{ + "namespace": []string{namespace}, + "name": "fooshare", + }, + }, + }) + }) + + cat, err := catalog.NewRestCatalog("rest", r.srv.URL, + catalog.WithPrefix("prefix"), + catalog.WithWarehouseLocation("s3://some-bucket"), + catalog.WithCredential(TestCreds)) + r.Require().NoError(err) + + r.NotNil(cat) + r.Equal(r.configVals.Get("warehouse"), "s3://some-bucket") + + tables, err := cat.ListTables(context.Background(), catalog.ToRestIdentifier(namespace)) + r.Require().NoError(err) + r.Equal([]table.Identifier{{"examples", "fooshare"}}, tables) +} + +func (r *RestCatalogSuite) TestListTables404() { + namespace := "examples" + r.mux.HandleFunc("/v1/namespaces/"+namespace+"/tables", func(w http.ResponseWriter, req *http.Request) { + r.Require().Equal(http.MethodGet, req.Method) + + for k, v := range TestHeaders { + r.Equal(v, req.Header.Values(k)) + } + + w.WriteHeader(http.StatusNotFound) + json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{ + "message": "Namespace does not exist: personal in warehouse 8bcb0838-50fc-472d-9ddb-8feb89ef5f1e", + "type": "NoSuchNamespaceException", + "code": 404, + }, + }) + }) + + cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) + r.Require().NoError(err) + + _, err = cat.ListTables(context.Background(), catalog.ToRestIdentifier(namespace)) + r.ErrorIs(err, catalog.ErrNoSuchNamespace) + r.ErrorContains(err, "Namespace does not exist: personal in warehouse 8bcb0838-50fc-472d-9ddb-8feb89ef5f1e") +} + +func (r *RestCatalogSuite) TestListNamespaces200() { + r.mux.HandleFunc("/v1/namespaces", func(w http.ResponseWriter, req *http.Request) { + r.Require().Equal(http.MethodGet, req.Method) + + for k, v := range TestHeaders { + r.Equal(v, req.Header.Values(k)) + } + + json.NewEncoder(w).Encode(map[string]any{ + "namespaces": []table.Identifier{ + {"default"}, {"examples"}, {"fokko"}, {"system"}, + }, + }) + }) + + cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) + r.Require().NoError(err) + + results, err := cat.ListNamespaces(context.Background(), nil) + r.Require().NoError(err) + + r.Equal([]table.Identifier{{"default"}, {"examples"}, {"fokko"}, {"system"}}, results) +} + +func (r *RestCatalogSuite) TestListNamespaceWithParent200() { + r.mux.HandleFunc("/v1/namespaces", func(w http.ResponseWriter, req *http.Request) { + r.Require().Equal(http.MethodGet, req.Method) + r.Require().Equal("accounting", req.URL.Query().Get("parent")) + + for k, v := range TestHeaders { + r.Equal(v, req.Header.Values(k)) + } + + json.NewEncoder(w).Encode(map[string]any{ + "namespaces": []table.Identifier{ + {"accounting", "tax"}, + }, + }) + }) + + cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) + r.Require().NoError(err) + + results, err := cat.ListNamespaces(context.Background(), catalog.ToRestIdentifier("accounting")) + r.Require().NoError(err) + + r.Equal([]table.Identifier{{"accounting", "tax"}}, results) +} + +func (r *RestCatalogSuite) TestListNamespaces400() { + r.mux.HandleFunc("/v1/namespaces", func(w http.ResponseWriter, req *http.Request) { + r.Require().Equal(http.MethodGet, req.Method) + + for k, v := range TestHeaders { + r.Equal(v, req.Header.Values(k)) + } + + w.WriteHeader(http.StatusNotFound) + json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{ + "message": "Namespace does not exist: personal in warehouse 8bcb0838-50fc-472d-9ddb-8feb89ef5f1e", + "type": "NoSuchNamespaceException", + "code": 404, + }, + }) + }) + + cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) + r.Require().NoError(err) + + _, err = cat.ListNamespaces(context.Background(), catalog.ToRestIdentifier("accounting")) + r.ErrorIs(err, catalog.ErrNoSuchNamespace) + r.ErrorContains(err, "Namespace does not exist: personal in warehouse 8bcb0838-50fc-472d-9ddb-8feb89ef5f1e") +} + +func (r *RestCatalogSuite) TestCreateNamespace200() { + r.mux.HandleFunc("/v1/namespaces", func(w http.ResponseWriter, req *http.Request) { + r.Require().Equal(http.MethodPost, req.Method) + r.Require().Equal("application/json", req.Header.Get("Content-Type")) + + for k, v := range TestHeaders { + r.Equal(v, req.Header.Values(k)) + } + + defer req.Body.Close() + dec := json.NewDecoder(req.Body) + body := struct { + Namespace table.Identifier `json:"namespace"` + Props iceberg.Properties `json:"properties"` + }{} + + r.Require().NoError(dec.Decode(&body)) + r.Equal(table.Identifier{"leden"}, body.Namespace) + r.Empty(body.Props) + + json.NewEncoder(w).Encode(map[string]any{ + "namespace": []string{"leden"}, "properties": map[string]any{}, + }) + }) + + cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) + r.Require().NoError(err) + + r.Require().NoError(cat.CreateNamespace(context.Background(), catalog.ToRestIdentifier("leden"), nil)) +} + +func (r *RestCatalogSuite) TestCreateNamespaceWithProps200() { + r.mux.HandleFunc("/v1/namespaces", func(w http.ResponseWriter, req *http.Request) { + r.Require().Equal(http.MethodPost, req.Method) + r.Require().Equal("application/json", req.Header.Get("Content-Type")) + + for k, v := range TestHeaders { + r.Equal(v, req.Header.Values(k)) + } + + defer req.Body.Close() + dec := json.NewDecoder(req.Body) + body := struct { + Namespace table.Identifier `json:"namespace"` + Props iceberg.Properties `json:"properties"` + }{} + + r.Require().NoError(dec.Decode(&body)) + r.Equal(table.Identifier{"leden"}, body.Namespace) + r.Equal(iceberg.Properties{"foo": "bar", "super": "duper"}, body.Props) + + json.NewEncoder(w).Encode(map[string]any{ + "namespace": []string{"leden"}, "properties": body.Props, + }) + }) + + cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) + r.Require().NoError(err) + + r.Require().NoError(cat.CreateNamespace(context.Background(), catalog.ToRestIdentifier("leden"), iceberg.Properties{"foo": "bar", "super": "duper"})) +} + +func (r *RestCatalogSuite) TestCreateNamespace409() { + r.mux.HandleFunc("/v1/namespaces", func(w http.ResponseWriter, req *http.Request) { + r.Require().Equal(http.MethodPost, req.Method) + r.Require().Equal("application/json", req.Header.Get("Content-Type")) + + for k, v := range TestHeaders { + r.Equal(v, req.Header.Values(k)) + } + + defer req.Body.Close() + dec := json.NewDecoder(req.Body) + body := struct { + Namespace table.Identifier `json:"namespace"` + Props iceberg.Properties `json:"properties"` + }{} + + r.Require().NoError(dec.Decode(&body)) + r.Equal(table.Identifier{"fokko"}, body.Namespace) + r.Empty(body.Props) + + w.WriteHeader(http.StatusConflict) + json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{ + "message": "Namespace already exists: fokko in warehouse 8bcb0838-50fc-472d-9ddb-8feb89ef5f1e", + "type": "AlreadyExistsException", + "code": 409, + }, + }) + }) + + cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) + r.Require().NoError(err) + + err = cat.CreateNamespace(context.Background(), catalog.ToRestIdentifier("fokko"), nil) + r.ErrorIs(err, catalog.ErrNamespaceAlreadyExists) + r.ErrorContains(err, "fokko in warehouse") +} + +func (r *RestCatalogSuite) TestDropNamespace204() { + r.mux.HandleFunc("/v1/namespaces/examples", func(w http.ResponseWriter, req *http.Request) { + r.Require().Equal(http.MethodDelete, req.Method) + + for k, v := range TestHeaders { + r.Equal(v, req.Header.Values(k)) + } + + w.WriteHeader(http.StatusNoContent) + }) + + cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) + r.Require().NoError(err) + + r.NoError(cat.DropNamespace(context.Background(), catalog.ToRestIdentifier("examples"))) +} + +func (r *RestCatalogSuite) TestDropNamespace404() { + r.mux.HandleFunc("/v1/namespaces/examples", func(w http.ResponseWriter, req *http.Request) { + r.Require().Equal(http.MethodDelete, req.Method) + + for k, v := range TestHeaders { + r.Equal(v, req.Header.Values(k)) + } + + w.WriteHeader(http.StatusNotFound) + json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{ + "message": "Namespace does not exist: examples in warehouse", + "type": "NoSuchNamespaceException", + "code": 404, + }, + }) + }) + + cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) + r.Require().NoError(err) + + err = cat.DropNamespace(context.Background(), catalog.ToRestIdentifier("examples")) + r.ErrorIs(err, catalog.ErrNoSuchNamespace) + r.ErrorContains(err, "examples in warehouse") +} + +func (r *RestCatalogSuite) TestLoadNamespaceProps200() { + r.mux.HandleFunc("/v1/namespaces/leden", func(w http.ResponseWriter, req *http.Request) { + r.Require().Equal(http.MethodGet, req.Method) + + for k, v := range TestHeaders { + r.Equal(v, req.Header.Values(k)) + } + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]any{ + "namespace": []string{"fokko"}, + "properties": map[string]any{"prop": "yes"}, + }) + }) + + cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) + r.Require().NoError(err) + + props, err := cat.LoadNamespaceProperties(context.Background(), catalog.ToRestIdentifier("leden")) + r.Require().NoError(err) + r.Equal(iceberg.Properties{"prop": "yes"}, props) +} + +func (r *RestCatalogSuite) TestLoadNamespaceProps404() { + r.mux.HandleFunc("/v1/namespaces/leden", func(w http.ResponseWriter, req *http.Request) { + r.Require().Equal(http.MethodGet, req.Method) + + for k, v := range TestHeaders { + r.Equal(v, req.Header.Values(k)) + } + + w.WriteHeader(http.StatusNotFound) + json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{ + "message": "Namespace does not exist: fokko22 in warehouse", + "type": "NoSuchNamespaceException", + "code": 404, + }, + }) + }) + + cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) + r.Require().NoError(err) + + _, err = cat.LoadNamespaceProperties(context.Background(), catalog.ToRestIdentifier("leden")) + r.ErrorIs(err, catalog.ErrNoSuchNamespace) + r.ErrorContains(err, "Namespace does not exist: fokko22 in warehouse") +} + +func (r *RestCatalogSuite) TestUpdateNamespaceProps200() { + r.mux.HandleFunc("/v1/namespaces/fokko/properties", func(w http.ResponseWriter, req *http.Request) { + r.Require().Equal(http.MethodPost, req.Method) + + for k, v := range TestHeaders { + r.Equal(v, req.Header.Values(k)) + } + + json.NewEncoder(w).Encode(map[string]any{ + "removed": []string{}, + "updated": []string{"prop"}, + "missing": []string{"abc"}, + }) + }) + + cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) + r.Require().NoError(err) + + summary, err := cat.UpdateNamespaceProperties(context.Background(), table.Identifier([]string{"fokko"}), + []string{"abc"}, iceberg.Properties{"prop": "yes"}) + r.Require().NoError(err) + + r.Equal(catalog.PropertiesUpdateSummary{ + Removed: []string{}, + Updated: []string{"prop"}, + Missing: []string{"abc"}, + }, summary) +} + +func (r *RestCatalogSuite) TestUpdateNamespaceProps404() { + r.mux.HandleFunc("/v1/namespaces/fokko/properties", func(w http.ResponseWriter, req *http.Request) { + r.Require().Equal(http.MethodPost, req.Method) + + for k, v := range TestHeaders { + r.Equal(v, req.Header.Values(k)) + } + + w.WriteHeader(http.StatusNotFound) + json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{ + "message": "Namespace does not exist: does_not_exist in warehouse", + "type": "NoSuchNamespaceException", + "code": 404, + }, + }) + }) + + cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) + r.Require().NoError(err) + + _, err = cat.UpdateNamespaceProperties(context.Background(), + table.Identifier{"fokko"}, []string{"abc"}, iceberg.Properties{"prop": "yes"}) + r.ErrorIs(err, catalog.ErrNoSuchNamespace) + r.ErrorContains(err, "Namespace does not exist: does_not_exist in warehouse") +} + +func (r *RestCatalogSuite) TestLoadTable200() { + r.mux.HandleFunc("/v1/namespaces/fokko/tables/table", func(w http.ResponseWriter, req *http.Request) { + r.Require().Equal(http.MethodGet, req.Method) + + for k, v := range TestHeaders { + r.Equal(v, req.Header.Values(k)) + } + + w.Write([]byte(`{ + "metadata-location": "s3://warehouse/database/table/metadata/00001-5f2f8166-244c-4eae-ac36-384ecdec81fc.gz.metadata.json", + "metadata": { + "format-version": 1, + "table-uuid": "b55d9dda-6561-423a-8bfc-787980ce421f", + "location": "s3://warehouse/database/table", + "last-updated-ms": 1646787054459, + "last-column-id": 2, + "schema": { + "type": "struct", + "schema-id": 0, + "fields": [ + {"id": 1, "name": "id", "required": false, "type": "int"}, + {"id": 2, "name": "data", "required": false, "type": "string"} + ] + }, + "current-schema-id": 0, + "schemas": [ + { + "type": "struct", + "schema-id": 0, + "fields": [ + {"id": 1, "name": "id", "required": false, "type": "int"}, + {"id": 2, "name": "data", "required": false, "type": "string"} + ] + } + ], + "partition-spec": [], + "default-spec-id": 0, + "partition-specs": [{"spec-id": 0, "fields": []}], + "last-partition-id": 999, + "default-sort-order-id": 0, + "sort-orders": [{"order-id": 0, "fields": []}], + "properties": {"owner": "bryan", "write.metadata.compression-codec": "gzip"}, + "current-snapshot-id": 3497810964824022504, + "refs": {"main": {"snapshot-id": 3497810964824022504, "type": "branch"}}, + "snapshots": [ + { + "snapshot-id": 3497810964824022504, + "timestamp-ms": 1646787054459, + "summary": { + "operation": "append", + "spark.app.id": "local-1646787004168", + "added-data-files": "1", + "added-records": "1", + "added-files-size": "697", + "changed-partition-count": "1", + "total-records": "1", + "total-files-size": "697", + "total-data-files": "1", + "total-delete-files": "0", + "total-position-deletes": "0", + "total-equality-deletes": "0" + }, + "manifest-list": "s3://warehouse/database/table/metadata/snap-3497810964824022504-1-c4f68204-666b-4e50-a9df-b10c34bf6b82.avro", + "schema-id": 0 + } + ], + "snapshot-log": [{"timestamp-ms": 1646787054459, "snapshot-id": 3497810964824022504}], + "metadata-log": [ + { + "timestamp-ms": 1646787031514, + "metadata-file": "s3://warehouse/database/table/metadata/00000-88484a1c-00e5-4a07-a787-c0e7aeffa805.gz.metadata.json" + } + ] + } + }`)) + }) + + cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) + r.Require().NoError(err) + + tbl, err := cat.LoadTable(context.Background(), catalog.ToRestIdentifier("fokko", "table"), nil) + r.Require().NoError(err) + + r.Equal(catalog.ToRestIdentifier("rest", "fokko", "table"), tbl.Identifier()) + r.Equal("s3://warehouse/database/table/metadata/00001-5f2f8166-244c-4eae-ac36-384ecdec81fc.gz.metadata.json", tbl.MetadataLocation()) + r.EqualValues(1, tbl.Metadata().Version()) + r.Equal("b55d9dda-6561-423a-8bfc-787980ce421f", tbl.Metadata().TableUUID().String()) + r.EqualValues(1646787054459, tbl.Metadata().LastUpdatedMillis()) + r.Equal(2, tbl.Metadata().LastColumnID()) + r.Zero(tbl.Schema().ID) + r.Zero(tbl.Metadata().DefaultPartitionSpec()) + r.Equal(999, *tbl.Metadata().LastPartitionSpecID()) + r.Equal(table.UnsortedSortOrder, tbl.SortOrder()) + r.EqualValues(3497810964824022504, tbl.CurrentSnapshot().SnapshotID) + zero := 0 + r.True(tbl.SnapshotByName("main").Equals(table.Snapshot{ + SnapshotID: 3497810964824022504, + TimestampMs: 1646787054459, + SchemaID: &zero, + ManifestList: "s3://warehouse/database/table/metadata/snap-3497810964824022504-1-c4f68204-666b-4e50-a9df-b10c34bf6b82.avro", + Summary: &table.Summary{ + Operation: table.OpAppend, + Properties: map[string]string{ + "spark.app.id": "local-1646787004168", + "added-data-files": "1", + "added-records": "1", + "added-files-size": "697", + "changed-partition-count": "1", + "total-records": "1", + "total-files-size": "697", + "total-data-files": "1", + "total-delete-files": "0", + "total-position-deletes": "0", + "total-equality-deletes": "0", + }, + }, + })) +} + +type RestTLSCatalogSuite struct { + suite.Suite + + srv *httptest.Server + mux *http.ServeMux + + configVals url.Values +} + +func (r *RestTLSCatalogSuite) SetupTest() { + r.mux = http.NewServeMux() + + r.mux.HandleFunc("/v1/config", func(w http.ResponseWriter, req *http.Request) { + r.Require().Equal(http.MethodGet, req.Method) + r.configVals = req.URL.Query() + + json.NewEncoder(w).Encode(map[string]any{ + "defaults": map[string]any{}, + "overrides": map[string]any{}, + }) + }) + + r.srv = httptest.NewTLSServer(r.mux) +} + +func (r *RestTLSCatalogSuite) TearDownTest() { + r.srv.Close() + r.srv = nil + r.mux = nil + r.configVals = nil +} + +func (r *RestTLSCatalogSuite) TestSSLFail() { + cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken)) + r.Nil(cat) + + r.ErrorContains(err, "tls: failed to verify certificate") +} + +func (r *RestTLSCatalogSuite) TestSSLConfig() { + cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken), + catalog.WithWarehouseLocation("s3://some-bucket"), + catalog.WithTLSConfig(&tls.Config{InsecureSkipVerify: true})) + r.NoError(err) + + r.NotNil(cat) + r.Equal(r.configVals.Get("warehouse"), "s3://some-bucket") +} + +func (r *RestTLSCatalogSuite) TestSSLCerts() { + certs := x509.NewCertPool() + for _, c := range r.srv.TLS.Certificates { + roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1]) + r.Require().NoError(err) + for _, root := range roots { + certs.AddCert(root) + } + } + + cat, err := catalog.NewRestCatalog("rest", r.srv.URL, catalog.WithOAuthToken(TestToken), + catalog.WithWarehouseLocation("s3://some-bucket"), + catalog.WithTLSConfig(&tls.Config{RootCAs: certs})) + r.NoError(err) + + r.NotNil(cat) + r.Equal(r.configVals.Get("warehouse"), "s3://some-bucket") +} + +func TestRestCatalog(t *testing.T) { + suite.Run(t, new(RestCatalogSuite)) + suite.Run(t, new(RestTLSCatalogSuite)) +} diff --git a/cmd/iceberg/main.go b/cmd/iceberg/main.go index fb25618..155fde4 100644 --- a/cmd/iceberg/main.go +++ b/cmd/iceberg/main.go @@ -1,343 +1,343 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package main - -import ( - "context" - "errors" - "fmt" - "log" - "os" - "strings" - - "github.com/apache/iceberg-go" - "github.com/apache/iceberg-go/catalog" - "github.com/apache/iceberg-go/table" - "github.com/docopt/docopt-go" -) - -const usage = `iceberg. - -Usage: - iceberg list [options] [PARENT] - iceberg describe [options] [namespace | table] IDENTIFIER - iceberg (schema | spec | uuid | location) [options] TABLE_ID - iceberg drop [options] (namespace | table) IDENTIFIER - iceberg files [options] TABLE_ID [--history] - iceberg rename [options] - iceberg properties [options] get (namespace | table) IDENTIFIER [PROPNAME] - iceberg properties [options] set (namespace | table) IDENTIFIER PROPNAME VALUE - iceberg properties [options] remove (namespace | table) IDENTIFIER PROPNAME - iceberg -h | --help | --version - -Arguments: - PARENT Catalog parent namespace - IDENTIFIER fully qualified namespace or table - TABLE_ID full path to a table - PROPNAME name of a property - VALUE value to set - -Options: - -h --help show this helpe messages and exit - --catalog TEXT specify the catalog type [default: rest] - --uri TEXT specify the catalog URI - --output TYPE output type (json/text) [default: text] - --credential TEXT specify credentials for the catalog - --warehouse TEXT specify the warehouse to use` - -func main() { - args, err := docopt.ParseArgs(usage, os.Args[1:], iceberg.Version()) - if err != nil { - log.Fatal(err) - } - - cfg := struct { - List bool `docopt:"list"` - Describe bool `docopt:"describe"` - Schema bool `docopt:"schema"` - Spec bool `docopt:"spec"` - Uuid bool `docopt:"uuid"` - Location bool `docopt:"location"` - Props bool `docopt:"properties"` - Drop bool `docopt:"drop"` - Files bool `docopt:"files"` - Rename bool `docopt:"rename"` - - Get bool `docopt:"get"` - Set bool `docopt:"set"` - Remove bool `docopt:"remove"` - - Namespace bool `docopt:"namespace"` - Table bool `docopt:"table"` - - RenameFrom string `docopt:""` - RenameTo string `docopt:""` - - Parent string `docopt:"PARENT"` - Ident string `docopt:"IDENTIFIER"` - TableID string `docopt:"TABLE_ID"` - PropName string `docopt:"PROPNAME"` - Value string `docopt:"VALUE"` - - Catalog string `docopt:"--catalog"` - URI string `docopt:"--uri"` - Output string `docopt:"--output"` - History bool `docopt:"--history"` - Cred string `docopt:"--credential"` - Warehouse string `docopt:"--warehouse"` - }{} - - if err := args.Bind(&cfg); err != nil { - log.Fatal(err) - } - - var output Output - switch strings.ToLower(cfg.Output) { - case "text": - output = text{} - case "json": - fallthrough - default: - log.Fatal("unimplemented output type") - } - - var cat catalog.Catalog - switch catalog.CatalogType(cfg.Catalog) { - case catalog.REST: - opts := []catalog.Option[catalog.RestCatalog]{} - if len(cfg.Cred) > 0 { - opts = append(opts, catalog.WithCredential(cfg.Cred)) - } - - if len(cfg.Warehouse) > 0 { - opts = append(opts, catalog.WithWarehouseLocation(cfg.Warehouse)) - } - - if cat, err = catalog.NewRestCatalog("rest", cfg.URI, opts...); err != nil { - log.Fatal(err) - } - default: - log.Fatal("unrecognized catalog type") - } - - switch { - case cfg.List: - list(output, cat, cfg.Parent) - case cfg.Describe: - entityType := "any" - if cfg.Namespace { - entityType = "ns" - } else if cfg.Table { - entityType = "tbl" - } - - describe(output, cat, cfg.Ident, entityType) - case cfg.Schema: - tbl := loadTable(output, cat, cfg.TableID) - output.Schema(tbl.Schema()) - case cfg.Spec: - tbl := loadTable(output, cat, cfg.TableID) - output.Spec(tbl.Spec()) - case cfg.Location: - tbl := loadTable(output, cat, cfg.TableID) - output.Text(tbl.Location()) - case cfg.Uuid: - tbl := loadTable(output, cat, cfg.TableID) - output.Uuid(tbl.Metadata().TableUUID()) - case cfg.Props: - properties(output, cat, propCmd{ - get: cfg.Get, set: cfg.Set, remove: cfg.Remove, - namespace: cfg.Namespace, table: cfg.Table, - identifier: cfg.Ident, - propname: cfg.PropName, - value: cfg.Value, - }) - case cfg.Rename: - _, err := cat.RenameTable(context.Background(), - catalog.ToRestIdentifier(cfg.RenameFrom), catalog.ToRestIdentifier(cfg.RenameTo)) - if err != nil { - output.Error(err) - os.Exit(1) - } - - output.Text("Renamed table from " + cfg.RenameFrom + " to " + cfg.RenameTo) - case cfg.Drop: - switch { - case cfg.Namespace: - err := cat.DropNamespace(context.Background(), catalog.ToRestIdentifier(cfg.Ident)) - if err != nil { - output.Error(err) - os.Exit(1) - } - case cfg.Table: - err := cat.DropTable(context.Background(), catalog.ToRestIdentifier(cfg.Ident)) - if err != nil { - output.Error(err) - os.Exit(1) - } - } - case cfg.Files: - tbl := loadTable(output, cat, cfg.TableID) - output.Files(tbl, cfg.History) - } -} - -func list(output Output, cat catalog.Catalog, parent string) { - prnt := catalog.ToRestIdentifier(parent) - - ids, err := cat.ListNamespaces(context.Background(), prnt) - if err != nil { - output.Error(err) - os.Exit(1) - } - - if len(ids) == 0 && parent != "" { - ids, err = cat.ListTables(context.Background(), prnt) - if err != nil { - output.Error(err) - os.Exit(1) - } - } - output.Identifiers(ids) -} - -func describe(output Output, cat catalog.Catalog, id string, entityType string) { - ctx := context.Background() - - ident := catalog.ToRestIdentifier(id) - - isNS, isTbl := false, false - if (entityType == "any" || entityType == "ns") && len(ident) > 0 { - nsprops, err := cat.LoadNamespaceProperties(ctx, ident) - if err != nil { - if errors.Is(err, catalog.ErrNoSuchNamespace) { - if entityType != "any" || len(ident) == 1 { - output.Error(err) - os.Exit(1) - } - } else { - output.Error(err) - os.Exit(1) - } - } else { - isNS = true - output.DescribeProperties(nsprops) - } - } - - if (entityType == "any" || entityType == "tbl") && len(ident) > 1 { - tbl, err := cat.LoadTable(ctx, ident, nil) - if err != nil { - if !errors.Is(err, catalog.ErrNoSuchTable) || entityType != "any" { - output.Error(err) - os.Exit(1) - } - } else { - isTbl = true - output.DescribeTable(tbl) - } - } - - if !isNS && !isTbl { - output.Error(fmt.Errorf("%w: table or namespace does not exist: %s", - catalog.ErrNoSuchNamespace, ident)) - os.Exit(1) - } -} - -func loadTable(output Output, cat catalog.Catalog, id string) *table.Table { - tbl, err := cat.LoadTable(context.Background(), catalog.ToRestIdentifier(id), nil) - if err != nil { - output.Error(err) - os.Exit(1) - } - - return tbl -} - -type propCmd struct { - get, set, remove bool - namespace, table bool - - identifier, propname, value string -} - -func properties(output Output, cat catalog.Catalog, args propCmd) { - ctx, ident := context.Background(), catalog.ToRestIdentifier(args.identifier) - - switch { - case args.get: - var props iceberg.Properties - switch { - case args.namespace: - var err error - props, err = cat.LoadNamespaceProperties(ctx, ident) - if err != nil { - output.Error(err) - os.Exit(1) - } - case args.table: - tbl := loadTable(output, cat, args.identifier) - props = tbl.Metadata().Properties() - } - - if args.propname == "" { - output.DescribeProperties(props) - return - } - - if val, ok := props[args.propname]; ok { - output.Text(val) - } else { - output.Error(errors.New("could not find property " + args.propname + " on namespace " + args.identifier)) - os.Exit(1) - } - case args.set: - switch { - case args.namespace: - _, err := cat.UpdateNamespaceProperties(ctx, ident, - nil, iceberg.Properties{args.propname: args.value}) - if err != nil { - output.Error(err) - os.Exit(1) - } - - output.Text("updated " + args.propname + " on " + args.identifier) - case args.table: - loadTable(output, cat, args.identifier) - output.Text("Setting " + args.propname + "=" + args.value + " on " + args.identifier) - output.Error(errors.New("not implemented: Writing is WIP")) - } - case args.remove: - switch { - case args.namespace: - _, err := cat.UpdateNamespaceProperties(ctx, ident, - []string{args.propname}, nil) - if err != nil { - output.Error(err) - os.Exit(1) - } - - output.Text("removing " + args.propname + " from " + args.identifier) - case args.table: - loadTable(output, cat, args.identifier) - output.Text("Setting " + args.propname + "=" + args.value + " on " + args.identifier) - output.Error(errors.New("not implemented: Writing is WIP")) - } - } -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package main + +import ( + "context" + "errors" + "fmt" + "log" + "os" + "strings" + + "github.com/apache/iceberg-go" + "github.com/apache/iceberg-go/catalog" + "github.com/apache/iceberg-go/table" + "github.com/docopt/docopt-go" +) + +const usage = `iceberg. + +Usage: + iceberg list [options] [PARENT] + iceberg describe [options] [namespace | table] IDENTIFIER + iceberg (schema | spec | uuid | location) [options] TABLE_ID + iceberg drop [options] (namespace | table) IDENTIFIER + iceberg files [options] TABLE_ID [--history] + iceberg rename [options] + iceberg properties [options] get (namespace | table) IDENTIFIER [PROPNAME] + iceberg properties [options] set (namespace | table) IDENTIFIER PROPNAME VALUE + iceberg properties [options] remove (namespace | table) IDENTIFIER PROPNAME + iceberg -h | --help | --version + +Arguments: + PARENT Catalog parent namespace + IDENTIFIER fully qualified namespace or table + TABLE_ID full path to a table + PROPNAME name of a property + VALUE value to set + +Options: + -h --help show this helpe messages and exit + --catalog TEXT specify the catalog type [default: rest] + --uri TEXT specify the catalog URI + --output TYPE output type (json/text) [default: text] + --credential TEXT specify credentials for the catalog + --warehouse TEXT specify the warehouse to use` + +func main() { + args, err := docopt.ParseArgs(usage, os.Args[1:], iceberg.Version()) + if err != nil { + log.Fatal(err) + } + + cfg := struct { + List bool `docopt:"list"` + Describe bool `docopt:"describe"` + Schema bool `docopt:"schema"` + Spec bool `docopt:"spec"` + Uuid bool `docopt:"uuid"` + Location bool `docopt:"location"` + Props bool `docopt:"properties"` + Drop bool `docopt:"drop"` + Files bool `docopt:"files"` + Rename bool `docopt:"rename"` + + Get bool `docopt:"get"` + Set bool `docopt:"set"` + Remove bool `docopt:"remove"` + + Namespace bool `docopt:"namespace"` + Table bool `docopt:"table"` + + RenameFrom string `docopt:""` + RenameTo string `docopt:""` + + Parent string `docopt:"PARENT"` + Ident string `docopt:"IDENTIFIER"` + TableID string `docopt:"TABLE_ID"` + PropName string `docopt:"PROPNAME"` + Value string `docopt:"VALUE"` + + Catalog string `docopt:"--catalog"` + URI string `docopt:"--uri"` + Output string `docopt:"--output"` + History bool `docopt:"--history"` + Cred string `docopt:"--credential"` + Warehouse string `docopt:"--warehouse"` + }{} + + if err := args.Bind(&cfg); err != nil { + log.Fatal(err) + } + + var output Output + switch strings.ToLower(cfg.Output) { + case "text": + output = text{} + case "json": + fallthrough + default: + log.Fatal("unimplemented output type") + } + + var cat catalog.Catalog + switch catalog.CatalogType(cfg.Catalog) { + case catalog.REST: + opts := []catalog.Option[catalog.RestCatalog]{} + if len(cfg.Cred) > 0 { + opts = append(opts, catalog.WithCredential(cfg.Cred)) + } + + if len(cfg.Warehouse) > 0 { + opts = append(opts, catalog.WithWarehouseLocation(cfg.Warehouse)) + } + + if cat, err = catalog.NewRestCatalog("rest", cfg.URI, opts...); err != nil { + log.Fatal(err) + } + default: + log.Fatal("unrecognized catalog type") + } + + switch { + case cfg.List: + list(output, cat, cfg.Parent) + case cfg.Describe: + entityType := "any" + if cfg.Namespace { + entityType = "ns" + } else if cfg.Table { + entityType = "tbl" + } + + describe(output, cat, cfg.Ident, entityType) + case cfg.Schema: + tbl := loadTable(output, cat, cfg.TableID) + output.Schema(tbl.Schema()) + case cfg.Spec: + tbl := loadTable(output, cat, cfg.TableID) + output.Spec(tbl.Spec()) + case cfg.Location: + tbl := loadTable(output, cat, cfg.TableID) + output.Text(tbl.Location()) + case cfg.Uuid: + tbl := loadTable(output, cat, cfg.TableID) + output.Uuid(tbl.Metadata().TableUUID()) + case cfg.Props: + properties(output, cat, propCmd{ + get: cfg.Get, set: cfg.Set, remove: cfg.Remove, + namespace: cfg.Namespace, table: cfg.Table, + identifier: cfg.Ident, + propname: cfg.PropName, + value: cfg.Value, + }) + case cfg.Rename: + _, err := cat.RenameTable(context.Background(), + catalog.ToRestIdentifier(cfg.RenameFrom), catalog.ToRestIdentifier(cfg.RenameTo)) + if err != nil { + output.Error(err) + os.Exit(1) + } + + output.Text("Renamed table from " + cfg.RenameFrom + " to " + cfg.RenameTo) + case cfg.Drop: + switch { + case cfg.Namespace: + err := cat.DropNamespace(context.Background(), catalog.ToRestIdentifier(cfg.Ident)) + if err != nil { + output.Error(err) + os.Exit(1) + } + case cfg.Table: + err := cat.DropTable(context.Background(), catalog.ToRestIdentifier(cfg.Ident)) + if err != nil { + output.Error(err) + os.Exit(1) + } + } + case cfg.Files: + tbl := loadTable(output, cat, cfg.TableID) + output.Files(tbl, cfg.History) + } +} + +func list(output Output, cat catalog.Catalog, parent string) { + prnt := catalog.ToRestIdentifier(parent) + + ids, err := cat.ListNamespaces(context.Background(), prnt) + if err != nil { + output.Error(err) + os.Exit(1) + } + + if len(ids) == 0 && parent != "" { + ids, err = cat.ListTables(context.Background(), prnt) + if err != nil { + output.Error(err) + os.Exit(1) + } + } + output.Identifiers(ids) +} + +func describe(output Output, cat catalog.Catalog, id string, entityType string) { + ctx := context.Background() + + ident := catalog.ToRestIdentifier(id) + + isNS, isTbl := false, false + if (entityType == "any" || entityType == "ns") && len(ident) > 0 { + nsprops, err := cat.LoadNamespaceProperties(ctx, ident) + if err != nil { + if errors.Is(err, catalog.ErrNoSuchNamespace) { + if entityType != "any" || len(ident) == 1 { + output.Error(err) + os.Exit(1) + } + } else { + output.Error(err) + os.Exit(1) + } + } else { + isNS = true + output.DescribeProperties(nsprops) + } + } + + if (entityType == "any" || entityType == "tbl") && len(ident) > 1 { + tbl, err := cat.LoadTable(ctx, ident, nil) + if err != nil { + if !errors.Is(err, catalog.ErrNoSuchTable) || entityType != "any" { + output.Error(err) + os.Exit(1) + } + } else { + isTbl = true + output.DescribeTable(tbl) + } + } + + if !isNS && !isTbl { + output.Error(fmt.Errorf("%w: table or namespace does not exist: %s", + catalog.ErrNoSuchNamespace, ident)) + os.Exit(1) + } +} + +func loadTable(output Output, cat catalog.Catalog, id string) *table.Table { + tbl, err := cat.LoadTable(context.Background(), catalog.ToRestIdentifier(id), nil) + if err != nil { + output.Error(err) + os.Exit(1) + } + + return tbl +} + +type propCmd struct { + get, set, remove bool + namespace, table bool + + identifier, propname, value string +} + +func properties(output Output, cat catalog.Catalog, args propCmd) { + ctx, ident := context.Background(), catalog.ToRestIdentifier(args.identifier) + + switch { + case args.get: + var props iceberg.Properties + switch { + case args.namespace: + var err error + props, err = cat.LoadNamespaceProperties(ctx, ident) + if err != nil { + output.Error(err) + os.Exit(1) + } + case args.table: + tbl := loadTable(output, cat, args.identifier) + props = tbl.Metadata().Properties() + } + + if args.propname == "" { + output.DescribeProperties(props) + return + } + + if val, ok := props[args.propname]; ok { + output.Text(val) + } else { + output.Error(errors.New("could not find property " + args.propname + " on namespace " + args.identifier)) + os.Exit(1) + } + case args.set: + switch { + case args.namespace: + _, err := cat.UpdateNamespaceProperties(ctx, ident, + nil, iceberg.Properties{args.propname: args.value}) + if err != nil { + output.Error(err) + os.Exit(1) + } + + output.Text("updated " + args.propname + " on " + args.identifier) + case args.table: + loadTable(output, cat, args.identifier) + output.Text("Setting " + args.propname + "=" + args.value + " on " + args.identifier) + output.Error(errors.New("not implemented: Writing is WIP")) + } + case args.remove: + switch { + case args.namespace: + _, err := cat.UpdateNamespaceProperties(ctx, ident, + []string{args.propname}, nil) + if err != nil { + output.Error(err) + os.Exit(1) + } + + output.Text("removing " + args.propname + " from " + args.identifier) + case args.table: + loadTable(output, cat, args.identifier) + output.Text("Setting " + args.propname + "=" + args.value + " on " + args.identifier) + output.Error(errors.New("not implemented: Writing is WIP")) + } + } +} diff --git a/cmd/iceberg/output.go b/cmd/iceberg/output.go index 03ad798..6f26f1d 100644 --- a/cmd/iceberg/output.go +++ b/cmd/iceberg/output.go @@ -1,220 +1,220 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package main - -import ( - "fmt" - "log" - "os" - "strconv" - "strings" - - "github.com/apache/iceberg-go" - "github.com/apache/iceberg-go/table" - - "github.com/google/uuid" - "github.com/pterm/pterm" - "github.com/pterm/pterm/putils" -) - -type Output interface { - Identifiers([]table.Identifier) - DescribeTable(*table.Table) - Files(tbl *table.Table, history bool) - DescribeProperties(iceberg.Properties) - Text(string) - Schema(*iceberg.Schema) - Spec(iceberg.PartitionSpec) - Uuid(uuid.UUID) - Error(error) -} - -type text struct{} - -func (text) Identifiers(idlist []table.Identifier) { - data := pterm.TableData{[]string{"IDs"}} - for _, ids := range idlist { - data = append(data, []string{strings.Join(ids, ".")}) - } - - pterm.DefaultTable. - WithBoxed(true). - WithHasHeader(true). - WithHeaderRowSeparator("-"). - WithData(data).Render() -} - -func (t text) DescribeTable(tbl *table.Table) { - propData := pterm.TableData{{"key", "value"}} - for k, v := range tbl.Metadata().Properties() { - propData = append(propData, []string{k, v}) - } - propTable := pterm.DefaultTable. - WithHasHeader(true). - WithHeaderRowSeparator("-"). - WithData(propData) - - snapshotList := pterm.LeveledList{} - for _, s := range tbl.Metadata().Snapshots() { - var manifest string - if s.ManifestList != "" { - manifest = ": " + s.ManifestList - } - - snapshotList = append(snapshotList, pterm.LeveledListItem{ - Level: 0, Text: fmt.Sprintf("Snapshot %d, schema %d%s", - s.SnapshotID, *s.SchemaID, manifest), - }) - } - - snapshotTreeNode := putils.TreeFromLeveledList(snapshotList) - snapshotTreeNode.Text = "Snapshots" - - pterm.DefaultTable. - WithData(pterm.TableData{ - {"Table format version", strconv.Itoa(tbl.Metadata().Version())}, - {"Metadata location", tbl.MetadataLocation()}, - {"Table UUID", tbl.Metadata().TableUUID().String()}, - {"Last updated", strconv.Itoa(int(tbl.Metadata().LastUpdatedMillis()))}, - {"Sort Order", tbl.SortOrder().String()}, - {"Partition Spec", tbl.Spec().String()}, - }).Render() - - t.Schema(tbl.Schema()) - snap := "" - if tbl.CurrentSnapshot() != nil { - snap = tbl.CurrentSnapshot().String() - } - pterm.DefaultTable. - WithData(pterm.TableData{ - {"Current Snapshot", snap}, - }).Render() - pterm.DefaultTree.WithRoot(snapshotTreeNode).Render() - pterm.Println("Properties") - propTable.Render() -} - -func (t text) Files(tbl *table.Table, history bool) { - var snapshots []table.Snapshot - if history { - snapshots = tbl.Metadata().Snapshots() - } else { - snap := tbl.CurrentSnapshot() - if snap != nil { - snapshots = []table.Snapshot{*snap} - } - } - - snapshotTree := pterm.LeveledList{} - for _, snap := range snapshots { - manifest := snap.ManifestList - if manifest != "" { - manifest = ": " + manifest - } - - snapshotTree = append(snapshotTree, pterm.LeveledListItem{ - Level: 0, - Text: fmt.Sprintf("Snapshot %d, schema %d%s", - snap.SnapshotID, *snap.SchemaID, manifest), - }) - - manifestList, err := snap.Manifests(tbl.FS()) - if err != nil { - t.Error(err) - os.Exit(1) - } - - for _, m := range manifestList { - snapshotTree = append(snapshotTree, pterm.LeveledListItem{ - Level: 1, Text: "Manifest: " + m.FilePath(), - }) - datafiles, err := m.FetchEntries(tbl.FS(), false) - if err != nil { - t.Error(err) - os.Exit(1) - } - for _, e := range datafiles { - snapshotTree = append(snapshotTree, pterm.LeveledListItem{ - Level: 2, Text: "Datafile: " + e.DataFile().FilePath(), - }) - } - } - } - - node := putils.TreeFromLeveledList(snapshotTree) - node.Text = "Snapshots: " + strings.Join(tbl.Identifier(), ".") - pterm.DefaultTree.WithRoot(node).Render() -} - -func (text) DescribeProperties(props iceberg.Properties) { - data := pterm.TableData{[]string{"Key", "Value"}} - for k, v := range props { - data = append(data, []string{k, v}) - } - - pterm.DefaultTable. - WithBoxed(true). - WithHasHeader(true). - WithHeaderRowSeparator("-"). - WithData(data).Render() -} - -func (text) Text(val string) { - fmt.Println(val) -} - -func (text) Schema(schema *iceberg.Schema) { - schemaTree := pterm.LeveledList{} - var addChildren func(iceberg.NestedField, int) - addChildren = func(nf iceberg.NestedField, depth int) { - if nested, ok := nf.Type.(iceberg.NestedType); ok { - for _, n := range nested.Fields() { - schemaTree = append(schemaTree, pterm.LeveledListItem{ - Level: depth, Text: n.String(), - }) - addChildren(n, depth+1) - } - } - } - - for _, f := range schema.Fields() { - schemaTree = append(schemaTree, pterm.LeveledListItem{ - Level: 0, Text: f.String(), - }) - addChildren(f, 1) - } - schemaTreeNode := putils.TreeFromLeveledList(schemaTree) - schemaTreeNode.Text = "Current Schema, id=" + strconv.Itoa(schema.ID) - pterm.DefaultTree.WithRoot(schemaTreeNode).Render() -} - -func (text) Spec(spec iceberg.PartitionSpec) { - fmt.Println(spec) -} - -func (text) Uuid(u uuid.UUID) { - if u.String() != "" { - fmt.Println(u.String()) - } else { - fmt.Println("missing") - } -} - -func (text) Error(err error) { - log.Fatal(err) -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package main + +import ( + "fmt" + "log" + "os" + "strconv" + "strings" + + "github.com/apache/iceberg-go" + "github.com/apache/iceberg-go/table" + + "github.com/google/uuid" + "github.com/pterm/pterm" + "github.com/pterm/pterm/putils" +) + +type Output interface { + Identifiers([]table.Identifier) + DescribeTable(*table.Table) + Files(tbl *table.Table, history bool) + DescribeProperties(iceberg.Properties) + Text(string) + Schema(*iceberg.Schema) + Spec(iceberg.PartitionSpec) + Uuid(uuid.UUID) + Error(error) +} + +type text struct{} + +func (text) Identifiers(idlist []table.Identifier) { + data := pterm.TableData{[]string{"IDs"}} + for _, ids := range idlist { + data = append(data, []string{strings.Join(ids, ".")}) + } + + pterm.DefaultTable. + WithBoxed(true). + WithHasHeader(true). + WithHeaderRowSeparator("-"). + WithData(data).Render() +} + +func (t text) DescribeTable(tbl *table.Table) { + propData := pterm.TableData{{"key", "value"}} + for k, v := range tbl.Metadata().Properties() { + propData = append(propData, []string{k, v}) + } + propTable := pterm.DefaultTable. + WithHasHeader(true). + WithHeaderRowSeparator("-"). + WithData(propData) + + snapshotList := pterm.LeveledList{} + for _, s := range tbl.Metadata().Snapshots() { + var manifest string + if s.ManifestList != "" { + manifest = ": " + s.ManifestList + } + + snapshotList = append(snapshotList, pterm.LeveledListItem{ + Level: 0, Text: fmt.Sprintf("Snapshot %d, schema %d%s", + s.SnapshotID, *s.SchemaID, manifest), + }) + } + + snapshotTreeNode := putils.TreeFromLeveledList(snapshotList) + snapshotTreeNode.Text = "Snapshots" + + pterm.DefaultTable. + WithData(pterm.TableData{ + {"Table format version", strconv.Itoa(tbl.Metadata().Version())}, + {"Metadata location", tbl.MetadataLocation()}, + {"Table UUID", tbl.Metadata().TableUUID().String()}, + {"Last updated", strconv.Itoa(int(tbl.Metadata().LastUpdatedMillis()))}, + {"Sort Order", tbl.SortOrder().String()}, + {"Partition Spec", tbl.Spec().String()}, + }).Render() + + t.Schema(tbl.Schema()) + snap := "" + if tbl.CurrentSnapshot() != nil { + snap = tbl.CurrentSnapshot().String() + } + pterm.DefaultTable. + WithData(pterm.TableData{ + {"Current Snapshot", snap}, + }).Render() + pterm.DefaultTree.WithRoot(snapshotTreeNode).Render() + pterm.Println("Properties") + propTable.Render() +} + +func (t text) Files(tbl *table.Table, history bool) { + var snapshots []table.Snapshot + if history { + snapshots = tbl.Metadata().Snapshots() + } else { + snap := tbl.CurrentSnapshot() + if snap != nil { + snapshots = []table.Snapshot{*snap} + } + } + + snapshotTree := pterm.LeveledList{} + for _, snap := range snapshots { + manifest := snap.ManifestList + if manifest != "" { + manifest = ": " + manifest + } + + snapshotTree = append(snapshotTree, pterm.LeveledListItem{ + Level: 0, + Text: fmt.Sprintf("Snapshot %d, schema %d%s", + snap.SnapshotID, *snap.SchemaID, manifest), + }) + + manifestList, err := snap.Manifests(tbl.FS()) + if err != nil { + t.Error(err) + os.Exit(1) + } + + for _, m := range manifestList { + snapshotTree = append(snapshotTree, pterm.LeveledListItem{ + Level: 1, Text: "Manifest: " + m.FilePath(), + }) + datafiles, err := m.FetchEntries(tbl.FS(), false) + if err != nil { + t.Error(err) + os.Exit(1) + } + for _, e := range datafiles { + snapshotTree = append(snapshotTree, pterm.LeveledListItem{ + Level: 2, Text: "Datafile: " + e.DataFile().FilePath(), + }) + } + } + } + + node := putils.TreeFromLeveledList(snapshotTree) + node.Text = "Snapshots: " + strings.Join(tbl.Identifier(), ".") + pterm.DefaultTree.WithRoot(node).Render() +} + +func (text) DescribeProperties(props iceberg.Properties) { + data := pterm.TableData{[]string{"Key", "Value"}} + for k, v := range props { + data = append(data, []string{k, v}) + } + + pterm.DefaultTable. + WithBoxed(true). + WithHasHeader(true). + WithHeaderRowSeparator("-"). + WithData(data).Render() +} + +func (text) Text(val string) { + fmt.Println(val) +} + +func (text) Schema(schema *iceberg.Schema) { + schemaTree := pterm.LeveledList{} + var addChildren func(iceberg.NestedField, int) + addChildren = func(nf iceberg.NestedField, depth int) { + if nested, ok := nf.Type.(iceberg.NestedType); ok { + for _, n := range nested.Fields() { + schemaTree = append(schemaTree, pterm.LeveledListItem{ + Level: depth, Text: n.String(), + }) + addChildren(n, depth+1) + } + } + } + + for _, f := range schema.Fields() { + schemaTree = append(schemaTree, pterm.LeveledListItem{ + Level: 0, Text: f.String(), + }) + addChildren(f, 1) + } + schemaTreeNode := putils.TreeFromLeveledList(schemaTree) + schemaTreeNode.Text = "Current Schema, id=" + strconv.Itoa(schema.ID) + pterm.DefaultTree.WithRoot(schemaTreeNode).Render() +} + +func (text) Spec(spec iceberg.PartitionSpec) { + fmt.Println(spec) +} + +func (text) Uuid(u uuid.UUID) { + if u.String() != "" { + fmt.Println(u.String()) + } else { + fmt.Println("missing") + } +} + +func (text) Error(err error) { + log.Fatal(err) +} diff --git a/cmd/iceberg/output_test.go b/cmd/iceberg/output_test.go index 5907660..311481c 100644 --- a/cmd/iceberg/output_test.go +++ b/cmd/iceberg/output_test.go @@ -1,195 +1,195 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package main - -import ( - "bytes" - "testing" - - "github.com/apache/iceberg-go/table" - "github.com/pterm/pterm" - "github.com/stretchr/testify/assert" -) - -var testArgs = []struct { - meta string - expected string -}{ - {`{ - "format-version": 2, - "table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1", - "location": "s3://bucket/test/location", - "last-sequence-number": 0, - "last-updated-ms": 1602638573590, - "last-column-id": 3, - "current-schema-id": 0, - "schemas": [ - {"type": "struct", "schema-id": 0, "fields": [{"id": 1, "name": "x", "required": true, "type": "long"}]}, - { - "type": "struct", - "fields": [ - {"id": 1, "name": "x", "required": true, "type": "long"} - ] - } - ], - "default-spec-id": 0, - "partition-specs": [{"spec-id": 0, "fields": []}], - "last-partition-id": 1000, - "default-sort-order-id": 0, - "sort-orders": [ - { - "order-id": 0, - "fields": [ ] - } - ], - "properties": {"read.split.target.size": "134217728"}, - "current-snapshot-id": -1, - "snapshots": [ ], - "snapshot-log": [ ], - "metadata-log": [ ], - "refs": { } -}`, -`Table format version | 2 -Metadata location | -Table UUID | 9c12d441-03fe-4693-9a96-a0705ddf69c1 -Last updated | 1602638573590 -Sort Order | 0: [] -Partition Spec | [] - -Current Schema, id=0 -└──1: x: required long - -Current Snapshot | - -Snapshots - -Properties -key | value ----------------------------------- -read.split.target.size | 134217728 - -`}, - {`{ - "format-version": 2, - "table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1", - "location": "s3://bucket/test/location", - "last-sequence-number": 34, - "last-updated-ms": 1602638573590, - "last-column-id": 3, - "current-schema-id": 1, - "schemas": [ - {"type": "struct", "schema-id": 0, "fields": [{"id": 1, "name": "x", "required": true, "type": "long"}]}, - { - "type": "struct", - "schema-id": 1, - "identifier-field-ids": [1, 2], - "fields": [ - {"id": 1, "name": "x", "required": true, "type": "long"}, - {"id": 2, "name": "y", "required": true, "type": "long", "doc": "comment"}, - {"id": 3, "name": "z", "required": true, "type": "long"} - ] - } - ], - "default-spec-id": 0, - "partition-specs": [{"spec-id": 0, "fields": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]}], - "last-partition-id": 1000, - "default-sort-order-id": 3, - "sort-orders": [ - { - "order-id": 3, - "fields": [ - {"transform": "identity", "source-id": 2, "direction": "asc", "null-order": "nulls-first"}, - {"transform": "bucket[4]", "source-id": 3, "direction": "desc", "null-order": "nulls-last"} - ] - } - ], - "properties": {"read.split.target.size": "134217728"}, - "current-snapshot-id": 3055729675574597004, - "snapshots": [ - { - "snapshot-id": 3051729675574597004, - "timestamp-ms": 1515100955770, - "sequence-number": 0, - "summary": {"operation": "append"}, - "manifest-list": "s3://a/b/1.avro", - "schema-id": 1 - }, - { - "snapshot-id": 3055729675574597004, - "parent-snapshot-id": 3051729675574597004, - "timestamp-ms": 1555100955770, - "sequence-number": 1, - "summary": {"operation": "append"}, - "manifest-list": "s3://a/b/2.avro", - "schema-id": 1 - } - ], - "snapshot-log": [ - {"snapshot-id": 3051729675574597004, "timestamp-ms": 1515100955770}, - {"snapshot-id": 3055729675574597004, "timestamp-ms": 1555100955770} - ], - "metadata-log": [{"metadata-file": "s3://bucket/.../v1.json", "timestamp-ms": 1515100}], - "refs": {"test": {"snapshot-id": 3051729675574597004, "type": "tag", "max-ref-age-ms": 10000000}} -}`, -`Table format version | 2 -Metadata location | -Table UUID | 9c12d441-03fe-4693-9a96-a0705ddf69c1 -Last updated | 1602638573590 -Sort Order | 3: [ - | 2 asc nulls-first - | bucket[4](3) desc nulls-last - | ] -Partition Spec | [ - | 1000: x: identity(1) - | ] - -Current Schema, id=1 -├──1: x: required long -├──2: y: required long (comment) -└──3: z: required long - -Current Snapshot | append, {}: id=3055729675574597004, parent_id=3051729675574597004, schema_id=1, sequence_number=1, timestamp_ms=1555100955770, manifest_list=s3://a/b/2.avro - -Snapshots -├──Snapshot 3051729675574597004, schema 1: s3://a/b/1.avro -└──Snapshot 3055729675574597004, schema 1: s3://a/b/2.avro - -Properties -key | value ----------------------------------- -read.split.target.size | 134217728 - -`}, -} - - -func TestDescribeTable(t *testing.T) { - var buf bytes.Buffer - pterm.SetDefaultOutput(&buf) - pterm.DisableColor() - - for _, tt := range testArgs { - meta, _ := table.ParseMetadataBytes([]byte(tt.meta)) - table := table.New([]string{"t"}, meta, "", nil) - buf.Reset() - - text{}.DescribeTable(table) - - assert.Equal(t, tt.expected, buf.String()) - } -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package main + +import ( + "bytes" + "testing" + + "github.com/apache/iceberg-go/table" + "github.com/pterm/pterm" + "github.com/stretchr/testify/assert" +) + +var testArgs = []struct { + meta string + expected string +}{ + {`{ + "format-version": 2, + "table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1", + "location": "s3://bucket/test/location", + "last-sequence-number": 0, + "last-updated-ms": 1602638573590, + "last-column-id": 3, + "current-schema-id": 0, + "schemas": [ + {"type": "struct", "schema-id": 0, "fields": [{"id": 1, "name": "x", "required": true, "type": "long"}]}, + { + "type": "struct", + "fields": [ + {"id": 1, "name": "x", "required": true, "type": "long"} + ] + } + ], + "default-spec-id": 0, + "partition-specs": [{"spec-id": 0, "fields": []}], + "last-partition-id": 1000, + "default-sort-order-id": 0, + "sort-orders": [ + { + "order-id": 0, + "fields": [ ] + } + ], + "properties": {"read.split.target.size": "134217728"}, + "current-snapshot-id": -1, + "snapshots": [ ], + "snapshot-log": [ ], + "metadata-log": [ ], + "refs": { } +}`, +`Table format version | 2 +Metadata location | +Table UUID | 9c12d441-03fe-4693-9a96-a0705ddf69c1 +Last updated | 1602638573590 +Sort Order | 0: [] +Partition Spec | [] + +Current Schema, id=0 +└──1: x: required long + +Current Snapshot | + +Snapshots + +Properties +key | value +---------------------------------- +read.split.target.size | 134217728 + +`}, + {`{ + "format-version": 2, + "table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1", + "location": "s3://bucket/test/location", + "last-sequence-number": 34, + "last-updated-ms": 1602638573590, + "last-column-id": 3, + "current-schema-id": 1, + "schemas": [ + {"type": "struct", "schema-id": 0, "fields": [{"id": 1, "name": "x", "required": true, "type": "long"}]}, + { + "type": "struct", + "schema-id": 1, + "identifier-field-ids": [1, 2], + "fields": [ + {"id": 1, "name": "x", "required": true, "type": "long"}, + {"id": 2, "name": "y", "required": true, "type": "long", "doc": "comment"}, + {"id": 3, "name": "z", "required": true, "type": "long"} + ] + } + ], + "default-spec-id": 0, + "partition-specs": [{"spec-id": 0, "fields": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]}], + "last-partition-id": 1000, + "default-sort-order-id": 3, + "sort-orders": [ + { + "order-id": 3, + "fields": [ + {"transform": "identity", "source-id": 2, "direction": "asc", "null-order": "nulls-first"}, + {"transform": "bucket[4]", "source-id": 3, "direction": "desc", "null-order": "nulls-last"} + ] + } + ], + "properties": {"read.split.target.size": "134217728"}, + "current-snapshot-id": 3055729675574597004, + "snapshots": [ + { + "snapshot-id": 3051729675574597004, + "timestamp-ms": 1515100955770, + "sequence-number": 0, + "summary": {"operation": "append"}, + "manifest-list": "s3://a/b/1.avro", + "schema-id": 1 + }, + { + "snapshot-id": 3055729675574597004, + "parent-snapshot-id": 3051729675574597004, + "timestamp-ms": 1555100955770, + "sequence-number": 1, + "summary": {"operation": "append"}, + "manifest-list": "s3://a/b/2.avro", + "schema-id": 1 + } + ], + "snapshot-log": [ + {"snapshot-id": 3051729675574597004, "timestamp-ms": 1515100955770}, + {"snapshot-id": 3055729675574597004, "timestamp-ms": 1555100955770} + ], + "metadata-log": [{"metadata-file": "s3://bucket/.../v1.json", "timestamp-ms": 1515100}], + "refs": {"test": {"snapshot-id": 3051729675574597004, "type": "tag", "max-ref-age-ms": 10000000}} +}`, +`Table format version | 2 +Metadata location | +Table UUID | 9c12d441-03fe-4693-9a96-a0705ddf69c1 +Last updated | 1602638573590 +Sort Order | 3: [ + | 2 asc nulls-first + | bucket[4](3) desc nulls-last + | ] +Partition Spec | [ + | 1000: x: identity(1) + | ] + +Current Schema, id=1 +├──1: x: required long +├──2: y: required long (comment) +└──3: z: required long + +Current Snapshot | append, {}: id=3055729675574597004, parent_id=3051729675574597004, schema_id=1, sequence_number=1, timestamp_ms=1555100955770, manifest_list=s3://a/b/2.avro + +Snapshots +├──Snapshot 3051729675574597004, schema 1: s3://a/b/1.avro +└──Snapshot 3055729675574597004, schema 1: s3://a/b/2.avro + +Properties +key | value +---------------------------------- +read.split.target.size | 134217728 + +`}, +} + + +func TestDescribeTable(t *testing.T) { + var buf bytes.Buffer + pterm.SetDefaultOutput(&buf) + pterm.DisableColor() + + for _, tt := range testArgs { + meta, _ := table.ParseMetadataBytes([]byte(tt.meta)) + table := table.New([]string{"t"}, meta, "", nil) + buf.Reset() + + text{}.DescribeTable(table) + + assert.Equal(t, tt.expected, buf.String()) + } +} diff --git a/dev/Dockerfile b/dev/Dockerfile index c8b9e65..6e22cc1 100644 --- a/dev/Dockerfile +++ b/dev/Dockerfile @@ -1,25 +1,25 @@ -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -FROM tabulario/spark-iceberg - -RUN pip3 install -q ipython -RUN pip3 install pyiceberg[s3fs,hive] -RUN pip3 install pyarrow - -COPY provision.py . - -ENTRYPOINT ["./entrypoint.sh"] +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +FROM tabulario/spark-iceberg + +RUN pip3 install -q ipython +RUN pip3 install pyiceberg[s3fs,hive] +RUN pip3 install pyarrow + +COPY provision.py . + +ENTRYPOINT ["./entrypoint.sh"] CMD ["notebook"] \ No newline at end of file diff --git a/dev/check-license b/dev/check-license index 7ba2d9a..dee146d 100755 --- a/dev/check-license +++ b/dev/check-license @@ -1,83 +1,83 @@ -#!/usr/bin/env bash - -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -acquire_rat_jar () { - - URL="https://repo.maven.apache.org/maven2/org/apache/rat/apache-rat/${RAT_VERSION}/apache-rat-${RAT_VERSION}.jar" - - JAR="$rat_jar" - - # Download rat launch jar if it hasn't been downloaded yet - if [ ! -f "$JAR" ]; then - # Download - printf "Attempting to fetch rat\n" - JAR_DL="${JAR}.part" - if [ $(command -v curl) ]; then - curl -L --silent "${URL}" > "$JAR_DL" && mv "$JAR_DL" "$JAR" - elif [ $(command -v wget) ]; then - wget --quiet ${URL} -O "$JAR_DL" && mv "$JAR_DL" "$JAR" - else - printf "You do not have curl or wget installed, please install rat manually.\n" - exit -1 - fi - fi - - unzip -tq "$JAR" &> /dev/null - if [ $? -ne 0 ]; then - # We failed to download - rm "$JAR" - printf "Our attempt to download rat locally to ${JAR} failed. Please install rat manually.\n" - exit -1 - fi -} - -# Go to the Spark project root directory -FWDIR="$(cd "`dirname "$0"`"/..; pwd)" -cd "$FWDIR" - -if test -x "$JAVA_HOME/bin/java"; then - declare java_cmd="$JAVA_HOME/bin/java" -else - declare java_cmd=java -fi - -export RAT_VERSION=0.15 -export rat_jar="$FWDIR"/lib/apache-rat-${RAT_VERSION}.jar -mkdir -p "$FWDIR"/lib - -[[ -f "$rat_jar" ]] || acquire_rat_jar || { - echo "Download failed. Obtain the rat jar manually and place it at $rat_jar" - exit 1 -} - -mkdir -p build -$java_cmd -jar "$rat_jar" -E "$FWDIR"/dev/release/rat_exclude_files.txt -d "$FWDIR" > build/rat-results.txt - -if [ $? -ne 0 ]; then - echo "RAT exited abnormally" - exit 1 -fi - -ERRORS="$(cat build/rat-results.txt | grep -e "??")" - -if test ! -z "$ERRORS"; then - echo "Could not find Apache license headers in the following files:" - echo "$ERRORS" - exit 1 -else - echo -e "RAT checks passed." -fi +#!/usr/bin/env bash + +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +acquire_rat_jar () { + + URL="https://repo.maven.apache.org/maven2/org/apache/rat/apache-rat/${RAT_VERSION}/apache-rat-${RAT_VERSION}.jar" + + JAR="$rat_jar" + + # Download rat launch jar if it hasn't been downloaded yet + if [ ! -f "$JAR" ]; then + # Download + printf "Attempting to fetch rat\n" + JAR_DL="${JAR}.part" + if [ $(command -v curl) ]; then + curl -L --silent "${URL}" > "$JAR_DL" && mv "$JAR_DL" "$JAR" + elif [ $(command -v wget) ]; then + wget --quiet ${URL} -O "$JAR_DL" && mv "$JAR_DL" "$JAR" + else + printf "You do not have curl or wget installed, please install rat manually.\n" + exit -1 + fi + fi + + unzip -tq "$JAR" &> /dev/null + if [ $? -ne 0 ]; then + # We failed to download + rm "$JAR" + printf "Our attempt to download rat locally to ${JAR} failed. Please install rat manually.\n" + exit -1 + fi +} + +# Go to the Spark project root directory +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" +cd "$FWDIR" + +if test -x "$JAVA_HOME/bin/java"; then + declare java_cmd="$JAVA_HOME/bin/java" +else + declare java_cmd=java +fi + +export RAT_VERSION=0.15 +export rat_jar="$FWDIR"/lib/apache-rat-${RAT_VERSION}.jar +mkdir -p "$FWDIR"/lib + +[[ -f "$rat_jar" ]] || acquire_rat_jar || { + echo "Download failed. Obtain the rat jar manually and place it at $rat_jar" + exit 1 +} + +mkdir -p build +$java_cmd -jar "$rat_jar" -E "$FWDIR"/dev/release/rat_exclude_files.txt -d "$FWDIR" > build/rat-results.txt + +if [ $? -ne 0 ]; then + echo "RAT exited abnormally" + exit 1 +fi + +ERRORS="$(cat build/rat-results.txt | grep -e "??")" + +if test ! -z "$ERRORS"; then + echo "Could not find Apache license headers in the following files:" + echo "$ERRORS" + exit 1 +else + echo -e "RAT checks passed." +fi diff --git a/dev/docker-compose.yml b/dev/docker-compose.yml index 63821cd..12e89c6 100644 --- a/dev/docker-compose.yml +++ b/dev/docker-compose.yml @@ -1,91 +1,91 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -version: "3" - -services: - spark-iceberg: - image: pyiceberg-spark - container_name: spark-iceberg - build: . - networks: - iceberg_net: - depends_on: - - rest - - minio - volumes: - - ./warehouse:/home/iceberg/warehouse - - ./notebooks:/home/iceberg/notebooks/notebooks - environment: - - AWS_ACCESS_KEY_ID=admin - - AWS_SECRET_ACCESS_KEY=password - - AWS_REGION=us-east-1 - ports: - - 8888:8888 - - 8080:8080 - - 10000:10000 - - 10001:10001 - rest: - image: tabulario/iceberg-rest - container_name: iceberg-rest - networks: - iceberg_net: - ports: - - 8181:8181 - environment: - - AWS_ACCESS_KEY_ID=admin - - AWS_SECRET_ACCESS_KEY=password - - AWS_REGION=us-east-1 - - CATALOG_WAREHOUSE=s3://warehouse/ - - CATALOG_IO__IMPL=org.apache.iceberg.aws.s3.S3FileIO - - CATALOG_S3_ENDPOINT=http://minio:9000 - minio: - image: minio/minio - container_name: minio - environment: - - MINIO_ROOT_USER=admin - - MINIO_ROOT_PASSWORD=password - - MINIO_DOMAIN=minio - networks: - iceberg_net: - aliases: - - warehouse.minio - ports: - - 9001:9001 - - 9000:9000 - command: ["server", "/data", "--console-address", ":9001"] - mc: - depends_on: - - minio - image: minio/mc - container_name: mc - networks: - iceberg_net: - environment: - - AWS_ACCESS_KEY_ID=admin - - AWS_SECRET_ACCESS_KEY=password - - AWS_REGION=us-east-1 - entrypoint: > - /bin/sh -c " - until (/usr/bin/mc config host add minio http://minio:9000 admin password) do echo '...waiting...' && sleep 1; done; - /usr/bin/mc rm -r --force minio/warehouse; - /usr/bin/mc mb minio/warehouse; - /usr/bin/mc policy set public minio/warehouse; - tail -f /dev/null - " - -networks: - iceberg_net: +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +version: "3" + +services: + spark-iceberg: + image: pyiceberg-spark + container_name: spark-iceberg + build: . + networks: + iceberg_net: + depends_on: + - rest + - minio + volumes: + - ./warehouse:/home/iceberg/warehouse + - ./notebooks:/home/iceberg/notebooks/notebooks + environment: + - AWS_ACCESS_KEY_ID=admin + - AWS_SECRET_ACCESS_KEY=password + - AWS_REGION=us-east-1 + ports: + - 8888:8888 + - 8080:8080 + - 10000:10000 + - 10001:10001 + rest: + image: tabulario/iceberg-rest + container_name: iceberg-rest + networks: + iceberg_net: + ports: + - 8181:8181 + environment: + - AWS_ACCESS_KEY_ID=admin + - AWS_SECRET_ACCESS_KEY=password + - AWS_REGION=us-east-1 + - CATALOG_WAREHOUSE=s3://warehouse/ + - CATALOG_IO__IMPL=org.apache.iceberg.aws.s3.S3FileIO + - CATALOG_S3_ENDPOINT=http://minio:9000 + minio: + image: minio/minio + container_name: minio + environment: + - MINIO_ROOT_USER=admin + - MINIO_ROOT_PASSWORD=password + - MINIO_DOMAIN=minio + networks: + iceberg_net: + aliases: + - warehouse.minio + ports: + - 9001:9001 + - 9000:9000 + command: ["server", "/data", "--console-address", ":9001"] + mc: + depends_on: + - minio + image: minio/mc + container_name: mc + networks: + iceberg_net: + environment: + - AWS_ACCESS_KEY_ID=admin + - AWS_SECRET_ACCESS_KEY=password + - AWS_REGION=us-east-1 + entrypoint: > + /bin/sh -c " + until (/usr/bin/mc config host add minio http://minio:9000 admin password) do echo '...waiting...' && sleep 1; done; + /usr/bin/mc rm -r --force minio/warehouse; + /usr/bin/mc mb minio/warehouse; + /usr/bin/mc policy set public minio/warehouse; + tail -f /dev/null + " + +networks: + iceberg_net: diff --git a/dev/provision.py b/dev/provision.py index 77906c1..34f54c1 100644 --- a/dev/provision.py +++ b/dev/provision.py @@ -1,379 +1,379 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from pyspark.sql import SparkSession -from pyspark.sql.functions import current_date, date_add, expr - -from pyiceberg.catalog import load_catalog -from pyiceberg.schema import Schema -from pyiceberg.types import FixedType, NestedField, UUIDType - -spark = SparkSession.builder.getOrCreate() - -catalogs = { - 'rest': load_catalog( - "rest", - **{ - "type": "rest", - "uri": "http://rest:8181", - "s3.endpoint": "http://minio:9000", - "s3.access-key-id": "admin", - "s3.secret-access-key": "password", - }, - ), -} - -for catalog_name, catalog in catalogs.items(): - spark.sql( - f""" - CREATE DATABASE IF NOT EXISTS default; - """ - ) - - schema = Schema( - NestedField(field_id=1, name="uuid_col", field_type=UUIDType(), required=False), - NestedField(field_id=2, name="fixed_col", field_type=FixedType(25), required=False), - ) - - catalog.create_table(identifier=f"default.test_uuid_and_fixed_unpartitioned", schema=schema) - - spark.sql( - f""" - INSERT INTO default.test_uuid_and_fixed_unpartitioned VALUES - ('102cb62f-e6f8-4eb0-9973-d9b012ff0967', CAST('1234567890123456789012345' AS BINARY)), - ('ec33e4b2-a834-4cc3-8c4a-a1d3bfc2f226', CAST('1231231231231231231231231' AS BINARY)), - ('639cccce-c9d2-494a-a78c-278ab234f024', CAST('12345678901234567ass12345' AS BINARY)), - ('c1b0d8e0-0b0e-4b1e-9b0a-0e0b0d0c0a0b', CAST('asdasasdads12312312312111' AS BINARY)), - ('923dae77-83d6-47cd-b4b0-d383e64ee57e', CAST('qweeqwwqq1231231231231111' AS BINARY)); - """ - ) - - spark.sql( - f""" - CREATE OR REPLACE TABLE default.test_null_nan - USING iceberg - AS SELECT - 1 AS idx, - float('NaN') AS col_numeric - UNION ALL SELECT - 2 AS idx, - null AS col_numeric - UNION ALL SELECT - 3 AS idx, - 1 AS col_numeric; - """ - ) - - spark.sql( - f""" - CREATE OR REPLACE TABLE default.test_null_nan_rewritten - USING iceberg - AS SELECT * FROM default.test_null_nan; - """ - ) - - spark.sql( - f""" - CREATE OR REPLACE TABLE default.test_limit as - SELECT * LATERAL VIEW explode(ARRAY(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) AS idx; - """ - ) - - spark.sql( - f""" - CREATE OR REPLACE TABLE default.test_positional_mor_deletes ( - dt date, - number integer, - letter string - ) - USING iceberg - TBLPROPERTIES ( - 'write.delete.mode'='merge-on-read', - 'write.update.mode'='merge-on-read', - 'write.merge.mode'='merge-on-read', - 'format-version'='2' - ); - """ - ) - - # Partitioning is not really needed, but there is a bug: - # https://github.com/apache/iceberg/pull/7685 - spark.sql(f"ALTER TABLE default.test_positional_mor_deletes ADD PARTITION FIELD years(dt) AS dt_years") - - spark.sql( - f""" - INSERT INTO default.test_positional_mor_deletes - VALUES - (CAST('2023-03-01' AS date), 1, 'a'), - (CAST('2023-03-02' AS date), 2, 'b'), - (CAST('2023-03-03' AS date), 3, 'c'), - (CAST('2023-03-04' AS date), 4, 'd'), - (CAST('2023-03-05' AS date), 5, 'e'), - (CAST('2023-03-06' AS date), 6, 'f'), - (CAST('2023-03-07' AS date), 7, 'g'), - (CAST('2023-03-08' AS date), 8, 'h'), - (CAST('2023-03-09' AS date), 9, 'i'), - (CAST('2023-03-10' AS date), 10, 'j'), - (CAST('2023-03-11' AS date), 11, 'k'), - (CAST('2023-03-12' AS date), 12, 'l'); - """ - ) - - spark.sql(f"ALTER TABLE default.test_positional_mor_deletes CREATE TAG tag_12") - - spark.sql(f"ALTER TABLE default.test_positional_mor_deletes CREATE BRANCH without_5") - - spark.sql(f"DELETE FROM default.test_positional_mor_deletes.branch_without_5 WHERE number = 5") - - spark.sql(f"DELETE FROM default.test_positional_mor_deletes WHERE number = 9") - - spark.sql( - f""" - CREATE OR REPLACE TABLE default.test_positional_mor_double_deletes ( - dt date, - number integer, - letter string - ) - USING iceberg - TBLPROPERTIES ( - 'write.delete.mode'='merge-on-read', - 'write.update.mode'='merge-on-read', - 'write.merge.mode'='merge-on-read', - 'format-version'='2' - ); - """ - ) - - spark.sql(f"ALTER TABLE default.test_positional_mor_double_deletes ADD PARTITION FIELD years(dt) AS dt_years") - - spark.sql( - f""" - INSERT INTO default.test_positional_mor_double_deletes - VALUES - (CAST('2023-03-01' AS date), 1, 'a'), - (CAST('2023-03-02' AS date), 2, 'b'), - (CAST('2023-03-03' AS date), 3, 'c'), - (CAST('2023-03-04' AS date), 4, 'd'), - (CAST('2023-03-05' AS date), 5, 'e'), - (CAST('2023-03-06' AS date), 6, 'f'), - (CAST('2023-03-07' AS date), 7, 'g'), - (CAST('2023-03-08' AS date), 8, 'h'), - (CAST('2023-03-09' AS date), 9, 'i'), - (CAST('2023-03-10' AS date), 10, 'j'), - (CAST('2023-03-11' AS date), 11, 'k'), - (CAST('2023-03-12' AS date), 12, 'l'); - """ - ) - - spark.sql(f"DELETE FROM default.test_positional_mor_double_deletes WHERE number = 9") - - spark.sql(f"DELETE FROM default.test_positional_mor_double_deletes WHERE letter == 'f'") - - all_types_dataframe = ( - spark.range(0, 5, 1, 5) - .withColumnRenamed("id", "longCol") - .withColumn("intCol", expr("CAST(longCol AS INT)")) - .withColumn("floatCol", expr("CAST(longCol AS FLOAT)")) - .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)")) - .withColumn("dateCol", date_add(current_date(), 1)) - .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) - .withColumn("stringCol", expr("CAST(dateCol AS STRING)")) - .withColumn("booleanCol", expr("longCol > 5")) - .withColumn("binaryCol", expr("CAST(longCol AS BINARY)")) - .withColumn("byteCol", expr("CAST(longCol AS BYTE)")) - .withColumn("decimalCol", expr("CAST(longCol AS DECIMAL(10, 2))")) - .withColumn("shortCol", expr("CAST(longCol AS SHORT)")) - .withColumn("mapCol", expr("MAP(longCol, decimalCol)")) - .withColumn("arrayCol", expr("ARRAY(longCol)")) - .withColumn("structCol", expr("STRUCT(mapCol, arrayCol)")) - ) - - all_types_dataframe.writeTo(f"default.test_all_types").tableProperty("format-version", "2").partitionedBy( - "intCol" - ).createOrReplace() - - for table_name, partition in [ - ("test_partitioned_by_identity", "ts"), - ("test_partitioned_by_years", "years(dt)"), - ("test_partitioned_by_months", "months(dt)"), - ("test_partitioned_by_days", "days(ts)"), - ("test_partitioned_by_hours", "hours(ts)"), - ("test_partitioned_by_truncate", "truncate(1, letter)"), - ("test_partitioned_by_bucket", "bucket(16, number)"), - ]: - spark.sql( - f""" - CREATE OR REPLACE TABLE default.{table_name} ( - dt date, - ts timestamp, - number integer, - letter string - ) - USING iceberg; - """ - ) - - spark.sql(f"ALTER TABLE default.{table_name} ADD PARTITION FIELD {partition}") - - spark.sql( - f""" - INSERT INTO default.{table_name} - VALUES - (CAST('2022-03-01' AS date), CAST('2022-03-01 01:22:00' AS timestamp), 1, 'a'), - (CAST('2022-03-02' AS date), CAST('2022-03-02 02:22:00' AS timestamp), 2, 'b'), - (CAST('2022-03-03' AS date), CAST('2022-03-03 03:22:00' AS timestamp), 3, 'c'), - (CAST('2022-03-04' AS date), CAST('2022-03-04 04:22:00' AS timestamp), 4, 'd'), - (CAST('2023-03-05' AS date), CAST('2023-03-05 05:22:00' AS timestamp), 5, 'e'), - (CAST('2023-03-06' AS date), CAST('2023-03-06 06:22:00' AS timestamp), 6, 'f'), - (CAST('2023-03-07' AS date), CAST('2023-03-07 07:22:00' AS timestamp), 7, 'g'), - (CAST('2023-03-08' AS date), CAST('2023-03-08 08:22:00' AS timestamp), 8, 'h'), - (CAST('2023-03-09' AS date), CAST('2023-03-09 09:22:00' AS timestamp), 9, 'i'), - (CAST('2023-03-10' AS date), CAST('2023-03-10 10:22:00' AS timestamp), 10, 'j'), - (CAST('2023-03-11' AS date), CAST('2023-03-11 11:22:00' AS timestamp), 11, 'k'), - (CAST('2023-03-12' AS date), CAST('2023-03-12 12:22:00' AS timestamp), 12, 'l'); - """ - ) - - # There is an issue with CREATE OR REPLACE - # https://github.com/apache/iceberg/issues/8756 - spark.sql(f"DROP TABLE IF EXISTS default.test_table_version") - - spark.sql( - f""" - CREATE TABLE default.test_table_version ( - dt date, - number integer, - letter string - ) - USING iceberg - TBLPROPERTIES ( - 'format-version'='1' - ); - """ - ) - - spark.sql( - f""" - CREATE TABLE default.test_table_sanitized_character ( - `letter/abc` string - ) - USING iceberg - TBLPROPERTIES ( - 'format-version'='1' - ); - """ - ) - - spark.sql( - f""" - INSERT INTO default.test_table_sanitized_character - VALUES - ('123') - """ - ) - - spark.sql( - f""" - INSERT INTO default.test_table_sanitized_character - VALUES - ('123') - """ - ) - - spark.sql( - f""" - CREATE TABLE default.test_table_add_column ( - a string - ) - USING iceberg - """ - ) - - spark.sql(f"INSERT INTO default.test_table_add_column VALUES ('1')") - - spark.sql(f"ALTER TABLE default.test_table_add_column ADD COLUMN b string") - - spark.sql(f"INSERT INTO default.test_table_add_column VALUES ('2', '2')") - - spark.sql( - f""" - CREATE TABLE default.test_table_empty_list_and_map ( - col_list array, - col_map map, - col_list_with_struct array> - ) - USING iceberg - TBLPROPERTIES ( - 'format-version'='1' - ); - """ - ) - - spark.sql( - f""" - INSERT INTO default.test_table_empty_list_and_map - VALUES (null, null, null), - (array(), map(), array(struct(1))) - """ - ) - - spark.sql( - f""" - CREATE OR REPLACE TABLE default.test_table_snapshot_operations ( - number integer - ) - USING iceberg - TBLPROPERTIES ( - 'format-version'='2' - ); - """ - ) - - spark.sql( - f""" - INSERT INTO default.test_table_snapshot_operations - VALUES (1) - """ - ) - - spark.sql( - f""" - INSERT INTO default.test_table_snapshot_operations - VALUES (2) - """ - ) - - spark.sql( - f""" - DELETE FROM default.test_table_snapshot_operations - WHERE number = 2 - """ - ) - - spark.sql( - f""" - INSERT INTO default.test_table_snapshot_operations - VALUES (3) - """ - ) - - spark.sql( - f""" - INSERT INTO default.test_table_snapshot_operations - VALUES (4) - """ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyspark.sql import SparkSession +from pyspark.sql.functions import current_date, date_add, expr + +from pyiceberg.catalog import load_catalog +from pyiceberg.schema import Schema +from pyiceberg.types import FixedType, NestedField, UUIDType + +spark = SparkSession.builder.getOrCreate() + +catalogs = { + 'rest': load_catalog( + "rest", + **{ + "type": "rest", + "uri": "http://rest:8181", + "s3.endpoint": "http://minio:9000", + "s3.access-key-id": "admin", + "s3.secret-access-key": "password", + }, + ), +} + +for catalog_name, catalog in catalogs.items(): + spark.sql( + f""" + CREATE DATABASE IF NOT EXISTS default; + """ + ) + + schema = Schema( + NestedField(field_id=1, name="uuid_col", field_type=UUIDType(), required=False), + NestedField(field_id=2, name="fixed_col", field_type=FixedType(25), required=False), + ) + + catalog.create_table(identifier=f"default.test_uuid_and_fixed_unpartitioned", schema=schema) + + spark.sql( + f""" + INSERT INTO default.test_uuid_and_fixed_unpartitioned VALUES + ('102cb62f-e6f8-4eb0-9973-d9b012ff0967', CAST('1234567890123456789012345' AS BINARY)), + ('ec33e4b2-a834-4cc3-8c4a-a1d3bfc2f226', CAST('1231231231231231231231231' AS BINARY)), + ('639cccce-c9d2-494a-a78c-278ab234f024', CAST('12345678901234567ass12345' AS BINARY)), + ('c1b0d8e0-0b0e-4b1e-9b0a-0e0b0d0c0a0b', CAST('asdasasdads12312312312111' AS BINARY)), + ('923dae77-83d6-47cd-b4b0-d383e64ee57e', CAST('qweeqwwqq1231231231231111' AS BINARY)); + """ + ) + + spark.sql( + f""" + CREATE OR REPLACE TABLE default.test_null_nan + USING iceberg + AS SELECT + 1 AS idx, + float('NaN') AS col_numeric + UNION ALL SELECT + 2 AS idx, + null AS col_numeric + UNION ALL SELECT + 3 AS idx, + 1 AS col_numeric; + """ + ) + + spark.sql( + f""" + CREATE OR REPLACE TABLE default.test_null_nan_rewritten + USING iceberg + AS SELECT * FROM default.test_null_nan; + """ + ) + + spark.sql( + f""" + CREATE OR REPLACE TABLE default.test_limit as + SELECT * LATERAL VIEW explode(ARRAY(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) AS idx; + """ + ) + + spark.sql( + f""" + CREATE OR REPLACE TABLE default.test_positional_mor_deletes ( + dt date, + number integer, + letter string + ) + USING iceberg + TBLPROPERTIES ( + 'write.delete.mode'='merge-on-read', + 'write.update.mode'='merge-on-read', + 'write.merge.mode'='merge-on-read', + 'format-version'='2' + ); + """ + ) + + # Partitioning is not really needed, but there is a bug: + # https://github.com/apache/iceberg/pull/7685 + spark.sql(f"ALTER TABLE default.test_positional_mor_deletes ADD PARTITION FIELD years(dt) AS dt_years") + + spark.sql( + f""" + INSERT INTO default.test_positional_mor_deletes + VALUES + (CAST('2023-03-01' AS date), 1, 'a'), + (CAST('2023-03-02' AS date), 2, 'b'), + (CAST('2023-03-03' AS date), 3, 'c'), + (CAST('2023-03-04' AS date), 4, 'd'), + (CAST('2023-03-05' AS date), 5, 'e'), + (CAST('2023-03-06' AS date), 6, 'f'), + (CAST('2023-03-07' AS date), 7, 'g'), + (CAST('2023-03-08' AS date), 8, 'h'), + (CAST('2023-03-09' AS date), 9, 'i'), + (CAST('2023-03-10' AS date), 10, 'j'), + (CAST('2023-03-11' AS date), 11, 'k'), + (CAST('2023-03-12' AS date), 12, 'l'); + """ + ) + + spark.sql(f"ALTER TABLE default.test_positional_mor_deletes CREATE TAG tag_12") + + spark.sql(f"ALTER TABLE default.test_positional_mor_deletes CREATE BRANCH without_5") + + spark.sql(f"DELETE FROM default.test_positional_mor_deletes.branch_without_5 WHERE number = 5") + + spark.sql(f"DELETE FROM default.test_positional_mor_deletes WHERE number = 9") + + spark.sql( + f""" + CREATE OR REPLACE TABLE default.test_positional_mor_double_deletes ( + dt date, + number integer, + letter string + ) + USING iceberg + TBLPROPERTIES ( + 'write.delete.mode'='merge-on-read', + 'write.update.mode'='merge-on-read', + 'write.merge.mode'='merge-on-read', + 'format-version'='2' + ); + """ + ) + + spark.sql(f"ALTER TABLE default.test_positional_mor_double_deletes ADD PARTITION FIELD years(dt) AS dt_years") + + spark.sql( + f""" + INSERT INTO default.test_positional_mor_double_deletes + VALUES + (CAST('2023-03-01' AS date), 1, 'a'), + (CAST('2023-03-02' AS date), 2, 'b'), + (CAST('2023-03-03' AS date), 3, 'c'), + (CAST('2023-03-04' AS date), 4, 'd'), + (CAST('2023-03-05' AS date), 5, 'e'), + (CAST('2023-03-06' AS date), 6, 'f'), + (CAST('2023-03-07' AS date), 7, 'g'), + (CAST('2023-03-08' AS date), 8, 'h'), + (CAST('2023-03-09' AS date), 9, 'i'), + (CAST('2023-03-10' AS date), 10, 'j'), + (CAST('2023-03-11' AS date), 11, 'k'), + (CAST('2023-03-12' AS date), 12, 'l'); + """ + ) + + spark.sql(f"DELETE FROM default.test_positional_mor_double_deletes WHERE number = 9") + + spark.sql(f"DELETE FROM default.test_positional_mor_double_deletes WHERE letter == 'f'") + + all_types_dataframe = ( + spark.range(0, 5, 1, 5) + .withColumnRenamed("id", "longCol") + .withColumn("intCol", expr("CAST(longCol AS INT)")) + .withColumn("floatCol", expr("CAST(longCol AS FLOAT)")) + .withColumn("doubleCol", expr("CAST(longCol AS DOUBLE)")) + .withColumn("dateCol", date_add(current_date(), 1)) + .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)")) + .withColumn("stringCol", expr("CAST(dateCol AS STRING)")) + .withColumn("booleanCol", expr("longCol > 5")) + .withColumn("binaryCol", expr("CAST(longCol AS BINARY)")) + .withColumn("byteCol", expr("CAST(longCol AS BYTE)")) + .withColumn("decimalCol", expr("CAST(longCol AS DECIMAL(10, 2))")) + .withColumn("shortCol", expr("CAST(longCol AS SHORT)")) + .withColumn("mapCol", expr("MAP(longCol, decimalCol)")) + .withColumn("arrayCol", expr("ARRAY(longCol)")) + .withColumn("structCol", expr("STRUCT(mapCol, arrayCol)")) + ) + + all_types_dataframe.writeTo(f"default.test_all_types").tableProperty("format-version", "2").partitionedBy( + "intCol" + ).createOrReplace() + + for table_name, partition in [ + ("test_partitioned_by_identity", "ts"), + ("test_partitioned_by_years", "years(dt)"), + ("test_partitioned_by_months", "months(dt)"), + ("test_partitioned_by_days", "days(ts)"), + ("test_partitioned_by_hours", "hours(ts)"), + ("test_partitioned_by_truncate", "truncate(1, letter)"), + ("test_partitioned_by_bucket", "bucket(16, number)"), + ]: + spark.sql( + f""" + CREATE OR REPLACE TABLE default.{table_name} ( + dt date, + ts timestamp, + number integer, + letter string + ) + USING iceberg; + """ + ) + + spark.sql(f"ALTER TABLE default.{table_name} ADD PARTITION FIELD {partition}") + + spark.sql( + f""" + INSERT INTO default.{table_name} + VALUES + (CAST('2022-03-01' AS date), CAST('2022-03-01 01:22:00' AS timestamp), 1, 'a'), + (CAST('2022-03-02' AS date), CAST('2022-03-02 02:22:00' AS timestamp), 2, 'b'), + (CAST('2022-03-03' AS date), CAST('2022-03-03 03:22:00' AS timestamp), 3, 'c'), + (CAST('2022-03-04' AS date), CAST('2022-03-04 04:22:00' AS timestamp), 4, 'd'), + (CAST('2023-03-05' AS date), CAST('2023-03-05 05:22:00' AS timestamp), 5, 'e'), + (CAST('2023-03-06' AS date), CAST('2023-03-06 06:22:00' AS timestamp), 6, 'f'), + (CAST('2023-03-07' AS date), CAST('2023-03-07 07:22:00' AS timestamp), 7, 'g'), + (CAST('2023-03-08' AS date), CAST('2023-03-08 08:22:00' AS timestamp), 8, 'h'), + (CAST('2023-03-09' AS date), CAST('2023-03-09 09:22:00' AS timestamp), 9, 'i'), + (CAST('2023-03-10' AS date), CAST('2023-03-10 10:22:00' AS timestamp), 10, 'j'), + (CAST('2023-03-11' AS date), CAST('2023-03-11 11:22:00' AS timestamp), 11, 'k'), + (CAST('2023-03-12' AS date), CAST('2023-03-12 12:22:00' AS timestamp), 12, 'l'); + """ + ) + + # There is an issue with CREATE OR REPLACE + # https://github.com/apache/iceberg/issues/8756 + spark.sql(f"DROP TABLE IF EXISTS default.test_table_version") + + spark.sql( + f""" + CREATE TABLE default.test_table_version ( + dt date, + number integer, + letter string + ) + USING iceberg + TBLPROPERTIES ( + 'format-version'='1' + ); + """ + ) + + spark.sql( + f""" + CREATE TABLE default.test_table_sanitized_character ( + `letter/abc` string + ) + USING iceberg + TBLPROPERTIES ( + 'format-version'='1' + ); + """ + ) + + spark.sql( + f""" + INSERT INTO default.test_table_sanitized_character + VALUES + ('123') + """ + ) + + spark.sql( + f""" + INSERT INTO default.test_table_sanitized_character + VALUES + ('123') + """ + ) + + spark.sql( + f""" + CREATE TABLE default.test_table_add_column ( + a string + ) + USING iceberg + """ + ) + + spark.sql(f"INSERT INTO default.test_table_add_column VALUES ('1')") + + spark.sql(f"ALTER TABLE default.test_table_add_column ADD COLUMN b string") + + spark.sql(f"INSERT INTO default.test_table_add_column VALUES ('2', '2')") + + spark.sql( + f""" + CREATE TABLE default.test_table_empty_list_and_map ( + col_list array, + col_map map, + col_list_with_struct array> + ) + USING iceberg + TBLPROPERTIES ( + 'format-version'='1' + ); + """ + ) + + spark.sql( + f""" + INSERT INTO default.test_table_empty_list_and_map + VALUES (null, null, null), + (array(), map(), array(struct(1))) + """ + ) + + spark.sql( + f""" + CREATE OR REPLACE TABLE default.test_table_snapshot_operations ( + number integer + ) + USING iceberg + TBLPROPERTIES ( + 'format-version'='2' + ); + """ + ) + + spark.sql( + f""" + INSERT INTO default.test_table_snapshot_operations + VALUES (1) + """ + ) + + spark.sql( + f""" + INSERT INTO default.test_table_snapshot_operations + VALUES (2) + """ + ) + + spark.sql( + f""" + DELETE FROM default.test_table_snapshot_operations + WHERE number = 2 + """ + ) + + spark.sql( + f""" + INSERT INTO default.test_table_snapshot_operations + VALUES (3) + """ + ) + + spark.sql( + f""" + INSERT INTO default.test_table_snapshot_operations + VALUES (4) + """ ) \ No newline at end of file diff --git a/dev/release/README.md b/dev/release/README.md index 1f9285f..19a538e 100644 --- a/dev/release/README.md +++ b/dev/release/README.md @@ -1,105 +1,105 @@ - - -# Release - -## Overview - - 1. Test the revision to be released - 2. Prepare RC and vote (detailed later) - 3. Publish (detailed later) - -### Prepare RC and vote - -Run `dev/release/release_rc.sh` on a working copy of -`git@github.com:apache/iceberg-go` not from your fork: - -```console -$ git clone git@github.com:apache/iceberg-go.git -$ dev/release/release_rc.sh ${VERSION} ${RC} -(Send a vote email to dev@iceberg.apache.org. - You can use a draft shown by release_rc.sh for the email.) -``` - -Here is an example to release RC1: - -```console -$ GH_TOKEN=${YOUR_GITHUB_TOKEN} dev/release/release_rc.sh 1.0.0 1 -``` - -The arguments of `release_rc.sh` are the version and the RC number. If RC1 has a problem, we'll increment the RC number such as RC2, RC3 and so on. - -Requirements to run `release_rc.sh`: - - * You must be an Apache Iceberg committer or PMC member - * You must prepare your PGP key for signing - -If you don't have a PGP key, https://infra.apache.org/release-signing.html#generate -may be helpful. - -Your PGP key must be registered to the following: - - * https://dist.apache.org/repos/dist/dev/iceberg/KEYS - * https://dist.apache.org/repos/dist/release/iceberg/KEYS - -See the header comment of them for how to add a PGP key. - -Apache Iceberg committers can update them by Subversion client with their ASF account. -e.g.: - -```console -$ svn co https://dist.apache.org/repos/dist/dev/iceberg -$ cd iceberg -$ editor KEYS -$ svn ci KEYS -``` - -### Publish - -We need to do the following to publish a new release: - - * Publish to apache.org - -Run `dev/release/release.sh` to publish to apache.org: - -```console -$ GH_TOKEN=${YOUR_GITHUB_TOKEN} dev/release/release.sh ${VERSION} ${RC} -``` - -Add the release to ASF's report database via [Apache Committee Report Helper](https://reporter.apache.org/addrelease.html?iceberg) - -### Verify - -We have a script for verifying a RC. - -You must install the following to run the script: - - * `curl` - * `gpg` - * `shasum` or `sha256sum`/`sha512sum` - * `tar` - -You don't need to have Go installed, if it isn't on the system the latest Go will be -automatically downloaded and used only for verification. - -To verify a RC, run the following: - -```console -$ dev/release/verify_rc.sh ${VERSION} ${RC} -``` - -If the verification is successful, the message `RC looks good!` is shown. + + +# Release + +## Overview + + 1. Test the revision to be released + 2. Prepare RC and vote (detailed later) + 3. Publish (detailed later) + +### Prepare RC and vote + +Run `dev/release/release_rc.sh` on a working copy of +`git@github.com:apache/iceberg-go` not from your fork: + +```console +$ git clone git@github.com:apache/iceberg-go.git +$ dev/release/release_rc.sh ${VERSION} ${RC} +(Send a vote email to dev@iceberg.apache.org. + You can use a draft shown by release_rc.sh for the email.) +``` + +Here is an example to release RC1: + +```console +$ GH_TOKEN=${YOUR_GITHUB_TOKEN} dev/release/release_rc.sh 1.0.0 1 +``` + +The arguments of `release_rc.sh` are the version and the RC number. If RC1 has a problem, we'll increment the RC number such as RC2, RC3 and so on. + +Requirements to run `release_rc.sh`: + + * You must be an Apache Iceberg committer or PMC member + * You must prepare your PGP key for signing + +If you don't have a PGP key, https://infra.apache.org/release-signing.html#generate +may be helpful. + +Your PGP key must be registered to the following: + + * https://dist.apache.org/repos/dist/dev/iceberg/KEYS + * https://dist.apache.org/repos/dist/release/iceberg/KEYS + +See the header comment of them for how to add a PGP key. + +Apache Iceberg committers can update them by Subversion client with their ASF account. +e.g.: + +```console +$ svn co https://dist.apache.org/repos/dist/dev/iceberg +$ cd iceberg +$ editor KEYS +$ svn ci KEYS +``` + +### Publish + +We need to do the following to publish a new release: + + * Publish to apache.org + +Run `dev/release/release.sh` to publish to apache.org: + +```console +$ GH_TOKEN=${YOUR_GITHUB_TOKEN} dev/release/release.sh ${VERSION} ${RC} +``` + +Add the release to ASF's report database via [Apache Committee Report Helper](https://reporter.apache.org/addrelease.html?iceberg) + +### Verify + +We have a script for verifying a RC. + +You must install the following to run the script: + + * `curl` + * `gpg` + * `shasum` or `sha256sum`/`sha512sum` + * `tar` + +You don't need to have Go installed, if it isn't on the system the latest Go will be +automatically downloaded and used only for verification. + +To verify a RC, run the following: + +```console +$ dev/release/verify_rc.sh ${VERSION} ${RC} +``` + +If the verification is successful, the message `RC looks good!` is shown. diff --git a/dev/release/check_rat_report.py b/dev/release/check_rat_report.py index c45baa0..94620be 100755 --- a/dev/release/check_rat_report.py +++ b/dev/release/check_rat_report.py @@ -1,59 +1,59 @@ -#!/usr/bin/env python3 -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import fnmatch -import re -import sys -import xml.etree.ElementTree as ET - -if len(sys.argv) != 3: - sys.stderr.write("Usage: %s exclude_globs.lst rat_report.xml\n" % - sys.argv[0]) - sys.exit(1) - -exclude_globs_filename = sys.argv[1] -xml_filename = sys.argv[2] - -globs = [line.strip() for line in open(exclude_globs_filename, "r")] - -tree = ET.parse(xml_filename) -root = tree.getroot() -resources = root.findall('resource') - -all_ok = True -for r in resources: - approvals = r.findall('license-approval') - if not approvals or approvals[0].attrib['name'] == 'true': - continue - clean_name = re.sub('^[^/]+/', '', r.attrib['name']) - excluded = False - for g in globs: - if fnmatch.fnmatch(clean_name, g): - excluded = True - break - if not excluded: - sys.stdout.write("NOT APPROVED: %s (%s): %s\n" % ( - clean_name, r.attrib['name'], approvals[0].attrib['name'])) - all_ok = False - -if not all_ok: - sys.exit(1) - -print('OK') -sys.exit(0) +#!/usr/bin/env python3 +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import fnmatch +import re +import sys +import xml.etree.ElementTree as ET + +if len(sys.argv) != 3: + sys.stderr.write("Usage: %s exclude_globs.lst rat_report.xml\n" % + sys.argv[0]) + sys.exit(1) + +exclude_globs_filename = sys.argv[1] +xml_filename = sys.argv[2] + +globs = [line.strip() for line in open(exclude_globs_filename, "r")] + +tree = ET.parse(xml_filename) +root = tree.getroot() +resources = root.findall('resource') + +all_ok = True +for r in resources: + approvals = r.findall('license-approval') + if not approvals or approvals[0].attrib['name'] == 'true': + continue + clean_name = re.sub('^[^/]+/', '', r.attrib['name']) + excluded = False + for g in globs: + if fnmatch.fnmatch(clean_name, g): + excluded = True + break + if not excluded: + sys.stdout.write("NOT APPROVED: %s (%s): %s\n" % ( + clean_name, r.attrib['name'], approvals[0].attrib['name'])) + all_ok = False + +if not all_ok: + sys.exit(1) + +print('OK') +sys.exit(0) diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index 08ab050..b524bc2 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -1,24 +1,24 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -.gitignore -LICENSE -NOTICE -go.sum -build -rat-results.txt +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +.gitignore +LICENSE +NOTICE +go.sum +build +rat-results.txt operation_string.go \ No newline at end of file diff --git a/dev/release/release.sh b/dev/release/release.sh index d67264e..c4ee9a3 100755 --- a/dev/release/release.sh +++ b/dev/release/release.sh @@ -1,83 +1,83 @@ -#!/usr/bin/env bash -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -set -eu - -if [ "$#" -ne 2 ]; then - echo "Usage: $0 " - echo " e.g.: $0 1.0.0 1" - exit 1 -fi - -version=$1 -rc=$2 - -git_origin_url="$(git remote get-url origin)" -repository="${git_origin_url#*github.com?}" -repository="${repository%.git}" -if [ "${git_origin_url}" != "git@github.com:apache/iceberg-go.git" ]; then - echo "This script must be ran with a working copy of apache/iceberg-go." - echo "The origin's URL: ${git_origin_url}" - exit 1 -fi - -tag="v${version}" -rc_tag="${tag}-rc${rc}" -echo "Tagging for release: ${tag}" -git tag "${tag}" "${rc_tag}" -git push origin "${tag}" - -dist_url="https://dist.apache.org/repos/dist/release/iceberg" -dist_dir="dev/release/dist" -echo "Checking out ${dist_url}" -rm -rf "${dist_dir}" -svn co --depth=empty "${dist_url}" "${dist_dir}" -gh release download "${rc_tag}" \ - --repo "${repository}" \ - --dir "${dist_dir}" \ - --skip-existing - -release_id="apache-iceberg-go-${version}" -echo "Uploading to release/" -pushd "${dist_dir}" -svn add . -svn ci -m "Apache Iceberg Go ${version}" -popd -rm -rf "${dist_dir}" - -echo "Keep only the latest versions" -old_releases=$( - svn ls https://dist.apache.org/repos/dist/release/iceberg/ | - grep -E '^apache-iceberg-go-' | - sort --version-sort --reverse | - tail -n +2 -) -for old_release_version in ${old_releases}; do - echo "Remove old release ${old_release_version}" - svn \ - delete \ - -m "Remove old Apache Iceberg Go release: ${old_release_version}" \ - "https://dist.apache.org/repos/dist/release/iceberg/${old_release_version}" -done - -echo "Success! The release is available here:" -echo " https://dist.apache.org/repos/dist/release/iceberg/${release_id}" -echo -echo "Add this release to ASF's report database:" -echo " https://reporter.apache.org/addrelease.html?iceberg" +#!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -eu + +if [ "$#" -ne 2 ]; then + echo "Usage: $0 " + echo " e.g.: $0 1.0.0 1" + exit 1 +fi + +version=$1 +rc=$2 + +git_origin_url="$(git remote get-url origin)" +repository="${git_origin_url#*github.com?}" +repository="${repository%.git}" +if [ "${git_origin_url}" != "git@github.com:apache/iceberg-go.git" ]; then + echo "This script must be ran with a working copy of apache/iceberg-go." + echo "The origin's URL: ${git_origin_url}" + exit 1 +fi + +tag="v${version}" +rc_tag="${tag}-rc${rc}" +echo "Tagging for release: ${tag}" +git tag "${tag}" "${rc_tag}" +git push origin "${tag}" + +dist_url="https://dist.apache.org/repos/dist/release/iceberg" +dist_dir="dev/release/dist" +echo "Checking out ${dist_url}" +rm -rf "${dist_dir}" +svn co --depth=empty "${dist_url}" "${dist_dir}" +gh release download "${rc_tag}" \ + --repo "${repository}" \ + --dir "${dist_dir}" \ + --skip-existing + +release_id="apache-iceberg-go-${version}" +echo "Uploading to release/" +pushd "${dist_dir}" +svn add . +svn ci -m "Apache Iceberg Go ${version}" +popd +rm -rf "${dist_dir}" + +echo "Keep only the latest versions" +old_releases=$( + svn ls https://dist.apache.org/repos/dist/release/iceberg/ | + grep -E '^apache-iceberg-go-' | + sort --version-sort --reverse | + tail -n +2 +) +for old_release_version in ${old_releases}; do + echo "Remove old release ${old_release_version}" + svn \ + delete \ + -m "Remove old Apache Iceberg Go release: ${old_release_version}" \ + "https://dist.apache.org/repos/dist/release/iceberg/${old_release_version}" +done + +echo "Success! The release is available here:" +echo " https://dist.apache.org/repos/dist/release/iceberg/${release_id}" +echo +echo "Add this release to ASF's report database:" +echo " https://reporter.apache.org/addrelease.html?iceberg" diff --git a/dev/release/release_rc.sh b/dev/release/release_rc.sh index 3c93f84..9864547 100755 --- a/dev/release/release_rc.sh +++ b/dev/release/release_rc.sh @@ -1,140 +1,140 @@ -#!/usr/bin/env bash -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -set -eu - -SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -SOURCE_TOP_DIR="$(cd "${SOURCE_DIR}/../../" && pwd)" - -if [ "$#" -ne 1 ]; then - echo "Usage: $0 " - echo " e.g.: $0 1.0.0 1" - exit 1 -fi - -# TODO: possibly use go-semantic-release to auto generate the version? -version=$1 -rc=$2 - -: "${RELEASE_DEFAULT:=1}" -: "${RELEASE_PULL:=${RELEASE_DEFAULT}}" -: "${RELEASE_PUSH_TAG:=${RELEASE_DEFAULT}}" -: "${RELEASE_SIGN:=${RELEASE_DEFAULT}}" -: "${RELEASE_UPLOAD:=${RELEASE_DEFAULT}}" - - -cd "${SOURCE_TOP_DIR}" - -if [ "${RELEASE_PULL}" -gt 0 ] || [ "${RELEASE_PUSH_TAG}" -gt 0 ]; then - git_origin_url="$(git remote get-url origin)" - if [ "${git_origin_url}" != "git@github.com:apache/iceberg-go.git" ]; then - echo "This script must be ran with working copy of apache/iceberg-go." - echo "The origin's URL: ${git_origin_url}" - exit 1 - fi -fi - -if [ "${RELEASE_PULL}" -gt 0 ]; then - echo "Ensure using the latest commit" - git checkout main - git pull --rebase --prune -fi - -rc_tag="v${version}-rc${rc}" -if [ "${RELEASE_PUSH_TAG}" -gt 0 ]; then - echo "Tagging for RC: ${rc_tag}" - git tag -a -m "${version} RC${rc}" "${rc_tag}" - git push origin "${rc_tag}" -fi - -rc_hash="$(git rev-list --max-count=1 "${rc_tag}")" - -id="apache-iceberg-go-${version}" -tar_gz="${id}.tar.gz" - -if [ "${RELEASE_SIGN}" -gt 0 ]; then - git_origin_url="$(git remote get-url origin)" - repository="${git_origin_url#*github.com?}" - repository="${repository%.git}" - - echo "Looking for GitHub Actions workflow on ${repository}:${rc_tag}" - run_id="" - while [ -z "${run_id}" ]; do - echo "Waiting for run to start..." - run_id=$(gh run list \ - --repo "${repository}" \ - --workflow=rc.yml \ - --json 'databaseId,event,headBranch,status' \ - --jq ".[] | select(.event == \"push\" and .headBranch == \"${rc_tag}\") | .databaseId") - sleep 1 - done - - echo "Found GitHub Actions workflow with ID: ${run_id}" - gh run watch --repo "${repository}" --exit-status "${run_id}" - - echo "Downloading .tar.gz from GitHub Releases" - gh release download "${rc_tag}" \ - --dir . \ - --pattern "${tar_gz}" \ - --repo "${repository}" \ - --skip-existing - - echo "Signing tar.gz and creating checksums" - gpg --armor --output "${tar_gz}.asc" --detach-sig "${tar_gz}" -fi - -if [ "${RELEASE_UPLOAD}" -gt 0 ]; then - echo "Uploading signature" - gh release upload "${rc_tag}" \ - --clobber \ - --repo "${repository}" \ - "${tar_gz}.asc" -fi - -echo "Draft email for dev@iceberg.apache.org mailing list" -echo "" -echo "---------------------------------------------------------" -cat < " + echo " e.g.: $0 1.0.0 1" + exit 1 +fi + +# TODO: possibly use go-semantic-release to auto generate the version? +version=$1 +rc=$2 + +: "${RELEASE_DEFAULT:=1}" +: "${RELEASE_PULL:=${RELEASE_DEFAULT}}" +: "${RELEASE_PUSH_TAG:=${RELEASE_DEFAULT}}" +: "${RELEASE_SIGN:=${RELEASE_DEFAULT}}" +: "${RELEASE_UPLOAD:=${RELEASE_DEFAULT}}" + + +cd "${SOURCE_TOP_DIR}" + +if [ "${RELEASE_PULL}" -gt 0 ] || [ "${RELEASE_PUSH_TAG}" -gt 0 ]; then + git_origin_url="$(git remote get-url origin)" + if [ "${git_origin_url}" != "git@github.com:apache/iceberg-go.git" ]; then + echo "This script must be ran with working copy of apache/iceberg-go." + echo "The origin's URL: ${git_origin_url}" + exit 1 + fi +fi + +if [ "${RELEASE_PULL}" -gt 0 ]; then + echo "Ensure using the latest commit" + git checkout main + git pull --rebase --prune +fi + +rc_tag="v${version}-rc${rc}" +if [ "${RELEASE_PUSH_TAG}" -gt 0 ]; then + echo "Tagging for RC: ${rc_tag}" + git tag -a -m "${version} RC${rc}" "${rc_tag}" + git push origin "${rc_tag}" +fi + +rc_hash="$(git rev-list --max-count=1 "${rc_tag}")" + +id="apache-iceberg-go-${version}" +tar_gz="${id}.tar.gz" + +if [ "${RELEASE_SIGN}" -gt 0 ]; then + git_origin_url="$(git remote get-url origin)" + repository="${git_origin_url#*github.com?}" + repository="${repository%.git}" + + echo "Looking for GitHub Actions workflow on ${repository}:${rc_tag}" + run_id="" + while [ -z "${run_id}" ]; do + echo "Waiting for run to start..." + run_id=$(gh run list \ + --repo "${repository}" \ + --workflow=rc.yml \ + --json 'databaseId,event,headBranch,status' \ + --jq ".[] | select(.event == \"push\" and .headBranch == \"${rc_tag}\") | .databaseId") + sleep 1 + done + + echo "Found GitHub Actions workflow with ID: ${run_id}" + gh run watch --repo "${repository}" --exit-status "${run_id}" + + echo "Downloading .tar.gz from GitHub Releases" + gh release download "${rc_tag}" \ + --dir . \ + --pattern "${tar_gz}" \ + --repo "${repository}" \ + --skip-existing + + echo "Signing tar.gz and creating checksums" + gpg --armor --output "${tar_gz}.asc" --detach-sig "${tar_gz}" +fi + +if [ "${RELEASE_UPLOAD}" -gt 0 ]; then + echo "Uploading signature" + gh release upload "${rc_tag}" \ + --clobber \ + --repo "${repository}" \ + "${tar_gz}.asc" +fi + +echo "Draft email for dev@iceberg.apache.org mailing list" +echo "" +echo "---------------------------------------------------------" +cat < \ - "${FILTERED_RAT_TXT}"; then - echo "No unapproved licenses" -else - cat "${FILTERED_RAT_TXT}" - N_UNAPPROVED=$(grep -c "NOT APPROVED" "${FILTERED_RAT_TXT}") - echo "${N_UNAPPROVED} unapproved licenses. Check Rat report: ${RAT_XML}" - exit 1 -fi +#!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -eu + +RELEASE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +RAT_VERSION=0.16.1 + +RAT_JAR="${RELEASE_DIR}/apache-rat-${RAT_VERSION}.jar" +if [ ! -f "${RAT_JAR}" ]; then + curl \ + --fail \ + --output "${RAT_JAR}" \ + --show-error \ + --silent \ + https://repo1.maven.org/maven2/org/apache/rat/apache-rat/${RAT_VERSION}/apache-rat-${RAT_VERSION}.jar +fi + +RAT_XML="${RELEASE_DIR}/rat.xml" +java \ + -jar "${RAT_JAR}" \ + --out "${RAT_XML}" \ + --xml \ + "$1" +FILTERED_RAT_TXT="${RELEASE_DIR}/filtered_rat.txt" +if ${PYTHON:-python3} \ + "${RELEASE_DIR}/check_rat_report.py" \ + "${RELEASE_DIR}/rat_exclude_files.txt" \ + "${RAT_XML}" > \ + "${FILTERED_RAT_TXT}"; then + echo "No unapproved licenses" +else + cat "${FILTERED_RAT_TXT}" + N_UNAPPROVED=$(grep -c "NOT APPROVED" "${FILTERED_RAT_TXT}") + echo "${N_UNAPPROVED} unapproved licenses. Check Rat report: ${RAT_XML}" + exit 1 +fi diff --git a/dev/release/verify_rc.sh b/dev/release/verify_rc.sh index a84af66..159ab5c 100755 --- a/dev/release/verify_rc.sh +++ b/dev/release/verify_rc.sh @@ -1,206 +1,206 @@ -#!/usr/bin/env bash -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -set -eu - -SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -TOP_SOURCE_DIR="$(dirname "$(dirname "${SOURCE_DIR}")")" - -if [ "$#" -ne 2 ]; then - echo "Usage: $0 " - echo " e.g.: $0 18.0.0 1" - exit 1 -fi - -set -o pipefail -set -x - -VERSION="$1" -RC="$2" - -ICEBERG_DIST_BASE_URL="https://dist.apache.org/repos/dist/dev/iceberg" -DOWNLOAD_RC_BASE_URL="https://github.com/apache/iceberg-go/releases/download/v${VERSION}-rc${RC}" -ARCHIVE_BASE_NAME="apache-iceberg-go-${VERSION}" - -: "${VERIFY_DEFAULT:=1}" -: "${VERIFY_DOWNLOAD:=${VERIFY_DEFAULT}}" -: "${VERIFY_FORCE_USE_GO_BINARY:=0}" -: "${VERIFY_SIGN:=${VERIFY_DEFAULT}}" - -VERIFY_SUCCESS=no - -setup_tmpdir() { - cleanup() { - go clean -modcache || : - if [ "${VERIFY_SUCCESS}" = "yes" ]; then - rm -rf "${VERIFY_TMPDIR}" - else - echo "Failed to verify release candidate. See ${VERIFY_TMPDIR} for details." - fi - } - - if [ -z "${VERIFY_TMPDIR:-}" ]; then - VERIFY_TMPDIR="$(mktemp -d -t "$1.XXXXX")" - trap cleanup EXIT - else - mkdir -p "${VERIFY_TMPDIR}" - fi -} - -download() { - curl \ - --fail \ - --location \ - --remote-name \ - --show-error \ - --silent \ - "$1" -} - -download_rc_file() { - if [ "${VERIFY_DOWNLOAD}" -gt 0 ]; then - download "${DOWNLOAD_RC_BASE_URL}/$1" - else - cp "${TOP_SOURCE_DIR}/$1" "$1" - fi -} - -import_gpg_keys() { - if [ "${VERIFY_SIGN}" -gt 0 ]; then - download "${ICEBERG_DIST_BASE_URL}/KEYS" - gpg --import KEYS - fi -} - -if type shasum >/dev/null 2>&1; then - sha256_verify="shasum -a 256 -c" - sha512_verify="shasum -a 512 -c" -else - sha256_verify="sha256sum -c" - sha512_verify="sha512sum -c" -fi - -fetch_archive() { - download_rc_file "${ARCHIVE_BASE_NAME}.tar.gz" - if [ "${VERIFY_SIGN}" -gt 0 ]; then - download_rc_file "${ARCHIVE_BASE_NAME}.tar.gz.asc" - gpg --verify "${ARCHIVE_BASE_NAME}.tar.gz.asc" "${ARCHIVE_BASE_NAME}.tar.gz" - fi - download_rc_file "${ARCHIVE_BASE_NAME}.tar.gz.sha256" - ${sha256_verify} "${ARCHIVE_BASE_NAME}.tar.gz.sha256" - download_rc_file "${ARCHIVE_BASE_NAME}.tar.gz.sha512" - ${sha512_verify} "${ARCHIVE_BASE_NAME}.tar.gz.sha512" -} - -ensure_source_directory() { - tar xf "${ARCHIVE_BASE_NAME}".tar.gz -} - -latest_go_version() { - local -a options - options=( - --fail - --location - --show-error - --silent - ) - if [ -n "${GITHUB_TOKEN:-}" ]; then - options+=("--header" "Authorization: Bearer ${GITHUB_TOKEN}") - fi - curl \ - "${options[@]}" \ - https://api.github.com/repos/golang/go/git/matching-refs/tags/go | - grep -o '"ref": "refs/tags/go.*"' | - tail -n 1 | - sed \ - -e 's,^"ref": "refs/tags/go,,g' \ - -e 's/"$//g' -} - -ensure_go() { - if [ "${VERIFY_FORCE_USE_GO_BINARY}" -le 0 ]; then - if go version; then - GOPATH="${VERIFY_TMPDIR}/gopath" - export GOPATH - mkdir -p "${GOPATH}" - return - fi - fi - - local go_version - go_version=$(latest_go_version) - local go_os - go_os="$(uname)" - case "${go_os}" in - Darwin) - go_os="darwin" - ;; - Linux) - go_os="linux" - ;; - esac - local go_arch - go_arch="$(arch)" - case "${go_arch}" in - i386 | x86_64) - go_arch="amd64" - ;; - aarch64) - go_arch="arm64" - ;; - esac - local go_binary_tar_gz - go_binary_tar_gz="go${go_version}.${go_os}-${go_arch}.tar.gz" - local go_binary_url - go_binary_url="https://go.dev/dl/${go_binary_tar_gz}" - curl \ - --fail \ - --location \ - --output "${go_binary_tar_gz}" \ - --show-error \ - --silent \ - "${go_binary_url}" - tar xf "${go_binary_tar_gz}" - GOROOT="$(pwd)/go" - export GOROOT - GOPATH="$(pwd)/gopath" - export GOPATH - mkdir -p "${GOPATH}" - PATH="${GOROOT}/bin:${GOPATH}/bin:${PATH}" -} - -test_source_distribution() { - go test -v ./... - # TODO: run integration tests -} - -setup_tmpdir "iceberg-go-${VERSION}-${RC}" -echo "Working in sandbox ${VERIFY_TMPDIR}" -cd "${VERIFY_TMPDIR}" - -import_gpg_keys -fetch_archive -ensure_source_directory -ensure_go -pushd "${ARCHIVE_BASE_NAME}" -test_source_distribution -popd - -VERIFY_SUCCESS=yes -echo "RC looks good!" +#!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -eu + +SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TOP_SOURCE_DIR="$(dirname "$(dirname "${SOURCE_DIR}")")" + +if [ "$#" -ne 2 ]; then + echo "Usage: $0 " + echo " e.g.: $0 18.0.0 1" + exit 1 +fi + +set -o pipefail +set -x + +VERSION="$1" +RC="$2" + +ICEBERG_DIST_BASE_URL="https://dist.apache.org/repos/dist/dev/iceberg" +DOWNLOAD_RC_BASE_URL="https://github.com/apache/iceberg-go/releases/download/v${VERSION}-rc${RC}" +ARCHIVE_BASE_NAME="apache-iceberg-go-${VERSION}" + +: "${VERIFY_DEFAULT:=1}" +: "${VERIFY_DOWNLOAD:=${VERIFY_DEFAULT}}" +: "${VERIFY_FORCE_USE_GO_BINARY:=0}" +: "${VERIFY_SIGN:=${VERIFY_DEFAULT}}" + +VERIFY_SUCCESS=no + +setup_tmpdir() { + cleanup() { + go clean -modcache || : + if [ "${VERIFY_SUCCESS}" = "yes" ]; then + rm -rf "${VERIFY_TMPDIR}" + else + echo "Failed to verify release candidate. See ${VERIFY_TMPDIR} for details." + fi + } + + if [ -z "${VERIFY_TMPDIR:-}" ]; then + VERIFY_TMPDIR="$(mktemp -d -t "$1.XXXXX")" + trap cleanup EXIT + else + mkdir -p "${VERIFY_TMPDIR}" + fi +} + +download() { + curl \ + --fail \ + --location \ + --remote-name \ + --show-error \ + --silent \ + "$1" +} + +download_rc_file() { + if [ "${VERIFY_DOWNLOAD}" -gt 0 ]; then + download "${DOWNLOAD_RC_BASE_URL}/$1" + else + cp "${TOP_SOURCE_DIR}/$1" "$1" + fi +} + +import_gpg_keys() { + if [ "${VERIFY_SIGN}" -gt 0 ]; then + download "${ICEBERG_DIST_BASE_URL}/KEYS" + gpg --import KEYS + fi +} + +if type shasum >/dev/null 2>&1; then + sha256_verify="shasum -a 256 -c" + sha512_verify="shasum -a 512 -c" +else + sha256_verify="sha256sum -c" + sha512_verify="sha512sum -c" +fi + +fetch_archive() { + download_rc_file "${ARCHIVE_BASE_NAME}.tar.gz" + if [ "${VERIFY_SIGN}" -gt 0 ]; then + download_rc_file "${ARCHIVE_BASE_NAME}.tar.gz.asc" + gpg --verify "${ARCHIVE_BASE_NAME}.tar.gz.asc" "${ARCHIVE_BASE_NAME}.tar.gz" + fi + download_rc_file "${ARCHIVE_BASE_NAME}.tar.gz.sha256" + ${sha256_verify} "${ARCHIVE_BASE_NAME}.tar.gz.sha256" + download_rc_file "${ARCHIVE_BASE_NAME}.tar.gz.sha512" + ${sha512_verify} "${ARCHIVE_BASE_NAME}.tar.gz.sha512" +} + +ensure_source_directory() { + tar xf "${ARCHIVE_BASE_NAME}".tar.gz +} + +latest_go_version() { + local -a options + options=( + --fail + --location + --show-error + --silent + ) + if [ -n "${GITHUB_TOKEN:-}" ]; then + options+=("--header" "Authorization: Bearer ${GITHUB_TOKEN}") + fi + curl \ + "${options[@]}" \ + https://api.github.com/repos/golang/go/git/matching-refs/tags/go | + grep -o '"ref": "refs/tags/go.*"' | + tail -n 1 | + sed \ + -e 's,^"ref": "refs/tags/go,,g' \ + -e 's/"$//g' +} + +ensure_go() { + if [ "${VERIFY_FORCE_USE_GO_BINARY}" -le 0 ]; then + if go version; then + GOPATH="${VERIFY_TMPDIR}/gopath" + export GOPATH + mkdir -p "${GOPATH}" + return + fi + fi + + local go_version + go_version=$(latest_go_version) + local go_os + go_os="$(uname)" + case "${go_os}" in + Darwin) + go_os="darwin" + ;; + Linux) + go_os="linux" + ;; + esac + local go_arch + go_arch="$(arch)" + case "${go_arch}" in + i386 | x86_64) + go_arch="amd64" + ;; + aarch64) + go_arch="arm64" + ;; + esac + local go_binary_tar_gz + go_binary_tar_gz="go${go_version}.${go_os}-${go_arch}.tar.gz" + local go_binary_url + go_binary_url="https://go.dev/dl/${go_binary_tar_gz}" + curl \ + --fail \ + --location \ + --output "${go_binary_tar_gz}" \ + --show-error \ + --silent \ + "${go_binary_url}" + tar xf "${go_binary_tar_gz}" + GOROOT="$(pwd)/go" + export GOROOT + GOPATH="$(pwd)/gopath" + export GOPATH + mkdir -p "${GOPATH}" + PATH="${GOROOT}/bin:${GOPATH}/bin:${PATH}" +} + +test_source_distribution() { + go test -v ./... + # TODO: run integration tests +} + +setup_tmpdir "iceberg-go-${VERSION}-${RC}" +echo "Working in sandbox ${VERIFY_TMPDIR}" +cd "${VERIFY_TMPDIR}" + +import_gpg_keys +fetch_archive +ensure_source_directory +ensure_go +pushd "${ARCHIVE_BASE_NAME}" +test_source_distribution +popd + +VERIFY_SUCCESS=yes +echo "RC looks good!" diff --git a/errors.go b/errors.go index f4fc986..eef808f 100644 --- a/errors.go +++ b/errors.go @@ -1,32 +1,32 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package iceberg - -import "errors" - -var ( - ErrInvalidTypeString = errors.New("invalid type") - ErrNotImplemented = errors.New("not implemented") - ErrInvalidArgument = errors.New("invalid argument") - ErrInvalidSchema = errors.New("invalid schema") - ErrInvalidTransform = errors.New("invalid transform syntax") - ErrType = errors.New("type error") - ErrBadCast = errors.New("could not cast value") - ErrBadLiteral = errors.New("invalid literal value") - ErrInvalidBinSerialization = errors.New("invalid binary serialization") -) +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package iceberg + +import "errors" + +var ( + ErrInvalidTypeString = errors.New("invalid type") + ErrNotImplemented = errors.New("not implemented") + ErrInvalidArgument = errors.New("invalid argument") + ErrInvalidSchema = errors.New("invalid schema") + ErrInvalidTransform = errors.New("invalid transform syntax") + ErrType = errors.New("type error") + ErrBadCast = errors.New("could not cast value") + ErrBadLiteral = errors.New("invalid literal value") + ErrInvalidBinSerialization = errors.New("invalid binary serialization") +) diff --git a/exprs.go b/exprs.go index 4bdb8c1..23d1d69 100644 --- a/exprs.go +++ b/exprs.go @@ -1,1029 +1,1029 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package iceberg - -import ( - "fmt" - "reflect" - - "github.com/google/uuid" -) - -//go:generate stringer -type=Operation -linecomment - -// Operation is an enum used for constants to define what operation a given -// expression or predicate is going to execute. -type Operation int - -const ( - // do not change the order of these enum constants. - // they are grouped for quick validation of operation type by - // using <= and >= of the first/last operation in a group - - OpTrue Operation = iota // True - OpFalse // False - // unary ops - OpIsNull // IsNull - OpNotNull // NotNull - OpIsNan // IsNaN - OpNotNan // NotNaN - // literal ops - OpLT // LessThan - OpLTEQ // LessThanEqual - OpGT // GreaterThan - OpGTEQ // GreaterThanEqual - OpEQ // Equal - OpNEQ // NotEqual - OpStartsWith // StartsWith - OpNotStartsWith // NotStartsWith - // set ops - OpIn // In - OpNotIn // NotIn - // boolean ops - OpNot // Not - OpAnd // And - OpOr // Or -) - -// Negate returns the inverse operation for a given op -func (op Operation) Negate() Operation { - switch op { - case OpIsNull: - return OpNotNull - case OpNotNull: - return OpIsNull - case OpIsNan: - return OpNotNan - case OpNotNan: - return OpIsNan - case OpLT: - return OpGTEQ - case OpLTEQ: - return OpGT - case OpGT: - return OpLTEQ - case OpGTEQ: - return OpLT - case OpEQ: - return OpNEQ - case OpNEQ: - return OpEQ - case OpIn: - return OpNotIn - case OpNotIn: - return OpIn - case OpStartsWith: - return OpNotStartsWith - case OpNotStartsWith: - return OpStartsWith - default: - panic("no negation for operation " + op.String()) - } -} - -// FlipLR returns the correct operation to use if the left and right operands -// are flipped. -func (op Operation) FlipLR() Operation { - switch op { - case OpLT: - return OpGT - case OpLTEQ: - return OpGTEQ - case OpGT: - return OpLT - case OpGTEQ: - return OpLTEQ - case OpAnd: - return OpAnd - case OpOr: - return OpOr - default: - panic("no left-right flip for operation: " + op.String()) - } -} - -// BooleanExpression represents a full expression which will evaluate to a -// boolean value such as GreaterThan or StartsWith, etc. -type BooleanExpression interface { - fmt.Stringer - Op() Operation - Negate() BooleanExpression - Equals(BooleanExpression) bool -} - -// AlwaysTrue is the boolean expression "True" -type AlwaysTrue struct{} - -func (AlwaysTrue) String() string { return "AlwaysTrue()" } -func (AlwaysTrue) Op() Operation { return OpTrue } -func (AlwaysTrue) Negate() BooleanExpression { return AlwaysFalse{} } -func (AlwaysTrue) Equals(other BooleanExpression) bool { - _, ok := other.(AlwaysTrue) - return ok -} - -// AlwaysFalse is the boolean expression "False" -type AlwaysFalse struct{} - -func (AlwaysFalse) String() string { return "AlwaysFalse()" } -func (AlwaysFalse) Op() Operation { return OpFalse } -func (AlwaysFalse) Negate() BooleanExpression { return AlwaysTrue{} } -func (AlwaysFalse) Equals(other BooleanExpression) bool { - _, ok := other.(AlwaysFalse) - return ok -} - -type NotExpr struct { - child BooleanExpression -} - -// NewNot creates a BooleanExpression representing a "Not" operation on the given -// argument. It will optimize slightly though: -// -// If the argument is AlwaysTrue or AlwaysFalse, the appropriate inverse expression -// will be returned directly. If the argument is itself a NotExpr, then the child -// will be returned rather than NotExpr(NotExpr(child)). -func NewNot(child BooleanExpression) BooleanExpression { - if child == nil { - panic(fmt.Errorf("%w: cannot create NotExpr with nil child", - ErrInvalidArgument)) - } - - switch t := child.(type) { - case NotExpr: - return t.child - case AlwaysTrue: - return AlwaysFalse{} - case AlwaysFalse: - return AlwaysTrue{} - } - - return NotExpr{child: child} -} - -func (n NotExpr) String() string { return "Not(child=" + n.child.String() + ")" } -func (NotExpr) Op() Operation { return OpNot } -func (n NotExpr) Negate() BooleanExpression { return n.child } -func (n NotExpr) Equals(other BooleanExpression) bool { - rhs, ok := other.(NotExpr) - if !ok { - return false - } - return n.child.Equals(rhs.child) -} - -type AndExpr struct { - left, right BooleanExpression -} - -func newAnd(left, right BooleanExpression) BooleanExpression { - if left == nil || right == nil { - panic(fmt.Errorf("%w: cannot construct AndExpr with nil arguments", - ErrInvalidArgument)) - } - - switch { - case left == AlwaysFalse{} || right == AlwaysFalse{}: - return AlwaysFalse{} - case left == AlwaysTrue{}: - return right - case right == AlwaysTrue{}: - return left - } - - return AndExpr{left: left, right: right} -} - -// NewAnd will construct a new AndExpr, allowing the caller to provide potentially -// more than just two arguments which will be folded to create an appropriate expression -// tree. i.e. NewAnd(a, b, c, d) becomes AndExpr(a, AndExpr(b, AndExpr(c, d))) -// -// Slight optimizations are performed on creation if either argument is AlwaysFalse -// or AlwaysTrue by performing reductions. If any argument is AlwaysFalse, then everything -// will get folded to a return of AlwaysFalse. If an argument is AlwaysTrue, then the other -// argument will be returned directly rather than creating an AndExpr. -// -// Will panic if any argument is nil -func NewAnd(left, right BooleanExpression, addl ...BooleanExpression) BooleanExpression { - folded := newAnd(left, right) - for _, a := range addl { - folded = newAnd(folded, a) - } - return folded -} - -func (a AndExpr) String() string { - return "And(left=" + a.left.String() + ", right=" + a.right.String() + ")" -} - -func (AndExpr) Op() Operation { return OpAnd } - -func (a AndExpr) Equals(other BooleanExpression) bool { - rhs, ok := other.(AndExpr) - if !ok { - return false - } - - return (a.left.Equals(rhs.left) && a.right.Equals(rhs.right)) || - (a.left.Equals(rhs.right) && a.right.Equals(rhs.left)) -} - -func (a AndExpr) Negate() BooleanExpression { - return NewOr(a.left.Negate(), a.right.Negate()) -} - -type OrExpr struct { - left, right BooleanExpression -} - -func newOr(left, right BooleanExpression) BooleanExpression { - if left == nil || right == nil { - panic(fmt.Errorf("%w: cannot construct OrExpr with nil arguments", - ErrInvalidArgument)) - } - - switch { - case left == AlwaysTrue{} || right == AlwaysTrue{}: - return AlwaysTrue{} - case left == AlwaysFalse{}: - return right - case right == AlwaysFalse{}: - return left - } - - return OrExpr{left: left, right: right} -} - -// NewOr will construct a new OrExpr, allowing the caller to provide potentially -// more than just two arguments which will be folded to create an appropriate expression -// tree. i.e. NewOr(a, b, c, d) becomes OrExpr(a, OrExpr(b, OrExpr(c, d))) -// -// Slight optimizations are performed on creation if either argument is AlwaysFalse -// or AlwaysTrue by performing reductions. If any argument is AlwaysTrue, then everything -// will get folded to a return of AlwaysTrue. If an argument is AlwaysFalse, then the other -// argument will be returned directly rather than creating an OrExpr. -// -// Will panic if any argument is nil -func NewOr(left, right BooleanExpression, addl ...BooleanExpression) BooleanExpression { - folded := newOr(left, right) - for _, a := range addl { - folded = newOr(folded, a) - } - return folded -} - -func (o OrExpr) String() string { - return "Or(left=" + o.left.String() + ", right=" + o.right.String() + ")" -} - -func (OrExpr) Op() Operation { return OpOr } - -func (o OrExpr) Equals(other BooleanExpression) bool { - rhs, ok := other.(OrExpr) - if !ok { - return false - } - - return (o.left.Equals(rhs.left) && o.right.Equals(rhs.right)) || - (o.left.Equals(rhs.right) && o.right.Equals(rhs.left)) -} - -func (o OrExpr) Negate() BooleanExpression { - return NewAnd(o.left.Negate(), o.right.Negate()) -} - -// A Term is a simple expression that evaluates to a value -type Term interface { - fmt.Stringer - // requiring this method ensures that only types we define can be used - // as a term. - isTerm() -} - -// UnboundTerm is an expression that evaluates to a value that isn't yet bound -// to a schema, thus it isn't yet known what the type will be. -type UnboundTerm interface { - Term - - Equals(UnboundTerm) bool - Bind(schema *Schema, caseSensitive bool) (BoundTerm, error) -} - -// BoundTerm is a simple expression (typically a reference) that evaluates to a -// value and has been bound to a schema. -type BoundTerm interface { - Term - - Equals(BoundTerm) bool - Ref() BoundReference - Type() Type - - evalToLiteral(structLike) Optional[Literal] - evalIsNull(structLike) bool -} - -// unbound is a generic interface representing something that is not yet bound -// to a particular type. -type unbound[B any] interface { - Bind(schema *Schema, caseSensitive bool) (B, error) -} - -// An UnboundPredicate represents a boolean predicate expression which has not -// yet been bound to a schema. Binding it will produce a BooleanExpression. -// -// BooleanExpression is used for the binding result because we may optimize and -// return AlwaysTrue / AlwaysFalse in some scenarios during binding which are -// not considered to be "Bound" as they do not have a bound Term or Reference. -type UnboundPredicate interface { - BooleanExpression - unbound[BooleanExpression] - Term() UnboundTerm -} - -// BoundPredicate is a boolean predicate expression which has been bound to a schema. -// The underlying reference and term can be retrieved from it. -type BoundPredicate interface { - BooleanExpression - Ref() BoundReference - Term() BoundTerm -} - -// Reference is a field name not yet bound to a particular field in a schema -type Reference string - -func (r Reference) String() string { - return "Reference(name='" + string(r) + "')" -} - -func (Reference) isTerm() {} -func (r Reference) Equals(other UnboundTerm) bool { - rhs, ok := other.(Reference) - if !ok { - return false - } - - return r == rhs -} - -func (r Reference) Bind(s *Schema, caseSensitive bool) (BoundTerm, error) { - var ( - field NestedField - found bool - ) - - if caseSensitive { - field, found = s.FindFieldByName(string(r)) - } else { - field, found = s.FindFieldByNameCaseInsensitive(string(r)) - } - if !found { - return nil, fmt.Errorf("%w: could not bind reference '%s', caseSensitive=%t", - ErrInvalidSchema, string(r), caseSensitive) - } - - acc, ok := s.accessorForField(field.ID) - if !ok { - return nil, ErrInvalidSchema - } - - return createBoundRef(field, acc), nil -} - -// BoundReference is a named reference that has been bound to a particular field -// in a given schema. -type BoundReference interface { - BoundTerm - - Field() NestedField - Pos() int -} - -type boundRef[T LiteralType] struct { - field NestedField - acc accessor -} - -func createBoundRef(field NestedField, acc accessor) BoundReference { - switch field.Type.(type) { - case BooleanType: - return &boundRef[bool]{field: field, acc: acc} - case Int32Type: - return &boundRef[int32]{field: field, acc: acc} - case Int64Type: - return &boundRef[int64]{field: field, acc: acc} - case Float32Type: - return &boundRef[float32]{field: field, acc: acc} - case Float64Type: - return &boundRef[float64]{field: field, acc: acc} - case DateType: - return &boundRef[Date]{field: field, acc: acc} - case TimeType: - return &boundRef[Time]{field: field, acc: acc} - case TimestampType, TimestampTzType: - return &boundRef[Timestamp]{field: field, acc: acc} - case StringType: - return &boundRef[string]{field: field, acc: acc} - case FixedType, BinaryType: - return &boundRef[[]byte]{field: field, acc: acc} - case DecimalType: - return &boundRef[Decimal]{field: field, acc: acc} - case UUIDType: - return &boundRef[uuid.UUID]{field: field, acc: acc} - } - panic("unhandled bound reference type: " + field.Type.String()) -} - -func (b *boundRef[T]) Pos() int { return b.acc.pos } - -func (*boundRef[T]) isTerm() {} - -func (b *boundRef[T]) String() string { - return fmt.Sprintf("BoundReference(field=%s, accessor=%s)", b.field, &b.acc) -} - -func (b *boundRef[T]) Equals(other BoundTerm) bool { - rhs, ok := other.(*boundRef[T]) - if !ok { - return false - } - - return b.field.Equals(rhs.field) -} - -func (b *boundRef[T]) Ref() BoundReference { return b } -func (b *boundRef[T]) Field() NestedField { return b.field } -func (b *boundRef[T]) Type() Type { return b.field.Type } - -func (b *boundRef[T]) eval(st structLike) Optional[T] { - switch v := b.acc.Get(st).(type) { - case nil: - return Optional[T]{} - case T: - return Optional[T]{Valid: true, Val: v} - default: - var z T - typ, val := reflect.TypeOf(z), reflect.ValueOf(v) - if !val.CanConvert(typ) { - panic(fmt.Errorf("%w: cannot convert value '%+v' to expected type %s", - ErrInvalidSchema, val.Interface(), typ.String())) - } - - return Optional[T]{ - Valid: true, - Val: val.Convert(typ).Interface().(T), - } - } -} - -func (b *boundRef[T]) evalToLiteral(st structLike) Optional[Literal] { - v := b.eval(st) - if !v.Valid { - return Optional[Literal]{} - } - - lit := NewLiteral[T](v.Val) - if !lit.Type().Equals(b.field.Type) { - lit, _ = lit.To(b.field.Type) - } - return Optional[Literal]{Val: lit, Valid: true} -} - -func (b *boundRef[T]) evalIsNull(st structLike) bool { - v := b.eval(st) - return !v.Valid -} - -// UnaryPredicate creates and returns an unbound predicate for the provided unary operation. -// Will panic if op is not a unary operation. -func UnaryPredicate(op Operation, t UnboundTerm) UnboundPredicate { - if op < OpIsNull || op > OpNotNan { - panic(fmt.Errorf("%w: invalid operation for unary predicate: %s", - ErrInvalidArgument, op)) - } - - if t == nil { - panic(fmt.Errorf("%w: cannot create unary predicate with nil term", - ErrInvalidArgument)) - } - - return &unboundUnaryPredicate{op: op, term: t} -} - -type unboundUnaryPredicate struct { - op Operation - term UnboundTerm -} - -func (up *unboundUnaryPredicate) String() string { - return fmt.Sprintf("%s(term=%s)", up.op, up.term) -} - -func (up *unboundUnaryPredicate) Equals(other BooleanExpression) bool { - rhs, ok := other.(*unboundUnaryPredicate) - if !ok { - return false - } - - return up.op == rhs.op && up.term.Equals(rhs.term) -} - -func (up *unboundUnaryPredicate) Op() Operation { return up.op } -func (up *unboundUnaryPredicate) Negate() BooleanExpression { - return &unboundUnaryPredicate{op: up.op.Negate(), term: up.term} -} - -func (up *unboundUnaryPredicate) Term() UnboundTerm { return up.term } -func (up *unboundUnaryPredicate) Bind(schema *Schema, caseSensitive bool) (BooleanExpression, error) { - bound, err := up.term.Bind(schema, caseSensitive) - if err != nil { - return nil, err - } - - // fast case optimizations - switch up.op { - case OpIsNull: - if bound.Ref().Field().Required && !schema.FieldHasOptionalParent(bound.Ref().Field().ID) { - return AlwaysFalse{}, nil - } - case OpNotNull: - if bound.Ref().Field().Required && !schema.FieldHasOptionalParent(bound.Ref().Field().ID) { - return AlwaysTrue{}, nil - } - case OpIsNan: - if !bound.Type().Equals(PrimitiveTypes.Float32) && !bound.Type().Equals(PrimitiveTypes.Float64) { - return AlwaysFalse{}, nil - } - case OpNotNan: - if !bound.Type().Equals(PrimitiveTypes.Float32) && !bound.Type().Equals(PrimitiveTypes.Float64) { - return AlwaysTrue{}, nil - } - } - - return createBoundUnaryPredicate(up.op, bound), nil -} - -// BoundUnaryPredicate is a bound predicate expression that has no arguments -type BoundUnaryPredicate interface { - BoundPredicate - - AsUnbound(Reference) UnboundPredicate -} - -type bound[T LiteralType] interface { - BoundTerm - - eval(structLike) Optional[T] -} - -func newBoundUnaryPred[T LiteralType](op Operation, term BoundTerm) BoundUnaryPredicate { - return &boundUnaryPredicate[T]{op: op, term: term.(bound[T])} -} - -func createBoundUnaryPredicate(op Operation, term BoundTerm) BoundUnaryPredicate { - switch term.Type().(type) { - case BooleanType: - return newBoundUnaryPred[bool](op, term) - case Int32Type: - return newBoundUnaryPred[int32](op, term) - case Int64Type: - return newBoundUnaryPred[int64](op, term) - case Float32Type: - return newBoundUnaryPred[float32](op, term) - case Float64Type: - return newBoundUnaryPred[float64](op, term) - case DateType: - return newBoundUnaryPred[Date](op, term) - case TimeType: - return newBoundUnaryPred[Time](op, term) - case TimestampType, TimestampTzType: - return newBoundUnaryPred[Timestamp](op, term) - case StringType: - return newBoundUnaryPred[string](op, term) - case FixedType, BinaryType: - return newBoundUnaryPred[[]byte](op, term) - case DecimalType: - return newBoundUnaryPred[Decimal](op, term) - case UUIDType: - return newBoundUnaryPred[uuid.UUID](op, term) - } - panic("unhandled bound reference type: " + term.Type().String()) -} - -type boundUnaryPredicate[T LiteralType] struct { - op Operation - term bound[T] -} - -func (bp *boundUnaryPredicate[T]) AsUnbound(r Reference) UnboundPredicate { - return &unboundUnaryPredicate{op: bp.op, term: r} -} - -func (bp *boundUnaryPredicate[T]) Equals(other BooleanExpression) bool { - rhs, ok := other.(*boundUnaryPredicate[T]) - if !ok { - return false - } - - return bp.op == rhs.op && bp.term.Equals(rhs.term) -} - -func (bp *boundUnaryPredicate[T]) Op() Operation { return bp.op } -func (bp *boundUnaryPredicate[T]) Negate() BooleanExpression { - return &boundUnaryPredicate[T]{op: bp.op.Negate(), term: bp.term} -} - -func (bp *boundUnaryPredicate[T]) Term() BoundTerm { return bp.term } -func (bp *boundUnaryPredicate[T]) Ref() BoundReference { return bp.term.Ref() } -func (bp *boundUnaryPredicate[T]) String() string { - return fmt.Sprintf("Bound%s(term=%s)", bp.op, bp.term) -} - -// LiteralPredicate constructs an unbound predicate for an operation that requires -// a single literal argument, such as LessThan or StartsWith. -// -// Panics if the operation provided is not a valid Literal operation, -// if the term is nil or if the literal is nil. -func LiteralPredicate(op Operation, t UnboundTerm, lit Literal) UnboundPredicate { - switch { - case op < OpLT || op > OpNotStartsWith: - panic(fmt.Errorf("%w: invalid operation for LiteralPredicate: %s", - ErrInvalidArgument, op)) - case t == nil: - panic(fmt.Errorf("%w: cannot create literal predicate with nil term", - ErrInvalidArgument)) - case lit == nil: - panic(fmt.Errorf("%w: cannot create literal predicate with nil literal", - ErrInvalidArgument)) - } - - return &unboundLiteralPredicate{op: op, term: t, lit: lit} -} - -type unboundLiteralPredicate struct { - op Operation - term UnboundTerm - lit Literal -} - -func (ul *unboundLiteralPredicate) String() string { - return fmt.Sprintf("%s(term=%s, literal=%s)", ul.op, ul.term, ul.lit) -} - -func (ul *unboundLiteralPredicate) Equals(other BooleanExpression) bool { - rhs, ok := other.(*unboundLiteralPredicate) - if !ok { - return false - } - - return ul.op == rhs.op && ul.term.Equals(rhs.term) && ul.lit.Equals(rhs.lit) -} - -func (ul *unboundLiteralPredicate) Op() Operation { return ul.op } -func (ul *unboundLiteralPredicate) Negate() BooleanExpression { - return &unboundLiteralPredicate{op: ul.op.Negate(), term: ul.term, lit: ul.lit} -} -func (ul *unboundLiteralPredicate) Term() UnboundTerm { return ul.term } -func (ul *unboundLiteralPredicate) Bind(schema *Schema, caseSensitive bool) (BooleanExpression, error) { - bound, err := ul.term.Bind(schema, caseSensitive) - if err != nil { - return nil, err - } - - if (ul.op == OpStartsWith || ul.op == OpNotStartsWith) && - !(bound.Type().Equals(PrimitiveTypes.String) || bound.Type().Equals(PrimitiveTypes.Binary)) { - return nil, fmt.Errorf("%w: StartsWith and NotStartsWith must bind to String type, not %s", - ErrType, bound.Type()) - } - - lit, err := ul.lit.To(bound.Type()) - if err != nil { - return nil, err - } - - switch lit.(type) { - case AboveMaxLiteral: - switch ul.op { - case OpLT, OpLTEQ, OpNEQ: - return AlwaysTrue{}, nil - case OpGT, OpGTEQ, OpEQ: - return AlwaysFalse{}, nil - } - case BelowMinLiteral: - switch ul.op { - case OpLT, OpLTEQ, OpEQ: - return AlwaysFalse{}, nil - case OpGT, OpGTEQ, OpNEQ: - return AlwaysTrue{}, nil - } - } - - return createBoundLiteralPredicate(ul.op, bound, lit) -} - -// BoundLiteralPredicate represents a bound boolean expression that utilizes a single -// literal as an argument, such as Equals or StartsWith. -type BoundLiteralPredicate interface { - BoundPredicate - - Literal() Literal - AsUnbound(Reference, Literal) UnboundPredicate -} - -func newBoundLiteralPredicate[T LiteralType](op Operation, term BoundTerm, lit Literal) BoundPredicate { - return &boundLiteralPredicate[T]{op: op, term: term.(bound[T]), - lit: lit.(TypedLiteral[T])} -} - -func createBoundLiteralPredicate(op Operation, term BoundTerm, lit Literal) (BoundPredicate, error) { - finalLit, err := lit.To(term.Type()) - if err != nil { - return nil, err - } - - switch term.Type().(type) { - case BooleanType: - return newBoundLiteralPredicate[bool](op, term, finalLit), nil - case Int32Type: - return newBoundLiteralPredicate[int32](op, term, finalLit), nil - case Int64Type: - return newBoundLiteralPredicate[int64](op, term, finalLit), nil - case Float32Type: - return newBoundLiteralPredicate[float32](op, term, finalLit), nil - case Float64Type: - return newBoundLiteralPredicate[float64](op, term, finalLit), nil - case DateType: - return newBoundLiteralPredicate[Date](op, term, finalLit), nil - case TimeType: - return newBoundLiteralPredicate[Time](op, term, finalLit), nil - case TimestampType, TimestampTzType: - return newBoundLiteralPredicate[Timestamp](op, term, finalLit), nil - case StringType: - return newBoundLiteralPredicate[string](op, term, finalLit), nil - case FixedType, BinaryType: - return newBoundLiteralPredicate[[]byte](op, term, finalLit), nil - case DecimalType: - return newBoundLiteralPredicate[Decimal](op, term, finalLit), nil - case UUIDType: - return newBoundLiteralPredicate[uuid.UUID](op, term, finalLit), nil - } - return nil, fmt.Errorf("%w: could not create bound literal predicate for term type %s", - ErrInvalidArgument, term.Type()) -} - -type boundLiteralPredicate[T LiteralType] struct { - op Operation - term bound[T] - lit TypedLiteral[T] -} - -func (blp *boundLiteralPredicate[T]) Equals(other BooleanExpression) bool { - rhs, ok := other.(*boundLiteralPredicate[T]) - if !ok { - return false - } - - return blp.op == rhs.op && blp.term.Equals(rhs.term) && blp.lit.Equals(rhs.lit) -} - -func (blp *boundLiteralPredicate[T]) Op() Operation { return blp.op } -func (blp *boundLiteralPredicate[T]) Negate() BooleanExpression { - return &boundLiteralPredicate[T]{op: blp.op.Negate(), term: blp.term, lit: blp.lit} -} -func (blp *boundLiteralPredicate[T]) Term() BoundTerm { return blp.term } -func (blp *boundLiteralPredicate[T]) Ref() BoundReference { return blp.term.Ref() } -func (blp *boundLiteralPredicate[T]) String() string { - return fmt.Sprintf("Bound%s(term=%s, literal=%s)", blp.op, blp.term, blp.lit) -} -func (blp *boundLiteralPredicate[T]) Literal() Literal { return blp.lit } -func (blp *boundLiteralPredicate[T]) AsUnbound(r Reference, l Literal) UnboundPredicate { - return &unboundLiteralPredicate{op: blp.op, term: r, lit: l} -} - -// SetPredicate creates a boolean expression representing a predicate that uses a set -// of literals as the argument, like In or NotIn. Duplicate literals will be folded -// into a set, only maintaining the unique literals. -// -// Will panic if op is not a valid Set operation -func SetPredicate(op Operation, t UnboundTerm, lits []Literal) BooleanExpression { - if op < OpIn || op > OpNotIn { - panic(fmt.Errorf("%w: invalid operation for SetPredicate: %s", - ErrInvalidArgument, op)) - } - - if t == nil { - panic(fmt.Errorf("%w: cannot create set predicate with nil term", - ErrInvalidArgument)) - } - - switch len(lits) { - case 0: - if op == OpIn { - return AlwaysFalse{} - } else if op == OpNotIn { - return AlwaysTrue{} - } - case 1: - if op == OpIn { - return LiteralPredicate(OpEQ, t, lits[0]) - } else if op == OpNotIn { - return LiteralPredicate(OpNEQ, t, lits[0]) - } - } - - return &unboundSetPredicate{op: op, term: t, lits: newLiteralSet(lits...)} -} - -type unboundSetPredicate struct { - op Operation - term UnboundTerm - lits Set[Literal] -} - -func (usp *unboundSetPredicate) String() string { - return fmt.Sprintf("%s(term=%s, {%v})", usp.op, usp.term, usp.lits.Members()) -} - -func (usp *unboundSetPredicate) Equals(other BooleanExpression) bool { - rhs, ok := other.(*unboundSetPredicate) - if !ok { - return false - } - - return usp.op == rhs.op && usp.term.Equals(rhs.term) && - usp.lits.Equals(rhs.lits) -} - -func (usp *unboundSetPredicate) Op() Operation { return usp.op } -func (usp *unboundSetPredicate) Negate() BooleanExpression { - return &unboundSetPredicate{op: usp.op.Negate(), term: usp.term, lits: usp.lits} -} - -func (usp *unboundSetPredicate) Term() UnboundTerm { return usp.term } -func (usp *unboundSetPredicate) Bind(schema *Schema, caseSensitive bool) (BooleanExpression, error) { - bound, err := usp.term.Bind(schema, caseSensitive) - if err != nil { - return nil, err - } - - return createBoundSetPredicate(usp.op, bound, usp.lits) -} - -// BoundSetPredicate is a bound expression that utilizes a set of literals such as In or NotIn -type BoundSetPredicate interface { - BoundPredicate - - Literals() Set[Literal] - AsUnbound(Reference, []Literal) UnboundPredicate -} - -func createBoundSetPredicate(op Operation, term BoundTerm, lits Set[Literal]) (BooleanExpression, error) { - boundType := term.Type() - - typedSet := newLiteralSet() - for _, v := range lits.Members() { - casted, err := v.To(boundType) - if err != nil { - return nil, err - } - typedSet.Add(casted) - } - - switch typedSet.Len() { - case 0: - if op == OpIn { - return AlwaysFalse{}, nil - } else if op == OpNotIn { - return AlwaysTrue{}, nil - } - case 1: - if op == OpIn { - return createBoundLiteralPredicate(OpEQ, term, typedSet.Members()[0]) - } else if op == OpNotIn { - return createBoundLiteralPredicate(OpNEQ, term, typedSet.Members()[0]) - } - } - - switch term.Type().(type) { - case BooleanType: - return newBoundSetPredicate[bool](op, term, typedSet), nil - case Int32Type: - return newBoundSetPredicate[int32](op, term, typedSet), nil - case Int64Type: - return newBoundSetPredicate[int64](op, term, typedSet), nil - case Float32Type: - return newBoundSetPredicate[float32](op, term, typedSet), nil - case Float64Type: - return newBoundSetPredicate[float64](op, term, typedSet), nil - case DateType: - return newBoundSetPredicate[Date](op, term, typedSet), nil - case TimeType: - return newBoundSetPredicate[Time](op, term, typedSet), nil - case TimestampType, TimestampTzType: - return newBoundSetPredicate[Timestamp](op, term, typedSet), nil - case StringType: - return newBoundSetPredicate[string](op, term, typedSet), nil - case BinaryType, FixedType: - return newBoundSetPredicate[[]byte](op, term, typedSet), nil - case DecimalType: - return newBoundSetPredicate[Decimal](op, term, typedSet), nil - case UUIDType: - return newBoundSetPredicate[uuid.UUID](op, term, typedSet), nil - } - - return nil, fmt.Errorf("%w: invalid bound type for set predicate - %s", - ErrType, term.Type()) -} - -func newBoundSetPredicate[T LiteralType](op Operation, term BoundTerm, lits Set[Literal]) *boundSetPredicate[T] { - return &boundSetPredicate[T]{op: op, term: term.(bound[T]), lits: lits} -} - -type boundSetPredicate[T LiteralType] struct { - op Operation - term bound[T] - lits Set[Literal] -} - -func (bsp *boundSetPredicate[T]) Equals(other BooleanExpression) bool { - rhs, ok := other.(*boundSetPredicate[T]) - if !ok { - return false - } - - return bsp.op == rhs.op && bsp.term.Equals(rhs.term) && - bsp.lits.Equals(rhs.lits) -} - -func (bsp *boundSetPredicate[T]) Op() Operation { return bsp.op } -func (bsp *boundSetPredicate[T]) Negate() BooleanExpression { - return &boundSetPredicate[T]{op: bsp.op.Negate(), term: bsp.term, - lits: bsp.lits} -} -func (bsp *boundSetPredicate[T]) Term() BoundTerm { return bsp.term } -func (bsp *boundSetPredicate[T]) Ref() BoundReference { return bsp.term.Ref() } -func (bsp *boundSetPredicate[T]) String() string { - return fmt.Sprintf("Bound%s(term=%s, {%v})", bsp.op, bsp.term, bsp.lits.Members()) -} -func (bsp *boundSetPredicate[T]) AsUnbound(r Reference, lits []Literal) UnboundPredicate { - litSet := newLiteralSet(lits...) - if litSet.Len() == 1 { - switch bsp.op { - case OpIn: - return LiteralPredicate(OpEQ, r, lits[0]) - case OpNotIn: - return LiteralPredicate(OpNEQ, r, lits[0]) - } - } - - return &unboundSetPredicate{op: bsp.op, term: r, lits: litSet} -} - -func (bsp *boundSetPredicate[T]) Literals() Set[Literal] { - return bsp.lits -} - -type BoundTransform struct { - transform Transform - term BoundTerm -} - -func (*BoundTransform) isTerm() {} -func (b *BoundTransform) String() string { - return fmt.Sprintf("BoundTransform(transform=%s, term=%s)", - b.transform, b.term) -} - -func (b *BoundTransform) Ref() BoundReference { return b.term.Ref() } -func (b *BoundTransform) Type() Type { return b.transform.ResultType(b.term.Type()) } - -func (b *BoundTransform) Equals(other BoundTerm) bool { - rhs, ok := other.(*BoundTransform) - if !ok { - return false - } - - return b.transform.Equals(rhs.transform) && b.term.Equals(rhs.term) -} - -func (b *BoundTransform) evalToLiteral(st structLike) Optional[Literal] { - return b.transform.Apply(b.term.evalToLiteral(st)) -} - -func (b *BoundTransform) evalIsNull(st structLike) bool { - return !b.evalToLiteral(st).Valid -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package iceberg + +import ( + "fmt" + "reflect" + + "github.com/google/uuid" +) + +//go:generate stringer -type=Operation -linecomment + +// Operation is an enum used for constants to define what operation a given +// expression or predicate is going to execute. +type Operation int + +const ( + // do not change the order of these enum constants. + // they are grouped for quick validation of operation type by + // using <= and >= of the first/last operation in a group + + OpTrue Operation = iota // True + OpFalse // False + // unary ops + OpIsNull // IsNull + OpNotNull // NotNull + OpIsNan // IsNaN + OpNotNan // NotNaN + // literal ops + OpLT // LessThan + OpLTEQ // LessThanEqual + OpGT // GreaterThan + OpGTEQ // GreaterThanEqual + OpEQ // Equal + OpNEQ // NotEqual + OpStartsWith // StartsWith + OpNotStartsWith // NotStartsWith + // set ops + OpIn // In + OpNotIn // NotIn + // boolean ops + OpNot // Not + OpAnd // And + OpOr // Or +) + +// Negate returns the inverse operation for a given op +func (op Operation) Negate() Operation { + switch op { + case OpIsNull: + return OpNotNull + case OpNotNull: + return OpIsNull + case OpIsNan: + return OpNotNan + case OpNotNan: + return OpIsNan + case OpLT: + return OpGTEQ + case OpLTEQ: + return OpGT + case OpGT: + return OpLTEQ + case OpGTEQ: + return OpLT + case OpEQ: + return OpNEQ + case OpNEQ: + return OpEQ + case OpIn: + return OpNotIn + case OpNotIn: + return OpIn + case OpStartsWith: + return OpNotStartsWith + case OpNotStartsWith: + return OpStartsWith + default: + panic("no negation for operation " + op.String()) + } +} + +// FlipLR returns the correct operation to use if the left and right operands +// are flipped. +func (op Operation) FlipLR() Operation { + switch op { + case OpLT: + return OpGT + case OpLTEQ: + return OpGTEQ + case OpGT: + return OpLT + case OpGTEQ: + return OpLTEQ + case OpAnd: + return OpAnd + case OpOr: + return OpOr + default: + panic("no left-right flip for operation: " + op.String()) + } +} + +// BooleanExpression represents a full expression which will evaluate to a +// boolean value such as GreaterThan or StartsWith, etc. +type BooleanExpression interface { + fmt.Stringer + Op() Operation + Negate() BooleanExpression + Equals(BooleanExpression) bool +} + +// AlwaysTrue is the boolean expression "True" +type AlwaysTrue struct{} + +func (AlwaysTrue) String() string { return "AlwaysTrue()" } +func (AlwaysTrue) Op() Operation { return OpTrue } +func (AlwaysTrue) Negate() BooleanExpression { return AlwaysFalse{} } +func (AlwaysTrue) Equals(other BooleanExpression) bool { + _, ok := other.(AlwaysTrue) + return ok +} + +// AlwaysFalse is the boolean expression "False" +type AlwaysFalse struct{} + +func (AlwaysFalse) String() string { return "AlwaysFalse()" } +func (AlwaysFalse) Op() Operation { return OpFalse } +func (AlwaysFalse) Negate() BooleanExpression { return AlwaysTrue{} } +func (AlwaysFalse) Equals(other BooleanExpression) bool { + _, ok := other.(AlwaysFalse) + return ok +} + +type NotExpr struct { + child BooleanExpression +} + +// NewNot creates a BooleanExpression representing a "Not" operation on the given +// argument. It will optimize slightly though: +// +// If the argument is AlwaysTrue or AlwaysFalse, the appropriate inverse expression +// will be returned directly. If the argument is itself a NotExpr, then the child +// will be returned rather than NotExpr(NotExpr(child)). +func NewNot(child BooleanExpression) BooleanExpression { + if child == nil { + panic(fmt.Errorf("%w: cannot create NotExpr with nil child", + ErrInvalidArgument)) + } + + switch t := child.(type) { + case NotExpr: + return t.child + case AlwaysTrue: + return AlwaysFalse{} + case AlwaysFalse: + return AlwaysTrue{} + } + + return NotExpr{child: child} +} + +func (n NotExpr) String() string { return "Not(child=" + n.child.String() + ")" } +func (NotExpr) Op() Operation { return OpNot } +func (n NotExpr) Negate() BooleanExpression { return n.child } +func (n NotExpr) Equals(other BooleanExpression) bool { + rhs, ok := other.(NotExpr) + if !ok { + return false + } + return n.child.Equals(rhs.child) +} + +type AndExpr struct { + left, right BooleanExpression +} + +func newAnd(left, right BooleanExpression) BooleanExpression { + if left == nil || right == nil { + panic(fmt.Errorf("%w: cannot construct AndExpr with nil arguments", + ErrInvalidArgument)) + } + + switch { + case left == AlwaysFalse{} || right == AlwaysFalse{}: + return AlwaysFalse{} + case left == AlwaysTrue{}: + return right + case right == AlwaysTrue{}: + return left + } + + return AndExpr{left: left, right: right} +} + +// NewAnd will construct a new AndExpr, allowing the caller to provide potentially +// more than just two arguments which will be folded to create an appropriate expression +// tree. i.e. NewAnd(a, b, c, d) becomes AndExpr(a, AndExpr(b, AndExpr(c, d))) +// +// Slight optimizations are performed on creation if either argument is AlwaysFalse +// or AlwaysTrue by performing reductions. If any argument is AlwaysFalse, then everything +// will get folded to a return of AlwaysFalse. If an argument is AlwaysTrue, then the other +// argument will be returned directly rather than creating an AndExpr. +// +// Will panic if any argument is nil +func NewAnd(left, right BooleanExpression, addl ...BooleanExpression) BooleanExpression { + folded := newAnd(left, right) + for _, a := range addl { + folded = newAnd(folded, a) + } + return folded +} + +func (a AndExpr) String() string { + return "And(left=" + a.left.String() + ", right=" + a.right.String() + ")" +} + +func (AndExpr) Op() Operation { return OpAnd } + +func (a AndExpr) Equals(other BooleanExpression) bool { + rhs, ok := other.(AndExpr) + if !ok { + return false + } + + return (a.left.Equals(rhs.left) && a.right.Equals(rhs.right)) || + (a.left.Equals(rhs.right) && a.right.Equals(rhs.left)) +} + +func (a AndExpr) Negate() BooleanExpression { + return NewOr(a.left.Negate(), a.right.Negate()) +} + +type OrExpr struct { + left, right BooleanExpression +} + +func newOr(left, right BooleanExpression) BooleanExpression { + if left == nil || right == nil { + panic(fmt.Errorf("%w: cannot construct OrExpr with nil arguments", + ErrInvalidArgument)) + } + + switch { + case left == AlwaysTrue{} || right == AlwaysTrue{}: + return AlwaysTrue{} + case left == AlwaysFalse{}: + return right + case right == AlwaysFalse{}: + return left + } + + return OrExpr{left: left, right: right} +} + +// NewOr will construct a new OrExpr, allowing the caller to provide potentially +// more than just two arguments which will be folded to create an appropriate expression +// tree. i.e. NewOr(a, b, c, d) becomes OrExpr(a, OrExpr(b, OrExpr(c, d))) +// +// Slight optimizations are performed on creation if either argument is AlwaysFalse +// or AlwaysTrue by performing reductions. If any argument is AlwaysTrue, then everything +// will get folded to a return of AlwaysTrue. If an argument is AlwaysFalse, then the other +// argument will be returned directly rather than creating an OrExpr. +// +// Will panic if any argument is nil +func NewOr(left, right BooleanExpression, addl ...BooleanExpression) BooleanExpression { + folded := newOr(left, right) + for _, a := range addl { + folded = newOr(folded, a) + } + return folded +} + +func (o OrExpr) String() string { + return "Or(left=" + o.left.String() + ", right=" + o.right.String() + ")" +} + +func (OrExpr) Op() Operation { return OpOr } + +func (o OrExpr) Equals(other BooleanExpression) bool { + rhs, ok := other.(OrExpr) + if !ok { + return false + } + + return (o.left.Equals(rhs.left) && o.right.Equals(rhs.right)) || + (o.left.Equals(rhs.right) && o.right.Equals(rhs.left)) +} + +func (o OrExpr) Negate() BooleanExpression { + return NewAnd(o.left.Negate(), o.right.Negate()) +} + +// A Term is a simple expression that evaluates to a value +type Term interface { + fmt.Stringer + // requiring this method ensures that only types we define can be used + // as a term. + isTerm() +} + +// UnboundTerm is an expression that evaluates to a value that isn't yet bound +// to a schema, thus it isn't yet known what the type will be. +type UnboundTerm interface { + Term + + Equals(UnboundTerm) bool + Bind(schema *Schema, caseSensitive bool) (BoundTerm, error) +} + +// BoundTerm is a simple expression (typically a reference) that evaluates to a +// value and has been bound to a schema. +type BoundTerm interface { + Term + + Equals(BoundTerm) bool + Ref() BoundReference + Type() Type + + evalToLiteral(structLike) Optional[Literal] + evalIsNull(structLike) bool +} + +// unbound is a generic interface representing something that is not yet bound +// to a particular type. +type unbound[B any] interface { + Bind(schema *Schema, caseSensitive bool) (B, error) +} + +// An UnboundPredicate represents a boolean predicate expression which has not +// yet been bound to a schema. Binding it will produce a BooleanExpression. +// +// BooleanExpression is used for the binding result because we may optimize and +// return AlwaysTrue / AlwaysFalse in some scenarios during binding which are +// not considered to be "Bound" as they do not have a bound Term or Reference. +type UnboundPredicate interface { + BooleanExpression + unbound[BooleanExpression] + Term() UnboundTerm +} + +// BoundPredicate is a boolean predicate expression which has been bound to a schema. +// The underlying reference and term can be retrieved from it. +type BoundPredicate interface { + BooleanExpression + Ref() BoundReference + Term() BoundTerm +} + +// Reference is a field name not yet bound to a particular field in a schema +type Reference string + +func (r Reference) String() string { + return "Reference(name='" + string(r) + "')" +} + +func (Reference) isTerm() {} +func (r Reference) Equals(other UnboundTerm) bool { + rhs, ok := other.(Reference) + if !ok { + return false + } + + return r == rhs +} + +func (r Reference) Bind(s *Schema, caseSensitive bool) (BoundTerm, error) { + var ( + field NestedField + found bool + ) + + if caseSensitive { + field, found = s.FindFieldByName(string(r)) + } else { + field, found = s.FindFieldByNameCaseInsensitive(string(r)) + } + if !found { + return nil, fmt.Errorf("%w: could not bind reference '%s', caseSensitive=%t", + ErrInvalidSchema, string(r), caseSensitive) + } + + acc, ok := s.accessorForField(field.ID) + if !ok { + return nil, ErrInvalidSchema + } + + return createBoundRef(field, acc), nil +} + +// BoundReference is a named reference that has been bound to a particular field +// in a given schema. +type BoundReference interface { + BoundTerm + + Field() NestedField + Pos() int +} + +type boundRef[T LiteralType] struct { + field NestedField + acc accessor +} + +func createBoundRef(field NestedField, acc accessor) BoundReference { + switch field.Type.(type) { + case BooleanType: + return &boundRef[bool]{field: field, acc: acc} + case Int32Type: + return &boundRef[int32]{field: field, acc: acc} + case Int64Type: + return &boundRef[int64]{field: field, acc: acc} + case Float32Type: + return &boundRef[float32]{field: field, acc: acc} + case Float64Type: + return &boundRef[float64]{field: field, acc: acc} + case DateType: + return &boundRef[Date]{field: field, acc: acc} + case TimeType: + return &boundRef[Time]{field: field, acc: acc} + case TimestampType, TimestampTzType: + return &boundRef[Timestamp]{field: field, acc: acc} + case StringType: + return &boundRef[string]{field: field, acc: acc} + case FixedType, BinaryType: + return &boundRef[[]byte]{field: field, acc: acc} + case DecimalType: + return &boundRef[Decimal]{field: field, acc: acc} + case UUIDType: + return &boundRef[uuid.UUID]{field: field, acc: acc} + } + panic("unhandled bound reference type: " + field.Type.String()) +} + +func (b *boundRef[T]) Pos() int { return b.acc.pos } + +func (*boundRef[T]) isTerm() {} + +func (b *boundRef[T]) String() string { + return fmt.Sprintf("BoundReference(field=%s, accessor=%s)", b.field, &b.acc) +} + +func (b *boundRef[T]) Equals(other BoundTerm) bool { + rhs, ok := other.(*boundRef[T]) + if !ok { + return false + } + + return b.field.Equals(rhs.field) +} + +func (b *boundRef[T]) Ref() BoundReference { return b } +func (b *boundRef[T]) Field() NestedField { return b.field } +func (b *boundRef[T]) Type() Type { return b.field.Type } + +func (b *boundRef[T]) eval(st structLike) Optional[T] { + switch v := b.acc.Get(st).(type) { + case nil: + return Optional[T]{} + case T: + return Optional[T]{Valid: true, Val: v} + default: + var z T + typ, val := reflect.TypeOf(z), reflect.ValueOf(v) + if !val.CanConvert(typ) { + panic(fmt.Errorf("%w: cannot convert value '%+v' to expected type %s", + ErrInvalidSchema, val.Interface(), typ.String())) + } + + return Optional[T]{ + Valid: true, + Val: val.Convert(typ).Interface().(T), + } + } +} + +func (b *boundRef[T]) evalToLiteral(st structLike) Optional[Literal] { + v := b.eval(st) + if !v.Valid { + return Optional[Literal]{} + } + + lit := NewLiteral[T](v.Val) + if !lit.Type().Equals(b.field.Type) { + lit, _ = lit.To(b.field.Type) + } + return Optional[Literal]{Val: lit, Valid: true} +} + +func (b *boundRef[T]) evalIsNull(st structLike) bool { + v := b.eval(st) + return !v.Valid +} + +// UnaryPredicate creates and returns an unbound predicate for the provided unary operation. +// Will panic if op is not a unary operation. +func UnaryPredicate(op Operation, t UnboundTerm) UnboundPredicate { + if op < OpIsNull || op > OpNotNan { + panic(fmt.Errorf("%w: invalid operation for unary predicate: %s", + ErrInvalidArgument, op)) + } + + if t == nil { + panic(fmt.Errorf("%w: cannot create unary predicate with nil term", + ErrInvalidArgument)) + } + + return &unboundUnaryPredicate{op: op, term: t} +} + +type unboundUnaryPredicate struct { + op Operation + term UnboundTerm +} + +func (up *unboundUnaryPredicate) String() string { + return fmt.Sprintf("%s(term=%s)", up.op, up.term) +} + +func (up *unboundUnaryPredicate) Equals(other BooleanExpression) bool { + rhs, ok := other.(*unboundUnaryPredicate) + if !ok { + return false + } + + return up.op == rhs.op && up.term.Equals(rhs.term) +} + +func (up *unboundUnaryPredicate) Op() Operation { return up.op } +func (up *unboundUnaryPredicate) Negate() BooleanExpression { + return &unboundUnaryPredicate{op: up.op.Negate(), term: up.term} +} + +func (up *unboundUnaryPredicate) Term() UnboundTerm { return up.term } +func (up *unboundUnaryPredicate) Bind(schema *Schema, caseSensitive bool) (BooleanExpression, error) { + bound, err := up.term.Bind(schema, caseSensitive) + if err != nil { + return nil, err + } + + // fast case optimizations + switch up.op { + case OpIsNull: + if bound.Ref().Field().Required && !schema.FieldHasOptionalParent(bound.Ref().Field().ID) { + return AlwaysFalse{}, nil + } + case OpNotNull: + if bound.Ref().Field().Required && !schema.FieldHasOptionalParent(bound.Ref().Field().ID) { + return AlwaysTrue{}, nil + } + case OpIsNan: + if !bound.Type().Equals(PrimitiveTypes.Float32) && !bound.Type().Equals(PrimitiveTypes.Float64) { + return AlwaysFalse{}, nil + } + case OpNotNan: + if !bound.Type().Equals(PrimitiveTypes.Float32) && !bound.Type().Equals(PrimitiveTypes.Float64) { + return AlwaysTrue{}, nil + } + } + + return createBoundUnaryPredicate(up.op, bound), nil +} + +// BoundUnaryPredicate is a bound predicate expression that has no arguments +type BoundUnaryPredicate interface { + BoundPredicate + + AsUnbound(Reference) UnboundPredicate +} + +type bound[T LiteralType] interface { + BoundTerm + + eval(structLike) Optional[T] +} + +func newBoundUnaryPred[T LiteralType](op Operation, term BoundTerm) BoundUnaryPredicate { + return &boundUnaryPredicate[T]{op: op, term: term.(bound[T])} +} + +func createBoundUnaryPredicate(op Operation, term BoundTerm) BoundUnaryPredicate { + switch term.Type().(type) { + case BooleanType: + return newBoundUnaryPred[bool](op, term) + case Int32Type: + return newBoundUnaryPred[int32](op, term) + case Int64Type: + return newBoundUnaryPred[int64](op, term) + case Float32Type: + return newBoundUnaryPred[float32](op, term) + case Float64Type: + return newBoundUnaryPred[float64](op, term) + case DateType: + return newBoundUnaryPred[Date](op, term) + case TimeType: + return newBoundUnaryPred[Time](op, term) + case TimestampType, TimestampTzType: + return newBoundUnaryPred[Timestamp](op, term) + case StringType: + return newBoundUnaryPred[string](op, term) + case FixedType, BinaryType: + return newBoundUnaryPred[[]byte](op, term) + case DecimalType: + return newBoundUnaryPred[Decimal](op, term) + case UUIDType: + return newBoundUnaryPred[uuid.UUID](op, term) + } + panic("unhandled bound reference type: " + term.Type().String()) +} + +type boundUnaryPredicate[T LiteralType] struct { + op Operation + term bound[T] +} + +func (bp *boundUnaryPredicate[T]) AsUnbound(r Reference) UnboundPredicate { + return &unboundUnaryPredicate{op: bp.op, term: r} +} + +func (bp *boundUnaryPredicate[T]) Equals(other BooleanExpression) bool { + rhs, ok := other.(*boundUnaryPredicate[T]) + if !ok { + return false + } + + return bp.op == rhs.op && bp.term.Equals(rhs.term) +} + +func (bp *boundUnaryPredicate[T]) Op() Operation { return bp.op } +func (bp *boundUnaryPredicate[T]) Negate() BooleanExpression { + return &boundUnaryPredicate[T]{op: bp.op.Negate(), term: bp.term} +} + +func (bp *boundUnaryPredicate[T]) Term() BoundTerm { return bp.term } +func (bp *boundUnaryPredicate[T]) Ref() BoundReference { return bp.term.Ref() } +func (bp *boundUnaryPredicate[T]) String() string { + return fmt.Sprintf("Bound%s(term=%s)", bp.op, bp.term) +} + +// LiteralPredicate constructs an unbound predicate for an operation that requires +// a single literal argument, such as LessThan or StartsWith. +// +// Panics if the operation provided is not a valid Literal operation, +// if the term is nil or if the literal is nil. +func LiteralPredicate(op Operation, t UnboundTerm, lit Literal) UnboundPredicate { + switch { + case op < OpLT || op > OpNotStartsWith: + panic(fmt.Errorf("%w: invalid operation for LiteralPredicate: %s", + ErrInvalidArgument, op)) + case t == nil: + panic(fmt.Errorf("%w: cannot create literal predicate with nil term", + ErrInvalidArgument)) + case lit == nil: + panic(fmt.Errorf("%w: cannot create literal predicate with nil literal", + ErrInvalidArgument)) + } + + return &unboundLiteralPredicate{op: op, term: t, lit: lit} +} + +type unboundLiteralPredicate struct { + op Operation + term UnboundTerm + lit Literal +} + +func (ul *unboundLiteralPredicate) String() string { + return fmt.Sprintf("%s(term=%s, literal=%s)", ul.op, ul.term, ul.lit) +} + +func (ul *unboundLiteralPredicate) Equals(other BooleanExpression) bool { + rhs, ok := other.(*unboundLiteralPredicate) + if !ok { + return false + } + + return ul.op == rhs.op && ul.term.Equals(rhs.term) && ul.lit.Equals(rhs.lit) +} + +func (ul *unboundLiteralPredicate) Op() Operation { return ul.op } +func (ul *unboundLiteralPredicate) Negate() BooleanExpression { + return &unboundLiteralPredicate{op: ul.op.Negate(), term: ul.term, lit: ul.lit} +} +func (ul *unboundLiteralPredicate) Term() UnboundTerm { return ul.term } +func (ul *unboundLiteralPredicate) Bind(schema *Schema, caseSensitive bool) (BooleanExpression, error) { + bound, err := ul.term.Bind(schema, caseSensitive) + if err != nil { + return nil, err + } + + if (ul.op == OpStartsWith || ul.op == OpNotStartsWith) && + !(bound.Type().Equals(PrimitiveTypes.String) || bound.Type().Equals(PrimitiveTypes.Binary)) { + return nil, fmt.Errorf("%w: StartsWith and NotStartsWith must bind to String type, not %s", + ErrType, bound.Type()) + } + + lit, err := ul.lit.To(bound.Type()) + if err != nil { + return nil, err + } + + switch lit.(type) { + case AboveMaxLiteral: + switch ul.op { + case OpLT, OpLTEQ, OpNEQ: + return AlwaysTrue{}, nil + case OpGT, OpGTEQ, OpEQ: + return AlwaysFalse{}, nil + } + case BelowMinLiteral: + switch ul.op { + case OpLT, OpLTEQ, OpEQ: + return AlwaysFalse{}, nil + case OpGT, OpGTEQ, OpNEQ: + return AlwaysTrue{}, nil + } + } + + return createBoundLiteralPredicate(ul.op, bound, lit) +} + +// BoundLiteralPredicate represents a bound boolean expression that utilizes a single +// literal as an argument, such as Equals or StartsWith. +type BoundLiteralPredicate interface { + BoundPredicate + + Literal() Literal + AsUnbound(Reference, Literal) UnboundPredicate +} + +func newBoundLiteralPredicate[T LiteralType](op Operation, term BoundTerm, lit Literal) BoundPredicate { + return &boundLiteralPredicate[T]{op: op, term: term.(bound[T]), + lit: lit.(TypedLiteral[T])} +} + +func createBoundLiteralPredicate(op Operation, term BoundTerm, lit Literal) (BoundPredicate, error) { + finalLit, err := lit.To(term.Type()) + if err != nil { + return nil, err + } + + switch term.Type().(type) { + case BooleanType: + return newBoundLiteralPredicate[bool](op, term, finalLit), nil + case Int32Type: + return newBoundLiteralPredicate[int32](op, term, finalLit), nil + case Int64Type: + return newBoundLiteralPredicate[int64](op, term, finalLit), nil + case Float32Type: + return newBoundLiteralPredicate[float32](op, term, finalLit), nil + case Float64Type: + return newBoundLiteralPredicate[float64](op, term, finalLit), nil + case DateType: + return newBoundLiteralPredicate[Date](op, term, finalLit), nil + case TimeType: + return newBoundLiteralPredicate[Time](op, term, finalLit), nil + case TimestampType, TimestampTzType: + return newBoundLiteralPredicate[Timestamp](op, term, finalLit), nil + case StringType: + return newBoundLiteralPredicate[string](op, term, finalLit), nil + case FixedType, BinaryType: + return newBoundLiteralPredicate[[]byte](op, term, finalLit), nil + case DecimalType: + return newBoundLiteralPredicate[Decimal](op, term, finalLit), nil + case UUIDType: + return newBoundLiteralPredicate[uuid.UUID](op, term, finalLit), nil + } + return nil, fmt.Errorf("%w: could not create bound literal predicate for term type %s", + ErrInvalidArgument, term.Type()) +} + +type boundLiteralPredicate[T LiteralType] struct { + op Operation + term bound[T] + lit TypedLiteral[T] +} + +func (blp *boundLiteralPredicate[T]) Equals(other BooleanExpression) bool { + rhs, ok := other.(*boundLiteralPredicate[T]) + if !ok { + return false + } + + return blp.op == rhs.op && blp.term.Equals(rhs.term) && blp.lit.Equals(rhs.lit) +} + +func (blp *boundLiteralPredicate[T]) Op() Operation { return blp.op } +func (blp *boundLiteralPredicate[T]) Negate() BooleanExpression { + return &boundLiteralPredicate[T]{op: blp.op.Negate(), term: blp.term, lit: blp.lit} +} +func (blp *boundLiteralPredicate[T]) Term() BoundTerm { return blp.term } +func (blp *boundLiteralPredicate[T]) Ref() BoundReference { return blp.term.Ref() } +func (blp *boundLiteralPredicate[T]) String() string { + return fmt.Sprintf("Bound%s(term=%s, literal=%s)", blp.op, blp.term, blp.lit) +} +func (blp *boundLiteralPredicate[T]) Literal() Literal { return blp.lit } +func (blp *boundLiteralPredicate[T]) AsUnbound(r Reference, l Literal) UnboundPredicate { + return &unboundLiteralPredicate{op: blp.op, term: r, lit: l} +} + +// SetPredicate creates a boolean expression representing a predicate that uses a set +// of literals as the argument, like In or NotIn. Duplicate literals will be folded +// into a set, only maintaining the unique literals. +// +// Will panic if op is not a valid Set operation +func SetPredicate(op Operation, t UnboundTerm, lits []Literal) BooleanExpression { + if op < OpIn || op > OpNotIn { + panic(fmt.Errorf("%w: invalid operation for SetPredicate: %s", + ErrInvalidArgument, op)) + } + + if t == nil { + panic(fmt.Errorf("%w: cannot create set predicate with nil term", + ErrInvalidArgument)) + } + + switch len(lits) { + case 0: + if op == OpIn { + return AlwaysFalse{} + } else if op == OpNotIn { + return AlwaysTrue{} + } + case 1: + if op == OpIn { + return LiteralPredicate(OpEQ, t, lits[0]) + } else if op == OpNotIn { + return LiteralPredicate(OpNEQ, t, lits[0]) + } + } + + return &unboundSetPredicate{op: op, term: t, lits: newLiteralSet(lits...)} +} + +type unboundSetPredicate struct { + op Operation + term UnboundTerm + lits Set[Literal] +} + +func (usp *unboundSetPredicate) String() string { + return fmt.Sprintf("%s(term=%s, {%v})", usp.op, usp.term, usp.lits.Members()) +} + +func (usp *unboundSetPredicate) Equals(other BooleanExpression) bool { + rhs, ok := other.(*unboundSetPredicate) + if !ok { + return false + } + + return usp.op == rhs.op && usp.term.Equals(rhs.term) && + usp.lits.Equals(rhs.lits) +} + +func (usp *unboundSetPredicate) Op() Operation { return usp.op } +func (usp *unboundSetPredicate) Negate() BooleanExpression { + return &unboundSetPredicate{op: usp.op.Negate(), term: usp.term, lits: usp.lits} +} + +func (usp *unboundSetPredicate) Term() UnboundTerm { return usp.term } +func (usp *unboundSetPredicate) Bind(schema *Schema, caseSensitive bool) (BooleanExpression, error) { + bound, err := usp.term.Bind(schema, caseSensitive) + if err != nil { + return nil, err + } + + return createBoundSetPredicate(usp.op, bound, usp.lits) +} + +// BoundSetPredicate is a bound expression that utilizes a set of literals such as In or NotIn +type BoundSetPredicate interface { + BoundPredicate + + Literals() Set[Literal] + AsUnbound(Reference, []Literal) UnboundPredicate +} + +func createBoundSetPredicate(op Operation, term BoundTerm, lits Set[Literal]) (BooleanExpression, error) { + boundType := term.Type() + + typedSet := newLiteralSet() + for _, v := range lits.Members() { + casted, err := v.To(boundType) + if err != nil { + return nil, err + } + typedSet.Add(casted) + } + + switch typedSet.Len() { + case 0: + if op == OpIn { + return AlwaysFalse{}, nil + } else if op == OpNotIn { + return AlwaysTrue{}, nil + } + case 1: + if op == OpIn { + return createBoundLiteralPredicate(OpEQ, term, typedSet.Members()[0]) + } else if op == OpNotIn { + return createBoundLiteralPredicate(OpNEQ, term, typedSet.Members()[0]) + } + } + + switch term.Type().(type) { + case BooleanType: + return newBoundSetPredicate[bool](op, term, typedSet), nil + case Int32Type: + return newBoundSetPredicate[int32](op, term, typedSet), nil + case Int64Type: + return newBoundSetPredicate[int64](op, term, typedSet), nil + case Float32Type: + return newBoundSetPredicate[float32](op, term, typedSet), nil + case Float64Type: + return newBoundSetPredicate[float64](op, term, typedSet), nil + case DateType: + return newBoundSetPredicate[Date](op, term, typedSet), nil + case TimeType: + return newBoundSetPredicate[Time](op, term, typedSet), nil + case TimestampType, TimestampTzType: + return newBoundSetPredicate[Timestamp](op, term, typedSet), nil + case StringType: + return newBoundSetPredicate[string](op, term, typedSet), nil + case BinaryType, FixedType: + return newBoundSetPredicate[[]byte](op, term, typedSet), nil + case DecimalType: + return newBoundSetPredicate[Decimal](op, term, typedSet), nil + case UUIDType: + return newBoundSetPredicate[uuid.UUID](op, term, typedSet), nil + } + + return nil, fmt.Errorf("%w: invalid bound type for set predicate - %s", + ErrType, term.Type()) +} + +func newBoundSetPredicate[T LiteralType](op Operation, term BoundTerm, lits Set[Literal]) *boundSetPredicate[T] { + return &boundSetPredicate[T]{op: op, term: term.(bound[T]), lits: lits} +} + +type boundSetPredicate[T LiteralType] struct { + op Operation + term bound[T] + lits Set[Literal] +} + +func (bsp *boundSetPredicate[T]) Equals(other BooleanExpression) bool { + rhs, ok := other.(*boundSetPredicate[T]) + if !ok { + return false + } + + return bsp.op == rhs.op && bsp.term.Equals(rhs.term) && + bsp.lits.Equals(rhs.lits) +} + +func (bsp *boundSetPredicate[T]) Op() Operation { return bsp.op } +func (bsp *boundSetPredicate[T]) Negate() BooleanExpression { + return &boundSetPredicate[T]{op: bsp.op.Negate(), term: bsp.term, + lits: bsp.lits} +} +func (bsp *boundSetPredicate[T]) Term() BoundTerm { return bsp.term } +func (bsp *boundSetPredicate[T]) Ref() BoundReference { return bsp.term.Ref() } +func (bsp *boundSetPredicate[T]) String() string { + return fmt.Sprintf("Bound%s(term=%s, {%v})", bsp.op, bsp.term, bsp.lits.Members()) +} +func (bsp *boundSetPredicate[T]) AsUnbound(r Reference, lits []Literal) UnboundPredicate { + litSet := newLiteralSet(lits...) + if litSet.Len() == 1 { + switch bsp.op { + case OpIn: + return LiteralPredicate(OpEQ, r, lits[0]) + case OpNotIn: + return LiteralPredicate(OpNEQ, r, lits[0]) + } + } + + return &unboundSetPredicate{op: bsp.op, term: r, lits: litSet} +} + +func (bsp *boundSetPredicate[T]) Literals() Set[Literal] { + return bsp.lits +} + +type BoundTransform struct { + transform Transform + term BoundTerm +} + +func (*BoundTransform) isTerm() {} +func (b *BoundTransform) String() string { + return fmt.Sprintf("BoundTransform(transform=%s, term=%s)", + b.transform, b.term) +} + +func (b *BoundTransform) Ref() BoundReference { return b.term.Ref() } +func (b *BoundTransform) Type() Type { return b.transform.ResultType(b.term.Type()) } + +func (b *BoundTransform) Equals(other BoundTerm) bool { + rhs, ok := other.(*BoundTransform) + if !ok { + return false + } + + return b.transform.Equals(rhs.transform) && b.term.Equals(rhs.term) +} + +func (b *BoundTransform) evalToLiteral(st structLike) Optional[Literal] { + return b.transform.Apply(b.term.evalToLiteral(st)) +} + +func (b *BoundTransform) evalIsNull(st structLike) bool { + return !b.evalToLiteral(st).Valid +} diff --git a/exprs_test.go b/exprs_test.go index 3ea5257..b88fc35 100644 --- a/exprs_test.go +++ b/exprs_test.go @@ -1,742 +1,742 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package iceberg_test - -import ( - "math" - "strconv" - "testing" - - "github.com/apache/iceberg-go" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -type ExprA struct{} - -func (ExprA) String() string { return "ExprA" } -func (ExprA) Op() iceberg.Operation { return iceberg.OpFalse } -func (ExprA) Negate() iceberg.BooleanExpression { return ExprB{} } -func (ExprA) Equals(o iceberg.BooleanExpression) bool { - _, ok := o.(ExprA) - return ok -} - -type ExprB struct{} - -func (ExprB) String() string { return "ExprB" } -func (ExprB) Op() iceberg.Operation { return iceberg.OpTrue } -func (ExprB) Negate() iceberg.BooleanExpression { return ExprA{} } -func (ExprB) Equals(o iceberg.BooleanExpression) bool { - _, ok := o.(ExprB) - return ok -} - -func TestUnaryExpr(t *testing.T) { - assert.PanicsWithError(t, "invalid argument: invalid operation for unary predicate: LessThan", func() { - iceberg.UnaryPredicate(iceberg.OpLT, iceberg.Reference("a")) - }) - - assert.PanicsWithError(t, "invalid argument: cannot create unary predicate with nil term", func() { - iceberg.UnaryPredicate(iceberg.OpIsNull, nil) - }) - - t.Run("negate", func(t *testing.T) { - n := iceberg.IsNull(iceberg.Reference("a")).Negate() - exp := iceberg.NotNull(iceberg.Reference("a")) - - assert.Equal(t, exp, n) - assert.True(t, exp.Equals(n)) - assert.True(t, n.Equals(exp)) - }) - - sc := iceberg.NewSchema(1, iceberg.NestedField{ - ID: 2, Name: "a", Type: iceberg.PrimitiveTypes.Int32}) - sc2 := iceberg.NewSchema(1, iceberg.NestedField{ - ID: 2, Name: "a", Type: iceberg.PrimitiveTypes.Float64}) - sc3 := iceberg.NewSchema(1, iceberg.NestedField{ - ID: 2, Name: "a", Type: iceberg.PrimitiveTypes.Int32, Required: true}) - sc4 := iceberg.NewSchema(1, iceberg.NestedField{ - ID: 2, Name: "a", Type: iceberg.PrimitiveTypes.Float32, Required: true}) - - t.Run("isnull and notnull", func(t *testing.T) { - t.Run("bind", func(t *testing.T) { - n, err := iceberg.IsNull(iceberg.Reference("a")).Bind(sc, true) - require.NoError(t, err) - - assert.Equal(t, iceberg.OpIsNull, n.Op()) - assert.Implements(t, (*iceberg.BoundUnaryPredicate)(nil), n) - p := n.(iceberg.BoundUnaryPredicate) - assert.IsType(t, iceberg.PrimitiveTypes.Int32, p.Term().Type()) - assert.Same(t, p.Ref(), p.Term().Ref()) - assert.Same(t, p.Ref(), p.Ref().Ref()) - - f := p.Ref().Field() - assert.True(t, f.Equals(sc.Field(0))) - }) - - t.Run("negate and bind", func(t *testing.T) { - n1, err := iceberg.IsNull(iceberg.Reference("a")).Bind(sc, true) - require.NoError(t, err) - - n2, err := iceberg.NotNull(iceberg.Reference("a")).Bind(sc, true) - require.NoError(t, err) - - assert.True(t, n1.Negate().Equals(n2)) - assert.True(t, n2.Negate().Equals(n1)) - }) - - t.Run("null bind required", func(t *testing.T) { - n1, err := iceberg.IsNull(iceberg.Reference("a")).Bind(sc3, true) - require.NoError(t, err) - - n2, err := iceberg.NotNull(iceberg.Reference("a")).Bind(sc3, true) - require.NoError(t, err) - - assert.True(t, n1.Equals(iceberg.AlwaysFalse{})) - assert.True(t, n2.Equals(iceberg.AlwaysTrue{})) - }) - }) - - t.Run("isnan notnan", func(t *testing.T) { - t.Run("negate and bind", func(t *testing.T) { - n1, err := iceberg.IsNaN(iceberg.Reference("a")).Bind(sc2, true) - require.NoError(t, err) - - n2, err := iceberg.NotNaN(iceberg.Reference("a")).Bind(sc2, true) - require.NoError(t, err) - - assert.True(t, n1.Negate().Equals(n2)) - assert.True(t, n2.Negate().Equals(n1)) - }) - - t.Run("bind float", func(t *testing.T) { - n, err := iceberg.IsNaN(iceberg.Reference("a")).Bind(sc4, true) - require.NoError(t, err) - - assert.Equal(t, iceberg.OpIsNan, n.Op()) - assert.Implements(t, (*iceberg.BoundUnaryPredicate)(nil), n) - p := n.(iceberg.BoundUnaryPredicate) - assert.IsType(t, iceberg.PrimitiveTypes.Float32, p.Term().Type()) - - n2, err := iceberg.NotNaN(iceberg.Reference("a")).Bind(sc4, true) - require.NoError(t, err) - - assert.Equal(t, iceberg.OpNotNan, n2.Op()) - assert.Implements(t, (*iceberg.BoundUnaryPredicate)(nil), n2) - p2 := n2.(iceberg.BoundUnaryPredicate) - assert.IsType(t, iceberg.PrimitiveTypes.Float32, p2.Term().Type()) - }) - - t.Run("bind double", func(t *testing.T) { - n, err := iceberg.IsNaN(iceberg.Reference("a")).Bind(sc2, true) - require.NoError(t, err) - - assert.Equal(t, iceberg.OpIsNan, n.Op()) - assert.Implements(t, (*iceberg.BoundUnaryPredicate)(nil), n) - p := n.(iceberg.BoundUnaryPredicate) - assert.IsType(t, iceberg.PrimitiveTypes.Float64, p.Term().Type()) - - n2, err := iceberg.NotNaN(iceberg.Reference("a")).Bind(sc2, true) - require.NoError(t, err) - - assert.Equal(t, iceberg.OpNotNan, n2.Op()) - assert.Implements(t, (*iceberg.BoundUnaryPredicate)(nil), n2) - p2 := n2.(iceberg.BoundUnaryPredicate) - assert.IsType(t, iceberg.PrimitiveTypes.Float64, p2.Term().Type()) - }) - - t.Run("bind non floating", func(t *testing.T) { - n1, err := iceberg.IsNaN(iceberg.Reference("a")).Bind(sc, true) - require.NoError(t, err) - - n2, err := iceberg.NotNaN(iceberg.Reference("a")).Bind(sc, true) - require.NoError(t, err) - - assert.True(t, n1.Equals(iceberg.AlwaysFalse{})) - assert.True(t, n2.Equals(iceberg.AlwaysTrue{})) - }) - }) -} - -func TestRefBindingCaseSensitive(t *testing.T) { - ref1, ref2 := iceberg.Reference("foo"), iceberg.Reference("Foo") - - bound1, err := ref1.Bind(tableSchemaSimple, true) - require.NoError(t, err) - assert.True(t, bound1.Type().Equals(iceberg.PrimitiveTypes.String)) - - _, err = ref2.Bind(tableSchemaSimple, true) - assert.ErrorIs(t, err, iceberg.ErrInvalidSchema) - assert.ErrorContains(t, err, "could not bind reference 'Foo', caseSensitive=true") - - bound2, err := ref2.Bind(tableSchemaSimple, false) - require.NoError(t, err) - assert.True(t, bound1.Equals(bound2)) - - _, err = iceberg.Reference("foot").Bind(tableSchemaSimple, false) - assert.ErrorIs(t, err, iceberg.ErrInvalidSchema) - assert.ErrorContains(t, err, "could not bind reference 'foot', caseSensitive=false") -} - -func TestRefTypes(t *testing.T) { - sc := iceberg.NewSchema(1, - iceberg.NestedField{ID: 1, Name: "a", Type: iceberg.PrimitiveTypes.Bool}, - iceberg.NestedField{ID: 2, Name: "b", Type: iceberg.PrimitiveTypes.Int32}, - iceberg.NestedField{ID: 3, Name: "c", Type: iceberg.PrimitiveTypes.Int64}, - iceberg.NestedField{ID: 4, Name: "d", Type: iceberg.PrimitiveTypes.Float32}, - iceberg.NestedField{ID: 5, Name: "e", Type: iceberg.PrimitiveTypes.Float64}, - iceberg.NestedField{ID: 6, Name: "f", Type: iceberg.PrimitiveTypes.Date}, - iceberg.NestedField{ID: 7, Name: "g", Type: iceberg.PrimitiveTypes.Time}, - iceberg.NestedField{ID: 8, Name: "h", Type: iceberg.PrimitiveTypes.Timestamp}, - iceberg.NestedField{ID: 9, Name: "i", Type: iceberg.DecimalTypeOf(9, 2)}, - iceberg.NestedField{ID: 10, Name: "j", Type: iceberg.PrimitiveTypes.String}, - iceberg.NestedField{ID: 11, Name: "k", Type: iceberg.PrimitiveTypes.Binary}, - iceberg.NestedField{ID: 12, Name: "l", Type: iceberg.PrimitiveTypes.UUID}, - iceberg.NestedField{ID: 13, Name: "m", Type: iceberg.FixedTypeOf(5)}) - - t.Run("bind term", func(t *testing.T) { - for i := 0; i < sc.NumFields(); i++ { - fld := sc.Field(i) - t.Run(fld.Type.String(), func(t *testing.T) { - ref, err := iceberg.Reference(fld.Name).Bind(sc, true) - require.NoError(t, err) - - assert.True(t, ref.Type().Equals(fld.Type)) - assert.True(t, fld.Equals(ref.Ref().Field())) - }) - } - }) - - t.Run("bind unary", func(t *testing.T) { - for i := 0; i < sc.NumFields(); i++ { - fld := sc.Field(i) - t.Run(fld.Type.String(), func(t *testing.T) { - b, err := iceberg.IsNull(iceberg.Reference(fld.Name)).Bind(sc, true) - require.NoError(t, err) - - assert.True(t, b.(iceberg.BoundUnaryPredicate).Ref().Type().Equals(fld.Type)) - - un := b.(iceberg.BoundUnaryPredicate).AsUnbound(iceberg.Reference("foo")) - assert.Equal(t, b.Op(), un.Op()) - }) - } - }) - - t.Run("bind literal", func(t *testing.T) { - t.Run("bool", func(t *testing.T) { - b1, err := iceberg.EqualTo(iceberg.Reference("a"), true).Bind(sc, true) - require.NoError(t, err) - assert.Equal(t, iceberg.OpEQ, b1.Op()) - assert.True(t, b1.(iceberg.BoundLiteralPredicate).Ref().Type().Equals(iceberg.PrimitiveTypes.Bool)) - }) - - for i := 1; i < 9; i++ { - fld := sc.Field(i) - t.Run(fld.Type.String(), func(t *testing.T) { - b, err := iceberg.EqualTo(iceberg.Reference(fld.Name), int32(5)).Bind(sc, true) - require.NoError(t, err) - - assert.Equal(t, iceberg.OpEQ, b.Op()) - assert.True(t, b.(iceberg.BoundLiteralPredicate).Literal().Type().Equals(fld.Type)) - assert.True(t, b.(iceberg.BoundLiteralPredicate).Ref().Type().Equals(fld.Type)) - }) - } - - t.Run("string-binary", func(t *testing.T) { - str, err := iceberg.EqualTo(iceberg.Reference("j"), "foobar").Bind(sc, true) - require.NoError(t, err) - - bin, err := iceberg.EqualTo(iceberg.Reference("k"), []byte("foobar")).Bind(sc, true) - require.NoError(t, err) - - assert.Equal(t, iceberg.OpEQ, str.Op()) - assert.True(t, str.(iceberg.BoundLiteralPredicate).Literal().Type().Equals(iceberg.PrimitiveTypes.String)) - assert.Equal(t, iceberg.OpEQ, bin.Op()) - assert.True(t, bin.(iceberg.BoundLiteralPredicate).Literal().Type().Equals(iceberg.PrimitiveTypes.Binary)) - }) - - t.Run("fixed", func(t *testing.T) { - fx, err := iceberg.EqualTo(iceberg.Reference("m"), []byte{0, 1, 2, 3, 4}).Bind(sc, true) - require.NoError(t, err) - - assert.Equal(t, iceberg.OpEQ, fx.Op()) - assert.True(t, fx.(iceberg.BoundLiteralPredicate).Literal().Type().Equals(iceberg.FixedTypeOf(5))) - }) - - t.Run("uuid", func(t *testing.T) { - uid, err := iceberg.EqualTo(iceberg.Reference("l"), uuid.New().String()).Bind(sc, true) - require.NoError(t, err) - - assert.Equal(t, iceberg.OpEQ, uid.Op()) - assert.True(t, uid.(iceberg.BoundLiteralPredicate).Literal().Type().Equals(iceberg.PrimitiveTypes.UUID)) - }) - }) - - t.Run("bind set", func(t *testing.T) { - t.Run("bool", func(t *testing.T) { - b, err := iceberg.IsIn(iceberg.Reference("a"), true, false).(iceberg.UnboundPredicate).Bind(sc, true) - require.NoError(t, err) - - assert.Equal(t, iceberg.OpIn, b.Op()) - }) - - for i := 1; i < 9; i++ { - fld := sc.Field(i) - t.Run(fld.Type.String(), func(t *testing.T) { - b, err := iceberg.IsIn(iceberg.Reference(fld.Name), int32(10), int32(5), int32(5)).(iceberg.UnboundPredicate). - Bind(sc, true) - require.NoError(t, err) - - assert.Equal(t, iceberg.OpIn, b.Op()) - assert.True(t, b.(iceberg.BoundSetPredicate).Ref().Type().Equals(fld.Type)) - for _, v := range b.(iceberg.BoundSetPredicate).Literals().Members() { - assert.True(t, v.Type().Equals(fld.Type)) - } - }) - } - - t.Run("string-binary", func(t *testing.T) { - str, err := iceberg.IsIn(iceberg.Reference("j"), "hello", "foobar").(iceberg.UnboundPredicate). - Bind(sc, true) - require.NoError(t, err) - - bin, err := iceberg.IsIn(iceberg.Reference("k"), []byte("baz"), []byte("foobar")).(iceberg.UnboundPredicate). - Bind(sc, true) - require.NoError(t, err) - - assert.Equal(t, iceberg.OpIn, str.Op()) - assert.Equal(t, iceberg.OpIn, bin.Op()) - - assert.True(t, str.(iceberg.BoundSetPredicate).Ref().Type().Equals(iceberg.PrimitiveTypes.String)) - for _, v := range str.(iceberg.BoundSetPredicate).Literals().Members() { - assert.True(t, v.Type().Equals(iceberg.PrimitiveTypes.String)) - } - - assert.True(t, bin.(iceberg.BoundSetPredicate).Ref().Type().Equals(iceberg.PrimitiveTypes.Binary)) - for _, v := range bin.(iceberg.BoundSetPredicate).Literals().Members() { - assert.True(t, v.Type().Equals(iceberg.PrimitiveTypes.Binary)) - } - }) - - t.Run("fixed", func(t *testing.T) { - fx, err := iceberg.IsIn(iceberg.Reference("m"), []byte{4, 5, 6, 7, 8}, []byte{0, 1, 2, 3, 4}).(iceberg.UnboundPredicate). - Bind(sc, true) - require.NoError(t, err) - - assert.Equal(t, iceberg.OpIn, fx.Op()) - assert.True(t, fx.(iceberg.BoundSetPredicate).Ref().Type().Equals(iceberg.FixedTypeOf(5))) - for _, v := range fx.(iceberg.BoundSetPredicate).Literals().Members() { - assert.True(t, v.Type().Equals(iceberg.FixedTypeOf(5))) - } - }) - - t.Run("uuid", func(t *testing.T) { - uid, err := iceberg.IsIn(iceberg.Reference("l"), uuid.New().String(), uuid.New().String()).(iceberg.UnboundPredicate). - Bind(sc, true) - require.NoError(t, err) - - assert.Equal(t, iceberg.OpIn, uid.Op()) - assert.True(t, uid.(iceberg.BoundSetPredicate).Ref().Type().Equals(iceberg.PrimitiveTypes.UUID)) - for _, v := range uid.(iceberg.BoundSetPredicate).Literals().Members() { - assert.True(t, v.Type().Equals(iceberg.PrimitiveTypes.UUID)) - } - }) - }) -} - -func TestInNotInSimplifications(t *testing.T) { - assert.PanicsWithError(t, "invalid argument: invalid operation for SetPredicate: LessThan", - func() { iceberg.SetPredicate(iceberg.OpLT, iceberg.Reference("x"), nil) }) - assert.PanicsWithError(t, "invalid argument: cannot create set predicate with nil term", - func() { iceberg.SetPredicate(iceberg.OpIn, nil, nil) }) - assert.NotPanics(t, func() { iceberg.SetPredicate(iceberg.OpIn, iceberg.Reference("x"), nil) }) - - t.Run("in to eq", func(t *testing.T) { - a := iceberg.IsIn(iceberg.Reference("x"), 34.56) - b := iceberg.EqualTo(iceberg.Reference("x"), 34.56) - assert.True(t, a.Equals(b)) - }) - - t.Run("notin to notequal", func(t *testing.T) { - a := iceberg.NotIn(iceberg.Reference("x"), 34.56) - b := iceberg.NotEqualTo(iceberg.Reference("x"), 34.56) - assert.True(t, a.Equals(b)) - }) - - t.Run("empty", func(t *testing.T) { - a := iceberg.IsIn[float32](iceberg.Reference("x")) - b := iceberg.NotIn[float32](iceberg.Reference("x")) - - assert.Equal(t, iceberg.AlwaysFalse{}, a) - assert.Equal(t, iceberg.AlwaysTrue{}, b) - }) - - t.Run("bind and negate", func(t *testing.T) { - inexp := iceberg.IsIn(iceberg.Reference("foo"), "hello", "world") - notin := iceberg.NotIn(iceberg.Reference("foo"), "hello", "world") - assert.True(t, inexp.Negate().Equals(notin)) - assert.True(t, notin.Negate().Equals(inexp)) - assert.Equal(t, iceberg.OpIn, inexp.Op()) - assert.Equal(t, iceberg.OpNotIn, notin.Op()) - - boundin, err := inexp.(iceberg.UnboundPredicate).Bind(tableSchemaSimple, true) - require.NoError(t, err) - - boundnot, err := notin.(iceberg.UnboundPredicate).Bind(tableSchemaSimple, true) - require.NoError(t, err) - - assert.True(t, boundin.Negate().Equals(boundnot)) - assert.True(t, boundnot.Negate().Equals(boundin)) - }) - - t.Run("bind dedup", func(t *testing.T) { - isin := iceberg.IsIn(iceberg.Reference("foo"), "hello", "world", "world") - bound, err := isin.(iceberg.UnboundPredicate).Bind(tableSchemaSimple, true) - require.NoError(t, err) - - assert.Implements(t, (*iceberg.BoundSetPredicate)(nil), bound) - bsp := bound.(iceberg.BoundSetPredicate) - assert.Equal(t, 2, bsp.Literals().Len()) - assert.True(t, bsp.Literals().Contains(iceberg.NewLiteral("hello"))) - assert.True(t, bsp.Literals().Contains(iceberg.NewLiteral("world"))) - }) - - t.Run("bind dedup to eq", func(t *testing.T) { - isin := iceberg.IsIn(iceberg.Reference("foo"), "world", "world") - bound, err := isin.(iceberg.UnboundPredicate).Bind(tableSchemaSimple, true) - require.NoError(t, err) - - assert.Equal(t, iceberg.OpEQ, bound.Op()) - assert.Equal(t, iceberg.NewLiteral("world"), - bound.(iceberg.BoundLiteralPredicate).Literal()) - }) -} - -func TestLiteralPredicateErrors(t *testing.T) { - assert.PanicsWithError(t, "invalid argument: invalid operation for LiteralPredicate: In", - func() { iceberg.LiteralPredicate(iceberg.OpIn, iceberg.Reference("foo"), iceberg.NewLiteral("hello")) }) - assert.PanicsWithError(t, "invalid argument: cannot create literal predicate with nil term", - func() { iceberg.LiteralPredicate(iceberg.OpLT, nil, iceberg.NewLiteral("hello")) }) - assert.PanicsWithError(t, "invalid argument: cannot create literal predicate with nil literal", - func() { iceberg.LiteralPredicate(iceberg.OpLT, iceberg.Reference("foo"), nil) }) -} - -func TestNegations(t *testing.T) { - ref := iceberg.Reference("foo") - - tests := []struct { - name string - ex1, ex2 iceberg.UnboundPredicate - }{ - {"equal-not", iceberg.EqualTo(ref, "hello"), iceberg.NotEqualTo(ref, "hello")}, - {"greater-equal-less", iceberg.GreaterThanEqual(ref, "hello"), iceberg.LessThan(ref, "hello")}, - {"greater-less-equal", iceberg.GreaterThan(ref, "hello"), iceberg.LessThanEqual(ref, "hello")}, - {"starts-with", iceberg.StartsWith(ref, "hello"), iceberg.NotStartsWith(ref, "hello")}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - assert.False(t, tt.ex1.Equals(tt.ex2)) - assert.False(t, tt.ex2.Equals(tt.ex1)) - assert.True(t, tt.ex1.Negate().Equals(tt.ex2)) - assert.True(t, tt.ex2.Negate().Equals(tt.ex1)) - - b1, err := tt.ex1.Bind(tableSchemaSimple, true) - require.NoError(t, err) - b2, err := tt.ex2.Bind(tableSchemaSimple, true) - require.NoError(t, err) - - assert.False(t, b1.Equals(b2)) - assert.False(t, b2.Equals(b1)) - assert.True(t, b1.Negate().Equals(b2)) - assert.True(t, b2.Negate().Equals(b1)) - }) - } -} - -func TestBoolExprEQ(t *testing.T) { - tests := []struct { - exp, testexpra, testexprb iceberg.BooleanExpression - }{ - {iceberg.NewAnd(ExprA{}, ExprB{}), - iceberg.NewAnd(ExprA{}, ExprB{}), - iceberg.NewOr(ExprA{}, ExprB{})}, - {iceberg.NewOr(ExprA{}, ExprB{}), - iceberg.NewOr(ExprA{}, ExprB{}), - iceberg.NewAnd(ExprA{}, ExprB{})}, - {iceberg.NewAnd(ExprA{}, ExprB{}), - iceberg.NewAnd(ExprB{}, ExprA{}), - iceberg.NewOr(ExprB{}, ExprA{})}, - {iceberg.NewOr(ExprA{}, ExprB{}), - iceberg.NewOr(ExprB{}, ExprA{}), - iceberg.NewAnd(ExprB{}, ExprA{})}, - {iceberg.NewNot(ExprA{}), iceberg.NewNot(ExprA{}), ExprB{}}, - {ExprA{}, ExprA{}, ExprB{}}, - {ExprB{}, ExprB{}, ExprA{}}, - {iceberg.IsIn(iceberg.Reference("foo"), "hello", "world"), - iceberg.IsIn(iceberg.Reference("foo"), "hello", "world"), - iceberg.IsIn(iceberg.Reference("not_foo"), "hello", "world")}, - {iceberg.IsIn(iceberg.Reference("foo"), "hello", "world"), - iceberg.IsIn(iceberg.Reference("foo"), "hello", "world"), - iceberg.IsIn(iceberg.Reference("foo"), "goodbye", "world")}, - } - - for i, tt := range tests { - t.Run(strconv.Itoa(i), func(t *testing.T) { - assert.True(t, tt.exp.Equals(tt.testexpra)) - assert.False(t, tt.exp.Equals(tt.testexprb)) - }) - } -} - -func TestBoolExprNegate(t *testing.T) { - tests := []struct { - lhs, rhs iceberg.BooleanExpression - }{ - {iceberg.NewAnd(ExprA{}, ExprB{}), iceberg.NewOr(ExprB{}, ExprA{})}, - {iceberg.NewOr(ExprB{}, ExprA{}), iceberg.NewAnd(ExprA{}, ExprB{})}, - {iceberg.NewNot(ExprA{}), ExprA{}}, - {iceberg.IsIn(iceberg.Reference("foo"), "hello", "world"), - iceberg.NotIn(iceberg.Reference("foo"), "hello", "world")}, - {iceberg.NotIn(iceberg.Reference("foo"), "hello", "world"), - iceberg.IsIn(iceberg.Reference("foo"), "hello", "world")}, - {iceberg.GreaterThan(iceberg.Reference("foo"), int32(5)), - iceberg.LessThanEqual(iceberg.Reference("foo"), int32(5))}, - {iceberg.LessThan(iceberg.Reference("foo"), int32(5)), - iceberg.GreaterThanEqual(iceberg.Reference("foo"), int32(5))}, - {iceberg.EqualTo(iceberg.Reference("foo"), int32(5)), - iceberg.NotEqualTo(iceberg.Reference("foo"), int32(5))}, - {ExprA{}, ExprB{}}, - } - - for _, tt := range tests { - assert.True(t, tt.lhs.Negate().Equals(tt.rhs)) - } -} - -func TestBoolExprPanics(t *testing.T) { - assert.PanicsWithError(t, "invalid argument: cannot construct AndExpr with nil arguments", - func() { iceberg.NewAnd(nil, ExprA{}) }) - assert.PanicsWithError(t, "invalid argument: cannot construct AndExpr with nil arguments", - func() { iceberg.NewAnd(ExprA{}, nil) }) - assert.PanicsWithError(t, "invalid argument: cannot construct AndExpr with nil arguments", - func() { iceberg.NewAnd(ExprA{}, ExprA{}, nil) }) - - assert.PanicsWithError(t, "invalid argument: cannot construct OrExpr with nil arguments", - func() { iceberg.NewOr(nil, ExprA{}) }) - assert.PanicsWithError(t, "invalid argument: cannot construct OrExpr with nil arguments", - func() { iceberg.NewOr(ExprA{}, nil) }) - assert.PanicsWithError(t, "invalid argument: cannot construct OrExpr with nil arguments", - func() { iceberg.NewOr(ExprA{}, ExprA{}, nil) }) - - assert.PanicsWithError(t, "invalid argument: cannot create NotExpr with nil child", - func() { iceberg.NewNot(nil) }) -} - -func TestExprFolding(t *testing.T) { - tests := []struct { - lhs, rhs iceberg.BooleanExpression - }{ - {iceberg.NewAnd(ExprA{}, ExprB{}, ExprA{}), - iceberg.NewAnd(iceberg.NewAnd(ExprA{}, ExprB{}), ExprA{})}, - {iceberg.NewOr(ExprA{}, ExprB{}, ExprA{}), - iceberg.NewOr(iceberg.NewOr(ExprA{}, ExprB{}), ExprA{})}, - {iceberg.NewNot(iceberg.NewNot(ExprA{})), ExprA{}}, - } - - for _, tt := range tests { - assert.True(t, tt.lhs.Equals(tt.rhs)) - } -} - -func TestBaseAlwaysTrueAlwaysFalse(t *testing.T) { - tests := []struct { - lhs, rhs iceberg.BooleanExpression - }{ - {iceberg.NewAnd(iceberg.AlwaysTrue{}, ExprB{}), ExprB{}}, - {iceberg.NewAnd(iceberg.AlwaysFalse{}, ExprB{}), iceberg.AlwaysFalse{}}, - {iceberg.NewAnd(ExprB{}, iceberg.AlwaysTrue{}), ExprB{}}, - {iceberg.NewOr(iceberg.AlwaysTrue{}, ExprB{}), iceberg.AlwaysTrue{}}, - {iceberg.NewOr(iceberg.AlwaysFalse{}, ExprB{}), ExprB{}}, - {iceberg.NewOr(ExprA{}, iceberg.AlwaysFalse{}), ExprA{}}, - {iceberg.NewNot(iceberg.NewNot(ExprA{})), ExprA{}}, - {iceberg.NewNot(iceberg.AlwaysTrue{}), iceberg.AlwaysFalse{}}, - {iceberg.NewNot(iceberg.AlwaysFalse{}), iceberg.AlwaysTrue{}}, - } - - for _, tt := range tests { - assert.True(t, tt.lhs.Equals(tt.rhs)) - } -} - -func TestNegateAlways(t *testing.T) { - assert.Equal(t, iceberg.OpTrue, iceberg.AlwaysTrue{}.Op()) - assert.Equal(t, iceberg.OpFalse, iceberg.AlwaysFalse{}.Op()) - - assert.Equal(t, iceberg.AlwaysTrue{}, iceberg.AlwaysFalse{}.Negate()) - assert.Equal(t, iceberg.AlwaysFalse{}, iceberg.AlwaysTrue{}.Negate()) -} - -func TestBoundReferenceToString(t *testing.T) { - ref, err := iceberg.Reference("foo").Bind(tableSchemaSimple, true) - require.NoError(t, err) - - assert.Equal(t, "BoundReference(field=1: foo: optional string, accessor=Accessor(position=0, inner=))", - ref.String()) -} - -func TestToString(t *testing.T) { - schema := iceberg.NewSchema(1, - iceberg.NestedField{ID: 1, Name: "a", Type: iceberg.PrimitiveTypes.String}, - iceberg.NestedField{ID: 2, Name: "b", Type: iceberg.PrimitiveTypes.String}, - iceberg.NestedField{ID: 3, Name: "c", Type: iceberg.PrimitiveTypes.String}, - iceberg.NestedField{ID: 4, Name: "d", Type: iceberg.PrimitiveTypes.Int32}, - iceberg.NestedField{ID: 5, Name: "e", Type: iceberg.PrimitiveTypes.Int32}, - iceberg.NestedField{ID: 6, Name: "f", Type: iceberg.PrimitiveTypes.Int32}, - iceberg.NestedField{ID: 7, Name: "g", Type: iceberg.PrimitiveTypes.Float32}, - iceberg.NestedField{ID: 8, Name: "h", Type: iceberg.DecimalTypeOf(8, 4)}, - iceberg.NestedField{ID: 9, Name: "i", Type: iceberg.PrimitiveTypes.UUID}, - iceberg.NestedField{ID: 10, Name: "j", Type: iceberg.PrimitiveTypes.Bool}, - iceberg.NestedField{ID: 11, Name: "k", Type: iceberg.PrimitiveTypes.Bool}, - iceberg.NestedField{ID: 12, Name: "l", Type: iceberg.PrimitiveTypes.Binary}) - - null := iceberg.IsNull(iceberg.Reference("a")) - nan := iceberg.IsNaN(iceberg.Reference("g")) - boundNull, _ := null.Bind(schema, true) - boundNan, _ := nan.Bind(schema, true) - - equal := iceberg.EqualTo(iceberg.Reference("c"), "a") - grtequal := iceberg.GreaterThanEqual(iceberg.Reference("a"), "a") - greater := iceberg.GreaterThan(iceberg.Reference("a"), "a") - startsWith := iceberg.StartsWith(iceberg.Reference("b"), "foo") - - boundEqual, _ := equal.Bind(schema, true) - boundGrtEqual, _ := grtequal.Bind(schema, true) - boundGreater, _ := greater.Bind(schema, true) - boundStarts, _ := startsWith.Bind(schema, true) - - tests := []struct { - e iceberg.BooleanExpression - expected string - }{ - {iceberg.NewAnd(null, nan), - "And(left=IsNull(term=Reference(name='a')), right=IsNaN(term=Reference(name='g')))"}, - {iceberg.NewOr(null, nan), - "Or(left=IsNull(term=Reference(name='a')), right=IsNaN(term=Reference(name='g')))"}, - {iceberg.NewNot(null), - "Not(child=IsNull(term=Reference(name='a')))"}, - {iceberg.AlwaysTrue{}, "AlwaysTrue()"}, - {iceberg.AlwaysFalse{}, "AlwaysFalse()"}, - {boundNull, - "BoundIsNull(term=BoundReference(field=1: a: optional string, accessor=Accessor(position=0, inner=)))"}, - {boundNull.Negate(), - "BoundNotNull(term=BoundReference(field=1: a: optional string, accessor=Accessor(position=0, inner=)))"}, - {boundNan, - "BoundIsNaN(term=BoundReference(field=7: g: optional float, accessor=Accessor(position=6, inner=)))"}, - {boundNan.Negate(), - "BoundNotNaN(term=BoundReference(field=7: g: optional float, accessor=Accessor(position=6, inner=)))"}, - {equal, - "Equal(term=Reference(name='c'), literal=a)"}, - {equal.Negate(), - "NotEqual(term=Reference(name='c'), literal=a)"}, - {grtequal, - "GreaterThanEqual(term=Reference(name='a'), literal=a)"}, - {grtequal.Negate(), - "LessThan(term=Reference(name='a'), literal=a)"}, - {greater, - "GreaterThan(term=Reference(name='a'), literal=a)"}, - {greater.Negate(), - "LessThanEqual(term=Reference(name='a'), literal=a)"}, - {startsWith, - "StartsWith(term=Reference(name='b'), literal=foo)"}, - {startsWith.Negate(), - "NotStartsWith(term=Reference(name='b'), literal=foo)"}, - {boundEqual, - "BoundEqual(term=BoundReference(field=3: c: optional string, accessor=Accessor(position=2, inner=)), literal=a)"}, - {boundEqual.Negate(), - "BoundNotEqual(term=BoundReference(field=3: c: optional string, accessor=Accessor(position=2, inner=)), literal=a)"}, - {boundGreater, - "BoundGreaterThan(term=BoundReference(field=1: a: optional string, accessor=Accessor(position=0, inner=)), literal=a)"}, - {boundGreater.Negate(), - "BoundLessThanEqual(term=BoundReference(field=1: a: optional string, accessor=Accessor(position=0, inner=)), literal=a)"}, - {boundGrtEqual, - "BoundGreaterThanEqual(term=BoundReference(field=1: a: optional string, accessor=Accessor(position=0, inner=)), literal=a)"}, - {boundGrtEqual.Negate(), - "BoundLessThan(term=BoundReference(field=1: a: optional string, accessor=Accessor(position=0, inner=)), literal=a)"}, - {boundStarts, - "BoundStartsWith(term=BoundReference(field=2: b: optional string, accessor=Accessor(position=1, inner=)), literal=foo)"}, - {boundStarts.Negate(), - "BoundNotStartsWith(term=BoundReference(field=2: b: optional string, accessor=Accessor(position=1, inner=)), literal=foo)"}, - } - - for _, tt := range tests { - assert.Equal(t, tt.expected, tt.e.String()) - } -} - -func TestBindAboveBelowIntMax(t *testing.T) { - sc := iceberg.NewSchema(1, - iceberg.NestedField{ID: 1, Name: "a", Type: iceberg.PrimitiveTypes.Int32}, - iceberg.NestedField{ID: 2, Name: "b", Type: iceberg.PrimitiveTypes.Float32}, - ) - - ref, ref2 := iceberg.Reference("a"), iceberg.Reference("b") - above, below := int64(math.MaxInt32)+1, int64(math.MinInt32)-1 - above2, below2 := float64(math.MaxFloat32)+1e37, float64(-math.MaxFloat32)-1e37 - - tests := []struct { - pred iceberg.UnboundPredicate - exp iceberg.BooleanExpression - }{ - {iceberg.EqualTo(ref, above), iceberg.AlwaysFalse{}}, - {iceberg.EqualTo(ref, below), iceberg.AlwaysFalse{}}, - {iceberg.NotEqualTo(ref, above), iceberg.AlwaysTrue{}}, - {iceberg.NotEqualTo(ref, below), iceberg.AlwaysTrue{}}, - {iceberg.LessThan(ref, above), iceberg.AlwaysTrue{}}, - {iceberg.LessThan(ref, below), iceberg.AlwaysFalse{}}, - {iceberg.LessThanEqual(ref, above), iceberg.AlwaysTrue{}}, - {iceberg.LessThanEqual(ref, below), iceberg.AlwaysFalse{}}, - {iceberg.GreaterThan(ref, above), iceberg.AlwaysFalse{}}, - {iceberg.GreaterThan(ref, below), iceberg.AlwaysTrue{}}, - {iceberg.GreaterThanEqual(ref, above), iceberg.AlwaysFalse{}}, - {iceberg.GreaterThanEqual(ref, below), iceberg.AlwaysTrue{}}, - - {iceberg.EqualTo(ref2, above2), iceberg.AlwaysFalse{}}, - {iceberg.EqualTo(ref2, below2), iceberg.AlwaysFalse{}}, - {iceberg.NotEqualTo(ref2, above2), iceberg.AlwaysTrue{}}, - {iceberg.NotEqualTo(ref2, below2), iceberg.AlwaysTrue{}}, - {iceberg.LessThan(ref2, above2), iceberg.AlwaysTrue{}}, - {iceberg.LessThan(ref2, below2), iceberg.AlwaysFalse{}}, - {iceberg.LessThanEqual(ref2, above2), iceberg.AlwaysTrue{}}, - {iceberg.LessThanEqual(ref2, below2), iceberg.AlwaysFalse{}}, - {iceberg.GreaterThan(ref2, above2), iceberg.AlwaysFalse{}}, - {iceberg.GreaterThan(ref2, below2), iceberg.AlwaysTrue{}}, - {iceberg.GreaterThanEqual(ref2, above2), iceberg.AlwaysFalse{}}, - {iceberg.GreaterThanEqual(ref2, below2), iceberg.AlwaysTrue{}}, - } - - for _, tt := range tests { - t.Run(tt.pred.String(), func(t *testing.T) { - b, err := tt.pred.Bind(sc, true) - require.NoError(t, err) - assert.Equal(t, tt.exp, b) - }) - } -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package iceberg_test + +import ( + "math" + "strconv" + "testing" + + "github.com/apache/iceberg-go" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type ExprA struct{} + +func (ExprA) String() string { return "ExprA" } +func (ExprA) Op() iceberg.Operation { return iceberg.OpFalse } +func (ExprA) Negate() iceberg.BooleanExpression { return ExprB{} } +func (ExprA) Equals(o iceberg.BooleanExpression) bool { + _, ok := o.(ExprA) + return ok +} + +type ExprB struct{} + +func (ExprB) String() string { return "ExprB" } +func (ExprB) Op() iceberg.Operation { return iceberg.OpTrue } +func (ExprB) Negate() iceberg.BooleanExpression { return ExprA{} } +func (ExprB) Equals(o iceberg.BooleanExpression) bool { + _, ok := o.(ExprB) + return ok +} + +func TestUnaryExpr(t *testing.T) { + assert.PanicsWithError(t, "invalid argument: invalid operation for unary predicate: LessThan", func() { + iceberg.UnaryPredicate(iceberg.OpLT, iceberg.Reference("a")) + }) + + assert.PanicsWithError(t, "invalid argument: cannot create unary predicate with nil term", func() { + iceberg.UnaryPredicate(iceberg.OpIsNull, nil) + }) + + t.Run("negate", func(t *testing.T) { + n := iceberg.IsNull(iceberg.Reference("a")).Negate() + exp := iceberg.NotNull(iceberg.Reference("a")) + + assert.Equal(t, exp, n) + assert.True(t, exp.Equals(n)) + assert.True(t, n.Equals(exp)) + }) + + sc := iceberg.NewSchema(1, iceberg.NestedField{ + ID: 2, Name: "a", Type: iceberg.PrimitiveTypes.Int32}) + sc2 := iceberg.NewSchema(1, iceberg.NestedField{ + ID: 2, Name: "a", Type: iceberg.PrimitiveTypes.Float64}) + sc3 := iceberg.NewSchema(1, iceberg.NestedField{ + ID: 2, Name: "a", Type: iceberg.PrimitiveTypes.Int32, Required: true}) + sc4 := iceberg.NewSchema(1, iceberg.NestedField{ + ID: 2, Name: "a", Type: iceberg.PrimitiveTypes.Float32, Required: true}) + + t.Run("isnull and notnull", func(t *testing.T) { + t.Run("bind", func(t *testing.T) { + n, err := iceberg.IsNull(iceberg.Reference("a")).Bind(sc, true) + require.NoError(t, err) + + assert.Equal(t, iceberg.OpIsNull, n.Op()) + assert.Implements(t, (*iceberg.BoundUnaryPredicate)(nil), n) + p := n.(iceberg.BoundUnaryPredicate) + assert.IsType(t, iceberg.PrimitiveTypes.Int32, p.Term().Type()) + assert.Same(t, p.Ref(), p.Term().Ref()) + assert.Same(t, p.Ref(), p.Ref().Ref()) + + f := p.Ref().Field() + assert.True(t, f.Equals(sc.Field(0))) + }) + + t.Run("negate and bind", func(t *testing.T) { + n1, err := iceberg.IsNull(iceberg.Reference("a")).Bind(sc, true) + require.NoError(t, err) + + n2, err := iceberg.NotNull(iceberg.Reference("a")).Bind(sc, true) + require.NoError(t, err) + + assert.True(t, n1.Negate().Equals(n2)) + assert.True(t, n2.Negate().Equals(n1)) + }) + + t.Run("null bind required", func(t *testing.T) { + n1, err := iceberg.IsNull(iceberg.Reference("a")).Bind(sc3, true) + require.NoError(t, err) + + n2, err := iceberg.NotNull(iceberg.Reference("a")).Bind(sc3, true) + require.NoError(t, err) + + assert.True(t, n1.Equals(iceberg.AlwaysFalse{})) + assert.True(t, n2.Equals(iceberg.AlwaysTrue{})) + }) + }) + + t.Run("isnan notnan", func(t *testing.T) { + t.Run("negate and bind", func(t *testing.T) { + n1, err := iceberg.IsNaN(iceberg.Reference("a")).Bind(sc2, true) + require.NoError(t, err) + + n2, err := iceberg.NotNaN(iceberg.Reference("a")).Bind(sc2, true) + require.NoError(t, err) + + assert.True(t, n1.Negate().Equals(n2)) + assert.True(t, n2.Negate().Equals(n1)) + }) + + t.Run("bind float", func(t *testing.T) { + n, err := iceberg.IsNaN(iceberg.Reference("a")).Bind(sc4, true) + require.NoError(t, err) + + assert.Equal(t, iceberg.OpIsNan, n.Op()) + assert.Implements(t, (*iceberg.BoundUnaryPredicate)(nil), n) + p := n.(iceberg.BoundUnaryPredicate) + assert.IsType(t, iceberg.PrimitiveTypes.Float32, p.Term().Type()) + + n2, err := iceberg.NotNaN(iceberg.Reference("a")).Bind(sc4, true) + require.NoError(t, err) + + assert.Equal(t, iceberg.OpNotNan, n2.Op()) + assert.Implements(t, (*iceberg.BoundUnaryPredicate)(nil), n2) + p2 := n2.(iceberg.BoundUnaryPredicate) + assert.IsType(t, iceberg.PrimitiveTypes.Float32, p2.Term().Type()) + }) + + t.Run("bind double", func(t *testing.T) { + n, err := iceberg.IsNaN(iceberg.Reference("a")).Bind(sc2, true) + require.NoError(t, err) + + assert.Equal(t, iceberg.OpIsNan, n.Op()) + assert.Implements(t, (*iceberg.BoundUnaryPredicate)(nil), n) + p := n.(iceberg.BoundUnaryPredicate) + assert.IsType(t, iceberg.PrimitiveTypes.Float64, p.Term().Type()) + + n2, err := iceberg.NotNaN(iceberg.Reference("a")).Bind(sc2, true) + require.NoError(t, err) + + assert.Equal(t, iceberg.OpNotNan, n2.Op()) + assert.Implements(t, (*iceberg.BoundUnaryPredicate)(nil), n2) + p2 := n2.(iceberg.BoundUnaryPredicate) + assert.IsType(t, iceberg.PrimitiveTypes.Float64, p2.Term().Type()) + }) + + t.Run("bind non floating", func(t *testing.T) { + n1, err := iceberg.IsNaN(iceberg.Reference("a")).Bind(sc, true) + require.NoError(t, err) + + n2, err := iceberg.NotNaN(iceberg.Reference("a")).Bind(sc, true) + require.NoError(t, err) + + assert.True(t, n1.Equals(iceberg.AlwaysFalse{})) + assert.True(t, n2.Equals(iceberg.AlwaysTrue{})) + }) + }) +} + +func TestRefBindingCaseSensitive(t *testing.T) { + ref1, ref2 := iceberg.Reference("foo"), iceberg.Reference("Foo") + + bound1, err := ref1.Bind(tableSchemaSimple, true) + require.NoError(t, err) + assert.True(t, bound1.Type().Equals(iceberg.PrimitiveTypes.String)) + + _, err = ref2.Bind(tableSchemaSimple, true) + assert.ErrorIs(t, err, iceberg.ErrInvalidSchema) + assert.ErrorContains(t, err, "could not bind reference 'Foo', caseSensitive=true") + + bound2, err := ref2.Bind(tableSchemaSimple, false) + require.NoError(t, err) + assert.True(t, bound1.Equals(bound2)) + + _, err = iceberg.Reference("foot").Bind(tableSchemaSimple, false) + assert.ErrorIs(t, err, iceberg.ErrInvalidSchema) + assert.ErrorContains(t, err, "could not bind reference 'foot', caseSensitive=false") +} + +func TestRefTypes(t *testing.T) { + sc := iceberg.NewSchema(1, + iceberg.NestedField{ID: 1, Name: "a", Type: iceberg.PrimitiveTypes.Bool}, + iceberg.NestedField{ID: 2, Name: "b", Type: iceberg.PrimitiveTypes.Int32}, + iceberg.NestedField{ID: 3, Name: "c", Type: iceberg.PrimitiveTypes.Int64}, + iceberg.NestedField{ID: 4, Name: "d", Type: iceberg.PrimitiveTypes.Float32}, + iceberg.NestedField{ID: 5, Name: "e", Type: iceberg.PrimitiveTypes.Float64}, + iceberg.NestedField{ID: 6, Name: "f", Type: iceberg.PrimitiveTypes.Date}, + iceberg.NestedField{ID: 7, Name: "g", Type: iceberg.PrimitiveTypes.Time}, + iceberg.NestedField{ID: 8, Name: "h", Type: iceberg.PrimitiveTypes.Timestamp}, + iceberg.NestedField{ID: 9, Name: "i", Type: iceberg.DecimalTypeOf(9, 2)}, + iceberg.NestedField{ID: 10, Name: "j", Type: iceberg.PrimitiveTypes.String}, + iceberg.NestedField{ID: 11, Name: "k", Type: iceberg.PrimitiveTypes.Binary}, + iceberg.NestedField{ID: 12, Name: "l", Type: iceberg.PrimitiveTypes.UUID}, + iceberg.NestedField{ID: 13, Name: "m", Type: iceberg.FixedTypeOf(5)}) + + t.Run("bind term", func(t *testing.T) { + for i := 0; i < sc.NumFields(); i++ { + fld := sc.Field(i) + t.Run(fld.Type.String(), func(t *testing.T) { + ref, err := iceberg.Reference(fld.Name).Bind(sc, true) + require.NoError(t, err) + + assert.True(t, ref.Type().Equals(fld.Type)) + assert.True(t, fld.Equals(ref.Ref().Field())) + }) + } + }) + + t.Run("bind unary", func(t *testing.T) { + for i := 0; i < sc.NumFields(); i++ { + fld := sc.Field(i) + t.Run(fld.Type.String(), func(t *testing.T) { + b, err := iceberg.IsNull(iceberg.Reference(fld.Name)).Bind(sc, true) + require.NoError(t, err) + + assert.True(t, b.(iceberg.BoundUnaryPredicate).Ref().Type().Equals(fld.Type)) + + un := b.(iceberg.BoundUnaryPredicate).AsUnbound(iceberg.Reference("foo")) + assert.Equal(t, b.Op(), un.Op()) + }) + } + }) + + t.Run("bind literal", func(t *testing.T) { + t.Run("bool", func(t *testing.T) { + b1, err := iceberg.EqualTo(iceberg.Reference("a"), true).Bind(sc, true) + require.NoError(t, err) + assert.Equal(t, iceberg.OpEQ, b1.Op()) + assert.True(t, b1.(iceberg.BoundLiteralPredicate).Ref().Type().Equals(iceberg.PrimitiveTypes.Bool)) + }) + + for i := 1; i < 9; i++ { + fld := sc.Field(i) + t.Run(fld.Type.String(), func(t *testing.T) { + b, err := iceberg.EqualTo(iceberg.Reference(fld.Name), int32(5)).Bind(sc, true) + require.NoError(t, err) + + assert.Equal(t, iceberg.OpEQ, b.Op()) + assert.True(t, b.(iceberg.BoundLiteralPredicate).Literal().Type().Equals(fld.Type)) + assert.True(t, b.(iceberg.BoundLiteralPredicate).Ref().Type().Equals(fld.Type)) + }) + } + + t.Run("string-binary", func(t *testing.T) { + str, err := iceberg.EqualTo(iceberg.Reference("j"), "foobar").Bind(sc, true) + require.NoError(t, err) + + bin, err := iceberg.EqualTo(iceberg.Reference("k"), []byte("foobar")).Bind(sc, true) + require.NoError(t, err) + + assert.Equal(t, iceberg.OpEQ, str.Op()) + assert.True(t, str.(iceberg.BoundLiteralPredicate).Literal().Type().Equals(iceberg.PrimitiveTypes.String)) + assert.Equal(t, iceberg.OpEQ, bin.Op()) + assert.True(t, bin.(iceberg.BoundLiteralPredicate).Literal().Type().Equals(iceberg.PrimitiveTypes.Binary)) + }) + + t.Run("fixed", func(t *testing.T) { + fx, err := iceberg.EqualTo(iceberg.Reference("m"), []byte{0, 1, 2, 3, 4}).Bind(sc, true) + require.NoError(t, err) + + assert.Equal(t, iceberg.OpEQ, fx.Op()) + assert.True(t, fx.(iceberg.BoundLiteralPredicate).Literal().Type().Equals(iceberg.FixedTypeOf(5))) + }) + + t.Run("uuid", func(t *testing.T) { + uid, err := iceberg.EqualTo(iceberg.Reference("l"), uuid.New().String()).Bind(sc, true) + require.NoError(t, err) + + assert.Equal(t, iceberg.OpEQ, uid.Op()) + assert.True(t, uid.(iceberg.BoundLiteralPredicate).Literal().Type().Equals(iceberg.PrimitiveTypes.UUID)) + }) + }) + + t.Run("bind set", func(t *testing.T) { + t.Run("bool", func(t *testing.T) { + b, err := iceberg.IsIn(iceberg.Reference("a"), true, false).(iceberg.UnboundPredicate).Bind(sc, true) + require.NoError(t, err) + + assert.Equal(t, iceberg.OpIn, b.Op()) + }) + + for i := 1; i < 9; i++ { + fld := sc.Field(i) + t.Run(fld.Type.String(), func(t *testing.T) { + b, err := iceberg.IsIn(iceberg.Reference(fld.Name), int32(10), int32(5), int32(5)).(iceberg.UnboundPredicate). + Bind(sc, true) + require.NoError(t, err) + + assert.Equal(t, iceberg.OpIn, b.Op()) + assert.True(t, b.(iceberg.BoundSetPredicate).Ref().Type().Equals(fld.Type)) + for _, v := range b.(iceberg.BoundSetPredicate).Literals().Members() { + assert.True(t, v.Type().Equals(fld.Type)) + } + }) + } + + t.Run("string-binary", func(t *testing.T) { + str, err := iceberg.IsIn(iceberg.Reference("j"), "hello", "foobar").(iceberg.UnboundPredicate). + Bind(sc, true) + require.NoError(t, err) + + bin, err := iceberg.IsIn(iceberg.Reference("k"), []byte("baz"), []byte("foobar")).(iceberg.UnboundPredicate). + Bind(sc, true) + require.NoError(t, err) + + assert.Equal(t, iceberg.OpIn, str.Op()) + assert.Equal(t, iceberg.OpIn, bin.Op()) + + assert.True(t, str.(iceberg.BoundSetPredicate).Ref().Type().Equals(iceberg.PrimitiveTypes.String)) + for _, v := range str.(iceberg.BoundSetPredicate).Literals().Members() { + assert.True(t, v.Type().Equals(iceberg.PrimitiveTypes.String)) + } + + assert.True(t, bin.(iceberg.BoundSetPredicate).Ref().Type().Equals(iceberg.PrimitiveTypes.Binary)) + for _, v := range bin.(iceberg.BoundSetPredicate).Literals().Members() { + assert.True(t, v.Type().Equals(iceberg.PrimitiveTypes.Binary)) + } + }) + + t.Run("fixed", func(t *testing.T) { + fx, err := iceberg.IsIn(iceberg.Reference("m"), []byte{4, 5, 6, 7, 8}, []byte{0, 1, 2, 3, 4}).(iceberg.UnboundPredicate). + Bind(sc, true) + require.NoError(t, err) + + assert.Equal(t, iceberg.OpIn, fx.Op()) + assert.True(t, fx.(iceberg.BoundSetPredicate).Ref().Type().Equals(iceberg.FixedTypeOf(5))) + for _, v := range fx.(iceberg.BoundSetPredicate).Literals().Members() { + assert.True(t, v.Type().Equals(iceberg.FixedTypeOf(5))) + } + }) + + t.Run("uuid", func(t *testing.T) { + uid, err := iceberg.IsIn(iceberg.Reference("l"), uuid.New().String(), uuid.New().String()).(iceberg.UnboundPredicate). + Bind(sc, true) + require.NoError(t, err) + + assert.Equal(t, iceberg.OpIn, uid.Op()) + assert.True(t, uid.(iceberg.BoundSetPredicate).Ref().Type().Equals(iceberg.PrimitiveTypes.UUID)) + for _, v := range uid.(iceberg.BoundSetPredicate).Literals().Members() { + assert.True(t, v.Type().Equals(iceberg.PrimitiveTypes.UUID)) + } + }) + }) +} + +func TestInNotInSimplifications(t *testing.T) { + assert.PanicsWithError(t, "invalid argument: invalid operation for SetPredicate: LessThan", + func() { iceberg.SetPredicate(iceberg.OpLT, iceberg.Reference("x"), nil) }) + assert.PanicsWithError(t, "invalid argument: cannot create set predicate with nil term", + func() { iceberg.SetPredicate(iceberg.OpIn, nil, nil) }) + assert.NotPanics(t, func() { iceberg.SetPredicate(iceberg.OpIn, iceberg.Reference("x"), nil) }) + + t.Run("in to eq", func(t *testing.T) { + a := iceberg.IsIn(iceberg.Reference("x"), 34.56) + b := iceberg.EqualTo(iceberg.Reference("x"), 34.56) + assert.True(t, a.Equals(b)) + }) + + t.Run("notin to notequal", func(t *testing.T) { + a := iceberg.NotIn(iceberg.Reference("x"), 34.56) + b := iceberg.NotEqualTo(iceberg.Reference("x"), 34.56) + assert.True(t, a.Equals(b)) + }) + + t.Run("empty", func(t *testing.T) { + a := iceberg.IsIn[float32](iceberg.Reference("x")) + b := iceberg.NotIn[float32](iceberg.Reference("x")) + + assert.Equal(t, iceberg.AlwaysFalse{}, a) + assert.Equal(t, iceberg.AlwaysTrue{}, b) + }) + + t.Run("bind and negate", func(t *testing.T) { + inexp := iceberg.IsIn(iceberg.Reference("foo"), "hello", "world") + notin := iceberg.NotIn(iceberg.Reference("foo"), "hello", "world") + assert.True(t, inexp.Negate().Equals(notin)) + assert.True(t, notin.Negate().Equals(inexp)) + assert.Equal(t, iceberg.OpIn, inexp.Op()) + assert.Equal(t, iceberg.OpNotIn, notin.Op()) + + boundin, err := inexp.(iceberg.UnboundPredicate).Bind(tableSchemaSimple, true) + require.NoError(t, err) + + boundnot, err := notin.(iceberg.UnboundPredicate).Bind(tableSchemaSimple, true) + require.NoError(t, err) + + assert.True(t, boundin.Negate().Equals(boundnot)) + assert.True(t, boundnot.Negate().Equals(boundin)) + }) + + t.Run("bind dedup", func(t *testing.T) { + isin := iceberg.IsIn(iceberg.Reference("foo"), "hello", "world", "world") + bound, err := isin.(iceberg.UnboundPredicate).Bind(tableSchemaSimple, true) + require.NoError(t, err) + + assert.Implements(t, (*iceberg.BoundSetPredicate)(nil), bound) + bsp := bound.(iceberg.BoundSetPredicate) + assert.Equal(t, 2, bsp.Literals().Len()) + assert.True(t, bsp.Literals().Contains(iceberg.NewLiteral("hello"))) + assert.True(t, bsp.Literals().Contains(iceberg.NewLiteral("world"))) + }) + + t.Run("bind dedup to eq", func(t *testing.T) { + isin := iceberg.IsIn(iceberg.Reference("foo"), "world", "world") + bound, err := isin.(iceberg.UnboundPredicate).Bind(tableSchemaSimple, true) + require.NoError(t, err) + + assert.Equal(t, iceberg.OpEQ, bound.Op()) + assert.Equal(t, iceberg.NewLiteral("world"), + bound.(iceberg.BoundLiteralPredicate).Literal()) + }) +} + +func TestLiteralPredicateErrors(t *testing.T) { + assert.PanicsWithError(t, "invalid argument: invalid operation for LiteralPredicate: In", + func() { iceberg.LiteralPredicate(iceberg.OpIn, iceberg.Reference("foo"), iceberg.NewLiteral("hello")) }) + assert.PanicsWithError(t, "invalid argument: cannot create literal predicate with nil term", + func() { iceberg.LiteralPredicate(iceberg.OpLT, nil, iceberg.NewLiteral("hello")) }) + assert.PanicsWithError(t, "invalid argument: cannot create literal predicate with nil literal", + func() { iceberg.LiteralPredicate(iceberg.OpLT, iceberg.Reference("foo"), nil) }) +} + +func TestNegations(t *testing.T) { + ref := iceberg.Reference("foo") + + tests := []struct { + name string + ex1, ex2 iceberg.UnboundPredicate + }{ + {"equal-not", iceberg.EqualTo(ref, "hello"), iceberg.NotEqualTo(ref, "hello")}, + {"greater-equal-less", iceberg.GreaterThanEqual(ref, "hello"), iceberg.LessThan(ref, "hello")}, + {"greater-less-equal", iceberg.GreaterThan(ref, "hello"), iceberg.LessThanEqual(ref, "hello")}, + {"starts-with", iceberg.StartsWith(ref, "hello"), iceberg.NotStartsWith(ref, "hello")}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.False(t, tt.ex1.Equals(tt.ex2)) + assert.False(t, tt.ex2.Equals(tt.ex1)) + assert.True(t, tt.ex1.Negate().Equals(tt.ex2)) + assert.True(t, tt.ex2.Negate().Equals(tt.ex1)) + + b1, err := tt.ex1.Bind(tableSchemaSimple, true) + require.NoError(t, err) + b2, err := tt.ex2.Bind(tableSchemaSimple, true) + require.NoError(t, err) + + assert.False(t, b1.Equals(b2)) + assert.False(t, b2.Equals(b1)) + assert.True(t, b1.Negate().Equals(b2)) + assert.True(t, b2.Negate().Equals(b1)) + }) + } +} + +func TestBoolExprEQ(t *testing.T) { + tests := []struct { + exp, testexpra, testexprb iceberg.BooleanExpression + }{ + {iceberg.NewAnd(ExprA{}, ExprB{}), + iceberg.NewAnd(ExprA{}, ExprB{}), + iceberg.NewOr(ExprA{}, ExprB{})}, + {iceberg.NewOr(ExprA{}, ExprB{}), + iceberg.NewOr(ExprA{}, ExprB{}), + iceberg.NewAnd(ExprA{}, ExprB{})}, + {iceberg.NewAnd(ExprA{}, ExprB{}), + iceberg.NewAnd(ExprB{}, ExprA{}), + iceberg.NewOr(ExprB{}, ExprA{})}, + {iceberg.NewOr(ExprA{}, ExprB{}), + iceberg.NewOr(ExprB{}, ExprA{}), + iceberg.NewAnd(ExprB{}, ExprA{})}, + {iceberg.NewNot(ExprA{}), iceberg.NewNot(ExprA{}), ExprB{}}, + {ExprA{}, ExprA{}, ExprB{}}, + {ExprB{}, ExprB{}, ExprA{}}, + {iceberg.IsIn(iceberg.Reference("foo"), "hello", "world"), + iceberg.IsIn(iceberg.Reference("foo"), "hello", "world"), + iceberg.IsIn(iceberg.Reference("not_foo"), "hello", "world")}, + {iceberg.IsIn(iceberg.Reference("foo"), "hello", "world"), + iceberg.IsIn(iceberg.Reference("foo"), "hello", "world"), + iceberg.IsIn(iceberg.Reference("foo"), "goodbye", "world")}, + } + + for i, tt := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + assert.True(t, tt.exp.Equals(tt.testexpra)) + assert.False(t, tt.exp.Equals(tt.testexprb)) + }) + } +} + +func TestBoolExprNegate(t *testing.T) { + tests := []struct { + lhs, rhs iceberg.BooleanExpression + }{ + {iceberg.NewAnd(ExprA{}, ExprB{}), iceberg.NewOr(ExprB{}, ExprA{})}, + {iceberg.NewOr(ExprB{}, ExprA{}), iceberg.NewAnd(ExprA{}, ExprB{})}, + {iceberg.NewNot(ExprA{}), ExprA{}}, + {iceberg.IsIn(iceberg.Reference("foo"), "hello", "world"), + iceberg.NotIn(iceberg.Reference("foo"), "hello", "world")}, + {iceberg.NotIn(iceberg.Reference("foo"), "hello", "world"), + iceberg.IsIn(iceberg.Reference("foo"), "hello", "world")}, + {iceberg.GreaterThan(iceberg.Reference("foo"), int32(5)), + iceberg.LessThanEqual(iceberg.Reference("foo"), int32(5))}, + {iceberg.LessThan(iceberg.Reference("foo"), int32(5)), + iceberg.GreaterThanEqual(iceberg.Reference("foo"), int32(5))}, + {iceberg.EqualTo(iceberg.Reference("foo"), int32(5)), + iceberg.NotEqualTo(iceberg.Reference("foo"), int32(5))}, + {ExprA{}, ExprB{}}, + } + + for _, tt := range tests { + assert.True(t, tt.lhs.Negate().Equals(tt.rhs)) + } +} + +func TestBoolExprPanics(t *testing.T) { + assert.PanicsWithError(t, "invalid argument: cannot construct AndExpr with nil arguments", + func() { iceberg.NewAnd(nil, ExprA{}) }) + assert.PanicsWithError(t, "invalid argument: cannot construct AndExpr with nil arguments", + func() { iceberg.NewAnd(ExprA{}, nil) }) + assert.PanicsWithError(t, "invalid argument: cannot construct AndExpr with nil arguments", + func() { iceberg.NewAnd(ExprA{}, ExprA{}, nil) }) + + assert.PanicsWithError(t, "invalid argument: cannot construct OrExpr with nil arguments", + func() { iceberg.NewOr(nil, ExprA{}) }) + assert.PanicsWithError(t, "invalid argument: cannot construct OrExpr with nil arguments", + func() { iceberg.NewOr(ExprA{}, nil) }) + assert.PanicsWithError(t, "invalid argument: cannot construct OrExpr with nil arguments", + func() { iceberg.NewOr(ExprA{}, ExprA{}, nil) }) + + assert.PanicsWithError(t, "invalid argument: cannot create NotExpr with nil child", + func() { iceberg.NewNot(nil) }) +} + +func TestExprFolding(t *testing.T) { + tests := []struct { + lhs, rhs iceberg.BooleanExpression + }{ + {iceberg.NewAnd(ExprA{}, ExprB{}, ExprA{}), + iceberg.NewAnd(iceberg.NewAnd(ExprA{}, ExprB{}), ExprA{})}, + {iceberg.NewOr(ExprA{}, ExprB{}, ExprA{}), + iceberg.NewOr(iceberg.NewOr(ExprA{}, ExprB{}), ExprA{})}, + {iceberg.NewNot(iceberg.NewNot(ExprA{})), ExprA{}}, + } + + for _, tt := range tests { + assert.True(t, tt.lhs.Equals(tt.rhs)) + } +} + +func TestBaseAlwaysTrueAlwaysFalse(t *testing.T) { + tests := []struct { + lhs, rhs iceberg.BooleanExpression + }{ + {iceberg.NewAnd(iceberg.AlwaysTrue{}, ExprB{}), ExprB{}}, + {iceberg.NewAnd(iceberg.AlwaysFalse{}, ExprB{}), iceberg.AlwaysFalse{}}, + {iceberg.NewAnd(ExprB{}, iceberg.AlwaysTrue{}), ExprB{}}, + {iceberg.NewOr(iceberg.AlwaysTrue{}, ExprB{}), iceberg.AlwaysTrue{}}, + {iceberg.NewOr(iceberg.AlwaysFalse{}, ExprB{}), ExprB{}}, + {iceberg.NewOr(ExprA{}, iceberg.AlwaysFalse{}), ExprA{}}, + {iceberg.NewNot(iceberg.NewNot(ExprA{})), ExprA{}}, + {iceberg.NewNot(iceberg.AlwaysTrue{}), iceberg.AlwaysFalse{}}, + {iceberg.NewNot(iceberg.AlwaysFalse{}), iceberg.AlwaysTrue{}}, + } + + for _, tt := range tests { + assert.True(t, tt.lhs.Equals(tt.rhs)) + } +} + +func TestNegateAlways(t *testing.T) { + assert.Equal(t, iceberg.OpTrue, iceberg.AlwaysTrue{}.Op()) + assert.Equal(t, iceberg.OpFalse, iceberg.AlwaysFalse{}.Op()) + + assert.Equal(t, iceberg.AlwaysTrue{}, iceberg.AlwaysFalse{}.Negate()) + assert.Equal(t, iceberg.AlwaysFalse{}, iceberg.AlwaysTrue{}.Negate()) +} + +func TestBoundReferenceToString(t *testing.T) { + ref, err := iceberg.Reference("foo").Bind(tableSchemaSimple, true) + require.NoError(t, err) + + assert.Equal(t, "BoundReference(field=1: foo: optional string, accessor=Accessor(position=0, inner=))", + ref.String()) +} + +func TestToString(t *testing.T) { + schema := iceberg.NewSchema(1, + iceberg.NestedField{ID: 1, Name: "a", Type: iceberg.PrimitiveTypes.String}, + iceberg.NestedField{ID: 2, Name: "b", Type: iceberg.PrimitiveTypes.String}, + iceberg.NestedField{ID: 3, Name: "c", Type: iceberg.PrimitiveTypes.String}, + iceberg.NestedField{ID: 4, Name: "d", Type: iceberg.PrimitiveTypes.Int32}, + iceberg.NestedField{ID: 5, Name: "e", Type: iceberg.PrimitiveTypes.Int32}, + iceberg.NestedField{ID: 6, Name: "f", Type: iceberg.PrimitiveTypes.Int32}, + iceberg.NestedField{ID: 7, Name: "g", Type: iceberg.PrimitiveTypes.Float32}, + iceberg.NestedField{ID: 8, Name: "h", Type: iceberg.DecimalTypeOf(8, 4)}, + iceberg.NestedField{ID: 9, Name: "i", Type: iceberg.PrimitiveTypes.UUID}, + iceberg.NestedField{ID: 10, Name: "j", Type: iceberg.PrimitiveTypes.Bool}, + iceberg.NestedField{ID: 11, Name: "k", Type: iceberg.PrimitiveTypes.Bool}, + iceberg.NestedField{ID: 12, Name: "l", Type: iceberg.PrimitiveTypes.Binary}) + + null := iceberg.IsNull(iceberg.Reference("a")) + nan := iceberg.IsNaN(iceberg.Reference("g")) + boundNull, _ := null.Bind(schema, true) + boundNan, _ := nan.Bind(schema, true) + + equal := iceberg.EqualTo(iceberg.Reference("c"), "a") + grtequal := iceberg.GreaterThanEqual(iceberg.Reference("a"), "a") + greater := iceberg.GreaterThan(iceberg.Reference("a"), "a") + startsWith := iceberg.StartsWith(iceberg.Reference("b"), "foo") + + boundEqual, _ := equal.Bind(schema, true) + boundGrtEqual, _ := grtequal.Bind(schema, true) + boundGreater, _ := greater.Bind(schema, true) + boundStarts, _ := startsWith.Bind(schema, true) + + tests := []struct { + e iceberg.BooleanExpression + expected string + }{ + {iceberg.NewAnd(null, nan), + "And(left=IsNull(term=Reference(name='a')), right=IsNaN(term=Reference(name='g')))"}, + {iceberg.NewOr(null, nan), + "Or(left=IsNull(term=Reference(name='a')), right=IsNaN(term=Reference(name='g')))"}, + {iceberg.NewNot(null), + "Not(child=IsNull(term=Reference(name='a')))"}, + {iceberg.AlwaysTrue{}, "AlwaysTrue()"}, + {iceberg.AlwaysFalse{}, "AlwaysFalse()"}, + {boundNull, + "BoundIsNull(term=BoundReference(field=1: a: optional string, accessor=Accessor(position=0, inner=)))"}, + {boundNull.Negate(), + "BoundNotNull(term=BoundReference(field=1: a: optional string, accessor=Accessor(position=0, inner=)))"}, + {boundNan, + "BoundIsNaN(term=BoundReference(field=7: g: optional float, accessor=Accessor(position=6, inner=)))"}, + {boundNan.Negate(), + "BoundNotNaN(term=BoundReference(field=7: g: optional float, accessor=Accessor(position=6, inner=)))"}, + {equal, + "Equal(term=Reference(name='c'), literal=a)"}, + {equal.Negate(), + "NotEqual(term=Reference(name='c'), literal=a)"}, + {grtequal, + "GreaterThanEqual(term=Reference(name='a'), literal=a)"}, + {grtequal.Negate(), + "LessThan(term=Reference(name='a'), literal=a)"}, + {greater, + "GreaterThan(term=Reference(name='a'), literal=a)"}, + {greater.Negate(), + "LessThanEqual(term=Reference(name='a'), literal=a)"}, + {startsWith, + "StartsWith(term=Reference(name='b'), literal=foo)"}, + {startsWith.Negate(), + "NotStartsWith(term=Reference(name='b'), literal=foo)"}, + {boundEqual, + "BoundEqual(term=BoundReference(field=3: c: optional string, accessor=Accessor(position=2, inner=)), literal=a)"}, + {boundEqual.Negate(), + "BoundNotEqual(term=BoundReference(field=3: c: optional string, accessor=Accessor(position=2, inner=)), literal=a)"}, + {boundGreater, + "BoundGreaterThan(term=BoundReference(field=1: a: optional string, accessor=Accessor(position=0, inner=)), literal=a)"}, + {boundGreater.Negate(), + "BoundLessThanEqual(term=BoundReference(field=1: a: optional string, accessor=Accessor(position=0, inner=)), literal=a)"}, + {boundGrtEqual, + "BoundGreaterThanEqual(term=BoundReference(field=1: a: optional string, accessor=Accessor(position=0, inner=)), literal=a)"}, + {boundGrtEqual.Negate(), + "BoundLessThan(term=BoundReference(field=1: a: optional string, accessor=Accessor(position=0, inner=)), literal=a)"}, + {boundStarts, + "BoundStartsWith(term=BoundReference(field=2: b: optional string, accessor=Accessor(position=1, inner=)), literal=foo)"}, + {boundStarts.Negate(), + "BoundNotStartsWith(term=BoundReference(field=2: b: optional string, accessor=Accessor(position=1, inner=)), literal=foo)"}, + } + + for _, tt := range tests { + assert.Equal(t, tt.expected, tt.e.String()) + } +} + +func TestBindAboveBelowIntMax(t *testing.T) { + sc := iceberg.NewSchema(1, + iceberg.NestedField{ID: 1, Name: "a", Type: iceberg.PrimitiveTypes.Int32}, + iceberg.NestedField{ID: 2, Name: "b", Type: iceberg.PrimitiveTypes.Float32}, + ) + + ref, ref2 := iceberg.Reference("a"), iceberg.Reference("b") + above, below := int64(math.MaxInt32)+1, int64(math.MinInt32)-1 + above2, below2 := float64(math.MaxFloat32)+1e37, float64(-math.MaxFloat32)-1e37 + + tests := []struct { + pred iceberg.UnboundPredicate + exp iceberg.BooleanExpression + }{ + {iceberg.EqualTo(ref, above), iceberg.AlwaysFalse{}}, + {iceberg.EqualTo(ref, below), iceberg.AlwaysFalse{}}, + {iceberg.NotEqualTo(ref, above), iceberg.AlwaysTrue{}}, + {iceberg.NotEqualTo(ref, below), iceberg.AlwaysTrue{}}, + {iceberg.LessThan(ref, above), iceberg.AlwaysTrue{}}, + {iceberg.LessThan(ref, below), iceberg.AlwaysFalse{}}, + {iceberg.LessThanEqual(ref, above), iceberg.AlwaysTrue{}}, + {iceberg.LessThanEqual(ref, below), iceberg.AlwaysFalse{}}, + {iceberg.GreaterThan(ref, above), iceberg.AlwaysFalse{}}, + {iceberg.GreaterThan(ref, below), iceberg.AlwaysTrue{}}, + {iceberg.GreaterThanEqual(ref, above), iceberg.AlwaysFalse{}}, + {iceberg.GreaterThanEqual(ref, below), iceberg.AlwaysTrue{}}, + + {iceberg.EqualTo(ref2, above2), iceberg.AlwaysFalse{}}, + {iceberg.EqualTo(ref2, below2), iceberg.AlwaysFalse{}}, + {iceberg.NotEqualTo(ref2, above2), iceberg.AlwaysTrue{}}, + {iceberg.NotEqualTo(ref2, below2), iceberg.AlwaysTrue{}}, + {iceberg.LessThan(ref2, above2), iceberg.AlwaysTrue{}}, + {iceberg.LessThan(ref2, below2), iceberg.AlwaysFalse{}}, + {iceberg.LessThanEqual(ref2, above2), iceberg.AlwaysTrue{}}, + {iceberg.LessThanEqual(ref2, below2), iceberg.AlwaysFalse{}}, + {iceberg.GreaterThan(ref2, above2), iceberg.AlwaysFalse{}}, + {iceberg.GreaterThan(ref2, below2), iceberg.AlwaysTrue{}}, + {iceberg.GreaterThanEqual(ref2, above2), iceberg.AlwaysFalse{}}, + {iceberg.GreaterThanEqual(ref2, below2), iceberg.AlwaysTrue{}}, + } + + for _, tt := range tests { + t.Run(tt.pred.String(), func(t *testing.T) { + b, err := tt.pred.Bind(sc, true) + require.NoError(t, err) + assert.Equal(t, tt.exp, b) + }) + } +} diff --git a/go.mod b/go.mod index 04bfa2f..46046e8 100644 --- a/go.mod +++ b/go.mod @@ -34,17 +34,27 @@ require ( github.com/stretchr/testify v1.9.0 github.com/twmb/murmur3 v1.1.8 github.com/wolfeidau/s3iofs v1.5.2 + gocloud.dev v0.40.0 golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 + google.golang.org/api v0.201.0 ) require ( atomicgo.dev/cursor v0.2.0 // indirect atomicgo.dev/keyboard v0.2.9 // indirect atomicgo.dev/schedule v0.1.0 // indirect + cloud.google.com/go v0.116.0 // indirect + cloud.google.com/go/auth v0.9.8 // indirect + cloud.google.com/go/auth/oauth2adapt v0.2.4 // indirect + cloud.google.com/go/compute/metadata v0.5.2 // indirect + cloud.google.com/go/iam v1.2.1 // indirect + cloud.google.com/go/storage v1.43.0 // indirect github.com/andybalholm/brotli v1.1.0 // indirect github.com/apache/thrift v0.20.0 // indirect + github.com/aws/aws-sdk-go v1.55.5 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.4 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.14 // indirect + github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.10 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.21 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.21 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect @@ -58,10 +68,19 @@ require ( github.com/aws/aws-sdk-go-v2/service/sts v1.31.3 // indirect github.com/containerd/console v1.0.3 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/go-logr/logr v1.4.2 // indirect + github.com/go-logr/stdr v1.2.2 // indirect github.com/goccy/go-json v0.10.3 // indirect + github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/flatbuffers v24.3.25+incompatible // indirect + github.com/google/s2a-go v0.1.8 // indirect + github.com/google/wire v0.6.0 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect + github.com/googleapis/gax-go/v2 v2.13.0 // indirect github.com/gookit/color v1.5.4 // indirect + github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/asmfmt v1.3.2 // indirect github.com/klauspost/compress v1.17.9 // indirect @@ -79,13 +98,27 @@ require ( github.com/stretchr/objx v0.5.2 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/zeebo/xxh3 v1.0.2 // indirect + go.opencensus.io v0.24.0 // indirect + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect + go.opentelemetry.io/otel v1.29.0 // indirect + go.opentelemetry.io/otel/metric v1.29.0 // indirect + go.opentelemetry.io/otel/trace v1.29.0 // indirect + golang.org/x/crypto v0.28.0 // indirect golang.org/x/mod v0.21.0 // indirect - golang.org/x/net v0.29.0 // indirect + golang.org/x/net v0.30.0 // indirect + golang.org/x/oauth2 v0.23.0 // indirect golang.org/x/sync v0.8.0 // indirect - golang.org/x/sys v0.25.0 // indirect - golang.org/x/term v0.24.0 // indirect - golang.org/x/text v0.18.0 // indirect + golang.org/x/sys v0.26.0 // indirect + golang.org/x/term v0.25.0 // indirect + golang.org/x/text v0.19.0 // indirect + golang.org/x/time v0.7.0 // indirect golang.org/x/tools v0.25.0 // indirect - golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect + golang.org/x/xerrors v0.0.0-20240716161551-93cc26a95ae9 // indirect + google.golang.org/genproto v0.0.0-20241007155032-5fefd90f89a9 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20240930140551-af27646dc61f // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20241007155032-5fefd90f89a9 // indirect + google.golang.org/grpc v1.67.1 // indirect + google.golang.org/protobuf v1.35.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 7bf89e5..cfecf04 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,22 @@ atomicgo.dev/keyboard v0.2.9 h1:tOsIid3nlPLZ3lwgG8KZMp/SFmr7P0ssEN5JUsm78K8= atomicgo.dev/keyboard v0.2.9/go.mod h1:BC4w9g00XkxH/f1HXhW2sXmJFOCWbKn9xrOunSFtExQ= atomicgo.dev/schedule v0.1.0 h1:nTthAbhZS5YZmgYbb2+DH8uQIZcTlIrd4eYr3UQxEjs= atomicgo.dev/schedule v0.1.0/go.mod h1:xeUa3oAkiuHYh8bKiQBRojqAMq3PXXbJujjb0hw8pEU= +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE= +cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U= +cloud.google.com/go/auth v0.9.8 h1:+CSJ0Gw9iVeSENVCKJoLHhdUykDgXSc4Qn+gu2BRtR8= +cloud.google.com/go/auth v0.9.8/go.mod h1:xxA5AqpDrvS+Gkmo9RqrGGRh6WSNKKOXhY3zNOr38tI= +cloud.google.com/go/auth/oauth2adapt v0.2.4 h1:0GWE/FUsXhf6C+jAkWgYm7X9tK8cuEIfy19DBn6B6bY= +cloud.google.com/go/auth/oauth2adapt v0.2.4/go.mod h1:jC/jOpwFP6JBxhB3P5Rr0a9HLMC/Pe3eaL4NmdvqPtc= +cloud.google.com/go/compute/metadata v0.5.2 h1:UxK4uu/Tn+I3p2dYWTfiX4wva7aYlKixAHn3fyqngqo= +cloud.google.com/go/compute/metadata v0.5.2/go.mod h1:C66sj2AluDcIqakBq/M8lw8/ybHgOZqin2obFxa/E5k= +cloud.google.com/go/iam v1.2.1 h1:QFct02HRb7H12J/3utj0qf5tobFh9V4vR6h9eX5EBRU= +cloud.google.com/go/iam v1.2.1/go.mod h1:3VUIJDPpwT6p/amXRC5GY8fCCh70lxPygguVtI0Z4/g= +cloud.google.com/go/longrunning v0.6.1 h1:lOLTFxYpr8hcRtcwWir5ITh1PAKUD/sG2lKrTSYjyMc= +cloud.google.com/go/longrunning v0.6.1/go.mod h1:nHISoOZpBcmlwbJmiVk5oDRz0qG/ZxPynEGs1iZ79s0= +cloud.google.com/go/storage v1.43.0 h1:CcxnSohZwizt4LCzQHWvBf1/kvtHUn7gk9QERXPyXFs= +cloud.google.com/go/storage v1.43.0/go.mod h1:ajvxEa7WmZS1PxvKRq4bq0tFT3vMd502JwstCcYv0Q0= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c h1:RGWPOewvKIROun94nF7v2cua9qP+thov/7M50KEoeSU= github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c/go.mod h1:X0CRv0ky0k6m906ixxpzmDRLvX58TFUKS2eePweuyxk= github.com/MarvinJWendt/testza v0.1.0/go.mod h1:7AxNvlfeHP7Z/hDQ5JtE3OKYT3XFUeLCDE2DQninSqs= @@ -24,6 +40,8 @@ github.com/apache/arrow-go/v18 v18.0.0-20240924011512-14844aea3205/go.mod h1:MXq github.com/apache/thrift v0.20.0 h1:631+KvYbsBZxmuJjYwhezVsrfc/TbqtZV4QcxOX1fOI= github.com/apache/thrift v0.20.0/go.mod h1:hOk1BQqcp2OLzGsyVXdfMk7YFlMxK3aoEVhjD06QhB8= github.com/atomicgo/cursor v0.0.1/go.mod h1:cBON2QmmrysudxNBFthvMtN32r3jxVRIvzkUiF/RuIk= +github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU= +github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= github.com/aws/aws-sdk-go-v2 v1.32.2 h1:AkNLZEyYMLnx/Q/mSKkcMqwNFXMAvFto9bNsHqcTduI= github.com/aws/aws-sdk-go-v2 v1.32.2/go.mod h1:2SK5n0a2karNTv5tbP1SjsX0uhttou00v/HpXKM1ZUo= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.4 h1:70PVAiL15/aBMh5LThwgXdSQorVr91L127ttckI9QQU= @@ -34,6 +52,8 @@ github.com/aws/aws-sdk-go-v2/credentials v1.17.37 h1:G2aOH01yW8X373JK419THj5QVqu github.com/aws/aws-sdk-go-v2/credentials v1.17.37/go.mod h1:0ecCjlb7htYCptRD45lXJ6aJDQac6D2NlKGpZqyTG6A= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.14 h1:C/d03NAmh8C4BZXhuRNboF/DqhBkBCeDiJDcaqIT5pA= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.14/go.mod h1:7I0Ju7p9mCIdlrfS+JCgqcYD0VXz/N4yozsox+0o078= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.10 h1:zeN9UtUlA6FTx0vFSayxSX32HDw73Yb6Hh2izDSFxXY= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.10/go.mod h1:3HKuexPDcwLWPaqpW2UR/9n8N/u/3CKcGAzSs8p8u8g= github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.21 h1:UAsR3xA31QGf79WzpG/ixT9FZvQlh5HY1NRqSHBNOCk= github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.21/go.mod h1:JNr43NFf5L9YaG3eKTm7HQzls9J+A9YYcGI5Quh1r2Y= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.21 h1:6jZVETqmYCadGFvrYEQfC5fAQmlo80CeL5psbno6r0s= @@ -62,6 +82,9 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.31.3 h1:VzudTFrDCIDakXtemR7l6Qzt2+JY github.com/aws/aws-sdk-go-v2/service/sts v1.31.3/go.mod h1:yMWe0F+XG0DkRZK5ODZhG7BEFYhLXi2dqGsv6tX0cgI= github.com/aws/smithy-go v1.22.0 h1:uunKnWlcoL3zO7q+gG2Pk53joueEOsnNB28QdMsmiMM= github.com/aws/smithy-go v1.22.0/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/containerd/console v1.0.3 h1:lIr7SlA5PxZyMV30bDW0MGbiOPXwc63yRuCP0ARubLw= github.com/containerd/console v1.0.3/go.mod h1:7LqA/THxQ86k76b8c/EMSiaJ3h1eZkMkXar0TQ1gf3U= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -69,21 +92,76 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815 h1:bWDMxwH3px2JBh6AyO7hdCn/PkvCZXii8TGj7sbtEbQ= github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE= +github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/flatbuffers v24.3.25+incompatible h1:CX395cjN9Kke9mmalRoL3d81AtFUxJM+yDthflgJGkI= github.com/google/flatbuffers v24.3.25+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-replayers/grpcreplay v1.3.0 h1:1Keyy0m1sIpqstQmgz307zhiJ1pV4uIlFds5weTmxbo= +github.com/google/go-replayers/grpcreplay v1.3.0/go.mod h1:v6NgKtkijC0d3e3RW8il6Sy5sqRVUwoQa4mHOGEy8DI= +github.com/google/go-replayers/httpreplay v1.2.0 h1:VM1wEyyjaoU53BwrOnaf9VhAyQQEEioJvFYxYcLRKzk= +github.com/google/go-replayers/httpreplay v1.2.0/go.mod h1:WahEFFZZ7a1P4VM1qEeHy+tME4bwyqPcwWbNlUI1Mcg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/martian/v3 v3.3.3 h1:DIhPTQrbPkgs2yJYdXU/eNACCG5DVQjySNRNlflZ9Fc= +github.com/google/martian/v3 v3.3.3/go.mod h1:iEPrYcgCF7jA9OtScMFQyAlZZ4YXTKEtJ1E6RWzmBA0= +github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM= +github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO8deJA= +github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= +github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/wire v0.6.0 h1:HBkoIh4BdSxoyo9PveV8giw7ZsaBOvzWKfcg/6MrVwI= +github.com/google/wire v0.6.0/go.mod h1:F4QhpQ9EDIdJ1Mbop/NZBRB+5yrR6qg3BnctaoUk6NA= +github.com/googleapis/enterprise-certificate-proxy v0.3.4 h1:XYIDZApgAnrN1c855gTgghdIA6Stxb52D5RnLI1SLyw= +github.com/googleapis/enterprise-certificate-proxy v0.3.4/go.mod h1:YKe7cfqYXjKGpGvmSg28/fFvhNzinZQm8DGnaburhGA= +github.com/googleapis/gax-go/v2 v2.13.0 h1:yitjD5f7jQHhyDsnhKEBU52NdvvdSeGzlAnDPT0hH1s= +github.com/googleapis/gax-go/v2 v2.13.0/go.mod h1:Z/fvTZXF8/uw7Xu5GuslPw+bplx6SS338j1Is2S+B7A= github.com/gookit/color v1.4.2/go.mod h1:fqRyamkC1W8uxl+lxCQxOT09l/vYfZ+QeiX3rKQHCoQ= github.com/gookit/color v1.5.0/go.mod h1:43aQb+Zerm/BWh2GnrgOQm7ffz7tvQXEKV6BFMl7wAo= github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0= github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w= github.com/hamba/avro/v2 v2.26.0 h1:IaT5l6W3zh7K67sMrT2+RreJyDTllBGVJm4+Hedk9qE= github.com/hamba/avro/v2 v2.26.0/go.mod h1:I8glyswHnpED3Nlx2ZdUe+4LJnCOOyiCzLMno9i/Uu0= +github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/asmfmt v1.3.2 h1:4Ri7ox3EwapiOjCki+hw14RyKk201CN4rzyCJRFLpK4= @@ -121,6 +199,7 @@ github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/pterm/pterm v0.12.27/go.mod h1:PhQ89w4i95rhgE+xedAoqous6K9X+r6aSOI2eFF7DZI= github.com/pterm/pterm v0.12.29/go.mod h1:WI3qxgvoQFFGKGjGnJR849gU0TsEOvKn5Q8LlY1U7lg= github.com/pterm/pterm v0.12.30/go.mod h1:MOqLIyMOgmTDz9yorcYbcw+HsgoZo3BQfg2wtl3HEFE= @@ -136,12 +215,17 @@ github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUc github.com/sergi/go-diff v1.2.0 h1:XU+rvMAioB0UC3q1MFrIQy4Vo5/4VsRDQQXHsEya6xQ= github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/twmb/murmur3 v1.1.8 h1:8Yt9taO/WN3l08xErzjeschgZU2QSrwm1kclYq+0aRg= @@ -156,26 +240,72 @@ github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= +go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= +go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 h1:r6I7RJCN86bpD/FQwedZ0vSixDpwuWREjW9oRMsmqDc= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0/go.mod h1:B9yO6b04uB80CzjedvewuqDhxJxi11s7/GtiGa8bAjI= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 h1:TT4fX+nBOA/+LUkobKGW1ydGcn+G3vRw9+g5HwCphpk= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0/go.mod h1:L7UH0GbB0p47T4Rri3uHjbpCFYrVrwc1I25QhNPiGK8= +go.opentelemetry.io/otel v1.29.0 h1:PdomN/Al4q/lN6iBJEN3AwPvUiHPMlt93c8bqTG5Llw= +go.opentelemetry.io/otel v1.29.0/go.mod h1:N/WtXPs1CNCUEx+Agz5uouwCba+i+bJGFicT8SR4NP8= +go.opentelemetry.io/otel/metric v1.29.0 h1:vPf/HFWTNkPu1aYeIsc98l4ktOQaL6LeSoeV2g+8YLc= +go.opentelemetry.io/otel/metric v1.29.0/go.mod h1:auu/QWieFVWx+DmQOUMgj0F8LHWdgalxXqvp7BII/W8= +go.opentelemetry.io/otel/sdk v1.29.0 h1:vkqKjk7gwhS8VaWb0POZKmIEDimRCMsopNYnriHyryo= +go.opentelemetry.io/otel/sdk v1.29.0/go.mod h1:pM8Dx5WKnvxLCb+8lG1PRNIDxu9g9b9g59Qr7hfAAok= +go.opentelemetry.io/otel/trace v1.29.0 h1:J/8ZNK4XgR7a21DZUAsbF8pZ5Jcw1VhACmnYt39JTi4= +go.opentelemetry.io/otel/trace v1.29.0/go.mod h1:eHl3w0sp3paPkYstJOmAimxhiFXPg+MMTlEh3nsQgWQ= +gocloud.dev v0.40.0 h1:f8LgP+4WDqOG/RXoUcyLpeIAGOcAbZrZbDQCUee10ng= +gocloud.dev v0.40.0/go.mod h1:drz+VyYNBvrMTW0KZiBAYEdl8lbNZx+OQ7oQvdrFmSQ= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= +golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= +golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= +golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk= golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= -golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= +golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= +golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= +golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs= +golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -185,39 +315,91 @@ golang.org/x/sys v0.0.0-20220319134239-a9b59b0215f8/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= -golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= +golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.24.0 h1:Mh5cbb+Zk2hqqXNO7S1iTjEphVL+jb8ZWaqh/g+JWkM= -golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= +golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= +golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= +golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= -golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= +golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= +golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= +golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps= golang.org/x/tools v0.25.0 h1:oFU9pkj/iJgs+0DT+VMHrx+oBKs/LJMV+Uvg78sl+fE= golang.org/x/tools v0.25.0/go.mod h1:/vtpO8WL1N9cQC3FN5zPqb//fRXskFHbLKk4OW1Q7rg= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 h1:+cNy6SZtPcJQH3LJVLOSmiC7MMxXNOb3PU/VUEz+EhU= -golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20240716161551-93cc26a95ae9 h1:LLhsEBxRTBLuKlQxFBYUOU8xyFgXv6cOTp2HASDlsDk= +golang.org/x/xerrors v0.0.0-20240716161551-93cc26a95ae9/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= gonum.org/v1/gonum v0.15.1 h1:FNy7N6OUZVUaWG9pTiD+jlhdQ3lMP+/LcTpJ6+a8sQ0= gonum.org/v1/gonum v0.15.1/go.mod h1:eZTZuRFrzu5pcyjN5wJhcIhnUdNijYxX1T2IcrOGY0o= +google.golang.org/api v0.201.0 h1:+7AD9JNM3tREtawRMu8sOjSbb8VYcYXJG/2eEOmfDu0= +google.golang.org/api v0.201.0/go.mod h1:HVY0FCHVs89xIW9fzf/pBvOEm+OolHa86G/txFezyq4= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= +google.golang.org/genproto v0.0.0-20241007155032-5fefd90f89a9 h1:nFS3IivktIU5Mk6KQa+v6RKkHUpdQpphqGNLxqNnbEk= +google.golang.org/genproto v0.0.0-20241007155032-5fefd90f89a9/go.mod h1:tEzYTYZxbmVNOu0OAFH9HzdJtLn6h4Aj89zzlBCdHms= +google.golang.org/genproto/googleapis/api v0.0.0-20240930140551-af27646dc61f h1:jTm13A2itBi3La6yTGqn8bVSrc3ZZ1r8ENHlIXBfnRA= +google.golang.org/genproto/googleapis/api v0.0.0-20240930140551-af27646dc61f/go.mod h1:CLGoBuH1VHxAUXVPP8FfPwPEVJB6lz3URE5mY2SuayE= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241007155032-5fefd90f89a9 h1:QCqS/PdaHTSWGvupk2F/ehwHtGc0/GYkT+3GAcR1CCc= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241007155032-5fefd90f89a9/go.mod h1:GX3210XPVPUjJbTUbvwI8f2IpZDMZuPJWDzDuebbviI= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= +google.golang.org/grpc v1.67.1 h1:zWnc1Vrcno+lHZCOofnIMvycFcc0QRGIzm9dhnDX68E= +google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA= +google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/internal/avro_schemas.go b/internal/avro_schemas.go index 893b7b2..1541768 100644 --- a/internal/avro_schemas.go +++ b/internal/avro_schemas.go @@ -1,568 +1,568 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package internal - -import "github.com/hamba/avro/v2" - -const ( - ManifestListV1Key = "manifest-list-v1" - ManifestListV2Key = "manifest-list-v2" - ManifestEntryV1Key = "manifest-entry-v1" - ManifestEntryV2Key = "manifest-entry-v2" -) - -var ( - AvroSchemaCache avro.SchemaCache -) - -func init() { - AvroSchemaCache.Add(ManifestListV1Key, avro.MustParse(`{ - "type": "record", - "name": "manifest_file", - "fields": [ - {"name": "manifest_path", "type": "string", "doc": "Location URI with FS scheme", "field-id": 500}, - {"name": "manifest_length", "type": "long", "doc": "Total file size in bytes", "field-id": 501}, - {"name": "partition_spec_id", "type": "int", "doc": "Spec ID used to write", "field-id": 502}, - { - "name": "added_snapshot_id", - "type": "long", - "doc": "Snapshot ID that added the manifest", - "field-id": 503 - }, - { - "name": "added_data_files_count", - "type": ["null", "int"], - "doc": "Added entry count", - "field-id": 504 - }, - { - "name": "existing_data_files_count", - "type": ["null", "int"], - "doc": "Existing entry count", - "field-id": 505 - }, - { - "name": "deleted_data_files_count", - "type": ["null", "int"], - "doc": "Deleted entry count", - "field-id": 506 - }, - { - "name": "partitions", - "type": [ - "null", - { - "type": "array", - "items": { - "type": "record", - "name": "r508", - "fields": [ - { - "name": "contains_null", - "type": "boolean", - "doc": "True if any file has a null partition value", - "field-id": 509 - }, - { - "name": "contains_nan", - "type": ["null", "boolean"], - "doc": "True if any file has a nan partition value", - "field-id": 518 - }, - { - "name": "lower_bound", - "type": ["null", "bytes"], - "doc": "Partition lower bound for all files", - "field-id": 510 - }, - { - "name": "upper_bound", - "type": ["null", "bytes"], - "doc": "Partition upper bound for all files", - "field-id": 511 - } - ] - }, - "element-id": 508 - } - ], - "doc": "Summary for each partition", - "field-id": 507 - }, - {"name": "added_rows_count", "type": ["null", "long"], "doc": "Added rows count", "field-id": 512}, - { - "name": "existing_rows_count", - "type": ["null", "long"], - "doc": "Existing rows count", - "field-id": 513 - }, - { - "name": "deleted_rows_count", - "type": ["null", "long"], - "doc": "Deleted rows count", - "field-id": 514 - } - ] - }`)) - - AvroSchemaCache.Add(ManifestListV2Key, avro.MustParse(`{ - "type": "record", - "name": "manifest_file", - "fields": [ - {"name": "manifest_path", "type": "string", "doc": "Location URI with FS scheme", "field-id": 500}, - {"name": "manifest_length", "type": "long", "doc": "Total file size in bytes", "field-id": 501}, - {"name": "partition_spec_id", "type": "int", "doc": "Spec ID used to write", "field-id": 502}, - {"name": "content", "type": "int", "doc": "Contents of the manifest: 0=data, 1=deletes", "field-id": 517}, - { - "name": "sequence_number", - "type": "long", - "doc": "Sequence number when the manifest was added", - "field-id": 515 - }, - { - "name": "min_sequence_number", - "type": "long", - "doc": "Lowest sequence number in the manifest", - "field-id": 516 - }, - {"name": "added_snapshot_id", "type": "long", "doc": "Snapshot ID that added the manifest", "field-id": 503}, - {"name": "added_files_count", "type": "int", "doc": "Added entry count", "field-id": 504}, - {"name": "existing_files_count", "type": "int", "doc": "Existing entry count", "field-id": 505}, - {"name": "deleted_files_count", "type": "int", "doc": "Deleted entry count", "field-id": 506}, - {"name": "added_rows_count", "type": "long", "doc": "Added rows count", "field-id": 512}, - {"name": "existing_rows_count", "type": "long", "doc": "Existing rows count", "field-id": 513}, - {"name": "deleted_rows_count", "type": "long", "doc": "Deleted rows count", "field-id": 514}, - { - "name": "partitions", - "type": [ - "null", - { - "type": "array", - "items": { - "type": "record", - "name": "r508", - "fields": [ - { - "name": "contains_null", - "type": "boolean", - "doc": "True if any file has a null partition value", - "field-id": 509 - }, - { - "name": "contains_nan", - "type": ["null", "boolean"], - "doc": "True if any file has a nan partition value", - "field-id": 518 - }, - { - "name": "lower_bound", - "type": ["null", "bytes"], - "doc": "Partition lower bound for all files", - "field-id": 510 - }, - { - "name": "upper_bound", - "type": ["null", "bytes"], - "doc": "Partition upper bound for all files", - "field-id": 511 - } - ] - }, - "element-id": 508 - } - ], - "doc": "Summary for each partition", - "field-id": 507 - } - ] - }`)) - - AvroSchemaCache.Add(ManifestEntryV1Key, avro.MustParse(`{ - "type": "record", - "name": "manifest_entry", - "fields": [ - {"name": "status", "type": "int", "field-id": 0}, - {"name": "snapshot_id", "type": "long", "field-id": 1}, - { - "name": "data_file", - "type": { - "type": "record", - "name": "r2", - "fields": [ - {"name": "file_path", "type": "string", "doc": "Location URI with FS scheme", "field-id": 100}, - { - "name": "file_format", - "type": "string", - "doc": "File format name: avro, orc, or parquet", - "field-id": 101 - }, - { - "name": "partition", - "type": { - "type": "record", - "name": "r102", - "fields": [ - {"field-id": 1000, "name": "VendorID", "type": ["null", "int"]}, - { - "field-id": 1001, - "name": "tpep_pickup_datetime", - "type": ["null", {"type": "int", "logicalType": "date"}] - } - ] - }, - "field-id": 102 - }, - {"name": "record_count", "type": "long", "doc": "Number of records in the file", "field-id": 103}, - {"name": "file_size_in_bytes", "type": "long", "doc": "Total file size in bytes", "field-id": 104}, - {"name": "block_size_in_bytes", "type": "long", "field-id": 105}, - { - "name": "column_sizes", - "type": [ - "null", - { - "type": "array", - "items": { - "type": "record", - "name": "k117_v118", - "fields": [ - {"name": "key", "type": "int", "field-id": 117}, - {"name": "value", "type": "long", "field-id": 118} - ] - }, - "logicalType": "map" - } - ], - "doc": "Map of column id to total size on disk", - "field-id": 108 - }, - { - "name": "value_counts", - "type": [ - "null", - { - "type": "array", - "items": { - "type": "record", - "name": "k119_v120", - "fields": [ - {"name": "key", "type": "int", "field-id": 119}, - {"name": "value", "type": "long", "field-id": 120} - ] - }, - "logicalType": "map" - } - ], - "doc": "Map of column id to total count, including null and NaN", - "field-id": 109 - }, - { - "name": "null_value_counts", - "type": [ - "null", - { - "type": "array", - "items": { - "type": "record", - "name": "k121_v122", - "fields": [ - {"name": "key", "type": "int", "field-id": 121}, - {"name": "value", "type": "long", "field-id": 122} - ] - }, - "logicalType": "map" - } - ], - "doc": "Map of column id to null value count", - "field-id": 110 - }, - { - "name": "nan_value_counts", - "type": [ - "null", - { - "type": "array", - "items": { - "type": "record", - "name": "k138_v139", - "fields": [ - {"name": "key", "type": "int", "field-id": 138}, - {"name": "value", "type": "long", "field-id": 139} - ] - }, - "logicalType": "map" - } - ], - "doc": "Map of column id to number of NaN values in the column", - "field-id": 137 - }, - { - "name": "lower_bounds", - "type": [ - "null", - { - "type": "array", - "items": { - "type": "record", - "name": "k126_v127", - "fields": [ - {"name": "key", "type": "int", "field-id": 126}, - {"name": "value", "type": "bytes", "field-id": 127} - ] - }, - "logicalType": "map" - } - ], - "doc": "Map of column id to lower bound", - "field-id": 125 - }, - { - "name": "upper_bounds", - "type": [ - "null", - { - "type": "array", - "items": { - "type": "record", - "name": "k129_v130", - "fields": [ - {"name": "key", "type": "int", "field-id": 129}, - {"name": "value", "type": "bytes", "field-id": 130} - ] - }, - "logicalType": "map" - } - ], - "doc": "Map of column id to upper bound", - "field-id": 128 - }, - { - "name": "key_metadata", - "type": ["null", "bytes"], - "doc": "Encryption key metadata blob", - "field-id": 131 - }, - { - "name": "split_offsets", - "type": ["null", {"type": "array", "items": "long", "element-id": 133}], - "doc": "Splittable offsets", - "field-id": 132 - }, - { - "name": "sort_order_id", - "type": ["null", "int"], - "doc": "Sort order ID", - "field-id": 140 - } - ] - }, - "field-id": 2 - } - ] - }`)) - - AvroSchemaCache.Add(ManifestEntryV2Key, avro.MustParse(`{ - "type": "record", - "name": "manifest_entry", - "fields": [ - {"name": "status", "type": "int", "field-id": 0}, - {"name": "snapshot_id", "type": ["null", "long"], "field-id": 1}, - {"name": "sequence_number", "type": ["null", "long"], "field-id": 3}, - {"name": "file_sequence_number", "type": ["null", "long"], "field-id": 4}, - { - "name": "data_file", - "type": { - "type": "record", - "name": "r2", - "fields": [ - {"name": "content", "type": "int", "doc": "Type of content stored by the data file", "field-id": 134}, - {"name": "file_path", "type": "string", "doc": "Location URI with FS scheme", "field-id": 100}, - { - "name": "file_format", - "type": "string", - "doc": "File format name: avro, orc, or parquet", - "field-id": 101 - }, - { - "name": "partition", - "type": { - "type": "record", - "name": "r102", - "fields": [ - {"field-id": 1000, "name": "VendorID", "type": ["null", "int"]}, - { - "field-id": 1001, - "name": "tpep_pickup_datetime", - "type": ["null", {"type": "int", "logicalType": "date"}] - } - ] - }, - "field-id": 102 - }, - {"name": "record_count", "type": "long", "doc": "Number of records in the file", "field-id": 103}, - {"name": "file_size_in_bytes", "type": "long", "doc": "Total file size in bytes", "field-id": 104}, - { - "name": "column_sizes", - "type": [ - "null", - { - "type": "array", - "items": { - "type": "record", - "name": "k117_v118", - "fields": [ - {"name": "key", "type": "int", "field-id": 117}, - {"name": "value", "type": "long", "field-id": 118} - ] - }, - "logicalType": "map" - } - ], - "doc": "Map of column id to total size on disk", - "field-id": 108 - }, - { - "name": "value_counts", - "type": [ - "null", - { - "type": "array", - "items": { - "type": "record", - "name": "k119_v120", - "fields": [ - {"name": "key", "type": "int", "field-id": 119}, - {"name": "value", "type": "long", "field-id": 120} - ] - }, - "logicalType": "map" - } - ], - "doc": "Map of column id to total count, including null and NaN", - "field-id": 109 - }, - { - "name": "null_value_counts", - "type": [ - "null", - { - "type": "array", - "items": { - "type": "record", - "name": "k121_v122", - "fields": [ - {"name": "key", "type": "int", "field-id": 121}, - {"name": "value", "type": "long", "field-id": 122} - ] - }, - "logicalType": "map" - } - ], - "doc": "Map of column id to null value count", - "field-id": 110 - }, - { - "name": "nan_value_counts", - "type": [ - "null", - { - "type": "array", - "items": { - "type": "record", - "name": "k138_v139", - "fields": [ - {"name": "key", "type": "int", "field-id": 138}, - {"name": "value", "type": "long", "field-id": 139} - ] - }, - "logicalType": "map" - } - ], - "doc": "Map of column id to number of NaN values in the column", - "field-id": 137 - }, - { - "name": "lower_bounds", - "type": [ - "null", - { - "type": "array", - "items": { - "type": "record", - "name": "k126_v127", - "fields": [ - {"name": "key", "type": "int", "field-id": 126}, - {"name": "value", "type": "bytes", "field-id": 127} - ] - }, - "logicalType": "map" - } - ], - "doc": "Map of column id to lower bound", - "field-id": 125 - }, - { - "name": "upper_bounds", - "type": [ - "null", - { - "type": "array", - "items": { - "type": "record", - "name": "k129_v130", - "fields": [ - {"name": "key", "type": "int", "field-id": 129}, - {"name": "value", "type": "bytes", "field-id": 130} - ] - }, - "logicalType": "map" - } - ], - "doc": "Map of column id to upper bound", - "field-id": 128 - }, - { - "name": "key_metadata", - "type": ["null", "bytes"], - "doc": "Encryption key metadata blob", - "field-id": 131 - }, - { - "name": "split_offsets", - "type": ["null", {"type": "array", "items": "long", "element-id": 133}], - "doc": "Splittable offsets", - "field-id": 132 - }, - { - "name": "equality_ids", - "type": ["null", {"type": "array", "items": "int", "element-id": 136}], - "doc": "Field ids used to determine row equality for delete files", - "field-id": 135 - }, - { - "name": "sort_order_id", - "type": ["null", "int"], - "doc": "Sort order ID", - "field-id": 140 - } - ] - }, - "field-id": 2 - } - ] - }`)) -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package internal + +import "github.com/hamba/avro/v2" + +const ( + ManifestListV1Key = "manifest-list-v1" + ManifestListV2Key = "manifest-list-v2" + ManifestEntryV1Key = "manifest-entry-v1" + ManifestEntryV2Key = "manifest-entry-v2" +) + +var ( + AvroSchemaCache avro.SchemaCache +) + +func init() { + AvroSchemaCache.Add(ManifestListV1Key, avro.MustParse(`{ + "type": "record", + "name": "manifest_file", + "fields": [ + {"name": "manifest_path", "type": "string", "doc": "Location URI with FS scheme", "field-id": 500}, + {"name": "manifest_length", "type": "long", "doc": "Total file size in bytes", "field-id": 501}, + {"name": "partition_spec_id", "type": "int", "doc": "Spec ID used to write", "field-id": 502}, + { + "name": "added_snapshot_id", + "type": "long", + "doc": "Snapshot ID that added the manifest", + "field-id": 503 + }, + { + "name": "added_data_files_count", + "type": ["null", "int"], + "doc": "Added entry count", + "field-id": 504 + }, + { + "name": "existing_data_files_count", + "type": ["null", "int"], + "doc": "Existing entry count", + "field-id": 505 + }, + { + "name": "deleted_data_files_count", + "type": ["null", "int"], + "doc": "Deleted entry count", + "field-id": 506 + }, + { + "name": "partitions", + "type": [ + "null", + { + "type": "array", + "items": { + "type": "record", + "name": "r508", + "fields": [ + { + "name": "contains_null", + "type": "boolean", + "doc": "True if any file has a null partition value", + "field-id": 509 + }, + { + "name": "contains_nan", + "type": ["null", "boolean"], + "doc": "True if any file has a nan partition value", + "field-id": 518 + }, + { + "name": "lower_bound", + "type": ["null", "bytes"], + "doc": "Partition lower bound for all files", + "field-id": 510 + }, + { + "name": "upper_bound", + "type": ["null", "bytes"], + "doc": "Partition upper bound for all files", + "field-id": 511 + } + ] + }, + "element-id": 508 + } + ], + "doc": "Summary for each partition", + "field-id": 507 + }, + {"name": "added_rows_count", "type": ["null", "long"], "doc": "Added rows count", "field-id": 512}, + { + "name": "existing_rows_count", + "type": ["null", "long"], + "doc": "Existing rows count", + "field-id": 513 + }, + { + "name": "deleted_rows_count", + "type": ["null", "long"], + "doc": "Deleted rows count", + "field-id": 514 + } + ] + }`)) + + AvroSchemaCache.Add(ManifestListV2Key, avro.MustParse(`{ + "type": "record", + "name": "manifest_file", + "fields": [ + {"name": "manifest_path", "type": "string", "doc": "Location URI with FS scheme", "field-id": 500}, + {"name": "manifest_length", "type": "long", "doc": "Total file size in bytes", "field-id": 501}, + {"name": "partition_spec_id", "type": "int", "doc": "Spec ID used to write", "field-id": 502}, + {"name": "content", "type": "int", "doc": "Contents of the manifest: 0=data, 1=deletes", "field-id": 517}, + { + "name": "sequence_number", + "type": "long", + "doc": "Sequence number when the manifest was added", + "field-id": 515 + }, + { + "name": "min_sequence_number", + "type": "long", + "doc": "Lowest sequence number in the manifest", + "field-id": 516 + }, + {"name": "added_snapshot_id", "type": "long", "doc": "Snapshot ID that added the manifest", "field-id": 503}, + {"name": "added_files_count", "type": "int", "doc": "Added entry count", "field-id": 504}, + {"name": "existing_files_count", "type": "int", "doc": "Existing entry count", "field-id": 505}, + {"name": "deleted_files_count", "type": "int", "doc": "Deleted entry count", "field-id": 506}, + {"name": "added_rows_count", "type": "long", "doc": "Added rows count", "field-id": 512}, + {"name": "existing_rows_count", "type": "long", "doc": "Existing rows count", "field-id": 513}, + {"name": "deleted_rows_count", "type": "long", "doc": "Deleted rows count", "field-id": 514}, + { + "name": "partitions", + "type": [ + "null", + { + "type": "array", + "items": { + "type": "record", + "name": "r508", + "fields": [ + { + "name": "contains_null", + "type": "boolean", + "doc": "True if any file has a null partition value", + "field-id": 509 + }, + { + "name": "contains_nan", + "type": ["null", "boolean"], + "doc": "True if any file has a nan partition value", + "field-id": 518 + }, + { + "name": "lower_bound", + "type": ["null", "bytes"], + "doc": "Partition lower bound for all files", + "field-id": 510 + }, + { + "name": "upper_bound", + "type": ["null", "bytes"], + "doc": "Partition upper bound for all files", + "field-id": 511 + } + ] + }, + "element-id": 508 + } + ], + "doc": "Summary for each partition", + "field-id": 507 + } + ] + }`)) + + AvroSchemaCache.Add(ManifestEntryV1Key, avro.MustParse(`{ + "type": "record", + "name": "manifest_entry", + "fields": [ + {"name": "status", "type": "int", "field-id": 0}, + {"name": "snapshot_id", "type": "long", "field-id": 1}, + { + "name": "data_file", + "type": { + "type": "record", + "name": "r2", + "fields": [ + {"name": "file_path", "type": "string", "doc": "Location URI with FS scheme", "field-id": 100}, + { + "name": "file_format", + "type": "string", + "doc": "File format name: avro, orc, or parquet", + "field-id": 101 + }, + { + "name": "partition", + "type": { + "type": "record", + "name": "r102", + "fields": [ + {"field-id": 1000, "name": "VendorID", "type": ["null", "int"]}, + { + "field-id": 1001, + "name": "tpep_pickup_datetime", + "type": ["null", {"type": "int", "logicalType": "date"}] + } + ] + }, + "field-id": 102 + }, + {"name": "record_count", "type": "long", "doc": "Number of records in the file", "field-id": 103}, + {"name": "file_size_in_bytes", "type": "long", "doc": "Total file size in bytes", "field-id": 104}, + {"name": "block_size_in_bytes", "type": "long", "field-id": 105}, + { + "name": "column_sizes", + "type": [ + "null", + { + "type": "array", + "items": { + "type": "record", + "name": "k117_v118", + "fields": [ + {"name": "key", "type": "int", "field-id": 117}, + {"name": "value", "type": "long", "field-id": 118} + ] + }, + "logicalType": "map" + } + ], + "doc": "Map of column id to total size on disk", + "field-id": 108 + }, + { + "name": "value_counts", + "type": [ + "null", + { + "type": "array", + "items": { + "type": "record", + "name": "k119_v120", + "fields": [ + {"name": "key", "type": "int", "field-id": 119}, + {"name": "value", "type": "long", "field-id": 120} + ] + }, + "logicalType": "map" + } + ], + "doc": "Map of column id to total count, including null and NaN", + "field-id": 109 + }, + { + "name": "null_value_counts", + "type": [ + "null", + { + "type": "array", + "items": { + "type": "record", + "name": "k121_v122", + "fields": [ + {"name": "key", "type": "int", "field-id": 121}, + {"name": "value", "type": "long", "field-id": 122} + ] + }, + "logicalType": "map" + } + ], + "doc": "Map of column id to null value count", + "field-id": 110 + }, + { + "name": "nan_value_counts", + "type": [ + "null", + { + "type": "array", + "items": { + "type": "record", + "name": "k138_v139", + "fields": [ + {"name": "key", "type": "int", "field-id": 138}, + {"name": "value", "type": "long", "field-id": 139} + ] + }, + "logicalType": "map" + } + ], + "doc": "Map of column id to number of NaN values in the column", + "field-id": 137 + }, + { + "name": "lower_bounds", + "type": [ + "null", + { + "type": "array", + "items": { + "type": "record", + "name": "k126_v127", + "fields": [ + {"name": "key", "type": "int", "field-id": 126}, + {"name": "value", "type": "bytes", "field-id": 127} + ] + }, + "logicalType": "map" + } + ], + "doc": "Map of column id to lower bound", + "field-id": 125 + }, + { + "name": "upper_bounds", + "type": [ + "null", + { + "type": "array", + "items": { + "type": "record", + "name": "k129_v130", + "fields": [ + {"name": "key", "type": "int", "field-id": 129}, + {"name": "value", "type": "bytes", "field-id": 130} + ] + }, + "logicalType": "map" + } + ], + "doc": "Map of column id to upper bound", + "field-id": 128 + }, + { + "name": "key_metadata", + "type": ["null", "bytes"], + "doc": "Encryption key metadata blob", + "field-id": 131 + }, + { + "name": "split_offsets", + "type": ["null", {"type": "array", "items": "long", "element-id": 133}], + "doc": "Splittable offsets", + "field-id": 132 + }, + { + "name": "sort_order_id", + "type": ["null", "int"], + "doc": "Sort order ID", + "field-id": 140 + } + ] + }, + "field-id": 2 + } + ] + }`)) + + AvroSchemaCache.Add(ManifestEntryV2Key, avro.MustParse(`{ + "type": "record", + "name": "manifest_entry", + "fields": [ + {"name": "status", "type": "int", "field-id": 0}, + {"name": "snapshot_id", "type": ["null", "long"], "field-id": 1}, + {"name": "sequence_number", "type": ["null", "long"], "field-id": 3}, + {"name": "file_sequence_number", "type": ["null", "long"], "field-id": 4}, + { + "name": "data_file", + "type": { + "type": "record", + "name": "r2", + "fields": [ + {"name": "content", "type": "int", "doc": "Type of content stored by the data file", "field-id": 134}, + {"name": "file_path", "type": "string", "doc": "Location URI with FS scheme", "field-id": 100}, + { + "name": "file_format", + "type": "string", + "doc": "File format name: avro, orc, or parquet", + "field-id": 101 + }, + { + "name": "partition", + "type": { + "type": "record", + "name": "r102", + "fields": [ + {"field-id": 1000, "name": "VendorID", "type": ["null", "int"]}, + { + "field-id": 1001, + "name": "tpep_pickup_datetime", + "type": ["null", {"type": "int", "logicalType": "date"}] + } + ] + }, + "field-id": 102 + }, + {"name": "record_count", "type": "long", "doc": "Number of records in the file", "field-id": 103}, + {"name": "file_size_in_bytes", "type": "long", "doc": "Total file size in bytes", "field-id": 104}, + { + "name": "column_sizes", + "type": [ + "null", + { + "type": "array", + "items": { + "type": "record", + "name": "k117_v118", + "fields": [ + {"name": "key", "type": "int", "field-id": 117}, + {"name": "value", "type": "long", "field-id": 118} + ] + }, + "logicalType": "map" + } + ], + "doc": "Map of column id to total size on disk", + "field-id": 108 + }, + { + "name": "value_counts", + "type": [ + "null", + { + "type": "array", + "items": { + "type": "record", + "name": "k119_v120", + "fields": [ + {"name": "key", "type": "int", "field-id": 119}, + {"name": "value", "type": "long", "field-id": 120} + ] + }, + "logicalType": "map" + } + ], + "doc": "Map of column id to total count, including null and NaN", + "field-id": 109 + }, + { + "name": "null_value_counts", + "type": [ + "null", + { + "type": "array", + "items": { + "type": "record", + "name": "k121_v122", + "fields": [ + {"name": "key", "type": "int", "field-id": 121}, + {"name": "value", "type": "long", "field-id": 122} + ] + }, + "logicalType": "map" + } + ], + "doc": "Map of column id to null value count", + "field-id": 110 + }, + { + "name": "nan_value_counts", + "type": [ + "null", + { + "type": "array", + "items": { + "type": "record", + "name": "k138_v139", + "fields": [ + {"name": "key", "type": "int", "field-id": 138}, + {"name": "value", "type": "long", "field-id": 139} + ] + }, + "logicalType": "map" + } + ], + "doc": "Map of column id to number of NaN values in the column", + "field-id": 137 + }, + { + "name": "lower_bounds", + "type": [ + "null", + { + "type": "array", + "items": { + "type": "record", + "name": "k126_v127", + "fields": [ + {"name": "key", "type": "int", "field-id": 126}, + {"name": "value", "type": "bytes", "field-id": 127} + ] + }, + "logicalType": "map" + } + ], + "doc": "Map of column id to lower bound", + "field-id": 125 + }, + { + "name": "upper_bounds", + "type": [ + "null", + { + "type": "array", + "items": { + "type": "record", + "name": "k129_v130", + "fields": [ + {"name": "key", "type": "int", "field-id": 129}, + {"name": "value", "type": "bytes", "field-id": 130} + ] + }, + "logicalType": "map" + } + ], + "doc": "Map of column id to upper bound", + "field-id": 128 + }, + { + "name": "key_metadata", + "type": ["null", "bytes"], + "doc": "Encryption key metadata blob", + "field-id": 131 + }, + { + "name": "split_offsets", + "type": ["null", {"type": "array", "items": "long", "element-id": 133}], + "doc": "Splittable offsets", + "field-id": 132 + }, + { + "name": "equality_ids", + "type": ["null", {"type": "array", "items": "int", "element-id": 136}], + "doc": "Field ids used to determine row equality for delete files", + "field-id": 135 + }, + { + "name": "sort_order_id", + "type": ["null", "int"], + "doc": "Sort order ID", + "field-id": 140 + } + ] + }, + "field-id": 2 + } + ] + }`)) +} diff --git a/internal/mock_fs.go b/internal/mock_fs.go index 95f6c3f..94496a8 100644 --- a/internal/mock_fs.go +++ b/internal/mock_fs.go @@ -1,85 +1,85 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package internal - -import ( - "bytes" - "errors" - "io/fs" - - "github.com/apache/iceberg-go/io" - "github.com/stretchr/testify/mock" -) - -type MockFS struct { - mock.Mock -} - -func (m *MockFS) Open(name string) (io.File, error) { - args := m.Called(name) - return args.Get(0).(io.File), args.Error(1) -} - -func (m *MockFS) Remove(name string) error { - return m.Called(name).Error(0) -} - -type MockFSReadFile struct { - MockFS -} - -func (m *MockFSReadFile) ReadFile(name string) ([]byte, error) { - args := m.Called(name) - return args.Get(0).([]byte), args.Error(1) -} - -type MockFile struct { - Contents *bytes.Reader - - closed bool -} - -func (m *MockFile) Stat() (fs.FileInfo, error) { - return nil, nil -} - -func (m *MockFile) Read(p []byte) (int, error) { - return m.Contents.Read(p) -} - -func (m *MockFile) Close() error { - if m.closed { - return errors.New("already closed") - } - m.closed = true - return nil -} - -func (m *MockFile) ReadAt(p []byte, off int64) (n int, err error) { - if m.closed { - return 0, errors.New("already closed") - } - return m.Contents.ReadAt(p, off) -} - -func (m *MockFile) Seek(offset int64, whence int) (n int64, err error) { - if m.closed { - return 0, errors.New("already closed") - } - return m.Contents.Seek(offset, whence) -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package internal + +import ( + "bytes" + "errors" + "io/fs" + + "github.com/apache/iceberg-go/io" + "github.com/stretchr/testify/mock" +) + +type MockFS struct { + mock.Mock +} + +func (m *MockFS) Open(name string) (io.File, error) { + args := m.Called(name) + return args.Get(0).(io.File), args.Error(1) +} + +func (m *MockFS) Remove(name string) error { + return m.Called(name).Error(0) +} + +type MockFSReadFile struct { + MockFS +} + +func (m *MockFSReadFile) ReadFile(name string) ([]byte, error) { + args := m.Called(name) + return args.Get(0).([]byte), args.Error(1) +} + +type MockFile struct { + Contents *bytes.Reader + + closed bool +} + +func (m *MockFile) Stat() (fs.FileInfo, error) { + return nil, nil +} + +func (m *MockFile) Read(p []byte) (int, error) { + return m.Contents.Read(p) +} + +func (m *MockFile) Close() error { + if m.closed { + return errors.New("already closed") + } + m.closed = true + return nil +} + +func (m *MockFile) ReadAt(p []byte, off int64) (n int, err error) { + if m.closed { + return 0, errors.New("already closed") + } + return m.Contents.ReadAt(p, off) +} + +func (m *MockFile) Seek(offset int64, whence int) (n int64, err error) { + if m.closed { + return 0, errors.New("already closed") + } + return m.Contents.Seek(offset, whence) +} diff --git a/io/blob.go b/io/blob.go new file mode 100644 index 0000000..474ba25 --- /dev/null +++ b/io/blob.go @@ -0,0 +1,333 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package io + +import ( + "context" + "io" + "io/fs" + "net/url" + "path/filepath" + "strings" + "time" + + "gocloud.dev/blob" + "gocloud.dev/blob/memblob" +) + +// iofsFileInfo describes a single file in an io/fs.FS. +// It implements fs.FileInfo and fs.DirEntry. +// Copied from `gocloud.dev/blob.iofsDir` because it is private +type iofsFileInfo struct { + lo *blob.ListObject + name string +} + +func (f *iofsFileInfo) Name() string { return f.name } +func (f *iofsFileInfo) Size() int64 { return f.lo.Size } +func (f *iofsFileInfo) Mode() fs.FileMode { return fs.ModeIrregular } +func (f *iofsFileInfo) ModTime() time.Time { return f.lo.ModTime } +func (f *iofsFileInfo) IsDir() bool { return false } +func (f *iofsFileInfo) Sys() interface{} { return f.lo } +func (f *iofsFileInfo) Info() (fs.FileInfo, error) { return f, nil } +func (f *iofsFileInfo) Type() fs.FileMode { return fs.ModeIrregular } + +// iofsDir describes a single directory in an `iceberg-go/io.FS`. +// It implements `io/fs.FileInfo`, `io/fs.DirEntry`, and `io/fs.File`. +// Copied from `gocloud.dev/blob.iofsDir`, but modified to use `iceberg-go/io.File` instead of `io/fs.File` +type iofsDir struct { + b *BlobFileIO + key string + name string + // If opened is true, we've read entries via openOnce(). + opened bool + entries []fs.DirEntry + offset int +} + +func newDir(b *BlobFileIO, key, name string) *iofsDir { + return &iofsDir{b: b, key: key, name: name} +} + +func (d *iofsDir) Name() string { return d.name } +func (d *iofsDir) Size() int64 { return 0 } +func (d *iofsDir) Mode() fs.FileMode { return fs.ModeDir } +func (d *iofsDir) Type() fs.FileMode { return fs.ModeDir } +func (d *iofsDir) ModTime() time.Time { return time.Time{} } +func (d *iofsDir) IsDir() bool { return true } +func (d *iofsDir) Sys() interface{} { return d } +func (d *iofsDir) Info() (fs.FileInfo, error) { return d, nil } +func (d *iofsDir) Stat() (fs.FileInfo, error) { return d, nil } +func (d *iofsDir) Read([]byte) (int, error) { + return 0, &fs.PathError{Op: "read", Path: d.key, Err: fs.ErrInvalid} +} +func (d *iofsDir) ReadAt(p []byte, off int64) (int, error) { + return 0, &fs.PathError{Op: "readAt", Path: d.key, Err: fs.ErrInvalid} +} +func (d *iofsDir) Seek(offset int64, whence int) (int64, error) { + return 0, &fs.PathError{Op: "seek", Path: d.key, Err: fs.ErrInvalid} +} +func (d *iofsDir) Close() error { return nil } +func (d *iofsDir) ReadDir(count int) ([]fs.DirEntry, error) { + if err := d.openOnce(); err != nil { + return nil, err + } + n := len(d.entries) - d.offset + if n == 0 && count > 0 { + return nil, io.EOF + } + if count > 0 && n > count { + n = count + } + list := make([]fs.DirEntry, n) + for i := range list { + list[i] = d.entries[d.offset+i] + } + d.offset += n + return list, nil +} + +func (d *iofsDir) openOnce() error { + if d.opened { + return nil + } + d.opened = true + + // blob expects directories to end in the delimiter, except at the top level. + prefix := d.key + if prefix != "" { + prefix += "/" + } + listOpts := blob.ListOptions{ + Prefix: prefix, + Delimiter: "/", + } + ctx := d.b.ctx + + // Fetch all the directory entries. + // Conceivably we could only fetch a few here, and fetch the rest lazily + // on demand, but that would add significant complexity. + iter := d.b.List(&listOpts) + for { + item, err := iter.Next(ctx) + if err == io.EOF { + break + } + if err != nil { + return err + } + name := filepath.Base(item.Key) + if item.IsDir { + d.entries = append(d.entries, newDir(d.b, item.Key, name)) + } else { + d.entries = append(d.entries, &iofsFileInfo{item, name}) + } + } + // There is no such thing as an empty directory in Bucket, so if + // we didn't find anything, it doesn't exist. + if len(d.entries) == 0 { + return fs.ErrNotExist + } + return nil +} + +// blobOpenFile describes a single open blob as a File. +// It implements the iceberg-go/io.File interface. +// It is based on gocloud.dev/blob.iofsOpenFile which: +// - Doesn't support the `io.ReaderAt` interface +// - Is not externally accessible, so copied here +type blobOpenFile struct { + *blob.Reader + name string +} + +func (f *blobOpenFile) ReadAt(p []byte, off int64) (int, error) { + finalOff, err := f.Reader.Seek(off, io.SeekStart) + if err != nil { + return -1, err + } else if finalOff != off { + return -1, io.ErrUnexpectedEOF + } + + return f.Read(p) +} + +// Functions to implement the `Stat()` function in the `io/fs.File` interface + +func (f *blobOpenFile) Name() string { return f.name } +func (f *blobOpenFile) Mode() fs.FileMode { return fs.ModeIrregular } +func (f *blobOpenFile) Sys() interface{} { return f.Reader } +func (f *blobOpenFile) IsDir() bool { return false } +func (f *blobOpenFile) Stat() (fs.FileInfo, error) { return f, nil } + +// BlobFileIO represents a file system backed by a bucket in object store. It implements the `iceberg-go/io.FileIO` interface. +type BlobFileIO struct { + *blob.Bucket + ctx context.Context + opts *blob.ReaderOptions + prefix string +} + +func (io *BlobFileIO) preprocess(n string) string { + _, after, found := strings.Cut(n, "://") + if found { + n = after + } + + out := strings.TrimPrefix(n, io.prefix) + if out == "/" { + out = "." + } else { + out = strings.TrimPrefix(out, "/") + } + + return out +} + +// Open a Blob from a Bucket using the BlobFileIO. Note this +// function is copied from blob.Bucket.Open, but extended to +// return a iceberg-go/io.File instance instead of io/fs.File +func (io *BlobFileIO) Open(path string) (File, error) { + if _, err := url.Parse(path); err != nil { + return nil, &fs.PathError{Op: "open", Path: path, Err: fs.ErrInvalid} + } + path = io.preprocess(path) + + var isDir bool + var key, name string // name is the last part of the path + if path == "." { + // Root is always a directory, but blob doesn't want the "." in the key. + isDir = true + key, name = "", "." + } else { + exists, _ := io.Bucket.Exists(io.ctx, path) + isDir = !exists + key, name = path, filepath.Base(path) + } + + // If it's a directory, list the directory contents. We can't do this lazily + // because we need to error out here if it doesn't exist. + if isDir { + dir := newDir(io, key, name) + err := dir.openOnce() + if err != nil { + if err == fs.ErrNotExist && path == "." { + // The root directory must exist. + return dir, nil + } + return nil, &fs.PathError{Op: "open", Path: path, Err: err} + } + return dir, nil + } + + // It's a file; open it and return a wrapper. + r, err := io.Bucket.NewReader(io.ctx, path, io.opts) + if err != nil { + return nil, &fs.PathError{Op: "open", Path: path, Err: err} + } + + return &blobOpenFile{Reader: r, name: name}, nil +} + +// Remove a Blob from a Bucket using the BlobFileIO +func (io *BlobFileIO) Remove(path string) error { + if !fs.ValidPath(path) { + return &fs.PathError{Op: "remove", Path: path, Err: fs.ErrInvalid} + } + path = io.preprocess(path) + + return io.Bucket.Delete(io.ctx, path) +} + +// NewWriter returns a Writer that writes to the blob stored at path. +// A nil WriterOptions is treated the same as the zero value. +// +// If overwrite is disabled and a blob with this path already exists, +// an error will be returned. +// +// The returned Writer will store ctx for later use in Write and/or Close. +// To abort a write, cancel ctx; otherwise, it must remain open until +// Close is called. +// +// The caller must call Close on the returned Writer, even if the write is +// aborted. +func (io *BlobFileIO) NewWriter(path string, overwrite bool, opts *blob.WriterOptions) (w *BlobWriteFile, err error) { + if !fs.ValidPath(path) { + return nil, &fs.PathError{Op: "new writer", Path: path, Err: fs.ErrInvalid} + } + path = io.preprocess(path) + if !overwrite { + if exists, err := io.Bucket.Exists(io.ctx, path); exists { + if err != nil { + return nil, &fs.PathError{Op: "new writer", Path: path, Err: err} + } + return nil, &fs.PathError{Op: "new writer", Path: path, Err: fs.ErrInvalid} + } + } + bw, err := io.Bucket.NewWriter(io.ctx, path, opts) + if err != nil { + return nil, err + } + return &BlobWriteFile{ + Writer: bw, + name: path, + opts: opts}, + nil +} + +func urlToBucketPath(parsed *url.URL) (string, string) { + return parsed.Host, parsed.Path +} + +// Create a new BlobFileIO instance +func CreateBlobFileIO(parsed *url.URL, props map[string]string) (*BlobFileIO, error) { + ctx := context.Background() + + var bucket *blob.Bucket + var err error + switch parsed.Scheme { + case "mem": + // memblob doesn't use the URL host or path + bucket = memblob.OpenBucket(nil) + case "s3", "s3a", "s3n": + bucket, err = createS3Bucket(ctx, parsed, props) + case "gs": + bucket, err = createGCSBucket(ctx, parsed, props) + } + + if err != nil { + return nil, err + } + + if parsed.Path != "" && parsed.Path != "/" { + bucket = blob.PrefixedBucket(bucket, strings.TrimPrefix(parsed.Path, "/")) + } + + return &BlobFileIO{Bucket: bucket, ctx: ctx, opts: &blob.ReaderOptions{}, prefix: parsed.Host + parsed.Path}, nil +} + +type BlobWriteFile struct { + *blob.Writer + name string + opts *blob.WriterOptions +} + +func (f *BlobWriteFile) Name() string { return f.name } +func (f *BlobWriteFile) Sys() interface{} { return f.Writer } +func (f *BlobWriteFile) Close() error { return f.Writer.Close() } +func (f *BlobWriteFile) Write(p []byte) (int, error) { return f.Writer.Write(p) } diff --git a/io/gcs_cdk.go b/io/gcs_cdk.go new file mode 100644 index 0000000..232306f --- /dev/null +++ b/io/gcs_cdk.go @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package io + +import ( + "context" + "net/url" + + "gocloud.dev/blob" + "gocloud.dev/blob/gcsblob" + "gocloud.dev/gcp" + "google.golang.org/api/option" +) + +// Constants for GCS configuration options +const ( + GCSEndpoint = "gcs.endpoint" + GCSKeyPath = "gcs.keypath" + GCSJSONKey = "gcs.jsonkey" +) + +func ParseGCSConfig(props map[string]string) *gcsblob.Options { + var o []option.ClientOption + if url := props[GCSEndpoint]; url != "" { + o = append(o, option.WithEndpoint(url)) + } + if key := props[GCSJSONKey]; key != "" { + o = append(o, option.WithCredentialsJSON([]byte(key))) + } + if path := props[GCSKeyPath]; path != "" { + o = append(o, option.WithCredentialsFile(path)) + } + return &gcsblob.Options{ + ClientOptions: o, + } +} + +// Construct a S3 bucket from a URL +func createGCSBucket(ctx context.Context, parsed *url.URL, props map[string]string) (*blob.Bucket, error) { + gcscfg := ParseGCSConfig(props) + client := gcp.NewAnonymousHTTPClient(gcp.DefaultTransport()) + bucket, err := gcsblob.OpenBucket(ctx, client, parsed.Host, gcscfg) + if err != nil { + return nil, err + } + + return bucket, nil +} diff --git a/io/io.go b/io/io.go index abe5971..6aeee25 100644 --- a/io/io.go +++ b/io/io.go @@ -1,248 +1,266 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package io - -import ( - "errors" - "fmt" - "io" - "io/fs" - "net/url" - "strings" -) - -// IO is an interface to a hierarchical file system. -// -// The IO interface is the minimum implementation required for a file -// system to utilize an iceberg table. A file system may implement -// additional interfaces, such as ReadFileIO, to provide additional or -// optimized functionality. -type IO interface { - // Open opens the named file. - // - // When Open returns an error, it should be of type *PathError - // with the Op field set to "open", the Path field set to name, - // and the Err field describing the problem. - // - // Open should reject attempts to open names that do not satisfy - // fs.ValidPath(name), returning a *PathError with Err set to - // ErrInvalid or ErrNotExist. - Open(name string) (File, error) - - // Remove removes the named file or (empty) directory. - // - // If there is an error, it will be of type *PathError. - Remove(name string) error -} - -// ReadFileIO is the interface implemented by a file system that -// provides an optimized implementation of ReadFile. -type ReadFileIO interface { - IO - - // ReadFile reads the named file and returns its contents. - // A successful call returns a nil error, not io.EOF. - // (Because ReadFile reads the whole file, the expected EOF - // from the final Read is not treated as an error to be reported.) - // - // The caller is permitted to modify the returned byte slice. - // This method should return a copy of the underlying data. - ReadFile(name string) ([]byte, error) -} - -// A File provides access to a single file. The File interface is the -// minimum implementation required for Iceberg to interact with a file. -// Directory files should also implement -type File interface { - fs.File - io.ReadSeekCloser - io.ReaderAt -} - -// A ReadDirFile is a directory file whose entries can be read with the -// ReadDir method. Every directory file should implement this interface. -// (It is permissible for any file to implement this interface, but -// if so ReadDir should return an error for non-directories.) -type ReadDirFile interface { - File - - // ReadDir read the contents of the directory and returns a slice - // of up to n DirEntry values in directory order. Subsequent calls - // on the same file will yield further DirEntry values. - // - // If n > 0, ReadDir returns at most n DirEntry structures. In this - // case, if ReadDir returns an empty slice, it will return a non-nil - // error explaining why. - // - // At the end of a directory, the error is io.EOF. (ReadDir must return - // io.EOF itself, not an error wrapping io.EOF.) - // - // If n <= 0, ReadDir returns all the DirEntry values from the directory - // in a single slice. In this case, if ReadDir succeeds (reads all the way - // to the end of the directory), it returns the slice and a nil error. - // If it encounters an error before the end of the directory, ReadDir - // returns the DirEntry list read until that point and a non-nil error. - ReadDir(n int) ([]fs.DirEntry, error) -} - -// FS wraps an io/fs.FS as an IO interface. -func FS(fsys fs.FS) IO { - if _, ok := fsys.(fs.ReadFileFS); ok { - return readFileFS{ioFS{fsys, nil}} - } - return ioFS{fsys, nil} -} - -// FSPreProcName wraps an io/fs.FS like FS, only if fn is non-nil then -// it is called to preprocess any filenames before they are passed to -// the underlying fsys. -func FSPreProcName(fsys fs.FS, fn func(string) string) IO { - if _, ok := fsys.(fs.ReadFileFS); ok { - return readFileFS{ioFS{fsys, fn}} - } - return ioFS{fsys, fn} -} - -type readFileFS struct { - ioFS -} - -func (r readFileFS) ReadFile(name string) ([]byte, error) { - if r.preProcessName != nil { - name = r.preProcessName(name) - } - - rfs, ok := r.fsys.(fs.ReadFileFS) - if !ok { - return nil, errMissingReadFile - } - return rfs.ReadFile(name) -} - -type ioFS struct { - fsys fs.FS - - preProcessName func(string) string -} - -func (f ioFS) Open(name string) (File, error) { - if f.preProcessName != nil { - name = f.preProcessName(name) - } - - if name == "/" { - name = "." - } else { - name = strings.TrimPrefix(name, "/") - } - file, err := f.fsys.Open(name) - if err != nil { - return nil, err - } - - return ioFile{file}, nil -} - -func (f ioFS) Remove(name string) error { - r, ok := f.fsys.(interface{ Remove(name string) error }) - if !ok { - return errMissingRemove - } - return r.Remove(name) -} - -var ( - errMissingReadDir = errors.New("fs.File directory missing ReadDir method") - errMissingSeek = errors.New("fs.File missing Seek method") - errMissingReadAt = errors.New("fs.File missing ReadAt") - errMissingRemove = errors.New("fs.FS missing Remove method") - errMissingReadFile = errors.New("fs.FS missing ReadFile method") -) - -type ioFile struct { - file fs.File -} - -func (f ioFile) Close() error { return f.file.Close() } -func (f ioFile) Read(b []byte) (int, error) { return f.file.Read(b) } -func (f ioFile) Stat() (fs.FileInfo, error) { return f.file.Stat() } -func (f ioFile) Seek(offset int64, whence int) (int64, error) { - s, ok := f.file.(io.Seeker) - if !ok { - return 0, errMissingSeek - } - return s.Seek(offset, whence) -} - -func (f ioFile) ReadAt(p []byte, off int64) (n int, err error) { - r, ok := f.file.(io.ReaderAt) - if !ok { - return 0, errMissingReadAt - } - return r.ReadAt(p, off) -} - -func (f ioFile) ReadDir(count int) ([]fs.DirEntry, error) { - d, ok := f.file.(fs.ReadDirFile) - if !ok { - return nil, errMissingReadDir - } - - return d.ReadDir(count) -} - -func inferFileIOFromSchema(path string, props map[string]string) (IO, error) { - parsed, err := url.Parse(path) - if err != nil { - return nil, err - } - - switch parsed.Scheme { - case "s3", "s3a", "s3n": - return createS3FileIO(parsed, props) - case "file", "": - return LocalFS{}, nil - default: - return nil, fmt.Errorf("IO for file '%s' not implemented", path) - } -} - -// LoadFS takes a map of properties and an optional URI location -// and attempts to infer an IO object from it. -// -// A schema of "file://" or an empty string will result in a LocalFS -// implementation. Otherwise this will return an error if the schema -// does not yet have an implementation here. -// -// Currently only LocalFS and S3 are implemented. -func LoadFS(props map[string]string, location string) (IO, error) { - if location == "" { - location = props["warehouse"] - } - - iofs, err := inferFileIOFromSchema(location, props) - if err != nil { - return nil, err - } - - if iofs == nil { - iofs = LocalFS{} - } - - return iofs, nil -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package io + +import ( + "errors" + "fmt" + "io" + "io/fs" + "net/url" + "strings" +) + +// IO is an interface to a hierarchical file system. +// +// The IO interface is the minimum implementation required for a file +// system to utilize an iceberg table. A file system may implement +// additional interfaces, such as ReadFileIO, to provide additional or +// optimized functionality. +type IO interface { + // Open opens the named file. + // + // When Open returns an error, it should be of type *PathError + // with the Op field set to "open", the Path field set to name, + // and the Err field describing the problem. + // + // Open should reject attempts to open names that do not satisfy + // fs.ValidPath(name), returning a *PathError with Err set to + // ErrInvalid or ErrNotExist. + Open(name string) (File, error) + + // Remove removes the named file or (empty) directory. + // + // If there is an error, it will be of type *PathError. + Remove(name string) error +} + +// ReadFileIO is the interface implemented by a file system that +// provides an optimized implementation of ReadFile. +type ReadFileIO interface { + IO + + // ReadFile reads the named file and returns its contents. + // A successful call returns a nil error, not io.EOF. + // (Because ReadFile reads the whole file, the expected EOF + // from the final Read is not treated as an error to be reported.) + // + // The caller is permitted to modify the returned byte slice. + // This method should return a copy of the underlying data. + ReadFile(name string) ([]byte, error) +} + +// WriteFileIO is the interface implemented by a file system that +// provides an optimized implementation of WriteFile +type WriteFileIO interface { + IO + + // WriteFile writes p to the named file. + // An error will be returned if the file already exists. + Write(name string, p []byte) error + Close() error +} + +// A File provides access to a single file. The File interface is the +// minimum implementation required for Iceberg to interact with a file. +// Directory files should also implement +type File interface { + fs.File + io.ReadSeekCloser + io.ReaderAt +} + +// A ReadDirFile is a directory file whose entries can be read with the +// ReadDir method. Every directory file should implement this interface. +// (It is permissible for any file to implement this interface, but +// if so ReadDir should return an error for non-directories.) +type ReadDirFile interface { + File + + // ReadDir read the contents of the directory and returns a slice + // of up to n DirEntry values in directory order. Subsequent calls + // on the same file will yield further DirEntry values. + // + // If n > 0, ReadDir returns at most n DirEntry structures. In this + // case, if ReadDir returns an empty slice, it will return a non-nil + // error explaining why. + // + // At the end of a directory, the error is io.EOF. (ReadDir must return + // io.EOF itself, not an error wrapping io.EOF.) + // + // If n <= 0, ReadDir returns all the DirEntry values from the directory + // in a single slice. In this case, if ReadDir succeeds (reads all the way + // to the end of the directory), it returns the slice and a nil error. + // If it encounters an error before the end of the directory, ReadDir + // returns the DirEntry list read until that point and a non-nil error. + ReadDir(n int) ([]fs.DirEntry, error) +} + +// FS wraps an io/fs.FS as an IO interface. +func FS(fsys fs.FS) IO { + if _, ok := fsys.(fs.ReadFileFS); ok { + return readFileFS{ioFS{fsys, nil}} + } + return ioFS{fsys, nil} +} + +// FSPreProcName wraps an io/fs.FS like FS, only if fn is non-nil then +// it is called to preprocess any filenames before they are passed to +// the underlying fsys. +func FSPreProcName(fsys fs.FS, fn func(string) string) IO { + if _, ok := fsys.(fs.ReadFileFS); ok { + return readFileFS{ioFS{fsys, fn}} + } + return ioFS{fsys, fn} +} + +type readFileFS struct { + ioFS +} + +func (r readFileFS) ReadFile(name string) ([]byte, error) { + if r.preProcessName != nil { + name = r.preProcessName(name) + } + + rfs, ok := r.fsys.(fs.ReadFileFS) + if !ok { + return nil, errMissingReadFile + } + return rfs.ReadFile(name) +} + +type ioFS struct { + fsys fs.FS + + preProcessName func(string) string +} + +func (f ioFS) Open(name string) (File, error) { + if f.preProcessName != nil { + name = f.preProcessName(name) + } + + if name == "/" { + name = "." + } else { + name = strings.TrimPrefix(name, "/") + } + file, err := f.fsys.Open(name) + if err != nil { + return nil, err + } + + return ioFile{file}, nil +} + +func (f ioFS) Remove(name string) error { + r, ok := f.fsys.(interface{ Remove(name string) error }) + if !ok { + return errMissingRemove + } + return r.Remove(name) +} + +var ( + errMissingReadDir = errors.New("fs.File directory missing ReadDir method") + errMissingSeek = errors.New("fs.File missing Seek method") + errMissingReadAt = errors.New("fs.File missing ReadAt") + errMissingRemove = errors.New("fs.FS missing Remove method") + errMissingReadFile = errors.New("fs.FS missing ReadFile method") +) + +type ioFile struct { + file fs.File +} + +func (f ioFile) Close() error { return f.file.Close() } +func (f ioFile) Read(b []byte) (int, error) { return f.file.Read(b) } +func (f ioFile) Stat() (fs.FileInfo, error) { return f.file.Stat() } +func (f ioFile) Seek(offset int64, whence int) (int64, error) { + s, ok := f.file.(io.Seeker) + if !ok { + return 0, errMissingSeek + } + return s.Seek(offset, whence) +} + +func (f ioFile) ReadAt(p []byte, off int64) (n int, err error) { + r, ok := f.file.(io.ReaderAt) + if !ok { + return 0, errMissingReadAt + } + return r.ReadAt(p, off) +} + +func (f ioFile) ReadDir(count int) ([]fs.DirEntry, error) { + d, ok := f.file.(fs.ReadDirFile) + if !ok { + return nil, errMissingReadDir + } + + return d.ReadDir(count) +} + +func inferFileIOFromSchema(path string, props map[string]string) (IO, error) { + parsed, err := url.Parse(path) + if err != nil { + return nil, err + } + + switch parsed.Scheme { + case "s3", "s3a", "s3n": + if props["s3.use-cdk"] == "true" { + return CreateBlobFileIO(parsed, props) + } + return createS3FileIO(parsed, props) + case "gs": + return CreateBlobFileIO(parsed, props) + case "mem": + return CreateBlobFileIO(parsed, props) + case "file", "": + return LocalFS{}, nil + default: + return nil, fmt.Errorf("IO for file '%s' not implemented", path) + } +} + +// LoadFS takes a map of properties and an optional URI location +// and attempts to infer an IO object from it. +// +// A schema of "file://" or an empty string will result in a LocalFS +// implementation. Otherwise this will return an error if the schema +// does not yet have an implementation here. +// +// Currently local, S3, GCS, and In-Memory FSs are implemented. +func LoadFS(props map[string]string, location string) (IO, error) { + if location == "" { + location = props["warehouse"] + } + + iofs, err := inferFileIOFromSchema(location, props) + if err != nil { + return nil, err + } + + if iofs == nil { + iofs = LocalFS{} + } + + return iofs, nil +} diff --git a/io/local.go b/io/local.go index 560d9be..cb80e2a 100644 --- a/io/local.go +++ b/io/local.go @@ -1,35 +1,35 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package io - -import ( - "os" - "strings" -) - -// LocalFS is an implementation of IO that implements interaction with -// the local file system. -type LocalFS struct{} - -func (LocalFS) Open(name string) (File, error) { - return os.Open(strings.TrimPrefix(name, "file://")) -} - -func (LocalFS) Remove(name string) error { - return os.Remove(name) -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package io + +import ( + "os" + "strings" +) + +// LocalFS is an implementation of IO that implements interaction with +// the local file system. +type LocalFS struct{} + +func (LocalFS) Open(name string) (File, error) { + return os.Open(strings.TrimPrefix(name, "file://")) +} + +func (LocalFS) Remove(name string) error { + return os.Remove(name) +} diff --git a/io/s3.go b/io/s3.go index 7396130..31fb074 100644 --- a/io/s3.go +++ b/io/s3.go @@ -1,114 +1,142 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package io - -import ( - "context" - "fmt" - "net/http" - "net/url" - "os" - "strings" - - awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" - "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/credentials" - "github.com/aws/aws-sdk-go-v2/service/s3" - "github.com/aws/smithy-go/auth/bearer" - "github.com/wolfeidau/s3iofs" -) - -// Constants for S3 configuration options -const ( - S3Region = "s3.region" - S3SessionToken = "s3.session-token" - S3SecretAccessKey = "s3.secret-access-key" - S3AccessKeyID = "s3.access-key-id" - S3EndpointURL = "s3.endpoint" - S3ProxyURI = "s3.proxy-uri" -) - -func createS3FileIO(parsed *url.URL, props map[string]string) (IO, error) { - cfgOpts := []func(*config.LoadOptions) error{} - opts := []func(*s3.Options){} - - endpoint, ok := props[S3EndpointURL] - if !ok { - endpoint = os.Getenv("AWS_S3_ENDPOINT") - } - - if endpoint != "" { - opts = append(opts, func(o *s3.Options) { - o.BaseEndpoint = &endpoint - }) - } - - if tok, ok := props["token"]; ok { - cfgOpts = append(cfgOpts, config.WithBearerAuthTokenProvider( - &bearer.StaticTokenProvider{Token: bearer.Token{Value: tok}})) - } - - if region, ok := props[S3Region]; ok { - opts = append(opts, func(o *s3.Options) { - o.Region = region - }) - } else if region, ok := props["client.region"]; ok { - opts = append(opts, func(o *s3.Options) { - o.Region = region - }) - } - - accessKey, secretAccessKey := props[S3AccessKeyID], props[S3SecretAccessKey] - token := props[S3SessionToken] - if accessKey != "" || secretAccessKey != "" || token != "" { - opts = append(opts, func(o *s3.Options) { - o.Credentials = credentials.NewStaticCredentialsProvider( - props[S3AccessKeyID], props[S3SecretAccessKey], props[S3SessionToken]) - }) - } - - if proxy, ok := props[S3ProxyURI]; ok { - proxyURL, err := url.Parse(proxy) - if err != nil { - return nil, fmt.Errorf("invalid s3 proxy url '%s'", proxy) - } - - opts = append(opts, func(o *s3.Options) { - o.HTTPClient = awshttp.NewBuildableClient().WithTransportOptions( - func(t *http.Transport) { t.Proxy = http.ProxyURL(proxyURL) }) - }) - } - - awscfg, err := config.LoadDefaultConfig(context.Background(), cfgOpts...) - if err != nil { - return nil, err - } - - s3Client := s3.NewFromConfig(awscfg, opts...) - preprocess := func(n string) string { - _, after, found := strings.Cut(n, "://") - if found { - n = after - } - - return strings.TrimPrefix(n, parsed.Host) - } - - s3fs := s3iofs.NewWithClient(parsed.Host, s3Client) - return FSPreProcName(s3fs, preprocess), nil -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package io + +import ( + "context" + "fmt" + "net/http" + "net/url" + "os" + "slices" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws" + awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/smithy-go/auth/bearer" + "github.com/wolfeidau/s3iofs" +) + +// Constants for S3 configuration options +const ( + S3Region = "s3.region" + S3SessionToken = "s3.session-token" + S3SecretAccessKey = "s3.secret-access-key" + S3AccessKeyID = "s3.access-key-id" + S3EndpointURL = "s3.endpoint" + S3ProxyURI = "s3.proxy-uri" + S3ConnectTimeout = "s3.connect-timeout" + S3SignerUri = "s3.signer.uri" +) + +var unsupportedS3Props = []string{ + S3ConnectTimeout, + S3SignerUri, +} + +func ParseAWSConfig(props map[string]string) (*aws.Config, error) { + // If any unsupported properties are set, return an error. + for k := range props { + if slices.Contains(unsupportedS3Props, k) { + return nil, fmt.Errorf("unsupported S3 property %q", k) + } + } + + opts := []func(*config.LoadOptions) error{} + endpoint, ok := props[S3EndpointURL] + if !ok { + endpoint = os.Getenv("AWS_S3_ENDPOINT") + } + + if endpoint != "" { + opts = append(opts, config.WithEndpointResolverWithOptions(aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) { + if service != s3.ServiceID { + // fallback to default resolution for the service + return aws.Endpoint{}, &aws.EndpointNotFoundError{} + } + + return aws.Endpoint{ + URL: endpoint, + SigningRegion: region, + HostnameImmutable: true, + }, nil + }))) + } + + if tok, ok := props["token"]; ok { + opts = append(opts, config.WithBearerAuthTokenProvider( + &bearer.StaticTokenProvider{Token: bearer.Token{Value: tok}})) + } + + if region, ok := props[S3Region]; ok { + opts = append(opts, config.WithRegion(region)) + } else if region, ok := props["client.region"]; ok { + opts = append(opts, config.WithRegion(region)) + } + + accessKey, secretAccessKey := props[S3AccessKeyID], props[S3SecretAccessKey] + token := props[S3SessionToken] + if accessKey != "" || secretAccessKey != "" || token != "" { + opts = append(opts, config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider( + props[S3AccessKeyID], props[S3SecretAccessKey], props[S3SessionToken]))) + } + + if proxy, ok := props[S3ProxyURI]; ok { + proxyURL, err := url.Parse(proxy) + if err != nil { + return nil, fmt.Errorf("invalid s3 proxy url '%s'", proxy) + } + + opts = append(opts, config.WithHTTPClient(awshttp.NewBuildableClient().WithTransportOptions( + func(t *http.Transport) { + t.Proxy = http.ProxyURL(proxyURL) + }, + ))) + } + + awscfg := new(aws.Config) + var err error + *awscfg, err = config.LoadDefaultConfig(context.Background(), opts...) + if err != nil { + return nil, err + } + + return awscfg, nil +} + +func createS3FileIO(parsed *url.URL, props map[string]string) (IO, error) { + awscfg, err := ParseAWSConfig(props) + if err != nil { + return nil, err + } + + preprocess := func(n string) string { + _, after, found := strings.Cut(n, "://") + if found { + n = after + } + + return strings.TrimPrefix(n, parsed.Host) + } + + s3fs := s3iofs.New(parsed.Host, *awscfg) + return FSPreProcName(s3fs, preprocess), nil +} diff --git a/io/s3_cdk.go b/io/s3_cdk.go new file mode 100644 index 0000000..b712bb7 --- /dev/null +++ b/io/s3_cdk.go @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package io + +import ( + "context" + "net/url" + + "github.com/aws/aws-sdk-go-v2/service/s3" + "gocloud.dev/blob" + "gocloud.dev/blob/s3blob" +) + +// Construct a S3 bucket from a URL +func createS3Bucket(ctx context.Context, parsed *url.URL, props map[string]string) (*blob.Bucket, error) { + awscfg, err := ParseAWSConfig(props) + if err != nil { + return nil, err + } + + client := s3.NewFromConfig(*awscfg) + + // Create a *blob.Bucket. + bucket, err := s3blob.OpenBucketV2(ctx, client, parsed.Host, nil) + if err != nil { + return nil, err + } + + return bucket, nil +} diff --git a/literals.go b/literals.go index 2e16d02..c3ca1fc 100644 --- a/literals.go +++ b/literals.go @@ -1,1170 +1,1170 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package iceberg - -import ( - "bytes" - "cmp" - "encoding" - "encoding/binary" - "errors" - "fmt" - "math" - "math/big" - "reflect" - "strconv" - "time" - "unsafe" - - "github.com/apache/arrow-go/v18/arrow" - "github.com/apache/arrow-go/v18/arrow/decimal128" - "github.com/google/uuid" -) - -// LiteralType is a generic type constraint for the explicit Go types that we allow -// for literal values. This represents the actual primitive types that exist in Iceberg -type LiteralType interface { - bool | int32 | int64 | float32 | float64 | Date | - Time | Timestamp | string | []byte | uuid.UUID | Decimal -} - -// Comparator is a comparison function for specific literal types: -// -// returns 0 if v1 == v2 -// returns <0 if v1 < v2 -// returns >0 if v1 > v2 -type Comparator[T LiteralType] func(v1, v2 T) int - -// Literal is a non-null literal value. It can be casted using To and be checked for -// equality against other literals. -type Literal interface { - fmt.Stringer - encoding.BinaryMarshaler - - Type() Type - To(Type) (Literal, error) - Equals(Literal) bool -} - -// TypedLiteral is a generic interface for Literals so that you can retrieve the value. -// This is based on the physical representative type, which means that FixedLiteral and -// BinaryLiteral will both return []byte, etc. -type TypedLiteral[T LiteralType] interface { - Literal - - Value() T - Comparator() Comparator[T] -} - -type NumericLiteral interface { - Literal - Increment() Literal - Decrement() Literal -} - -// NewLiteral provides a literal based on the type of T -func NewLiteral[T LiteralType](val T) Literal { - switch v := any(val).(type) { - case bool: - return BoolLiteral(v) - case int32: - return Int32Literal(v) - case int64: - return Int64Literal(v) - case float32: - return Float32Literal(v) - case float64: - return Float64Literal(v) - case Date: - return DateLiteral(v) - case Time: - return TimeLiteral(v) - case Timestamp: - return TimestampLiteral(v) - case string: - return StringLiteral(v) - case []byte: - return BinaryLiteral(v) - case uuid.UUID: - return UUIDLiteral(v) - case Decimal: - return DecimalLiteral(v) - } - panic("can't happen due to literal type constraint") -} - -// LiteralFromBytes uses the defined Iceberg spec for how to serialize a value of -// a the provided type and returns the appropriate Literal value from it. -// -// If you already have a value of the desired Literal type, you could alternatively -// call UnmarshalBinary on it yourself manually. -// -// This is primarily used for retrieving stat values. -func LiteralFromBytes(typ Type, data []byte) (Literal, error) { - if data == nil { - return nil, ErrInvalidBinSerialization - } - - switch t := typ.(type) { - case BooleanType: - var v BoolLiteral - err := v.UnmarshalBinary(data) - return v, err - case Int32Type: - var v Int32Literal - err := v.UnmarshalBinary(data) - return v, err - case Int64Type: - var v Int64Literal - err := v.UnmarshalBinary(data) - return v, err - case Float32Type: - var v Float32Literal - err := v.UnmarshalBinary(data) - return v, err - case Float64Type: - var v Float64Literal - err := v.UnmarshalBinary(data) - return v, err - case StringType: - var v StringLiteral - err := v.UnmarshalBinary(data) - return v, err - case BinaryType: - var v BinaryLiteral - err := v.UnmarshalBinary(data) - return v, err - case FixedType: - if len(data) != t.Len() { - return nil, fmt.Errorf("%w: expected length %d for type %s, got %d", - ErrInvalidBinSerialization, t.Len(), t, len(data)) - } - var v FixedLiteral - err := v.UnmarshalBinary(data) - return v, err - case DecimalType: - v := DecimalLiteral{Scale: t.scale} - err := v.UnmarshalBinary(data) - return v, err - case DateType: - var v DateLiteral - err := v.UnmarshalBinary(data) - return v, err - case TimeType: - var v TimeLiteral - err := v.UnmarshalBinary(data) - return v, err - case TimestampType, TimestampTzType: - var v TimestampLiteral - err := v.UnmarshalBinary(data) - return v, err - case UUIDType: - var v UUIDLiteral - err := v.UnmarshalBinary(data) - return v, err - } - - return nil, ErrType -} - -// convenience to avoid repreating this pattern for primitive types -func literalEq[L interface { - comparable - LiteralType -}, T TypedLiteral[L]](lhs T, other Literal) bool { - rhs, ok := other.(T) - if !ok { - return false - } - - return lhs.Value() == rhs.Value() -} - -// AboveMaxLiteral represents values that are above the maximum for their type -// such as values > math.MaxInt32 for an Int32Literal -type AboveMaxLiteral interface { - Literal - - aboveMax() -} - -// BelowMinLiteral represents values that are below the minimum for their type -// such as values < math.MinInt32 for an Int32Literal -type BelowMinLiteral interface { - Literal - - belowMin() -} - -type aboveMaxLiteral[T int32 | int64 | float32 | float64] struct { - value T -} - -func (ab aboveMaxLiteral[T]) MarshalBinary() (data []byte, err error) { - return nil, fmt.Errorf("%w: cannot marshal above max literal", - ErrInvalidBinSerialization) -} - -func (ab aboveMaxLiteral[T]) aboveMax() {} - -func (ab aboveMaxLiteral[T]) Type() Type { - var z T - switch any(z).(type) { - case int32: - return PrimitiveTypes.Int32 - case int64: - return PrimitiveTypes.Int64 - case float32: - return PrimitiveTypes.Float32 - case float64: - return PrimitiveTypes.Float64 - default: - panic("should never happen") - } -} - -func (ab aboveMaxLiteral[T]) To(t Type) (Literal, error) { - if ab.Type().Equals(t) { - return ab, nil - } - return nil, fmt.Errorf("%w: cannot change type of AboveMax%sLiteral", - ErrBadCast, reflect.TypeOf(T(0)).String()) -} - -func (ab aboveMaxLiteral[T]) Value() T { return ab.value } - -func (ab aboveMaxLiteral[T]) String() string { return "AboveMax" } -func (ab aboveMaxLiteral[T]) Equals(other Literal) bool { - // AboveMaxLiteral isn't comparable and thus isn't even equal to itself - return false -} - -type belowMinLiteral[T int32 | int64 | float32 | float64] struct { - value T -} - -func (bm belowMinLiteral[T]) MarshalBinary() (data []byte, err error) { - return nil, fmt.Errorf("%w: cannot marshal above max literal", - ErrInvalidBinSerialization) -} - -func (bm belowMinLiteral[T]) belowMin() {} - -func (bm belowMinLiteral[T]) Type() Type { - var z T - switch any(z).(type) { - case int32: - return PrimitiveTypes.Int32 - case int64: - return PrimitiveTypes.Int64 - case float32: - return PrimitiveTypes.Float32 - case float64: - return PrimitiveTypes.Float64 - default: - panic("should never happen") - } -} - -func (bm belowMinLiteral[T]) To(t Type) (Literal, error) { - if bm.Type().Equals(t) { - return bm, nil - } - return nil, fmt.Errorf("%w: cannot change type of BelowMin%sLiteral", - ErrBadCast, reflect.TypeOf(T(0)).String()) -} - -func (bm belowMinLiteral[T]) Value() T { return bm.value } - -func (bm belowMinLiteral[T]) String() string { return "BelowMin" } -func (bm belowMinLiteral[T]) Equals(other Literal) bool { - // BelowMinLiteral isn't comparable and thus isn't even equal to itself - return false -} - -func Int32AboveMaxLiteral() Literal { - return aboveMaxLiteral[int32]{value: math.MaxInt32} -} - -func Int64AboveMaxLiteral() Literal { - return aboveMaxLiteral[int64]{value: math.MaxInt64} -} - -func Float32AboveMaxLiteral() Literal { - return aboveMaxLiteral[float32]{value: math.MaxFloat32} -} - -func Float64AboveMaxLiteral() Literal { - return aboveMaxLiteral[float64]{value: math.MaxFloat64} -} - -func Int32BelowMinLiteral() Literal { - return belowMinLiteral[int32]{value: math.MinInt32} -} - -func Int64BelowMinLiteral() Literal { - return belowMinLiteral[int64]{value: math.MinInt64} -} - -func Float32BelowMinLiteral() Literal { - return belowMinLiteral[float32]{value: -math.MaxFloat32} -} - -func Float64BelowMinLiteral() Literal { - return belowMinLiteral[float64]{value: -math.MaxFloat64} -} - -type BoolLiteral bool - -func (BoolLiteral) Comparator() Comparator[bool] { - return func(v1, v2 bool) int { - if v1 { - if v2 { - return 0 - } - return 1 - } - return -1 - } -} - -func (b BoolLiteral) Type() Type { return PrimitiveTypes.Bool } -func (b BoolLiteral) Value() bool { return bool(b) } -func (b BoolLiteral) String() string { return strconv.FormatBool(bool(b)) } -func (b BoolLiteral) To(t Type) (Literal, error) { - switch t.(type) { - case BooleanType: - return b, nil - } - return nil, fmt.Errorf("%w: BoolLiteral to %s", ErrBadCast, t) -} - -func (b BoolLiteral) Equals(l Literal) bool { - return literalEq(b, l) -} - -var ( - falseBin, trueBin = [1]byte{0x0}, [1]byte{0x1} -) - -func (b BoolLiteral) MarshalBinary() (data []byte, err error) { - // stored as 0x00 for false, and anything non-zero for True - if b { - return trueBin[:], nil - } - return falseBin[:], nil -} - -func (b *BoolLiteral) UnmarshalBinary(data []byte) error { - // stored as 0x00 for false and anything non-zero for True - if len(data) < 1 { - return fmt.Errorf("%w: expected at least 1 byte for bool", ErrInvalidBinSerialization) - } - *b = data[0] != 0 - return nil -} - -type Int32Literal int32 - -func (Int32Literal) Comparator() Comparator[int32] { return cmp.Compare[int32] } -func (i Int32Literal) Type() Type { return PrimitiveTypes.Int32 } -func (i Int32Literal) Value() int32 { return int32(i) } -func (i Int32Literal) String() string { return strconv.FormatInt(int64(i), 10) } -func (i Int32Literal) To(t Type) (Literal, error) { - switch t := t.(type) { - case Int32Type: - return i, nil - case Int64Type: - return Int64Literal(i), nil - case Float32Type: - return Float32Literal(i), nil - case Float64Type: - return Float64Literal(i), nil - case DateType: - return DateLiteral(i), nil - case TimeType: - return TimeLiteral(i), nil - case TimestampType: - return TimestampLiteral(i), nil - case TimestampTzType: - return TimestampLiteral(i), nil - case DecimalType: - unscaled := Decimal{Val: decimal128.FromI64(int64(i)), Scale: 0} - if t.scale == 0 { - return DecimalLiteral(unscaled), nil - } - out, err := unscaled.Val.Rescale(0, int32(t.scale)) - if err != nil { - return nil, fmt.Errorf("%w: failed to cast to DecimalType: %s", ErrBadCast, err.Error()) - } - return DecimalLiteral{Val: out, Scale: t.scale}, nil - } - - return nil, fmt.Errorf("%w: Int32Literal to %s", ErrBadCast, t) -} - -func (i Int32Literal) Equals(other Literal) bool { - return literalEq(i, other) -} - -func (i Int32Literal) Increment() Literal { - if i == math.MaxInt32 { - return Int32AboveMaxLiteral() - } - - return Int32Literal(i + 1) -} - -func (i Int32Literal) Decrement() Literal { - if i == math.MinInt32 { - return Int32BelowMinLiteral() - } - - return Int32Literal(i - 1) -} - -func (i Int32Literal) MarshalBinary() (data []byte, err error) { - // stored as 4 bytes in little endian order - data = make([]byte, 4) - binary.LittleEndian.PutUint32(data, uint32(i)) - return -} - -func (i *Int32Literal) UnmarshalBinary(data []byte) error { - // stored as 4 bytes little endian - if len(data) != 4 { - return fmt.Errorf("%w: expected 4 bytes for int32 value, got %d", - ErrInvalidBinSerialization, len(data)) - } - - *i = Int32Literal(binary.LittleEndian.Uint32(data)) - return nil -} - -type Int64Literal int64 - -func (Int64Literal) Comparator() Comparator[int64] { return cmp.Compare[int64] } -func (i Int64Literal) Type() Type { return PrimitiveTypes.Int64 } -func (i Int64Literal) Value() int64 { return int64(i) } -func (i Int64Literal) String() string { return strconv.FormatInt(int64(i), 10) } -func (i Int64Literal) To(t Type) (Literal, error) { - switch t := t.(type) { - case Int32Type: - if math.MaxInt32 < i { - return Int32AboveMaxLiteral(), nil - } else if math.MinInt32 > i { - return Int32BelowMinLiteral(), nil - } - return Int32Literal(i), nil - case Int64Type: - return i, nil - case Float32Type: - return Float32Literal(i), nil - case Float64Type: - return Float64Literal(i), nil - case DateType: - return DateLiteral(i), nil - case TimeType: - return TimeLiteral(i), nil - case TimestampType: - return TimestampLiteral(i), nil - case TimestampTzType: - return TimestampLiteral(i), nil - case DecimalType: - unscaled := Decimal{Val: decimal128.FromI64(int64(i)), Scale: 0} - if t.scale == 0 { - return DecimalLiteral(unscaled), nil - } - out, err := unscaled.Val.Rescale(0, int32(t.scale)) - if err != nil { - return nil, fmt.Errorf("%w: failed to cast to DecimalType: %s", ErrBadCast, err.Error()) - } - return DecimalLiteral{Val: out, Scale: t.scale}, nil - } - - return nil, fmt.Errorf("%w: Int64Literal to %s", ErrBadCast, t) -} - -func (i Int64Literal) Equals(other Literal) bool { - return literalEq(i, other) -} - -func (i Int64Literal) Increment() Literal { - if i == math.MaxInt64 { - return Int64AboveMaxLiteral() - } - - return Int64Literal(i + 1) -} - -func (i Int64Literal) Decrement() Literal { - if i == math.MinInt64 { - return Int64BelowMinLiteral() - } - - return Int64Literal(i - 1) -} - -func (i Int64Literal) MarshalBinary() (data []byte, err error) { - // stored as 8 byte little-endian - data = make([]byte, 8) - binary.LittleEndian.PutUint64(data, uint64(i)) - return -} - -func (i *Int64Literal) UnmarshalBinary(data []byte) error { - // stored as 8 byte little-endian - if len(data) != 8 { - return fmt.Errorf("%w: expected 8 bytes for int64 value, got %d", - ErrInvalidBinSerialization, len(data)) - } - *i = Int64Literal(binary.LittleEndian.Uint64(data)) - return nil -} - -type Float32Literal float32 - -func (Float32Literal) Comparator() Comparator[float32] { return cmp.Compare[float32] } -func (f Float32Literal) Type() Type { return PrimitiveTypes.Float32 } -func (f Float32Literal) Value() float32 { return float32(f) } -func (f Float32Literal) String() string { return strconv.FormatFloat(float64(f), 'g', -1, 32) } -func (f Float32Literal) To(t Type) (Literal, error) { - switch t := t.(type) { - case Float32Type: - return f, nil - case Float64Type: - return Float64Literal(f), nil - case DecimalType: - v, err := decimal128.FromFloat32(float32(f), int32(t.precision), int32(t.scale)) - if err != nil { - return nil, err - } - return DecimalLiteral{Val: v, Scale: t.scale}, nil - } - - return nil, fmt.Errorf("%w: Float32Literal to %s", ErrBadCast, t) -} - -func (f Float32Literal) Equals(other Literal) bool { - return literalEq(f, other) -} - -func (f Float32Literal) MarshalBinary() (data []byte, err error) { - // stored as 4 bytes little endian - data = make([]byte, 4) - binary.LittleEndian.PutUint32(data, math.Float32bits(float32(f))) - return -} - -func (f *Float32Literal) UnmarshalBinary(data []byte) error { - // stored as 4 bytes little endian - if len(data) != 4 { - return fmt.Errorf("%w: expected 4 bytes for float32 value, got %d", - ErrInvalidBinSerialization, len(data)) - } - *f = Float32Literal(math.Float32frombits(binary.LittleEndian.Uint32(data))) - return nil -} - -type Float64Literal float64 - -func (Float64Literal) Comparator() Comparator[float64] { return cmp.Compare[float64] } -func (f Float64Literal) Type() Type { return PrimitiveTypes.Float64 } -func (f Float64Literal) Value() float64 { return float64(f) } -func (f Float64Literal) String() string { return strconv.FormatFloat(float64(f), 'g', -1, 64) } -func (f Float64Literal) To(t Type) (Literal, error) { - switch t := t.(type) { - case Float32Type: - if math.MaxFloat32 < f { - return Float32AboveMaxLiteral(), nil - } else if -math.MaxFloat32 > f { - return Float32BelowMinLiteral(), nil - } - return Float32Literal(f), nil - case Float64Type: - return f, nil - case DecimalType: - v, err := decimal128.FromFloat64(float64(f), int32(t.precision), int32(t.scale)) - if err != nil { - return nil, err - } - return DecimalLiteral{Val: v, Scale: t.scale}, nil - } - - return nil, fmt.Errorf("%w: Float64Literal to %s", ErrBadCast, t) -} - -func (f Float64Literal) Equals(other Literal) bool { - return literalEq(f, other) -} - -func (f Float64Literal) MarshalBinary() (data []byte, err error) { - // stored as 8 bytes little endian - data = make([]byte, 8) - binary.LittleEndian.PutUint64(data, math.Float64bits(float64(f))) - return -} - -func (f *Float64Literal) UnmarshalBinary(data []byte) error { - // stored as 8 bytes in little endian - if len(data) != 8 { - return fmt.Errorf("%w: expected 8 bytes for float64 value, got %d", - ErrInvalidBinSerialization, len(data)) - } - *f = Float64Literal(math.Float64frombits(binary.LittleEndian.Uint64(data))) - return nil -} - -type DateLiteral Date - -func (DateLiteral) Comparator() Comparator[Date] { return cmp.Compare[Date] } -func (d DateLiteral) Type() Type { return PrimitiveTypes.Date } -func (d DateLiteral) Value() Date { return Date(d) } -func (d DateLiteral) String() string { - t := Date(d).ToTime() - return t.Format("2006-01-02") -} -func (d DateLiteral) To(t Type) (Literal, error) { - switch t.(type) { - case DateType: - return d, nil - } - return nil, fmt.Errorf("%w: DateLiteral to %s", ErrBadCast, t) -} -func (d DateLiteral) Equals(other Literal) bool { - return literalEq(d, other) -} - -func (d DateLiteral) Increment() Literal { return DateLiteral(d + 1) } -func (d DateLiteral) Decrement() Literal { return DateLiteral(d - 1) } - -func (d DateLiteral) MarshalBinary() (data []byte, err error) { - // stored as 4 byte little endian - data = make([]byte, 4) - binary.LittleEndian.PutUint32(data, uint32(d)) - return -} - -func (d *DateLiteral) UnmarshalBinary(data []byte) error { - // stored as 4 byte little endian - if len(data) != 4 { - return fmt.Errorf("%w: expected 4 bytes for date value, got %d", - ErrInvalidBinSerialization, len(data)) - } - *d = DateLiteral(binary.LittleEndian.Uint32(data)) - return nil -} - -type TimeLiteral Time - -func (TimeLiteral) Comparator() Comparator[Time] { return cmp.Compare[Time] } -func (t TimeLiteral) Type() Type { return PrimitiveTypes.Time } -func (t TimeLiteral) Value() Time { return Time(t) } -func (t TimeLiteral) String() string { - tm := time.UnixMicro(int64(t)).UTC() - return tm.Format("15:04:05.000000") -} -func (t TimeLiteral) To(typ Type) (Literal, error) { - switch typ.(type) { - case TimeType: - return t, nil - } - return nil, fmt.Errorf("%w: TimeLiteral to %s", ErrBadCast, typ) - -} -func (t TimeLiteral) Equals(other Literal) bool { - return literalEq(t, other) -} - -func (t TimeLiteral) MarshalBinary() (data []byte, err error) { - // stored as 8 byte little-endian - data = make([]byte, 8) - binary.LittleEndian.PutUint64(data, uint64(t)) - return -} - -func (t *TimeLiteral) UnmarshalBinary(data []byte) error { - // stored as 8 byte little-endian representing microseconds from midnight - if len(data) != 8 { - return fmt.Errorf("%w: expected 8 bytes for time value, got %d", - ErrInvalidBinSerialization, len(data)) - } - *t = TimeLiteral(binary.LittleEndian.Uint64(data)) - return nil -} - -type TimestampLiteral Timestamp - -func (TimestampLiteral) Comparator() Comparator[Timestamp] { return cmp.Compare[Timestamp] } -func (t TimestampLiteral) Type() Type { return PrimitiveTypes.Timestamp } -func (t TimestampLiteral) Value() Timestamp { return Timestamp(t) } -func (t TimestampLiteral) String() string { - tm := Timestamp(t).ToTime() - return tm.Format("2006-01-02 15:04:05.000000") -} -func (t TimestampLiteral) To(typ Type) (Literal, error) { - switch typ.(type) { - case TimestampType: - return t, nil - case TimestampTzType: - return t, nil - case DateType: - return DateLiteral(Timestamp(t).ToDate()), nil - } - return nil, fmt.Errorf("%w: TimestampLiteral to %s", ErrBadCast, typ) -} -func (t TimestampLiteral) Equals(other Literal) bool { - return literalEq(t, other) -} - -func (t TimestampLiteral) Increment() Literal { return TimestampLiteral(t + 1) } -func (t TimestampLiteral) Decrement() Literal { return TimestampLiteral(t - 1) } - -func (t TimestampLiteral) MarshalBinary() (data []byte, err error) { - // stored as 8 byte little endian - data = make([]byte, 8) - binary.LittleEndian.PutUint64(data, uint64(t)) - return -} - -func (t *TimestampLiteral) UnmarshalBinary(data []byte) error { - // stored as 8 byte little endian value representing microseconds since epoch - if len(data) != 8 { - return fmt.Errorf("%w: expected 8 bytes for timestamp value, got %d", - ErrInvalidBinSerialization, len(data)) - } - *t = TimestampLiteral(binary.LittleEndian.Uint64(data)) - return nil -} - -type StringLiteral string - -func (StringLiteral) Comparator() Comparator[string] { return cmp.Compare[string] } -func (s StringLiteral) Type() Type { return PrimitiveTypes.String } -func (s StringLiteral) Value() string { return string(s) } -func (s StringLiteral) String() string { return string(s) } -func (s StringLiteral) To(typ Type) (Literal, error) { - switch t := typ.(type) { - case StringType: - return s, nil - case Int32Type: - n, err := strconv.ParseInt(string(s), 10, 64) - if err != nil { - return nil, fmt.Errorf("%w: casting '%s' to %s", - errors.Join(ErrBadCast, err), s, typ) - } - - if math.MaxInt32 < n { - return Int32AboveMaxLiteral(), nil - } else if math.MinInt32 > n { - return Int32BelowMinLiteral(), nil - } - - return Int32Literal(n), nil - case Int64Type: - n, err := strconv.ParseInt(string(s), 10, 64) - if err != nil { - return nil, fmt.Errorf("%w: casting '%s' to %s", - errors.Join(ErrBadCast, err), s, typ) - } - - return Int64Literal(n), nil - case Float32Type: - n, err := strconv.ParseFloat(string(s), 32) - if err != nil { - return nil, fmt.Errorf("%w: casting '%s' to %s", - errors.Join(ErrBadCast, err), s, typ) - } - return Float32Literal(n), nil - case Float64Type: - n, err := strconv.ParseFloat(string(s), 64) - if err != nil { - return nil, fmt.Errorf("%w: casting '%s' to %s", - errors.Join(ErrBadCast, err), s, typ) - } - return Float64Literal(n), nil - case DateType: - tm, err := time.Parse("2006-01-02", string(s)) - if err != nil { - return nil, fmt.Errorf("%w: casting '%s' to %s - %s", - ErrBadCast, s, typ, err.Error()) - } - return DateLiteral(tm.Truncate(24*time.Hour).Unix() / int64((time.Hour * 24).Seconds())), nil - case TimeType: - val, err := arrow.Time64FromString(string(s), arrow.Microsecond) - if err != nil { - return nil, fmt.Errorf("%w: casting '%s' to %s - %s", - ErrBadCast, s, typ, err.Error()) - } - - return TimeLiteral(val), nil - case TimestampType: - // requires RFC3339 with no time zone - tm, err := time.Parse("2006-01-02T15:04:05", string(s)) - if err != nil { - return nil, fmt.Errorf("%w: invalid Timestamp format for casting from string '%s': %s", - ErrBadCast, s, err.Error()) - } - - return TimestampLiteral(Timestamp(tm.UTC().UnixMicro())), nil - case TimestampTzType: - // requires RFC3339 format WITH time zone - tm, err := time.Parse(time.RFC3339, string(s)) - if err != nil { - return nil, fmt.Errorf("%w: invalid TimestampTz format for casting from string '%s': %s", - ErrBadCast, s, err.Error()) - } - - return TimestampLiteral(Timestamp(tm.UTC().UnixMicro())), nil - case UUIDType: - val, err := uuid.Parse(string(s)) - if err != nil { - return nil, fmt.Errorf("%w: casting '%s' to %s - %s", - ErrBadCast, s, typ, err.Error()) - } - return UUIDLiteral(val), nil - case DecimalType: - n, err := decimal128.FromString(string(s), int32(t.precision), int32(t.scale)) - if err != nil { - return nil, fmt.Errorf("%w: casting '%s' to %s - %s", - ErrBadCast, s, typ, err.Error()) - } - return DecimalLiteral{Val: n, Scale: t.scale}, nil - case BooleanType: - val, err := strconv.ParseBool(string(s)) - if err != nil { - return nil, fmt.Errorf("%w: casting '%s' to %s - %s", - ErrBadCast, s, typ, err.Error()) - } - return BoolLiteral(val), nil - case BinaryType: - return BinaryLiteral(s), nil - case FixedType: - if len(s) != t.len { - return nil, fmt.Errorf("%w: cast '%s' to %s - wrong length", - ErrBadCast, s, t) - } - return FixedLiteral(s), nil - } - return nil, fmt.Errorf("%w: StringLiteral to %s", ErrBadCast, typ) -} - -func (s StringLiteral) Equals(other Literal) bool { - return literalEq(s, other) -} - -func (s StringLiteral) MarshalBinary() (data []byte, err error) { - // stored as UTF-8 bytes without length - // avoid copying by just returning a slice of the raw bytes - data = unsafe.Slice(unsafe.StringData(string(s)), len(s)) - return -} - -func (s *StringLiteral) UnmarshalBinary(data []byte) error { - // stored as UTF-8 bytes without length - // avoid copy, but this means that the passed in slice is being given - // to the literal for ownership - *s = StringLiteral(unsafe.String(unsafe.SliceData(data), len(data))) - return nil -} - -type BinaryLiteral []byte - -func (BinaryLiteral) Comparator() Comparator[[]byte] { - return bytes.Compare -} -func (b BinaryLiteral) Type() Type { return PrimitiveTypes.Binary } -func (b BinaryLiteral) Value() []byte { return []byte(b) } -func (b BinaryLiteral) String() string { return string(b) } -func (b BinaryLiteral) To(typ Type) (Literal, error) { - switch t := typ.(type) { - case UUIDType: - val, err := uuid.FromBytes(b) - if err != nil { - return nil, fmt.Errorf("%w: cannot convert BinaryLiteral to UUID", - errors.Join(ErrBadCast, err)) - } - return UUIDLiteral(val), nil - case FixedType: - if len(b) == t.len { - return FixedLiteral(b), nil - } - - return nil, fmt.Errorf("%w: cannot convert BinaryLiteral to %s, different length - %d <> %d", - ErrBadCast, typ, len(b), t.len) - case BinaryType: - return b, nil - } - - return nil, fmt.Errorf("%w: BinaryLiteral to %s", ErrBadCast, typ) -} -func (b BinaryLiteral) Equals(other Literal) bool { - rhs, ok := other.(BinaryLiteral) - if !ok { - return false - } - - return bytes.Equal([]byte(b), rhs) -} - -func (b BinaryLiteral) MarshalBinary() (data []byte, err error) { - // stored directly as is - data = b - return -} - -func (b *BinaryLiteral) UnmarshalBinary(data []byte) error { - // stored directly as is - *b = BinaryLiteral(data) - return nil -} - -type FixedLiteral []byte - -func (FixedLiteral) Comparator() Comparator[[]byte] { return bytes.Compare } -func (f FixedLiteral) Type() Type { return FixedTypeOf(len(f)) } -func (f FixedLiteral) Value() []byte { return []byte(f) } -func (f FixedLiteral) String() string { return string(f) } -func (f FixedLiteral) To(typ Type) (Literal, error) { - switch t := typ.(type) { - case UUIDType: - val, err := uuid.FromBytes(f) - if err != nil { - return nil, fmt.Errorf("%w: cannot convert FixedLiteral to UUID - %s", - ErrBadCast, err.Error()) - } - return UUIDLiteral(val), nil - case FixedType: - if len(f) == t.len { - return FixedLiteral(f), nil - } - - return nil, fmt.Errorf("%w: cannot convert FixedLiteral to %s, different length - %d <> %d", - ErrBadCast, typ, len(f), t.len) - case BinaryType: - return f, nil - } - - return nil, fmt.Errorf("%w: FixedLiteral[%d] to %s", - ErrBadCast, len(f), typ) -} -func (f FixedLiteral) Equals(other Literal) bool { - rhs, ok := other.(FixedLiteral) - if !ok { - return false - } - - return bytes.Equal([]byte(f), rhs) -} - -func (f FixedLiteral) MarshalBinary() (data []byte, err error) { - // stored directly as is - data = f - return -} - -func (f *FixedLiteral) UnmarshalBinary(data []byte) error { - // stored directly as is - *f = FixedLiteral(data) - return nil -} - -type UUIDLiteral uuid.UUID - -func (UUIDLiteral) Comparator() Comparator[uuid.UUID] { - return func(v1, v2 uuid.UUID) int { - return bytes.Compare(v1[:], v2[:]) - } -} - -func (UUIDLiteral) Type() Type { return PrimitiveTypes.UUID } -func (u UUIDLiteral) Value() uuid.UUID { return uuid.UUID(u) } -func (u UUIDLiteral) String() string { return uuid.UUID(u).String() } -func (u UUIDLiteral) To(typ Type) (Literal, error) { - switch t := typ.(type) { - case UUIDType: - return u, nil - case FixedType: - if len(u) == t.len { - v, _ := uuid.UUID(u).MarshalBinary() - return FixedLiteral(v), nil - } - - return nil, fmt.Errorf("%w: cannot convert UUIDLiteral to %s, different length - %d <> %d", - ErrBadCast, typ, len(u), t.len) - case BinaryType: - v, _ := uuid.UUID(u).MarshalBinary() - return BinaryLiteral(v), nil - } - - return nil, fmt.Errorf("%w: UUIDLiteral to %s", ErrBadCast, typ) -} -func (u UUIDLiteral) Equals(other Literal) bool { - rhs, ok := other.(UUIDLiteral) - if !ok { - return false - } - - return uuid.UUID(u) == uuid.UUID(rhs) -} - -func (u UUIDLiteral) MarshalBinary() (data []byte, err error) { - return uuid.UUID(u).MarshalBinary() -} - -func (u *UUIDLiteral) UnmarshalBinary(data []byte) error { - // stored as 16-byte big-endian value - out, err := uuid.FromBytes(data) - if err != nil { - return err - } - *u = UUIDLiteral(out) - return nil -} - -type DecimalLiteral Decimal - -func (DecimalLiteral) Comparator() Comparator[Decimal] { - return func(v1, v2 Decimal) int { - if v1.Scale == v2.Scale { - return v1.Val.Cmp(v2.Val) - } - - rescaled, err := v2.Val.Rescale(int32(v2.Scale), int32(v1.Scale)) - if err != nil { - return -1 - } - - return v1.Val.Cmp(rescaled) - } -} -func (d DecimalLiteral) Type() Type { return DecimalTypeOf(9, d.Scale) } -func (d DecimalLiteral) Value() Decimal { return Decimal(d) } -func (d DecimalLiteral) String() string { - return d.Val.ToString(int32(d.Scale)) -} - -func (d DecimalLiteral) To(t Type) (Literal, error) { - switch t := t.(type) { - case DecimalType: - if d.Scale == t.scale { - return d, nil - } - - return nil, fmt.Errorf("%w: could not convert %v to %s", - ErrBadCast, d, t) - case Int32Type: - v := d.Val.BigInt().Int64() - if v > math.MaxInt32 { - return Int32AboveMaxLiteral(), nil - } else if v < math.MinInt32 { - return Int32BelowMinLiteral(), nil - } - - return Int32Literal(int32(v)), nil - case Int64Type: - v := d.Val.BigInt() - if !v.IsInt64() { - if v.Sign() > 0 { - return Int64AboveMaxLiteral(), nil - } else if v.Sign() < 0 { - return Int64BelowMinLiteral(), nil - } - } - - return Int64Literal(v.Int64()), nil - case Float32Type: - v := d.Val.ToFloat64(int32(d.Scale)) - if v > math.MaxFloat32 { - return Float32AboveMaxLiteral(), nil - } else if v < -math.MaxFloat32 { - return Float32BelowMinLiteral(), nil - } - - return Float32Literal(float32(v)), nil - case Float64Type: - return Float64Literal(d.Val.ToFloat64(int32(d.Scale))), nil - } - - return nil, fmt.Errorf("%w: DecimalLiteral to %s", ErrBadCast, t) -} - -func (d DecimalLiteral) Equals(other Literal) bool { - rhs, ok := other.(DecimalLiteral) - if !ok { - return false - } - - rescaled, err := rhs.Val.Rescale(int32(rhs.Scale), int32(d.Scale)) - if err != nil { - return false - } - return d.Val == rescaled -} - -func (d DecimalLiteral) Increment() Literal { - d.Val = d.Val.Add(decimal128.FromU64(1)) - return d -} - -func (d DecimalLiteral) Decrement() Literal { - d.Val = d.Val.Sub(decimal128.FromU64(1)) - return d -} - -func (d DecimalLiteral) MarshalBinary() (data []byte, err error) { - // stored as unscaled value in two's compliment big-endian values - // using the minimum number of bytes for the values - n := decimal128.Num(d.Val).BigInt() - // bytes gives absolute value as big-endian bytes - data = n.Bytes() - if n.Sign() < 0 { - // convert to 2's complement for negative value - for i, v := range data { - data[i] = ^v - } - data[len(data)-1] += 1 - } - return -} - -func (d *DecimalLiteral) UnmarshalBinary(data []byte) error { - // stored as unscaled value in two's complement - // big-endian values using the minimum number of bytes - if len(data) == 0 { - d.Val = decimal128.Num{} - return nil - } - - if int8(data[0]) >= 0 { - // not negative - d.Val = decimal128.FromBigInt((&big.Int{}).SetBytes(data)) - return nil - } - - // convert two's complement and remember it's negative - out := make([]byte, len(data)) - for i, b := range data { - out[i] = ^b - } - out[len(out)-1] += 1 - - value := (&big.Int{}).SetBytes(out) - d.Val = decimal128.FromBigInt(value.Neg(value)) - return nil -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package iceberg + +import ( + "bytes" + "cmp" + "encoding" + "encoding/binary" + "errors" + "fmt" + "math" + "math/big" + "reflect" + "strconv" + "time" + "unsafe" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/decimal128" + "github.com/google/uuid" +) + +// LiteralType is a generic type constraint for the explicit Go types that we allow +// for literal values. This represents the actual primitive types that exist in Iceberg +type LiteralType interface { + bool | int32 | int64 | float32 | float64 | Date | + Time | Timestamp | string | []byte | uuid.UUID | Decimal +} + +// Comparator is a comparison function for specific literal types: +// +// returns 0 if v1 == v2 +// returns <0 if v1 < v2 +// returns >0 if v1 > v2 +type Comparator[T LiteralType] func(v1, v2 T) int + +// Literal is a non-null literal value. It can be casted using To and be checked for +// equality against other literals. +type Literal interface { + fmt.Stringer + encoding.BinaryMarshaler + + Type() Type + To(Type) (Literal, error) + Equals(Literal) bool +} + +// TypedLiteral is a generic interface for Literals so that you can retrieve the value. +// This is based on the physical representative type, which means that FixedLiteral and +// BinaryLiteral will both return []byte, etc. +type TypedLiteral[T LiteralType] interface { + Literal + + Value() T + Comparator() Comparator[T] +} + +type NumericLiteral interface { + Literal + Increment() Literal + Decrement() Literal +} + +// NewLiteral provides a literal based on the type of T +func NewLiteral[T LiteralType](val T) Literal { + switch v := any(val).(type) { + case bool: + return BoolLiteral(v) + case int32: + return Int32Literal(v) + case int64: + return Int64Literal(v) + case float32: + return Float32Literal(v) + case float64: + return Float64Literal(v) + case Date: + return DateLiteral(v) + case Time: + return TimeLiteral(v) + case Timestamp: + return TimestampLiteral(v) + case string: + return StringLiteral(v) + case []byte: + return BinaryLiteral(v) + case uuid.UUID: + return UUIDLiteral(v) + case Decimal: + return DecimalLiteral(v) + } + panic("can't happen due to literal type constraint") +} + +// LiteralFromBytes uses the defined Iceberg spec for how to serialize a value of +// a the provided type and returns the appropriate Literal value from it. +// +// If you already have a value of the desired Literal type, you could alternatively +// call UnmarshalBinary on it yourself manually. +// +// This is primarily used for retrieving stat values. +func LiteralFromBytes(typ Type, data []byte) (Literal, error) { + if data == nil { + return nil, ErrInvalidBinSerialization + } + + switch t := typ.(type) { + case BooleanType: + var v BoolLiteral + err := v.UnmarshalBinary(data) + return v, err + case Int32Type: + var v Int32Literal + err := v.UnmarshalBinary(data) + return v, err + case Int64Type: + var v Int64Literal + err := v.UnmarshalBinary(data) + return v, err + case Float32Type: + var v Float32Literal + err := v.UnmarshalBinary(data) + return v, err + case Float64Type: + var v Float64Literal + err := v.UnmarshalBinary(data) + return v, err + case StringType: + var v StringLiteral + err := v.UnmarshalBinary(data) + return v, err + case BinaryType: + var v BinaryLiteral + err := v.UnmarshalBinary(data) + return v, err + case FixedType: + if len(data) != t.Len() { + return nil, fmt.Errorf("%w: expected length %d for type %s, got %d", + ErrInvalidBinSerialization, t.Len(), t, len(data)) + } + var v FixedLiteral + err := v.UnmarshalBinary(data) + return v, err + case DecimalType: + v := DecimalLiteral{Scale: t.scale} + err := v.UnmarshalBinary(data) + return v, err + case DateType: + var v DateLiteral + err := v.UnmarshalBinary(data) + return v, err + case TimeType: + var v TimeLiteral + err := v.UnmarshalBinary(data) + return v, err + case TimestampType, TimestampTzType: + var v TimestampLiteral + err := v.UnmarshalBinary(data) + return v, err + case UUIDType: + var v UUIDLiteral + err := v.UnmarshalBinary(data) + return v, err + } + + return nil, ErrType +} + +// convenience to avoid repreating this pattern for primitive types +func literalEq[L interface { + comparable + LiteralType +}, T TypedLiteral[L]](lhs T, other Literal) bool { + rhs, ok := other.(T) + if !ok { + return false + } + + return lhs.Value() == rhs.Value() +} + +// AboveMaxLiteral represents values that are above the maximum for their type +// such as values > math.MaxInt32 for an Int32Literal +type AboveMaxLiteral interface { + Literal + + aboveMax() +} + +// BelowMinLiteral represents values that are below the minimum for their type +// such as values < math.MinInt32 for an Int32Literal +type BelowMinLiteral interface { + Literal + + belowMin() +} + +type aboveMaxLiteral[T int32 | int64 | float32 | float64] struct { + value T +} + +func (ab aboveMaxLiteral[T]) MarshalBinary() (data []byte, err error) { + return nil, fmt.Errorf("%w: cannot marshal above max literal", + ErrInvalidBinSerialization) +} + +func (ab aboveMaxLiteral[T]) aboveMax() {} + +func (ab aboveMaxLiteral[T]) Type() Type { + var z T + switch any(z).(type) { + case int32: + return PrimitiveTypes.Int32 + case int64: + return PrimitiveTypes.Int64 + case float32: + return PrimitiveTypes.Float32 + case float64: + return PrimitiveTypes.Float64 + default: + panic("should never happen") + } +} + +func (ab aboveMaxLiteral[T]) To(t Type) (Literal, error) { + if ab.Type().Equals(t) { + return ab, nil + } + return nil, fmt.Errorf("%w: cannot change type of AboveMax%sLiteral", + ErrBadCast, reflect.TypeOf(T(0)).String()) +} + +func (ab aboveMaxLiteral[T]) Value() T { return ab.value } + +func (ab aboveMaxLiteral[T]) String() string { return "AboveMax" } +func (ab aboveMaxLiteral[T]) Equals(other Literal) bool { + // AboveMaxLiteral isn't comparable and thus isn't even equal to itself + return false +} + +type belowMinLiteral[T int32 | int64 | float32 | float64] struct { + value T +} + +func (bm belowMinLiteral[T]) MarshalBinary() (data []byte, err error) { + return nil, fmt.Errorf("%w: cannot marshal above max literal", + ErrInvalidBinSerialization) +} + +func (bm belowMinLiteral[T]) belowMin() {} + +func (bm belowMinLiteral[T]) Type() Type { + var z T + switch any(z).(type) { + case int32: + return PrimitiveTypes.Int32 + case int64: + return PrimitiveTypes.Int64 + case float32: + return PrimitiveTypes.Float32 + case float64: + return PrimitiveTypes.Float64 + default: + panic("should never happen") + } +} + +func (bm belowMinLiteral[T]) To(t Type) (Literal, error) { + if bm.Type().Equals(t) { + return bm, nil + } + return nil, fmt.Errorf("%w: cannot change type of BelowMin%sLiteral", + ErrBadCast, reflect.TypeOf(T(0)).String()) +} + +func (bm belowMinLiteral[T]) Value() T { return bm.value } + +func (bm belowMinLiteral[T]) String() string { return "BelowMin" } +func (bm belowMinLiteral[T]) Equals(other Literal) bool { + // BelowMinLiteral isn't comparable and thus isn't even equal to itself + return false +} + +func Int32AboveMaxLiteral() Literal { + return aboveMaxLiteral[int32]{value: math.MaxInt32} +} + +func Int64AboveMaxLiteral() Literal { + return aboveMaxLiteral[int64]{value: math.MaxInt64} +} + +func Float32AboveMaxLiteral() Literal { + return aboveMaxLiteral[float32]{value: math.MaxFloat32} +} + +func Float64AboveMaxLiteral() Literal { + return aboveMaxLiteral[float64]{value: math.MaxFloat64} +} + +func Int32BelowMinLiteral() Literal { + return belowMinLiteral[int32]{value: math.MinInt32} +} + +func Int64BelowMinLiteral() Literal { + return belowMinLiteral[int64]{value: math.MinInt64} +} + +func Float32BelowMinLiteral() Literal { + return belowMinLiteral[float32]{value: -math.MaxFloat32} +} + +func Float64BelowMinLiteral() Literal { + return belowMinLiteral[float64]{value: -math.MaxFloat64} +} + +type BoolLiteral bool + +func (BoolLiteral) Comparator() Comparator[bool] { + return func(v1, v2 bool) int { + if v1 { + if v2 { + return 0 + } + return 1 + } + return -1 + } +} + +func (b BoolLiteral) Type() Type { return PrimitiveTypes.Bool } +func (b BoolLiteral) Value() bool { return bool(b) } +func (b BoolLiteral) String() string { return strconv.FormatBool(bool(b)) } +func (b BoolLiteral) To(t Type) (Literal, error) { + switch t.(type) { + case BooleanType: + return b, nil + } + return nil, fmt.Errorf("%w: BoolLiteral to %s", ErrBadCast, t) +} + +func (b BoolLiteral) Equals(l Literal) bool { + return literalEq(b, l) +} + +var ( + falseBin, trueBin = [1]byte{0x0}, [1]byte{0x1} +) + +func (b BoolLiteral) MarshalBinary() (data []byte, err error) { + // stored as 0x00 for false, and anything non-zero for True + if b { + return trueBin[:], nil + } + return falseBin[:], nil +} + +func (b *BoolLiteral) UnmarshalBinary(data []byte) error { + // stored as 0x00 for false and anything non-zero for True + if len(data) < 1 { + return fmt.Errorf("%w: expected at least 1 byte for bool", ErrInvalidBinSerialization) + } + *b = data[0] != 0 + return nil +} + +type Int32Literal int32 + +func (Int32Literal) Comparator() Comparator[int32] { return cmp.Compare[int32] } +func (i Int32Literal) Type() Type { return PrimitiveTypes.Int32 } +func (i Int32Literal) Value() int32 { return int32(i) } +func (i Int32Literal) String() string { return strconv.FormatInt(int64(i), 10) } +func (i Int32Literal) To(t Type) (Literal, error) { + switch t := t.(type) { + case Int32Type: + return i, nil + case Int64Type: + return Int64Literal(i), nil + case Float32Type: + return Float32Literal(i), nil + case Float64Type: + return Float64Literal(i), nil + case DateType: + return DateLiteral(i), nil + case TimeType: + return TimeLiteral(i), nil + case TimestampType: + return TimestampLiteral(i), nil + case TimestampTzType: + return TimestampLiteral(i), nil + case DecimalType: + unscaled := Decimal{Val: decimal128.FromI64(int64(i)), Scale: 0} + if t.scale == 0 { + return DecimalLiteral(unscaled), nil + } + out, err := unscaled.Val.Rescale(0, int32(t.scale)) + if err != nil { + return nil, fmt.Errorf("%w: failed to cast to DecimalType: %s", ErrBadCast, err.Error()) + } + return DecimalLiteral{Val: out, Scale: t.scale}, nil + } + + return nil, fmt.Errorf("%w: Int32Literal to %s", ErrBadCast, t) +} + +func (i Int32Literal) Equals(other Literal) bool { + return literalEq(i, other) +} + +func (i Int32Literal) Increment() Literal { + if i == math.MaxInt32 { + return Int32AboveMaxLiteral() + } + + return Int32Literal(i + 1) +} + +func (i Int32Literal) Decrement() Literal { + if i == math.MinInt32 { + return Int32BelowMinLiteral() + } + + return Int32Literal(i - 1) +} + +func (i Int32Literal) MarshalBinary() (data []byte, err error) { + // stored as 4 bytes in little endian order + data = make([]byte, 4) + binary.LittleEndian.PutUint32(data, uint32(i)) + return +} + +func (i *Int32Literal) UnmarshalBinary(data []byte) error { + // stored as 4 bytes little endian + if len(data) != 4 { + return fmt.Errorf("%w: expected 4 bytes for int32 value, got %d", + ErrInvalidBinSerialization, len(data)) + } + + *i = Int32Literal(binary.LittleEndian.Uint32(data)) + return nil +} + +type Int64Literal int64 + +func (Int64Literal) Comparator() Comparator[int64] { return cmp.Compare[int64] } +func (i Int64Literal) Type() Type { return PrimitiveTypes.Int64 } +func (i Int64Literal) Value() int64 { return int64(i) } +func (i Int64Literal) String() string { return strconv.FormatInt(int64(i), 10) } +func (i Int64Literal) To(t Type) (Literal, error) { + switch t := t.(type) { + case Int32Type: + if math.MaxInt32 < i { + return Int32AboveMaxLiteral(), nil + } else if math.MinInt32 > i { + return Int32BelowMinLiteral(), nil + } + return Int32Literal(i), nil + case Int64Type: + return i, nil + case Float32Type: + return Float32Literal(i), nil + case Float64Type: + return Float64Literal(i), nil + case DateType: + return DateLiteral(i), nil + case TimeType: + return TimeLiteral(i), nil + case TimestampType: + return TimestampLiteral(i), nil + case TimestampTzType: + return TimestampLiteral(i), nil + case DecimalType: + unscaled := Decimal{Val: decimal128.FromI64(int64(i)), Scale: 0} + if t.scale == 0 { + return DecimalLiteral(unscaled), nil + } + out, err := unscaled.Val.Rescale(0, int32(t.scale)) + if err != nil { + return nil, fmt.Errorf("%w: failed to cast to DecimalType: %s", ErrBadCast, err.Error()) + } + return DecimalLiteral{Val: out, Scale: t.scale}, nil + } + + return nil, fmt.Errorf("%w: Int64Literal to %s", ErrBadCast, t) +} + +func (i Int64Literal) Equals(other Literal) bool { + return literalEq(i, other) +} + +func (i Int64Literal) Increment() Literal { + if i == math.MaxInt64 { + return Int64AboveMaxLiteral() + } + + return Int64Literal(i + 1) +} + +func (i Int64Literal) Decrement() Literal { + if i == math.MinInt64 { + return Int64BelowMinLiteral() + } + + return Int64Literal(i - 1) +} + +func (i Int64Literal) MarshalBinary() (data []byte, err error) { + // stored as 8 byte little-endian + data = make([]byte, 8) + binary.LittleEndian.PutUint64(data, uint64(i)) + return +} + +func (i *Int64Literal) UnmarshalBinary(data []byte) error { + // stored as 8 byte little-endian + if len(data) != 8 { + return fmt.Errorf("%w: expected 8 bytes for int64 value, got %d", + ErrInvalidBinSerialization, len(data)) + } + *i = Int64Literal(binary.LittleEndian.Uint64(data)) + return nil +} + +type Float32Literal float32 + +func (Float32Literal) Comparator() Comparator[float32] { return cmp.Compare[float32] } +func (f Float32Literal) Type() Type { return PrimitiveTypes.Float32 } +func (f Float32Literal) Value() float32 { return float32(f) } +func (f Float32Literal) String() string { return strconv.FormatFloat(float64(f), 'g', -1, 32) } +func (f Float32Literal) To(t Type) (Literal, error) { + switch t := t.(type) { + case Float32Type: + return f, nil + case Float64Type: + return Float64Literal(f), nil + case DecimalType: + v, err := decimal128.FromFloat32(float32(f), int32(t.precision), int32(t.scale)) + if err != nil { + return nil, err + } + return DecimalLiteral{Val: v, Scale: t.scale}, nil + } + + return nil, fmt.Errorf("%w: Float32Literal to %s", ErrBadCast, t) +} + +func (f Float32Literal) Equals(other Literal) bool { + return literalEq(f, other) +} + +func (f Float32Literal) MarshalBinary() (data []byte, err error) { + // stored as 4 bytes little endian + data = make([]byte, 4) + binary.LittleEndian.PutUint32(data, math.Float32bits(float32(f))) + return +} + +func (f *Float32Literal) UnmarshalBinary(data []byte) error { + // stored as 4 bytes little endian + if len(data) != 4 { + return fmt.Errorf("%w: expected 4 bytes for float32 value, got %d", + ErrInvalidBinSerialization, len(data)) + } + *f = Float32Literal(math.Float32frombits(binary.LittleEndian.Uint32(data))) + return nil +} + +type Float64Literal float64 + +func (Float64Literal) Comparator() Comparator[float64] { return cmp.Compare[float64] } +func (f Float64Literal) Type() Type { return PrimitiveTypes.Float64 } +func (f Float64Literal) Value() float64 { return float64(f) } +func (f Float64Literal) String() string { return strconv.FormatFloat(float64(f), 'g', -1, 64) } +func (f Float64Literal) To(t Type) (Literal, error) { + switch t := t.(type) { + case Float32Type: + if math.MaxFloat32 < f { + return Float32AboveMaxLiteral(), nil + } else if -math.MaxFloat32 > f { + return Float32BelowMinLiteral(), nil + } + return Float32Literal(f), nil + case Float64Type: + return f, nil + case DecimalType: + v, err := decimal128.FromFloat64(float64(f), int32(t.precision), int32(t.scale)) + if err != nil { + return nil, err + } + return DecimalLiteral{Val: v, Scale: t.scale}, nil + } + + return nil, fmt.Errorf("%w: Float64Literal to %s", ErrBadCast, t) +} + +func (f Float64Literal) Equals(other Literal) bool { + return literalEq(f, other) +} + +func (f Float64Literal) MarshalBinary() (data []byte, err error) { + // stored as 8 bytes little endian + data = make([]byte, 8) + binary.LittleEndian.PutUint64(data, math.Float64bits(float64(f))) + return +} + +func (f *Float64Literal) UnmarshalBinary(data []byte) error { + // stored as 8 bytes in little endian + if len(data) != 8 { + return fmt.Errorf("%w: expected 8 bytes for float64 value, got %d", + ErrInvalidBinSerialization, len(data)) + } + *f = Float64Literal(math.Float64frombits(binary.LittleEndian.Uint64(data))) + return nil +} + +type DateLiteral Date + +func (DateLiteral) Comparator() Comparator[Date] { return cmp.Compare[Date] } +func (d DateLiteral) Type() Type { return PrimitiveTypes.Date } +func (d DateLiteral) Value() Date { return Date(d) } +func (d DateLiteral) String() string { + t := Date(d).ToTime() + return t.Format("2006-01-02") +} +func (d DateLiteral) To(t Type) (Literal, error) { + switch t.(type) { + case DateType: + return d, nil + } + return nil, fmt.Errorf("%w: DateLiteral to %s", ErrBadCast, t) +} +func (d DateLiteral) Equals(other Literal) bool { + return literalEq(d, other) +} + +func (d DateLiteral) Increment() Literal { return DateLiteral(d + 1) } +func (d DateLiteral) Decrement() Literal { return DateLiteral(d - 1) } + +func (d DateLiteral) MarshalBinary() (data []byte, err error) { + // stored as 4 byte little endian + data = make([]byte, 4) + binary.LittleEndian.PutUint32(data, uint32(d)) + return +} + +func (d *DateLiteral) UnmarshalBinary(data []byte) error { + // stored as 4 byte little endian + if len(data) != 4 { + return fmt.Errorf("%w: expected 4 bytes for date value, got %d", + ErrInvalidBinSerialization, len(data)) + } + *d = DateLiteral(binary.LittleEndian.Uint32(data)) + return nil +} + +type TimeLiteral Time + +func (TimeLiteral) Comparator() Comparator[Time] { return cmp.Compare[Time] } +func (t TimeLiteral) Type() Type { return PrimitiveTypes.Time } +func (t TimeLiteral) Value() Time { return Time(t) } +func (t TimeLiteral) String() string { + tm := time.UnixMicro(int64(t)).UTC() + return tm.Format("15:04:05.000000") +} +func (t TimeLiteral) To(typ Type) (Literal, error) { + switch typ.(type) { + case TimeType: + return t, nil + } + return nil, fmt.Errorf("%w: TimeLiteral to %s", ErrBadCast, typ) + +} +func (t TimeLiteral) Equals(other Literal) bool { + return literalEq(t, other) +} + +func (t TimeLiteral) MarshalBinary() (data []byte, err error) { + // stored as 8 byte little-endian + data = make([]byte, 8) + binary.LittleEndian.PutUint64(data, uint64(t)) + return +} + +func (t *TimeLiteral) UnmarshalBinary(data []byte) error { + // stored as 8 byte little-endian representing microseconds from midnight + if len(data) != 8 { + return fmt.Errorf("%w: expected 8 bytes for time value, got %d", + ErrInvalidBinSerialization, len(data)) + } + *t = TimeLiteral(binary.LittleEndian.Uint64(data)) + return nil +} + +type TimestampLiteral Timestamp + +func (TimestampLiteral) Comparator() Comparator[Timestamp] { return cmp.Compare[Timestamp] } +func (t TimestampLiteral) Type() Type { return PrimitiveTypes.Timestamp } +func (t TimestampLiteral) Value() Timestamp { return Timestamp(t) } +func (t TimestampLiteral) String() string { + tm := Timestamp(t).ToTime() + return tm.Format("2006-01-02 15:04:05.000000") +} +func (t TimestampLiteral) To(typ Type) (Literal, error) { + switch typ.(type) { + case TimestampType: + return t, nil + case TimestampTzType: + return t, nil + case DateType: + return DateLiteral(Timestamp(t).ToDate()), nil + } + return nil, fmt.Errorf("%w: TimestampLiteral to %s", ErrBadCast, typ) +} +func (t TimestampLiteral) Equals(other Literal) bool { + return literalEq(t, other) +} + +func (t TimestampLiteral) Increment() Literal { return TimestampLiteral(t + 1) } +func (t TimestampLiteral) Decrement() Literal { return TimestampLiteral(t - 1) } + +func (t TimestampLiteral) MarshalBinary() (data []byte, err error) { + // stored as 8 byte little endian + data = make([]byte, 8) + binary.LittleEndian.PutUint64(data, uint64(t)) + return +} + +func (t *TimestampLiteral) UnmarshalBinary(data []byte) error { + // stored as 8 byte little endian value representing microseconds since epoch + if len(data) != 8 { + return fmt.Errorf("%w: expected 8 bytes for timestamp value, got %d", + ErrInvalidBinSerialization, len(data)) + } + *t = TimestampLiteral(binary.LittleEndian.Uint64(data)) + return nil +} + +type StringLiteral string + +func (StringLiteral) Comparator() Comparator[string] { return cmp.Compare[string] } +func (s StringLiteral) Type() Type { return PrimitiveTypes.String } +func (s StringLiteral) Value() string { return string(s) } +func (s StringLiteral) String() string { return string(s) } +func (s StringLiteral) To(typ Type) (Literal, error) { + switch t := typ.(type) { + case StringType: + return s, nil + case Int32Type: + n, err := strconv.ParseInt(string(s), 10, 64) + if err != nil { + return nil, fmt.Errorf("%w: casting '%s' to %s", + errors.Join(ErrBadCast, err), s, typ) + } + + if math.MaxInt32 < n { + return Int32AboveMaxLiteral(), nil + } else if math.MinInt32 > n { + return Int32BelowMinLiteral(), nil + } + + return Int32Literal(n), nil + case Int64Type: + n, err := strconv.ParseInt(string(s), 10, 64) + if err != nil { + return nil, fmt.Errorf("%w: casting '%s' to %s", + errors.Join(ErrBadCast, err), s, typ) + } + + return Int64Literal(n), nil + case Float32Type: + n, err := strconv.ParseFloat(string(s), 32) + if err != nil { + return nil, fmt.Errorf("%w: casting '%s' to %s", + errors.Join(ErrBadCast, err), s, typ) + } + return Float32Literal(n), nil + case Float64Type: + n, err := strconv.ParseFloat(string(s), 64) + if err != nil { + return nil, fmt.Errorf("%w: casting '%s' to %s", + errors.Join(ErrBadCast, err), s, typ) + } + return Float64Literal(n), nil + case DateType: + tm, err := time.Parse("2006-01-02", string(s)) + if err != nil { + return nil, fmt.Errorf("%w: casting '%s' to %s - %s", + ErrBadCast, s, typ, err.Error()) + } + return DateLiteral(tm.Truncate(24*time.Hour).Unix() / int64((time.Hour * 24).Seconds())), nil + case TimeType: + val, err := arrow.Time64FromString(string(s), arrow.Microsecond) + if err != nil { + return nil, fmt.Errorf("%w: casting '%s' to %s - %s", + ErrBadCast, s, typ, err.Error()) + } + + return TimeLiteral(val), nil + case TimestampType: + // requires RFC3339 with no time zone + tm, err := time.Parse("2006-01-02T15:04:05", string(s)) + if err != nil { + return nil, fmt.Errorf("%w: invalid Timestamp format for casting from string '%s': %s", + ErrBadCast, s, err.Error()) + } + + return TimestampLiteral(Timestamp(tm.UTC().UnixMicro())), nil + case TimestampTzType: + // requires RFC3339 format WITH time zone + tm, err := time.Parse(time.RFC3339, string(s)) + if err != nil { + return nil, fmt.Errorf("%w: invalid TimestampTz format for casting from string '%s': %s", + ErrBadCast, s, err.Error()) + } + + return TimestampLiteral(Timestamp(tm.UTC().UnixMicro())), nil + case UUIDType: + val, err := uuid.Parse(string(s)) + if err != nil { + return nil, fmt.Errorf("%w: casting '%s' to %s - %s", + ErrBadCast, s, typ, err.Error()) + } + return UUIDLiteral(val), nil + case DecimalType: + n, err := decimal128.FromString(string(s), int32(t.precision), int32(t.scale)) + if err != nil { + return nil, fmt.Errorf("%w: casting '%s' to %s - %s", + ErrBadCast, s, typ, err.Error()) + } + return DecimalLiteral{Val: n, Scale: t.scale}, nil + case BooleanType: + val, err := strconv.ParseBool(string(s)) + if err != nil { + return nil, fmt.Errorf("%w: casting '%s' to %s - %s", + ErrBadCast, s, typ, err.Error()) + } + return BoolLiteral(val), nil + case BinaryType: + return BinaryLiteral(s), nil + case FixedType: + if len(s) != t.len { + return nil, fmt.Errorf("%w: cast '%s' to %s - wrong length", + ErrBadCast, s, t) + } + return FixedLiteral(s), nil + } + return nil, fmt.Errorf("%w: StringLiteral to %s", ErrBadCast, typ) +} + +func (s StringLiteral) Equals(other Literal) bool { + return literalEq(s, other) +} + +func (s StringLiteral) MarshalBinary() (data []byte, err error) { + // stored as UTF-8 bytes without length + // avoid copying by just returning a slice of the raw bytes + data = unsafe.Slice(unsafe.StringData(string(s)), len(s)) + return +} + +func (s *StringLiteral) UnmarshalBinary(data []byte) error { + // stored as UTF-8 bytes without length + // avoid copy, but this means that the passed in slice is being given + // to the literal for ownership + *s = StringLiteral(unsafe.String(unsafe.SliceData(data), len(data))) + return nil +} + +type BinaryLiteral []byte + +func (BinaryLiteral) Comparator() Comparator[[]byte] { + return bytes.Compare +} +func (b BinaryLiteral) Type() Type { return PrimitiveTypes.Binary } +func (b BinaryLiteral) Value() []byte { return []byte(b) } +func (b BinaryLiteral) String() string { return string(b) } +func (b BinaryLiteral) To(typ Type) (Literal, error) { + switch t := typ.(type) { + case UUIDType: + val, err := uuid.FromBytes(b) + if err != nil { + return nil, fmt.Errorf("%w: cannot convert BinaryLiteral to UUID", + errors.Join(ErrBadCast, err)) + } + return UUIDLiteral(val), nil + case FixedType: + if len(b) == t.len { + return FixedLiteral(b), nil + } + + return nil, fmt.Errorf("%w: cannot convert BinaryLiteral to %s, different length - %d <> %d", + ErrBadCast, typ, len(b), t.len) + case BinaryType: + return b, nil + } + + return nil, fmt.Errorf("%w: BinaryLiteral to %s", ErrBadCast, typ) +} +func (b BinaryLiteral) Equals(other Literal) bool { + rhs, ok := other.(BinaryLiteral) + if !ok { + return false + } + + return bytes.Equal([]byte(b), rhs) +} + +func (b BinaryLiteral) MarshalBinary() (data []byte, err error) { + // stored directly as is + data = b + return +} + +func (b *BinaryLiteral) UnmarshalBinary(data []byte) error { + // stored directly as is + *b = BinaryLiteral(data) + return nil +} + +type FixedLiteral []byte + +func (FixedLiteral) Comparator() Comparator[[]byte] { return bytes.Compare } +func (f FixedLiteral) Type() Type { return FixedTypeOf(len(f)) } +func (f FixedLiteral) Value() []byte { return []byte(f) } +func (f FixedLiteral) String() string { return string(f) } +func (f FixedLiteral) To(typ Type) (Literal, error) { + switch t := typ.(type) { + case UUIDType: + val, err := uuid.FromBytes(f) + if err != nil { + return nil, fmt.Errorf("%w: cannot convert FixedLiteral to UUID - %s", + ErrBadCast, err.Error()) + } + return UUIDLiteral(val), nil + case FixedType: + if len(f) == t.len { + return FixedLiteral(f), nil + } + + return nil, fmt.Errorf("%w: cannot convert FixedLiteral to %s, different length - %d <> %d", + ErrBadCast, typ, len(f), t.len) + case BinaryType: + return f, nil + } + + return nil, fmt.Errorf("%w: FixedLiteral[%d] to %s", + ErrBadCast, len(f), typ) +} +func (f FixedLiteral) Equals(other Literal) bool { + rhs, ok := other.(FixedLiteral) + if !ok { + return false + } + + return bytes.Equal([]byte(f), rhs) +} + +func (f FixedLiteral) MarshalBinary() (data []byte, err error) { + // stored directly as is + data = f + return +} + +func (f *FixedLiteral) UnmarshalBinary(data []byte) error { + // stored directly as is + *f = FixedLiteral(data) + return nil +} + +type UUIDLiteral uuid.UUID + +func (UUIDLiteral) Comparator() Comparator[uuid.UUID] { + return func(v1, v2 uuid.UUID) int { + return bytes.Compare(v1[:], v2[:]) + } +} + +func (UUIDLiteral) Type() Type { return PrimitiveTypes.UUID } +func (u UUIDLiteral) Value() uuid.UUID { return uuid.UUID(u) } +func (u UUIDLiteral) String() string { return uuid.UUID(u).String() } +func (u UUIDLiteral) To(typ Type) (Literal, error) { + switch t := typ.(type) { + case UUIDType: + return u, nil + case FixedType: + if len(u) == t.len { + v, _ := uuid.UUID(u).MarshalBinary() + return FixedLiteral(v), nil + } + + return nil, fmt.Errorf("%w: cannot convert UUIDLiteral to %s, different length - %d <> %d", + ErrBadCast, typ, len(u), t.len) + case BinaryType: + v, _ := uuid.UUID(u).MarshalBinary() + return BinaryLiteral(v), nil + } + + return nil, fmt.Errorf("%w: UUIDLiteral to %s", ErrBadCast, typ) +} +func (u UUIDLiteral) Equals(other Literal) bool { + rhs, ok := other.(UUIDLiteral) + if !ok { + return false + } + + return uuid.UUID(u) == uuid.UUID(rhs) +} + +func (u UUIDLiteral) MarshalBinary() (data []byte, err error) { + return uuid.UUID(u).MarshalBinary() +} + +func (u *UUIDLiteral) UnmarshalBinary(data []byte) error { + // stored as 16-byte big-endian value + out, err := uuid.FromBytes(data) + if err != nil { + return err + } + *u = UUIDLiteral(out) + return nil +} + +type DecimalLiteral Decimal + +func (DecimalLiteral) Comparator() Comparator[Decimal] { + return func(v1, v2 Decimal) int { + if v1.Scale == v2.Scale { + return v1.Val.Cmp(v2.Val) + } + + rescaled, err := v2.Val.Rescale(int32(v2.Scale), int32(v1.Scale)) + if err != nil { + return -1 + } + + return v1.Val.Cmp(rescaled) + } +} +func (d DecimalLiteral) Type() Type { return DecimalTypeOf(9, d.Scale) } +func (d DecimalLiteral) Value() Decimal { return Decimal(d) } +func (d DecimalLiteral) String() string { + return d.Val.ToString(int32(d.Scale)) +} + +func (d DecimalLiteral) To(t Type) (Literal, error) { + switch t := t.(type) { + case DecimalType: + if d.Scale == t.scale { + return d, nil + } + + return nil, fmt.Errorf("%w: could not convert %v to %s", + ErrBadCast, d, t) + case Int32Type: + v := d.Val.BigInt().Int64() + if v > math.MaxInt32 { + return Int32AboveMaxLiteral(), nil + } else if v < math.MinInt32 { + return Int32BelowMinLiteral(), nil + } + + return Int32Literal(int32(v)), nil + case Int64Type: + v := d.Val.BigInt() + if !v.IsInt64() { + if v.Sign() > 0 { + return Int64AboveMaxLiteral(), nil + } else if v.Sign() < 0 { + return Int64BelowMinLiteral(), nil + } + } + + return Int64Literal(v.Int64()), nil + case Float32Type: + v := d.Val.ToFloat64(int32(d.Scale)) + if v > math.MaxFloat32 { + return Float32AboveMaxLiteral(), nil + } else if v < -math.MaxFloat32 { + return Float32BelowMinLiteral(), nil + } + + return Float32Literal(float32(v)), nil + case Float64Type: + return Float64Literal(d.Val.ToFloat64(int32(d.Scale))), nil + } + + return nil, fmt.Errorf("%w: DecimalLiteral to %s", ErrBadCast, t) +} + +func (d DecimalLiteral) Equals(other Literal) bool { + rhs, ok := other.(DecimalLiteral) + if !ok { + return false + } + + rescaled, err := rhs.Val.Rescale(int32(rhs.Scale), int32(d.Scale)) + if err != nil { + return false + } + return d.Val == rescaled +} + +func (d DecimalLiteral) Increment() Literal { + d.Val = d.Val.Add(decimal128.FromU64(1)) + return d +} + +func (d DecimalLiteral) Decrement() Literal { + d.Val = d.Val.Sub(decimal128.FromU64(1)) + return d +} + +func (d DecimalLiteral) MarshalBinary() (data []byte, err error) { + // stored as unscaled value in two's compliment big-endian values + // using the minimum number of bytes for the values + n := decimal128.Num(d.Val).BigInt() + // bytes gives absolute value as big-endian bytes + data = n.Bytes() + if n.Sign() < 0 { + // convert to 2's complement for negative value + for i, v := range data { + data[i] = ^v + } + data[len(data)-1] += 1 + } + return +} + +func (d *DecimalLiteral) UnmarshalBinary(data []byte) error { + // stored as unscaled value in two's complement + // big-endian values using the minimum number of bytes + if len(data) == 0 { + d.Val = decimal128.Num{} + return nil + } + + if int8(data[0]) >= 0 { + // not negative + d.Val = decimal128.FromBigInt((&big.Int{}).SetBytes(data)) + return nil + } + + // convert two's complement and remember it's negative + out := make([]byte, len(data)) + for i, b := range data { + out[i] = ^b + } + out[len(out)-1] += 1 + + value := (&big.Int{}).SetBytes(out) + d.Val = decimal128.FromBigInt(value.Neg(value)) + return nil +} diff --git a/literals_test.go b/literals_test.go index 4dbb7f2..c7ec370 100644 --- a/literals_test.go +++ b/literals_test.go @@ -1,1004 +1,1004 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package iceberg_test - -import ( - "math" - "strconv" - "testing" - "time" - - "github.com/apache/arrow-go/v18/arrow" - "github.com/apache/arrow-go/v18/arrow/decimal128" - "github.com/apache/iceberg-go" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestNumericLiteralCompare(t *testing.T) { - smallLit := iceberg.NewLiteral(int32(10)).(iceberg.Int32Literal) - bigLit := iceberg.NewLiteral(int32(1000)).(iceberg.Int32Literal) - - assert.False(t, smallLit.Equals(bigLit)) - assert.True(t, smallLit.Equals(iceberg.NewLiteral(int32(10)))) - - cmp := smallLit.Comparator() - - assert.Equal(t, -1, cmp(smallLit.Value(), bigLit.Value())) - assert.Equal(t, 1, cmp(bigLit.Value(), smallLit.Value())) - assert.True(t, smallLit.Type().Equals(iceberg.PrimitiveTypes.Int32)) -} - -func TestIntConversion(t *testing.T) { - lit := iceberg.NewLiteral(int32(34)) - - t.Run("to int64", func(t *testing.T) { - longLit, err := lit.To(iceberg.PrimitiveTypes.Int64) - assert.NoError(t, err) - assert.IsType(t, iceberg.Int64Literal(0), longLit) - assert.EqualValues(t, 34, longLit) - }) - - t.Run("to float32", func(t *testing.T) { - floatLit, err := lit.To(iceberg.PrimitiveTypes.Float32) - assert.NoError(t, err) - assert.IsType(t, iceberg.Float32Literal(0), floatLit) - assert.EqualValues(t, 34, floatLit) - }) - - t.Run("to float64", func(t *testing.T) { - dblLit, err := lit.To(iceberg.PrimitiveTypes.Float64) - assert.NoError(t, err) - assert.IsType(t, iceberg.Float64Literal(0), dblLit) - assert.EqualValues(t, 34, dblLit) - }) -} - -func TestIntToDecimalConversion(t *testing.T) { - tests := []struct { - ty iceberg.DecimalType - val iceberg.Decimal - }{ - {iceberg.DecimalTypeOf(9, 0), - iceberg.Decimal{Val: decimal128.FromI64(34), Scale: 0}}, - {iceberg.DecimalTypeOf(9, 2), - iceberg.Decimal{Val: decimal128.FromI64(3400), Scale: 2}}, - {iceberg.DecimalTypeOf(9, 4), - iceberg.Decimal{Val: decimal128.FromI64(340000), Scale: 4}}, - } - - for _, tt := range tests { - t.Run(tt.ty.String(), func(t *testing.T) { - lit := iceberg.Int32Literal(34) - - dec, err := lit.To(tt.ty) - require.NoError(t, err) - assert.IsType(t, iceberg.DecimalLiteral(tt.val), dec) - assert.EqualValues(t, tt.val, dec) - }) - } -} - -func TestIntToDateConversion(t *testing.T) { - oneDay, _ := time.Parse("2006-01-02", "2022-03-28") - val := int32(arrow.Date32FromTime(oneDay)) - dateLit, err := iceberg.NewLiteral(val).To(iceberg.PrimitiveTypes.Date) - require.NoError(t, err) - assert.True(t, dateLit.Type().Equals(iceberg.PrimitiveTypes.Date)) - assert.EqualValues(t, val, dateLit) - - lit := iceberg.Int32Literal(34) - tm, err := lit.To(iceberg.PrimitiveTypes.Time) - require.NoError(t, err) - assert.EqualValues(t, lit, tm) - - tm, err = lit.To(iceberg.PrimitiveTypes.Timestamp) - require.NoError(t, err) - assert.EqualValues(t, lit, tm) - - tm, err = lit.To(iceberg.PrimitiveTypes.TimestampTz) - require.NoError(t, err) - assert.EqualValues(t, lit, tm) -} - -func TestInt64Conversions(t *testing.T) { - tests := []struct { - from iceberg.Int64Literal - to iceberg.Literal - }{ - {iceberg.Int64Literal(34), iceberg.NewLiteral(int32(34))}, - {iceberg.Int64Literal(34), iceberg.NewLiteral(float32(34))}, - {iceberg.Int64Literal(34), iceberg.NewLiteral(float64(34))}, - {iceberg.Int64Literal(19709), iceberg.NewLiteral(iceberg.Date(19709))}, - {iceberg.Int64Literal(51661919000), iceberg.NewLiteral(iceberg.Time(51661919000))}, - {iceberg.Int64Literal(1647305201), iceberg.NewLiteral(iceberg.Timestamp(1647305201))}, - {iceberg.Int64Literal(34), - iceberg.NewLiteral(iceberg.Decimal{Val: decimal128.FromI64(34), Scale: 0})}, - {iceberg.Int64Literal(34), - iceberg.NewLiteral(iceberg.Decimal{Val: decimal128.FromI64(3400), Scale: 2})}, - {iceberg.Int64Literal(34), - iceberg.NewLiteral(iceberg.Decimal{Val: decimal128.FromI64(340000), Scale: 4})}, - } - - for _, tt := range tests { - t.Run(tt.to.Type().String(), func(t *testing.T) { - got, err := tt.from.To(tt.to.Type()) - require.NoError(t, err) - assert.True(t, tt.to.Equals(got)) - }) - } -} - -func TestInt64ToInt32OutsideBound(t *testing.T) { - bigLit := iceberg.NewLiteral(int64(math.MaxInt32 + 1)) - aboveMax, err := bigLit.To(iceberg.PrimitiveTypes.Int32) - require.NoError(t, err) - assert.Implements(t, (*iceberg.AboveMaxLiteral)(nil), aboveMax) - assert.Equal(t, iceberg.Int32AboveMaxLiteral(), aboveMax) - assert.Equal(t, iceberg.PrimitiveTypes.Int32, aboveMax.Type()) - - smallLit := iceberg.NewLiteral(int64(math.MinInt32 - 1)) - belowMin, err := smallLit.To(iceberg.PrimitiveTypes.Int32) - require.NoError(t, err) - assert.Implements(t, (*iceberg.BelowMinLiteral)(nil), belowMin) - assert.Equal(t, iceberg.Int32BelowMinLiteral(), belowMin) - assert.Equal(t, iceberg.PrimitiveTypes.Int32, belowMin.Type()) -} - -func TestFloatConversions(t *testing.T) { - n1, _ := decimal128.FromFloat32(34.56, 9, 1) - n2, _ := decimal128.FromFloat32(34.56, 9, 2) - n3, _ := decimal128.FromFloat32(34.56, 9, 4) - - tests := []struct { - from iceberg.Float32Literal - to iceberg.Literal - }{ - {iceberg.Float32Literal(34.5), iceberg.NewLiteral(float64(34.5))}, - {iceberg.Float32Literal(34.56), - iceberg.NewLiteral(iceberg.Decimal{Val: n1, Scale: 1})}, - {iceberg.Float32Literal(34.56), - iceberg.NewLiteral(iceberg.Decimal{Val: n2, Scale: 2})}, - {iceberg.Float32Literal(34.56), - iceberg.NewLiteral(iceberg.Decimal{Val: n3, Scale: 4})}, - } - - for _, tt := range tests { - t.Run(tt.to.Type().String(), func(t *testing.T) { - got, err := tt.from.To(tt.to.Type()) - require.NoError(t, err) - assert.Truef(t, tt.to.Equals(got), "expected: %s, got: %s", tt.to, got) - }) - } -} - -func TestFloat64Conversions(t *testing.T) { - n1, _ := decimal128.FromFloat64(34.56, 9, 1) - n2, _ := decimal128.FromFloat64(34.56, 9, 2) - n3, _ := decimal128.FromFloat64(34.56, 9, 4) - - tests := []struct { - from iceberg.Float64Literal - to iceberg.Literal - }{ - {iceberg.Float64Literal(34.5), iceberg.NewLiteral(float32(34.5))}, - {iceberg.Float64Literal(34.56), - iceberg.NewLiteral(iceberg.Decimal{Val: n1, Scale: 1})}, - {iceberg.Float64Literal(34.56), - iceberg.NewLiteral(iceberg.Decimal{Val: n2, Scale: 2})}, - {iceberg.Float64Literal(34.56), - iceberg.NewLiteral(iceberg.Decimal{Val: n3, Scale: 4})}, - } - - for _, tt := range tests { - t.Run(tt.to.Type().String(), func(t *testing.T) { - got, err := tt.from.To(tt.to.Type()) - require.NoError(t, err) - assert.Truef(t, tt.to.Equals(got), "expected: %s, got: %s", tt.to, got) - }) - } -} - -func TestFloat64toFloat32OutsideBounds(t *testing.T) { - bigLit := iceberg.NewLiteral(float64(math.MaxFloat32 + 1.0e37)) - aboveMax, err := bigLit.To(iceberg.PrimitiveTypes.Float32) - require.NoError(t, err) - assert.Equal(t, iceberg.Float32AboveMaxLiteral(), aboveMax) - - smallLit := iceberg.NewLiteral(float64(-math.MaxFloat32 - 1.0e37)) - belowMin, err := smallLit.To(iceberg.PrimitiveTypes.Float32) - require.NoError(t, err) - assert.Equal(t, iceberg.Float32BelowMinLiteral(), belowMin) -} - -func TestDecimalToDecimalConversion(t *testing.T) { - lit := iceberg.NewLiteral(iceberg.Decimal{Val: decimal128.FromI64(3411), Scale: 2}) - - v, err := lit.To(iceberg.DecimalTypeOf(9, 2)) - require.NoError(t, err) - assert.Equal(t, lit, v) - - v, err = lit.To(iceberg.DecimalTypeOf(11, 2)) - require.NoError(t, err) - assert.Equal(t, lit, v) - - _, err = lit.To(iceberg.DecimalTypeOf(9, 0)) - assert.ErrorIs(t, err, iceberg.ErrBadCast) - assert.ErrorContains(t, err, "could not convert 34.11 to decimal(9, 0)") - - _, err = lit.To(iceberg.DecimalTypeOf(9, 1)) - assert.ErrorIs(t, err, iceberg.ErrBadCast) - assert.ErrorContains(t, err, "could not convert 34.11 to decimal(9, 1)") - - _, err = lit.To(iceberg.DecimalTypeOf(9, 3)) - assert.ErrorIs(t, err, iceberg.ErrBadCast) - assert.ErrorContains(t, err, "could not convert 34.11 to decimal(9, 3)") -} - -func TestDecimalLiteralConversions(t *testing.T) { - n1 := iceberg.Decimal{Val: decimal128.FromI64(1234), Scale: 2} - n2 := iceberg.Decimal{Val: decimal128.FromI64(math.MaxInt32 + 1), Scale: 0} - n3 := iceberg.Decimal{Val: decimal128.FromI64(math.MinInt32 - 1), Scale: 10} - - tests := []struct { - from iceberg.DecimalLiteral - to iceberg.Literal - }{ - {iceberg.DecimalLiteral(n1), iceberg.NewLiteral(int32(1234))}, - {iceberg.DecimalLiteral(n1), iceberg.NewLiteral(int64(1234))}, - {iceberg.DecimalLiteral(n2), iceberg.NewLiteral(int64(math.MaxInt32 + 1))}, - {iceberg.DecimalLiteral(n1), iceberg.NewLiteral(float32(12.34))}, - {iceberg.DecimalLiteral(n1), iceberg.NewLiteral(float64(12.34))}, - {iceberg.DecimalLiteral(n3), iceberg.NewLiteral(int64(math.MinInt32 - 1))}, - } - - for _, tt := range tests { - t.Run(tt.to.Type().String(), func(t *testing.T) { - got, err := tt.from.To(tt.to.Type()) - require.NoError(t, err) - assert.Truef(t, tt.to.Equals(got), "expected: %s, got: %s", tt.to, got) - }) - } - - above, err := iceberg.DecimalLiteral(n2).To(iceberg.PrimitiveTypes.Int32) - require.NoError(t, err) - assert.Equal(t, iceberg.Int32AboveMaxLiteral(), above) - assert.Equal(t, iceberg.PrimitiveTypes.Int32, above.Type()) - - below, err := iceberg.DecimalLiteral(n3).To(iceberg.PrimitiveTypes.Int32) - require.NoError(t, err) - assert.Equal(t, iceberg.Int32BelowMinLiteral(), below) - assert.Equal(t, iceberg.PrimitiveTypes.Int32, below.Type()) - - n4 := iceberg.Decimal{Val: decimal128.FromU64(math.MaxInt64 + 1), Scale: 0} - n5 := iceberg.Decimal{Val: decimal128.FromU64(math.MaxUint64).Negate(), Scale: 20} - - above, err = iceberg.DecimalLiteral(n4).To(iceberg.PrimitiveTypes.Int64) - require.NoError(t, err) - assert.Equal(t, iceberg.Int64AboveMaxLiteral(), above) - assert.Equal(t, iceberg.PrimitiveTypes.Int64, above.Type()) - - below, err = iceberg.DecimalLiteral(n5).To(iceberg.PrimitiveTypes.Int64) - require.NoError(t, err) - assert.Equal(t, iceberg.Int64BelowMinLiteral(), below) - assert.Equal(t, iceberg.PrimitiveTypes.Int64, below.Type()) - - v, err := decimal128.FromFloat64(math.MaxFloat32+1e37, 38, -1) - require.NoError(t, err) - above, err = iceberg.DecimalLiteral(iceberg.Decimal{Val: v, Scale: -1}). - To(iceberg.PrimitiveTypes.Float32) - require.NoError(t, err) - assert.Equal(t, iceberg.Float32AboveMaxLiteral(), above) - assert.Equal(t, iceberg.PrimitiveTypes.Float32, above.Type()) - - below, err = iceberg.DecimalLiteral(iceberg.Decimal{Val: v.Negate(), Scale: -1}). - To(iceberg.PrimitiveTypes.Float32) - require.NoError(t, err) - assert.Equal(t, iceberg.Float32BelowMinLiteral(), below) - assert.Equal(t, iceberg.PrimitiveTypes.Float32, below.Type()) -} - -func TestLiteralTimestampToDate(t *testing.T) { - v, _ := arrow.TimestampFromString("1970-01-01T00:00:00.000000+00:00", arrow.Microsecond) - tsLit := iceberg.NewLiteral(iceberg.Timestamp(v)) - dateLit, err := tsLit.To(iceberg.PrimitiveTypes.Date) - require.NoError(t, err) - assert.Zero(t, dateLit) -} - -func TestStringLiterals(t *testing.T) { - sqrt2 := iceberg.NewLiteral("1.414") - pi := iceberg.NewLiteral("3.141") - piStr := iceberg.StringLiteral("3.141") - piDbl := iceberg.NewLiteral(float64(3.141)) - - v, err := pi.To(iceberg.PrimitiveTypes.Float64) - require.NoError(t, err) - assert.Equal(t, piDbl, v) - - assert.False(t, sqrt2.Equals(pi)) - assert.True(t, pi.Equals(piStr)) - assert.False(t, pi.Equals(piDbl)) - assert.Equal(t, "3.141", pi.String()) - - cmp := piStr.Comparator() - assert.Equal(t, -1, cmp(sqrt2.(iceberg.StringLiteral).Value(), piStr.Value())) - assert.Equal(t, 1, cmp(piStr.Value(), sqrt2.(iceberg.StringLiteral).Value())) - - v, err = pi.To(iceberg.PrimitiveTypes.String) - require.NoError(t, err) - assert.Equal(t, pi, v) -} - -func TestStringLiteralConversion(t *testing.T) { - tm, _ := time.Parse("2006-01-02", "2017-08-18") - expected := uuid.New() - - tests := []struct { - from iceberg.StringLiteral - to iceberg.Literal - }{ - {iceberg.StringLiteral("2017-08-18"), - iceberg.NewLiteral(iceberg.Date(arrow.Date32FromTime(tm)))}, - {iceberg.StringLiteral("14:21:01.919"), - iceberg.NewLiteral(iceberg.Time(51661919000))}, - {iceberg.StringLiteral("2017-08-18T14:21:01.919234"), - iceberg.NewLiteral(iceberg.Timestamp(1503066061919234))}, - {iceberg.StringLiteral(expected.String()), iceberg.NewLiteral(expected)}, - {iceberg.StringLiteral("34.560"), - iceberg.NewLiteral(iceberg.Decimal{Val: decimal128.FromI64(34560), Scale: 3})}, - {iceberg.StringLiteral("true"), iceberg.NewLiteral(true)}, - {iceberg.StringLiteral("True"), iceberg.NewLiteral(true)}, - {iceberg.StringLiteral("false"), iceberg.NewLiteral(false)}, - {iceberg.StringLiteral("False"), iceberg.NewLiteral(false)}, - {iceberg.StringLiteral("12345"), iceberg.NewLiteral(int32(12345))}, - {iceberg.StringLiteral("12345123456"), iceberg.NewLiteral(int64(12345123456))}, - {iceberg.StringLiteral("3.14"), iceberg.NewLiteral(float32(3.14))}, - } - - for _, tt := range tests { - t.Run(tt.to.Type().String(), func(t *testing.T) { - got, err := tt.from.To(tt.to.Type()) - require.NoError(t, err) - assert.Truef(t, tt.to.Equals(got), "expected: %s, got: %s", tt.to, got) - }) - } - - lit := iceberg.StringLiteral("2017-08-18T14:21:01.919234-07:00") - casted, err := lit.To(iceberg.PrimitiveTypes.TimestampTz) - require.NoError(t, err) - expectedTimestamp := iceberg.NewLiteral(iceberg.Timestamp(1503091261919234)) - assert.Truef(t, casted.Equals(expectedTimestamp), "expected: %s, got: %s", - expectedTimestamp, casted) - - _, err = lit.To(iceberg.PrimitiveTypes.Timestamp) - require.Error(t, err) - assert.ErrorIs(t, err, iceberg.ErrBadCast) - assert.ErrorContains(t, err, `parsing time "2017-08-18T14:21:01.919234-07:00": extra text: "-07:00"`) - assert.ErrorContains(t, err, "invalid Timestamp format for casting from string") - - _, err = iceberg.StringLiteral("2017-08-18T14:21:01.919234").To(iceberg.PrimitiveTypes.TimestampTz) - require.Error(t, err) - assert.ErrorIs(t, err, iceberg.ErrBadCast) - assert.ErrorContains(t, err, `cannot parse "" as "Z07:00"`) -} - -func TestLiteralIdentityConversions(t *testing.T) { - fixedLit, _ := iceberg.NewLiteral([]byte{0x01, 0x02, 0x03}).To(iceberg.FixedTypeOf(3)) - - tests := []struct { - lit iceberg.Literal - typ iceberg.PrimitiveType - }{ - {iceberg.NewLiteral(true), iceberg.PrimitiveTypes.Bool}, - {iceberg.NewLiteral(int32(34)), iceberg.PrimitiveTypes.Int32}, - {iceberg.NewLiteral(int64(340000000)), iceberg.PrimitiveTypes.Int64}, - {iceberg.NewLiteral(float32(34.11)), iceberg.PrimitiveTypes.Float32}, - {iceberg.NewLiteral(float64(3.5028235e38)), iceberg.PrimitiveTypes.Float64}, - {iceberg.NewLiteral(iceberg.Decimal{Val: decimal128.FromI64(3455), Scale: 2}), - iceberg.DecimalTypeOf(9, 2)}, - {iceberg.NewLiteral(iceberg.Date(19079)), iceberg.PrimitiveTypes.Date}, - {iceberg.NewLiteral(iceberg.Timestamp(1503091261919234)), - iceberg.PrimitiveTypes.Timestamp}, - {iceberg.NewLiteral("abc"), iceberg.PrimitiveTypes.String}, - {iceberg.NewLiteral(uuid.New()), iceberg.PrimitiveTypes.UUID}, - {fixedLit, iceberg.FixedTypeOf(3)}, - {iceberg.NewLiteral([]byte{0x01, 0x02, 0x03}), iceberg.PrimitiveTypes.Binary}, - } - - for _, tt := range tests { - t.Run(tt.typ.String(), func(t *testing.T) { - expected, err := tt.lit.To(tt.typ) - require.NoError(t, err) - assert.Equal(t, expected, tt.lit) - }) - } -} - -func TestFixedLiteral(t *testing.T) { - fixedLit012 := iceberg.FixedLiteral{0x00, 0x01, 0x02} - fixedLit013 := iceberg.FixedLiteral{0x00, 0x01, 0x03} - assert.True(t, fixedLit012.Equals(fixedLit012)) - assert.False(t, fixedLit012.Equals(fixedLit013)) - - cmp := fixedLit012.Comparator() - assert.Equal(t, -1, cmp(fixedLit012, fixedLit013)) - assert.Equal(t, 1, cmp(fixedLit013, fixedLit012)) - assert.Equal(t, 0, cmp(fixedLit013, fixedLit013)) - - testUuid := uuid.New() - lit, err := iceberg.NewLiteral(testUuid[:]).To(iceberg.FixedTypeOf(16)) - require.NoError(t, err) - uuidLit, err := lit.To(iceberg.PrimitiveTypes.UUID) - require.NoError(t, err) - - assert.EqualValues(t, uuidLit, testUuid) - - fixedUuid, err := uuidLit.To(iceberg.FixedTypeOf(16)) - require.NoError(t, err) - assert.EqualValues(t, testUuid[:], fixedUuid) - - binUuid, err := uuidLit.To(iceberg.PrimitiveTypes.Binary) - require.NoError(t, err) - assert.EqualValues(t, testUuid[:], binUuid) - - binlit, err := fixedLit012.To(iceberg.PrimitiveTypes.Binary) - require.NoError(t, err) - assert.EqualValues(t, fixedLit012, binlit) -} - -func TestBinaryLiteral(t *testing.T) { - binLit012 := iceberg.NewLiteral([]byte{0x00, 0x01, 0x02}).(iceberg.BinaryLiteral) - binLit013 := iceberg.NewLiteral([]byte{0x00, 0x01, 0x03}).(iceberg.BinaryLiteral) - assert.True(t, binLit012.Equals(binLit012)) - assert.False(t, binLit012.Equals(binLit013)) - - cmp := binLit012.Comparator() - assert.Equal(t, -1, cmp(binLit012, binLit013)) - assert.Equal(t, 1, cmp(binLit013, binLit012)) - assert.Equal(t, 0, cmp(binLit013, binLit013)) -} - -func TestBinaryLiteralConversions(t *testing.T) { - binLit012 := iceberg.NewLiteral([]byte{0x00, 0x01, 0x02}) - fixed, err := binLit012.To(iceberg.FixedTypeOf(3)) - require.NoError(t, err) - assert.Equal(t, iceberg.FixedLiteral{0x00, 0x01, 0x02}, fixed) - - _, err = binLit012.To(iceberg.FixedTypeOf(4)) - assert.ErrorIs(t, err, iceberg.ErrBadCast) - assert.ErrorContains(t, err, "cannot convert BinaryLiteral to fixed[4], different length - 3 <> 4") - - _, err = binLit012.To(iceberg.FixedTypeOf(2)) - assert.ErrorIs(t, err, iceberg.ErrBadCast) - assert.ErrorContains(t, err, "cannot convert BinaryLiteral to fixed[2], different length - 3 <> 2") - - testUuid := uuid.New() - lit := iceberg.NewLiteral(testUuid[:]) - uuidLit, err := lit.To(iceberg.PrimitiveTypes.UUID) - require.NoError(t, err) - assert.EqualValues(t, testUuid, uuidLit) - - _, err = binLit012.To(iceberg.PrimitiveTypes.UUID) - assert.ErrorIs(t, err, iceberg.ErrBadCast) - assert.ErrorContains(t, err, "cannot convert BinaryLiteral to UUID") -} - -func testInvalidLiteralConversions(t *testing.T, lit iceberg.Literal, typs []iceberg.Type) { - t.Run(lit.Type().String(), func(t *testing.T) { - for _, tt := range typs { - t.Run(tt.String(), func(t *testing.T) { - _, err := lit.To(tt) - assert.ErrorIs(t, err, iceberg.ErrBadCast) - }) - } - }) -} - -func TestInvalidBoolLiteralConversions(t *testing.T) { - testInvalidLiteralConversions(t, iceberg.NewLiteral(true), []iceberg.Type{ - iceberg.PrimitiveTypes.Int32, - iceberg.PrimitiveTypes.Int64, - iceberg.PrimitiveTypes.Float32, - iceberg.PrimitiveTypes.Float64, - iceberg.PrimitiveTypes.Date, - iceberg.PrimitiveTypes.Time, - iceberg.PrimitiveTypes.Timestamp, - iceberg.PrimitiveTypes.TimestampTz, - iceberg.DecimalTypeOf(9, 2), - iceberg.PrimitiveTypes.String, - iceberg.PrimitiveTypes.UUID, - iceberg.PrimitiveTypes.Binary, - iceberg.FixedTypeOf(2), - }) -} - -func TestInvalidNumericConversions(t *testing.T) { - testInvalidLiteralConversions(t, iceberg.NewLiteral(int32(34)), []iceberg.Type{ - iceberg.PrimitiveTypes.Bool, - iceberg.PrimitiveTypes.String, - iceberg.PrimitiveTypes.UUID, - iceberg.FixedTypeOf(1), - iceberg.PrimitiveTypes.Binary, - }) - - testInvalidLiteralConversions(t, iceberg.NewLiteral(int64(34)), []iceberg.Type{ - iceberg.PrimitiveTypes.Bool, - iceberg.PrimitiveTypes.String, - iceberg.PrimitiveTypes.UUID, - iceberg.FixedTypeOf(1), - iceberg.PrimitiveTypes.Binary, - }) - - testInvalidLiteralConversions(t, iceberg.NewLiteral(float32(34)), []iceberg.Type{ - iceberg.PrimitiveTypes.Bool, - iceberg.PrimitiveTypes.Int32, - iceberg.PrimitiveTypes.Int64, - iceberg.PrimitiveTypes.Date, - iceberg.PrimitiveTypes.Time, - iceberg.PrimitiveTypes.Timestamp, - iceberg.PrimitiveTypes.TimestampTz, - iceberg.PrimitiveTypes.String, - iceberg.PrimitiveTypes.UUID, - iceberg.FixedTypeOf(1), - iceberg.PrimitiveTypes.Binary, - }) - - testInvalidLiteralConversions(t, iceberg.NewLiteral(float64(34)), []iceberg.Type{ - iceberg.PrimitiveTypes.Bool, - iceberg.PrimitiveTypes.Int32, - iceberg.PrimitiveTypes.Int64, - iceberg.PrimitiveTypes.Date, - iceberg.PrimitiveTypes.Time, - iceberg.PrimitiveTypes.Timestamp, - iceberg.PrimitiveTypes.TimestampTz, - iceberg.PrimitiveTypes.String, - iceberg.PrimitiveTypes.UUID, - iceberg.FixedTypeOf(1), - iceberg.PrimitiveTypes.Binary, - }) - - testInvalidLiteralConversions(t, iceberg.NewLiteral(iceberg.Decimal{Val: decimal128.FromI64(3411), Scale: 2}), - []iceberg.Type{ - iceberg.PrimitiveTypes.Bool, - iceberg.PrimitiveTypes.Date, - iceberg.PrimitiveTypes.Time, - iceberg.PrimitiveTypes.Timestamp, - iceberg.PrimitiveTypes.TimestampTz, - iceberg.PrimitiveTypes.String, - iceberg.PrimitiveTypes.UUID, - iceberg.FixedTypeOf(1), - iceberg.PrimitiveTypes.Binary, - }) -} - -func TestInvalidDateTimeLiteralConversions(t *testing.T) { - lit, _ := iceberg.NewLiteral("2017-08-18").To(iceberg.PrimitiveTypes.Date) - testInvalidLiteralConversions(t, lit, []iceberg.Type{ - iceberg.PrimitiveTypes.Bool, - iceberg.PrimitiveTypes.Int32, - iceberg.PrimitiveTypes.Int64, - iceberg.PrimitiveTypes.Float32, - iceberg.PrimitiveTypes.Float64, - iceberg.PrimitiveTypes.Time, - iceberg.PrimitiveTypes.Timestamp, - iceberg.PrimitiveTypes.TimestampTz, - iceberg.DecimalTypeOf(9, 2), - iceberg.PrimitiveTypes.String, - iceberg.PrimitiveTypes.UUID, - iceberg.FixedTypeOf(1), - iceberg.PrimitiveTypes.Binary, - }) - - lit, _ = iceberg.NewLiteral("14:21:01.919").To(iceberg.PrimitiveTypes.Time) - testInvalidLiteralConversions(t, lit, []iceberg.Type{ - iceberg.PrimitiveTypes.Bool, - iceberg.PrimitiveTypes.Int32, - iceberg.PrimitiveTypes.Int64, - iceberg.PrimitiveTypes.Float32, - iceberg.PrimitiveTypes.Float64, - iceberg.PrimitiveTypes.Date, - iceberg.PrimitiveTypes.Timestamp, - iceberg.PrimitiveTypes.TimestampTz, - iceberg.DecimalTypeOf(9, 2), - iceberg.PrimitiveTypes.String, - iceberg.PrimitiveTypes.UUID, - iceberg.FixedTypeOf(1), - iceberg.PrimitiveTypes.Binary, - }) - - lit, _ = iceberg.NewLiteral("2017-08-18T14:21:01.919").To(iceberg.PrimitiveTypes.Timestamp) - testInvalidLiteralConversions(t, lit, []iceberg.Type{ - iceberg.PrimitiveTypes.Bool, - iceberg.PrimitiveTypes.Int32, - iceberg.PrimitiveTypes.Int64, - iceberg.PrimitiveTypes.Float32, - iceberg.PrimitiveTypes.Float64, - iceberg.PrimitiveTypes.Time, - iceberg.DecimalTypeOf(9, 2), - iceberg.PrimitiveTypes.String, - iceberg.PrimitiveTypes.UUID, - iceberg.FixedTypeOf(1), - iceberg.PrimitiveTypes.Binary, - }) -} - -func TestInvalidStringLiteralConversions(t *testing.T) { - testInvalidLiteralConversions(t, iceberg.NewLiteral("abc"), []iceberg.Type{ - iceberg.FixedTypeOf(1), - }) -} - -func TestInvalidBinaryLiteralConversions(t *testing.T) { - testInvalidLiteralConversions(t, iceberg.NewLiteral(uuid.New()), []iceberg.Type{ - iceberg.PrimitiveTypes.Bool, - iceberg.PrimitiveTypes.Int32, - iceberg.PrimitiveTypes.Int64, - iceberg.PrimitiveTypes.Float32, - iceberg.PrimitiveTypes.Float64, - iceberg.PrimitiveTypes.Date, - iceberg.PrimitiveTypes.Time, - iceberg.PrimitiveTypes.Timestamp, - iceberg.PrimitiveTypes.TimestampTz, - iceberg.DecimalTypeOf(9, 2), - iceberg.PrimitiveTypes.String, - iceberg.FixedTypeOf(1), - }) - - lit, _ := iceberg.NewLiteral([]byte{0x00, 0x01, 0x02}).To(iceberg.FixedTypeOf(3)) - testInvalidLiteralConversions(t, lit, []iceberg.Type{ - iceberg.PrimitiveTypes.Bool, - iceberg.PrimitiveTypes.Int32, - iceberg.PrimitiveTypes.Int64, - iceberg.PrimitiveTypes.Float32, - iceberg.PrimitiveTypes.Float64, - iceberg.PrimitiveTypes.Date, - iceberg.PrimitiveTypes.Time, - iceberg.PrimitiveTypes.Timestamp, - iceberg.PrimitiveTypes.TimestampTz, - iceberg.DecimalTypeOf(9, 2), - iceberg.PrimitiveTypes.String, - iceberg.PrimitiveTypes.UUID, - }) - - testInvalidLiteralConversions(t, iceberg.NewLiteral([]byte{0x00, 0x01, 0x02}), []iceberg.Type{ - iceberg.PrimitiveTypes.Bool, - iceberg.PrimitiveTypes.Int32, - iceberg.PrimitiveTypes.Int64, - iceberg.PrimitiveTypes.Float32, - iceberg.PrimitiveTypes.Float64, - iceberg.PrimitiveTypes.Date, - iceberg.PrimitiveTypes.Time, - iceberg.PrimitiveTypes.Timestamp, - iceberg.PrimitiveTypes.TimestampTz, - iceberg.DecimalTypeOf(9, 2), - iceberg.PrimitiveTypes.String, - iceberg.PrimitiveTypes.UUID, - }) -} - -func TestBadStringLiteralCasts(t *testing.T) { - tests := []iceberg.Type{ - iceberg.PrimitiveTypes.Int32, - iceberg.PrimitiveTypes.Int64, - iceberg.PrimitiveTypes.Float32, - iceberg.PrimitiveTypes.Float64, - iceberg.PrimitiveTypes.Date, - iceberg.PrimitiveTypes.Time, - iceberg.PrimitiveTypes.Timestamp, - iceberg.PrimitiveTypes.TimestampTz, - iceberg.PrimitiveTypes.Bool, - iceberg.DecimalTypeOf(9, 2), - iceberg.PrimitiveTypes.UUID, - } - - for _, tt := range tests { - t.Run(tt.String(), func(t *testing.T) { - _, err := iceberg.NewLiteral("abc").To(tt) - assert.ErrorIs(t, err, iceberg.ErrBadCast) - }) - } -} - -func TestStringLiteralToIntMaxMinValue(t *testing.T) { - above, err := iceberg.NewLiteral(strconv.FormatInt(math.MaxInt32+1, 10)). - To(iceberg.PrimitiveTypes.Int32) - require.NoError(t, err) - assert.Equal(t, iceberg.Int32AboveMaxLiteral(), above) - - below, err := iceberg.NewLiteral(strconv.FormatInt(math.MinInt32-1, 10)). - To(iceberg.PrimitiveTypes.Int32) - require.NoError(t, err) - assert.Equal(t, iceberg.Int32BelowMinLiteral(), below) -} - -func TestUnmarshalBinary(t *testing.T) { - tests := []struct { - typ iceberg.Type - data []byte - result iceberg.Literal - }{ - {iceberg.PrimitiveTypes.Bool, []byte{0x0}, iceberg.BoolLiteral(false)}, - {iceberg.PrimitiveTypes.Bool, []byte{0x1}, iceberg.BoolLiteral(true)}, - {iceberg.PrimitiveTypes.Int32, []byte{0xd2, 0x04, 0x00, 0x00}, iceberg.Int32Literal(1234)}, - {iceberg.PrimitiveTypes.Int64, []byte{0xd2, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - iceberg.Int64Literal(1234)}, - {iceberg.PrimitiveTypes.Float32, []byte{0x00, 0x00, 0x90, 0xc0}, iceberg.Float32Literal(-4.5)}, - {iceberg.PrimitiveTypes.Float64, []byte{0x8d, 0x97, 0x6e, 0x12, 0x83, 0xc0, 0xf3, 0x3f}, - iceberg.Float64Literal(1.2345)}, - {iceberg.PrimitiveTypes.Date, []byte{0xe8, 0x03, 0x00, 0x00}, iceberg.DateLiteral(1000)}, - {iceberg.PrimitiveTypes.Date, []byte{0xd2, 0x04, 0x00, 0x00}, iceberg.DateLiteral(1234)}, - {iceberg.PrimitiveTypes.Time, []byte{0x10, 0x27, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - iceberg.TimeLiteral(10000)}, - {iceberg.PrimitiveTypes.Time, []byte{0x00, 0xe8, 0x76, 0x48, 0x17, 0x00, 0x00, 0x00}, - iceberg.TimeLiteral(100000000000)}, - {iceberg.PrimitiveTypes.TimestampTz, []byte{0x80, 0x1a, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00}, - iceberg.TimestampLiteral(400000)}, - {iceberg.PrimitiveTypes.TimestampTz, []byte{0x00, 0xe8, 0x76, 0x48, 0x17, 0x00, 0x00, 0x00}, - iceberg.TimestampLiteral(100000000000)}, - {iceberg.PrimitiveTypes.Timestamp, []byte{0x80, 0x1a, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00}, - iceberg.TimestampLiteral(400000)}, - {iceberg.PrimitiveTypes.Timestamp, []byte{0x00, 0xe8, 0x76, 0x48, 0x17, 0x00, 0x00, 0x00}, - iceberg.TimestampLiteral(100000000000)}, - {iceberg.PrimitiveTypes.String, []byte("ABC"), iceberg.StringLiteral("ABC")}, - {iceberg.PrimitiveTypes.String, []byte("foo"), iceberg.StringLiteral("foo")}, - {iceberg.PrimitiveTypes.UUID, - []byte{0xf7, 0x9c, 0x3e, 0x09, 0x67, 0x7c, 0x4b, 0xbd, 0xa4, 0x79, 0x3f, 0x34, 0x9c, 0xb7, 0x85, 0xe7}, - iceberg.UUIDLiteral(uuid.UUID{0xf7, 0x9c, 0x3e, 0x09, 0x67, 0x7c, 0x4b, 0xbd, 0xa4, 0x79, 0x3f, 0x34, 0x9c, 0xb7, 0x85, 0xe7})}, - {iceberg.FixedTypeOf(3), []byte("foo"), iceberg.FixedLiteral([]byte("foo"))}, - {iceberg.PrimitiveTypes.Binary, []byte("foo"), iceberg.BinaryLiteral([]byte("foo"))}, - {iceberg.DecimalTypeOf(5, 2), []byte{0x30, 0x39}, - iceberg.DecimalLiteral{Scale: 2, Val: decimal128.FromU64(12345)}}, - {iceberg.DecimalTypeOf(7, 4), []byte{0x12, 0xd6, 0x87}, - iceberg.DecimalLiteral{Scale: 4, Val: decimal128.FromU64(1234567)}}, - {iceberg.DecimalTypeOf(7, 4), []byte{0xff, 0xed, 0x29, 0x79}, - iceberg.DecimalLiteral{Scale: 4, Val: decimal128.FromI64(-1234567)}}, - } - - for _, tt := range tests { - t.Run(tt.typ.String(), func(t *testing.T) { - lit, err := iceberg.LiteralFromBytes(tt.typ, tt.data) - require.NoError(t, err) - - assert.Truef(t, tt.result.Equals(lit), "expected: %s, got: %s", tt.result, lit) - }) - } -} - -func TestRoundTripLiteralBinary(t *testing.T) { - tests := []struct { - typ iceberg.Type - b []byte - result iceberg.Literal - }{ - {iceberg.PrimitiveTypes.Bool, []byte{0x0}, iceberg.BoolLiteral(false)}, - {iceberg.PrimitiveTypes.Bool, []byte{0x1}, iceberg.BoolLiteral(true)}, - {iceberg.PrimitiveTypes.Int32, []byte{0xd2, 0x04, 0x00, 0x00}, iceberg.Int32Literal(1234)}, - {iceberg.PrimitiveTypes.Int64, []byte{0xd2, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - iceberg.Int64Literal(1234)}, - {iceberg.PrimitiveTypes.Float32, []byte{0x00, 0x00, 0x90, 0xc0}, iceberg.Float32Literal(-4.5)}, - {iceberg.PrimitiveTypes.Float32, []byte{0x19, 0x04, 0x9e, 0x3f}, iceberg.Float32Literal(1.2345)}, - {iceberg.PrimitiveTypes.Float64, []byte{0x8d, 0x97, 0x6e, 0x12, 0x83, 0xc0, 0xf3, 0x3f}, - iceberg.Float64Literal(1.2345)}, - {iceberg.PrimitiveTypes.Date, []byte{0xe8, 0x03, 0x00, 0x00}, iceberg.DateLiteral(1000)}, - {iceberg.PrimitiveTypes.Date, []byte{0xd2, 0x04, 0x00, 0x00}, iceberg.DateLiteral(1234)}, - {iceberg.PrimitiveTypes.Time, []byte{0x10, 0x27, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - iceberg.TimeLiteral(10000)}, - {iceberg.PrimitiveTypes.Time, []byte{0x00, 0xe8, 0x76, 0x48, 0x17, 0x00, 0x00, 0x00}, - iceberg.TimeLiteral(100000000000)}, - {iceberg.PrimitiveTypes.TimestampTz, []byte{0x80, 0x1a, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00}, - iceberg.TimestampLiteral(400000)}, - {iceberg.PrimitiveTypes.TimestampTz, []byte{0x00, 0xe8, 0x76, 0x48, 0x17, 0x00, 0x00, 0x00}, - iceberg.TimestampLiteral(100000000000)}, - {iceberg.PrimitiveTypes.Timestamp, []byte{0x80, 0x1a, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00}, - iceberg.TimestampLiteral(400000)}, - {iceberg.PrimitiveTypes.Timestamp, []byte{0x00, 0xe8, 0x76, 0x48, 0x17, 0x00, 0x00, 0x00}, - iceberg.TimestampLiteral(100000000000)}, - {iceberg.PrimitiveTypes.String, []byte("ABC"), iceberg.StringLiteral("ABC")}, - {iceberg.PrimitiveTypes.String, []byte("foo"), iceberg.StringLiteral("foo")}, - {iceberg.PrimitiveTypes.UUID, - []byte{0xf7, 0x9c, 0x3e, 0x09, 0x67, 0x7c, 0x4b, 0xbd, 0xa4, 0x79, 0x3f, 0x34, 0x9c, 0xb7, 0x85, 0xe7}, - iceberg.UUIDLiteral(uuid.UUID{0xf7, 0x9c, 0x3e, 0x09, 0x67, 0x7c, 0x4b, 0xbd, 0xa4, 0x79, 0x3f, 0x34, 0x9c, 0xb7, 0x85, 0xe7})}, - {iceberg.FixedTypeOf(3), []byte("foo"), iceberg.FixedLiteral([]byte("foo"))}, - {iceberg.PrimitiveTypes.Binary, []byte("foo"), iceberg.BinaryLiteral([]byte("foo"))}, - {iceberg.DecimalTypeOf(5, 2), []byte{0x30, 0x39}, - iceberg.DecimalLiteral{Scale: 2, Val: decimal128.FromU64(12345)}}, - // decimal on 3-bytes to test that we use the minimum number of bytes and not a power of 2 - // 1234567 is 00010010|11010110|10000111 in binary - // 00010010 -> 18, 11010110 -> 214, 10000111 -> 135 - {iceberg.DecimalTypeOf(7, 4), []byte{0x12, 0xd6, 0x87}, - iceberg.DecimalLiteral{Scale: 4, Val: decimal128.FromU64(1234567)}}, - // negative decimal to test two's complement - // -1234567 is 11101101|00101001|01111001 in binary - // 11101101 -> 237, 00101001 -> 41, 01111001 -> 121 - {iceberg.DecimalTypeOf(7, 4), []byte{0xed, 0x29, 0x79}, - iceberg.DecimalLiteral{Scale: 4, Val: decimal128.FromI64(-1234567)}}, - // test empty byte in decimal - // 11 is 00001011 in binary - // 00001011 -> 11 - {iceberg.DecimalTypeOf(10, 3), []byte{0x0b}, iceberg.DecimalLiteral{Scale: 3, Val: decimal128.FromU64(11)}}, - {iceberg.DecimalTypeOf(4, 2), []byte{0x04, 0xd2}, iceberg.DecimalLiteral{Scale: 2, Val: decimal128.FromU64(1234)}}, - } - - for _, tt := range tests { - t.Run(tt.result.String(), func(t *testing.T) { - lit, err := iceberg.LiteralFromBytes(tt.typ, tt.b) - require.NoError(t, err) - - assert.True(t, lit.Equals(tt.result)) - - data, err := lit.MarshalBinary() - require.NoError(t, err) - - assert.Equal(t, tt.b, data) - }) - } -} - -func TestLargeDecimalRoundTrip(t *testing.T) { - tests := []struct { - typ iceberg.DecimalType - b []byte - val string - }{ - {iceberg.DecimalTypeOf(38, 21), - []byte{0x09, 0x49, 0xb0, 0xf7, 0x13, 0xe9, 0x18, 0x30, 0x73, 0xb9, 0x1e, - 0x7e, 0xa2, 0xb3, 0x6a, 0x83}, - "12345678912345678.123456789123456789123"}, - {iceberg.DecimalTypeOf(38, 22), - []byte{0x09, 0x49, 0xb0, 0xf7, 0x13, 0xe9, 0x16, 0xbb, 0x01, 0x2f, - 0x4c, 0xc3, 0x2b, 0x42, 0x29, 0x22}, - "1234567891234567.1234567891234567891234"}, - {iceberg.DecimalTypeOf(38, 23), - []byte{0x09, 0x49, 0xb0, 0xf7, 0x13, 0xe9, 0x0a, 0x42, 0xa1, 0xad, - 0xe5, 0x2b, 0x33, 0x15, 0x9b, 0x59}, - "123456789123456.12345678912345678912345"}, - {iceberg.DecimalTypeOf(38, 24), - []byte{0x09, 0x49, 0xb0, 0xf7, 0x13, 0xe8, 0xa2, 0xbb, 0xe9, 0x67, - 0xba, 0x86, 0x77, 0xd8, 0x11, 0x80}, - "12345678912345.123456789123456789123456"}, - {iceberg.DecimalTypeOf(38, 25), - []byte{0x09, 0x49, 0xb0, 0xf7, 0x13, 0xe5, 0x6b, 0x3a, 0xd2, 0x78, - 0xdd, 0x04, 0xc8, 0x70, 0xaf, 0x07}, - "1234567891234.1234567891234567891234567"}, - {iceberg.DecimalTypeOf(38, 26), - []byte{0x09, 0x49, 0xb0, 0xf7, 0x13, 0xcd, 0x85, 0xc5, 0x03, 0x38, 0x37, - 0x3c, 0x38, 0x66, 0xd6, 0x4e}, - "123456789123.12345678912345678912345678"}, - {iceberg.DecimalTypeOf(38, 27), - []byte{0x09, 0x49, 0xb0, 0xf7, 0x13, 0x31, 0x46, 0xfd, 0xc7, 0x79, - 0xca, 0x39, 0x7c, 0x04, 0x5f, 0x15}, - "12345678912.123456789123456789123456789"}, - {iceberg.DecimalTypeOf(38, 28), - []byte{0x09, 0x49, 0xb0, 0xf7, 0x10, 0x52, 0x01, 0x72, 0x11, 0xda, - 0x08, 0x5b, 0x08, 0x2b, 0xb6, 0xd3}, - "1234567891.1234567891234567891234567891"}, - {iceberg.DecimalTypeOf(38, 29), - []byte{0x09, 0x49, 0xb0, 0xf7, 0x13, 0xe9, 0x18, 0x5b, 0x37, 0xc1, - 0x78, 0x0b, 0x91, 0xb5, 0x24, 0x40}, - "123456789.12345678912345678912345678912"}, - {iceberg.DecimalTypeOf(38, 30), - []byte{0x09, 0x49, 0xb0, 0xed, 0x1e, 0xdf, 0x80, 0x03, 0x47, 0x3b, - 0x16, 0x9b, 0xf1, 0x13, 0x6a, 0x83}, - "12345678.123456789123456789123456789123"}, - {iceberg.DecimalTypeOf(38, 31), - []byte{0x09, 0x49, 0xb0, 0x96, 0x2b, 0xac, 0x29, 0x64, 0x28, 0x70, - 0x36, 0x29, 0xea, 0xc2, 0x29, 0x22}, - "1234567.1234567891234567891234567891234"}, - {iceberg.DecimalTypeOf(38, 32), - []byte{0x09, 0x49, 0xad, 0xae, 0xe3, 0x68, 0xe7, 0x4f, 0xb5, 0x14, - 0xbc, 0xdc, 0x2b, 0x95, 0x9b, 0x59}, - "123456.12345678912345678912345678912345"}, - {iceberg.DecimalTypeOf(38, 33), - []byte{0x09, 0x49, 0x95, 0x94, 0x3e, 0x35, 0x93, 0xde, 0xb9, 0x2e, - 0xef, 0x53, 0xb3, 0xd8, 0x11, 0x80}, - "12345.123456789123456789123456789123456"}, - {iceberg.DecimalTypeOf(38, 34), - []byte{0x09, 0x48, 0xd5, 0xd7, 0x90, 0x78, 0xdf, 0x08, 0x1a, 0xf6, - 0x43, 0x09, 0x06, 0x70, 0xaf, 0x07}, - "1234.1234567891234567891234567891234567"}, - {iceberg.DecimalTypeOf(38, 35), - []byte{0x09, 0x43, 0x45, 0x82, 0x85, 0xc7, 0x56, 0x66, 0x24, 0x4d, - 0x16, 0x82, 0x40, 0x66, 0xd6, 0x4e}, - "123.12345678912345678912345678912345678"}, - {iceberg.DecimalTypeOf(21, 16), - []byte{0x06, 0xb1, 0x3a, 0xe3, 0xc4, 0x4e, 0x94, 0xaf, 0x07}, - "12345.1234567891234567"}, - {iceberg.DecimalTypeOf(22, 17), - []byte{0x42, 0xec, 0x4c, 0xe5, 0xab, 0x11, 0xce, 0xd6, 0x4e}, - "12345.12345678912345678"}, - {iceberg.DecimalTypeOf(23, 18), - []byte{0x02, 0x9d, 0x3b, 0x00, 0xf8, 0xae, 0xb2, 0x14, 0x5f, 0x15}, - "12345.123456789123456789"}, - {iceberg.DecimalTypeOf(24, 19), - []byte{0x1a, 0x24, 0x4e, 0x09, 0xb6, 0xd2, 0xf4, 0xcb, 0xb6, 0xd3}, - "12345.1234567891234567891"}, - {iceberg.DecimalTypeOf(25, 20), - []byte{0x01, 0x05, 0x6b, 0x0c, 0x61, 0x24, 0x3d, 0x8f, 0xf5, 0x24, 0x40}, - "12345.12345678912345678912"}, - {iceberg.DecimalTypeOf(26, 21), - []byte{0x0a, 0x36, 0x2e, 0x7b, 0xcb, 0x6a, 0x67, 0x9f, 0x93, 0x6a, 0x83}, - "12345.123456789123456789123"}, - {iceberg.DecimalTypeOf(27, 22), - []byte{0x66, 0x1d, 0xd0, 0xd5, 0xf2, 0x28, 0x0c, 0x3b, 0xc2, 0x29, 0x22}, - "12345.1234567891234567891234"}, - {iceberg.DecimalTypeOf(28, 23), - []byte{0x03, 0xfd, 0x2a, 0x28, 0x5b, 0x75, 0x90, 0x7a, 0x55, 0x95, 0x9b, 0x59}, - "12345.12345678912345678912345"}, - {iceberg.DecimalTypeOf(29, 24), - []byte{0x27, 0xe3, 0xa5, 0x93, 0x92, 0x97, 0xa4, 0xc7, 0x57, 0xd8, 0x11, 0x80}, - "12345.123456789123456789123456"}, - {iceberg.DecimalTypeOf(30, 25), - []byte{0x01, 0x8e, 0xe4, 0x77, 0xc3, 0xb9, 0xec, 0x6f, 0xc9, 0x6e, 0x70, 0xaf, 0x07}, - "12345.1234567891234567891234567"}, - {iceberg.DecimalTypeOf(31, 26), - []byte{0x0f, 0x94, 0xec, 0xad, 0xa5, 0x43, 0x3c, 0x5d, 0xde, 0x50, 0x66, 0xd6, 0x4e}, - "12345.12345678912345678912345678"}, - } - - for _, tt := range tests { - t.Run(tt.val, func(t *testing.T) { - lit, err := iceberg.LiteralFromBytes(tt.typ, tt.b) - require.NoError(t, err) - - v, err := decimal128.FromString(tt.val, int32(tt.typ.Precision()), int32(tt.typ.Scale())) - require.NoError(t, err) - - assert.True(t, lit.Equals(iceberg.DecimalLiteral{Scale: tt.typ.Scale(), Val: v})) - - data, err := lit.MarshalBinary() - require.NoError(t, err) - - assert.Equal(t, tt.b, data) - }) - } -} - -func TestDecimalMaxMinRoundTrip(t *testing.T) { - tests := []struct { - typ iceberg.DecimalType - v string - }{ - {iceberg.DecimalTypeOf(6, 2), "9999.99"}, - {iceberg.DecimalTypeOf(10, 10), ".9999999999"}, - {iceberg.DecimalTypeOf(2, 1), "9.9"}, - {iceberg.DecimalTypeOf(38, 37), "9.9999999999999999999999999999999999999"}, - {iceberg.DecimalTypeOf(20, 1), "9999999999999999999.9"}, - {iceberg.DecimalTypeOf(6, 2), "-9999.99"}, - {iceberg.DecimalTypeOf(10, 10), "-.9999999999"}, - {iceberg.DecimalTypeOf(2, 1), "-9.9"}, - {iceberg.DecimalTypeOf(38, 37), "-9.9999999999999999999999999999999999999"}, - {iceberg.DecimalTypeOf(20, 1), "-9999999999999999999.9"}, - } - - for _, tt := range tests { - t.Run(tt.v, func(t *testing.T) { - v, err := decimal128.FromString(tt.v, int32(tt.typ.Precision()), int32(tt.typ.Scale())) - require.NoError(t, err) - - lit := iceberg.DecimalLiteral{Scale: tt.typ.Scale(), Val: v} - b, err := lit.MarshalBinary() - require.NoError(t, err) - val, err := iceberg.LiteralFromBytes(tt.typ, b) - require.NoError(t, err) - - assert.True(t, val.Equals(lit)) - }) - } -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package iceberg_test + +import ( + "math" + "strconv" + "testing" + "time" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/decimal128" + "github.com/apache/iceberg-go" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNumericLiteralCompare(t *testing.T) { + smallLit := iceberg.NewLiteral(int32(10)).(iceberg.Int32Literal) + bigLit := iceberg.NewLiteral(int32(1000)).(iceberg.Int32Literal) + + assert.False(t, smallLit.Equals(bigLit)) + assert.True(t, smallLit.Equals(iceberg.NewLiteral(int32(10)))) + + cmp := smallLit.Comparator() + + assert.Equal(t, -1, cmp(smallLit.Value(), bigLit.Value())) + assert.Equal(t, 1, cmp(bigLit.Value(), smallLit.Value())) + assert.True(t, smallLit.Type().Equals(iceberg.PrimitiveTypes.Int32)) +} + +func TestIntConversion(t *testing.T) { + lit := iceberg.NewLiteral(int32(34)) + + t.Run("to int64", func(t *testing.T) { + longLit, err := lit.To(iceberg.PrimitiveTypes.Int64) + assert.NoError(t, err) + assert.IsType(t, iceberg.Int64Literal(0), longLit) + assert.EqualValues(t, 34, longLit) + }) + + t.Run("to float32", func(t *testing.T) { + floatLit, err := lit.To(iceberg.PrimitiveTypes.Float32) + assert.NoError(t, err) + assert.IsType(t, iceberg.Float32Literal(0), floatLit) + assert.EqualValues(t, 34, floatLit) + }) + + t.Run("to float64", func(t *testing.T) { + dblLit, err := lit.To(iceberg.PrimitiveTypes.Float64) + assert.NoError(t, err) + assert.IsType(t, iceberg.Float64Literal(0), dblLit) + assert.EqualValues(t, 34, dblLit) + }) +} + +func TestIntToDecimalConversion(t *testing.T) { + tests := []struct { + ty iceberg.DecimalType + val iceberg.Decimal + }{ + {iceberg.DecimalTypeOf(9, 0), + iceberg.Decimal{Val: decimal128.FromI64(34), Scale: 0}}, + {iceberg.DecimalTypeOf(9, 2), + iceberg.Decimal{Val: decimal128.FromI64(3400), Scale: 2}}, + {iceberg.DecimalTypeOf(9, 4), + iceberg.Decimal{Val: decimal128.FromI64(340000), Scale: 4}}, + } + + for _, tt := range tests { + t.Run(tt.ty.String(), func(t *testing.T) { + lit := iceberg.Int32Literal(34) + + dec, err := lit.To(tt.ty) + require.NoError(t, err) + assert.IsType(t, iceberg.DecimalLiteral(tt.val), dec) + assert.EqualValues(t, tt.val, dec) + }) + } +} + +func TestIntToDateConversion(t *testing.T) { + oneDay, _ := time.Parse("2006-01-02", "2022-03-28") + val := int32(arrow.Date32FromTime(oneDay)) + dateLit, err := iceberg.NewLiteral(val).To(iceberg.PrimitiveTypes.Date) + require.NoError(t, err) + assert.True(t, dateLit.Type().Equals(iceberg.PrimitiveTypes.Date)) + assert.EqualValues(t, val, dateLit) + + lit := iceberg.Int32Literal(34) + tm, err := lit.To(iceberg.PrimitiveTypes.Time) + require.NoError(t, err) + assert.EqualValues(t, lit, tm) + + tm, err = lit.To(iceberg.PrimitiveTypes.Timestamp) + require.NoError(t, err) + assert.EqualValues(t, lit, tm) + + tm, err = lit.To(iceberg.PrimitiveTypes.TimestampTz) + require.NoError(t, err) + assert.EqualValues(t, lit, tm) +} + +func TestInt64Conversions(t *testing.T) { + tests := []struct { + from iceberg.Int64Literal + to iceberg.Literal + }{ + {iceberg.Int64Literal(34), iceberg.NewLiteral(int32(34))}, + {iceberg.Int64Literal(34), iceberg.NewLiteral(float32(34))}, + {iceberg.Int64Literal(34), iceberg.NewLiteral(float64(34))}, + {iceberg.Int64Literal(19709), iceberg.NewLiteral(iceberg.Date(19709))}, + {iceberg.Int64Literal(51661919000), iceberg.NewLiteral(iceberg.Time(51661919000))}, + {iceberg.Int64Literal(1647305201), iceberg.NewLiteral(iceberg.Timestamp(1647305201))}, + {iceberg.Int64Literal(34), + iceberg.NewLiteral(iceberg.Decimal{Val: decimal128.FromI64(34), Scale: 0})}, + {iceberg.Int64Literal(34), + iceberg.NewLiteral(iceberg.Decimal{Val: decimal128.FromI64(3400), Scale: 2})}, + {iceberg.Int64Literal(34), + iceberg.NewLiteral(iceberg.Decimal{Val: decimal128.FromI64(340000), Scale: 4})}, + } + + for _, tt := range tests { + t.Run(tt.to.Type().String(), func(t *testing.T) { + got, err := tt.from.To(tt.to.Type()) + require.NoError(t, err) + assert.True(t, tt.to.Equals(got)) + }) + } +} + +func TestInt64ToInt32OutsideBound(t *testing.T) { + bigLit := iceberg.NewLiteral(int64(math.MaxInt32 + 1)) + aboveMax, err := bigLit.To(iceberg.PrimitiveTypes.Int32) + require.NoError(t, err) + assert.Implements(t, (*iceberg.AboveMaxLiteral)(nil), aboveMax) + assert.Equal(t, iceberg.Int32AboveMaxLiteral(), aboveMax) + assert.Equal(t, iceberg.PrimitiveTypes.Int32, aboveMax.Type()) + + smallLit := iceberg.NewLiteral(int64(math.MinInt32 - 1)) + belowMin, err := smallLit.To(iceberg.PrimitiveTypes.Int32) + require.NoError(t, err) + assert.Implements(t, (*iceberg.BelowMinLiteral)(nil), belowMin) + assert.Equal(t, iceberg.Int32BelowMinLiteral(), belowMin) + assert.Equal(t, iceberg.PrimitiveTypes.Int32, belowMin.Type()) +} + +func TestFloatConversions(t *testing.T) { + n1, _ := decimal128.FromFloat32(34.56, 9, 1) + n2, _ := decimal128.FromFloat32(34.56, 9, 2) + n3, _ := decimal128.FromFloat32(34.56, 9, 4) + + tests := []struct { + from iceberg.Float32Literal + to iceberg.Literal + }{ + {iceberg.Float32Literal(34.5), iceberg.NewLiteral(float64(34.5))}, + {iceberg.Float32Literal(34.56), + iceberg.NewLiteral(iceberg.Decimal{Val: n1, Scale: 1})}, + {iceberg.Float32Literal(34.56), + iceberg.NewLiteral(iceberg.Decimal{Val: n2, Scale: 2})}, + {iceberg.Float32Literal(34.56), + iceberg.NewLiteral(iceberg.Decimal{Val: n3, Scale: 4})}, + } + + for _, tt := range tests { + t.Run(tt.to.Type().String(), func(t *testing.T) { + got, err := tt.from.To(tt.to.Type()) + require.NoError(t, err) + assert.Truef(t, tt.to.Equals(got), "expected: %s, got: %s", tt.to, got) + }) + } +} + +func TestFloat64Conversions(t *testing.T) { + n1, _ := decimal128.FromFloat64(34.56, 9, 1) + n2, _ := decimal128.FromFloat64(34.56, 9, 2) + n3, _ := decimal128.FromFloat64(34.56, 9, 4) + + tests := []struct { + from iceberg.Float64Literal + to iceberg.Literal + }{ + {iceberg.Float64Literal(34.5), iceberg.NewLiteral(float32(34.5))}, + {iceberg.Float64Literal(34.56), + iceberg.NewLiteral(iceberg.Decimal{Val: n1, Scale: 1})}, + {iceberg.Float64Literal(34.56), + iceberg.NewLiteral(iceberg.Decimal{Val: n2, Scale: 2})}, + {iceberg.Float64Literal(34.56), + iceberg.NewLiteral(iceberg.Decimal{Val: n3, Scale: 4})}, + } + + for _, tt := range tests { + t.Run(tt.to.Type().String(), func(t *testing.T) { + got, err := tt.from.To(tt.to.Type()) + require.NoError(t, err) + assert.Truef(t, tt.to.Equals(got), "expected: %s, got: %s", tt.to, got) + }) + } +} + +func TestFloat64toFloat32OutsideBounds(t *testing.T) { + bigLit := iceberg.NewLiteral(float64(math.MaxFloat32 + 1.0e37)) + aboveMax, err := bigLit.To(iceberg.PrimitiveTypes.Float32) + require.NoError(t, err) + assert.Equal(t, iceberg.Float32AboveMaxLiteral(), aboveMax) + + smallLit := iceberg.NewLiteral(float64(-math.MaxFloat32 - 1.0e37)) + belowMin, err := smallLit.To(iceberg.PrimitiveTypes.Float32) + require.NoError(t, err) + assert.Equal(t, iceberg.Float32BelowMinLiteral(), belowMin) +} + +func TestDecimalToDecimalConversion(t *testing.T) { + lit := iceberg.NewLiteral(iceberg.Decimal{Val: decimal128.FromI64(3411), Scale: 2}) + + v, err := lit.To(iceberg.DecimalTypeOf(9, 2)) + require.NoError(t, err) + assert.Equal(t, lit, v) + + v, err = lit.To(iceberg.DecimalTypeOf(11, 2)) + require.NoError(t, err) + assert.Equal(t, lit, v) + + _, err = lit.To(iceberg.DecimalTypeOf(9, 0)) + assert.ErrorIs(t, err, iceberg.ErrBadCast) + assert.ErrorContains(t, err, "could not convert 34.11 to decimal(9, 0)") + + _, err = lit.To(iceberg.DecimalTypeOf(9, 1)) + assert.ErrorIs(t, err, iceberg.ErrBadCast) + assert.ErrorContains(t, err, "could not convert 34.11 to decimal(9, 1)") + + _, err = lit.To(iceberg.DecimalTypeOf(9, 3)) + assert.ErrorIs(t, err, iceberg.ErrBadCast) + assert.ErrorContains(t, err, "could not convert 34.11 to decimal(9, 3)") +} + +func TestDecimalLiteralConversions(t *testing.T) { + n1 := iceberg.Decimal{Val: decimal128.FromI64(1234), Scale: 2} + n2 := iceberg.Decimal{Val: decimal128.FromI64(math.MaxInt32 + 1), Scale: 0} + n3 := iceberg.Decimal{Val: decimal128.FromI64(math.MinInt32 - 1), Scale: 10} + + tests := []struct { + from iceberg.DecimalLiteral + to iceberg.Literal + }{ + {iceberg.DecimalLiteral(n1), iceberg.NewLiteral(int32(1234))}, + {iceberg.DecimalLiteral(n1), iceberg.NewLiteral(int64(1234))}, + {iceberg.DecimalLiteral(n2), iceberg.NewLiteral(int64(math.MaxInt32 + 1))}, + {iceberg.DecimalLiteral(n1), iceberg.NewLiteral(float32(12.34))}, + {iceberg.DecimalLiteral(n1), iceberg.NewLiteral(float64(12.34))}, + {iceberg.DecimalLiteral(n3), iceberg.NewLiteral(int64(math.MinInt32 - 1))}, + } + + for _, tt := range tests { + t.Run(tt.to.Type().String(), func(t *testing.T) { + got, err := tt.from.To(tt.to.Type()) + require.NoError(t, err) + assert.Truef(t, tt.to.Equals(got), "expected: %s, got: %s", tt.to, got) + }) + } + + above, err := iceberg.DecimalLiteral(n2).To(iceberg.PrimitiveTypes.Int32) + require.NoError(t, err) + assert.Equal(t, iceberg.Int32AboveMaxLiteral(), above) + assert.Equal(t, iceberg.PrimitiveTypes.Int32, above.Type()) + + below, err := iceberg.DecimalLiteral(n3).To(iceberg.PrimitiveTypes.Int32) + require.NoError(t, err) + assert.Equal(t, iceberg.Int32BelowMinLiteral(), below) + assert.Equal(t, iceberg.PrimitiveTypes.Int32, below.Type()) + + n4 := iceberg.Decimal{Val: decimal128.FromU64(math.MaxInt64 + 1), Scale: 0} + n5 := iceberg.Decimal{Val: decimal128.FromU64(math.MaxUint64).Negate(), Scale: 20} + + above, err = iceberg.DecimalLiteral(n4).To(iceberg.PrimitiveTypes.Int64) + require.NoError(t, err) + assert.Equal(t, iceberg.Int64AboveMaxLiteral(), above) + assert.Equal(t, iceberg.PrimitiveTypes.Int64, above.Type()) + + below, err = iceberg.DecimalLiteral(n5).To(iceberg.PrimitiveTypes.Int64) + require.NoError(t, err) + assert.Equal(t, iceberg.Int64BelowMinLiteral(), below) + assert.Equal(t, iceberg.PrimitiveTypes.Int64, below.Type()) + + v, err := decimal128.FromFloat64(math.MaxFloat32+1e37, 38, -1) + require.NoError(t, err) + above, err = iceberg.DecimalLiteral(iceberg.Decimal{Val: v, Scale: -1}). + To(iceberg.PrimitiveTypes.Float32) + require.NoError(t, err) + assert.Equal(t, iceberg.Float32AboveMaxLiteral(), above) + assert.Equal(t, iceberg.PrimitiveTypes.Float32, above.Type()) + + below, err = iceberg.DecimalLiteral(iceberg.Decimal{Val: v.Negate(), Scale: -1}). + To(iceberg.PrimitiveTypes.Float32) + require.NoError(t, err) + assert.Equal(t, iceberg.Float32BelowMinLiteral(), below) + assert.Equal(t, iceberg.PrimitiveTypes.Float32, below.Type()) +} + +func TestLiteralTimestampToDate(t *testing.T) { + v, _ := arrow.TimestampFromString("1970-01-01T00:00:00.000000+00:00", arrow.Microsecond) + tsLit := iceberg.NewLiteral(iceberg.Timestamp(v)) + dateLit, err := tsLit.To(iceberg.PrimitiveTypes.Date) + require.NoError(t, err) + assert.Zero(t, dateLit) +} + +func TestStringLiterals(t *testing.T) { + sqrt2 := iceberg.NewLiteral("1.414") + pi := iceberg.NewLiteral("3.141") + piStr := iceberg.StringLiteral("3.141") + piDbl := iceberg.NewLiteral(float64(3.141)) + + v, err := pi.To(iceberg.PrimitiveTypes.Float64) + require.NoError(t, err) + assert.Equal(t, piDbl, v) + + assert.False(t, sqrt2.Equals(pi)) + assert.True(t, pi.Equals(piStr)) + assert.False(t, pi.Equals(piDbl)) + assert.Equal(t, "3.141", pi.String()) + + cmp := piStr.Comparator() + assert.Equal(t, -1, cmp(sqrt2.(iceberg.StringLiteral).Value(), piStr.Value())) + assert.Equal(t, 1, cmp(piStr.Value(), sqrt2.(iceberg.StringLiteral).Value())) + + v, err = pi.To(iceberg.PrimitiveTypes.String) + require.NoError(t, err) + assert.Equal(t, pi, v) +} + +func TestStringLiteralConversion(t *testing.T) { + tm, _ := time.Parse("2006-01-02", "2017-08-18") + expected := uuid.New() + + tests := []struct { + from iceberg.StringLiteral + to iceberg.Literal + }{ + {iceberg.StringLiteral("2017-08-18"), + iceberg.NewLiteral(iceberg.Date(arrow.Date32FromTime(tm)))}, + {iceberg.StringLiteral("14:21:01.919"), + iceberg.NewLiteral(iceberg.Time(51661919000))}, + {iceberg.StringLiteral("2017-08-18T14:21:01.919234"), + iceberg.NewLiteral(iceberg.Timestamp(1503066061919234))}, + {iceberg.StringLiteral(expected.String()), iceberg.NewLiteral(expected)}, + {iceberg.StringLiteral("34.560"), + iceberg.NewLiteral(iceberg.Decimal{Val: decimal128.FromI64(34560), Scale: 3})}, + {iceberg.StringLiteral("true"), iceberg.NewLiteral(true)}, + {iceberg.StringLiteral("True"), iceberg.NewLiteral(true)}, + {iceberg.StringLiteral("false"), iceberg.NewLiteral(false)}, + {iceberg.StringLiteral("False"), iceberg.NewLiteral(false)}, + {iceberg.StringLiteral("12345"), iceberg.NewLiteral(int32(12345))}, + {iceberg.StringLiteral("12345123456"), iceberg.NewLiteral(int64(12345123456))}, + {iceberg.StringLiteral("3.14"), iceberg.NewLiteral(float32(3.14))}, + } + + for _, tt := range tests { + t.Run(tt.to.Type().String(), func(t *testing.T) { + got, err := tt.from.To(tt.to.Type()) + require.NoError(t, err) + assert.Truef(t, tt.to.Equals(got), "expected: %s, got: %s", tt.to, got) + }) + } + + lit := iceberg.StringLiteral("2017-08-18T14:21:01.919234-07:00") + casted, err := lit.To(iceberg.PrimitiveTypes.TimestampTz) + require.NoError(t, err) + expectedTimestamp := iceberg.NewLiteral(iceberg.Timestamp(1503091261919234)) + assert.Truef(t, casted.Equals(expectedTimestamp), "expected: %s, got: %s", + expectedTimestamp, casted) + + _, err = lit.To(iceberg.PrimitiveTypes.Timestamp) + require.Error(t, err) + assert.ErrorIs(t, err, iceberg.ErrBadCast) + assert.ErrorContains(t, err, `parsing time "2017-08-18T14:21:01.919234-07:00": extra text: "-07:00"`) + assert.ErrorContains(t, err, "invalid Timestamp format for casting from string") + + _, err = iceberg.StringLiteral("2017-08-18T14:21:01.919234").To(iceberg.PrimitiveTypes.TimestampTz) + require.Error(t, err) + assert.ErrorIs(t, err, iceberg.ErrBadCast) + assert.ErrorContains(t, err, `cannot parse "" as "Z07:00"`) +} + +func TestLiteralIdentityConversions(t *testing.T) { + fixedLit, _ := iceberg.NewLiteral([]byte{0x01, 0x02, 0x03}).To(iceberg.FixedTypeOf(3)) + + tests := []struct { + lit iceberg.Literal + typ iceberg.PrimitiveType + }{ + {iceberg.NewLiteral(true), iceberg.PrimitiveTypes.Bool}, + {iceberg.NewLiteral(int32(34)), iceberg.PrimitiveTypes.Int32}, + {iceberg.NewLiteral(int64(340000000)), iceberg.PrimitiveTypes.Int64}, + {iceberg.NewLiteral(float32(34.11)), iceberg.PrimitiveTypes.Float32}, + {iceberg.NewLiteral(float64(3.5028235e38)), iceberg.PrimitiveTypes.Float64}, + {iceberg.NewLiteral(iceberg.Decimal{Val: decimal128.FromI64(3455), Scale: 2}), + iceberg.DecimalTypeOf(9, 2)}, + {iceberg.NewLiteral(iceberg.Date(19079)), iceberg.PrimitiveTypes.Date}, + {iceberg.NewLiteral(iceberg.Timestamp(1503091261919234)), + iceberg.PrimitiveTypes.Timestamp}, + {iceberg.NewLiteral("abc"), iceberg.PrimitiveTypes.String}, + {iceberg.NewLiteral(uuid.New()), iceberg.PrimitiveTypes.UUID}, + {fixedLit, iceberg.FixedTypeOf(3)}, + {iceberg.NewLiteral([]byte{0x01, 0x02, 0x03}), iceberg.PrimitiveTypes.Binary}, + } + + for _, tt := range tests { + t.Run(tt.typ.String(), func(t *testing.T) { + expected, err := tt.lit.To(tt.typ) + require.NoError(t, err) + assert.Equal(t, expected, tt.lit) + }) + } +} + +func TestFixedLiteral(t *testing.T) { + fixedLit012 := iceberg.FixedLiteral{0x00, 0x01, 0x02} + fixedLit013 := iceberg.FixedLiteral{0x00, 0x01, 0x03} + assert.True(t, fixedLit012.Equals(fixedLit012)) + assert.False(t, fixedLit012.Equals(fixedLit013)) + + cmp := fixedLit012.Comparator() + assert.Equal(t, -1, cmp(fixedLit012, fixedLit013)) + assert.Equal(t, 1, cmp(fixedLit013, fixedLit012)) + assert.Equal(t, 0, cmp(fixedLit013, fixedLit013)) + + testUuid := uuid.New() + lit, err := iceberg.NewLiteral(testUuid[:]).To(iceberg.FixedTypeOf(16)) + require.NoError(t, err) + uuidLit, err := lit.To(iceberg.PrimitiveTypes.UUID) + require.NoError(t, err) + + assert.EqualValues(t, uuidLit, testUuid) + + fixedUuid, err := uuidLit.To(iceberg.FixedTypeOf(16)) + require.NoError(t, err) + assert.EqualValues(t, testUuid[:], fixedUuid) + + binUuid, err := uuidLit.To(iceberg.PrimitiveTypes.Binary) + require.NoError(t, err) + assert.EqualValues(t, testUuid[:], binUuid) + + binlit, err := fixedLit012.To(iceberg.PrimitiveTypes.Binary) + require.NoError(t, err) + assert.EqualValues(t, fixedLit012, binlit) +} + +func TestBinaryLiteral(t *testing.T) { + binLit012 := iceberg.NewLiteral([]byte{0x00, 0x01, 0x02}).(iceberg.BinaryLiteral) + binLit013 := iceberg.NewLiteral([]byte{0x00, 0x01, 0x03}).(iceberg.BinaryLiteral) + assert.True(t, binLit012.Equals(binLit012)) + assert.False(t, binLit012.Equals(binLit013)) + + cmp := binLit012.Comparator() + assert.Equal(t, -1, cmp(binLit012, binLit013)) + assert.Equal(t, 1, cmp(binLit013, binLit012)) + assert.Equal(t, 0, cmp(binLit013, binLit013)) +} + +func TestBinaryLiteralConversions(t *testing.T) { + binLit012 := iceberg.NewLiteral([]byte{0x00, 0x01, 0x02}) + fixed, err := binLit012.To(iceberg.FixedTypeOf(3)) + require.NoError(t, err) + assert.Equal(t, iceberg.FixedLiteral{0x00, 0x01, 0x02}, fixed) + + _, err = binLit012.To(iceberg.FixedTypeOf(4)) + assert.ErrorIs(t, err, iceberg.ErrBadCast) + assert.ErrorContains(t, err, "cannot convert BinaryLiteral to fixed[4], different length - 3 <> 4") + + _, err = binLit012.To(iceberg.FixedTypeOf(2)) + assert.ErrorIs(t, err, iceberg.ErrBadCast) + assert.ErrorContains(t, err, "cannot convert BinaryLiteral to fixed[2], different length - 3 <> 2") + + testUuid := uuid.New() + lit := iceberg.NewLiteral(testUuid[:]) + uuidLit, err := lit.To(iceberg.PrimitiveTypes.UUID) + require.NoError(t, err) + assert.EqualValues(t, testUuid, uuidLit) + + _, err = binLit012.To(iceberg.PrimitiveTypes.UUID) + assert.ErrorIs(t, err, iceberg.ErrBadCast) + assert.ErrorContains(t, err, "cannot convert BinaryLiteral to UUID") +} + +func testInvalidLiteralConversions(t *testing.T, lit iceberg.Literal, typs []iceberg.Type) { + t.Run(lit.Type().String(), func(t *testing.T) { + for _, tt := range typs { + t.Run(tt.String(), func(t *testing.T) { + _, err := lit.To(tt) + assert.ErrorIs(t, err, iceberg.ErrBadCast) + }) + } + }) +} + +func TestInvalidBoolLiteralConversions(t *testing.T) { + testInvalidLiteralConversions(t, iceberg.NewLiteral(true), []iceberg.Type{ + iceberg.PrimitiveTypes.Int32, + iceberg.PrimitiveTypes.Int64, + iceberg.PrimitiveTypes.Float32, + iceberg.PrimitiveTypes.Float64, + iceberg.PrimitiveTypes.Date, + iceberg.PrimitiveTypes.Time, + iceberg.PrimitiveTypes.Timestamp, + iceberg.PrimitiveTypes.TimestampTz, + iceberg.DecimalTypeOf(9, 2), + iceberg.PrimitiveTypes.String, + iceberg.PrimitiveTypes.UUID, + iceberg.PrimitiveTypes.Binary, + iceberg.FixedTypeOf(2), + }) +} + +func TestInvalidNumericConversions(t *testing.T) { + testInvalidLiteralConversions(t, iceberg.NewLiteral(int32(34)), []iceberg.Type{ + iceberg.PrimitiveTypes.Bool, + iceberg.PrimitiveTypes.String, + iceberg.PrimitiveTypes.UUID, + iceberg.FixedTypeOf(1), + iceberg.PrimitiveTypes.Binary, + }) + + testInvalidLiteralConversions(t, iceberg.NewLiteral(int64(34)), []iceberg.Type{ + iceberg.PrimitiveTypes.Bool, + iceberg.PrimitiveTypes.String, + iceberg.PrimitiveTypes.UUID, + iceberg.FixedTypeOf(1), + iceberg.PrimitiveTypes.Binary, + }) + + testInvalidLiteralConversions(t, iceberg.NewLiteral(float32(34)), []iceberg.Type{ + iceberg.PrimitiveTypes.Bool, + iceberg.PrimitiveTypes.Int32, + iceberg.PrimitiveTypes.Int64, + iceberg.PrimitiveTypes.Date, + iceberg.PrimitiveTypes.Time, + iceberg.PrimitiveTypes.Timestamp, + iceberg.PrimitiveTypes.TimestampTz, + iceberg.PrimitiveTypes.String, + iceberg.PrimitiveTypes.UUID, + iceberg.FixedTypeOf(1), + iceberg.PrimitiveTypes.Binary, + }) + + testInvalidLiteralConversions(t, iceberg.NewLiteral(float64(34)), []iceberg.Type{ + iceberg.PrimitiveTypes.Bool, + iceberg.PrimitiveTypes.Int32, + iceberg.PrimitiveTypes.Int64, + iceberg.PrimitiveTypes.Date, + iceberg.PrimitiveTypes.Time, + iceberg.PrimitiveTypes.Timestamp, + iceberg.PrimitiveTypes.TimestampTz, + iceberg.PrimitiveTypes.String, + iceberg.PrimitiveTypes.UUID, + iceberg.FixedTypeOf(1), + iceberg.PrimitiveTypes.Binary, + }) + + testInvalidLiteralConversions(t, iceberg.NewLiteral(iceberg.Decimal{Val: decimal128.FromI64(3411), Scale: 2}), + []iceberg.Type{ + iceberg.PrimitiveTypes.Bool, + iceberg.PrimitiveTypes.Date, + iceberg.PrimitiveTypes.Time, + iceberg.PrimitiveTypes.Timestamp, + iceberg.PrimitiveTypes.TimestampTz, + iceberg.PrimitiveTypes.String, + iceberg.PrimitiveTypes.UUID, + iceberg.FixedTypeOf(1), + iceberg.PrimitiveTypes.Binary, + }) +} + +func TestInvalidDateTimeLiteralConversions(t *testing.T) { + lit, _ := iceberg.NewLiteral("2017-08-18").To(iceberg.PrimitiveTypes.Date) + testInvalidLiteralConversions(t, lit, []iceberg.Type{ + iceberg.PrimitiveTypes.Bool, + iceberg.PrimitiveTypes.Int32, + iceberg.PrimitiveTypes.Int64, + iceberg.PrimitiveTypes.Float32, + iceberg.PrimitiveTypes.Float64, + iceberg.PrimitiveTypes.Time, + iceberg.PrimitiveTypes.Timestamp, + iceberg.PrimitiveTypes.TimestampTz, + iceberg.DecimalTypeOf(9, 2), + iceberg.PrimitiveTypes.String, + iceberg.PrimitiveTypes.UUID, + iceberg.FixedTypeOf(1), + iceberg.PrimitiveTypes.Binary, + }) + + lit, _ = iceberg.NewLiteral("14:21:01.919").To(iceberg.PrimitiveTypes.Time) + testInvalidLiteralConversions(t, lit, []iceberg.Type{ + iceberg.PrimitiveTypes.Bool, + iceberg.PrimitiveTypes.Int32, + iceberg.PrimitiveTypes.Int64, + iceberg.PrimitiveTypes.Float32, + iceberg.PrimitiveTypes.Float64, + iceberg.PrimitiveTypes.Date, + iceberg.PrimitiveTypes.Timestamp, + iceberg.PrimitiveTypes.TimestampTz, + iceberg.DecimalTypeOf(9, 2), + iceberg.PrimitiveTypes.String, + iceberg.PrimitiveTypes.UUID, + iceberg.FixedTypeOf(1), + iceberg.PrimitiveTypes.Binary, + }) + + lit, _ = iceberg.NewLiteral("2017-08-18T14:21:01.919").To(iceberg.PrimitiveTypes.Timestamp) + testInvalidLiteralConversions(t, lit, []iceberg.Type{ + iceberg.PrimitiveTypes.Bool, + iceberg.PrimitiveTypes.Int32, + iceberg.PrimitiveTypes.Int64, + iceberg.PrimitiveTypes.Float32, + iceberg.PrimitiveTypes.Float64, + iceberg.PrimitiveTypes.Time, + iceberg.DecimalTypeOf(9, 2), + iceberg.PrimitiveTypes.String, + iceberg.PrimitiveTypes.UUID, + iceberg.FixedTypeOf(1), + iceberg.PrimitiveTypes.Binary, + }) +} + +func TestInvalidStringLiteralConversions(t *testing.T) { + testInvalidLiteralConversions(t, iceberg.NewLiteral("abc"), []iceberg.Type{ + iceberg.FixedTypeOf(1), + }) +} + +func TestInvalidBinaryLiteralConversions(t *testing.T) { + testInvalidLiteralConversions(t, iceberg.NewLiteral(uuid.New()), []iceberg.Type{ + iceberg.PrimitiveTypes.Bool, + iceberg.PrimitiveTypes.Int32, + iceberg.PrimitiveTypes.Int64, + iceberg.PrimitiveTypes.Float32, + iceberg.PrimitiveTypes.Float64, + iceberg.PrimitiveTypes.Date, + iceberg.PrimitiveTypes.Time, + iceberg.PrimitiveTypes.Timestamp, + iceberg.PrimitiveTypes.TimestampTz, + iceberg.DecimalTypeOf(9, 2), + iceberg.PrimitiveTypes.String, + iceberg.FixedTypeOf(1), + }) + + lit, _ := iceberg.NewLiteral([]byte{0x00, 0x01, 0x02}).To(iceberg.FixedTypeOf(3)) + testInvalidLiteralConversions(t, lit, []iceberg.Type{ + iceberg.PrimitiveTypes.Bool, + iceberg.PrimitiveTypes.Int32, + iceberg.PrimitiveTypes.Int64, + iceberg.PrimitiveTypes.Float32, + iceberg.PrimitiveTypes.Float64, + iceberg.PrimitiveTypes.Date, + iceberg.PrimitiveTypes.Time, + iceberg.PrimitiveTypes.Timestamp, + iceberg.PrimitiveTypes.TimestampTz, + iceberg.DecimalTypeOf(9, 2), + iceberg.PrimitiveTypes.String, + iceberg.PrimitiveTypes.UUID, + }) + + testInvalidLiteralConversions(t, iceberg.NewLiteral([]byte{0x00, 0x01, 0x02}), []iceberg.Type{ + iceberg.PrimitiveTypes.Bool, + iceberg.PrimitiveTypes.Int32, + iceberg.PrimitiveTypes.Int64, + iceberg.PrimitiveTypes.Float32, + iceberg.PrimitiveTypes.Float64, + iceberg.PrimitiveTypes.Date, + iceberg.PrimitiveTypes.Time, + iceberg.PrimitiveTypes.Timestamp, + iceberg.PrimitiveTypes.TimestampTz, + iceberg.DecimalTypeOf(9, 2), + iceberg.PrimitiveTypes.String, + iceberg.PrimitiveTypes.UUID, + }) +} + +func TestBadStringLiteralCasts(t *testing.T) { + tests := []iceberg.Type{ + iceberg.PrimitiveTypes.Int32, + iceberg.PrimitiveTypes.Int64, + iceberg.PrimitiveTypes.Float32, + iceberg.PrimitiveTypes.Float64, + iceberg.PrimitiveTypes.Date, + iceberg.PrimitiveTypes.Time, + iceberg.PrimitiveTypes.Timestamp, + iceberg.PrimitiveTypes.TimestampTz, + iceberg.PrimitiveTypes.Bool, + iceberg.DecimalTypeOf(9, 2), + iceberg.PrimitiveTypes.UUID, + } + + for _, tt := range tests { + t.Run(tt.String(), func(t *testing.T) { + _, err := iceberg.NewLiteral("abc").To(tt) + assert.ErrorIs(t, err, iceberg.ErrBadCast) + }) + } +} + +func TestStringLiteralToIntMaxMinValue(t *testing.T) { + above, err := iceberg.NewLiteral(strconv.FormatInt(math.MaxInt32+1, 10)). + To(iceberg.PrimitiveTypes.Int32) + require.NoError(t, err) + assert.Equal(t, iceberg.Int32AboveMaxLiteral(), above) + + below, err := iceberg.NewLiteral(strconv.FormatInt(math.MinInt32-1, 10)). + To(iceberg.PrimitiveTypes.Int32) + require.NoError(t, err) + assert.Equal(t, iceberg.Int32BelowMinLiteral(), below) +} + +func TestUnmarshalBinary(t *testing.T) { + tests := []struct { + typ iceberg.Type + data []byte + result iceberg.Literal + }{ + {iceberg.PrimitiveTypes.Bool, []byte{0x0}, iceberg.BoolLiteral(false)}, + {iceberg.PrimitiveTypes.Bool, []byte{0x1}, iceberg.BoolLiteral(true)}, + {iceberg.PrimitiveTypes.Int32, []byte{0xd2, 0x04, 0x00, 0x00}, iceberg.Int32Literal(1234)}, + {iceberg.PrimitiveTypes.Int64, []byte{0xd2, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + iceberg.Int64Literal(1234)}, + {iceberg.PrimitiveTypes.Float32, []byte{0x00, 0x00, 0x90, 0xc0}, iceberg.Float32Literal(-4.5)}, + {iceberg.PrimitiveTypes.Float64, []byte{0x8d, 0x97, 0x6e, 0x12, 0x83, 0xc0, 0xf3, 0x3f}, + iceberg.Float64Literal(1.2345)}, + {iceberg.PrimitiveTypes.Date, []byte{0xe8, 0x03, 0x00, 0x00}, iceberg.DateLiteral(1000)}, + {iceberg.PrimitiveTypes.Date, []byte{0xd2, 0x04, 0x00, 0x00}, iceberg.DateLiteral(1234)}, + {iceberg.PrimitiveTypes.Time, []byte{0x10, 0x27, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + iceberg.TimeLiteral(10000)}, + {iceberg.PrimitiveTypes.Time, []byte{0x00, 0xe8, 0x76, 0x48, 0x17, 0x00, 0x00, 0x00}, + iceberg.TimeLiteral(100000000000)}, + {iceberg.PrimitiveTypes.TimestampTz, []byte{0x80, 0x1a, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00}, + iceberg.TimestampLiteral(400000)}, + {iceberg.PrimitiveTypes.TimestampTz, []byte{0x00, 0xe8, 0x76, 0x48, 0x17, 0x00, 0x00, 0x00}, + iceberg.TimestampLiteral(100000000000)}, + {iceberg.PrimitiveTypes.Timestamp, []byte{0x80, 0x1a, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00}, + iceberg.TimestampLiteral(400000)}, + {iceberg.PrimitiveTypes.Timestamp, []byte{0x00, 0xe8, 0x76, 0x48, 0x17, 0x00, 0x00, 0x00}, + iceberg.TimestampLiteral(100000000000)}, + {iceberg.PrimitiveTypes.String, []byte("ABC"), iceberg.StringLiteral("ABC")}, + {iceberg.PrimitiveTypes.String, []byte("foo"), iceberg.StringLiteral("foo")}, + {iceberg.PrimitiveTypes.UUID, + []byte{0xf7, 0x9c, 0x3e, 0x09, 0x67, 0x7c, 0x4b, 0xbd, 0xa4, 0x79, 0x3f, 0x34, 0x9c, 0xb7, 0x85, 0xe7}, + iceberg.UUIDLiteral(uuid.UUID{0xf7, 0x9c, 0x3e, 0x09, 0x67, 0x7c, 0x4b, 0xbd, 0xa4, 0x79, 0x3f, 0x34, 0x9c, 0xb7, 0x85, 0xe7})}, + {iceberg.FixedTypeOf(3), []byte("foo"), iceberg.FixedLiteral([]byte("foo"))}, + {iceberg.PrimitiveTypes.Binary, []byte("foo"), iceberg.BinaryLiteral([]byte("foo"))}, + {iceberg.DecimalTypeOf(5, 2), []byte{0x30, 0x39}, + iceberg.DecimalLiteral{Scale: 2, Val: decimal128.FromU64(12345)}}, + {iceberg.DecimalTypeOf(7, 4), []byte{0x12, 0xd6, 0x87}, + iceberg.DecimalLiteral{Scale: 4, Val: decimal128.FromU64(1234567)}}, + {iceberg.DecimalTypeOf(7, 4), []byte{0xff, 0xed, 0x29, 0x79}, + iceberg.DecimalLiteral{Scale: 4, Val: decimal128.FromI64(-1234567)}}, + } + + for _, tt := range tests { + t.Run(tt.typ.String(), func(t *testing.T) { + lit, err := iceberg.LiteralFromBytes(tt.typ, tt.data) + require.NoError(t, err) + + assert.Truef(t, tt.result.Equals(lit), "expected: %s, got: %s", tt.result, lit) + }) + } +} + +func TestRoundTripLiteralBinary(t *testing.T) { + tests := []struct { + typ iceberg.Type + b []byte + result iceberg.Literal + }{ + {iceberg.PrimitiveTypes.Bool, []byte{0x0}, iceberg.BoolLiteral(false)}, + {iceberg.PrimitiveTypes.Bool, []byte{0x1}, iceberg.BoolLiteral(true)}, + {iceberg.PrimitiveTypes.Int32, []byte{0xd2, 0x04, 0x00, 0x00}, iceberg.Int32Literal(1234)}, + {iceberg.PrimitiveTypes.Int64, []byte{0xd2, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + iceberg.Int64Literal(1234)}, + {iceberg.PrimitiveTypes.Float32, []byte{0x00, 0x00, 0x90, 0xc0}, iceberg.Float32Literal(-4.5)}, + {iceberg.PrimitiveTypes.Float32, []byte{0x19, 0x04, 0x9e, 0x3f}, iceberg.Float32Literal(1.2345)}, + {iceberg.PrimitiveTypes.Float64, []byte{0x8d, 0x97, 0x6e, 0x12, 0x83, 0xc0, 0xf3, 0x3f}, + iceberg.Float64Literal(1.2345)}, + {iceberg.PrimitiveTypes.Date, []byte{0xe8, 0x03, 0x00, 0x00}, iceberg.DateLiteral(1000)}, + {iceberg.PrimitiveTypes.Date, []byte{0xd2, 0x04, 0x00, 0x00}, iceberg.DateLiteral(1234)}, + {iceberg.PrimitiveTypes.Time, []byte{0x10, 0x27, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + iceberg.TimeLiteral(10000)}, + {iceberg.PrimitiveTypes.Time, []byte{0x00, 0xe8, 0x76, 0x48, 0x17, 0x00, 0x00, 0x00}, + iceberg.TimeLiteral(100000000000)}, + {iceberg.PrimitiveTypes.TimestampTz, []byte{0x80, 0x1a, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00}, + iceberg.TimestampLiteral(400000)}, + {iceberg.PrimitiveTypes.TimestampTz, []byte{0x00, 0xe8, 0x76, 0x48, 0x17, 0x00, 0x00, 0x00}, + iceberg.TimestampLiteral(100000000000)}, + {iceberg.PrimitiveTypes.Timestamp, []byte{0x80, 0x1a, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00}, + iceberg.TimestampLiteral(400000)}, + {iceberg.PrimitiveTypes.Timestamp, []byte{0x00, 0xe8, 0x76, 0x48, 0x17, 0x00, 0x00, 0x00}, + iceberg.TimestampLiteral(100000000000)}, + {iceberg.PrimitiveTypes.String, []byte("ABC"), iceberg.StringLiteral("ABC")}, + {iceberg.PrimitiveTypes.String, []byte("foo"), iceberg.StringLiteral("foo")}, + {iceberg.PrimitiveTypes.UUID, + []byte{0xf7, 0x9c, 0x3e, 0x09, 0x67, 0x7c, 0x4b, 0xbd, 0xa4, 0x79, 0x3f, 0x34, 0x9c, 0xb7, 0x85, 0xe7}, + iceberg.UUIDLiteral(uuid.UUID{0xf7, 0x9c, 0x3e, 0x09, 0x67, 0x7c, 0x4b, 0xbd, 0xa4, 0x79, 0x3f, 0x34, 0x9c, 0xb7, 0x85, 0xe7})}, + {iceberg.FixedTypeOf(3), []byte("foo"), iceberg.FixedLiteral([]byte("foo"))}, + {iceberg.PrimitiveTypes.Binary, []byte("foo"), iceberg.BinaryLiteral([]byte("foo"))}, + {iceberg.DecimalTypeOf(5, 2), []byte{0x30, 0x39}, + iceberg.DecimalLiteral{Scale: 2, Val: decimal128.FromU64(12345)}}, + // decimal on 3-bytes to test that we use the minimum number of bytes and not a power of 2 + // 1234567 is 00010010|11010110|10000111 in binary + // 00010010 -> 18, 11010110 -> 214, 10000111 -> 135 + {iceberg.DecimalTypeOf(7, 4), []byte{0x12, 0xd6, 0x87}, + iceberg.DecimalLiteral{Scale: 4, Val: decimal128.FromU64(1234567)}}, + // negative decimal to test two's complement + // -1234567 is 11101101|00101001|01111001 in binary + // 11101101 -> 237, 00101001 -> 41, 01111001 -> 121 + {iceberg.DecimalTypeOf(7, 4), []byte{0xed, 0x29, 0x79}, + iceberg.DecimalLiteral{Scale: 4, Val: decimal128.FromI64(-1234567)}}, + // test empty byte in decimal + // 11 is 00001011 in binary + // 00001011 -> 11 + {iceberg.DecimalTypeOf(10, 3), []byte{0x0b}, iceberg.DecimalLiteral{Scale: 3, Val: decimal128.FromU64(11)}}, + {iceberg.DecimalTypeOf(4, 2), []byte{0x04, 0xd2}, iceberg.DecimalLiteral{Scale: 2, Val: decimal128.FromU64(1234)}}, + } + + for _, tt := range tests { + t.Run(tt.result.String(), func(t *testing.T) { + lit, err := iceberg.LiteralFromBytes(tt.typ, tt.b) + require.NoError(t, err) + + assert.True(t, lit.Equals(tt.result)) + + data, err := lit.MarshalBinary() + require.NoError(t, err) + + assert.Equal(t, tt.b, data) + }) + } +} + +func TestLargeDecimalRoundTrip(t *testing.T) { + tests := []struct { + typ iceberg.DecimalType + b []byte + val string + }{ + {iceberg.DecimalTypeOf(38, 21), + []byte{0x09, 0x49, 0xb0, 0xf7, 0x13, 0xe9, 0x18, 0x30, 0x73, 0xb9, 0x1e, + 0x7e, 0xa2, 0xb3, 0x6a, 0x83}, + "12345678912345678.123456789123456789123"}, + {iceberg.DecimalTypeOf(38, 22), + []byte{0x09, 0x49, 0xb0, 0xf7, 0x13, 0xe9, 0x16, 0xbb, 0x01, 0x2f, + 0x4c, 0xc3, 0x2b, 0x42, 0x29, 0x22}, + "1234567891234567.1234567891234567891234"}, + {iceberg.DecimalTypeOf(38, 23), + []byte{0x09, 0x49, 0xb0, 0xf7, 0x13, 0xe9, 0x0a, 0x42, 0xa1, 0xad, + 0xe5, 0x2b, 0x33, 0x15, 0x9b, 0x59}, + "123456789123456.12345678912345678912345"}, + {iceberg.DecimalTypeOf(38, 24), + []byte{0x09, 0x49, 0xb0, 0xf7, 0x13, 0xe8, 0xa2, 0xbb, 0xe9, 0x67, + 0xba, 0x86, 0x77, 0xd8, 0x11, 0x80}, + "12345678912345.123456789123456789123456"}, + {iceberg.DecimalTypeOf(38, 25), + []byte{0x09, 0x49, 0xb0, 0xf7, 0x13, 0xe5, 0x6b, 0x3a, 0xd2, 0x78, + 0xdd, 0x04, 0xc8, 0x70, 0xaf, 0x07}, + "1234567891234.1234567891234567891234567"}, + {iceberg.DecimalTypeOf(38, 26), + []byte{0x09, 0x49, 0xb0, 0xf7, 0x13, 0xcd, 0x85, 0xc5, 0x03, 0x38, 0x37, + 0x3c, 0x38, 0x66, 0xd6, 0x4e}, + "123456789123.12345678912345678912345678"}, + {iceberg.DecimalTypeOf(38, 27), + []byte{0x09, 0x49, 0xb0, 0xf7, 0x13, 0x31, 0x46, 0xfd, 0xc7, 0x79, + 0xca, 0x39, 0x7c, 0x04, 0x5f, 0x15}, + "12345678912.123456789123456789123456789"}, + {iceberg.DecimalTypeOf(38, 28), + []byte{0x09, 0x49, 0xb0, 0xf7, 0x10, 0x52, 0x01, 0x72, 0x11, 0xda, + 0x08, 0x5b, 0x08, 0x2b, 0xb6, 0xd3}, + "1234567891.1234567891234567891234567891"}, + {iceberg.DecimalTypeOf(38, 29), + []byte{0x09, 0x49, 0xb0, 0xf7, 0x13, 0xe9, 0x18, 0x5b, 0x37, 0xc1, + 0x78, 0x0b, 0x91, 0xb5, 0x24, 0x40}, + "123456789.12345678912345678912345678912"}, + {iceberg.DecimalTypeOf(38, 30), + []byte{0x09, 0x49, 0xb0, 0xed, 0x1e, 0xdf, 0x80, 0x03, 0x47, 0x3b, + 0x16, 0x9b, 0xf1, 0x13, 0x6a, 0x83}, + "12345678.123456789123456789123456789123"}, + {iceberg.DecimalTypeOf(38, 31), + []byte{0x09, 0x49, 0xb0, 0x96, 0x2b, 0xac, 0x29, 0x64, 0x28, 0x70, + 0x36, 0x29, 0xea, 0xc2, 0x29, 0x22}, + "1234567.1234567891234567891234567891234"}, + {iceberg.DecimalTypeOf(38, 32), + []byte{0x09, 0x49, 0xad, 0xae, 0xe3, 0x68, 0xe7, 0x4f, 0xb5, 0x14, + 0xbc, 0xdc, 0x2b, 0x95, 0x9b, 0x59}, + "123456.12345678912345678912345678912345"}, + {iceberg.DecimalTypeOf(38, 33), + []byte{0x09, 0x49, 0x95, 0x94, 0x3e, 0x35, 0x93, 0xde, 0xb9, 0x2e, + 0xef, 0x53, 0xb3, 0xd8, 0x11, 0x80}, + "12345.123456789123456789123456789123456"}, + {iceberg.DecimalTypeOf(38, 34), + []byte{0x09, 0x48, 0xd5, 0xd7, 0x90, 0x78, 0xdf, 0x08, 0x1a, 0xf6, + 0x43, 0x09, 0x06, 0x70, 0xaf, 0x07}, + "1234.1234567891234567891234567891234567"}, + {iceberg.DecimalTypeOf(38, 35), + []byte{0x09, 0x43, 0x45, 0x82, 0x85, 0xc7, 0x56, 0x66, 0x24, 0x4d, + 0x16, 0x82, 0x40, 0x66, 0xd6, 0x4e}, + "123.12345678912345678912345678912345678"}, + {iceberg.DecimalTypeOf(21, 16), + []byte{0x06, 0xb1, 0x3a, 0xe3, 0xc4, 0x4e, 0x94, 0xaf, 0x07}, + "12345.1234567891234567"}, + {iceberg.DecimalTypeOf(22, 17), + []byte{0x42, 0xec, 0x4c, 0xe5, 0xab, 0x11, 0xce, 0xd6, 0x4e}, + "12345.12345678912345678"}, + {iceberg.DecimalTypeOf(23, 18), + []byte{0x02, 0x9d, 0x3b, 0x00, 0xf8, 0xae, 0xb2, 0x14, 0x5f, 0x15}, + "12345.123456789123456789"}, + {iceberg.DecimalTypeOf(24, 19), + []byte{0x1a, 0x24, 0x4e, 0x09, 0xb6, 0xd2, 0xf4, 0xcb, 0xb6, 0xd3}, + "12345.1234567891234567891"}, + {iceberg.DecimalTypeOf(25, 20), + []byte{0x01, 0x05, 0x6b, 0x0c, 0x61, 0x24, 0x3d, 0x8f, 0xf5, 0x24, 0x40}, + "12345.12345678912345678912"}, + {iceberg.DecimalTypeOf(26, 21), + []byte{0x0a, 0x36, 0x2e, 0x7b, 0xcb, 0x6a, 0x67, 0x9f, 0x93, 0x6a, 0x83}, + "12345.123456789123456789123"}, + {iceberg.DecimalTypeOf(27, 22), + []byte{0x66, 0x1d, 0xd0, 0xd5, 0xf2, 0x28, 0x0c, 0x3b, 0xc2, 0x29, 0x22}, + "12345.1234567891234567891234"}, + {iceberg.DecimalTypeOf(28, 23), + []byte{0x03, 0xfd, 0x2a, 0x28, 0x5b, 0x75, 0x90, 0x7a, 0x55, 0x95, 0x9b, 0x59}, + "12345.12345678912345678912345"}, + {iceberg.DecimalTypeOf(29, 24), + []byte{0x27, 0xe3, 0xa5, 0x93, 0x92, 0x97, 0xa4, 0xc7, 0x57, 0xd8, 0x11, 0x80}, + "12345.123456789123456789123456"}, + {iceberg.DecimalTypeOf(30, 25), + []byte{0x01, 0x8e, 0xe4, 0x77, 0xc3, 0xb9, 0xec, 0x6f, 0xc9, 0x6e, 0x70, 0xaf, 0x07}, + "12345.1234567891234567891234567"}, + {iceberg.DecimalTypeOf(31, 26), + []byte{0x0f, 0x94, 0xec, 0xad, 0xa5, 0x43, 0x3c, 0x5d, 0xde, 0x50, 0x66, 0xd6, 0x4e}, + "12345.12345678912345678912345678"}, + } + + for _, tt := range tests { + t.Run(tt.val, func(t *testing.T) { + lit, err := iceberg.LiteralFromBytes(tt.typ, tt.b) + require.NoError(t, err) + + v, err := decimal128.FromString(tt.val, int32(tt.typ.Precision()), int32(tt.typ.Scale())) + require.NoError(t, err) + + assert.True(t, lit.Equals(iceberg.DecimalLiteral{Scale: tt.typ.Scale(), Val: v})) + + data, err := lit.MarshalBinary() + require.NoError(t, err) + + assert.Equal(t, tt.b, data) + }) + } +} + +func TestDecimalMaxMinRoundTrip(t *testing.T) { + tests := []struct { + typ iceberg.DecimalType + v string + }{ + {iceberg.DecimalTypeOf(6, 2), "9999.99"}, + {iceberg.DecimalTypeOf(10, 10), ".9999999999"}, + {iceberg.DecimalTypeOf(2, 1), "9.9"}, + {iceberg.DecimalTypeOf(38, 37), "9.9999999999999999999999999999999999999"}, + {iceberg.DecimalTypeOf(20, 1), "9999999999999999999.9"}, + {iceberg.DecimalTypeOf(6, 2), "-9999.99"}, + {iceberg.DecimalTypeOf(10, 10), "-.9999999999"}, + {iceberg.DecimalTypeOf(2, 1), "-9.9"}, + {iceberg.DecimalTypeOf(38, 37), "-9.9999999999999999999999999999999999999"}, + {iceberg.DecimalTypeOf(20, 1), "-9999999999999999999.9"}, + } + + for _, tt := range tests { + t.Run(tt.v, func(t *testing.T) { + v, err := decimal128.FromString(tt.v, int32(tt.typ.Precision()), int32(tt.typ.Scale())) + require.NoError(t, err) + + lit := iceberg.DecimalLiteral{Scale: tt.typ.Scale(), Val: v} + b, err := lit.MarshalBinary() + require.NoError(t, err) + val, err := iceberg.LiteralFromBytes(tt.typ, b) + require.NoError(t, err) + + assert.True(t, val.Equals(lit)) + }) + } +} diff --git a/manifest.go b/manifest.go index b8320e3..bc59282 100644 --- a/manifest.go +++ b/manifest.go @@ -1,968 +1,968 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package iceberg - -import ( - "io" - "sync" - "time" - - iceio "github.com/apache/iceberg-go/io" - - "github.com/hamba/avro/v2" - "github.com/hamba/avro/v2/ocf" -) - -// ManifestContent indicates the type of data inside of the files -// described by a manifest. This will indicate whether the data files -// contain active data or deleted rows. -type ManifestContent int32 - -const ( - ManifestContentData ManifestContent = 0 - ManifestContentDeletes ManifestContent = 1 -) - -type FieldSummary struct { - ContainsNull bool `avro:"contains_null"` - ContainsNaN *bool `avro:"contains_nan"` - LowerBound *[]byte `avro:"lower_bound"` - UpperBound *[]byte `avro:"upper_bound"` -} - -// ManifestV1Builder is a helper for building a V1 manifest file -// struct which will conform to the ManifestFile interface. -type ManifestV1Builder struct { - m *manifestFileV1 -} - -// NewManifestV1Builder is passed all of the required fields and then allows -// all of the optional fields to be set by calling the corresponding methods -// before calling [ManifestV1Builder.Build] to construct the object. -func NewManifestV1Builder(path string, length int64, partitionSpecID int32, addedSnapshotID int64) *ManifestV1Builder { - return &ManifestV1Builder{ - m: &manifestFileV1{ - Path: path, - Len: length, - SpecID: partitionSpecID, - AddedSnapshotID: addedSnapshotID, - }, - } -} - -func (b *ManifestV1Builder) AddedFiles(cnt int32) *ManifestV1Builder { - b.m.AddedFilesCount = &cnt - return b -} - -func (b *ManifestV1Builder) ExistingFiles(cnt int32) *ManifestV1Builder { - b.m.ExistingFilesCount = &cnt - return b -} - -func (b *ManifestV1Builder) DeletedFiles(cnt int32) *ManifestV1Builder { - b.m.DeletedFilesCount = &cnt - return b -} - -func (b *ManifestV1Builder) AddedRows(cnt int64) *ManifestV1Builder { - b.m.AddedRowsCount = &cnt - return b -} - -func (b *ManifestV1Builder) ExistingRows(cnt int64) *ManifestV1Builder { - b.m.ExistingRowsCount = &cnt - return b -} - -func (b *ManifestV1Builder) DeletedRows(cnt int64) *ManifestV1Builder { - b.m.DeletedRowsCount = &cnt - return b -} - -func (b *ManifestV1Builder) Partitions(p []FieldSummary) *ManifestV1Builder { - b.m.PartitionList = &p - return b -} - -func (b *ManifestV1Builder) KeyMetadata(km []byte) *ManifestV1Builder { - b.m.Key = km - return b -} - -// Build returns the constructed manifest file, after calling Build this -// builder should not be used further as we avoid copying by just returning -// a pointer to the constructed manifest file. Further calls to the modifier -// methods after calling build would modify the constructed ManifestFile. -func (b *ManifestV1Builder) Build() ManifestFile { - return b.m -} - -type fallbackManifestFileV1 struct { - manifestFileV1 - AddedSnapshotID *int64 `avro:"added_snapshot_id"` -} - -func (f *fallbackManifestFileV1) toManifest() *manifestFileV1 { - f.manifestFileV1.AddedSnapshotID = *f.AddedSnapshotID - return &f.manifestFileV1 -} - -type manifestFileV1 struct { - Path string `avro:"manifest_path"` - Len int64 `avro:"manifest_length"` - SpecID int32 `avro:"partition_spec_id"` - AddedSnapshotID int64 `avro:"added_snapshot_id"` - AddedFilesCount *int32 `avro:"added_data_files_count"` - ExistingFilesCount *int32 `avro:"existing_data_files_count"` - DeletedFilesCount *int32 `avro:"deleted_data_files_count"` - AddedRowsCount *int64 `avro:"added_rows_count"` - ExistingRowsCount *int64 `avro:"existing_rows_count"` - DeletedRowsCount *int64 `avro:"deleted_rows_count"` - PartitionList *[]FieldSummary `avro:"partitions"` - Key []byte `avro:"key_metadata"` -} - -func (*manifestFileV1) Version() int { return 1 } -func (m *manifestFileV1) FilePath() string { return m.Path } -func (m *manifestFileV1) Length() int64 { return m.Len } -func (m *manifestFileV1) PartitionSpecID() int32 { return m.SpecID } -func (m *manifestFileV1) ManifestContent() ManifestContent { - return ManifestContentData -} -func (m *manifestFileV1) SnapshotID() int64 { - return m.AddedSnapshotID -} - -func (m *manifestFileV1) AddedDataFiles() int32 { - if m.AddedFilesCount == nil { - return 0 - } - return *m.AddedFilesCount -} - -func (m *manifestFileV1) ExistingDataFiles() int32 { - if m.ExistingFilesCount == nil { - return 0 - } - return *m.ExistingFilesCount -} - -func (m *manifestFileV1) DeletedDataFiles() int32 { - if m.DeletedFilesCount == nil { - return 0 - } - return *m.DeletedFilesCount -} - -func (m *manifestFileV1) AddedRows() int64 { - if m.AddedRowsCount == nil { - return 0 - } - return *m.AddedRowsCount -} - -func (m *manifestFileV1) ExistingRows() int64 { - if m.ExistingRowsCount == nil { - return 0 - } - return *m.ExistingRowsCount -} - -func (m *manifestFileV1) DeletedRows() int64 { - if m.DeletedRowsCount == nil { - return 0 - } - return *m.DeletedRowsCount -} - -func (m *manifestFileV1) HasAddedFiles() bool { - return m.AddedFilesCount == nil || *m.AddedFilesCount > 0 -} - -func (m *manifestFileV1) HasExistingFiles() bool { - return m.ExistingFilesCount == nil || *m.ExistingFilesCount > 0 -} - -func (m *manifestFileV1) SequenceNum() int64 { return 0 } -func (m *manifestFileV1) MinSequenceNum() int64 { return 0 } -func (m *manifestFileV1) KeyMetadata() []byte { return m.Key } -func (m *manifestFileV1) Partitions() []FieldSummary { - if m.PartitionList == nil { - return nil - } - return *m.PartitionList -} - -func (m *manifestFileV1) FetchEntries(fs iceio.IO, discardDeleted bool) ([]ManifestEntry, error) { - return fetchManifestEntries(m, fs, discardDeleted) -} - -// ManifestV2Builder is a helper for building a V2 manifest file -// struct which will conform to the ManifestFile interface. -type ManifestV2Builder struct { - m *manifestFileV2 -} - -// NewManifestV2Builder is constructed with the primary fields, with the remaining -// fields set to their zero value unless modified by calling the corresponding -// methods of the builder. Then calling [ManifestV2Builder.Build] to retrieve the -// constructed ManifestFile. -func NewManifestV2Builder(path string, length int64, partitionSpecID int32, content ManifestContent, addedSnapshotID int64) *ManifestV2Builder { - return &ManifestV2Builder{ - m: &manifestFileV2{ - Path: path, - Len: length, - SpecID: partitionSpecID, - Content: content, - AddedSnapshotID: addedSnapshotID, - }, - } -} - -func (b *ManifestV2Builder) SequenceNum(num, minSeqNum int64) *ManifestV2Builder { - b.m.SeqNumber, b.m.MinSeqNumber = num, minSeqNum - return b -} - -func (b *ManifestV2Builder) AddedFiles(cnt int32) *ManifestV2Builder { - b.m.AddedFilesCount = cnt - return b -} - -func (b *ManifestV2Builder) ExistingFiles(cnt int32) *ManifestV2Builder { - b.m.ExistingFilesCount = cnt - return b -} - -func (b *ManifestV2Builder) DeletedFiles(cnt int32) *ManifestV2Builder { - b.m.DeletedFilesCount = cnt - return b -} - -func (b *ManifestV2Builder) AddedRows(cnt int64) *ManifestV2Builder { - b.m.AddedRowsCount = cnt - return b -} - -func (b *ManifestV2Builder) ExistingRows(cnt int64) *ManifestV2Builder { - b.m.ExistingRowsCount = cnt - return b -} - -func (b *ManifestV2Builder) DeletedRows(cnt int64) *ManifestV2Builder { - b.m.DeletedRowsCount = cnt - return b -} - -func (b *ManifestV2Builder) Partitions(p []FieldSummary) *ManifestV2Builder { - b.m.PartitionList = &p - return b -} - -func (b *ManifestV2Builder) KeyMetadata(km []byte) *ManifestV2Builder { - b.m.Key = km - return b -} - -// Build returns the constructed manifest file, after calling Build this -// builder should not be used further as we avoid copying by just returning -// a pointer to the constructed manifest file. Further calls to the modifier -// methods after calling build would modify the constructed ManifestFile. -func (b *ManifestV2Builder) Build() ManifestFile { - return b.m -} - -type manifestFileV2 struct { - Path string `avro:"manifest_path"` - Len int64 `avro:"manifest_length"` - SpecID int32 `avro:"partition_spec_id"` - Content ManifestContent `avro:"content"` - SeqNumber int64 `avro:"sequence_number"` - MinSeqNumber int64 `avro:"min_sequence_number"` - AddedSnapshotID int64 `avro:"added_snapshot_id"` - AddedFilesCount int32 `avro:"added_files_count"` - ExistingFilesCount int32 `avro:"existing_files_count"` - DeletedFilesCount int32 `avro:"deleted_files_count"` - AddedRowsCount int64 `avro:"added_rows_count"` - ExistingRowsCount int64 `avro:"existing_rows_count"` - DeletedRowsCount int64 `avro:"deleted_rows_count"` - PartitionList *[]FieldSummary `avro:"partitions"` - Key []byte `avro:"key_metadata"` -} - -func (*manifestFileV2) Version() int { return 2 } - -func (m *manifestFileV2) FilePath() string { return m.Path } -func (m *manifestFileV2) Length() int64 { return m.Len } -func (m *manifestFileV2) PartitionSpecID() int32 { return m.SpecID } -func (m *manifestFileV2) ManifestContent() ManifestContent { return m.Content } -func (m *manifestFileV2) SnapshotID() int64 { - return m.AddedSnapshotID -} - -func (m *manifestFileV2) AddedDataFiles() int32 { - return m.AddedFilesCount -} - -func (m *manifestFileV2) ExistingDataFiles() int32 { - return m.ExistingFilesCount -} - -func (m *manifestFileV2) DeletedDataFiles() int32 { - return m.DeletedFilesCount -} - -func (m *manifestFileV2) AddedRows() int64 { - return m.AddedRowsCount -} - -func (m *manifestFileV2) ExistingRows() int64 { - return m.ExistingRowsCount -} - -func (m *manifestFileV2) DeletedRows() int64 { - return m.DeletedRowsCount -} - -func (m *manifestFileV2) SequenceNum() int64 { return m.SeqNumber } -func (m *manifestFileV2) MinSequenceNum() int64 { return m.MinSeqNumber } -func (m *manifestFileV2) KeyMetadata() []byte { return m.Key } - -func (m *manifestFileV2) Partitions() []FieldSummary { - if m.PartitionList == nil { - return nil - } - return *m.PartitionList -} - -func (m *manifestFileV2) HasAddedFiles() bool { - return m.AddedFilesCount > 0 -} - -func (m *manifestFileV2) HasExistingFiles() bool { - return m.ExistingFilesCount > 0 -} - -func (m *manifestFileV2) FetchEntries(fs iceio.IO, discardDeleted bool) ([]ManifestEntry, error) { - return fetchManifestEntries(m, fs, discardDeleted) -} - -func getFieldIDMap(sc avro.Schema) map[string]int { - getField := func(rs *avro.RecordSchema, name string) *avro.Field { - for _, f := range rs.Fields() { - if f.Name() == name { - return f - } - } - return nil - } - - result := make(map[string]int) - entryField := getField(sc.(*avro.RecordSchema), "data_file") - partitionField := getField(entryField.Type().(*avro.RecordSchema), "partition") - - for _, field := range partitionField.Type().(*avro.RecordSchema).Fields() { - if fid, ok := field.Prop("field-id").(float64); ok { - result[field.Name()] = int(fid) - } - } - return result -} - -type hasFieldToIDMap interface { - setFieldNameToIDMap(map[string]int) -} - -func fetchManifestEntries(m ManifestFile, fs iceio.IO, discardDeleted bool) ([]ManifestEntry, error) { - f, err := fs.Open(m.FilePath()) - if err != nil { - return nil, err - } - defer f.Close() - - dec, err := ocf.NewDecoder(f) - if err != nil { - return nil, err - } - - metadata := dec.Metadata() - sc, err := avro.ParseBytes(dec.Metadata()["avro.schema"]) - if err != nil { - return nil, err - } - - fieldNameToID := getFieldIDMap(sc) - isVer1, isFallback := true, false - if string(metadata["format-version"]) == "2" { - isVer1 = false - } else { - for _, f := range sc.(*avro.RecordSchema).Fields() { - if f.Name() == "snapshot_id" { - if f.Type().Type() == avro.Union { - isFallback = true - } - break - } - } - } - - results := make([]ManifestEntry, 0) - for dec.HasNext() { - var tmp ManifestEntry - if isVer1 { - if isFallback { - tmp = &fallbackManifestEntryV1{} - } else { - tmp = &manifestEntryV1{} - } - } else { - tmp = &manifestEntryV2{} - } - - if err := dec.Decode(tmp); err != nil { - return nil, err - } - - if isFallback { - tmp = tmp.(*fallbackManifestEntryV1).toEntry() - } - - if !discardDeleted || tmp.Status() != EntryStatusDELETED { - tmp.inheritSeqNum(m) - if fieldToIDMap, ok := tmp.DataFile().(hasFieldToIDMap); ok { - fieldToIDMap.setFieldNameToIDMap(fieldNameToID) - } - results = append(results, tmp) - } - } - - return results, dec.Error() -} - -// ManifestFile is the interface which covers both V1 and V2 manifest files. -type ManifestFile interface { - // Version returns the version number of this manifest file. - // It should be 1 or 2. - Version() int - // FilePath is the location URI of this manifest file. - FilePath() string - // Length is the length in bytes of the manifest file. - Length() int64 - // PartitionSpecID is the ID of the partition spec used to write - // this manifest. It must be listed in the table metadata - // partition-specs. - PartitionSpecID() int32 - // ManifestContent is the type of files tracked by this manifest, - // either data or delete files. All v1 manifests track data files. - ManifestContent() ManifestContent - // SnapshotID is the ID of the snapshot where this manifest file - // was added. - SnapshotID() int64 - // AddedDataFiles returns the number of entries in the manifest that - // have the status of EntryStatusADDED. - AddedDataFiles() int32 - // ExistingDataFiles returns the number of entries in the manifest - // which have the status of EntryStatusEXISTING. - ExistingDataFiles() int32 - // DeletedDataFiles returns the number of entries in the manifest - // which have the status of EntryStatusDELETED. - DeletedDataFiles() int32 - // AddedRows returns the number of rows in all files of the manifest - // that have status EntryStatusADDED. - AddedRows() int64 - // ExistingRows returns the number of rows in all files of the manifest - // which have status EntryStatusEXISTING. - ExistingRows() int64 - // DeletedRows returns the number of rows in all files of the manifest - // which have status EntryStatusDELETED. - DeletedRows() int64 - // SequenceNum returns the sequence number when this manifest was - // added to the table. Will be 0 for v1 manifest lists. - SequenceNum() int64 - // MinSequenceNum is the minimum data sequence number of all live data - // or delete files in the manifest. Will be 0 for v1 manifest lists. - MinSequenceNum() int64 - // KeyMetadata returns implementation-specific key metadata for encryption - // if it exists in the manifest list. - KeyMetadata() []byte - // Partitions returns a list of field summaries for each partition - // field in the spec. Each field in the list corresponds to a field in - // the manifest file's partition spec. - Partitions() []FieldSummary - - // HasAddedFiles returns true if AddedDataFiles > 0 or if it was null. - HasAddedFiles() bool - // HasExistingFiles returns true if ExistingDataFiles > 0 or if it was null. - HasExistingFiles() bool - // FetchEntries reads the manifest list file to fetch the list of - // manifest entries using the provided file system IO interface. - // If discardDeleted is true, entries for files containing deleted rows - // will be skipped. - FetchEntries(fs iceio.IO, discardDeleted bool) ([]ManifestEntry, error) -} - -// ReadManifestList reads in an avro manifest list file and returns a slice -// of manifest files or an error if one is encountered. -func ReadManifestList(in io.Reader) ([]ManifestFile, error) { - dec, err := ocf.NewDecoder(in) - if err != nil { - return nil, err - } - - sc, err := avro.ParseBytes(dec.Metadata()["avro.schema"]) - if err != nil { - return nil, err - } - - var fallbackAddedSnapshot bool - for _, f := range sc.(*avro.RecordSchema).Fields() { - if f.Name() == "added_snapshot_id" { - if f.Type().Type() == avro.Union { - fallbackAddedSnapshot = true - } - break - } - } - - out := make([]ManifestFile, 0) - for dec.HasNext() { - var file ManifestFile - if string(dec.Metadata()["format-version"]) == "2" { - file = &manifestFileV2{} - } else { - if fallbackAddedSnapshot { - file = &fallbackManifestFileV1{} - } else { - file = &manifestFileV1{} - } - } - - if err := dec.Decode(file); err != nil { - return nil, err - } - - if fallbackAddedSnapshot { - file = file.(*fallbackManifestFileV1).toManifest() - } - - out = append(out, file) - } - - return out, dec.Error() -} - -// ManifestEntryStatus defines constants for the entry status of -// existing, added or deleted. -type ManifestEntryStatus int8 - -const ( - EntryStatusEXISTING ManifestEntryStatus = 0 - EntryStatusADDED ManifestEntryStatus = 1 - EntryStatusDELETED ManifestEntryStatus = 2 -) - -// ManifestEntryContent defines constants for the type of file contents -// in the file entries. Data, Position based deletes and equality based -// deletes. -type ManifestEntryContent int8 - -const ( - EntryContentData ManifestEntryContent = 0 - EntryContentPosDeletes ManifestEntryContent = 1 - EntryContentEqDeletes ManifestEntryContent = 2 -) - -func (m ManifestEntryContent) String() string { - switch m { - case EntryContentData: - return "Data" - case EntryContentPosDeletes: - return "Positional_Deletes" - case EntryContentEqDeletes: - return "Equality_Deletes" - default: - return "UNKNOWN" - } -} - -// FileFormat defines constants for the format of data files. -type FileFormat string - -const ( - AvroFile FileFormat = "AVRO" - OrcFile FileFormat = "ORC" - ParquetFile FileFormat = "PARQUET" -) - -type colMap[K, V any] struct { - Key K `avro:"key"` - Value V `avro:"value"` -} - -func avroColMapToMap[K comparable, V any](c *[]colMap[K, V]) map[K]V { - if c == nil { - return nil - } - - out := make(map[K]V) - for _, data := range *c { - out[data.Key] = data.Value - } - return out -} - -func avroPartitionData(input map[string]any) map[string]any { - // hambra/avro/v2 will unmarshal a map[string]any such that - // each entry will actually be a map[string]any with the key being - // the avro type, not the field name. - // - // This means that partition data that looks like this: - // - // [{"field-id": 1000, "name": "ts", "type": {"type": "int", "logicalType": "date"}}] - // - // Becomes: - // - // map[string]any{"ts": map[string]any{"int.date": time.Time{}}} - // - // so we need to simplify our map and make the partition data handling easier - out := make(map[string]any) - for k, v := range input { - switch v := v.(type) { - case map[string]any: - for typeName, val := range v { - switch typeName { - case "int.date": - out[k] = Date(val.(time.Time).Truncate(24*time.Hour).Unix() / int64((time.Hour * 24).Seconds())) - case "int.time-millis": - out[k] = Time(val.(time.Duration).Microseconds()) - case "long.time-micros": - out[k] = Time(val.(time.Duration).Microseconds()) - case "long.timestamp-millis": - out[k] = Timestamp(val.(time.Time).UTC().UnixMicro()) - case "long.timestamp-micros": - out[k] = Timestamp(val.(time.Time).UTC().UnixMicro()) - case "bytes.decimal": - // not implemented yet - case "fixed.decimal": - // not implemented yet - default: - out[k] = val - } - } - default: - switch v := v.(type) { - case time.Time: - out[k] = Timestamp(v.UTC().UnixMicro()) - default: - out[k] = v - } - } - } - return out -} - -type dataFile struct { - Content ManifestEntryContent `avro:"content"` - Path string `avro:"file_path"` - Format FileFormat `avro:"file_format"` - PartitionData map[string]any `avro:"partition"` - RecordCount int64 `avro:"record_count"` - FileSize int64 `avro:"file_size_in_bytes"` - BlockSizeInBytes int64 `avro:"block_size_in_bytes"` - ColSizes *[]colMap[int, int64] `avro:"column_sizes"` - ValCounts *[]colMap[int, int64] `avro:"value_counts"` - NullCounts *[]colMap[int, int64] `avro:"null_value_counts"` - NaNCounts *[]colMap[int, int64] `avro:"nan_value_counts"` - DistinctCounts *[]colMap[int, int64] `avro:"distinct_counts"` - LowerBounds *[]colMap[int, []byte] `avro:"lower_bounds"` - UpperBounds *[]colMap[int, []byte] `avro:"upper_bounds"` - Key *[]byte `avro:"key_metadata"` - Splits *[]int64 `avro:"split_offsets"` - EqualityIDs *[]int `avro:"equality_ids"` - SortOrder *int `avro:"sort_order_id"` - - colSizeMap map[int]int64 - valCntMap map[int]int64 - nullCntMap map[int]int64 - nanCntMap map[int]int64 - distinctCntMap map[int]int64 - lowerBoundMap map[int][]byte - upperBoundMap map[int][]byte - - // not used for anything yet, but important to maintain the information - // for future development and updates such as when we get to writes, - // and scan planning - fieldNameToID map[string]int - - initMaps sync.Once -} - -func (d *dataFile) initializeMapData() { - d.initMaps.Do(func() { - d.colSizeMap = avroColMapToMap(d.ColSizes) - d.valCntMap = avroColMapToMap(d.ValCounts) - d.nullCntMap = avroColMapToMap(d.NullCounts) - d.nanCntMap = avroColMapToMap(d.NaNCounts) - d.distinctCntMap = avroColMapToMap(d.DistinctCounts) - d.lowerBoundMap = avroColMapToMap(d.LowerBounds) - d.upperBoundMap = avroColMapToMap(d.UpperBounds) - d.PartitionData = avroPartitionData(d.PartitionData) - }) -} - -func (d *dataFile) setFieldNameToIDMap(m map[string]int) { d.fieldNameToID = m } - -func (d *dataFile) ContentType() ManifestEntryContent { return d.Content } -func (d *dataFile) FilePath() string { return d.Path } -func (d *dataFile) FileFormat() FileFormat { return d.Format } -func (d *dataFile) Partition() map[string]any { - d.initializeMapData() - return d.PartitionData -} - -func (d *dataFile) Count() int64 { return d.RecordCount } -func (d *dataFile) FileSizeBytes() int64 { return d.FileSize } - -func (d *dataFile) ColumnSizes() map[int]int64 { - d.initializeMapData() - return d.colSizeMap -} - -func (d *dataFile) ValueCounts() map[int]int64 { - d.initializeMapData() - return d.valCntMap -} - -func (d *dataFile) NullValueCounts() map[int]int64 { - d.initializeMapData() - return d.nullCntMap -} - -func (d *dataFile) NaNValueCounts() map[int]int64 { - d.initializeMapData() - return d.nanCntMap -} - -func (d *dataFile) DistinctValueCounts() map[int]int64 { - d.initializeMapData() - return d.distinctCntMap -} - -func (d *dataFile) LowerBoundValues() map[int][]byte { - d.initializeMapData() - return d.lowerBoundMap -} - -func (d *dataFile) UpperBoundValues() map[int][]byte { - d.initializeMapData() - return d.upperBoundMap -} - -func (d *dataFile) KeyMetadata() []byte { - if d.Key == nil { - return nil - } - return *d.Key -} - -func (d *dataFile) SplitOffsets() []int64 { - if d.Splits == nil { - return nil - } - return *d.Splits -} - -func (d *dataFile) EqualityFieldIDs() []int { - if d.EqualityIDs == nil { - return nil - } - return d.EqualityFieldIDs() -} - -func (d *dataFile) SortOrderID() *int { return d.SortOrder } - -type manifestEntryV1 struct { - EntryStatus ManifestEntryStatus `avro:"status"` - Snapshot int64 `avro:"snapshot_id"` - SeqNum *int64 - FileSeqNum *int64 - Data dataFile `avro:"data_file"` -} - -type fallbackManifestEntryV1 struct { - manifestEntryV1 - Snapshot *int64 `avro:"snapshot_id"` -} - -func (f *fallbackManifestEntryV1) toEntry() *manifestEntryV1 { - f.manifestEntryV1.Snapshot = *f.Snapshot - return &f.manifestEntryV1 -} - -func (m *manifestEntryV1) inheritSeqNum(manifest ManifestFile) {} - -func (m *manifestEntryV1) Status() ManifestEntryStatus { return m.EntryStatus } -func (m *manifestEntryV1) SnapshotID() int64 { return m.Snapshot } - -func (m *manifestEntryV1) SequenceNum() int64 { - if m.SeqNum == nil { - return 0 - } - return *m.SeqNum -} - -func (m *manifestEntryV1) FileSequenceNum() *int64 { - return m.FileSeqNum -} - -func (m *manifestEntryV1) DataFile() DataFile { return &m.Data } - -type manifestEntryV2 struct { - EntryStatus ManifestEntryStatus `avro:"status"` - Snapshot *int64 `avro:"snapshot_id"` - SeqNum *int64 `avro:"sequence_number"` - FileSeqNum *int64 `avro:"file_sequence_number"` - Data dataFile `avro:"data_file"` -} - -func (m *manifestEntryV2) inheritSeqNum(manifest ManifestFile) { - if m.Snapshot == nil { - snap := manifest.SnapshotID() - m.Snapshot = &snap - } - - manifestSequenceNum := manifest.SequenceNum() - if m.SeqNum == nil && (manifestSequenceNum == 0 || m.EntryStatus == EntryStatusADDED) { - m.SeqNum = &manifestSequenceNum - } - - if m.FileSeqNum == nil && (manifestSequenceNum == 0 || m.EntryStatus == EntryStatusADDED) { - m.FileSeqNum = &manifestSequenceNum - } -} - -func (m *manifestEntryV2) Status() ManifestEntryStatus { return m.EntryStatus } -func (m *manifestEntryV2) SnapshotID() int64 { - if m.Snapshot == nil { - return 0 - } - return *m.Snapshot -} - -func (m *manifestEntryV2) SequenceNum() int64 { - if m.SeqNum == nil { - return 0 - } - return *m.SeqNum -} - -func (m *manifestEntryV2) FileSequenceNum() *int64 { - return m.FileSeqNum -} - -func (m *manifestEntryV2) DataFile() DataFile { return &m.Data } - -// DataFile is the interface for reading the information about a -// given data file indicated by an entry in a manifest list. -type DataFile interface { - // ContentType is the type of the content stored by the data file, - // either Data, Equality deletes, or Position deletes. All v1 files - // are Data files. - ContentType() ManifestEntryContent - // FilePath is the full URI for the file, complete with FS scheme. - FilePath() string - // FileFormat is the format of the data file, AVRO, Orc, or Parquet. - FileFormat() FileFormat - // Partition returns a mapping of field name to partition value for - // each of the partition spec's fields. - Partition() map[string]any - // Count returns the number of records in this file. - Count() int64 - // FileSizeBytes is the total file size in bytes. - FileSizeBytes() int64 - // ColumnSizes is a mapping from column id to the total size on disk - // of all regions that store the column. Does not include bytes - // necessary to read other columns, like footers. Map will be nil for - // row-oriented formats (avro). - ColumnSizes() map[int]int64 - // ValueCounts is a mapping from column id to the number of values - // in the column, including null and NaN values. - ValueCounts() map[int]int64 - // NullValueCounts is a mapping from column id to the number of - // null values in the column. - NullValueCounts() map[int]int64 - // NaNValueCounts is a mapping from column id to the number of NaN - // values in the column. - NaNValueCounts() map[int]int64 - // DistictValueCounts is a mapping from column id to the number of - // distinct values in the column. Distinct counts must be derived - // using values in the file by counting or using sketches, but not - // using methods like merging existing distinct counts. - DistinctValueCounts() map[int]int64 - // LowerBoundValues is a mapping from column id to the lower bounded - // value of the column, serialized as binary. Each value in the column - // must be less than or requal to all non-null, non-NaN values in the - // column for the file. - LowerBoundValues() map[int][]byte - // UpperBoundValues is a mapping from column id to the upper bounded - // value of the column, serialized as binary. Each value in the column - // must be greater than or equal to all non-null, non-NaN values in - // the column for the file. - UpperBoundValues() map[int][]byte - // KeyMetadata is implementation-specific key metadata for encryption. - KeyMetadata() []byte - // SplitOffsets are the split offsets for the data file. For example, - // all row group offsets in a Parquet file. Must be sorted ascending. - SplitOffsets() []int64 - // EqualityFieldIDs are used to determine row equality in equality - // delete files. It is required when the content type is - // EntryContentEqDeletes. - EqualityFieldIDs() []int - // SortOrderID returns the id representing the sort order for this - // file, or nil if there is no sort order. - SortOrderID() *int -} - -// ManifestEntry is an interface for both v1 and v2 manifest entries. -type ManifestEntry interface { - // Status returns the type of the file tracked by this entry. - // Deletes are informational only and not used in scans. - Status() ManifestEntryStatus - // SnapshotID is the id where the file was added, or deleted, - // if null it is inherited from the manifest list. - SnapshotID() int64 - // SequenceNum returns the data sequence number of the file. - // If it was null and the status is EntryStatusADDED then it - // is inherited from the manifest list. - SequenceNum() int64 - // FileSequenceNum returns the file sequence number indicating - // when the file was added. If it was null and the status is - // EntryStatusADDED then it is inherited from the manifest list. - FileSequenceNum() *int64 - // DataFile provides the information about the data file indicated - // by this manifest entry. - DataFile() DataFile - - inheritSeqNum(manifest ManifestFile) -} - -var PositionalDeleteSchema = NewSchema(0, - NestedField{ID: 2147483546, Type: PrimitiveTypes.String, Name: "file_path", Required: true}, - NestedField{ID: 2147483545, Type: PrimitiveTypes.Int32, Name: "pos", Required: true}, -) +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package iceberg + +import ( + "io" + "sync" + "time" + + iceio "github.com/apache/iceberg-go/io" + + "github.com/hamba/avro/v2" + "github.com/hamba/avro/v2/ocf" +) + +// ManifestContent indicates the type of data inside of the files +// described by a manifest. This will indicate whether the data files +// contain active data or deleted rows. +type ManifestContent int32 + +const ( + ManifestContentData ManifestContent = 0 + ManifestContentDeletes ManifestContent = 1 +) + +type FieldSummary struct { + ContainsNull bool `avro:"contains_null"` + ContainsNaN *bool `avro:"contains_nan"` + LowerBound *[]byte `avro:"lower_bound"` + UpperBound *[]byte `avro:"upper_bound"` +} + +// ManifestV1Builder is a helper for building a V1 manifest file +// struct which will conform to the ManifestFile interface. +type ManifestV1Builder struct { + m *manifestFileV1 +} + +// NewManifestV1Builder is passed all of the required fields and then allows +// all of the optional fields to be set by calling the corresponding methods +// before calling [ManifestV1Builder.Build] to construct the object. +func NewManifestV1Builder(path string, length int64, partitionSpecID int32, addedSnapshotID int64) *ManifestV1Builder { + return &ManifestV1Builder{ + m: &manifestFileV1{ + Path: path, + Len: length, + SpecID: partitionSpecID, + AddedSnapshotID: addedSnapshotID, + }, + } +} + +func (b *ManifestV1Builder) AddedFiles(cnt int32) *ManifestV1Builder { + b.m.AddedFilesCount = &cnt + return b +} + +func (b *ManifestV1Builder) ExistingFiles(cnt int32) *ManifestV1Builder { + b.m.ExistingFilesCount = &cnt + return b +} + +func (b *ManifestV1Builder) DeletedFiles(cnt int32) *ManifestV1Builder { + b.m.DeletedFilesCount = &cnt + return b +} + +func (b *ManifestV1Builder) AddedRows(cnt int64) *ManifestV1Builder { + b.m.AddedRowsCount = &cnt + return b +} + +func (b *ManifestV1Builder) ExistingRows(cnt int64) *ManifestV1Builder { + b.m.ExistingRowsCount = &cnt + return b +} + +func (b *ManifestV1Builder) DeletedRows(cnt int64) *ManifestV1Builder { + b.m.DeletedRowsCount = &cnt + return b +} + +func (b *ManifestV1Builder) Partitions(p []FieldSummary) *ManifestV1Builder { + b.m.PartitionList = &p + return b +} + +func (b *ManifestV1Builder) KeyMetadata(km []byte) *ManifestV1Builder { + b.m.Key = km + return b +} + +// Build returns the constructed manifest file, after calling Build this +// builder should not be used further as we avoid copying by just returning +// a pointer to the constructed manifest file. Further calls to the modifier +// methods after calling build would modify the constructed ManifestFile. +func (b *ManifestV1Builder) Build() ManifestFile { + return b.m +} + +type fallbackManifestFileV1 struct { + manifestFileV1 + AddedSnapshotID *int64 `avro:"added_snapshot_id"` +} + +func (f *fallbackManifestFileV1) toManifest() *manifestFileV1 { + f.manifestFileV1.AddedSnapshotID = *f.AddedSnapshotID + return &f.manifestFileV1 +} + +type manifestFileV1 struct { + Path string `avro:"manifest_path"` + Len int64 `avro:"manifest_length"` + SpecID int32 `avro:"partition_spec_id"` + AddedSnapshotID int64 `avro:"added_snapshot_id"` + AddedFilesCount *int32 `avro:"added_data_files_count"` + ExistingFilesCount *int32 `avro:"existing_data_files_count"` + DeletedFilesCount *int32 `avro:"deleted_data_files_count"` + AddedRowsCount *int64 `avro:"added_rows_count"` + ExistingRowsCount *int64 `avro:"existing_rows_count"` + DeletedRowsCount *int64 `avro:"deleted_rows_count"` + PartitionList *[]FieldSummary `avro:"partitions"` + Key []byte `avro:"key_metadata"` +} + +func (*manifestFileV1) Version() int { return 1 } +func (m *manifestFileV1) FilePath() string { return m.Path } +func (m *manifestFileV1) Length() int64 { return m.Len } +func (m *manifestFileV1) PartitionSpecID() int32 { return m.SpecID } +func (m *manifestFileV1) ManifestContent() ManifestContent { + return ManifestContentData +} +func (m *manifestFileV1) SnapshotID() int64 { + return m.AddedSnapshotID +} + +func (m *manifestFileV1) AddedDataFiles() int32 { + if m.AddedFilesCount == nil { + return 0 + } + return *m.AddedFilesCount +} + +func (m *manifestFileV1) ExistingDataFiles() int32 { + if m.ExistingFilesCount == nil { + return 0 + } + return *m.ExistingFilesCount +} + +func (m *manifestFileV1) DeletedDataFiles() int32 { + if m.DeletedFilesCount == nil { + return 0 + } + return *m.DeletedFilesCount +} + +func (m *manifestFileV1) AddedRows() int64 { + if m.AddedRowsCount == nil { + return 0 + } + return *m.AddedRowsCount +} + +func (m *manifestFileV1) ExistingRows() int64 { + if m.ExistingRowsCount == nil { + return 0 + } + return *m.ExistingRowsCount +} + +func (m *manifestFileV1) DeletedRows() int64 { + if m.DeletedRowsCount == nil { + return 0 + } + return *m.DeletedRowsCount +} + +func (m *manifestFileV1) HasAddedFiles() bool { + return m.AddedFilesCount == nil || *m.AddedFilesCount > 0 +} + +func (m *manifestFileV1) HasExistingFiles() bool { + return m.ExistingFilesCount == nil || *m.ExistingFilesCount > 0 +} + +func (m *manifestFileV1) SequenceNum() int64 { return 0 } +func (m *manifestFileV1) MinSequenceNum() int64 { return 0 } +func (m *manifestFileV1) KeyMetadata() []byte { return m.Key } +func (m *manifestFileV1) Partitions() []FieldSummary { + if m.PartitionList == nil { + return nil + } + return *m.PartitionList +} + +func (m *manifestFileV1) FetchEntries(fs iceio.IO, discardDeleted bool) ([]ManifestEntry, error) { + return fetchManifestEntries(m, fs, discardDeleted) +} + +// ManifestV2Builder is a helper for building a V2 manifest file +// struct which will conform to the ManifestFile interface. +type ManifestV2Builder struct { + m *manifestFileV2 +} + +// NewManifestV2Builder is constructed with the primary fields, with the remaining +// fields set to their zero value unless modified by calling the corresponding +// methods of the builder. Then calling [ManifestV2Builder.Build] to retrieve the +// constructed ManifestFile. +func NewManifestV2Builder(path string, length int64, partitionSpecID int32, content ManifestContent, addedSnapshotID int64) *ManifestV2Builder { + return &ManifestV2Builder{ + m: &manifestFileV2{ + Path: path, + Len: length, + SpecID: partitionSpecID, + Content: content, + AddedSnapshotID: addedSnapshotID, + }, + } +} + +func (b *ManifestV2Builder) SequenceNum(num, minSeqNum int64) *ManifestV2Builder { + b.m.SeqNumber, b.m.MinSeqNumber = num, minSeqNum + return b +} + +func (b *ManifestV2Builder) AddedFiles(cnt int32) *ManifestV2Builder { + b.m.AddedFilesCount = cnt + return b +} + +func (b *ManifestV2Builder) ExistingFiles(cnt int32) *ManifestV2Builder { + b.m.ExistingFilesCount = cnt + return b +} + +func (b *ManifestV2Builder) DeletedFiles(cnt int32) *ManifestV2Builder { + b.m.DeletedFilesCount = cnt + return b +} + +func (b *ManifestV2Builder) AddedRows(cnt int64) *ManifestV2Builder { + b.m.AddedRowsCount = cnt + return b +} + +func (b *ManifestV2Builder) ExistingRows(cnt int64) *ManifestV2Builder { + b.m.ExistingRowsCount = cnt + return b +} + +func (b *ManifestV2Builder) DeletedRows(cnt int64) *ManifestV2Builder { + b.m.DeletedRowsCount = cnt + return b +} + +func (b *ManifestV2Builder) Partitions(p []FieldSummary) *ManifestV2Builder { + b.m.PartitionList = &p + return b +} + +func (b *ManifestV2Builder) KeyMetadata(km []byte) *ManifestV2Builder { + b.m.Key = km + return b +} + +// Build returns the constructed manifest file, after calling Build this +// builder should not be used further as we avoid copying by just returning +// a pointer to the constructed manifest file. Further calls to the modifier +// methods after calling build would modify the constructed ManifestFile. +func (b *ManifestV2Builder) Build() ManifestFile { + return b.m +} + +type manifestFileV2 struct { + Path string `avro:"manifest_path"` + Len int64 `avro:"manifest_length"` + SpecID int32 `avro:"partition_spec_id"` + Content ManifestContent `avro:"content"` + SeqNumber int64 `avro:"sequence_number"` + MinSeqNumber int64 `avro:"min_sequence_number"` + AddedSnapshotID int64 `avro:"added_snapshot_id"` + AddedFilesCount int32 `avro:"added_files_count"` + ExistingFilesCount int32 `avro:"existing_files_count"` + DeletedFilesCount int32 `avro:"deleted_files_count"` + AddedRowsCount int64 `avro:"added_rows_count"` + ExistingRowsCount int64 `avro:"existing_rows_count"` + DeletedRowsCount int64 `avro:"deleted_rows_count"` + PartitionList *[]FieldSummary `avro:"partitions"` + Key []byte `avro:"key_metadata"` +} + +func (*manifestFileV2) Version() int { return 2 } + +func (m *manifestFileV2) FilePath() string { return m.Path } +func (m *manifestFileV2) Length() int64 { return m.Len } +func (m *manifestFileV2) PartitionSpecID() int32 { return m.SpecID } +func (m *manifestFileV2) ManifestContent() ManifestContent { return m.Content } +func (m *manifestFileV2) SnapshotID() int64 { + return m.AddedSnapshotID +} + +func (m *manifestFileV2) AddedDataFiles() int32 { + return m.AddedFilesCount +} + +func (m *manifestFileV2) ExistingDataFiles() int32 { + return m.ExistingFilesCount +} + +func (m *manifestFileV2) DeletedDataFiles() int32 { + return m.DeletedFilesCount +} + +func (m *manifestFileV2) AddedRows() int64 { + return m.AddedRowsCount +} + +func (m *manifestFileV2) ExistingRows() int64 { + return m.ExistingRowsCount +} + +func (m *manifestFileV2) DeletedRows() int64 { + return m.DeletedRowsCount +} + +func (m *manifestFileV2) SequenceNum() int64 { return m.SeqNumber } +func (m *manifestFileV2) MinSequenceNum() int64 { return m.MinSeqNumber } +func (m *manifestFileV2) KeyMetadata() []byte { return m.Key } + +func (m *manifestFileV2) Partitions() []FieldSummary { + if m.PartitionList == nil { + return nil + } + return *m.PartitionList +} + +func (m *manifestFileV2) HasAddedFiles() bool { + return m.AddedFilesCount > 0 +} + +func (m *manifestFileV2) HasExistingFiles() bool { + return m.ExistingFilesCount > 0 +} + +func (m *manifestFileV2) FetchEntries(fs iceio.IO, discardDeleted bool) ([]ManifestEntry, error) { + return fetchManifestEntries(m, fs, discardDeleted) +} + +func getFieldIDMap(sc avro.Schema) map[string]int { + getField := func(rs *avro.RecordSchema, name string) *avro.Field { + for _, f := range rs.Fields() { + if f.Name() == name { + return f + } + } + return nil + } + + result := make(map[string]int) + entryField := getField(sc.(*avro.RecordSchema), "data_file") + partitionField := getField(entryField.Type().(*avro.RecordSchema), "partition") + + for _, field := range partitionField.Type().(*avro.RecordSchema).Fields() { + if fid, ok := field.Prop("field-id").(float64); ok { + result[field.Name()] = int(fid) + } + } + return result +} + +type hasFieldToIDMap interface { + setFieldNameToIDMap(map[string]int) +} + +func fetchManifestEntries(m ManifestFile, fs iceio.IO, discardDeleted bool) ([]ManifestEntry, error) { + f, err := fs.Open(m.FilePath()) + if err != nil { + return nil, err + } + defer f.Close() + + dec, err := ocf.NewDecoder(f) + if err != nil { + return nil, err + } + + metadata := dec.Metadata() + sc, err := avro.ParseBytes(dec.Metadata()["avro.schema"]) + if err != nil { + return nil, err + } + + fieldNameToID := getFieldIDMap(sc) + isVer1, isFallback := true, false + if string(metadata["format-version"]) == "2" { + isVer1 = false + } else { + for _, f := range sc.(*avro.RecordSchema).Fields() { + if f.Name() == "snapshot_id" { + if f.Type().Type() == avro.Union { + isFallback = true + } + break + } + } + } + + results := make([]ManifestEntry, 0) + for dec.HasNext() { + var tmp ManifestEntry + if isVer1 { + if isFallback { + tmp = &fallbackManifestEntryV1{} + } else { + tmp = &manifestEntryV1{} + } + } else { + tmp = &manifestEntryV2{} + } + + if err := dec.Decode(tmp); err != nil { + return nil, err + } + + if isFallback { + tmp = tmp.(*fallbackManifestEntryV1).toEntry() + } + + if !discardDeleted || tmp.Status() != EntryStatusDELETED { + tmp.inheritSeqNum(m) + if fieldToIDMap, ok := tmp.DataFile().(hasFieldToIDMap); ok { + fieldToIDMap.setFieldNameToIDMap(fieldNameToID) + } + results = append(results, tmp) + } + } + + return results, dec.Error() +} + +// ManifestFile is the interface which covers both V1 and V2 manifest files. +type ManifestFile interface { + // Version returns the version number of this manifest file. + // It should be 1 or 2. + Version() int + // FilePath is the location URI of this manifest file. + FilePath() string + // Length is the length in bytes of the manifest file. + Length() int64 + // PartitionSpecID is the ID of the partition spec used to write + // this manifest. It must be listed in the table metadata + // partition-specs. + PartitionSpecID() int32 + // ManifestContent is the type of files tracked by this manifest, + // either data or delete files. All v1 manifests track data files. + ManifestContent() ManifestContent + // SnapshotID is the ID of the snapshot where this manifest file + // was added. + SnapshotID() int64 + // AddedDataFiles returns the number of entries in the manifest that + // have the status of EntryStatusADDED. + AddedDataFiles() int32 + // ExistingDataFiles returns the number of entries in the manifest + // which have the status of EntryStatusEXISTING. + ExistingDataFiles() int32 + // DeletedDataFiles returns the number of entries in the manifest + // which have the status of EntryStatusDELETED. + DeletedDataFiles() int32 + // AddedRows returns the number of rows in all files of the manifest + // that have status EntryStatusADDED. + AddedRows() int64 + // ExistingRows returns the number of rows in all files of the manifest + // which have status EntryStatusEXISTING. + ExistingRows() int64 + // DeletedRows returns the number of rows in all files of the manifest + // which have status EntryStatusDELETED. + DeletedRows() int64 + // SequenceNum returns the sequence number when this manifest was + // added to the table. Will be 0 for v1 manifest lists. + SequenceNum() int64 + // MinSequenceNum is the minimum data sequence number of all live data + // or delete files in the manifest. Will be 0 for v1 manifest lists. + MinSequenceNum() int64 + // KeyMetadata returns implementation-specific key metadata for encryption + // if it exists in the manifest list. + KeyMetadata() []byte + // Partitions returns a list of field summaries for each partition + // field in the spec. Each field in the list corresponds to a field in + // the manifest file's partition spec. + Partitions() []FieldSummary + + // HasAddedFiles returns true if AddedDataFiles > 0 or if it was null. + HasAddedFiles() bool + // HasExistingFiles returns true if ExistingDataFiles > 0 or if it was null. + HasExistingFiles() bool + // FetchEntries reads the manifest list file to fetch the list of + // manifest entries using the provided file system IO interface. + // If discardDeleted is true, entries for files containing deleted rows + // will be skipped. + FetchEntries(fs iceio.IO, discardDeleted bool) ([]ManifestEntry, error) +} + +// ReadManifestList reads in an avro manifest list file and returns a slice +// of manifest files or an error if one is encountered. +func ReadManifestList(in io.Reader) ([]ManifestFile, error) { + dec, err := ocf.NewDecoder(in) + if err != nil { + return nil, err + } + + sc, err := avro.ParseBytes(dec.Metadata()["avro.schema"]) + if err != nil { + return nil, err + } + + var fallbackAddedSnapshot bool + for _, f := range sc.(*avro.RecordSchema).Fields() { + if f.Name() == "added_snapshot_id" { + if f.Type().Type() == avro.Union { + fallbackAddedSnapshot = true + } + break + } + } + + out := make([]ManifestFile, 0) + for dec.HasNext() { + var file ManifestFile + if string(dec.Metadata()["format-version"]) == "2" { + file = &manifestFileV2{} + } else { + if fallbackAddedSnapshot { + file = &fallbackManifestFileV1{} + } else { + file = &manifestFileV1{} + } + } + + if err := dec.Decode(file); err != nil { + return nil, err + } + + if fallbackAddedSnapshot { + file = file.(*fallbackManifestFileV1).toManifest() + } + + out = append(out, file) + } + + return out, dec.Error() +} + +// ManifestEntryStatus defines constants for the entry status of +// existing, added or deleted. +type ManifestEntryStatus int8 + +const ( + EntryStatusEXISTING ManifestEntryStatus = 0 + EntryStatusADDED ManifestEntryStatus = 1 + EntryStatusDELETED ManifestEntryStatus = 2 +) + +// ManifestEntryContent defines constants for the type of file contents +// in the file entries. Data, Position based deletes and equality based +// deletes. +type ManifestEntryContent int8 + +const ( + EntryContentData ManifestEntryContent = 0 + EntryContentPosDeletes ManifestEntryContent = 1 + EntryContentEqDeletes ManifestEntryContent = 2 +) + +func (m ManifestEntryContent) String() string { + switch m { + case EntryContentData: + return "Data" + case EntryContentPosDeletes: + return "Positional_Deletes" + case EntryContentEqDeletes: + return "Equality_Deletes" + default: + return "UNKNOWN" + } +} + +// FileFormat defines constants for the format of data files. +type FileFormat string + +const ( + AvroFile FileFormat = "AVRO" + OrcFile FileFormat = "ORC" + ParquetFile FileFormat = "PARQUET" +) + +type colMap[K, V any] struct { + Key K `avro:"key"` + Value V `avro:"value"` +} + +func avroColMapToMap[K comparable, V any](c *[]colMap[K, V]) map[K]V { + if c == nil { + return nil + } + + out := make(map[K]V) + for _, data := range *c { + out[data.Key] = data.Value + } + return out +} + +func avroPartitionData(input map[string]any) map[string]any { + // hambra/avro/v2 will unmarshal a map[string]any such that + // each entry will actually be a map[string]any with the key being + // the avro type, not the field name. + // + // This means that partition data that looks like this: + // + // [{"field-id": 1000, "name": "ts", "type": {"type": "int", "logicalType": "date"}}] + // + // Becomes: + // + // map[string]any{"ts": map[string]any{"int.date": time.Time{}}} + // + // so we need to simplify our map and make the partition data handling easier + out := make(map[string]any) + for k, v := range input { + switch v := v.(type) { + case map[string]any: + for typeName, val := range v { + switch typeName { + case "int.date": + out[k] = Date(val.(time.Time).Truncate(24*time.Hour).Unix() / int64((time.Hour * 24).Seconds())) + case "int.time-millis": + out[k] = Time(val.(time.Duration).Microseconds()) + case "long.time-micros": + out[k] = Time(val.(time.Duration).Microseconds()) + case "long.timestamp-millis": + out[k] = Timestamp(val.(time.Time).UTC().UnixMicro()) + case "long.timestamp-micros": + out[k] = Timestamp(val.(time.Time).UTC().UnixMicro()) + case "bytes.decimal": + // not implemented yet + case "fixed.decimal": + // not implemented yet + default: + out[k] = val + } + } + default: + switch v := v.(type) { + case time.Time: + out[k] = Timestamp(v.UTC().UnixMicro()) + default: + out[k] = v + } + } + } + return out +} + +type dataFile struct { + Content ManifestEntryContent `avro:"content"` + Path string `avro:"file_path"` + Format FileFormat `avro:"file_format"` + PartitionData map[string]any `avro:"partition"` + RecordCount int64 `avro:"record_count"` + FileSize int64 `avro:"file_size_in_bytes"` + BlockSizeInBytes int64 `avro:"block_size_in_bytes"` + ColSizes *[]colMap[int, int64] `avro:"column_sizes"` + ValCounts *[]colMap[int, int64] `avro:"value_counts"` + NullCounts *[]colMap[int, int64] `avro:"null_value_counts"` + NaNCounts *[]colMap[int, int64] `avro:"nan_value_counts"` + DistinctCounts *[]colMap[int, int64] `avro:"distinct_counts"` + LowerBounds *[]colMap[int, []byte] `avro:"lower_bounds"` + UpperBounds *[]colMap[int, []byte] `avro:"upper_bounds"` + Key *[]byte `avro:"key_metadata"` + Splits *[]int64 `avro:"split_offsets"` + EqualityIDs *[]int `avro:"equality_ids"` + SortOrder *int `avro:"sort_order_id"` + + colSizeMap map[int]int64 + valCntMap map[int]int64 + nullCntMap map[int]int64 + nanCntMap map[int]int64 + distinctCntMap map[int]int64 + lowerBoundMap map[int][]byte + upperBoundMap map[int][]byte + + // not used for anything yet, but important to maintain the information + // for future development and updates such as when we get to writes, + // and scan planning + fieldNameToID map[string]int + + initMaps sync.Once +} + +func (d *dataFile) initializeMapData() { + d.initMaps.Do(func() { + d.colSizeMap = avroColMapToMap(d.ColSizes) + d.valCntMap = avroColMapToMap(d.ValCounts) + d.nullCntMap = avroColMapToMap(d.NullCounts) + d.nanCntMap = avroColMapToMap(d.NaNCounts) + d.distinctCntMap = avroColMapToMap(d.DistinctCounts) + d.lowerBoundMap = avroColMapToMap(d.LowerBounds) + d.upperBoundMap = avroColMapToMap(d.UpperBounds) + d.PartitionData = avroPartitionData(d.PartitionData) + }) +} + +func (d *dataFile) setFieldNameToIDMap(m map[string]int) { d.fieldNameToID = m } + +func (d *dataFile) ContentType() ManifestEntryContent { return d.Content } +func (d *dataFile) FilePath() string { return d.Path } +func (d *dataFile) FileFormat() FileFormat { return d.Format } +func (d *dataFile) Partition() map[string]any { + d.initializeMapData() + return d.PartitionData +} + +func (d *dataFile) Count() int64 { return d.RecordCount } +func (d *dataFile) FileSizeBytes() int64 { return d.FileSize } + +func (d *dataFile) ColumnSizes() map[int]int64 { + d.initializeMapData() + return d.colSizeMap +} + +func (d *dataFile) ValueCounts() map[int]int64 { + d.initializeMapData() + return d.valCntMap +} + +func (d *dataFile) NullValueCounts() map[int]int64 { + d.initializeMapData() + return d.nullCntMap +} + +func (d *dataFile) NaNValueCounts() map[int]int64 { + d.initializeMapData() + return d.nanCntMap +} + +func (d *dataFile) DistinctValueCounts() map[int]int64 { + d.initializeMapData() + return d.distinctCntMap +} + +func (d *dataFile) LowerBoundValues() map[int][]byte { + d.initializeMapData() + return d.lowerBoundMap +} + +func (d *dataFile) UpperBoundValues() map[int][]byte { + d.initializeMapData() + return d.upperBoundMap +} + +func (d *dataFile) KeyMetadata() []byte { + if d.Key == nil { + return nil + } + return *d.Key +} + +func (d *dataFile) SplitOffsets() []int64 { + if d.Splits == nil { + return nil + } + return *d.Splits +} + +func (d *dataFile) EqualityFieldIDs() []int { + if d.EqualityIDs == nil { + return nil + } + return d.EqualityFieldIDs() +} + +func (d *dataFile) SortOrderID() *int { return d.SortOrder } + +type manifestEntryV1 struct { + EntryStatus ManifestEntryStatus `avro:"status"` + Snapshot int64 `avro:"snapshot_id"` + SeqNum *int64 + FileSeqNum *int64 + Data dataFile `avro:"data_file"` +} + +type fallbackManifestEntryV1 struct { + manifestEntryV1 + Snapshot *int64 `avro:"snapshot_id"` +} + +func (f *fallbackManifestEntryV1) toEntry() *manifestEntryV1 { + f.manifestEntryV1.Snapshot = *f.Snapshot + return &f.manifestEntryV1 +} + +func (m *manifestEntryV1) inheritSeqNum(manifest ManifestFile) {} + +func (m *manifestEntryV1) Status() ManifestEntryStatus { return m.EntryStatus } +func (m *manifestEntryV1) SnapshotID() int64 { return m.Snapshot } + +func (m *manifestEntryV1) SequenceNum() int64 { + if m.SeqNum == nil { + return 0 + } + return *m.SeqNum +} + +func (m *manifestEntryV1) FileSequenceNum() *int64 { + return m.FileSeqNum +} + +func (m *manifestEntryV1) DataFile() DataFile { return &m.Data } + +type manifestEntryV2 struct { + EntryStatus ManifestEntryStatus `avro:"status"` + Snapshot *int64 `avro:"snapshot_id"` + SeqNum *int64 `avro:"sequence_number"` + FileSeqNum *int64 `avro:"file_sequence_number"` + Data dataFile `avro:"data_file"` +} + +func (m *manifestEntryV2) inheritSeqNum(manifest ManifestFile) { + if m.Snapshot == nil { + snap := manifest.SnapshotID() + m.Snapshot = &snap + } + + manifestSequenceNum := manifest.SequenceNum() + if m.SeqNum == nil && (manifestSequenceNum == 0 || m.EntryStatus == EntryStatusADDED) { + m.SeqNum = &manifestSequenceNum + } + + if m.FileSeqNum == nil && (manifestSequenceNum == 0 || m.EntryStatus == EntryStatusADDED) { + m.FileSeqNum = &manifestSequenceNum + } +} + +func (m *manifestEntryV2) Status() ManifestEntryStatus { return m.EntryStatus } +func (m *manifestEntryV2) SnapshotID() int64 { + if m.Snapshot == nil { + return 0 + } + return *m.Snapshot +} + +func (m *manifestEntryV2) SequenceNum() int64 { + if m.SeqNum == nil { + return 0 + } + return *m.SeqNum +} + +func (m *manifestEntryV2) FileSequenceNum() *int64 { + return m.FileSeqNum +} + +func (m *manifestEntryV2) DataFile() DataFile { return &m.Data } + +// DataFile is the interface for reading the information about a +// given data file indicated by an entry in a manifest list. +type DataFile interface { + // ContentType is the type of the content stored by the data file, + // either Data, Equality deletes, or Position deletes. All v1 files + // are Data files. + ContentType() ManifestEntryContent + // FilePath is the full URI for the file, complete with FS scheme. + FilePath() string + // FileFormat is the format of the data file, AVRO, Orc, or Parquet. + FileFormat() FileFormat + // Partition returns a mapping of field name to partition value for + // each of the partition spec's fields. + Partition() map[string]any + // Count returns the number of records in this file. + Count() int64 + // FileSizeBytes is the total file size in bytes. + FileSizeBytes() int64 + // ColumnSizes is a mapping from column id to the total size on disk + // of all regions that store the column. Does not include bytes + // necessary to read other columns, like footers. Map will be nil for + // row-oriented formats (avro). + ColumnSizes() map[int]int64 + // ValueCounts is a mapping from column id to the number of values + // in the column, including null and NaN values. + ValueCounts() map[int]int64 + // NullValueCounts is a mapping from column id to the number of + // null values in the column. + NullValueCounts() map[int]int64 + // NaNValueCounts is a mapping from column id to the number of NaN + // values in the column. + NaNValueCounts() map[int]int64 + // DistictValueCounts is a mapping from column id to the number of + // distinct values in the column. Distinct counts must be derived + // using values in the file by counting or using sketches, but not + // using methods like merging existing distinct counts. + DistinctValueCounts() map[int]int64 + // LowerBoundValues is a mapping from column id to the lower bounded + // value of the column, serialized as binary. Each value in the column + // must be less than or requal to all non-null, non-NaN values in the + // column for the file. + LowerBoundValues() map[int][]byte + // UpperBoundValues is a mapping from column id to the upper bounded + // value of the column, serialized as binary. Each value in the column + // must be greater than or equal to all non-null, non-NaN values in + // the column for the file. + UpperBoundValues() map[int][]byte + // KeyMetadata is implementation-specific key metadata for encryption. + KeyMetadata() []byte + // SplitOffsets are the split offsets for the data file. For example, + // all row group offsets in a Parquet file. Must be sorted ascending. + SplitOffsets() []int64 + // EqualityFieldIDs are used to determine row equality in equality + // delete files. It is required when the content type is + // EntryContentEqDeletes. + EqualityFieldIDs() []int + // SortOrderID returns the id representing the sort order for this + // file, or nil if there is no sort order. + SortOrderID() *int +} + +// ManifestEntry is an interface for both v1 and v2 manifest entries. +type ManifestEntry interface { + // Status returns the type of the file tracked by this entry. + // Deletes are informational only and not used in scans. + Status() ManifestEntryStatus + // SnapshotID is the id where the file was added, or deleted, + // if null it is inherited from the manifest list. + SnapshotID() int64 + // SequenceNum returns the data sequence number of the file. + // If it was null and the status is EntryStatusADDED then it + // is inherited from the manifest list. + SequenceNum() int64 + // FileSequenceNum returns the file sequence number indicating + // when the file was added. If it was null and the status is + // EntryStatusADDED then it is inherited from the manifest list. + FileSequenceNum() *int64 + // DataFile provides the information about the data file indicated + // by this manifest entry. + DataFile() DataFile + + inheritSeqNum(manifest ManifestFile) +} + +var PositionalDeleteSchema = NewSchema(0, + NestedField{ID: 2147483546, Type: PrimitiveTypes.String, Name: "file_path", Required: true}, + NestedField{ID: 2147483545, Type: PrimitiveTypes.Int32, Name: "pos", Required: true}, +) diff --git a/manifest_test.go b/manifest_test.go index 486c6a8..cfee6d2 100644 --- a/manifest_test.go +++ b/manifest_test.go @@ -1,772 +1,772 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package iceberg - -import ( - "bytes" - "testing" - "time" - - "github.com/apache/iceberg-go/internal" - "github.com/hamba/avro/v2/ocf" - "github.com/stretchr/testify/suite" -) - -var ( - falseBool = false - snapshotID int64 = 9182715666859759686 - addedRows int64 = 237993 - manifestFileRecordsV1 = []ManifestFile{ - NewManifestV1Builder("/home/iceberg/warehouse/nyc/taxis_partitioned/metadata/0125c686-8aa6-4502-bdcc-b6d17ca41a3b-m0.avro", - 7989, 0, snapshotID). - AddedFiles(3). - ExistingFiles(0). - DeletedFiles(0). - AddedRows(addedRows). - ExistingRows(0). - DeletedRows(0). - Partitions([]FieldSummary{{ - ContainsNull: true, ContainsNaN: &falseBool, - LowerBound: &[]byte{0x01, 0x00, 0x00, 0x00}, - UpperBound: &[]byte{0x02, 0x00, 0x00, 0x00}, - }}).Build()} - - manifestFileRecordsV2 = []ManifestFile{ - NewManifestV2Builder("/home/iceberg/warehouse/nyc/taxis_partitioned/metadata/0125c686-8aa6-4502-bdcc-b6d17ca41a3b-m0.avro", - 7989, 0, ManifestContentDeletes, snapshotID). - SequenceNum(3, 3). - AddedFiles(3). - ExistingFiles(0). - DeletedFiles(0). - AddedRows(addedRows). - ExistingRows(0). - DeletedRows(0). - Partitions([]FieldSummary{{ - ContainsNull: true, - ContainsNaN: &falseBool, - LowerBound: &[]byte{0x01, 0x00, 0x00, 0x00}, - UpperBound: &[]byte{0x02, 0x00, 0x00, 0x00}, - }}).Build()} - - entrySnapshotID int64 = 8744736658442914487 - intZero = 0 - manifestEntryV1Records = []*manifestEntryV1{ - { - EntryStatus: EntryStatusADDED, - Snapshot: entrySnapshotID, - Data: dataFile{ - // bad value for Content but this field doesn't exist in V1 - // so it shouldn't get written and shouldn't be read back out - // so the roundtrip test asserts that we get the default value - // back out. - Content: EntryContentEqDeletes, - Path: "/home/iceberg/warehouse/nyc/taxis_partitioned/data/VendorID=null/00000-633-d8a4223e-dc97-45a1-86e1-adaba6e8abd7-00001.parquet", - Format: ParquetFile, - PartitionData: map[string]any{"VendorID": int(1), "tpep_pickup_datetime": time.Unix(1925, 0)}, - RecordCount: 19513, - FileSize: 388872, - BlockSizeInBytes: 67108864, - ColSizes: &[]colMap[int, int64]{ - {Key: 1, Value: 53}, - {Key: 2, Value: 98153}, - {Key: 3, Value: 98693}, - {Key: 4, Value: 53}, - {Key: 5, Value: 53}, - {Key: 6, Value: 53}, - {Key: 7, Value: 17425}, - {Key: 8, Value: 18528}, - {Key: 9, Value: 53}, - {Key: 10, Value: 44788}, - {Key: 11, Value: 35571}, - {Key: 12, Value: 53}, - {Key: 13, Value: 1243}, - {Key: 14, Value: 2355}, - {Key: 15, Value: 12750}, - {Key: 16, Value: 4029}, - {Key: 17, Value: 110}, - {Key: 18, Value: 47194}, - {Key: 19, Value: 2948}, - }, - ValCounts: &[]colMap[int, int64]{ - {Key: 1, Value: 19513}, - {Key: 2, Value: 19513}, - {Key: 3, Value: 19513}, - {Key: 4, Value: 19513}, - {Key: 5, Value: 19513}, - {Key: 6, Value: 19513}, - {Key: 7, Value: 19513}, - {Key: 8, Value: 19513}, - {Key: 9, Value: 19513}, - {Key: 10, Value: 19513}, - {Key: 11, Value: 19513}, - {Key: 12, Value: 19513}, - {Key: 13, Value: 19513}, - {Key: 14, Value: 19513}, - {Key: 15, Value: 19513}, - {Key: 16, Value: 19513}, - {Key: 17, Value: 19513}, - {Key: 18, Value: 19513}, - {Key: 19, Value: 19513}, - }, - NullCounts: &[]colMap[int, int64]{ - {Key: 1, Value: 19513}, - {Key: 2, Value: 0}, - {Key: 3, Value: 0}, - {Key: 4, Value: 19513}, - {Key: 5, Value: 19513}, - {Key: 6, Value: 19513}, - {Key: 7, Value: 0}, - {Key: 8, Value: 0}, - {Key: 9, Value: 19513}, - {Key: 10, Value: 0}, - {Key: 11, Value: 0}, - {Key: 12, Value: 19513}, - {Key: 13, Value: 0}, - {Key: 14, Value: 0}, - {Key: 15, Value: 0}, - {Key: 16, Value: 0}, - {Key: 17, Value: 0}, - {Key: 18, Value: 0}, - {Key: 19, Value: 0}, - }, - NaNCounts: &[]colMap[int, int64]{ - {Key: 16, Value: 0}, - {Key: 17, Value: 0}, - {Key: 18, Value: 0}, - {Key: 19, Value: 0}, - {Key: 10, Value: 0}, - {Key: 11, Value: 0}, - {Key: 12, Value: 0}, - {Key: 13, Value: 0}, - {Key: 14, Value: 0}, - {Key: 15, Value: 0}, - }, - LowerBounds: &[]colMap[int, []byte]{ - {Key: 2, Value: []byte("2020-04-01 00:00")}, - {Key: 3, Value: []byte("2020-04-01 00:12")}, - {Key: 7, Value: []byte{0x03, 0x00, 0x00, 0x00}}, - {Key: 8, Value: []byte{0x01, 0x00, 0x00, 0x00}}, - {Key: 10, Value: []byte{0xf6, 0x28, 0x5c, 0x8f, 0xc2, 0x05, 'S', 0xc0}}, - {Key: 11, Value: []byte{0, 0, 0, 0, 0, 0, 0, 0}}, - {Key: 13, Value: []byte{0, 0, 0, 0, 0, 0, 0, 0}}, - {Key: 14, Value: []byte{0, 0, 0, 0, 0, 0, 0xe0, 0xbf}}, - {Key: 15, Value: []byte{')', '\\', 0x8f, 0xc2, 0xf5, '(', 0x08, 0xc0}}, - {Key: 16, Value: []byte{0, 0, 0, 0, 0, 0, 0, 0}}, - {Key: 17, Value: []byte{0, 0, 0, 0, 0, 0, 0, 0}}, - {Key: 18, Value: []byte{0xf6, '(', '\\', 0x8f, 0xc2, 0xc5, 'S', 0xc0}}, - {Key: 19, Value: []byte{0, 0, 0, 0, 0, 0, 0x04, 0xc0}}, - }, - UpperBounds: &[]colMap[int, []byte]{ - {Key: 2, Value: []byte("2020-04-30 23:5:")}, - {Key: 3, Value: []byte("2020-05-01 00:41")}, - {Key: 7, Value: []byte{'\t', 0x01, 0x00, 0x00}}, - {Key: 8, Value: []byte{'\t', 0x01, 0x00, 0x00}}, - {Key: 10, Value: []byte{0xcd, 0xcc, 0xcc, 0xcc, 0xcc, ',', '_', '@'}}, - {Key: 11, Value: []byte{0x1f, 0x85, 0xeb, 'Q', '\\', 0xe2, 0xfe, '@'}}, - {Key: 13, Value: []byte{0, 0, 0, 0, 0, 0, 0x12, '@'}}, - {Key: 14, Value: []byte{0, 0, 0, 0, 0, 0, 0xe0, '?'}}, - {Key: 15, Value: []byte{'q', '=', '\n', 0xd7, 0xa3, 0xf0, '1', '@'}}, - {Key: 16, Value: []byte{0, 0, 0, 0, 0, '`', 'B', '@'}}, - {Key: 17, Value: []byte{'3', '3', '3', '3', '3', '3', 0xd3, '?'}}, - {Key: 18, Value: []byte{0, 0, 0, 0, 0, 0x18, 'b', '@'}}, - {Key: 19, Value: []byte{0, 0, 0, 0, 0, 0, 0x04, '@'}}, - }, - Splits: &[]int64{4}, - SortOrder: &intZero, - }, - }, - { - EntryStatus: EntryStatusADDED, - Snapshot: 8744736658442914487, - Data: dataFile{ - Path: "/home/iceberg/warehouse/nyc/taxis_partitioned/data/VendorID=1/00000-633-d8a4223e-dc97-45a1-86e1-adaba6e8abd7-00002.parquet", - Format: ParquetFile, - PartitionData: map[string]any{"VendorID": int(1), "tpep_pickup_datetime": time.Unix(1925, 0)}, - RecordCount: 95050, - FileSize: 1265950, - BlockSizeInBytes: 67108864, - ColSizes: &[]colMap[int, int64]{ - {Key: 1, Value: 318}, - {Key: 2, Value: 329806}, - {Key: 3, Value: 331632}, - {Key: 4, Value: 15343}, - {Key: 5, Value: 2351}, - {Key: 6, Value: 3389}, - {Key: 7, Value: 71269}, - {Key: 8, Value: 76429}, - {Key: 9, Value: 16383}, - {Key: 10, Value: 86992}, - {Key: 11, Value: 89608}, - {Key: 12, Value: 265}, - {Key: 13, Value: 19377}, - {Key: 14, Value: 1692}, - {Key: 15, Value: 76162}, - {Key: 16, Value: 4354}, - {Key: 17, Value: 759}, - {Key: 18, Value: 120650}, - {Key: 19, Value: 11804}, - }, - ValCounts: &[]colMap[int, int64]{ - {Key: 1, Value: 95050}, - {Key: 2, Value: 95050}, - {Key: 3, Value: 95050}, - {Key: 4, Value: 95050}, - {Key: 5, Value: 95050}, - {Key: 6, Value: 95050}, - {Key: 7, Value: 95050}, - {Key: 8, Value: 95050}, - {Key: 9, Value: 95050}, - {Key: 10, Value: 95050}, - {Key: 11, Value: 95050}, - {Key: 12, Value: 95050}, - {Key: 13, Value: 95050}, - {Key: 14, Value: 95050}, - {Key: 15, Value: 95050}, - {Key: 16, Value: 95050}, - {Key: 17, Value: 95050}, - {Key: 18, Value: 95050}, - {Key: 19, Value: 95050}, - }, - NullCounts: &[]colMap[int, int64]{ - {Key: 1, Value: 0}, - {Key: 2, Value: 0}, - {Key: 3, Value: 0}, - {Key: 4, Value: 0}, - {Key: 5, Value: 0}, - {Key: 6, Value: 0}, - {Key: 7, Value: 0}, - {Key: 8, Value: 0}, - {Key: 9, Value: 0}, - {Key: 10, Value: 0}, - {Key: 11, Value: 0}, - {Key: 12, Value: 95050}, - {Key: 13, Value: 0}, - {Key: 14, Value: 0}, - {Key: 15, Value: 0}, - {Key: 16, Value: 0}, - {Key: 17, Value: 0}, - {Key: 18, Value: 0}, - {Key: 19, Value: 0}, - }, - NaNCounts: &[]colMap[int, int64]{ - {Key: 16, Value: 0}, - {Key: 17, Value: 0}, - {Key: 18, Value: 0}, - {Key: 19, Value: 0}, - {Key: 10, Value: 0}, - {Key: 11, Value: 0}, - {Key: 12, Value: 0}, - {Key: 13, Value: 0}, - {Key: 14, Value: 0}, - {Key: 15, Value: 0}, - }, - LowerBounds: &[]colMap[int, []byte]{ - {Key: 1, Value: []byte{0x01, 0x00, 0x00, 0x00}}, - {Key: 2, Value: []byte("2020-04-01 00:00")}, - {Key: 3, Value: []byte("2020-04-01 00:13")}, - {Key: 4, Value: []byte{0x00, 0x00, 0x00, 0x00}}, - {Key: 5, Value: []byte{0x01, 0x00, 0x00, 0x00}}, - {Key: 6, Value: []byte("N")}, - {Key: 7, Value: []byte{0x01, 0x00, 0x00, 0x00}}, - {Key: 8, Value: []byte{0x01, 0x00, 0x00, 0x00}}, - {Key: 9, Value: []byte{0x01, 0x00, 0x00, 0x00}}, - {Key: 10, Value: []byte{0, 0, 0, 0, 0, 0, 0, 0}}, - {Key: 11, Value: []byte{0, 0, 0, 0, 0, 0, 0, 0}}, - {Key: 13, Value: []byte{0, 0, 0, 0, 0, 0, 0, 0}}, - {Key: 14, Value: []byte{0, 0, 0, 0, 0, 0, 0, 0}}, - {Key: 15, Value: []byte{0, 0, 0, 0, 0, 0, 0, 0}}, - {Key: 16, Value: []byte{0, 0, 0, 0, 0, 0, 0, 0}}, - {Key: 17, Value: []byte{0, 0, 0, 0, 0, 0, 0, 0}}, - {Key: 18, Value: []byte{0, 0, 0, 0, 0, 0, 0, 0}}, - {Key: 19, Value: []byte{0, 0, 0, 0, 0, 0, 0, 0}}, - }, - UpperBounds: &[]colMap[int, []byte]{ - {Key: 1, Value: []byte{0x01, 0x00, 0x00, 0x00}}, - {Key: 2, Value: []byte("2020-04-30 23:5:")}, - {Key: 3, Value: []byte("2020-05-01 00:1:")}, - {Key: 4, Value: []byte{0x06, 0x00, 0x00, 0x00}}, - {Key: 5, Value: []byte{'c', 0x00, 0x00, 0x00}}, - {Key: 6, Value: []byte("Y")}, - {Key: 7, Value: []byte{'\t', 0x01, 0x00, 0x00}}, - {Key: 8, Value: []byte{'\t', 0x01, 0x00, 0x00}}, - {Key: 9, Value: []byte{0x04, 0x01, 0x00, 0x00}}, - {Key: 10, Value: []byte{'\\', 0x8f, 0xc2, 0xf5, '(', '8', 0x8c, '@'}}, - {Key: 11, Value: []byte{0xcd, 0xcc, 0xcc, 0xcc, 0xcc, ',', 'f', '@'}}, - {Key: 13, Value: []byte{0, 0, 0, 0, 0, 0, 0x1c, '@'}}, - {Key: 14, Value: []byte{0x9a, 0x99, 0x99, 0x99, 0x99, 0x99, 0xf1, '?'}}, - {Key: 15, Value: []byte{0, 0, 0, 0, 0, 0, 'Y', '@'}}, - {Key: 16, Value: []byte{0, 0, 0, 0, 0, 0xb0, 'X', '@'}}, - {Key: 17, Value: []byte{'3', '3', '3', '3', '3', '3', 0xd3, '?'}}, - {Key: 18, Value: []byte{0xc3, 0xf5, '(', '\\', 0x8f, ':', 0x8c, '@'}}, - {Key: 19, Value: []byte{0, 0, 0, 0, 0, 0, 0x04, '@'}}, - }, - Splits: &[]int64{4}, - SortOrder: &intZero, - }, - }, - } - - manifestEntryV2Records = []*manifestEntryV2{ - { - EntryStatus: EntryStatusADDED, - Snapshot: &entrySnapshotID, - Data: dataFile{ - Path: manifestEntryV1Records[0].Data.Path, - Format: manifestEntryV1Records[0].Data.Format, - PartitionData: manifestEntryV1Records[0].Data.PartitionData, - RecordCount: manifestEntryV1Records[0].Data.RecordCount, - FileSize: manifestEntryV1Records[0].Data.FileSize, - BlockSizeInBytes: manifestEntryV1Records[0].Data.BlockSizeInBytes, - ColSizes: manifestEntryV1Records[0].Data.ColSizes, - ValCounts: manifestEntryV1Records[0].Data.ValCounts, - NullCounts: manifestEntryV1Records[0].Data.NullCounts, - NaNCounts: manifestEntryV1Records[0].Data.NaNCounts, - LowerBounds: manifestEntryV1Records[0].Data.LowerBounds, - UpperBounds: manifestEntryV1Records[0].Data.UpperBounds, - Splits: manifestEntryV1Records[0].Data.Splits, - SortOrder: manifestEntryV1Records[0].Data.SortOrder, - }, - }, - { - EntryStatus: EntryStatusADDED, - Snapshot: &entrySnapshotID, - Data: dataFile{ - Path: manifestEntryV1Records[1].Data.Path, - Format: manifestEntryV1Records[1].Data.Format, - PartitionData: manifestEntryV1Records[1].Data.PartitionData, - RecordCount: manifestEntryV1Records[1].Data.RecordCount, - FileSize: manifestEntryV1Records[1].Data.FileSize, - BlockSizeInBytes: manifestEntryV1Records[1].Data.BlockSizeInBytes, - ColSizes: manifestEntryV1Records[1].Data.ColSizes, - ValCounts: manifestEntryV1Records[1].Data.ValCounts, - NullCounts: manifestEntryV1Records[1].Data.NullCounts, - NaNCounts: manifestEntryV1Records[1].Data.NaNCounts, - LowerBounds: manifestEntryV1Records[1].Data.LowerBounds, - UpperBounds: manifestEntryV1Records[1].Data.UpperBounds, - Splits: manifestEntryV1Records[1].Data.Splits, - SortOrder: manifestEntryV1Records[1].Data.SortOrder, - }, - }, - } -) - -type ManifestTestSuite struct { - suite.Suite - - v1ManifestList bytes.Buffer - v1ManifestEntries bytes.Buffer - - v2ManifestList bytes.Buffer - v2ManifestEntries bytes.Buffer -} - -func (m *ManifestTestSuite) writeManifestList() { - enc, err := ocf.NewEncoder(internal.AvroSchemaCache.Get(internal.ManifestListV1Key).String(), - &m.v1ManifestList, ocf.WithMetadata(map[string][]byte{ - "avro.codec": []byte("deflate"), - }), - ocf.WithCodec(ocf.Deflate)) - m.Require().NoError(err) - - m.Require().NoError(enc.Encode(manifestFileRecordsV1[0])) - enc.Close() - - enc, err = ocf.NewEncoder(internal.AvroSchemaCache.Get(internal.ManifestListV2Key).String(), - &m.v2ManifestList, ocf.WithMetadata(map[string][]byte{ - "format-version": []byte("2"), - "avro.codec": []byte("deflate"), - }), ocf.WithCodec(ocf.Deflate)) - m.Require().NoError(err) - - m.Require().NoError(enc.Encode(manifestFileRecordsV2[0])) - enc.Close() -} - -func (m *ManifestTestSuite) writeManifestEntries() { - enc, err := ocf.NewEncoder(internal.AvroSchemaCache.Get(internal.ManifestEntryV1Key).String(), &m.v1ManifestEntries, - ocf.WithMetadata(map[string][]byte{ - "format-version": []byte("1"), - }), ocf.WithCodec(ocf.Deflate)) - m.Require().NoError(err) - - for _, ent := range manifestEntryV1Records { - m.Require().NoError(enc.Encode(ent)) - } - m.Require().NoError(enc.Close()) - - enc, err = ocf.NewEncoder(internal.AvroSchemaCache.Get(internal.ManifestEntryV2Key).String(), - &m.v2ManifestEntries, ocf.WithMetadata(map[string][]byte{ - "format-version": []byte("2"), - "avro.codec": []byte("deflate"), - }), ocf.WithCodec(ocf.Deflate)) - m.Require().NoError(err) - - for _, ent := range manifestEntryV2Records { - m.Require().NoError(enc.Encode(ent)) - } - m.Require().NoError(enc.Close()) -} - -func (m *ManifestTestSuite) SetupSuite() { - m.writeManifestList() - m.writeManifestEntries() -} - -func (m *ManifestTestSuite) TestManifestEntriesV1() { - var mockfs internal.MockFS - manifest := manifestFileV1{ - Path: manifestFileRecordsV1[0].FilePath(), - } - - mockfs.Test(m.T()) - mockfs.On("Open", manifest.FilePath()).Return(&internal.MockFile{ - Contents: bytes.NewReader(m.v1ManifestEntries.Bytes())}, nil) - defer mockfs.AssertExpectations(m.T()) - entries, err := manifest.FetchEntries(&mockfs, false) - m.Require().NoError(err) - m.Len(entries, 2) - m.Zero(manifest.PartitionSpecID()) - m.Zero(manifest.SnapshotID()) - m.Zero(manifest.AddedDataFiles()) - m.Zero(manifest.ExistingDataFiles()) - m.Zero(manifest.DeletedDataFiles()) - m.Zero(manifest.ExistingRows()) - m.Zero(manifest.DeletedRows()) - m.Zero(manifest.AddedRows()) - - entry1 := entries[0] - - m.Equal(EntryStatusADDED, entry1.Status()) - m.EqualValues(8744736658442914487, entry1.SnapshotID()) - m.Zero(entry1.SequenceNum()) - m.Nil(entry1.FileSequenceNum()) - - datafile := entry1.DataFile() - m.Equal(EntryContentData, datafile.ContentType()) - m.Equal("/home/iceberg/warehouse/nyc/taxis_partitioned/data/VendorID=null/00000-633-d8a4223e-dc97-45a1-86e1-adaba6e8abd7-00001.parquet", datafile.FilePath()) - m.Equal(ParquetFile, datafile.FileFormat()) - m.EqualValues(19513, datafile.Count()) - m.EqualValues(388872, datafile.FileSizeBytes()) - m.Equal(map[int]int64{ - 1: 53, - 2: 98153, - 3: 98693, - 4: 53, - 5: 53, - 6: 53, - 7: 17425, - 8: 18528, - 9: 53, - 10: 44788, - 11: 35571, - 12: 53, - 13: 1243, - 14: 2355, - 15: 12750, - 16: 4029, - 17: 110, - 18: 47194, - 19: 2948, - }, datafile.ColumnSizes()) - m.Equal(map[int]int64{ - 1: 19513, - 2: 19513, - 3: 19513, - 4: 19513, - 5: 19513, - 6: 19513, - 7: 19513, - 8: 19513, - 9: 19513, - 10: 19513, - 11: 19513, - 12: 19513, - 13: 19513, - 14: 19513, - 15: 19513, - 16: 19513, - 17: 19513, - 18: 19513, - 19: 19513, - }, datafile.ValueCounts()) - m.Equal(map[int]int64{ - 1: 19513, - 2: 0, - 3: 0, - 4: 19513, - 5: 19513, - 6: 19513, - 7: 0, - 8: 0, - 9: 19513, - 10: 0, - 11: 0, - 12: 19513, - 13: 0, - 14: 0, - 15: 0, - 16: 0, - 17: 0, - 18: 0, - 19: 0, - }, datafile.NullValueCounts()) - m.Equal(map[int]int64{ - 16: 0, 17: 0, 18: 0, 19: 0, 10: 0, 11: 0, 12: 0, 13: 0, 14: 0, 15: 0, - }, datafile.NaNValueCounts()) - - m.Equal(map[int][]byte{ - 2: []byte("2020-04-01 00:00"), - 3: []byte("2020-04-01 00:12"), - 7: {0x03, 0x00, 0x00, 0x00}, - 8: {0x01, 0x00, 0x00, 0x00}, - 10: {0xf6, '(', '\\', 0x8f, 0xc2, 0x05, 'S', 0xc0}, - 11: {0, 0, 0, 0, 0, 0, 0, 0}, - 13: {0, 0, 0, 0, 0, 0, 0, 0}, - 14: {0, 0, 0, 0, 0, 0, 0xe0, 0xbf}, - 15: {')', '\\', 0x8f, 0xc2, 0xf5, '(', 0x08, 0xc0}, - 16: {0, 0, 0, 0, 0, 0, 0, 0}, - 17: {0, 0, 0, 0, 0, 0, 0, 0}, - 18: {0xf6, '(', '\\', 0x8f, 0xc2, 0xc5, 'S', 0xc0}, - 19: {0, 0, 0, 0, 0, 0, 0x04, 0xc0}, - }, datafile.LowerBoundValues()) - - m.Equal(map[int][]byte{ - 2: []byte("2020-04-30 23:5:"), - 3: []byte("2020-05-01 00:41"), - 7: {'\t', 0x01, 0, 0}, - 8: {'\t', 0x01, 0, 0}, - 10: {0xcd, 0xcc, 0xcc, 0xcc, 0xcc, ',', '_', '@'}, - 11: {0x1f, 0x85, 0xeb, 'Q', '\\', 0xe2, 0xfe, '@'}, - 13: {0, 0, 0, 0, 0, 0, 0x12, '@'}, - 14: {0, 0, 0, 0, 0, 0, 0xe0, '?'}, - 15: {'q', '=', '\n', 0xd7, 0xa3, 0xf0, '1', '@'}, - 16: {0, 0, 0, 0, 0, '`', 'B', '@'}, - 17: {'3', '3', '3', '3', '3', '3', 0xd3, '?'}, - 18: {0, 0, 0, 0, 0, 0x18, 'b', '@'}, - 19: {0, 0, 0, 0, 0, 0, 0x04, '@'}, - }, datafile.UpperBoundValues()) - - m.Nil(datafile.KeyMetadata()) - m.Equal([]int64{4}, datafile.SplitOffsets()) - m.Nil(datafile.EqualityFieldIDs()) - m.Zero(*datafile.SortOrderID()) -} - -func (m *ManifestTestSuite) TestReadManifestListV1() { - list, err := ReadManifestList(&m.v1ManifestList) - m.Require().NoError(err) - - m.Len(list, 1) - m.Equal(1, list[0].Version()) - m.EqualValues(7989, list[0].Length()) - m.Equal(ManifestContentData, list[0].ManifestContent()) - m.Zero(list[0].SequenceNum()) - m.Zero(list[0].MinSequenceNum()) - m.EqualValues(9182715666859759686, list[0].SnapshotID()) - m.EqualValues(3, list[0].AddedDataFiles()) - m.True(list[0].HasAddedFiles()) - m.Zero(list[0].ExistingDataFiles()) - m.False(list[0].HasExistingFiles()) - m.Zero(list[0].DeletedDataFiles()) - m.Equal(addedRows, list[0].AddedRows()) - m.Zero(list[0].ExistingRows()) - m.Zero(list[0].DeletedRows()) - m.Nil(list[0].KeyMetadata()) - m.Zero(list[0].PartitionSpecID()) - m.Equal(snapshotID, list[0].SnapshotID()) - - part := list[0].Partitions()[0] - m.True(part.ContainsNull) - m.False(*part.ContainsNaN) - m.Equal([]byte{0x01, 0x00, 0x00, 0x00}, *part.LowerBound) - m.Equal([]byte{0x02, 0x00, 0x00, 0x00}, *part.UpperBound) -} - -func (m *ManifestTestSuite) TestReadManifestListV2() { - list, err := ReadManifestList(&m.v2ManifestList) - m.Require().NoError(err) - - m.Equal("/home/iceberg/warehouse/nyc/taxis_partitioned/metadata/0125c686-8aa6-4502-bdcc-b6d17ca41a3b-m0.avro", list[0].FilePath()) - m.Len(list, 1) - m.Equal(2, list[0].Version()) - m.EqualValues(7989, list[0].Length()) - m.Equal(ManifestContentDeletes, list[0].ManifestContent()) - m.EqualValues(3, list[0].SequenceNum()) - m.EqualValues(3, list[0].MinSequenceNum()) - m.EqualValues(9182715666859759686, list[0].SnapshotID()) - m.EqualValues(3, list[0].AddedDataFiles()) - m.True(list[0].HasAddedFiles()) - m.Zero(list[0].ExistingDataFiles()) - m.False(list[0].HasExistingFiles()) - m.Zero(list[0].DeletedDataFiles()) - m.Equal(addedRows, list[0].AddedRows()) - m.Zero(list[0].ExistingRows()) - m.Zero(list[0].DeletedRows()) - m.Nil(list[0].KeyMetadata()) - m.Zero(list[0].PartitionSpecID()) - - part := list[0].Partitions()[0] - m.True(part.ContainsNull) - m.False(*part.ContainsNaN) - m.Equal([]byte{0x01, 0x00, 0x00, 0x00}, *part.LowerBound) - m.Equal([]byte{0x02, 0x00, 0x00, 0x00}, *part.UpperBound) -} - -func (m *ManifestTestSuite) TestManifestEntriesV2() { - var mockfs internal.MockFS - manifest := manifestFileV2{ - Path: manifestFileRecordsV2[0].FilePath(), - } - - mockfs.Test(m.T()) - mockfs.On("Open", manifest.FilePath()).Return(&internal.MockFile{ - Contents: bytes.NewReader(m.v2ManifestEntries.Bytes())}, nil) - defer mockfs.AssertExpectations(m.T()) - entries, err := manifest.FetchEntries(&mockfs, false) - m.Require().NoError(err) - m.Len(entries, 2) - m.Zero(manifest.PartitionSpecID()) - m.Zero(manifest.SnapshotID()) - m.Zero(manifest.AddedDataFiles()) - m.Zero(manifest.ExistingDataFiles()) - m.Zero(manifest.DeletedDataFiles()) - m.Zero(manifest.ExistingRows()) - m.Zero(manifest.DeletedRows()) - m.Zero(manifest.AddedRows()) - - entry1 := entries[0] - - m.Equal(EntryStatusADDED, entry1.Status()) - m.Equal(entrySnapshotID, entry1.SnapshotID()) - m.Zero(entry1.SequenceNum()) - m.Zero(*entry1.FileSequenceNum()) - - datafile := entry1.DataFile() - m.Equal(EntryContentData, datafile.ContentType()) - m.Equal("/home/iceberg/warehouse/nyc/taxis_partitioned/data/VendorID=null/00000-633-d8a4223e-dc97-45a1-86e1-adaba6e8abd7-00001.parquet", datafile.FilePath()) - m.Equal(ParquetFile, datafile.FileFormat()) - m.EqualValues(19513, datafile.Count()) - m.EqualValues(388872, datafile.FileSizeBytes()) - m.Equal(map[int]int64{ - 1: 53, - 2: 98153, - 3: 98693, - 4: 53, - 5: 53, - 6: 53, - 7: 17425, - 8: 18528, - 9: 53, - 10: 44788, - 11: 35571, - 12: 53, - 13: 1243, - 14: 2355, - 15: 12750, - 16: 4029, - 17: 110, - 18: 47194, - 19: 2948, - }, datafile.ColumnSizes()) - m.Equal(map[int]int64{ - 1: 19513, - 2: 19513, - 3: 19513, - 4: 19513, - 5: 19513, - 6: 19513, - 7: 19513, - 8: 19513, - 9: 19513, - 10: 19513, - 11: 19513, - 12: 19513, - 13: 19513, - 14: 19513, - 15: 19513, - 16: 19513, - 17: 19513, - 18: 19513, - 19: 19513, - }, datafile.ValueCounts()) - m.Equal(map[int]int64{ - 1: 19513, - 2: 0, - 3: 0, - 4: 19513, - 5: 19513, - 6: 19513, - 7: 0, - 8: 0, - 9: 19513, - 10: 0, - 11: 0, - 12: 19513, - 13: 0, - 14: 0, - 15: 0, - 16: 0, - 17: 0, - 18: 0, - 19: 0, - }, datafile.NullValueCounts()) - m.Equal(map[int]int64{ - 16: 0, 17: 0, 18: 0, 19: 0, 10: 0, 11: 0, 12: 0, 13: 0, 14: 0, 15: 0, - }, datafile.NaNValueCounts()) - - m.Equal(map[int][]byte{ - 2: []byte("2020-04-01 00:00"), - 3: []byte("2020-04-01 00:12"), - 7: {0x03, 0x00, 0x00, 0x00}, - 8: {0x01, 0x00, 0x00, 0x00}, - 10: {0xf6, '(', '\\', 0x8f, 0xc2, 0x05, 'S', 0xc0}, - 11: {0, 0, 0, 0, 0, 0, 0, 0}, - 13: {0, 0, 0, 0, 0, 0, 0, 0}, - 14: {0, 0, 0, 0, 0, 0, 0xe0, 0xbf}, - 15: {')', '\\', 0x8f, 0xc2, 0xf5, '(', 0x08, 0xc0}, - 16: {0, 0, 0, 0, 0, 0, 0, 0}, - 17: {0, 0, 0, 0, 0, 0, 0, 0}, - 18: {0xf6, '(', '\\', 0x8f, 0xc2, 0xc5, 'S', 0xc0}, - 19: {0, 0, 0, 0, 0, 0, 0x04, 0xc0}, - }, datafile.LowerBoundValues()) - - m.Equal(map[int][]byte{ - 2: []byte("2020-04-30 23:5:"), - 3: []byte("2020-05-01 00:41"), - 7: {'\t', 0x01, 0, 0}, - 8: {'\t', 0x01, 0, 0}, - 10: {0xcd, 0xcc, 0xcc, 0xcc, 0xcc, ',', '_', '@'}, - 11: {0x1f, 0x85, 0xeb, 'Q', '\\', 0xe2, 0xfe, '@'}, - 13: {0, 0, 0, 0, 0, 0, 0x12, '@'}, - 14: {0, 0, 0, 0, 0, 0, 0xe0, '?'}, - 15: {'q', '=', '\n', 0xd7, 0xa3, 0xf0, '1', '@'}, - 16: {0, 0, 0, 0, 0, '`', 'B', '@'}, - 17: {'3', '3', '3', '3', '3', '3', 0xd3, '?'}, - 18: {0, 0, 0, 0, 0, 0x18, 'b', '@'}, - 19: {0, 0, 0, 0, 0, 0, 0x04, '@'}, - }, datafile.UpperBoundValues()) - - m.Nil(datafile.KeyMetadata()) - m.Equal([]int64{4}, datafile.SplitOffsets()) - m.Nil(datafile.EqualityFieldIDs()) - m.Zero(*datafile.SortOrderID()) -} - -func TestManifests(t *testing.T) { - suite.Run(t, new(ManifestTestSuite)) -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package iceberg + +import ( + "bytes" + "testing" + "time" + + "github.com/apache/iceberg-go/internal" + "github.com/hamba/avro/v2/ocf" + "github.com/stretchr/testify/suite" +) + +var ( + falseBool = false + snapshotID int64 = 9182715666859759686 + addedRows int64 = 237993 + manifestFileRecordsV1 = []ManifestFile{ + NewManifestV1Builder("/home/iceberg/warehouse/nyc/taxis_partitioned/metadata/0125c686-8aa6-4502-bdcc-b6d17ca41a3b-m0.avro", + 7989, 0, snapshotID). + AddedFiles(3). + ExistingFiles(0). + DeletedFiles(0). + AddedRows(addedRows). + ExistingRows(0). + DeletedRows(0). + Partitions([]FieldSummary{{ + ContainsNull: true, ContainsNaN: &falseBool, + LowerBound: &[]byte{0x01, 0x00, 0x00, 0x00}, + UpperBound: &[]byte{0x02, 0x00, 0x00, 0x00}, + }}).Build()} + + manifestFileRecordsV2 = []ManifestFile{ + NewManifestV2Builder("/home/iceberg/warehouse/nyc/taxis_partitioned/metadata/0125c686-8aa6-4502-bdcc-b6d17ca41a3b-m0.avro", + 7989, 0, ManifestContentDeletes, snapshotID). + SequenceNum(3, 3). + AddedFiles(3). + ExistingFiles(0). + DeletedFiles(0). + AddedRows(addedRows). + ExistingRows(0). + DeletedRows(0). + Partitions([]FieldSummary{{ + ContainsNull: true, + ContainsNaN: &falseBool, + LowerBound: &[]byte{0x01, 0x00, 0x00, 0x00}, + UpperBound: &[]byte{0x02, 0x00, 0x00, 0x00}, + }}).Build()} + + entrySnapshotID int64 = 8744736658442914487 + intZero = 0 + manifestEntryV1Records = []*manifestEntryV1{ + { + EntryStatus: EntryStatusADDED, + Snapshot: entrySnapshotID, + Data: dataFile{ + // bad value for Content but this field doesn't exist in V1 + // so it shouldn't get written and shouldn't be read back out + // so the roundtrip test asserts that we get the default value + // back out. + Content: EntryContentEqDeletes, + Path: "/home/iceberg/warehouse/nyc/taxis_partitioned/data/VendorID=null/00000-633-d8a4223e-dc97-45a1-86e1-adaba6e8abd7-00001.parquet", + Format: ParquetFile, + PartitionData: map[string]any{"VendorID": int(1), "tpep_pickup_datetime": time.Unix(1925, 0)}, + RecordCount: 19513, + FileSize: 388872, + BlockSizeInBytes: 67108864, + ColSizes: &[]colMap[int, int64]{ + {Key: 1, Value: 53}, + {Key: 2, Value: 98153}, + {Key: 3, Value: 98693}, + {Key: 4, Value: 53}, + {Key: 5, Value: 53}, + {Key: 6, Value: 53}, + {Key: 7, Value: 17425}, + {Key: 8, Value: 18528}, + {Key: 9, Value: 53}, + {Key: 10, Value: 44788}, + {Key: 11, Value: 35571}, + {Key: 12, Value: 53}, + {Key: 13, Value: 1243}, + {Key: 14, Value: 2355}, + {Key: 15, Value: 12750}, + {Key: 16, Value: 4029}, + {Key: 17, Value: 110}, + {Key: 18, Value: 47194}, + {Key: 19, Value: 2948}, + }, + ValCounts: &[]colMap[int, int64]{ + {Key: 1, Value: 19513}, + {Key: 2, Value: 19513}, + {Key: 3, Value: 19513}, + {Key: 4, Value: 19513}, + {Key: 5, Value: 19513}, + {Key: 6, Value: 19513}, + {Key: 7, Value: 19513}, + {Key: 8, Value: 19513}, + {Key: 9, Value: 19513}, + {Key: 10, Value: 19513}, + {Key: 11, Value: 19513}, + {Key: 12, Value: 19513}, + {Key: 13, Value: 19513}, + {Key: 14, Value: 19513}, + {Key: 15, Value: 19513}, + {Key: 16, Value: 19513}, + {Key: 17, Value: 19513}, + {Key: 18, Value: 19513}, + {Key: 19, Value: 19513}, + }, + NullCounts: &[]colMap[int, int64]{ + {Key: 1, Value: 19513}, + {Key: 2, Value: 0}, + {Key: 3, Value: 0}, + {Key: 4, Value: 19513}, + {Key: 5, Value: 19513}, + {Key: 6, Value: 19513}, + {Key: 7, Value: 0}, + {Key: 8, Value: 0}, + {Key: 9, Value: 19513}, + {Key: 10, Value: 0}, + {Key: 11, Value: 0}, + {Key: 12, Value: 19513}, + {Key: 13, Value: 0}, + {Key: 14, Value: 0}, + {Key: 15, Value: 0}, + {Key: 16, Value: 0}, + {Key: 17, Value: 0}, + {Key: 18, Value: 0}, + {Key: 19, Value: 0}, + }, + NaNCounts: &[]colMap[int, int64]{ + {Key: 16, Value: 0}, + {Key: 17, Value: 0}, + {Key: 18, Value: 0}, + {Key: 19, Value: 0}, + {Key: 10, Value: 0}, + {Key: 11, Value: 0}, + {Key: 12, Value: 0}, + {Key: 13, Value: 0}, + {Key: 14, Value: 0}, + {Key: 15, Value: 0}, + }, + LowerBounds: &[]colMap[int, []byte]{ + {Key: 2, Value: []byte("2020-04-01 00:00")}, + {Key: 3, Value: []byte("2020-04-01 00:12")}, + {Key: 7, Value: []byte{0x03, 0x00, 0x00, 0x00}}, + {Key: 8, Value: []byte{0x01, 0x00, 0x00, 0x00}}, + {Key: 10, Value: []byte{0xf6, 0x28, 0x5c, 0x8f, 0xc2, 0x05, 'S', 0xc0}}, + {Key: 11, Value: []byte{0, 0, 0, 0, 0, 0, 0, 0}}, + {Key: 13, Value: []byte{0, 0, 0, 0, 0, 0, 0, 0}}, + {Key: 14, Value: []byte{0, 0, 0, 0, 0, 0, 0xe0, 0xbf}}, + {Key: 15, Value: []byte{')', '\\', 0x8f, 0xc2, 0xf5, '(', 0x08, 0xc0}}, + {Key: 16, Value: []byte{0, 0, 0, 0, 0, 0, 0, 0}}, + {Key: 17, Value: []byte{0, 0, 0, 0, 0, 0, 0, 0}}, + {Key: 18, Value: []byte{0xf6, '(', '\\', 0x8f, 0xc2, 0xc5, 'S', 0xc0}}, + {Key: 19, Value: []byte{0, 0, 0, 0, 0, 0, 0x04, 0xc0}}, + }, + UpperBounds: &[]colMap[int, []byte]{ + {Key: 2, Value: []byte("2020-04-30 23:5:")}, + {Key: 3, Value: []byte("2020-05-01 00:41")}, + {Key: 7, Value: []byte{'\t', 0x01, 0x00, 0x00}}, + {Key: 8, Value: []byte{'\t', 0x01, 0x00, 0x00}}, + {Key: 10, Value: []byte{0xcd, 0xcc, 0xcc, 0xcc, 0xcc, ',', '_', '@'}}, + {Key: 11, Value: []byte{0x1f, 0x85, 0xeb, 'Q', '\\', 0xe2, 0xfe, '@'}}, + {Key: 13, Value: []byte{0, 0, 0, 0, 0, 0, 0x12, '@'}}, + {Key: 14, Value: []byte{0, 0, 0, 0, 0, 0, 0xe0, '?'}}, + {Key: 15, Value: []byte{'q', '=', '\n', 0xd7, 0xa3, 0xf0, '1', '@'}}, + {Key: 16, Value: []byte{0, 0, 0, 0, 0, '`', 'B', '@'}}, + {Key: 17, Value: []byte{'3', '3', '3', '3', '3', '3', 0xd3, '?'}}, + {Key: 18, Value: []byte{0, 0, 0, 0, 0, 0x18, 'b', '@'}}, + {Key: 19, Value: []byte{0, 0, 0, 0, 0, 0, 0x04, '@'}}, + }, + Splits: &[]int64{4}, + SortOrder: &intZero, + }, + }, + { + EntryStatus: EntryStatusADDED, + Snapshot: 8744736658442914487, + Data: dataFile{ + Path: "/home/iceberg/warehouse/nyc/taxis_partitioned/data/VendorID=1/00000-633-d8a4223e-dc97-45a1-86e1-adaba6e8abd7-00002.parquet", + Format: ParquetFile, + PartitionData: map[string]any{"VendorID": int(1), "tpep_pickup_datetime": time.Unix(1925, 0)}, + RecordCount: 95050, + FileSize: 1265950, + BlockSizeInBytes: 67108864, + ColSizes: &[]colMap[int, int64]{ + {Key: 1, Value: 318}, + {Key: 2, Value: 329806}, + {Key: 3, Value: 331632}, + {Key: 4, Value: 15343}, + {Key: 5, Value: 2351}, + {Key: 6, Value: 3389}, + {Key: 7, Value: 71269}, + {Key: 8, Value: 76429}, + {Key: 9, Value: 16383}, + {Key: 10, Value: 86992}, + {Key: 11, Value: 89608}, + {Key: 12, Value: 265}, + {Key: 13, Value: 19377}, + {Key: 14, Value: 1692}, + {Key: 15, Value: 76162}, + {Key: 16, Value: 4354}, + {Key: 17, Value: 759}, + {Key: 18, Value: 120650}, + {Key: 19, Value: 11804}, + }, + ValCounts: &[]colMap[int, int64]{ + {Key: 1, Value: 95050}, + {Key: 2, Value: 95050}, + {Key: 3, Value: 95050}, + {Key: 4, Value: 95050}, + {Key: 5, Value: 95050}, + {Key: 6, Value: 95050}, + {Key: 7, Value: 95050}, + {Key: 8, Value: 95050}, + {Key: 9, Value: 95050}, + {Key: 10, Value: 95050}, + {Key: 11, Value: 95050}, + {Key: 12, Value: 95050}, + {Key: 13, Value: 95050}, + {Key: 14, Value: 95050}, + {Key: 15, Value: 95050}, + {Key: 16, Value: 95050}, + {Key: 17, Value: 95050}, + {Key: 18, Value: 95050}, + {Key: 19, Value: 95050}, + }, + NullCounts: &[]colMap[int, int64]{ + {Key: 1, Value: 0}, + {Key: 2, Value: 0}, + {Key: 3, Value: 0}, + {Key: 4, Value: 0}, + {Key: 5, Value: 0}, + {Key: 6, Value: 0}, + {Key: 7, Value: 0}, + {Key: 8, Value: 0}, + {Key: 9, Value: 0}, + {Key: 10, Value: 0}, + {Key: 11, Value: 0}, + {Key: 12, Value: 95050}, + {Key: 13, Value: 0}, + {Key: 14, Value: 0}, + {Key: 15, Value: 0}, + {Key: 16, Value: 0}, + {Key: 17, Value: 0}, + {Key: 18, Value: 0}, + {Key: 19, Value: 0}, + }, + NaNCounts: &[]colMap[int, int64]{ + {Key: 16, Value: 0}, + {Key: 17, Value: 0}, + {Key: 18, Value: 0}, + {Key: 19, Value: 0}, + {Key: 10, Value: 0}, + {Key: 11, Value: 0}, + {Key: 12, Value: 0}, + {Key: 13, Value: 0}, + {Key: 14, Value: 0}, + {Key: 15, Value: 0}, + }, + LowerBounds: &[]colMap[int, []byte]{ + {Key: 1, Value: []byte{0x01, 0x00, 0x00, 0x00}}, + {Key: 2, Value: []byte("2020-04-01 00:00")}, + {Key: 3, Value: []byte("2020-04-01 00:13")}, + {Key: 4, Value: []byte{0x00, 0x00, 0x00, 0x00}}, + {Key: 5, Value: []byte{0x01, 0x00, 0x00, 0x00}}, + {Key: 6, Value: []byte("N")}, + {Key: 7, Value: []byte{0x01, 0x00, 0x00, 0x00}}, + {Key: 8, Value: []byte{0x01, 0x00, 0x00, 0x00}}, + {Key: 9, Value: []byte{0x01, 0x00, 0x00, 0x00}}, + {Key: 10, Value: []byte{0, 0, 0, 0, 0, 0, 0, 0}}, + {Key: 11, Value: []byte{0, 0, 0, 0, 0, 0, 0, 0}}, + {Key: 13, Value: []byte{0, 0, 0, 0, 0, 0, 0, 0}}, + {Key: 14, Value: []byte{0, 0, 0, 0, 0, 0, 0, 0}}, + {Key: 15, Value: []byte{0, 0, 0, 0, 0, 0, 0, 0}}, + {Key: 16, Value: []byte{0, 0, 0, 0, 0, 0, 0, 0}}, + {Key: 17, Value: []byte{0, 0, 0, 0, 0, 0, 0, 0}}, + {Key: 18, Value: []byte{0, 0, 0, 0, 0, 0, 0, 0}}, + {Key: 19, Value: []byte{0, 0, 0, 0, 0, 0, 0, 0}}, + }, + UpperBounds: &[]colMap[int, []byte]{ + {Key: 1, Value: []byte{0x01, 0x00, 0x00, 0x00}}, + {Key: 2, Value: []byte("2020-04-30 23:5:")}, + {Key: 3, Value: []byte("2020-05-01 00:1:")}, + {Key: 4, Value: []byte{0x06, 0x00, 0x00, 0x00}}, + {Key: 5, Value: []byte{'c', 0x00, 0x00, 0x00}}, + {Key: 6, Value: []byte("Y")}, + {Key: 7, Value: []byte{'\t', 0x01, 0x00, 0x00}}, + {Key: 8, Value: []byte{'\t', 0x01, 0x00, 0x00}}, + {Key: 9, Value: []byte{0x04, 0x01, 0x00, 0x00}}, + {Key: 10, Value: []byte{'\\', 0x8f, 0xc2, 0xf5, '(', '8', 0x8c, '@'}}, + {Key: 11, Value: []byte{0xcd, 0xcc, 0xcc, 0xcc, 0xcc, ',', 'f', '@'}}, + {Key: 13, Value: []byte{0, 0, 0, 0, 0, 0, 0x1c, '@'}}, + {Key: 14, Value: []byte{0x9a, 0x99, 0x99, 0x99, 0x99, 0x99, 0xf1, '?'}}, + {Key: 15, Value: []byte{0, 0, 0, 0, 0, 0, 'Y', '@'}}, + {Key: 16, Value: []byte{0, 0, 0, 0, 0, 0xb0, 'X', '@'}}, + {Key: 17, Value: []byte{'3', '3', '3', '3', '3', '3', 0xd3, '?'}}, + {Key: 18, Value: []byte{0xc3, 0xf5, '(', '\\', 0x8f, ':', 0x8c, '@'}}, + {Key: 19, Value: []byte{0, 0, 0, 0, 0, 0, 0x04, '@'}}, + }, + Splits: &[]int64{4}, + SortOrder: &intZero, + }, + }, + } + + manifestEntryV2Records = []*manifestEntryV2{ + { + EntryStatus: EntryStatusADDED, + Snapshot: &entrySnapshotID, + Data: dataFile{ + Path: manifestEntryV1Records[0].Data.Path, + Format: manifestEntryV1Records[0].Data.Format, + PartitionData: manifestEntryV1Records[0].Data.PartitionData, + RecordCount: manifestEntryV1Records[0].Data.RecordCount, + FileSize: manifestEntryV1Records[0].Data.FileSize, + BlockSizeInBytes: manifestEntryV1Records[0].Data.BlockSizeInBytes, + ColSizes: manifestEntryV1Records[0].Data.ColSizes, + ValCounts: manifestEntryV1Records[0].Data.ValCounts, + NullCounts: manifestEntryV1Records[0].Data.NullCounts, + NaNCounts: manifestEntryV1Records[0].Data.NaNCounts, + LowerBounds: manifestEntryV1Records[0].Data.LowerBounds, + UpperBounds: manifestEntryV1Records[0].Data.UpperBounds, + Splits: manifestEntryV1Records[0].Data.Splits, + SortOrder: manifestEntryV1Records[0].Data.SortOrder, + }, + }, + { + EntryStatus: EntryStatusADDED, + Snapshot: &entrySnapshotID, + Data: dataFile{ + Path: manifestEntryV1Records[1].Data.Path, + Format: manifestEntryV1Records[1].Data.Format, + PartitionData: manifestEntryV1Records[1].Data.PartitionData, + RecordCount: manifestEntryV1Records[1].Data.RecordCount, + FileSize: manifestEntryV1Records[1].Data.FileSize, + BlockSizeInBytes: manifestEntryV1Records[1].Data.BlockSizeInBytes, + ColSizes: manifestEntryV1Records[1].Data.ColSizes, + ValCounts: manifestEntryV1Records[1].Data.ValCounts, + NullCounts: manifestEntryV1Records[1].Data.NullCounts, + NaNCounts: manifestEntryV1Records[1].Data.NaNCounts, + LowerBounds: manifestEntryV1Records[1].Data.LowerBounds, + UpperBounds: manifestEntryV1Records[1].Data.UpperBounds, + Splits: manifestEntryV1Records[1].Data.Splits, + SortOrder: manifestEntryV1Records[1].Data.SortOrder, + }, + }, + } +) + +type ManifestTestSuite struct { + suite.Suite + + v1ManifestList bytes.Buffer + v1ManifestEntries bytes.Buffer + + v2ManifestList bytes.Buffer + v2ManifestEntries bytes.Buffer +} + +func (m *ManifestTestSuite) writeManifestList() { + enc, err := ocf.NewEncoder(internal.AvroSchemaCache.Get(internal.ManifestListV1Key).String(), + &m.v1ManifestList, ocf.WithMetadata(map[string][]byte{ + "avro.codec": []byte("deflate"), + }), + ocf.WithCodec(ocf.Deflate)) + m.Require().NoError(err) + + m.Require().NoError(enc.Encode(manifestFileRecordsV1[0])) + enc.Close() + + enc, err = ocf.NewEncoder(internal.AvroSchemaCache.Get(internal.ManifestListV2Key).String(), + &m.v2ManifestList, ocf.WithMetadata(map[string][]byte{ + "format-version": []byte("2"), + "avro.codec": []byte("deflate"), + }), ocf.WithCodec(ocf.Deflate)) + m.Require().NoError(err) + + m.Require().NoError(enc.Encode(manifestFileRecordsV2[0])) + enc.Close() +} + +func (m *ManifestTestSuite) writeManifestEntries() { + enc, err := ocf.NewEncoder(internal.AvroSchemaCache.Get(internal.ManifestEntryV1Key).String(), &m.v1ManifestEntries, + ocf.WithMetadata(map[string][]byte{ + "format-version": []byte("1"), + }), ocf.WithCodec(ocf.Deflate)) + m.Require().NoError(err) + + for _, ent := range manifestEntryV1Records { + m.Require().NoError(enc.Encode(ent)) + } + m.Require().NoError(enc.Close()) + + enc, err = ocf.NewEncoder(internal.AvroSchemaCache.Get(internal.ManifestEntryV2Key).String(), + &m.v2ManifestEntries, ocf.WithMetadata(map[string][]byte{ + "format-version": []byte("2"), + "avro.codec": []byte("deflate"), + }), ocf.WithCodec(ocf.Deflate)) + m.Require().NoError(err) + + for _, ent := range manifestEntryV2Records { + m.Require().NoError(enc.Encode(ent)) + } + m.Require().NoError(enc.Close()) +} + +func (m *ManifestTestSuite) SetupSuite() { + m.writeManifestList() + m.writeManifestEntries() +} + +func (m *ManifestTestSuite) TestManifestEntriesV1() { + var mockfs internal.MockFS + manifest := manifestFileV1{ + Path: manifestFileRecordsV1[0].FilePath(), + } + + mockfs.Test(m.T()) + mockfs.On("Open", manifest.FilePath()).Return(&internal.MockFile{ + Contents: bytes.NewReader(m.v1ManifestEntries.Bytes())}, nil) + defer mockfs.AssertExpectations(m.T()) + entries, err := manifest.FetchEntries(&mockfs, false) + m.Require().NoError(err) + m.Len(entries, 2) + m.Zero(manifest.PartitionSpecID()) + m.Zero(manifest.SnapshotID()) + m.Zero(manifest.AddedDataFiles()) + m.Zero(manifest.ExistingDataFiles()) + m.Zero(manifest.DeletedDataFiles()) + m.Zero(manifest.ExistingRows()) + m.Zero(manifest.DeletedRows()) + m.Zero(manifest.AddedRows()) + + entry1 := entries[0] + + m.Equal(EntryStatusADDED, entry1.Status()) + m.EqualValues(8744736658442914487, entry1.SnapshotID()) + m.Zero(entry1.SequenceNum()) + m.Nil(entry1.FileSequenceNum()) + + datafile := entry1.DataFile() + m.Equal(EntryContentData, datafile.ContentType()) + m.Equal("/home/iceberg/warehouse/nyc/taxis_partitioned/data/VendorID=null/00000-633-d8a4223e-dc97-45a1-86e1-adaba6e8abd7-00001.parquet", datafile.FilePath()) + m.Equal(ParquetFile, datafile.FileFormat()) + m.EqualValues(19513, datafile.Count()) + m.EqualValues(388872, datafile.FileSizeBytes()) + m.Equal(map[int]int64{ + 1: 53, + 2: 98153, + 3: 98693, + 4: 53, + 5: 53, + 6: 53, + 7: 17425, + 8: 18528, + 9: 53, + 10: 44788, + 11: 35571, + 12: 53, + 13: 1243, + 14: 2355, + 15: 12750, + 16: 4029, + 17: 110, + 18: 47194, + 19: 2948, + }, datafile.ColumnSizes()) + m.Equal(map[int]int64{ + 1: 19513, + 2: 19513, + 3: 19513, + 4: 19513, + 5: 19513, + 6: 19513, + 7: 19513, + 8: 19513, + 9: 19513, + 10: 19513, + 11: 19513, + 12: 19513, + 13: 19513, + 14: 19513, + 15: 19513, + 16: 19513, + 17: 19513, + 18: 19513, + 19: 19513, + }, datafile.ValueCounts()) + m.Equal(map[int]int64{ + 1: 19513, + 2: 0, + 3: 0, + 4: 19513, + 5: 19513, + 6: 19513, + 7: 0, + 8: 0, + 9: 19513, + 10: 0, + 11: 0, + 12: 19513, + 13: 0, + 14: 0, + 15: 0, + 16: 0, + 17: 0, + 18: 0, + 19: 0, + }, datafile.NullValueCounts()) + m.Equal(map[int]int64{ + 16: 0, 17: 0, 18: 0, 19: 0, 10: 0, 11: 0, 12: 0, 13: 0, 14: 0, 15: 0, + }, datafile.NaNValueCounts()) + + m.Equal(map[int][]byte{ + 2: []byte("2020-04-01 00:00"), + 3: []byte("2020-04-01 00:12"), + 7: {0x03, 0x00, 0x00, 0x00}, + 8: {0x01, 0x00, 0x00, 0x00}, + 10: {0xf6, '(', '\\', 0x8f, 0xc2, 0x05, 'S', 0xc0}, + 11: {0, 0, 0, 0, 0, 0, 0, 0}, + 13: {0, 0, 0, 0, 0, 0, 0, 0}, + 14: {0, 0, 0, 0, 0, 0, 0xe0, 0xbf}, + 15: {')', '\\', 0x8f, 0xc2, 0xf5, '(', 0x08, 0xc0}, + 16: {0, 0, 0, 0, 0, 0, 0, 0}, + 17: {0, 0, 0, 0, 0, 0, 0, 0}, + 18: {0xf6, '(', '\\', 0x8f, 0xc2, 0xc5, 'S', 0xc0}, + 19: {0, 0, 0, 0, 0, 0, 0x04, 0xc0}, + }, datafile.LowerBoundValues()) + + m.Equal(map[int][]byte{ + 2: []byte("2020-04-30 23:5:"), + 3: []byte("2020-05-01 00:41"), + 7: {'\t', 0x01, 0, 0}, + 8: {'\t', 0x01, 0, 0}, + 10: {0xcd, 0xcc, 0xcc, 0xcc, 0xcc, ',', '_', '@'}, + 11: {0x1f, 0x85, 0xeb, 'Q', '\\', 0xe2, 0xfe, '@'}, + 13: {0, 0, 0, 0, 0, 0, 0x12, '@'}, + 14: {0, 0, 0, 0, 0, 0, 0xe0, '?'}, + 15: {'q', '=', '\n', 0xd7, 0xa3, 0xf0, '1', '@'}, + 16: {0, 0, 0, 0, 0, '`', 'B', '@'}, + 17: {'3', '3', '3', '3', '3', '3', 0xd3, '?'}, + 18: {0, 0, 0, 0, 0, 0x18, 'b', '@'}, + 19: {0, 0, 0, 0, 0, 0, 0x04, '@'}, + }, datafile.UpperBoundValues()) + + m.Nil(datafile.KeyMetadata()) + m.Equal([]int64{4}, datafile.SplitOffsets()) + m.Nil(datafile.EqualityFieldIDs()) + m.Zero(*datafile.SortOrderID()) +} + +func (m *ManifestTestSuite) TestReadManifestListV1() { + list, err := ReadManifestList(&m.v1ManifestList) + m.Require().NoError(err) + + m.Len(list, 1) + m.Equal(1, list[0].Version()) + m.EqualValues(7989, list[0].Length()) + m.Equal(ManifestContentData, list[0].ManifestContent()) + m.Zero(list[0].SequenceNum()) + m.Zero(list[0].MinSequenceNum()) + m.EqualValues(9182715666859759686, list[0].SnapshotID()) + m.EqualValues(3, list[0].AddedDataFiles()) + m.True(list[0].HasAddedFiles()) + m.Zero(list[0].ExistingDataFiles()) + m.False(list[0].HasExistingFiles()) + m.Zero(list[0].DeletedDataFiles()) + m.Equal(addedRows, list[0].AddedRows()) + m.Zero(list[0].ExistingRows()) + m.Zero(list[0].DeletedRows()) + m.Nil(list[0].KeyMetadata()) + m.Zero(list[0].PartitionSpecID()) + m.Equal(snapshotID, list[0].SnapshotID()) + + part := list[0].Partitions()[0] + m.True(part.ContainsNull) + m.False(*part.ContainsNaN) + m.Equal([]byte{0x01, 0x00, 0x00, 0x00}, *part.LowerBound) + m.Equal([]byte{0x02, 0x00, 0x00, 0x00}, *part.UpperBound) +} + +func (m *ManifestTestSuite) TestReadManifestListV2() { + list, err := ReadManifestList(&m.v2ManifestList) + m.Require().NoError(err) + + m.Equal("/home/iceberg/warehouse/nyc/taxis_partitioned/metadata/0125c686-8aa6-4502-bdcc-b6d17ca41a3b-m0.avro", list[0].FilePath()) + m.Len(list, 1) + m.Equal(2, list[0].Version()) + m.EqualValues(7989, list[0].Length()) + m.Equal(ManifestContentDeletes, list[0].ManifestContent()) + m.EqualValues(3, list[0].SequenceNum()) + m.EqualValues(3, list[0].MinSequenceNum()) + m.EqualValues(9182715666859759686, list[0].SnapshotID()) + m.EqualValues(3, list[0].AddedDataFiles()) + m.True(list[0].HasAddedFiles()) + m.Zero(list[0].ExistingDataFiles()) + m.False(list[0].HasExistingFiles()) + m.Zero(list[0].DeletedDataFiles()) + m.Equal(addedRows, list[0].AddedRows()) + m.Zero(list[0].ExistingRows()) + m.Zero(list[0].DeletedRows()) + m.Nil(list[0].KeyMetadata()) + m.Zero(list[0].PartitionSpecID()) + + part := list[0].Partitions()[0] + m.True(part.ContainsNull) + m.False(*part.ContainsNaN) + m.Equal([]byte{0x01, 0x00, 0x00, 0x00}, *part.LowerBound) + m.Equal([]byte{0x02, 0x00, 0x00, 0x00}, *part.UpperBound) +} + +func (m *ManifestTestSuite) TestManifestEntriesV2() { + var mockfs internal.MockFS + manifest := manifestFileV2{ + Path: manifestFileRecordsV2[0].FilePath(), + } + + mockfs.Test(m.T()) + mockfs.On("Open", manifest.FilePath()).Return(&internal.MockFile{ + Contents: bytes.NewReader(m.v2ManifestEntries.Bytes())}, nil) + defer mockfs.AssertExpectations(m.T()) + entries, err := manifest.FetchEntries(&mockfs, false) + m.Require().NoError(err) + m.Len(entries, 2) + m.Zero(manifest.PartitionSpecID()) + m.Zero(manifest.SnapshotID()) + m.Zero(manifest.AddedDataFiles()) + m.Zero(manifest.ExistingDataFiles()) + m.Zero(manifest.DeletedDataFiles()) + m.Zero(manifest.ExistingRows()) + m.Zero(manifest.DeletedRows()) + m.Zero(manifest.AddedRows()) + + entry1 := entries[0] + + m.Equal(EntryStatusADDED, entry1.Status()) + m.Equal(entrySnapshotID, entry1.SnapshotID()) + m.Zero(entry1.SequenceNum()) + m.Zero(*entry1.FileSequenceNum()) + + datafile := entry1.DataFile() + m.Equal(EntryContentData, datafile.ContentType()) + m.Equal("/home/iceberg/warehouse/nyc/taxis_partitioned/data/VendorID=null/00000-633-d8a4223e-dc97-45a1-86e1-adaba6e8abd7-00001.parquet", datafile.FilePath()) + m.Equal(ParquetFile, datafile.FileFormat()) + m.EqualValues(19513, datafile.Count()) + m.EqualValues(388872, datafile.FileSizeBytes()) + m.Equal(map[int]int64{ + 1: 53, + 2: 98153, + 3: 98693, + 4: 53, + 5: 53, + 6: 53, + 7: 17425, + 8: 18528, + 9: 53, + 10: 44788, + 11: 35571, + 12: 53, + 13: 1243, + 14: 2355, + 15: 12750, + 16: 4029, + 17: 110, + 18: 47194, + 19: 2948, + }, datafile.ColumnSizes()) + m.Equal(map[int]int64{ + 1: 19513, + 2: 19513, + 3: 19513, + 4: 19513, + 5: 19513, + 6: 19513, + 7: 19513, + 8: 19513, + 9: 19513, + 10: 19513, + 11: 19513, + 12: 19513, + 13: 19513, + 14: 19513, + 15: 19513, + 16: 19513, + 17: 19513, + 18: 19513, + 19: 19513, + }, datafile.ValueCounts()) + m.Equal(map[int]int64{ + 1: 19513, + 2: 0, + 3: 0, + 4: 19513, + 5: 19513, + 6: 19513, + 7: 0, + 8: 0, + 9: 19513, + 10: 0, + 11: 0, + 12: 19513, + 13: 0, + 14: 0, + 15: 0, + 16: 0, + 17: 0, + 18: 0, + 19: 0, + }, datafile.NullValueCounts()) + m.Equal(map[int]int64{ + 16: 0, 17: 0, 18: 0, 19: 0, 10: 0, 11: 0, 12: 0, 13: 0, 14: 0, 15: 0, + }, datafile.NaNValueCounts()) + + m.Equal(map[int][]byte{ + 2: []byte("2020-04-01 00:00"), + 3: []byte("2020-04-01 00:12"), + 7: {0x03, 0x00, 0x00, 0x00}, + 8: {0x01, 0x00, 0x00, 0x00}, + 10: {0xf6, '(', '\\', 0x8f, 0xc2, 0x05, 'S', 0xc0}, + 11: {0, 0, 0, 0, 0, 0, 0, 0}, + 13: {0, 0, 0, 0, 0, 0, 0, 0}, + 14: {0, 0, 0, 0, 0, 0, 0xe0, 0xbf}, + 15: {')', '\\', 0x8f, 0xc2, 0xf5, '(', 0x08, 0xc0}, + 16: {0, 0, 0, 0, 0, 0, 0, 0}, + 17: {0, 0, 0, 0, 0, 0, 0, 0}, + 18: {0xf6, '(', '\\', 0x8f, 0xc2, 0xc5, 'S', 0xc0}, + 19: {0, 0, 0, 0, 0, 0, 0x04, 0xc0}, + }, datafile.LowerBoundValues()) + + m.Equal(map[int][]byte{ + 2: []byte("2020-04-30 23:5:"), + 3: []byte("2020-05-01 00:41"), + 7: {'\t', 0x01, 0, 0}, + 8: {'\t', 0x01, 0, 0}, + 10: {0xcd, 0xcc, 0xcc, 0xcc, 0xcc, ',', '_', '@'}, + 11: {0x1f, 0x85, 0xeb, 'Q', '\\', 0xe2, 0xfe, '@'}, + 13: {0, 0, 0, 0, 0, 0, 0x12, '@'}, + 14: {0, 0, 0, 0, 0, 0, 0xe0, '?'}, + 15: {'q', '=', '\n', 0xd7, 0xa3, 0xf0, '1', '@'}, + 16: {0, 0, 0, 0, 0, '`', 'B', '@'}, + 17: {'3', '3', '3', '3', '3', '3', 0xd3, '?'}, + 18: {0, 0, 0, 0, 0, 0x18, 'b', '@'}, + 19: {0, 0, 0, 0, 0, 0, 0x04, '@'}, + }, datafile.UpperBoundValues()) + + m.Nil(datafile.KeyMetadata()) + m.Equal([]int64{4}, datafile.SplitOffsets()) + m.Nil(datafile.EqualityFieldIDs()) + m.Zero(*datafile.SortOrderID()) +} + +func TestManifests(t *testing.T) { + suite.Run(t, new(ManifestTestSuite)) +} diff --git a/operation_string.go b/operation_string.go index 3af65e3..63e7121 100644 --- a/operation_string.go +++ b/operation_string.go @@ -1,41 +1,41 @@ -// Code generated by "stringer -type=Operation -linecomment"; DO NOT EDIT. - -package iceberg - -import "strconv" - -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the stringer command to generate them again. - var x [1]struct{} - _ = x[OpTrue-0] - _ = x[OpFalse-1] - _ = x[OpIsNull-2] - _ = x[OpNotNull-3] - _ = x[OpIsNan-4] - _ = x[OpNotNan-5] - _ = x[OpLT-6] - _ = x[OpLTEQ-7] - _ = x[OpGT-8] - _ = x[OpGTEQ-9] - _ = x[OpEQ-10] - _ = x[OpNEQ-11] - _ = x[OpStartsWith-12] - _ = x[OpNotStartsWith-13] - _ = x[OpIn-14] - _ = x[OpNotIn-15] - _ = x[OpNot-16] - _ = x[OpAnd-17] - _ = x[OpOr-18] -} - -const _Operation_name = "TrueFalseIsNullNotNullIsNaNNotNaNLessThanLessThanEqualGreaterThanGreaterThanEqualEqualNotEqualStartsWithNotStartsWithInNotInNotAndOr" - -var _Operation_index = [...]uint8{0, 4, 9, 15, 22, 27, 33, 41, 54, 65, 81, 86, 94, 104, 117, 119, 124, 127, 130, 132} - -func (i Operation) String() string { - if i < 0 || i >= Operation(len(_Operation_index)-1) { - return "Operation(" + strconv.FormatInt(int64(i), 10) + ")" - } - return _Operation_name[_Operation_index[i]:_Operation_index[i+1]] -} +// Code generated by "stringer -type=Operation -linecomment"; DO NOT EDIT. + +package iceberg + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[OpTrue-0] + _ = x[OpFalse-1] + _ = x[OpIsNull-2] + _ = x[OpNotNull-3] + _ = x[OpIsNan-4] + _ = x[OpNotNan-5] + _ = x[OpLT-6] + _ = x[OpLTEQ-7] + _ = x[OpGT-8] + _ = x[OpGTEQ-9] + _ = x[OpEQ-10] + _ = x[OpNEQ-11] + _ = x[OpStartsWith-12] + _ = x[OpNotStartsWith-13] + _ = x[OpIn-14] + _ = x[OpNotIn-15] + _ = x[OpNot-16] + _ = x[OpAnd-17] + _ = x[OpOr-18] +} + +const _Operation_name = "TrueFalseIsNullNotNullIsNaNNotNaNLessThanLessThanEqualGreaterThanGreaterThanEqualEqualNotEqualStartsWithNotStartsWithInNotInNotAndOr" + +var _Operation_index = [...]uint8{0, 4, 9, 15, 22, 27, 33, 41, 54, 65, 81, 86, 94, 104, 117, 119, 124, 127, 130, 132} + +func (i Operation) String() string { + if i < 0 || i >= Operation(len(_Operation_index)-1) { + return "Operation(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _Operation_name[_Operation_index[i]:_Operation_index[i+1]] +} diff --git a/partitions.go b/partitions.go index 321af2e..0b94dee 100644 --- a/partitions.go +++ b/partitions.go @@ -1,232 +1,232 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package iceberg - -import ( - "encoding/json" - "fmt" - "strings" - - "golang.org/x/exp/slices" -) - -const ( - partitionDataIDStart = 1000 - InitialPartitionSpecID = 0 -) - -// UnpartitionedSpec is the default unpartitioned spec which can -// be used for comparisons or to just provide a convenience for referencing -// the same unpartitioned spec object. -var UnpartitionedSpec = &PartitionSpec{id: 0} - -// PartitionField represents how one partition value is derived from the -// source column by transformation. -type PartitionField struct { - // SourceID is the source column id of the table's schema - SourceID int `json:"source-id"` - // FieldID is the partition field id across all the table partition specs - FieldID int `json:"field-id"` - // Name is the name of the partition field itself - Name string `json:"name"` - // Transform is the transform used to produce the partition value - Transform Transform `json:"transform"` -} - -func (p *PartitionField) String() string { - return fmt.Sprintf("%d: %s: %s(%d)", p.FieldID, p.Name, p.Transform, p.SourceID) -} - -func (p *PartitionField) UnmarshalJSON(b []byte) error { - type Alias PartitionField - aux := struct { - TransformString string `json:"transform"` - *Alias - }{ - Alias: (*Alias)(p), - } - - err := json.Unmarshal(b, &aux) - if err != nil { - return err - } - - if p.Transform, err = ParseTransform(aux.TransformString); err != nil { - return err - } - - return nil -} - -// PartitionSpec captures the transformation from table data to partition values -type PartitionSpec struct { - // any change to a PartitionSpec will produce a new spec id - id int - fields []PartitionField - - // this is populated by initialize after creation - sourceIdToFields map[int][]PartitionField -} - -func NewPartitionSpec(fields ...PartitionField) PartitionSpec { - return NewPartitionSpecID(InitialPartitionSpecID, fields...) -} - -func NewPartitionSpecID(id int, fields ...PartitionField) PartitionSpec { - ret := PartitionSpec{id: id, fields: fields} - ret.initialize() - return ret -} - -// CompatibleWith returns true if this partition spec is considered -// compatible with the passed in partition spec. This means that the two -// specs have equivalent field lists regardless of the spec id. -func (ps *PartitionSpec) CompatibleWith(other *PartitionSpec) bool { - if ps == other { - return true - } - - if len(ps.fields) != len(other.fields) { - return false - } - - return slices.EqualFunc(ps.fields, other.fields, func(left, right PartitionField) bool { - return left.SourceID == right.SourceID && left.Name == right.Name && - left.Transform == right.Transform - }) -} - -// Equals returns true iff the field lists are the same AND the spec id -// is the same between this partition spec and the provided one. -func (ps PartitionSpec) Equals(other PartitionSpec) bool { - return ps.id == other.id && slices.Equal(ps.fields, other.fields) -} - -func (ps PartitionSpec) MarshalJSON() ([]byte, error) { - if ps.fields == nil { - ps.fields = []PartitionField{} - } - return json.Marshal(struct { - ID int `json:"spec-id"` - Fields []PartitionField `json:"fields"` - }{ps.id, ps.fields}) -} - -func (ps *PartitionSpec) UnmarshalJSON(b []byte) error { - aux := struct { - ID int `json:"spec-id"` - Fields []PartitionField `json:"fields"` - }{ID: ps.id, Fields: ps.fields} - - if err := json.Unmarshal(b, &aux); err != nil { - return err - } - - ps.id, ps.fields = aux.ID, aux.Fields - ps.initialize() - return nil -} - -func (ps *PartitionSpec) initialize() { - ps.sourceIdToFields = make(map[int][]PartitionField) - for _, f := range ps.fields { - ps.sourceIdToFields[f.SourceID] = - append(ps.sourceIdToFields[f.SourceID], f) - } -} - -func (ps *PartitionSpec) ID() int { return ps.id } -func (ps *PartitionSpec) NumFields() int { return len(ps.fields) } -func (ps *PartitionSpec) Field(i int) PartitionField { return ps.fields[i] } - -func (ps *PartitionSpec) IsUnpartitioned() bool { - if len(ps.fields) == 0 { - return true - } - - for _, f := range ps.fields { - if _, ok := f.Transform.(VoidTransform); !ok { - return false - } - } - - return true -} - -func (ps *PartitionSpec) FieldsBySourceID(fieldID int) []PartitionField { - return slices.Clone(ps.sourceIdToFields[fieldID]) -} - -func (ps PartitionSpec) String() string { - var b strings.Builder - b.WriteByte('[') - for i, f := range ps.fields { - if i == 0 { - b.WriteString("\n") - } - b.WriteString("\t") - b.WriteString(f.String()) - b.WriteString("\n") - } - b.WriteByte(']') - - return b.String() -} - -func (ps *PartitionSpec) LastAssignedFieldID() int { - if len(ps.fields) == 0 { - return partitionDataIDStart - 1 - } - - id := ps.fields[0].FieldID - for _, f := range ps.fields[1:] { - if f.FieldID > id { - id = f.FieldID - } - } - return id -} - -// PartitionType produces a struct of the partition spec. -// -// The partition fields should be optional: -// - All partition transforms are required to produce null if the input value -// is null. This can happen when the source column is optional. -// - Partition fields may be added later, in which case not all files would -// have the result field and it may be null. -// -// There is a case where we can guarantee that a partition field in the first -// and only parittion spec that uses a required source column will never be -// null, but it doesn't seem worth tracking this case. -func (ps *PartitionSpec) PartitionType(schema *Schema) *StructType { - nestedFields := []NestedField{} - for _, field := range ps.fields { - sourceType, ok := schema.FindTypeByID(field.SourceID) - if !ok { - continue - } - resultType := field.Transform.ResultType(sourceType) - nestedFields = append(nestedFields, NestedField{ - ID: field.FieldID, - Name: field.Name, - Type: resultType, - Required: false, - }) - } - return &StructType{FieldList: nestedFields} -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package iceberg + +import ( + "encoding/json" + "fmt" + "strings" + + "golang.org/x/exp/slices" +) + +const ( + partitionDataIDStart = 1000 + InitialPartitionSpecID = 0 +) + +// UnpartitionedSpec is the default unpartitioned spec which can +// be used for comparisons or to just provide a convenience for referencing +// the same unpartitioned spec object. +var UnpartitionedSpec = &PartitionSpec{id: 0} + +// PartitionField represents how one partition value is derived from the +// source column by transformation. +type PartitionField struct { + // SourceID is the source column id of the table's schema + SourceID int `json:"source-id"` + // FieldID is the partition field id across all the table partition specs + FieldID int `json:"field-id"` + // Name is the name of the partition field itself + Name string `json:"name"` + // Transform is the transform used to produce the partition value + Transform Transform `json:"transform"` +} + +func (p *PartitionField) String() string { + return fmt.Sprintf("%d: %s: %s(%d)", p.FieldID, p.Name, p.Transform, p.SourceID) +} + +func (p *PartitionField) UnmarshalJSON(b []byte) error { + type Alias PartitionField + aux := struct { + TransformString string `json:"transform"` + *Alias + }{ + Alias: (*Alias)(p), + } + + err := json.Unmarshal(b, &aux) + if err != nil { + return err + } + + if p.Transform, err = ParseTransform(aux.TransformString); err != nil { + return err + } + + return nil +} + +// PartitionSpec captures the transformation from table data to partition values +type PartitionSpec struct { + // any change to a PartitionSpec will produce a new spec id + id int + fields []PartitionField + + // this is populated by initialize after creation + sourceIdToFields map[int][]PartitionField +} + +func NewPartitionSpec(fields ...PartitionField) PartitionSpec { + return NewPartitionSpecID(InitialPartitionSpecID, fields...) +} + +func NewPartitionSpecID(id int, fields ...PartitionField) PartitionSpec { + ret := PartitionSpec{id: id, fields: fields} + ret.initialize() + return ret +} + +// CompatibleWith returns true if this partition spec is considered +// compatible with the passed in partition spec. This means that the two +// specs have equivalent field lists regardless of the spec id. +func (ps *PartitionSpec) CompatibleWith(other *PartitionSpec) bool { + if ps == other { + return true + } + + if len(ps.fields) != len(other.fields) { + return false + } + + return slices.EqualFunc(ps.fields, other.fields, func(left, right PartitionField) bool { + return left.SourceID == right.SourceID && left.Name == right.Name && + left.Transform == right.Transform + }) +} + +// Equals returns true iff the field lists are the same AND the spec id +// is the same between this partition spec and the provided one. +func (ps PartitionSpec) Equals(other PartitionSpec) bool { + return ps.id == other.id && slices.Equal(ps.fields, other.fields) +} + +func (ps PartitionSpec) MarshalJSON() ([]byte, error) { + if ps.fields == nil { + ps.fields = []PartitionField{} + } + return json.Marshal(struct { + ID int `json:"spec-id"` + Fields []PartitionField `json:"fields"` + }{ps.id, ps.fields}) +} + +func (ps *PartitionSpec) UnmarshalJSON(b []byte) error { + aux := struct { + ID int `json:"spec-id"` + Fields []PartitionField `json:"fields"` + }{ID: ps.id, Fields: ps.fields} + + if err := json.Unmarshal(b, &aux); err != nil { + return err + } + + ps.id, ps.fields = aux.ID, aux.Fields + ps.initialize() + return nil +} + +func (ps *PartitionSpec) initialize() { + ps.sourceIdToFields = make(map[int][]PartitionField) + for _, f := range ps.fields { + ps.sourceIdToFields[f.SourceID] = + append(ps.sourceIdToFields[f.SourceID], f) + } +} + +func (ps *PartitionSpec) ID() int { return ps.id } +func (ps *PartitionSpec) NumFields() int { return len(ps.fields) } +func (ps *PartitionSpec) Field(i int) PartitionField { return ps.fields[i] } + +func (ps *PartitionSpec) IsUnpartitioned() bool { + if len(ps.fields) == 0 { + return true + } + + for _, f := range ps.fields { + if _, ok := f.Transform.(VoidTransform); !ok { + return false + } + } + + return true +} + +func (ps *PartitionSpec) FieldsBySourceID(fieldID int) []PartitionField { + return slices.Clone(ps.sourceIdToFields[fieldID]) +} + +func (ps PartitionSpec) String() string { + var b strings.Builder + b.WriteByte('[') + for i, f := range ps.fields { + if i == 0 { + b.WriteString("\n") + } + b.WriteString("\t") + b.WriteString(f.String()) + b.WriteString("\n") + } + b.WriteByte(']') + + return b.String() +} + +func (ps *PartitionSpec) LastAssignedFieldID() int { + if len(ps.fields) == 0 { + return partitionDataIDStart - 1 + } + + id := ps.fields[0].FieldID + for _, f := range ps.fields[1:] { + if f.FieldID > id { + id = f.FieldID + } + } + return id +} + +// PartitionType produces a struct of the partition spec. +// +// The partition fields should be optional: +// - All partition transforms are required to produce null if the input value +// is null. This can happen when the source column is optional. +// - Partition fields may be added later, in which case not all files would +// have the result field and it may be null. +// +// There is a case where we can guarantee that a partition field in the first +// and only parittion spec that uses a required source column will never be +// null, but it doesn't seem worth tracking this case. +func (ps *PartitionSpec) PartitionType(schema *Schema) *StructType { + nestedFields := []NestedField{} + for _, field := range ps.fields { + sourceType, ok := schema.FindTypeByID(field.SourceID) + if !ok { + continue + } + resultType := field.Transform.ResultType(sourceType) + nestedFields = append(nestedFields, NestedField{ + ID: field.FieldID, + Name: field.Name, + Type: resultType, + Required: false, + }) + } + return &StructType{FieldList: nestedFields} +} diff --git a/partitions_test.go b/partitions_test.go index fd29190..7aa45b9 100644 --- a/partitions_test.go +++ b/partitions_test.go @@ -1,143 +1,143 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package iceberg_test - -import ( - "encoding/json" - "testing" - - "github.com/apache/iceberg-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestPartitionSpec(t *testing.T) { - assert.Equal(t, 999, iceberg.UnpartitionedSpec.LastAssignedFieldID()) - - bucket := iceberg.BucketTransform{NumBuckets: 4} - idField1 := iceberg.PartitionField{ - SourceID: 3, FieldID: 1001, Name: "id", Transform: bucket} - spec1 := iceberg.NewPartitionSpec(idField1) - - assert.Zero(t, spec1.ID()) - assert.Equal(t, 1, spec1.NumFields()) - assert.Equal(t, idField1, spec1.Field(0)) - assert.NotEqual(t, idField1, spec1) - assert.False(t, spec1.IsUnpartitioned()) - assert.True(t, spec1.CompatibleWith(&spec1)) - assert.True(t, spec1.Equals(spec1)) - assert.Equal(t, 1001, spec1.LastAssignedFieldID()) - assert.Equal(t, "[\n\t1001: id: bucket[4](3)\n]", spec1.String()) - - // only differs by PartitionField FieldID - idField2 := iceberg.PartitionField{ - SourceID: 3, FieldID: 1002, Name: "id", Transform: bucket} - spec2 := iceberg.NewPartitionSpec(idField2) - - assert.False(t, spec1.Equals(spec2)) - assert.True(t, spec1.CompatibleWith(&spec2)) - assert.Equal(t, []iceberg.PartitionField{idField1}, spec1.FieldsBySourceID(3)) - assert.Empty(t, spec1.FieldsBySourceID(1925)) - - spec3 := iceberg.NewPartitionSpec(idField1, idField2) - assert.False(t, spec1.CompatibleWith(&spec3)) - assert.Equal(t, 1002, spec3.LastAssignedFieldID()) -} - -func TestUnpartitionedWithVoidField(t *testing.T) { - spec := iceberg.NewPartitionSpec(iceberg.PartitionField{ - SourceID: 3, FieldID: 1001, Name: "void", Transform: iceberg.VoidTransform{}, - }) - - assert.True(t, spec.IsUnpartitioned()) - - spec2 := iceberg.NewPartitionSpec(iceberg.PartitionField{ - SourceID: 3, FieldID: 1001, Name: "void", Transform: iceberg.VoidTransform{}, - }, iceberg.PartitionField{ - SourceID: 3, FieldID: 1002, Name: "bucket", Transform: iceberg.BucketTransform{NumBuckets: 2}, - }) - - assert.False(t, spec2.IsUnpartitioned()) -} - -func TestSerializeUnpartitionedSpec(t *testing.T) { - data, err := json.Marshal(iceberg.UnpartitionedSpec) - require.NoError(t, err) - - assert.JSONEq(t, `{"spec-id": 0, "fields": []}`, string(data)) - assert.True(t, iceberg.UnpartitionedSpec.IsUnpartitioned()) -} - -func TestSerializePartitionSpec(t *testing.T) { - spec := iceberg.NewPartitionSpecID(3, - iceberg.PartitionField{SourceID: 1, FieldID: 1000, - Transform: iceberg.TruncateTransform{Width: 19}, Name: "str_truncate"}, - iceberg.PartitionField{SourceID: 2, FieldID: 1001, - Transform: iceberg.BucketTransform{NumBuckets: 25}, Name: "int_bucket"}, - ) - - data, err := json.Marshal(spec) - require.NoError(t, err) - - assert.JSONEq(t, `{ - "spec-id": 3, - "fields": [ - { - "source-id": 1, - "field-id": 1000, - "transform": "truncate[19]", - "name": "str_truncate" - }, - { - "source-id": 2, - "field-id": 1001, - "transform": "bucket[25]", - "name": "int_bucket" - } - ] - }`, string(data)) - - var outspec iceberg.PartitionSpec - require.NoError(t, json.Unmarshal(data, &outspec)) - - assert.True(t, spec.Equals(outspec)) -} - -func TestPartitionType(t *testing.T) { - spec := iceberg.NewPartitionSpecID(3, - iceberg.PartitionField{SourceID: 1, FieldID: 1000, - Transform: iceberg.TruncateTransform{Width: 19}, Name: "str_truncate"}, - iceberg.PartitionField{SourceID: 2, FieldID: 1001, - Transform: iceberg.BucketTransform{NumBuckets: 25}, Name: "int_bucket"}, - iceberg.PartitionField{SourceID: 3, FieldID: 1002, - Transform: iceberg.IdentityTransform{}, Name: "bool_identity"}, - iceberg.PartitionField{SourceID: 1, FieldID: 1003, - Transform: iceberg.VoidTransform{}, Name: "str_void"}, - ) - - expected := &iceberg.StructType{ - FieldList: []iceberg.NestedField{ - {ID: 1000, Name: "str_truncate", Type: iceberg.PrimitiveTypes.String}, - {ID: 1001, Name: "int_bucket", Type: iceberg.PrimitiveTypes.Int32}, - {ID: 1002, Name: "bool_identity", Type: iceberg.PrimitiveTypes.Bool}, - {ID: 1003, Name: "str_void", Type: iceberg.PrimitiveTypes.String}, - }, - } - actual := spec.PartitionType(tableSchemaSimple) - assert.Truef(t, expected.Equals(actual), "expected: %s, got: %s", expected, actual) -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package iceberg_test + +import ( + "encoding/json" + "testing" + + "github.com/apache/iceberg-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPartitionSpec(t *testing.T) { + assert.Equal(t, 999, iceberg.UnpartitionedSpec.LastAssignedFieldID()) + + bucket := iceberg.BucketTransform{NumBuckets: 4} + idField1 := iceberg.PartitionField{ + SourceID: 3, FieldID: 1001, Name: "id", Transform: bucket} + spec1 := iceberg.NewPartitionSpec(idField1) + + assert.Zero(t, spec1.ID()) + assert.Equal(t, 1, spec1.NumFields()) + assert.Equal(t, idField1, spec1.Field(0)) + assert.NotEqual(t, idField1, spec1) + assert.False(t, spec1.IsUnpartitioned()) + assert.True(t, spec1.CompatibleWith(&spec1)) + assert.True(t, spec1.Equals(spec1)) + assert.Equal(t, 1001, spec1.LastAssignedFieldID()) + assert.Equal(t, "[\n\t1001: id: bucket[4](3)\n]", spec1.String()) + + // only differs by PartitionField FieldID + idField2 := iceberg.PartitionField{ + SourceID: 3, FieldID: 1002, Name: "id", Transform: bucket} + spec2 := iceberg.NewPartitionSpec(idField2) + + assert.False(t, spec1.Equals(spec2)) + assert.True(t, spec1.CompatibleWith(&spec2)) + assert.Equal(t, []iceberg.PartitionField{idField1}, spec1.FieldsBySourceID(3)) + assert.Empty(t, spec1.FieldsBySourceID(1925)) + + spec3 := iceberg.NewPartitionSpec(idField1, idField2) + assert.False(t, spec1.CompatibleWith(&spec3)) + assert.Equal(t, 1002, spec3.LastAssignedFieldID()) +} + +func TestUnpartitionedWithVoidField(t *testing.T) { + spec := iceberg.NewPartitionSpec(iceberg.PartitionField{ + SourceID: 3, FieldID: 1001, Name: "void", Transform: iceberg.VoidTransform{}, + }) + + assert.True(t, spec.IsUnpartitioned()) + + spec2 := iceberg.NewPartitionSpec(iceberg.PartitionField{ + SourceID: 3, FieldID: 1001, Name: "void", Transform: iceberg.VoidTransform{}, + }, iceberg.PartitionField{ + SourceID: 3, FieldID: 1002, Name: "bucket", Transform: iceberg.BucketTransform{NumBuckets: 2}, + }) + + assert.False(t, spec2.IsUnpartitioned()) +} + +func TestSerializeUnpartitionedSpec(t *testing.T) { + data, err := json.Marshal(iceberg.UnpartitionedSpec) + require.NoError(t, err) + + assert.JSONEq(t, `{"spec-id": 0, "fields": []}`, string(data)) + assert.True(t, iceberg.UnpartitionedSpec.IsUnpartitioned()) +} + +func TestSerializePartitionSpec(t *testing.T) { + spec := iceberg.NewPartitionSpecID(3, + iceberg.PartitionField{SourceID: 1, FieldID: 1000, + Transform: iceberg.TruncateTransform{Width: 19}, Name: "str_truncate"}, + iceberg.PartitionField{SourceID: 2, FieldID: 1001, + Transform: iceberg.BucketTransform{NumBuckets: 25}, Name: "int_bucket"}, + ) + + data, err := json.Marshal(spec) + require.NoError(t, err) + + assert.JSONEq(t, `{ + "spec-id": 3, + "fields": [ + { + "source-id": 1, + "field-id": 1000, + "transform": "truncate[19]", + "name": "str_truncate" + }, + { + "source-id": 2, + "field-id": 1001, + "transform": "bucket[25]", + "name": "int_bucket" + } + ] + }`, string(data)) + + var outspec iceberg.PartitionSpec + require.NoError(t, json.Unmarshal(data, &outspec)) + + assert.True(t, spec.Equals(outspec)) +} + +func TestPartitionType(t *testing.T) { + spec := iceberg.NewPartitionSpecID(3, + iceberg.PartitionField{SourceID: 1, FieldID: 1000, + Transform: iceberg.TruncateTransform{Width: 19}, Name: "str_truncate"}, + iceberg.PartitionField{SourceID: 2, FieldID: 1001, + Transform: iceberg.BucketTransform{NumBuckets: 25}, Name: "int_bucket"}, + iceberg.PartitionField{SourceID: 3, FieldID: 1002, + Transform: iceberg.IdentityTransform{}, Name: "bool_identity"}, + iceberg.PartitionField{SourceID: 1, FieldID: 1003, + Transform: iceberg.VoidTransform{}, Name: "str_void"}, + ) + + expected := &iceberg.StructType{ + FieldList: []iceberg.NestedField{ + {ID: 1000, Name: "str_truncate", Type: iceberg.PrimitiveTypes.String}, + {ID: 1001, Name: "int_bucket", Type: iceberg.PrimitiveTypes.Int32}, + {ID: 1002, Name: "bool_identity", Type: iceberg.PrimitiveTypes.Bool}, + {ID: 1003, Name: "str_void", Type: iceberg.PrimitiveTypes.String}, + }, + } + actual := spec.PartitionType(tableSchemaSimple) + assert.Truef(t, expected.Equals(actual), "expected: %s, got: %s", expected, actual) +} diff --git a/predicates.go b/predicates.go index 24ace71..5b6e2ed 100644 --- a/predicates.go +++ b/predicates.go @@ -1,138 +1,138 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package iceberg - -// IsNull is a convenience wrapper for calling UnaryPredicate(OpIsNull, t) -// -// Will panic if t is nil -func IsNull(t UnboundTerm) UnboundPredicate { - return UnaryPredicate(OpIsNull, t) -} - -// NotNull is a convenience wrapper for calling UnaryPredicate(OpNotNull, t) -// -// Will panic if t is nil -func NotNull(t UnboundTerm) UnboundPredicate { - return UnaryPredicate(OpNotNull, t) -} - -// IsNaN is a convenience wrapper for calling UnaryPredicate(OpIsNan, t) -// -// Will panic if t is nil -func IsNaN(t UnboundTerm) UnboundPredicate { - return UnaryPredicate(OpIsNan, t) -} - -// NotNaN is a convenience wrapper for calling UnaryPredicate(OpNotNan, t) -// -// Will panic if t is nil -func NotNaN(t UnboundTerm) UnboundPredicate { - return UnaryPredicate(OpNotNan, t) -} - -// IsIn is a convenience wrapper for constructing an unbound set predicate for -// OpIn. It returns a BooleanExpression instead of an UnboundPredicate because -// depending on the arguments, it can automatically reduce to AlwaysFalse or -// AlwaysTrue (if given no values for examples). It may also reduce to EqualTo -// if only one value is provided. -// -// Will panic if t is nil -func IsIn[T LiteralType](t UnboundTerm, vals ...T) BooleanExpression { - lits := make([]Literal, 0, len(vals)) - for _, v := range vals { - lits = append(lits, NewLiteral(v)) - } - return SetPredicate(OpIn, t, lits) -} - -// NotIn is a convenience wrapper for constructing an unbound set predicate for -// OpNotIn. It returns a BooleanExpression instead of an UnboundPredicate because -// depending on the arguments, it can automatically reduce to AlwaysFalse or -// AlwaysTrue (if given no values for examples). It may also reduce to NotEqualTo -// if only one value is provided. -// -// Will panic if t is nil -func NotIn[T LiteralType](t UnboundTerm, vals ...T) BooleanExpression { - lits := make([]Literal, 0, len(vals)) - for _, v := range vals { - lits = append(lits, NewLiteral(v)) - } - return SetPredicate(OpNotIn, t, lits) -} - -// EqualTo is a convenience wrapper for calling LiteralPredicate(OpEQ, t, NewLiteral(v)) -// -// Will panic if t is nil -func EqualTo[T LiteralType](t UnboundTerm, v T) UnboundPredicate { - return LiteralPredicate(OpEQ, t, NewLiteral(v)) -} - -// NotEqualTo is a convenience wrapper for calling LiteralPredicate(OpNEQ, t, NewLiteral(v)) -// -// Will panic if t is nil -func NotEqualTo[T LiteralType](t UnboundTerm, v T) UnboundPredicate { - return LiteralPredicate(OpNEQ, t, NewLiteral(v)) -} - -// GreaterThanEqual is a convenience wrapper for calling LiteralPredicate(OpGTEQ, -// t, NewLiteral(v)) -// -// Will panic if t is nil -func GreaterThanEqual[T LiteralType](t UnboundTerm, v T) UnboundPredicate { - return LiteralPredicate(OpGTEQ, t, NewLiteral(v)) -} - -// GreaterThan is a convenience wrapper for calling LiteralPredicate(OpGT, -// t, NewLiteral(v)) -// -// Will panic if t is nil -func GreaterThan[T LiteralType](t UnboundTerm, v T) UnboundPredicate { - return LiteralPredicate(OpGT, t, NewLiteral(v)) -} - -// LessThanEqual is a convenience wrapper for calling LiteralPredicate(OpLTEQ, -// t, NewLiteral(v)) -// -// Will panic if t is nil -func LessThanEqual[T LiteralType](t UnboundTerm, v T) UnboundPredicate { - return LiteralPredicate(OpLTEQ, t, NewLiteral(v)) -} - -// LessThan is a convenience wrapper for calling LiteralPredicate(OpLT, -// t, NewLiteral(v)) -// -// Will panic if t is nil -func LessThan[T LiteralType](t UnboundTerm, v T) UnboundPredicate { - return LiteralPredicate(OpLT, t, NewLiteral(v)) -} - -// StartsWith is a convenience wrapper for calling LiteralPredicate(OpStartsWith, -// t, NewLiteral(v)) -// -// Will panic if t is nil -func StartsWith(t UnboundTerm, v string) UnboundPredicate { - return LiteralPredicate(OpStartsWith, t, NewLiteral(v)) -} - -// NotStartsWith is a convenience wrapper for calling LiteralPredicate(OpNotStartsWith, -// t, NewLiteral(v)) -// -// Will panic if t is nil -func NotStartsWith(t UnboundTerm, v string) UnboundPredicate { - return LiteralPredicate(OpNotStartsWith, t, NewLiteral(v)) -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package iceberg + +// IsNull is a convenience wrapper for calling UnaryPredicate(OpIsNull, t) +// +// Will panic if t is nil +func IsNull(t UnboundTerm) UnboundPredicate { + return UnaryPredicate(OpIsNull, t) +} + +// NotNull is a convenience wrapper for calling UnaryPredicate(OpNotNull, t) +// +// Will panic if t is nil +func NotNull(t UnboundTerm) UnboundPredicate { + return UnaryPredicate(OpNotNull, t) +} + +// IsNaN is a convenience wrapper for calling UnaryPredicate(OpIsNan, t) +// +// Will panic if t is nil +func IsNaN(t UnboundTerm) UnboundPredicate { + return UnaryPredicate(OpIsNan, t) +} + +// NotNaN is a convenience wrapper for calling UnaryPredicate(OpNotNan, t) +// +// Will panic if t is nil +func NotNaN(t UnboundTerm) UnboundPredicate { + return UnaryPredicate(OpNotNan, t) +} + +// IsIn is a convenience wrapper for constructing an unbound set predicate for +// OpIn. It returns a BooleanExpression instead of an UnboundPredicate because +// depending on the arguments, it can automatically reduce to AlwaysFalse or +// AlwaysTrue (if given no values for examples). It may also reduce to EqualTo +// if only one value is provided. +// +// Will panic if t is nil +func IsIn[T LiteralType](t UnboundTerm, vals ...T) BooleanExpression { + lits := make([]Literal, 0, len(vals)) + for _, v := range vals { + lits = append(lits, NewLiteral(v)) + } + return SetPredicate(OpIn, t, lits) +} + +// NotIn is a convenience wrapper for constructing an unbound set predicate for +// OpNotIn. It returns a BooleanExpression instead of an UnboundPredicate because +// depending on the arguments, it can automatically reduce to AlwaysFalse or +// AlwaysTrue (if given no values for examples). It may also reduce to NotEqualTo +// if only one value is provided. +// +// Will panic if t is nil +func NotIn[T LiteralType](t UnboundTerm, vals ...T) BooleanExpression { + lits := make([]Literal, 0, len(vals)) + for _, v := range vals { + lits = append(lits, NewLiteral(v)) + } + return SetPredicate(OpNotIn, t, lits) +} + +// EqualTo is a convenience wrapper for calling LiteralPredicate(OpEQ, t, NewLiteral(v)) +// +// Will panic if t is nil +func EqualTo[T LiteralType](t UnboundTerm, v T) UnboundPredicate { + return LiteralPredicate(OpEQ, t, NewLiteral(v)) +} + +// NotEqualTo is a convenience wrapper for calling LiteralPredicate(OpNEQ, t, NewLiteral(v)) +// +// Will panic if t is nil +func NotEqualTo[T LiteralType](t UnboundTerm, v T) UnboundPredicate { + return LiteralPredicate(OpNEQ, t, NewLiteral(v)) +} + +// GreaterThanEqual is a convenience wrapper for calling LiteralPredicate(OpGTEQ, +// t, NewLiteral(v)) +// +// Will panic if t is nil +func GreaterThanEqual[T LiteralType](t UnboundTerm, v T) UnboundPredicate { + return LiteralPredicate(OpGTEQ, t, NewLiteral(v)) +} + +// GreaterThan is a convenience wrapper for calling LiteralPredicate(OpGT, +// t, NewLiteral(v)) +// +// Will panic if t is nil +func GreaterThan[T LiteralType](t UnboundTerm, v T) UnboundPredicate { + return LiteralPredicate(OpGT, t, NewLiteral(v)) +} + +// LessThanEqual is a convenience wrapper for calling LiteralPredicate(OpLTEQ, +// t, NewLiteral(v)) +// +// Will panic if t is nil +func LessThanEqual[T LiteralType](t UnboundTerm, v T) UnboundPredicate { + return LiteralPredicate(OpLTEQ, t, NewLiteral(v)) +} + +// LessThan is a convenience wrapper for calling LiteralPredicate(OpLT, +// t, NewLiteral(v)) +// +// Will panic if t is nil +func LessThan[T LiteralType](t UnboundTerm, v T) UnboundPredicate { + return LiteralPredicate(OpLT, t, NewLiteral(v)) +} + +// StartsWith is a convenience wrapper for calling LiteralPredicate(OpStartsWith, +// t, NewLiteral(v)) +// +// Will panic if t is nil +func StartsWith(t UnboundTerm, v string) UnboundPredicate { + return LiteralPredicate(OpStartsWith, t, NewLiteral(v)) +} + +// NotStartsWith is a convenience wrapper for calling LiteralPredicate(OpNotStartsWith, +// t, NewLiteral(v)) +// +// Will panic if t is nil +func NotStartsWith(t UnboundTerm, v string) UnboundPredicate { + return LiteralPredicate(OpNotStartsWith, t, NewLiteral(v)) +} diff --git a/schema.go b/schema.go index 7ea7757..5d60fe9 100644 --- a/schema.go +++ b/schema.go @@ -1,1167 +1,1167 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package iceberg - -import ( - "encoding/json" - "fmt" - "maps" - "strings" - "sync" - "sync/atomic" - - "golang.org/x/exp/slices" -) - -// Schema is an Iceberg table schema, represented as a struct with -// multiple fields. The fields are only exported via accessor methods -// rather than exposing the slice directly in order to ensure a schema -// as immutable. -type Schema struct { - ID int `json:"schema-id"` - IdentifierFieldIDs []int `json:"identifier-field-ids"` - - fields []NestedField - - // the following maps are lazily populated as needed. - // rather than have lock contention with a mutex, we can use - // atomic pointers to Store/Load the values. - idToName atomic.Pointer[map[int]string] - idToField atomic.Pointer[map[int]NestedField] - nameToID atomic.Pointer[map[string]int] - nameToIDLower atomic.Pointer[map[string]int] - idToAccessor atomic.Pointer[map[int]accessor] - - lazyIDToParent func() (map[int]int, error) -} - -// NewSchema constructs a new schema with the provided ID -// and list of fields. -func NewSchema(id int, fields ...NestedField) *Schema { - return NewSchemaWithIdentifiers(id, []int{}, fields...) -} - -// NewSchemaWithIdentifiers constructs a new schema with the provided ID -// and fields, along with a slice of field IDs to be listed as identifier -// fields. -func NewSchemaWithIdentifiers(id int, identifierIDs []int, fields ...NestedField) *Schema { - s := &Schema{ID: id, fields: fields, IdentifierFieldIDs: identifierIDs} - s.lazyIDToParent = sync.OnceValues(func() (map[int]int, error) { - return IndexParents(s) - }) - return s -} - -func (s *Schema) String() string { - var b strings.Builder - b.WriteString("table {") - for _, f := range s.fields { - b.WriteString("\n\t") - b.WriteString(f.String()) - } - b.WriteString("\n}") - return b.String() -} - -func (s *Schema) lazyNameToID() (map[string]int, error) { - index := s.nameToID.Load() - if index != nil { - return *index, nil - } - - idx, err := IndexByName(s) - if err != nil { - return nil, err - } - - s.nameToID.Store(&idx) - return idx, nil -} - -func (s *Schema) lazyIDToField() (map[int]NestedField, error) { - index := s.idToField.Load() - if index != nil { - return *index, nil - } - - idx, err := IndexByID(s) - if err != nil { - return nil, err - } - - s.idToField.Store(&idx) - return idx, nil -} - -func (s *Schema) lazyIDToName() (map[int]string, error) { - index := s.idToName.Load() - if index != nil { - return *index, nil - } - - idx, err := IndexNameByID(s) - if err != nil { - return nil, err - } - - s.idToName.Store(&idx) - return idx, nil -} - -func (s *Schema) lazyNameToIDLower() (map[string]int, error) { - index := s.nameToIDLower.Load() - if index != nil { - return *index, nil - } - - idx, err := s.lazyNameToID() - if err != nil { - return nil, err - } - - out := make(map[string]int) - for k, v := range idx { - out[strings.ToLower(k)] = v - } - - s.nameToIDLower.Store(&out) - return out, nil -} - -func (s *Schema) lazyIdToAccessor() (map[int]accessor, error) { - index := s.idToAccessor.Load() - if index != nil { - return *index, nil - } - - idx, err := buildAccessors(s) - if err != nil { - return nil, err - } - - s.idToAccessor.Store(&idx) - return idx, nil -} - -func (s *Schema) Type() string { return "struct" } - -// AsStruct returns a Struct with the same fields as the schema which can -// then be used as a Type. -func (s *Schema) AsStruct() StructType { return StructType{FieldList: s.fields} } -func (s *Schema) NumFields() int { return len(s.fields) } -func (s *Schema) Field(i int) NestedField { return s.fields[i] } -func (s *Schema) Fields() []NestedField { return slices.Clone(s.fields) } - -func (s *Schema) UnmarshalJSON(b []byte) error { - type Alias Schema - aux := struct { - Fields []NestedField `json:"fields"` - *Alias - }{Alias: (*Alias)(s)} - - if err := json.Unmarshal(b, &aux); err != nil { - return err - } - - if s.lazyIDToParent == nil { - s.lazyIDToParent = sync.OnceValues(func() (map[int]int, error) { - return IndexParents(s) - }) - } - - s.fields = aux.Fields - if s.IdentifierFieldIDs == nil { - s.IdentifierFieldIDs = []int{} - } - return nil -} - -func (s *Schema) MarshalJSON() ([]byte, error) { - if s.IdentifierFieldIDs == nil { - s.IdentifierFieldIDs = []int{} - } - - type Alias Schema - return json.Marshal(struct { - Type string `json:"type"` - Fields []NestedField `json:"fields"` - *Alias - }{Type: "struct", Fields: s.fields, Alias: (*Alias)(s)}) -} - -// FindColumnName returns the name of the column identified by the -// passed in field id. The second return value reports whether or -// not the field id was found in the schema. -func (s *Schema) FindColumnName(fieldID int) (string, bool) { - idx, _ := s.lazyIDToName() - col, ok := idx[fieldID] - return col, ok -} - -// FindFieldByName returns the field identified by the name given, -// the second return value will be false if no field by this name -// is found. -// -// Note: This search is done in a case sensitive manner. To perform -// a case insensitive search, use [*Schema.FindFieldByNameCaseInsensitive]. -func (s *Schema) FindFieldByName(name string) (NestedField, bool) { - idx, _ := s.lazyNameToID() - - id, ok := idx[name] - if !ok { - return NestedField{}, false - } - - return s.FindFieldByID(id) -} - -// FindFieldByNameCaseInsensitive is like [*Schema.FindFieldByName], -// but performs a case insensitive search. -func (s *Schema) FindFieldByNameCaseInsensitive(name string) (NestedField, bool) { - idx, _ := s.lazyNameToIDLower() - - id, ok := idx[strings.ToLower(name)] - if !ok { - return NestedField{}, false - } - - return s.FindFieldByID(id) -} - -// FindFieldByID is like [*Schema.FindColumnName], but returns the whole -// field rather than just the field name. -func (s *Schema) FindFieldByID(id int) (NestedField, bool) { - idx, _ := s.lazyIDToField() - f, ok := idx[id] - return f, ok -} - -// FindTypeByID is like [*Schema.FindFieldByID], but returns only the data -// type of the field. -func (s *Schema) FindTypeByID(id int) (Type, bool) { - f, ok := s.FindFieldByID(id) - if !ok { - return nil, false - } - - return f.Type, true -} - -// FindTypeByName is a convenience function for calling [*Schema.FindFieldByName], -// and then returning just the type. -func (s *Schema) FindTypeByName(name string) (Type, bool) { - f, ok := s.FindFieldByName(name) - if !ok { - return nil, false - } - - return f.Type, true -} - -// FindTypeByNameCaseInsensitive is like [*Schema.FindTypeByName] but -// performs a case insensitive search. -func (s *Schema) FindTypeByNameCaseInsensitive(name string) (Type, bool) { - f, ok := s.FindFieldByNameCaseInsensitive(name) - if !ok { - return nil, false - } - - return f.Type, true -} - -func (s *Schema) accessorForField(id int) (accessor, bool) { - idx, err := s.lazyIdToAccessor() - if err != nil { - return accessor{}, false - } - - acc, ok := idx[id] - return acc, ok -} - -// Equals compares the fields and identifierIDs, but does not compare -// the schema ID itself. -func (s *Schema) Equals(other *Schema) bool { - if other == nil { - return false - } - - if s == other { - return true - } - - if len(s.fields) != len(other.fields) { - return false - } - - if !slices.Equal(s.IdentifierFieldIDs, other.IdentifierFieldIDs) { - return false - } - - return slices.EqualFunc(s.fields, other.fields, func(a, b NestedField) bool { - return a.Equals(b) - }) -} - -// HighestFieldID returns the value of the numerically highest field ID -// in this schema. -func (s *Schema) HighestFieldID() int { - id, _ := Visit[int](s, findLastFieldID{}) - return id -} - -type Void = struct{} - -var void = Void{} - -// Select creates a new schema with just the fields identified by name -// passed in the order they are provided. If caseSensitive is false, -// then fields will be identified by case insensitive search. -// -// An error is returned if a requested name cannot be found. -func (s *Schema) Select(caseSensitive bool, names ...string) (*Schema, error) { - ids := make(map[int]Void) - if caseSensitive { - nameMap, _ := s.lazyNameToID() - for _, n := range names { - id, ok := nameMap[n] - if !ok { - return nil, fmt.Errorf("%w: could not find column %s", ErrInvalidSchema, n) - } - ids[id] = void - } - } else { - nameMap, _ := s.lazyNameToIDLower() - for _, n := range names { - id, ok := nameMap[strings.ToLower(n)] - if !ok { - return nil, fmt.Errorf("%w: could not find column %s", ErrInvalidSchema, n) - } - ids[id] = void - } - } - - return PruneColumns(s, ids, true) -} - -func (s *Schema) FieldHasOptionalParent(id int) bool { - idToParent, _ := s.lazyIDToParent() - idToField, _ := s.lazyIDToField() - - f, ok := idToField[id] - if !ok { - return false - } - - for { - parent, ok := idToParent[f.ID] - if !ok { - return false - } - - if f = idToField[parent]; !f.Required { - return true - } - } -} - -// SchemaVisitor is an interface that can be implemented to allow for -// easy traversal and processing of a schema. -// -// A SchemaVisitor can also optionally implement the Before/After Field, -// ListElement, MapKey, or MapValue interfaces to allow them to get called -// at the appropriate points within schema traversal. -type SchemaVisitor[T any] interface { - Schema(schema *Schema, structResult T) T - Struct(st StructType, fieldResults []T) T - Field(field NestedField, fieldResult T) T - List(list ListType, elemResult T) T - Map(mapType MapType, keyResult, valueResult T) T - Primitive(p PrimitiveType) T -} - -type BeforeFieldVisitor interface { - BeforeField(field NestedField) -} - -type AfterFieldVisitor interface { - AfterField(field NestedField) -} - -type BeforeListElementVisitor interface { - BeforeListElement(elem NestedField) -} - -type AfterListElementVisitor interface { - AfterListElement(elem NestedField) -} - -type BeforeMapKeyVisitor interface { - BeforeMapKey(key NestedField) -} - -type AfterMapKeyVisitor interface { - AfterMapKey(key NestedField) -} - -type BeforeMapValueVisitor interface { - BeforeMapValue(value NestedField) -} - -type AfterMapValueVisitor interface { - AfterMapValue(value NestedField) -} - -// Visit accepts a visitor and performs a post-order traversal of the given schema. -func Visit[T any](sc *Schema, visitor SchemaVisitor[T]) (res T, err error) { - if sc == nil { - err = fmt.Errorf("%w: cannot visit nil schema", ErrInvalidArgument) - return - } - - defer func() { - if r := recover(); r != nil { - switch e := r.(type) { - case string: - err = fmt.Errorf("error encountered during schema visitor: %s", e) - case error: - err = fmt.Errorf("error encountered during schema visitor: %w", e) - } - } - }() - - return visitor.Schema(sc, visitStruct(sc.AsStruct(), visitor)), nil -} - -func visitStruct[T any](obj StructType, visitor SchemaVisitor[T]) T { - results := make([]T, len(obj.FieldList)) - - bf, _ := visitor.(BeforeFieldVisitor) - af, _ := visitor.(AfterFieldVisitor) - - for i, f := range obj.FieldList { - if bf != nil { - bf.BeforeField(f) - } - - res := visitField(f, visitor) - - if af != nil { - af.AfterField(f) - } - - results[i] = visitor.Field(f, res) - } - - return visitor.Struct(obj, results) -} - -func visitList[T any](obj ListType, visitor SchemaVisitor[T]) T { - elemField := obj.ElementField() - - if bl, ok := visitor.(BeforeListElementVisitor); ok { - bl.BeforeListElement(elemField) - } else if bf, ok := visitor.(BeforeFieldVisitor); ok { - bf.BeforeField(elemField) - } - - res := visitField(elemField, visitor) - - if al, ok := visitor.(AfterListElementVisitor); ok { - al.AfterListElement(elemField) - } else if af, ok := visitor.(AfterFieldVisitor); ok { - af.AfterField(elemField) - } - - return visitor.List(obj, res) -} - -func visitMap[T any](obj MapType, visitor SchemaVisitor[T]) T { - keyField, valueField := obj.KeyField(), obj.ValueField() - - if bmk, ok := visitor.(BeforeMapKeyVisitor); ok { - bmk.BeforeMapKey(keyField) - } else if bf, ok := visitor.(BeforeFieldVisitor); ok { - bf.BeforeField(keyField) - } - - keyRes := visitField(keyField, visitor) - - if amk, ok := visitor.(AfterMapKeyVisitor); ok { - amk.AfterMapKey(keyField) - } else if af, ok := visitor.(AfterFieldVisitor); ok { - af.AfterField(keyField) - } - - if bmk, ok := visitor.(BeforeMapValueVisitor); ok { - bmk.BeforeMapValue(valueField) - } else if bf, ok := visitor.(BeforeFieldVisitor); ok { - bf.BeforeField(valueField) - } - - valueRes := visitField(valueField, visitor) - - if amk, ok := visitor.(AfterMapValueVisitor); ok { - amk.AfterMapValue(valueField) - } else if af, ok := visitor.(AfterFieldVisitor); ok { - af.AfterField(valueField) - } - - return visitor.Map(obj, keyRes, valueRes) -} - -func visitField[T any](f NestedField, visitor SchemaVisitor[T]) T { - switch typ := f.Type.(type) { - case *StructType: - return visitStruct(*typ, visitor) - case *ListType: - return visitList(*typ, visitor) - case *MapType: - return visitMap(*typ, visitor) - default: // primitive - return visitor.Primitive(typ.(PrimitiveType)) - } -} - -// IndexByID performs a post-order traversal of the given schema and -// returns a mapping from field ID to field. -func IndexByID(schema *Schema) (map[int]NestedField, error) { - return Visit[map[int]NestedField](schema, &indexByID{index: make(map[int]NestedField)}) -} - -type indexByID struct { - index map[int]NestedField -} - -func (i *indexByID) Schema(*Schema, map[int]NestedField) map[int]NestedField { - return i.index -} - -func (i *indexByID) Struct(StructType, []map[int]NestedField) map[int]NestedField { - return i.index -} - -func (i *indexByID) Field(field NestedField, _ map[int]NestedField) map[int]NestedField { - i.index[field.ID] = field - return i.index -} - -func (i *indexByID) List(list ListType, _ map[int]NestedField) map[int]NestedField { - i.index[list.ElementID] = list.ElementField() - return i.index -} - -func (i *indexByID) Map(mapType MapType, _, _ map[int]NestedField) map[int]NestedField { - i.index[mapType.KeyID] = mapType.KeyField() - i.index[mapType.ValueID] = mapType.ValueField() - return i.index -} - -func (i *indexByID) Primitive(PrimitiveType) map[int]NestedField { - return i.index -} - -// IndexByName performs a post-order traversal of the schema and returns -// a mapping from field name to field ID. -func IndexByName(schema *Schema) (map[string]int, error) { - if schema == nil { - return nil, fmt.Errorf("%w: cannot index nil schema", ErrInvalidArgument) - } - - if len(schema.fields) > 0 { - indexer := &indexByName{ - index: make(map[string]int), - shortNameId: make(map[string]int), - fieldNames: make([]string, 0), - shortFieldNames: make([]string, 0), - } - if _, err := Visit[map[string]int](schema, indexer); err != nil { - return nil, err - } - - return indexer.ByName(), nil - } - return map[string]int{}, nil -} - -// IndexNameByID performs a post-order traversal of the schema and returns -// a mapping from field ID to field name. -func IndexNameByID(schema *Schema) (map[int]string, error) { - indexer := &indexByName{ - index: make(map[string]int), - shortNameId: make(map[string]int), - fieldNames: make([]string, 0), - shortFieldNames: make([]string, 0), - } - if _, err := Visit[map[string]int](schema, indexer); err != nil { - return nil, err - } - return indexer.ByID(), nil -} - -type indexByName struct { - index map[string]int - shortNameId map[string]int - combinedIndex map[string]int - fieldNames []string - shortFieldNames []string -} - -func (i *indexByName) ByID() map[int]string { - idToName := make(map[int]string) - for k, v := range i.index { - idToName[v] = k - } - return idToName -} - -func (i *indexByName) ByName() map[string]int { - i.combinedIndex = maps.Clone(i.shortNameId) - maps.Copy(i.combinedIndex, i.index) - return i.combinedIndex -} - -func (i *indexByName) Primitive(PrimitiveType) map[string]int { return i.index } -func (i *indexByName) addField(name string, fieldID int) { - fullName := name - if len(i.fieldNames) > 0 { - fullName = strings.Join(i.fieldNames, ".") + "." + name - } - - if _, ok := i.index[fullName]; ok { - panic(fmt.Errorf("%w: multiple fields for name %s: %d and %d", - ErrInvalidSchema, fullName, i.index[fullName], fieldID)) - } - - i.index[fullName] = fieldID - if len(i.shortFieldNames) > 0 { - shortName := strings.Join(i.shortFieldNames, ".") + "." + name - i.shortNameId[shortName] = fieldID - } -} - -func (i *indexByName) Schema(*Schema, map[string]int) map[string]int { - return i.index -} - -func (i *indexByName) Struct(StructType, []map[string]int) map[string]int { - return i.index -} - -func (i *indexByName) Field(field NestedField, _ map[string]int) map[string]int { - i.addField(field.Name, field.ID) - return i.index -} - -func (i *indexByName) List(list ListType, _ map[string]int) map[string]int { - i.addField(list.ElementField().Name, list.ElementID) - return i.index -} - -func (i *indexByName) Map(mapType MapType, _, _ map[string]int) map[string]int { - i.addField(mapType.KeyField().Name, mapType.KeyID) - i.addField(mapType.ValueField().Name, mapType.ValueID) - return i.index -} - -func (i *indexByName) BeforeListElement(elem NestedField) { - if _, ok := elem.Type.(*StructType); !ok { - i.shortFieldNames = append(i.shortFieldNames, elem.Name) - } - i.fieldNames = append(i.fieldNames, elem.Name) -} - -func (i *indexByName) AfterListElement(elem NestedField) { - if _, ok := elem.Type.(*StructType); !ok { - i.shortFieldNames = i.shortFieldNames[:len(i.shortFieldNames)-1] - } - i.fieldNames = i.fieldNames[:len(i.fieldNames)-1] -} - -func (i *indexByName) BeforeField(field NestedField) { - i.fieldNames = append(i.fieldNames, field.Name) - i.shortFieldNames = append(i.shortFieldNames, field.Name) -} - -func (i *indexByName) AfterField(field NestedField) { - i.fieldNames = i.fieldNames[:len(i.fieldNames)-1] - i.shortFieldNames = i.shortFieldNames[:len(i.shortFieldNames)-1] -} - -// PruneColumns visits a schema pruning any columns which do not exist in the -// provided selected set. Parent fields of a selected child will be retained. -func PruneColumns(schema *Schema, selected map[int]Void, selectFullTypes bool) (*Schema, error) { - - result, err := Visit[Type](schema, &pruneColVisitor{selected: selected, - fullTypes: selectFullTypes}) - if err != nil { - return nil, err - } - - n, ok := result.(NestedType) - if !ok { - n = &StructType{} - } - - newIdentifierIDs := make([]int, 0, len(schema.IdentifierFieldIDs)) - for _, id := range schema.IdentifierFieldIDs { - if _, ok := selected[id]; ok { - newIdentifierIDs = append(newIdentifierIDs, id) - } - } - - return &Schema{ - fields: n.Fields(), - ID: schema.ID, - IdentifierFieldIDs: newIdentifierIDs, - }, nil -} - -type pruneColVisitor struct { - selected map[int]Void - fullTypes bool -} - -func (p *pruneColVisitor) Schema(_ *Schema, structResult Type) Type { - return structResult -} - -func (p *pruneColVisitor) Struct(st StructType, fieldResults []Type) Type { - selected, fields := []NestedField{}, st.FieldList - sameType := true - - for i, t := range fieldResults { - field := fields[i] - if field.Type == t { - selected = append(selected, field) - } else if t != nil { - sameType = false - // type has changed, create a new field with the projected type - selected = append(selected, NestedField{ - ID: field.ID, - Name: field.Name, - Type: t, - Doc: field.Doc, - Required: field.Required, - }) - } - } - - if len(selected) > 0 { - if len(selected) == len(fields) && sameType { - // nothing changed, return the original - return &st - } else { - return &StructType{FieldList: selected} - } - } - - return nil -} - -func (p *pruneColVisitor) Field(field NestedField, fieldResult Type) Type { - _, ok := p.selected[field.ID] - if !ok { - if fieldResult != nil { - return fieldResult - } - - return nil - } - - if p.fullTypes { - return field.Type - } - - if _, ok := field.Type.(*StructType); ok { - return p.projectSelectedStruct(fieldResult) - } - - typ, ok := field.Type.(PrimitiveType) - if !ok { - panic(fmt.Errorf("%w: cannot explicitly project List or Map types, %d:%s of type %s was selected", - ErrInvalidSchema, field.ID, field.Name, field.Type)) - } - return typ -} - -func (p *pruneColVisitor) List(list ListType, elemResult Type) Type { - _, ok := p.selected[list.ElementID] - if !ok { - if elemResult != nil { - return p.projectList(&list, elemResult) - } - - return nil - } - - if p.fullTypes { - return &list - } - - _, ok = list.Element.(*StructType) - if list.Element != nil && ok { - projected := p.projectSelectedStruct(elemResult) - return p.projectList(&list, projected) - } - - if _, ok = list.Element.(PrimitiveType); !ok { - panic(fmt.Errorf("%w: cannot explicitly project List or Map types, %d of type %s was selected", - ErrInvalidSchema, list.ElementID, list.Element)) - } - - return &list -} - -func (p *pruneColVisitor) Map(mapType MapType, keyResult, valueResult Type) Type { - _, ok := p.selected[mapType.ValueID] - if !ok { - if valueResult != nil { - return p.projectMap(&mapType, valueResult) - } - - if _, ok = p.selected[mapType.KeyID]; ok { - return &mapType - } - - return nil - } - - if p.fullTypes { - return &mapType - } - - _, ok = mapType.ValueType.(*StructType) - if mapType.ValueType != nil && ok { - projected := p.projectSelectedStruct(valueResult) - return p.projectMap(&mapType, projected) - } - - if _, ok = mapType.ValueType.(PrimitiveType); !ok { - panic(fmt.Errorf("%w: cannot explicitly project List or Map types, Map value %d of type %s was selected", - ErrInvalidSchema, mapType.ValueID, mapType.ValueType)) - } - - return &mapType -} - -func (p *pruneColVisitor) Primitive(_ PrimitiveType) Type { return nil } - -func (*pruneColVisitor) projectSelectedStruct(projected Type) *StructType { - if projected == nil { - return &StructType{} - } - - if ty, ok := projected.(*StructType); ok { - return ty - } - - panic("expected a struct") -} - -func (*pruneColVisitor) projectList(listType *ListType, elementResult Type) *ListType { - if listType.Element.Equals(elementResult) { - return listType - } - - return &ListType{ElementID: listType.ElementID, Element: elementResult, - ElementRequired: listType.ElementRequired} -} - -func (*pruneColVisitor) projectMap(mapType *MapType, valueResult Type) *MapType { - if mapType.ValueType.Equals(valueResult) { - return mapType - } - - return &MapType{ - KeyID: mapType.KeyID, - ValueID: mapType.ValueID, - KeyType: mapType.KeyType, - ValueType: valueResult, - ValueRequired: mapType.ValueRequired, - } -} - -type findLastFieldID struct{} - -func (findLastFieldID) Schema(_ *Schema, result int) int { - return result -} - -func (findLastFieldID) Struct(_ StructType, fieldResults []int) int { - return max(fieldResults...) -} - -func (findLastFieldID) Field(field NestedField, fieldResult int) int { - return max(field.ID, fieldResult) -} - -func (findLastFieldID) List(_ ListType, elemResult int) int { return elemResult } - -func (findLastFieldID) Map(_ MapType, keyResult, valueResult int) int { - return max(keyResult, valueResult) -} - -func (findLastFieldID) Primitive(PrimitiveType) int { return 0 } - -// IndexParents generates an index of field IDs to their parent field -// IDs. Root fields are not indexed -func IndexParents(schema *Schema) (map[int]int, error) { - indexer := &indexParents{ - idToParent: make(map[int]int), - idStack: make([]int, 0), - } - return Visit(schema, indexer) -} - -type indexParents struct { - idToParent map[int]int - idStack []int -} - -func (i *indexParents) BeforeField(field NestedField) { - i.idStack = append(i.idStack, field.ID) -} - -func (i *indexParents) AfterField(field NestedField) { - i.idStack = i.idStack[:len(i.idStack)-1] -} - -func (i *indexParents) Schema(schema *Schema, _ map[int]int) map[int]int { - return i.idToParent -} - -func (i *indexParents) Struct(st StructType, _ []map[int]int) map[int]int { - var parent int - stackLen := len(i.idStack) - if stackLen > 0 { - parent = i.idStack[stackLen-1] - for _, f := range st.FieldList { - i.idToParent[f.ID] = parent - } - } - - return i.idToParent -} - -func (i *indexParents) Field(NestedField, map[int]int) map[int]int { - return i.idToParent -} - -func (i *indexParents) List(list ListType, _ map[int]int) map[int]int { - i.idToParent[list.ElementID] = i.idStack[len(i.idStack)-1] - return i.idToParent -} - -func (i *indexParents) Map(mapType MapType, _, _ map[int]int) map[int]int { - parent := i.idStack[len(i.idStack)-1] - i.idToParent[mapType.KeyID] = parent - i.idToParent[mapType.ValueID] = parent - return i.idToParent -} - -func (i *indexParents) Primitive(PrimitiveType) map[int]int { - return i.idToParent -} - -type buildPosAccessors struct{} - -func (buildPosAccessors) Schema(_ *Schema, structResult map[int]accessor) map[int]accessor { - return structResult -} - -func (buildPosAccessors) Struct(st StructType, fieldResults []map[int]accessor) map[int]accessor { - result := map[int]accessor{} - for pos, f := range st.FieldList { - if innerMap := fieldResults[pos]; len(innerMap) != 0 { - for inner, acc := range innerMap { - acc := acc - result[inner] = accessor{pos: pos, inner: &acc} - } - } else { - result[f.ID] = accessor{pos: pos} - } - } - return result -} - -func (buildPosAccessors) Field(_ NestedField, fieldResult map[int]accessor) map[int]accessor { - return fieldResult -} - -func (buildPosAccessors) List(ListType, map[int]accessor) map[int]accessor { - return map[int]accessor{} -} - -func (buildPosAccessors) Map(_ MapType, _, _ map[int]accessor) map[int]accessor { - return map[int]accessor{} -} - -func (buildPosAccessors) Primitive(PrimitiveType) map[int]accessor { - return map[int]accessor{} -} - -func buildAccessors(schema *Schema) (map[int]accessor, error) { - return Visit(schema, buildPosAccessors{}) -} - -type SchemaWithPartnerVisitor[T, P any] interface { - Schema(sc *Schema, schemaPartner P, structResult T) T - Struct(st StructType, structPartner P, fieldResults []T) T - Field(field NestedField, fieldPartner P, fieldResult T) T - List(l ListType, listPartner P, elemResult T) T - Map(m MapType, mapPartner P, keyResult, valResult T) T - Primitive(p PrimitiveType, primitivePartner P) T -} - -type PartnerAccessor[P any] interface { - SchemaPartner(P) P - FieldPartner(partnerStruct P, fieldID int, fieldName string) P - ListElementPartner(P) P - MapKeyPartner(P) P - MapValuePartner(P) P -} - -func VisitSchemaWithPartner[T, P any](sc *Schema, partner P, visitor SchemaWithPartnerVisitor[T, P], accessor PartnerAccessor[P]) (res T, err error) { - if sc == nil { - err = fmt.Errorf("%w: cannot visit nil schema", ErrInvalidArgument) - return - } - - if visitor == nil || accessor == nil { - err = fmt.Errorf("%w: cannot visit with nil visitor or accessor", ErrInvalidArgument) - return - } - - defer func() { - if r := recover(); r != nil { - switch e := r.(type) { - case string: - err = fmt.Errorf("error encountered during schema visitor: %s", e) - case error: - err = fmt.Errorf("error encountered during schema visitor: %w", e) - } - } - }() - - structPartner := accessor.SchemaPartner(partner) - return visitor.Schema(sc, partner, visitStructWithPartner(sc.AsStruct(), structPartner, visitor, accessor)), nil -} - -func visitStructWithPartner[T, P any](st StructType, partner P, visitor SchemaWithPartnerVisitor[T, P], accessor PartnerAccessor[P]) T { - type ( - beforeField interface { - BeforeField(NestedField, P) - } - afterField interface { - AfterField(NestedField, P) - } - ) - - bf, _ := visitor.(beforeField) - af, _ := visitor.(afterField) - - fieldResults := make([]T, len(st.FieldList)) - - for i, f := range st.FieldList { - fieldPartner := accessor.FieldPartner(partner, f.ID, f.Name) - if bf != nil { - bf.BeforeField(f, fieldPartner) - } - fieldResult := visitTypeWithPartner(f.Type, fieldPartner, visitor, accessor) - fieldResults[i] = visitor.Field(f, fieldPartner, fieldResult) - if af != nil { - af.AfterField(f, fieldPartner) - } - } - - return visitor.Struct(st, partner, fieldResults) -} - -func visitListWithPartner[T, P any](listType ListType, partner P, visitor SchemaWithPartnerVisitor[T, P], accessor PartnerAccessor[P]) T { - type ( - beforeListElem interface { - BeforeListElement(NestedField, P) - } - afterListElem interface { - AfterListElement(NestedField, P) - } - ) - - elemPartner := accessor.ListElementPartner(partner) - if ble, ok := visitor.(beforeListElem); ok { - ble.BeforeListElement(listType.ElementField(), elemPartner) - } - elemResult := visitTypeWithPartner(listType.Element, elemPartner, visitor, accessor) - if ale, ok := visitor.(afterListElem); ok { - ale.AfterListElement(listType.ElementField(), elemPartner) - } - - return visitor.List(listType, partner, elemResult) -} - -func visitMapWithPartner[T, P any](m MapType, partner P, visitor SchemaWithPartnerVisitor[T, P], accessor PartnerAccessor[P]) T { - type ( - beforeMapKey interface { - BeforeMapKey(NestedField, P) - } - afterMapKey interface { - AfterMapKey(NestedField, P) - } - - beforeMapValue interface { - BeforeMapValue(NestedField, P) - } - afterMapValue interface { - AfterMapValue(NestedField, P) - } - ) - - keyPartner := accessor.MapKeyPartner(partner) - if bmk, ok := visitor.(beforeMapKey); ok { - bmk.BeforeMapKey(m.KeyField(), keyPartner) - } - keyResult := visitTypeWithPartner(m.KeyType, keyPartner, visitor, accessor) - if amk, ok := visitor.(afterMapKey); ok { - amk.AfterMapKey(m.KeyField(), keyPartner) - } - - valPartner := accessor.MapValuePartner(partner) - if bmv, ok := visitor.(beforeMapValue); ok { - bmv.BeforeMapValue(m.ValueField(), valPartner) - } - valResult := visitTypeWithPartner(m.ValueType, valPartner, visitor, accessor) - if amv, ok := visitor.(afterMapValue); ok { - amv.AfterMapValue(m.ValueField(), valPartner) - } - - return visitor.Map(m, partner, keyResult, valResult) -} - -func visitTypeWithPartner[T, P any](t Type, fieldPartner P, visitor SchemaWithPartnerVisitor[T, P], accessor PartnerAccessor[P]) T { - switch t := t.(type) { - case *ListType: - return visitListWithPartner(*t, fieldPartner, visitor, accessor) - case *StructType: - return visitStructWithPartner(*t, fieldPartner, visitor, accessor) - case *MapType: - return visitMapWithPartner(*t, fieldPartner, visitor, accessor) - default: - return visitor.Primitive(t.(PrimitiveType), fieldPartner) - } -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package iceberg + +import ( + "encoding/json" + "fmt" + "maps" + "strings" + "sync" + "sync/atomic" + + "golang.org/x/exp/slices" +) + +// Schema is an Iceberg table schema, represented as a struct with +// multiple fields. The fields are only exported via accessor methods +// rather than exposing the slice directly in order to ensure a schema +// as immutable. +type Schema struct { + ID int `json:"schema-id"` + IdentifierFieldIDs []int `json:"identifier-field-ids"` + + fields []NestedField + + // the following maps are lazily populated as needed. + // rather than have lock contention with a mutex, we can use + // atomic pointers to Store/Load the values. + idToName atomic.Pointer[map[int]string] + idToField atomic.Pointer[map[int]NestedField] + nameToID atomic.Pointer[map[string]int] + nameToIDLower atomic.Pointer[map[string]int] + idToAccessor atomic.Pointer[map[int]accessor] + + lazyIDToParent func() (map[int]int, error) +} + +// NewSchema constructs a new schema with the provided ID +// and list of fields. +func NewSchema(id int, fields ...NestedField) *Schema { + return NewSchemaWithIdentifiers(id, []int{}, fields...) +} + +// NewSchemaWithIdentifiers constructs a new schema with the provided ID +// and fields, along with a slice of field IDs to be listed as identifier +// fields. +func NewSchemaWithIdentifiers(id int, identifierIDs []int, fields ...NestedField) *Schema { + s := &Schema{ID: id, fields: fields, IdentifierFieldIDs: identifierIDs} + s.lazyIDToParent = sync.OnceValues(func() (map[int]int, error) { + return IndexParents(s) + }) + return s +} + +func (s *Schema) String() string { + var b strings.Builder + b.WriteString("table {") + for _, f := range s.fields { + b.WriteString("\n\t") + b.WriteString(f.String()) + } + b.WriteString("\n}") + return b.String() +} + +func (s *Schema) lazyNameToID() (map[string]int, error) { + index := s.nameToID.Load() + if index != nil { + return *index, nil + } + + idx, err := IndexByName(s) + if err != nil { + return nil, err + } + + s.nameToID.Store(&idx) + return idx, nil +} + +func (s *Schema) lazyIDToField() (map[int]NestedField, error) { + index := s.idToField.Load() + if index != nil { + return *index, nil + } + + idx, err := IndexByID(s) + if err != nil { + return nil, err + } + + s.idToField.Store(&idx) + return idx, nil +} + +func (s *Schema) lazyIDToName() (map[int]string, error) { + index := s.idToName.Load() + if index != nil { + return *index, nil + } + + idx, err := IndexNameByID(s) + if err != nil { + return nil, err + } + + s.idToName.Store(&idx) + return idx, nil +} + +func (s *Schema) lazyNameToIDLower() (map[string]int, error) { + index := s.nameToIDLower.Load() + if index != nil { + return *index, nil + } + + idx, err := s.lazyNameToID() + if err != nil { + return nil, err + } + + out := make(map[string]int) + for k, v := range idx { + out[strings.ToLower(k)] = v + } + + s.nameToIDLower.Store(&out) + return out, nil +} + +func (s *Schema) lazyIdToAccessor() (map[int]accessor, error) { + index := s.idToAccessor.Load() + if index != nil { + return *index, nil + } + + idx, err := buildAccessors(s) + if err != nil { + return nil, err + } + + s.idToAccessor.Store(&idx) + return idx, nil +} + +func (s *Schema) Type() string { return "struct" } + +// AsStruct returns a Struct with the same fields as the schema which can +// then be used as a Type. +func (s *Schema) AsStruct() StructType { return StructType{FieldList: s.fields} } +func (s *Schema) NumFields() int { return len(s.fields) } +func (s *Schema) Field(i int) NestedField { return s.fields[i] } +func (s *Schema) Fields() []NestedField { return slices.Clone(s.fields) } + +func (s *Schema) UnmarshalJSON(b []byte) error { + type Alias Schema + aux := struct { + Fields []NestedField `json:"fields"` + *Alias + }{Alias: (*Alias)(s)} + + if err := json.Unmarshal(b, &aux); err != nil { + return err + } + + if s.lazyIDToParent == nil { + s.lazyIDToParent = sync.OnceValues(func() (map[int]int, error) { + return IndexParents(s) + }) + } + + s.fields = aux.Fields + if s.IdentifierFieldIDs == nil { + s.IdentifierFieldIDs = []int{} + } + return nil +} + +func (s *Schema) MarshalJSON() ([]byte, error) { + if s.IdentifierFieldIDs == nil { + s.IdentifierFieldIDs = []int{} + } + + type Alias Schema + return json.Marshal(struct { + Type string `json:"type"` + Fields []NestedField `json:"fields"` + *Alias + }{Type: "struct", Fields: s.fields, Alias: (*Alias)(s)}) +} + +// FindColumnName returns the name of the column identified by the +// passed in field id. The second return value reports whether or +// not the field id was found in the schema. +func (s *Schema) FindColumnName(fieldID int) (string, bool) { + idx, _ := s.lazyIDToName() + col, ok := idx[fieldID] + return col, ok +} + +// FindFieldByName returns the field identified by the name given, +// the second return value will be false if no field by this name +// is found. +// +// Note: This search is done in a case sensitive manner. To perform +// a case insensitive search, use [*Schema.FindFieldByNameCaseInsensitive]. +func (s *Schema) FindFieldByName(name string) (NestedField, bool) { + idx, _ := s.lazyNameToID() + + id, ok := idx[name] + if !ok { + return NestedField{}, false + } + + return s.FindFieldByID(id) +} + +// FindFieldByNameCaseInsensitive is like [*Schema.FindFieldByName], +// but performs a case insensitive search. +func (s *Schema) FindFieldByNameCaseInsensitive(name string) (NestedField, bool) { + idx, _ := s.lazyNameToIDLower() + + id, ok := idx[strings.ToLower(name)] + if !ok { + return NestedField{}, false + } + + return s.FindFieldByID(id) +} + +// FindFieldByID is like [*Schema.FindColumnName], but returns the whole +// field rather than just the field name. +func (s *Schema) FindFieldByID(id int) (NestedField, bool) { + idx, _ := s.lazyIDToField() + f, ok := idx[id] + return f, ok +} + +// FindTypeByID is like [*Schema.FindFieldByID], but returns only the data +// type of the field. +func (s *Schema) FindTypeByID(id int) (Type, bool) { + f, ok := s.FindFieldByID(id) + if !ok { + return nil, false + } + + return f.Type, true +} + +// FindTypeByName is a convenience function for calling [*Schema.FindFieldByName], +// and then returning just the type. +func (s *Schema) FindTypeByName(name string) (Type, bool) { + f, ok := s.FindFieldByName(name) + if !ok { + return nil, false + } + + return f.Type, true +} + +// FindTypeByNameCaseInsensitive is like [*Schema.FindTypeByName] but +// performs a case insensitive search. +func (s *Schema) FindTypeByNameCaseInsensitive(name string) (Type, bool) { + f, ok := s.FindFieldByNameCaseInsensitive(name) + if !ok { + return nil, false + } + + return f.Type, true +} + +func (s *Schema) accessorForField(id int) (accessor, bool) { + idx, err := s.lazyIdToAccessor() + if err != nil { + return accessor{}, false + } + + acc, ok := idx[id] + return acc, ok +} + +// Equals compares the fields and identifierIDs, but does not compare +// the schema ID itself. +func (s *Schema) Equals(other *Schema) bool { + if other == nil { + return false + } + + if s == other { + return true + } + + if len(s.fields) != len(other.fields) { + return false + } + + if !slices.Equal(s.IdentifierFieldIDs, other.IdentifierFieldIDs) { + return false + } + + return slices.EqualFunc(s.fields, other.fields, func(a, b NestedField) bool { + return a.Equals(b) + }) +} + +// HighestFieldID returns the value of the numerically highest field ID +// in this schema. +func (s *Schema) HighestFieldID() int { + id, _ := Visit[int](s, findLastFieldID{}) + return id +} + +type Void = struct{} + +var void = Void{} + +// Select creates a new schema with just the fields identified by name +// passed in the order they are provided. If caseSensitive is false, +// then fields will be identified by case insensitive search. +// +// An error is returned if a requested name cannot be found. +func (s *Schema) Select(caseSensitive bool, names ...string) (*Schema, error) { + ids := make(map[int]Void) + if caseSensitive { + nameMap, _ := s.lazyNameToID() + for _, n := range names { + id, ok := nameMap[n] + if !ok { + return nil, fmt.Errorf("%w: could not find column %s", ErrInvalidSchema, n) + } + ids[id] = void + } + } else { + nameMap, _ := s.lazyNameToIDLower() + for _, n := range names { + id, ok := nameMap[strings.ToLower(n)] + if !ok { + return nil, fmt.Errorf("%w: could not find column %s", ErrInvalidSchema, n) + } + ids[id] = void + } + } + + return PruneColumns(s, ids, true) +} + +func (s *Schema) FieldHasOptionalParent(id int) bool { + idToParent, _ := s.lazyIDToParent() + idToField, _ := s.lazyIDToField() + + f, ok := idToField[id] + if !ok { + return false + } + + for { + parent, ok := idToParent[f.ID] + if !ok { + return false + } + + if f = idToField[parent]; !f.Required { + return true + } + } +} + +// SchemaVisitor is an interface that can be implemented to allow for +// easy traversal and processing of a schema. +// +// A SchemaVisitor can also optionally implement the Before/After Field, +// ListElement, MapKey, or MapValue interfaces to allow them to get called +// at the appropriate points within schema traversal. +type SchemaVisitor[T any] interface { + Schema(schema *Schema, structResult T) T + Struct(st StructType, fieldResults []T) T + Field(field NestedField, fieldResult T) T + List(list ListType, elemResult T) T + Map(mapType MapType, keyResult, valueResult T) T + Primitive(p PrimitiveType) T +} + +type BeforeFieldVisitor interface { + BeforeField(field NestedField) +} + +type AfterFieldVisitor interface { + AfterField(field NestedField) +} + +type BeforeListElementVisitor interface { + BeforeListElement(elem NestedField) +} + +type AfterListElementVisitor interface { + AfterListElement(elem NestedField) +} + +type BeforeMapKeyVisitor interface { + BeforeMapKey(key NestedField) +} + +type AfterMapKeyVisitor interface { + AfterMapKey(key NestedField) +} + +type BeforeMapValueVisitor interface { + BeforeMapValue(value NestedField) +} + +type AfterMapValueVisitor interface { + AfterMapValue(value NestedField) +} + +// Visit accepts a visitor and performs a post-order traversal of the given schema. +func Visit[T any](sc *Schema, visitor SchemaVisitor[T]) (res T, err error) { + if sc == nil { + err = fmt.Errorf("%w: cannot visit nil schema", ErrInvalidArgument) + return + } + + defer func() { + if r := recover(); r != nil { + switch e := r.(type) { + case string: + err = fmt.Errorf("error encountered during schema visitor: %s", e) + case error: + err = fmt.Errorf("error encountered during schema visitor: %w", e) + } + } + }() + + return visitor.Schema(sc, visitStruct(sc.AsStruct(), visitor)), nil +} + +func visitStruct[T any](obj StructType, visitor SchemaVisitor[T]) T { + results := make([]T, len(obj.FieldList)) + + bf, _ := visitor.(BeforeFieldVisitor) + af, _ := visitor.(AfterFieldVisitor) + + for i, f := range obj.FieldList { + if bf != nil { + bf.BeforeField(f) + } + + res := visitField(f, visitor) + + if af != nil { + af.AfterField(f) + } + + results[i] = visitor.Field(f, res) + } + + return visitor.Struct(obj, results) +} + +func visitList[T any](obj ListType, visitor SchemaVisitor[T]) T { + elemField := obj.ElementField() + + if bl, ok := visitor.(BeforeListElementVisitor); ok { + bl.BeforeListElement(elemField) + } else if bf, ok := visitor.(BeforeFieldVisitor); ok { + bf.BeforeField(elemField) + } + + res := visitField(elemField, visitor) + + if al, ok := visitor.(AfterListElementVisitor); ok { + al.AfterListElement(elemField) + } else if af, ok := visitor.(AfterFieldVisitor); ok { + af.AfterField(elemField) + } + + return visitor.List(obj, res) +} + +func visitMap[T any](obj MapType, visitor SchemaVisitor[T]) T { + keyField, valueField := obj.KeyField(), obj.ValueField() + + if bmk, ok := visitor.(BeforeMapKeyVisitor); ok { + bmk.BeforeMapKey(keyField) + } else if bf, ok := visitor.(BeforeFieldVisitor); ok { + bf.BeforeField(keyField) + } + + keyRes := visitField(keyField, visitor) + + if amk, ok := visitor.(AfterMapKeyVisitor); ok { + amk.AfterMapKey(keyField) + } else if af, ok := visitor.(AfterFieldVisitor); ok { + af.AfterField(keyField) + } + + if bmk, ok := visitor.(BeforeMapValueVisitor); ok { + bmk.BeforeMapValue(valueField) + } else if bf, ok := visitor.(BeforeFieldVisitor); ok { + bf.BeforeField(valueField) + } + + valueRes := visitField(valueField, visitor) + + if amk, ok := visitor.(AfterMapValueVisitor); ok { + amk.AfterMapValue(valueField) + } else if af, ok := visitor.(AfterFieldVisitor); ok { + af.AfterField(valueField) + } + + return visitor.Map(obj, keyRes, valueRes) +} + +func visitField[T any](f NestedField, visitor SchemaVisitor[T]) T { + switch typ := f.Type.(type) { + case *StructType: + return visitStruct(*typ, visitor) + case *ListType: + return visitList(*typ, visitor) + case *MapType: + return visitMap(*typ, visitor) + default: // primitive + return visitor.Primitive(typ.(PrimitiveType)) + } +} + +// IndexByID performs a post-order traversal of the given schema and +// returns a mapping from field ID to field. +func IndexByID(schema *Schema) (map[int]NestedField, error) { + return Visit[map[int]NestedField](schema, &indexByID{index: make(map[int]NestedField)}) +} + +type indexByID struct { + index map[int]NestedField +} + +func (i *indexByID) Schema(*Schema, map[int]NestedField) map[int]NestedField { + return i.index +} + +func (i *indexByID) Struct(StructType, []map[int]NestedField) map[int]NestedField { + return i.index +} + +func (i *indexByID) Field(field NestedField, _ map[int]NestedField) map[int]NestedField { + i.index[field.ID] = field + return i.index +} + +func (i *indexByID) List(list ListType, _ map[int]NestedField) map[int]NestedField { + i.index[list.ElementID] = list.ElementField() + return i.index +} + +func (i *indexByID) Map(mapType MapType, _, _ map[int]NestedField) map[int]NestedField { + i.index[mapType.KeyID] = mapType.KeyField() + i.index[mapType.ValueID] = mapType.ValueField() + return i.index +} + +func (i *indexByID) Primitive(PrimitiveType) map[int]NestedField { + return i.index +} + +// IndexByName performs a post-order traversal of the schema and returns +// a mapping from field name to field ID. +func IndexByName(schema *Schema) (map[string]int, error) { + if schema == nil { + return nil, fmt.Errorf("%w: cannot index nil schema", ErrInvalidArgument) + } + + if len(schema.fields) > 0 { + indexer := &indexByName{ + index: make(map[string]int), + shortNameId: make(map[string]int), + fieldNames: make([]string, 0), + shortFieldNames: make([]string, 0), + } + if _, err := Visit[map[string]int](schema, indexer); err != nil { + return nil, err + } + + return indexer.ByName(), nil + } + return map[string]int{}, nil +} + +// IndexNameByID performs a post-order traversal of the schema and returns +// a mapping from field ID to field name. +func IndexNameByID(schema *Schema) (map[int]string, error) { + indexer := &indexByName{ + index: make(map[string]int), + shortNameId: make(map[string]int), + fieldNames: make([]string, 0), + shortFieldNames: make([]string, 0), + } + if _, err := Visit[map[string]int](schema, indexer); err != nil { + return nil, err + } + return indexer.ByID(), nil +} + +type indexByName struct { + index map[string]int + shortNameId map[string]int + combinedIndex map[string]int + fieldNames []string + shortFieldNames []string +} + +func (i *indexByName) ByID() map[int]string { + idToName := make(map[int]string) + for k, v := range i.index { + idToName[v] = k + } + return idToName +} + +func (i *indexByName) ByName() map[string]int { + i.combinedIndex = maps.Clone(i.shortNameId) + maps.Copy(i.combinedIndex, i.index) + return i.combinedIndex +} + +func (i *indexByName) Primitive(PrimitiveType) map[string]int { return i.index } +func (i *indexByName) addField(name string, fieldID int) { + fullName := name + if len(i.fieldNames) > 0 { + fullName = strings.Join(i.fieldNames, ".") + "." + name + } + + if _, ok := i.index[fullName]; ok { + panic(fmt.Errorf("%w: multiple fields for name %s: %d and %d", + ErrInvalidSchema, fullName, i.index[fullName], fieldID)) + } + + i.index[fullName] = fieldID + if len(i.shortFieldNames) > 0 { + shortName := strings.Join(i.shortFieldNames, ".") + "." + name + i.shortNameId[shortName] = fieldID + } +} + +func (i *indexByName) Schema(*Schema, map[string]int) map[string]int { + return i.index +} + +func (i *indexByName) Struct(StructType, []map[string]int) map[string]int { + return i.index +} + +func (i *indexByName) Field(field NestedField, _ map[string]int) map[string]int { + i.addField(field.Name, field.ID) + return i.index +} + +func (i *indexByName) List(list ListType, _ map[string]int) map[string]int { + i.addField(list.ElementField().Name, list.ElementID) + return i.index +} + +func (i *indexByName) Map(mapType MapType, _, _ map[string]int) map[string]int { + i.addField(mapType.KeyField().Name, mapType.KeyID) + i.addField(mapType.ValueField().Name, mapType.ValueID) + return i.index +} + +func (i *indexByName) BeforeListElement(elem NestedField) { + if _, ok := elem.Type.(*StructType); !ok { + i.shortFieldNames = append(i.shortFieldNames, elem.Name) + } + i.fieldNames = append(i.fieldNames, elem.Name) +} + +func (i *indexByName) AfterListElement(elem NestedField) { + if _, ok := elem.Type.(*StructType); !ok { + i.shortFieldNames = i.shortFieldNames[:len(i.shortFieldNames)-1] + } + i.fieldNames = i.fieldNames[:len(i.fieldNames)-1] +} + +func (i *indexByName) BeforeField(field NestedField) { + i.fieldNames = append(i.fieldNames, field.Name) + i.shortFieldNames = append(i.shortFieldNames, field.Name) +} + +func (i *indexByName) AfterField(field NestedField) { + i.fieldNames = i.fieldNames[:len(i.fieldNames)-1] + i.shortFieldNames = i.shortFieldNames[:len(i.shortFieldNames)-1] +} + +// PruneColumns visits a schema pruning any columns which do not exist in the +// provided selected set. Parent fields of a selected child will be retained. +func PruneColumns(schema *Schema, selected map[int]Void, selectFullTypes bool) (*Schema, error) { + + result, err := Visit[Type](schema, &pruneColVisitor{selected: selected, + fullTypes: selectFullTypes}) + if err != nil { + return nil, err + } + + n, ok := result.(NestedType) + if !ok { + n = &StructType{} + } + + newIdentifierIDs := make([]int, 0, len(schema.IdentifierFieldIDs)) + for _, id := range schema.IdentifierFieldIDs { + if _, ok := selected[id]; ok { + newIdentifierIDs = append(newIdentifierIDs, id) + } + } + + return &Schema{ + fields: n.Fields(), + ID: schema.ID, + IdentifierFieldIDs: newIdentifierIDs, + }, nil +} + +type pruneColVisitor struct { + selected map[int]Void + fullTypes bool +} + +func (p *pruneColVisitor) Schema(_ *Schema, structResult Type) Type { + return structResult +} + +func (p *pruneColVisitor) Struct(st StructType, fieldResults []Type) Type { + selected, fields := []NestedField{}, st.FieldList + sameType := true + + for i, t := range fieldResults { + field := fields[i] + if field.Type == t { + selected = append(selected, field) + } else if t != nil { + sameType = false + // type has changed, create a new field with the projected type + selected = append(selected, NestedField{ + ID: field.ID, + Name: field.Name, + Type: t, + Doc: field.Doc, + Required: field.Required, + }) + } + } + + if len(selected) > 0 { + if len(selected) == len(fields) && sameType { + // nothing changed, return the original + return &st + } else { + return &StructType{FieldList: selected} + } + } + + return nil +} + +func (p *pruneColVisitor) Field(field NestedField, fieldResult Type) Type { + _, ok := p.selected[field.ID] + if !ok { + if fieldResult != nil { + return fieldResult + } + + return nil + } + + if p.fullTypes { + return field.Type + } + + if _, ok := field.Type.(*StructType); ok { + return p.projectSelectedStruct(fieldResult) + } + + typ, ok := field.Type.(PrimitiveType) + if !ok { + panic(fmt.Errorf("%w: cannot explicitly project List or Map types, %d:%s of type %s was selected", + ErrInvalidSchema, field.ID, field.Name, field.Type)) + } + return typ +} + +func (p *pruneColVisitor) List(list ListType, elemResult Type) Type { + _, ok := p.selected[list.ElementID] + if !ok { + if elemResult != nil { + return p.projectList(&list, elemResult) + } + + return nil + } + + if p.fullTypes { + return &list + } + + _, ok = list.Element.(*StructType) + if list.Element != nil && ok { + projected := p.projectSelectedStruct(elemResult) + return p.projectList(&list, projected) + } + + if _, ok = list.Element.(PrimitiveType); !ok { + panic(fmt.Errorf("%w: cannot explicitly project List or Map types, %d of type %s was selected", + ErrInvalidSchema, list.ElementID, list.Element)) + } + + return &list +} + +func (p *pruneColVisitor) Map(mapType MapType, keyResult, valueResult Type) Type { + _, ok := p.selected[mapType.ValueID] + if !ok { + if valueResult != nil { + return p.projectMap(&mapType, valueResult) + } + + if _, ok = p.selected[mapType.KeyID]; ok { + return &mapType + } + + return nil + } + + if p.fullTypes { + return &mapType + } + + _, ok = mapType.ValueType.(*StructType) + if mapType.ValueType != nil && ok { + projected := p.projectSelectedStruct(valueResult) + return p.projectMap(&mapType, projected) + } + + if _, ok = mapType.ValueType.(PrimitiveType); !ok { + panic(fmt.Errorf("%w: cannot explicitly project List or Map types, Map value %d of type %s was selected", + ErrInvalidSchema, mapType.ValueID, mapType.ValueType)) + } + + return &mapType +} + +func (p *pruneColVisitor) Primitive(_ PrimitiveType) Type { return nil } + +func (*pruneColVisitor) projectSelectedStruct(projected Type) *StructType { + if projected == nil { + return &StructType{} + } + + if ty, ok := projected.(*StructType); ok { + return ty + } + + panic("expected a struct") +} + +func (*pruneColVisitor) projectList(listType *ListType, elementResult Type) *ListType { + if listType.Element.Equals(elementResult) { + return listType + } + + return &ListType{ElementID: listType.ElementID, Element: elementResult, + ElementRequired: listType.ElementRequired} +} + +func (*pruneColVisitor) projectMap(mapType *MapType, valueResult Type) *MapType { + if mapType.ValueType.Equals(valueResult) { + return mapType + } + + return &MapType{ + KeyID: mapType.KeyID, + ValueID: mapType.ValueID, + KeyType: mapType.KeyType, + ValueType: valueResult, + ValueRequired: mapType.ValueRequired, + } +} + +type findLastFieldID struct{} + +func (findLastFieldID) Schema(_ *Schema, result int) int { + return result +} + +func (findLastFieldID) Struct(_ StructType, fieldResults []int) int { + return max(fieldResults...) +} + +func (findLastFieldID) Field(field NestedField, fieldResult int) int { + return max(field.ID, fieldResult) +} + +func (findLastFieldID) List(_ ListType, elemResult int) int { return elemResult } + +func (findLastFieldID) Map(_ MapType, keyResult, valueResult int) int { + return max(keyResult, valueResult) +} + +func (findLastFieldID) Primitive(PrimitiveType) int { return 0 } + +// IndexParents generates an index of field IDs to their parent field +// IDs. Root fields are not indexed +func IndexParents(schema *Schema) (map[int]int, error) { + indexer := &indexParents{ + idToParent: make(map[int]int), + idStack: make([]int, 0), + } + return Visit(schema, indexer) +} + +type indexParents struct { + idToParent map[int]int + idStack []int +} + +func (i *indexParents) BeforeField(field NestedField) { + i.idStack = append(i.idStack, field.ID) +} + +func (i *indexParents) AfterField(field NestedField) { + i.idStack = i.idStack[:len(i.idStack)-1] +} + +func (i *indexParents) Schema(schema *Schema, _ map[int]int) map[int]int { + return i.idToParent +} + +func (i *indexParents) Struct(st StructType, _ []map[int]int) map[int]int { + var parent int + stackLen := len(i.idStack) + if stackLen > 0 { + parent = i.idStack[stackLen-1] + for _, f := range st.FieldList { + i.idToParent[f.ID] = parent + } + } + + return i.idToParent +} + +func (i *indexParents) Field(NestedField, map[int]int) map[int]int { + return i.idToParent +} + +func (i *indexParents) List(list ListType, _ map[int]int) map[int]int { + i.idToParent[list.ElementID] = i.idStack[len(i.idStack)-1] + return i.idToParent +} + +func (i *indexParents) Map(mapType MapType, _, _ map[int]int) map[int]int { + parent := i.idStack[len(i.idStack)-1] + i.idToParent[mapType.KeyID] = parent + i.idToParent[mapType.ValueID] = parent + return i.idToParent +} + +func (i *indexParents) Primitive(PrimitiveType) map[int]int { + return i.idToParent +} + +type buildPosAccessors struct{} + +func (buildPosAccessors) Schema(_ *Schema, structResult map[int]accessor) map[int]accessor { + return structResult +} + +func (buildPosAccessors) Struct(st StructType, fieldResults []map[int]accessor) map[int]accessor { + result := map[int]accessor{} + for pos, f := range st.FieldList { + if innerMap := fieldResults[pos]; len(innerMap) != 0 { + for inner, acc := range innerMap { + acc := acc + result[inner] = accessor{pos: pos, inner: &acc} + } + } else { + result[f.ID] = accessor{pos: pos} + } + } + return result +} + +func (buildPosAccessors) Field(_ NestedField, fieldResult map[int]accessor) map[int]accessor { + return fieldResult +} + +func (buildPosAccessors) List(ListType, map[int]accessor) map[int]accessor { + return map[int]accessor{} +} + +func (buildPosAccessors) Map(_ MapType, _, _ map[int]accessor) map[int]accessor { + return map[int]accessor{} +} + +func (buildPosAccessors) Primitive(PrimitiveType) map[int]accessor { + return map[int]accessor{} +} + +func buildAccessors(schema *Schema) (map[int]accessor, error) { + return Visit(schema, buildPosAccessors{}) +} + +type SchemaWithPartnerVisitor[T, P any] interface { + Schema(sc *Schema, schemaPartner P, structResult T) T + Struct(st StructType, structPartner P, fieldResults []T) T + Field(field NestedField, fieldPartner P, fieldResult T) T + List(l ListType, listPartner P, elemResult T) T + Map(m MapType, mapPartner P, keyResult, valResult T) T + Primitive(p PrimitiveType, primitivePartner P) T +} + +type PartnerAccessor[P any] interface { + SchemaPartner(P) P + FieldPartner(partnerStruct P, fieldID int, fieldName string) P + ListElementPartner(P) P + MapKeyPartner(P) P + MapValuePartner(P) P +} + +func VisitSchemaWithPartner[T, P any](sc *Schema, partner P, visitor SchemaWithPartnerVisitor[T, P], accessor PartnerAccessor[P]) (res T, err error) { + if sc == nil { + err = fmt.Errorf("%w: cannot visit nil schema", ErrInvalidArgument) + return + } + + if visitor == nil || accessor == nil { + err = fmt.Errorf("%w: cannot visit with nil visitor or accessor", ErrInvalidArgument) + return + } + + defer func() { + if r := recover(); r != nil { + switch e := r.(type) { + case string: + err = fmt.Errorf("error encountered during schema visitor: %s", e) + case error: + err = fmt.Errorf("error encountered during schema visitor: %w", e) + } + } + }() + + structPartner := accessor.SchemaPartner(partner) + return visitor.Schema(sc, partner, visitStructWithPartner(sc.AsStruct(), structPartner, visitor, accessor)), nil +} + +func visitStructWithPartner[T, P any](st StructType, partner P, visitor SchemaWithPartnerVisitor[T, P], accessor PartnerAccessor[P]) T { + type ( + beforeField interface { + BeforeField(NestedField, P) + } + afterField interface { + AfterField(NestedField, P) + } + ) + + bf, _ := visitor.(beforeField) + af, _ := visitor.(afterField) + + fieldResults := make([]T, len(st.FieldList)) + + for i, f := range st.FieldList { + fieldPartner := accessor.FieldPartner(partner, f.ID, f.Name) + if bf != nil { + bf.BeforeField(f, fieldPartner) + } + fieldResult := visitTypeWithPartner(f.Type, fieldPartner, visitor, accessor) + fieldResults[i] = visitor.Field(f, fieldPartner, fieldResult) + if af != nil { + af.AfterField(f, fieldPartner) + } + } + + return visitor.Struct(st, partner, fieldResults) +} + +func visitListWithPartner[T, P any](listType ListType, partner P, visitor SchemaWithPartnerVisitor[T, P], accessor PartnerAccessor[P]) T { + type ( + beforeListElem interface { + BeforeListElement(NestedField, P) + } + afterListElem interface { + AfterListElement(NestedField, P) + } + ) + + elemPartner := accessor.ListElementPartner(partner) + if ble, ok := visitor.(beforeListElem); ok { + ble.BeforeListElement(listType.ElementField(), elemPartner) + } + elemResult := visitTypeWithPartner(listType.Element, elemPartner, visitor, accessor) + if ale, ok := visitor.(afterListElem); ok { + ale.AfterListElement(listType.ElementField(), elemPartner) + } + + return visitor.List(listType, partner, elemResult) +} + +func visitMapWithPartner[T, P any](m MapType, partner P, visitor SchemaWithPartnerVisitor[T, P], accessor PartnerAccessor[P]) T { + type ( + beforeMapKey interface { + BeforeMapKey(NestedField, P) + } + afterMapKey interface { + AfterMapKey(NestedField, P) + } + + beforeMapValue interface { + BeforeMapValue(NestedField, P) + } + afterMapValue interface { + AfterMapValue(NestedField, P) + } + ) + + keyPartner := accessor.MapKeyPartner(partner) + if bmk, ok := visitor.(beforeMapKey); ok { + bmk.BeforeMapKey(m.KeyField(), keyPartner) + } + keyResult := visitTypeWithPartner(m.KeyType, keyPartner, visitor, accessor) + if amk, ok := visitor.(afterMapKey); ok { + amk.AfterMapKey(m.KeyField(), keyPartner) + } + + valPartner := accessor.MapValuePartner(partner) + if bmv, ok := visitor.(beforeMapValue); ok { + bmv.BeforeMapValue(m.ValueField(), valPartner) + } + valResult := visitTypeWithPartner(m.ValueType, valPartner, visitor, accessor) + if amv, ok := visitor.(afterMapValue); ok { + amv.AfterMapValue(m.ValueField(), valPartner) + } + + return visitor.Map(m, partner, keyResult, valResult) +} + +func visitTypeWithPartner[T, P any](t Type, fieldPartner P, visitor SchemaWithPartnerVisitor[T, P], accessor PartnerAccessor[P]) T { + switch t := t.(type) { + case *ListType: + return visitListWithPartner(*t, fieldPartner, visitor, accessor) + case *StructType: + return visitStructWithPartner(*t, fieldPartner, visitor, accessor) + case *MapType: + return visitMapWithPartner(*t, fieldPartner, visitor, accessor) + default: + return visitor.Primitive(t.(PrimitiveType), fieldPartner) + } +} diff --git a/schema_test.go b/schema_test.go index 4e8e746..d1b9311 100644 --- a/schema_test.go +++ b/schema_test.go @@ -1,758 +1,758 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package iceberg_test - -import ( - "encoding/json" - "strings" - "testing" - - "github.com/apache/iceberg-go" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -var ( - tableSchemaNested = iceberg.NewSchemaWithIdentifiers(1, - []int{1}, - iceberg.NestedField{ - ID: 1, Name: "foo", Type: iceberg.PrimitiveTypes.String, Required: false}, - iceberg.NestedField{ - ID: 2, Name: "bar", Type: iceberg.PrimitiveTypes.Int32, Required: true}, - iceberg.NestedField{ - ID: 3, Name: "baz", Type: iceberg.PrimitiveTypes.Bool, Required: false}, - iceberg.NestedField{ - ID: 4, Name: "qux", Required: true, Type: &iceberg.ListType{ - ElementID: 5, Element: iceberg.PrimitiveTypes.String, ElementRequired: true}}, - iceberg.NestedField{ - ID: 6, Name: "quux", - Type: &iceberg.MapType{ - KeyID: 7, - KeyType: iceberg.PrimitiveTypes.String, - ValueID: 8, - ValueType: &iceberg.MapType{ - KeyID: 9, - KeyType: iceberg.PrimitiveTypes.String, - ValueID: 10, - ValueType: iceberg.PrimitiveTypes.Int32, - ValueRequired: true, - }, - ValueRequired: true, - }, - Required: true}, - iceberg.NestedField{ - ID: 11, Name: "location", Type: &iceberg.ListType{ - ElementID: 12, Element: &iceberg.StructType{ - FieldList: []iceberg.NestedField{ - {ID: 13, Name: "latitude", Type: iceberg.PrimitiveTypes.Float32, Required: false}, - {ID: 14, Name: "longitude", Type: iceberg.PrimitiveTypes.Float32, Required: false}, - }, - }, - ElementRequired: true}, - Required: true}, - iceberg.NestedField{ - ID: 15, - Name: "person", - Type: &iceberg.StructType{ - FieldList: []iceberg.NestedField{ - {ID: 16, Name: "name", Type: iceberg.PrimitiveTypes.String, Required: false}, - {ID: 17, Name: "age", Type: iceberg.PrimitiveTypes.Int32, Required: true}, - }, - }, - Required: false, - }, - ) - - tableSchemaSimple = iceberg.NewSchemaWithIdentifiers(1, - []int{2}, - iceberg.NestedField{ID: 1, Name: "foo", Type: iceberg.PrimitiveTypes.String}, - iceberg.NestedField{ID: 2, Name: "bar", Type: iceberg.PrimitiveTypes.Int32, Required: true}, - iceberg.NestedField{ID: 3, Name: "baz", Type: iceberg.PrimitiveTypes.Bool}, - ) -) - -func TestSchemaToString(t *testing.T) { - assert.Equal(t, 3, tableSchemaSimple.NumFields()) - assert.Equal(t, `table { - 1: foo: optional string - 2: bar: required int - 3: baz: optional boolean -}`, tableSchemaSimple.String()) -} - -func TestNestedFieldToString(t *testing.T) { - tests := []struct { - idx int - expected string - }{ - {0, "1: foo: optional string"}, - {1, "2: bar: required int"}, - {2, "3: baz: optional boolean"}, - {3, "4: qux: required list"}, - {4, "6: quux: required map>"}, - {5, "11: location: required list>"}, - {6, "15: person: optional struct<16: name: optional string, 17: age: required int>"}, - } - - for _, tt := range tests { - assert.Equal(t, tt.expected, tableSchemaNested.Field(tt.idx).String()) - } -} - -func TestSchemaIndexByIDVisitor(t *testing.T) { - index, err := iceberg.IndexByID(tableSchemaNested) - require.NoError(t, err) - - assert.Equal(t, map[int]iceberg.NestedField{ - 1: tableSchemaNested.Field(0), - 2: tableSchemaNested.Field(1), - 3: tableSchemaNested.Field(2), - 4: tableSchemaNested.Field(3), - 5: {ID: 5, Name: "element", Type: iceberg.PrimitiveTypes.String, Required: true}, - 6: tableSchemaNested.Field(4), - 7: {ID: 7, Name: "key", Type: iceberg.PrimitiveTypes.String, Required: true}, - 8: {ID: 8, Name: "value", Type: &iceberg.MapType{ - KeyID: 9, - KeyType: iceberg.PrimitiveTypes.String, - ValueID: 10, - ValueType: iceberg.PrimitiveTypes.Int32, - ValueRequired: true, - }, Required: true}, - 9: {ID: 9, Name: "key", Type: iceberg.PrimitiveTypes.String, Required: true}, - 10: {ID: 10, Name: "value", Type: iceberg.PrimitiveTypes.Int32, Required: true}, - 11: tableSchemaNested.Field(5), - 12: {ID: 12, Name: "element", Type: &iceberg.StructType{ - FieldList: []iceberg.NestedField{ - {ID: 13, Name: "latitude", Type: iceberg.PrimitiveTypes.Float32, Required: false}, - {ID: 14, Name: "longitude", Type: iceberg.PrimitiveTypes.Float32, Required: false}, - }, - }, Required: true}, - 13: {ID: 13, Name: "latitude", Type: iceberg.PrimitiveTypes.Float32, Required: false}, - 14: {ID: 14, Name: "longitude", Type: iceberg.PrimitiveTypes.Float32, Required: false}, - 15: tableSchemaNested.Field(6), - 16: {ID: 16, Name: "name", Type: iceberg.PrimitiveTypes.String, Required: false}, - 17: {ID: 17, Name: "age", Type: iceberg.PrimitiveTypes.Int32, Required: true}, - }, index) -} - -func TestSchemaIndexByName(t *testing.T) { - index, err := iceberg.IndexByName(tableSchemaNested) - require.NoError(t, err) - - assert.Equal(t, map[string]int{ - "foo": 1, - "bar": 2, - "baz": 3, - "qux": 4, - "qux.element": 5, - "quux": 6, - "quux.key": 7, - "quux.value": 8, - "quux.value.key": 9, - "quux.value.value": 10, - "location": 11, - "location.element": 12, - "location.element.latitude": 13, - "location.element.longitude": 14, - "location.latitude": 13, - "location.longitude": 14, - "person": 15, - "person.name": 16, - "person.age": 17, - }, index) -} - -func TestSchemaFindColumnName(t *testing.T) { - tests := []struct { - id int - name string - }{ - {1, "foo"}, - {2, "bar"}, - {3, "baz"}, - {4, "qux"}, - {5, "qux.element"}, - {6, "quux"}, - {7, "quux.key"}, - {8, "quux.value"}, - {9, "quux.value.key"}, - {10, "quux.value.value"}, - {11, "location"}, - {12, "location.element"}, - {13, "location.element.latitude"}, - {14, "location.element.longitude"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - n, ok := tableSchemaNested.FindColumnName(tt.id) - assert.True(t, ok) - assert.Equal(t, tt.name, n) - }) - } -} - -func TestSchemaFindColumnNameIDNotFound(t *testing.T) { - n, ok := tableSchemaNested.FindColumnName(99) - assert.False(t, ok) - assert.Empty(t, n) -} - -func TestSchemaFindColumnNameByID(t *testing.T) { - tests := []struct { - id int - name string - }{ - {1, "foo"}, - {2, "bar"}, - {3, "baz"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - n, ok := tableSchemaSimple.FindColumnName(tt.id) - assert.True(t, ok) - assert.Equal(t, tt.name, n) - }) - } -} - -func TestSchemaFindFieldByID(t *testing.T) { - index, err := iceberg.IndexByID(tableSchemaSimple) - require.NoError(t, err) - - col1 := index[1] - assert.Equal(t, 1, col1.ID) - assert.Equal(t, iceberg.PrimitiveTypes.String, col1.Type) - assert.False(t, col1.Required) - - col2 := index[2] - assert.Equal(t, 2, col2.ID) - assert.Equal(t, iceberg.PrimitiveTypes.Int32, col2.Type) - assert.True(t, col2.Required) - - col3 := index[3] - assert.Equal(t, 3, col3.ID) - assert.Equal(t, iceberg.PrimitiveTypes.Bool, col3.Type) - assert.False(t, col3.Required) -} - -func TestFindFieldByIDUnknownField(t *testing.T) { - index, err := iceberg.IndexByID(tableSchemaSimple) - require.NoError(t, err) - _, ok := index[4] - assert.False(t, ok) -} - -func TestSchemaFindField(t *testing.T) { - tests := []iceberg.NestedField{ - {ID: 1, Name: "foo", Type: iceberg.PrimitiveTypes.String, Required: false}, - {ID: 2, Name: "bar", Type: iceberg.PrimitiveTypes.Int32, Required: true}, - {ID: 3, Name: "baz", Type: iceberg.PrimitiveTypes.Bool, Required: false}, - } - - for _, tt := range tests { - t.Run(tt.Name, func(t *testing.T) { - f, ok := tableSchemaSimple.FindFieldByID(tt.ID) - assert.True(t, ok) - assert.Equal(t, tt, f) - - f, ok = tableSchemaSimple.FindFieldByName(tt.Name) - assert.True(t, ok) - assert.Equal(t, tt, f) - - f, ok = tableSchemaSimple.FindFieldByNameCaseInsensitive(strings.ToUpper(tt.Name)) - assert.True(t, ok) - assert.Equal(t, tt, f) - }) - } -} - -func TestSchemaFindType(t *testing.T) { - _, ok := tableSchemaSimple.FindTypeByID(0) - assert.False(t, ok) - _, ok = tableSchemaSimple.FindTypeByName("FOOBAR") - assert.False(t, ok) - _, ok = tableSchemaSimple.FindTypeByNameCaseInsensitive("FOOBAR") - assert.False(t, ok) - - tests := []struct { - id int - name string - typ iceberg.Type - }{ - {1, "foo", iceberg.PrimitiveTypes.String}, - {2, "bar", iceberg.PrimitiveTypes.Int32}, - {3, "baz", iceberg.PrimitiveTypes.Bool}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - typ, ok := tableSchemaSimple.FindTypeByID(tt.id) - assert.True(t, ok) - assert.Equal(t, tt.typ, typ) - - typ, ok = tableSchemaSimple.FindTypeByName(tt.name) - assert.True(t, ok) - assert.Equal(t, tt.typ, typ) - - typ, ok = tableSchemaSimple.FindTypeByNameCaseInsensitive(strings.ToUpper(tt.name)) - assert.True(t, ok) - assert.Equal(t, tt.typ, typ) - }) - } -} - -func TestSerializeSchema(t *testing.T) { - data, err := json.Marshal(tableSchemaSimple) - require.NoError(t, err) - - assert.JSONEq(t, `{ - "type": "struct", - "fields": [ - {"id": 1, "name": "foo", "type": "string", "required": false}, - {"id": 2, "name": "bar", "type": "int", "required": true}, - {"id": 3, "name": "baz", "type": "boolean", "required": false} - ], - "schema-id": 1, - "identifier-field-ids": [2] - }`, string(data)) -} - -func TestUnmarshalSchema(t *testing.T) { - var schema iceberg.Schema - require.NoError(t, json.Unmarshal([]byte(`{ - "type": "struct", - "fields": [ - {"id": 1, "name": "foo", "type": "string", "required": false}, - {"id": 2, "name": "bar", "type": "int", "required": true}, - {"id": 3, "name": "baz", "type": "boolean", "required": false} - ], - "schema-id": 1, - "identifier-field-ids": [2] - }`), &schema)) - - assert.True(t, tableSchemaSimple.Equals(&schema)) -} - -func TestPruneColumnsString(t *testing.T) { - sc, err := iceberg.PruneColumns(tableSchemaNested, map[int]iceberg.Void{1: {}}, false) - require.NoError(t, err) - - assert.True(t, sc.Equals(iceberg.NewSchemaWithIdentifiers(1, []int{1}, - iceberg.NestedField{ID: 1, Name: "foo", Type: iceberg.PrimitiveTypes.String, Required: false}))) -} - -func TestPruneColumnsStringFull(t *testing.T) { - sc, err := iceberg.PruneColumns(tableSchemaNested, map[int]iceberg.Void{1: {}}, true) - require.NoError(t, err) - - assert.True(t, sc.Equals(iceberg.NewSchemaWithIdentifiers(1, []int{1}, - iceberg.NestedField{ID: 1, Name: "foo", Type: iceberg.PrimitiveTypes.String, Required: false}))) -} - -func TestPruneColumnsList(t *testing.T) { - sc, err := iceberg.PruneColumns(tableSchemaNested, map[int]iceberg.Void{5: {}}, false) - require.NoError(t, err) - - assert.True(t, sc.Equals(iceberg.NewSchema(1, - iceberg.NestedField{ID: 4, Name: "qux", Required: true, Type: &iceberg.ListType{ - ElementID: 5, Element: iceberg.PrimitiveTypes.String, ElementRequired: true, - }}))) -} - -func TestPruneColumnsListItself(t *testing.T) { - _, err := iceberg.PruneColumns(tableSchemaNested, map[int]iceberg.Void{4: {}}, false) - assert.ErrorIs(t, err, iceberg.ErrInvalidSchema) - - assert.ErrorContains(t, err, "cannot explicitly project List or Map types, 4:qux of type list was selected") -} - -func TestPruneColumnsListFull(t *testing.T) { - sc, err := iceberg.PruneColumns(tableSchemaNested, map[int]iceberg.Void{5: {}}, true) - require.NoError(t, err) - - assert.True(t, sc.Equals(iceberg.NewSchema(1, - iceberg.NestedField{ID: 4, Name: "qux", Required: true, Type: &iceberg.ListType{ - ElementID: 5, Element: iceberg.PrimitiveTypes.String, ElementRequired: true, - }}))) -} - -func TestPruneColumnsMap(t *testing.T) { - sc, err := iceberg.PruneColumns(tableSchemaNested, map[int]iceberg.Void{9: {}}, false) - require.NoError(t, err) - - assert.True(t, sc.Equals(iceberg.NewSchema(1, - iceberg.NestedField{ - ID: 6, - Name: "quux", - Required: true, - Type: &iceberg.MapType{ - KeyID: 7, - KeyType: iceberg.PrimitiveTypes.String, - ValueID: 8, - ValueType: &iceberg.MapType{ - KeyID: 9, - KeyType: iceberg.PrimitiveTypes.String, - ValueID: 10, - ValueType: iceberg.PrimitiveTypes.Int32, - ValueRequired: true, - }, - ValueRequired: true, - }, - }))) -} - -func TestPruneColumnsMapItself(t *testing.T) { - _, err := iceberg.PruneColumns(tableSchemaNested, map[int]iceberg.Void{6: {}}, false) - assert.ErrorIs(t, err, iceberg.ErrInvalidSchema) - assert.ErrorContains(t, err, "cannot explicitly project List or Map types, 6:quux of type map> was selected") -} - -func TestPruneColumnsMapFull(t *testing.T) { - sc, err := iceberg.PruneColumns(tableSchemaNested, map[int]iceberg.Void{9: {}}, true) - require.NoError(t, err) - - assert.True(t, sc.Equals(iceberg.NewSchema(1, - iceberg.NestedField{ - ID: 6, - Name: "quux", - Required: true, - Type: &iceberg.MapType{ - KeyID: 7, - KeyType: iceberg.PrimitiveTypes.String, - ValueID: 8, - ValueType: &iceberg.MapType{ - KeyID: 9, - KeyType: iceberg.PrimitiveTypes.String, - ValueID: 10, - ValueType: iceberg.PrimitiveTypes.Int32, - ValueRequired: true, - }, - ValueRequired: true, - }, - }))) -} - -func TestPruneColumnsMapKey(t *testing.T) { - sc, err := iceberg.PruneColumns(tableSchemaNested, map[int]iceberg.Void{10: {}}, false) - require.NoError(t, err) - - assert.True(t, sc.Equals(iceberg.NewSchema(1, - iceberg.NestedField{ - ID: 6, - Name: "quux", - Required: true, - Type: &iceberg.MapType{ - KeyID: 7, - KeyType: iceberg.PrimitiveTypes.String, - ValueID: 8, - ValueType: &iceberg.MapType{ - KeyID: 9, - KeyType: iceberg.PrimitiveTypes.String, - ValueID: 10, - ValueType: iceberg.PrimitiveTypes.Int32, - ValueRequired: true, - }, - ValueRequired: true, - }, - }))) -} - -func TestPruneColumnsStruct(t *testing.T) { - sc, err := iceberg.PruneColumns(tableSchemaNested, map[int]iceberg.Void{16: {}}, false) - require.NoError(t, err) - - assert.True(t, sc.Equals(iceberg.NewSchema(1, - iceberg.NestedField{ - ID: 15, - Name: "person", - Required: false, - Type: &iceberg.StructType{ - FieldList: []iceberg.NestedField{{ - ID: 16, Name: "name", Type: iceberg.PrimitiveTypes.String, Required: false, - }}, - }, - }))) -} - -func TestPruneColumnsStructFull(t *testing.T) { - sc, err := iceberg.PruneColumns(tableSchemaNested, map[int]iceberg.Void{16: {}}, true) - require.NoError(t, err) - - assert.True(t, sc.Equals(iceberg.NewSchema(1, - iceberg.NestedField{ - ID: 15, - Name: "person", - Required: false, - Type: &iceberg.StructType{ - FieldList: []iceberg.NestedField{{ - ID: 16, Name: "name", Type: iceberg.PrimitiveTypes.String, Required: false, - }}, - }, - }))) -} - -func TestPruneColumnsEmptyStruct(t *testing.T) { - schemaEmptyStruct := iceberg.NewSchema(0, iceberg.NestedField{ - ID: 15, Name: "person", Type: &iceberg.StructType{}, Required: false, - }) - - sc, err := iceberg.PruneColumns(schemaEmptyStruct, map[int]iceberg.Void{15: {}}, false) - require.NoError(t, err) - - assert.True(t, sc.Equals(iceberg.NewSchema(0, - iceberg.NestedField{ - ID: 15, Name: "person", Type: &iceberg.StructType{}, Required: false}))) -} - -func TestPruneColumnsEmptyStructFull(t *testing.T) { - schemaEmptyStruct := iceberg.NewSchema(0, iceberg.NestedField{ - ID: 15, Name: "person", Type: &iceberg.StructType{}, Required: false, - }) - - sc, err := iceberg.PruneColumns(schemaEmptyStruct, map[int]iceberg.Void{15: {}}, true) - require.NoError(t, err) - - assert.True(t, sc.Equals(iceberg.NewSchema(0, - iceberg.NestedField{ - ID: 15, Name: "person", Type: &iceberg.StructType{}, Required: false}))) -} - -func TestPruneColumnsStructInMap(t *testing.T) { - nestedSchema := iceberg.NewSchemaWithIdentifiers(1, []int{1}, - iceberg.NestedField{ - ID: 6, - Name: "id_to_person", - Required: true, - Type: &iceberg.MapType{ - KeyID: 7, - KeyType: iceberg.PrimitiveTypes.Int32, - ValueID: 8, - ValueType: &iceberg.StructType{ - FieldList: []iceberg.NestedField{ - {ID: 10, Name: "name", Type: iceberg.PrimitiveTypes.String}, - {ID: 11, Name: "age", Type: iceberg.PrimitiveTypes.Int32, Required: true}, - }, - }, - ValueRequired: true, - }, - }) - - sc, err := iceberg.PruneColumns(nestedSchema, map[int]iceberg.Void{11: {}}, false) - require.NoError(t, err) - - expected := iceberg.NewSchema(1, - iceberg.NestedField{ - ID: 6, - Name: "id_to_person", - Required: true, - Type: &iceberg.MapType{ - KeyID: 7, - KeyType: iceberg.PrimitiveTypes.Int32, - ValueID: 8, - ValueType: &iceberg.StructType{ - FieldList: []iceberg.NestedField{ - {ID: 11, Name: "age", Type: iceberg.PrimitiveTypes.Int32, Required: true}, - }, - }, - ValueRequired: true, - }, - }) - - assert.Truef(t, sc.Equals(expected), "expected: %s\ngot: %s", expected, sc) -} - -func TestPruneColumnsStructInMapFull(t *testing.T) { - nestedSchema := iceberg.NewSchemaWithIdentifiers(1, []int{1}, - iceberg.NestedField{ - ID: 6, - Name: "id_to_person", - Required: true, - Type: &iceberg.MapType{ - KeyID: 7, - KeyType: iceberg.PrimitiveTypes.Int32, - ValueID: 8, - ValueType: &iceberg.StructType{ - FieldList: []iceberg.NestedField{ - {ID: 10, Name: "name", Type: iceberg.PrimitiveTypes.String}, - {ID: 11, Name: "age", Type: iceberg.PrimitiveTypes.Int32, Required: true}, - }, - }, - ValueRequired: true, - }, - }) - - sc, err := iceberg.PruneColumns(nestedSchema, map[int]iceberg.Void{11: {}}, true) - require.NoError(t, err) - - expected := iceberg.NewSchema(1, - iceberg.NestedField{ - ID: 6, - Name: "id_to_person", - Required: true, - Type: &iceberg.MapType{ - KeyID: 7, - KeyType: iceberg.PrimitiveTypes.Int32, - ValueID: 8, - ValueType: &iceberg.StructType{ - FieldList: []iceberg.NestedField{ - {ID: 11, Name: "age", Type: iceberg.PrimitiveTypes.Int32, Required: true}, - }, - }, - ValueRequired: true, - }, - }) - - assert.Truef(t, sc.Equals(expected), "expected: %s\ngot: %s", expected, sc) -} - -func TestPruneColumnsSelectOriginalSchema(t *testing.T) { - id := tableSchemaNested.HighestFieldID() - selected := make(map[int]iceberg.Void) - for i := 0; i < id; i++ { - selected[i] = iceberg.Void{} - } - - sc, err := iceberg.PruneColumns(tableSchemaNested, selected, true) - require.NoError(t, err) - - assert.True(t, sc.Equals(tableSchemaNested)) -} - -func TestPruneNilSchema(t *testing.T) { - _, err := iceberg.PruneColumns(nil, nil, true) - assert.ErrorIs(t, err, iceberg.ErrInvalidArgument) -} - -func TestSchemaRoundTrip(t *testing.T) { - data, err := json.Marshal(tableSchemaNested) - require.NoError(t, err) - - assert.JSONEq(t, `{ - "type": "struct", - "schema-id": 1, - "identifier-field-ids": [1], - "fields": [ - { - "type": "string", - "id": 1, - "name": "foo", - "required": false - }, - { - "type": "int", - "id": 2, - "name": "bar", - "required": true - }, - { - "type": "boolean", - "id": 3, - "name": "baz", - "required": false - }, - { - "id": 4, - "name": "qux", - "required": true, - "type": { - "type": "list", - "element-id": 5, - "element-required": true, - "element": "string" - } - }, - { - "id": 6, - "name": "quux", - "required": true, - "type": { - "type": "map", - "key-id": 7, - "key": "string", - "value-id": 8, - "value": { - "type": "map", - "key-id": 9, - "key": "string", - "value-id": 10, - "value": "int", - "value-required": true - }, - "value-required": true - } - }, - { - "id": 11, - "name": "location", - "required": true, - "type": { - "type": "list", - "element-id": 12, - "element-required": true, - "element": { - "type": "struct", - "fields": [ - { - "id": 13, - "name": "latitude", - "type": "float", - "required": false - }, - { - "id": 14, - "name": "longitude", - "type": "float", - "required": false - } - ] - } - } - }, - { - "id": 15, - "name": "person", - "required": false, - "type": { - "type": "struct", - "fields": [ - { - "id": 16, - "name": "name", - "type": "string", - "required": false - }, - { - "id": 17, - "name": "age", - "type": "int", - "required": true - } - ] - } - } - ] - }`, string(data)) - - var sc iceberg.Schema - require.NoError(t, json.Unmarshal(data, &sc)) - - assert.Truef(t, tableSchemaNested.Equals(&sc), "expected: %s\ngot: %s", tableSchemaNested, &sc) -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package iceberg_test + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/apache/iceberg-go" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + tableSchemaNested = iceberg.NewSchemaWithIdentifiers(1, + []int{1}, + iceberg.NestedField{ + ID: 1, Name: "foo", Type: iceberg.PrimitiveTypes.String, Required: false}, + iceberg.NestedField{ + ID: 2, Name: "bar", Type: iceberg.PrimitiveTypes.Int32, Required: true}, + iceberg.NestedField{ + ID: 3, Name: "baz", Type: iceberg.PrimitiveTypes.Bool, Required: false}, + iceberg.NestedField{ + ID: 4, Name: "qux", Required: true, Type: &iceberg.ListType{ + ElementID: 5, Element: iceberg.PrimitiveTypes.String, ElementRequired: true}}, + iceberg.NestedField{ + ID: 6, Name: "quux", + Type: &iceberg.MapType{ + KeyID: 7, + KeyType: iceberg.PrimitiveTypes.String, + ValueID: 8, + ValueType: &iceberg.MapType{ + KeyID: 9, + KeyType: iceberg.PrimitiveTypes.String, + ValueID: 10, + ValueType: iceberg.PrimitiveTypes.Int32, + ValueRequired: true, + }, + ValueRequired: true, + }, + Required: true}, + iceberg.NestedField{ + ID: 11, Name: "location", Type: &iceberg.ListType{ + ElementID: 12, Element: &iceberg.StructType{ + FieldList: []iceberg.NestedField{ + {ID: 13, Name: "latitude", Type: iceberg.PrimitiveTypes.Float32, Required: false}, + {ID: 14, Name: "longitude", Type: iceberg.PrimitiveTypes.Float32, Required: false}, + }, + }, + ElementRequired: true}, + Required: true}, + iceberg.NestedField{ + ID: 15, + Name: "person", + Type: &iceberg.StructType{ + FieldList: []iceberg.NestedField{ + {ID: 16, Name: "name", Type: iceberg.PrimitiveTypes.String, Required: false}, + {ID: 17, Name: "age", Type: iceberg.PrimitiveTypes.Int32, Required: true}, + }, + }, + Required: false, + }, + ) + + tableSchemaSimple = iceberg.NewSchemaWithIdentifiers(1, + []int{2}, + iceberg.NestedField{ID: 1, Name: "foo", Type: iceberg.PrimitiveTypes.String}, + iceberg.NestedField{ID: 2, Name: "bar", Type: iceberg.PrimitiveTypes.Int32, Required: true}, + iceberg.NestedField{ID: 3, Name: "baz", Type: iceberg.PrimitiveTypes.Bool}, + ) +) + +func TestSchemaToString(t *testing.T) { + assert.Equal(t, 3, tableSchemaSimple.NumFields()) + assert.Equal(t, `table { + 1: foo: optional string + 2: bar: required int + 3: baz: optional boolean +}`, tableSchemaSimple.String()) +} + +func TestNestedFieldToString(t *testing.T) { + tests := []struct { + idx int + expected string + }{ + {0, "1: foo: optional string"}, + {1, "2: bar: required int"}, + {2, "3: baz: optional boolean"}, + {3, "4: qux: required list"}, + {4, "6: quux: required map>"}, + {5, "11: location: required list>"}, + {6, "15: person: optional struct<16: name: optional string, 17: age: required int>"}, + } + + for _, tt := range tests { + assert.Equal(t, tt.expected, tableSchemaNested.Field(tt.idx).String()) + } +} + +func TestSchemaIndexByIDVisitor(t *testing.T) { + index, err := iceberg.IndexByID(tableSchemaNested) + require.NoError(t, err) + + assert.Equal(t, map[int]iceberg.NestedField{ + 1: tableSchemaNested.Field(0), + 2: tableSchemaNested.Field(1), + 3: tableSchemaNested.Field(2), + 4: tableSchemaNested.Field(3), + 5: {ID: 5, Name: "element", Type: iceberg.PrimitiveTypes.String, Required: true}, + 6: tableSchemaNested.Field(4), + 7: {ID: 7, Name: "key", Type: iceberg.PrimitiveTypes.String, Required: true}, + 8: {ID: 8, Name: "value", Type: &iceberg.MapType{ + KeyID: 9, + KeyType: iceberg.PrimitiveTypes.String, + ValueID: 10, + ValueType: iceberg.PrimitiveTypes.Int32, + ValueRequired: true, + }, Required: true}, + 9: {ID: 9, Name: "key", Type: iceberg.PrimitiveTypes.String, Required: true}, + 10: {ID: 10, Name: "value", Type: iceberg.PrimitiveTypes.Int32, Required: true}, + 11: tableSchemaNested.Field(5), + 12: {ID: 12, Name: "element", Type: &iceberg.StructType{ + FieldList: []iceberg.NestedField{ + {ID: 13, Name: "latitude", Type: iceberg.PrimitiveTypes.Float32, Required: false}, + {ID: 14, Name: "longitude", Type: iceberg.PrimitiveTypes.Float32, Required: false}, + }, + }, Required: true}, + 13: {ID: 13, Name: "latitude", Type: iceberg.PrimitiveTypes.Float32, Required: false}, + 14: {ID: 14, Name: "longitude", Type: iceberg.PrimitiveTypes.Float32, Required: false}, + 15: tableSchemaNested.Field(6), + 16: {ID: 16, Name: "name", Type: iceberg.PrimitiveTypes.String, Required: false}, + 17: {ID: 17, Name: "age", Type: iceberg.PrimitiveTypes.Int32, Required: true}, + }, index) +} + +func TestSchemaIndexByName(t *testing.T) { + index, err := iceberg.IndexByName(tableSchemaNested) + require.NoError(t, err) + + assert.Equal(t, map[string]int{ + "foo": 1, + "bar": 2, + "baz": 3, + "qux": 4, + "qux.element": 5, + "quux": 6, + "quux.key": 7, + "quux.value": 8, + "quux.value.key": 9, + "quux.value.value": 10, + "location": 11, + "location.element": 12, + "location.element.latitude": 13, + "location.element.longitude": 14, + "location.latitude": 13, + "location.longitude": 14, + "person": 15, + "person.name": 16, + "person.age": 17, + }, index) +} + +func TestSchemaFindColumnName(t *testing.T) { + tests := []struct { + id int + name string + }{ + {1, "foo"}, + {2, "bar"}, + {3, "baz"}, + {4, "qux"}, + {5, "qux.element"}, + {6, "quux"}, + {7, "quux.key"}, + {8, "quux.value"}, + {9, "quux.value.key"}, + {10, "quux.value.value"}, + {11, "location"}, + {12, "location.element"}, + {13, "location.element.latitude"}, + {14, "location.element.longitude"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n, ok := tableSchemaNested.FindColumnName(tt.id) + assert.True(t, ok) + assert.Equal(t, tt.name, n) + }) + } +} + +func TestSchemaFindColumnNameIDNotFound(t *testing.T) { + n, ok := tableSchemaNested.FindColumnName(99) + assert.False(t, ok) + assert.Empty(t, n) +} + +func TestSchemaFindColumnNameByID(t *testing.T) { + tests := []struct { + id int + name string + }{ + {1, "foo"}, + {2, "bar"}, + {3, "baz"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + n, ok := tableSchemaSimple.FindColumnName(tt.id) + assert.True(t, ok) + assert.Equal(t, tt.name, n) + }) + } +} + +func TestSchemaFindFieldByID(t *testing.T) { + index, err := iceberg.IndexByID(tableSchemaSimple) + require.NoError(t, err) + + col1 := index[1] + assert.Equal(t, 1, col1.ID) + assert.Equal(t, iceberg.PrimitiveTypes.String, col1.Type) + assert.False(t, col1.Required) + + col2 := index[2] + assert.Equal(t, 2, col2.ID) + assert.Equal(t, iceberg.PrimitiveTypes.Int32, col2.Type) + assert.True(t, col2.Required) + + col3 := index[3] + assert.Equal(t, 3, col3.ID) + assert.Equal(t, iceberg.PrimitiveTypes.Bool, col3.Type) + assert.False(t, col3.Required) +} + +func TestFindFieldByIDUnknownField(t *testing.T) { + index, err := iceberg.IndexByID(tableSchemaSimple) + require.NoError(t, err) + _, ok := index[4] + assert.False(t, ok) +} + +func TestSchemaFindField(t *testing.T) { + tests := []iceberg.NestedField{ + {ID: 1, Name: "foo", Type: iceberg.PrimitiveTypes.String, Required: false}, + {ID: 2, Name: "bar", Type: iceberg.PrimitiveTypes.Int32, Required: true}, + {ID: 3, Name: "baz", Type: iceberg.PrimitiveTypes.Bool, Required: false}, + } + + for _, tt := range tests { + t.Run(tt.Name, func(t *testing.T) { + f, ok := tableSchemaSimple.FindFieldByID(tt.ID) + assert.True(t, ok) + assert.Equal(t, tt, f) + + f, ok = tableSchemaSimple.FindFieldByName(tt.Name) + assert.True(t, ok) + assert.Equal(t, tt, f) + + f, ok = tableSchemaSimple.FindFieldByNameCaseInsensitive(strings.ToUpper(tt.Name)) + assert.True(t, ok) + assert.Equal(t, tt, f) + }) + } +} + +func TestSchemaFindType(t *testing.T) { + _, ok := tableSchemaSimple.FindTypeByID(0) + assert.False(t, ok) + _, ok = tableSchemaSimple.FindTypeByName("FOOBAR") + assert.False(t, ok) + _, ok = tableSchemaSimple.FindTypeByNameCaseInsensitive("FOOBAR") + assert.False(t, ok) + + tests := []struct { + id int + name string + typ iceberg.Type + }{ + {1, "foo", iceberg.PrimitiveTypes.String}, + {2, "bar", iceberg.PrimitiveTypes.Int32}, + {3, "baz", iceberg.PrimitiveTypes.Bool}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + typ, ok := tableSchemaSimple.FindTypeByID(tt.id) + assert.True(t, ok) + assert.Equal(t, tt.typ, typ) + + typ, ok = tableSchemaSimple.FindTypeByName(tt.name) + assert.True(t, ok) + assert.Equal(t, tt.typ, typ) + + typ, ok = tableSchemaSimple.FindTypeByNameCaseInsensitive(strings.ToUpper(tt.name)) + assert.True(t, ok) + assert.Equal(t, tt.typ, typ) + }) + } +} + +func TestSerializeSchema(t *testing.T) { + data, err := json.Marshal(tableSchemaSimple) + require.NoError(t, err) + + assert.JSONEq(t, `{ + "type": "struct", + "fields": [ + {"id": 1, "name": "foo", "type": "string", "required": false}, + {"id": 2, "name": "bar", "type": "int", "required": true}, + {"id": 3, "name": "baz", "type": "boolean", "required": false} + ], + "schema-id": 1, + "identifier-field-ids": [2] + }`, string(data)) +} + +func TestUnmarshalSchema(t *testing.T) { + var schema iceberg.Schema + require.NoError(t, json.Unmarshal([]byte(`{ + "type": "struct", + "fields": [ + {"id": 1, "name": "foo", "type": "string", "required": false}, + {"id": 2, "name": "bar", "type": "int", "required": true}, + {"id": 3, "name": "baz", "type": "boolean", "required": false} + ], + "schema-id": 1, + "identifier-field-ids": [2] + }`), &schema)) + + assert.True(t, tableSchemaSimple.Equals(&schema)) +} + +func TestPruneColumnsString(t *testing.T) { + sc, err := iceberg.PruneColumns(tableSchemaNested, map[int]iceberg.Void{1: {}}, false) + require.NoError(t, err) + + assert.True(t, sc.Equals(iceberg.NewSchemaWithIdentifiers(1, []int{1}, + iceberg.NestedField{ID: 1, Name: "foo", Type: iceberg.PrimitiveTypes.String, Required: false}))) +} + +func TestPruneColumnsStringFull(t *testing.T) { + sc, err := iceberg.PruneColumns(tableSchemaNested, map[int]iceberg.Void{1: {}}, true) + require.NoError(t, err) + + assert.True(t, sc.Equals(iceberg.NewSchemaWithIdentifiers(1, []int{1}, + iceberg.NestedField{ID: 1, Name: "foo", Type: iceberg.PrimitiveTypes.String, Required: false}))) +} + +func TestPruneColumnsList(t *testing.T) { + sc, err := iceberg.PruneColumns(tableSchemaNested, map[int]iceberg.Void{5: {}}, false) + require.NoError(t, err) + + assert.True(t, sc.Equals(iceberg.NewSchema(1, + iceberg.NestedField{ID: 4, Name: "qux", Required: true, Type: &iceberg.ListType{ + ElementID: 5, Element: iceberg.PrimitiveTypes.String, ElementRequired: true, + }}))) +} + +func TestPruneColumnsListItself(t *testing.T) { + _, err := iceberg.PruneColumns(tableSchemaNested, map[int]iceberg.Void{4: {}}, false) + assert.ErrorIs(t, err, iceberg.ErrInvalidSchema) + + assert.ErrorContains(t, err, "cannot explicitly project List or Map types, 4:qux of type list was selected") +} + +func TestPruneColumnsListFull(t *testing.T) { + sc, err := iceberg.PruneColumns(tableSchemaNested, map[int]iceberg.Void{5: {}}, true) + require.NoError(t, err) + + assert.True(t, sc.Equals(iceberg.NewSchema(1, + iceberg.NestedField{ID: 4, Name: "qux", Required: true, Type: &iceberg.ListType{ + ElementID: 5, Element: iceberg.PrimitiveTypes.String, ElementRequired: true, + }}))) +} + +func TestPruneColumnsMap(t *testing.T) { + sc, err := iceberg.PruneColumns(tableSchemaNested, map[int]iceberg.Void{9: {}}, false) + require.NoError(t, err) + + assert.True(t, sc.Equals(iceberg.NewSchema(1, + iceberg.NestedField{ + ID: 6, + Name: "quux", + Required: true, + Type: &iceberg.MapType{ + KeyID: 7, + KeyType: iceberg.PrimitiveTypes.String, + ValueID: 8, + ValueType: &iceberg.MapType{ + KeyID: 9, + KeyType: iceberg.PrimitiveTypes.String, + ValueID: 10, + ValueType: iceberg.PrimitiveTypes.Int32, + ValueRequired: true, + }, + ValueRequired: true, + }, + }))) +} + +func TestPruneColumnsMapItself(t *testing.T) { + _, err := iceberg.PruneColumns(tableSchemaNested, map[int]iceberg.Void{6: {}}, false) + assert.ErrorIs(t, err, iceberg.ErrInvalidSchema) + assert.ErrorContains(t, err, "cannot explicitly project List or Map types, 6:quux of type map> was selected") +} + +func TestPruneColumnsMapFull(t *testing.T) { + sc, err := iceberg.PruneColumns(tableSchemaNested, map[int]iceberg.Void{9: {}}, true) + require.NoError(t, err) + + assert.True(t, sc.Equals(iceberg.NewSchema(1, + iceberg.NestedField{ + ID: 6, + Name: "quux", + Required: true, + Type: &iceberg.MapType{ + KeyID: 7, + KeyType: iceberg.PrimitiveTypes.String, + ValueID: 8, + ValueType: &iceberg.MapType{ + KeyID: 9, + KeyType: iceberg.PrimitiveTypes.String, + ValueID: 10, + ValueType: iceberg.PrimitiveTypes.Int32, + ValueRequired: true, + }, + ValueRequired: true, + }, + }))) +} + +func TestPruneColumnsMapKey(t *testing.T) { + sc, err := iceberg.PruneColumns(tableSchemaNested, map[int]iceberg.Void{10: {}}, false) + require.NoError(t, err) + + assert.True(t, sc.Equals(iceberg.NewSchema(1, + iceberg.NestedField{ + ID: 6, + Name: "quux", + Required: true, + Type: &iceberg.MapType{ + KeyID: 7, + KeyType: iceberg.PrimitiveTypes.String, + ValueID: 8, + ValueType: &iceberg.MapType{ + KeyID: 9, + KeyType: iceberg.PrimitiveTypes.String, + ValueID: 10, + ValueType: iceberg.PrimitiveTypes.Int32, + ValueRequired: true, + }, + ValueRequired: true, + }, + }))) +} + +func TestPruneColumnsStruct(t *testing.T) { + sc, err := iceberg.PruneColumns(tableSchemaNested, map[int]iceberg.Void{16: {}}, false) + require.NoError(t, err) + + assert.True(t, sc.Equals(iceberg.NewSchema(1, + iceberg.NestedField{ + ID: 15, + Name: "person", + Required: false, + Type: &iceberg.StructType{ + FieldList: []iceberg.NestedField{{ + ID: 16, Name: "name", Type: iceberg.PrimitiveTypes.String, Required: false, + }}, + }, + }))) +} + +func TestPruneColumnsStructFull(t *testing.T) { + sc, err := iceberg.PruneColumns(tableSchemaNested, map[int]iceberg.Void{16: {}}, true) + require.NoError(t, err) + + assert.True(t, sc.Equals(iceberg.NewSchema(1, + iceberg.NestedField{ + ID: 15, + Name: "person", + Required: false, + Type: &iceberg.StructType{ + FieldList: []iceberg.NestedField{{ + ID: 16, Name: "name", Type: iceberg.PrimitiveTypes.String, Required: false, + }}, + }, + }))) +} + +func TestPruneColumnsEmptyStruct(t *testing.T) { + schemaEmptyStruct := iceberg.NewSchema(0, iceberg.NestedField{ + ID: 15, Name: "person", Type: &iceberg.StructType{}, Required: false, + }) + + sc, err := iceberg.PruneColumns(schemaEmptyStruct, map[int]iceberg.Void{15: {}}, false) + require.NoError(t, err) + + assert.True(t, sc.Equals(iceberg.NewSchema(0, + iceberg.NestedField{ + ID: 15, Name: "person", Type: &iceberg.StructType{}, Required: false}))) +} + +func TestPruneColumnsEmptyStructFull(t *testing.T) { + schemaEmptyStruct := iceberg.NewSchema(0, iceberg.NestedField{ + ID: 15, Name: "person", Type: &iceberg.StructType{}, Required: false, + }) + + sc, err := iceberg.PruneColumns(schemaEmptyStruct, map[int]iceberg.Void{15: {}}, true) + require.NoError(t, err) + + assert.True(t, sc.Equals(iceberg.NewSchema(0, + iceberg.NestedField{ + ID: 15, Name: "person", Type: &iceberg.StructType{}, Required: false}))) +} + +func TestPruneColumnsStructInMap(t *testing.T) { + nestedSchema := iceberg.NewSchemaWithIdentifiers(1, []int{1}, + iceberg.NestedField{ + ID: 6, + Name: "id_to_person", + Required: true, + Type: &iceberg.MapType{ + KeyID: 7, + KeyType: iceberg.PrimitiveTypes.Int32, + ValueID: 8, + ValueType: &iceberg.StructType{ + FieldList: []iceberg.NestedField{ + {ID: 10, Name: "name", Type: iceberg.PrimitiveTypes.String}, + {ID: 11, Name: "age", Type: iceberg.PrimitiveTypes.Int32, Required: true}, + }, + }, + ValueRequired: true, + }, + }) + + sc, err := iceberg.PruneColumns(nestedSchema, map[int]iceberg.Void{11: {}}, false) + require.NoError(t, err) + + expected := iceberg.NewSchema(1, + iceberg.NestedField{ + ID: 6, + Name: "id_to_person", + Required: true, + Type: &iceberg.MapType{ + KeyID: 7, + KeyType: iceberg.PrimitiveTypes.Int32, + ValueID: 8, + ValueType: &iceberg.StructType{ + FieldList: []iceberg.NestedField{ + {ID: 11, Name: "age", Type: iceberg.PrimitiveTypes.Int32, Required: true}, + }, + }, + ValueRequired: true, + }, + }) + + assert.Truef(t, sc.Equals(expected), "expected: %s\ngot: %s", expected, sc) +} + +func TestPruneColumnsStructInMapFull(t *testing.T) { + nestedSchema := iceberg.NewSchemaWithIdentifiers(1, []int{1}, + iceberg.NestedField{ + ID: 6, + Name: "id_to_person", + Required: true, + Type: &iceberg.MapType{ + KeyID: 7, + KeyType: iceberg.PrimitiveTypes.Int32, + ValueID: 8, + ValueType: &iceberg.StructType{ + FieldList: []iceberg.NestedField{ + {ID: 10, Name: "name", Type: iceberg.PrimitiveTypes.String}, + {ID: 11, Name: "age", Type: iceberg.PrimitiveTypes.Int32, Required: true}, + }, + }, + ValueRequired: true, + }, + }) + + sc, err := iceberg.PruneColumns(nestedSchema, map[int]iceberg.Void{11: {}}, true) + require.NoError(t, err) + + expected := iceberg.NewSchema(1, + iceberg.NestedField{ + ID: 6, + Name: "id_to_person", + Required: true, + Type: &iceberg.MapType{ + KeyID: 7, + KeyType: iceberg.PrimitiveTypes.Int32, + ValueID: 8, + ValueType: &iceberg.StructType{ + FieldList: []iceberg.NestedField{ + {ID: 11, Name: "age", Type: iceberg.PrimitiveTypes.Int32, Required: true}, + }, + }, + ValueRequired: true, + }, + }) + + assert.Truef(t, sc.Equals(expected), "expected: %s\ngot: %s", expected, sc) +} + +func TestPruneColumnsSelectOriginalSchema(t *testing.T) { + id := tableSchemaNested.HighestFieldID() + selected := make(map[int]iceberg.Void) + for i := 0; i < id; i++ { + selected[i] = iceberg.Void{} + } + + sc, err := iceberg.PruneColumns(tableSchemaNested, selected, true) + require.NoError(t, err) + + assert.True(t, sc.Equals(tableSchemaNested)) +} + +func TestPruneNilSchema(t *testing.T) { + _, err := iceberg.PruneColumns(nil, nil, true) + assert.ErrorIs(t, err, iceberg.ErrInvalidArgument) +} + +func TestSchemaRoundTrip(t *testing.T) { + data, err := json.Marshal(tableSchemaNested) + require.NoError(t, err) + + assert.JSONEq(t, `{ + "type": "struct", + "schema-id": 1, + "identifier-field-ids": [1], + "fields": [ + { + "type": "string", + "id": 1, + "name": "foo", + "required": false + }, + { + "type": "int", + "id": 2, + "name": "bar", + "required": true + }, + { + "type": "boolean", + "id": 3, + "name": "baz", + "required": false + }, + { + "id": 4, + "name": "qux", + "required": true, + "type": { + "type": "list", + "element-id": 5, + "element-required": true, + "element": "string" + } + }, + { + "id": 6, + "name": "quux", + "required": true, + "type": { + "type": "map", + "key-id": 7, + "key": "string", + "value-id": 8, + "value": { + "type": "map", + "key-id": 9, + "key": "string", + "value-id": 10, + "value": "int", + "value-required": true + }, + "value-required": true + } + }, + { + "id": 11, + "name": "location", + "required": true, + "type": { + "type": "list", + "element-id": 12, + "element-required": true, + "element": { + "type": "struct", + "fields": [ + { + "id": 13, + "name": "latitude", + "type": "float", + "required": false + }, + { + "id": 14, + "name": "longitude", + "type": "float", + "required": false + } + ] + } + } + }, + { + "id": 15, + "name": "person", + "required": false, + "type": { + "type": "struct", + "fields": [ + { + "id": 16, + "name": "name", + "type": "string", + "required": false + }, + { + "id": 17, + "name": "age", + "type": "int", + "required": true + } + ] + } + } + ] + }`, string(data)) + + var sc iceberg.Schema + require.NoError(t, json.Unmarshal(data, &sc)) + + assert.Truef(t, tableSchemaNested.Equals(&sc), "expected: %s\ngot: %s", tableSchemaNested, &sc) +} diff --git a/table/arrow_utils.go b/table/arrow_utils.go index 6104fc6..bb9be85 100644 --- a/table/arrow_utils.go +++ b/table/arrow_utils.go @@ -1,412 +1,412 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package table - -import ( - "fmt" - "slices" - "strconv" - - "github.com/apache/arrow-go/v18/arrow" - "github.com/apache/iceberg-go" -) - -// constants to look for as Keys in Arrow field metadata -const ( - ArrowFieldDocKey = "doc" - // Arrow schemas that are generated from the Parquet library will utilize - // this key to identify the field id of the source Parquet field. - // We use this when converting to Iceberg to provide field IDs - ArrowParquetFieldIDKey = "PARQUET:field_id" -) - -// ArrowSchemaVisitor is an interface that can be implemented and used to -// call VisitArrowSchema for iterating -type ArrowSchemaVisitor[T any] interface { - Schema(*arrow.Schema, T) T - Struct(*arrow.StructType, []T) T - Field(arrow.Field, T) T - List(arrow.ListLikeType, T) T - Map(mt *arrow.MapType, keyResult T, valueResult T) T - Primitive(arrow.DataType) T -} - -func recoverError(err *error) { - if r := recover(); r != nil { - switch e := r.(type) { - case string: - *err = fmt.Errorf("error encountered during arrow schema visitor: %s", e) - case error: - *err = fmt.Errorf("error encountered during arrow schema visitor: %w", e) - } - } -} - -func VisitArrowSchema[T any](sc *arrow.Schema, visitor ArrowSchemaVisitor[T]) (res T, err error) { - if sc == nil { - err = fmt.Errorf("%w: cannot visit nil arrow schema", iceberg.ErrInvalidArgument) - return - } - - defer recoverError(&err) - - return visitor.Schema(sc, visitArrowStruct(arrow.StructOf(sc.Fields()...), visitor)), err -} - -func visitArrowField[T any](f arrow.Field, visitor ArrowSchemaVisitor[T]) T { - switch typ := f.Type.(type) { - case *arrow.StructType: - return visitArrowStruct(typ, visitor) - case *arrow.MapType: - return visitArrowMap(typ, visitor) - case arrow.ListLikeType: - return visitArrowList(typ, visitor) - default: - return visitor.Primitive(typ) - } -} - -func visitArrowStruct[T any](dt *arrow.StructType, visitor ArrowSchemaVisitor[T]) T { - type ( - beforeField interface { - BeforeField(arrow.Field) - } - afterField interface { - AfterField(arrow.Field) - } - ) - - results := make([]T, dt.NumFields()) - bf, _ := visitor.(beforeField) - af, _ := visitor.(afterField) - - for i, f := range dt.Fields() { - if bf != nil { - bf.BeforeField(f) - } - - res := visitArrowField(f, visitor) - - if af != nil { - af.AfterField(f) - } - - results[i] = visitor.Field(f, res) - } - - return visitor.Struct(dt, results) -} - -func visitArrowMap[T any](dt *arrow.MapType, visitor ArrowSchemaVisitor[T]) T { - type ( - beforeMapKey interface { - BeforeMapKey(arrow.Field) - } - beforeMapValue interface { - BeforeMapValue(arrow.Field) - } - afterMapKey interface { - AfterMapKey(arrow.Field) - } - afterMapValue interface { - AfterMapValue(arrow.Field) - } - ) - - key, val := dt.KeyField(), dt.ItemField() - - if bmk, ok := visitor.(beforeMapKey); ok { - bmk.BeforeMapKey(key) - } - - keyResult := visitArrowField(key, visitor) - - if amk, ok := visitor.(afterMapKey); ok { - amk.AfterMapKey(key) - } - - if bmv, ok := visitor.(beforeMapValue); ok { - bmv.BeforeMapValue(val) - } - - valueResult := visitArrowField(val, visitor) - - if amv, ok := visitor.(afterMapValue); ok { - amv.AfterMapValue(val) - } - - return visitor.Map(dt, keyResult, valueResult) -} - -func visitArrowList[T any](dt arrow.ListLikeType, visitor ArrowSchemaVisitor[T]) T { - type ( - beforeListElem interface { - BeforeListElement(arrow.Field) - } - afterListElem interface { - AfterListElement(arrow.Field) - } - ) - - elemField := dt.ElemField() - - if bl, ok := visitor.(beforeListElem); ok { - bl.BeforeListElement(elemField) - } - - res := visitArrowField(elemField, visitor) - - if al, ok := visitor.(afterListElem); ok { - al.AfterListElement(elemField) - } - - return visitor.List(dt, res) -} - -func getFieldID(f arrow.Field) *int { - if !f.HasMetadata() { - return nil - } - - fieldIDStr, ok := f.Metadata.GetValue(ArrowParquetFieldIDKey) - if !ok { - return nil - } - - id, err := strconv.Atoi(fieldIDStr) - if err != nil { - return nil - } - - return &id -} - -type hasIDs struct{} - -func (hasIDs) Schema(sc *arrow.Schema, result bool) bool { - return result -} - -func (hasIDs) Struct(st *arrow.StructType, results []bool) bool { - return !slices.Contains(results, false) -} - -func (hasIDs) Field(f arrow.Field, result bool) bool { - return getFieldID(f) != nil -} - -func (hasIDs) List(dt arrow.ListLikeType, elem bool) bool { - elemField := dt.ElemField() - return elem && getFieldID(elemField) != nil -} - -func (hasIDs) Map(m *arrow.MapType, key, val bool) bool { - return key && val && - getFieldID(m.KeyField()) != nil && getFieldID(m.ItemField()) != nil -} - -func (hasIDs) Primitive(arrow.DataType) bool { return true } - -type convertToIceberg struct { - downcastTimestamp bool - - fieldID func(arrow.Field) int -} - -func (convertToIceberg) Schema(_ *arrow.Schema, result iceberg.NestedField) iceberg.NestedField { - return result -} - -func (convertToIceberg) Struct(_ *arrow.StructType, results []iceberg.NestedField) iceberg.NestedField { - return iceberg.NestedField{ - Type: &iceberg.StructType{FieldList: results}, - } -} - -func (c convertToIceberg) Field(field arrow.Field, result iceberg.NestedField) iceberg.NestedField { - result.ID = c.fieldID(field) - if field.HasMetadata() { - if doc, ok := field.Metadata.GetValue(ArrowFieldDocKey); ok { - result.Doc = doc - } - } - - result.Required = !field.Nullable - result.Name = field.Name - return result -} - -func (c convertToIceberg) List(dt arrow.ListLikeType, elemResult iceberg.NestedField) iceberg.NestedField { - elemField := dt.ElemField() - elemID := c.fieldID(elemField) - - return iceberg.NestedField{ - Type: &iceberg.ListType{ - ElementID: elemID, - Element: elemResult.Type, - ElementRequired: !elemField.Nullable, - }, - } -} - -func (c convertToIceberg) Map(m *arrow.MapType, keyResult, valueResult iceberg.NestedField) iceberg.NestedField { - keyField, valField := m.KeyField(), m.ItemField() - keyID, valID := c.fieldID(keyField), c.fieldID(valField) - - return iceberg.NestedField{ - Type: &iceberg.MapType{ - KeyID: keyID, - KeyType: keyResult.Type, - ValueID: valID, - ValueType: valueResult.Type, - ValueRequired: !valField.Nullable, - }, - } -} - -var ( - utcAliases = []string{"UTC", "+00:00", "Etc/UTC", "Z"} -) - -func (c convertToIceberg) Primitive(dt arrow.DataType) (result iceberg.NestedField) { - switch dt := dt.(type) { - case *arrow.DictionaryType: - if _, ok := dt.ValueType.(arrow.NestedType); ok { - panic(fmt.Errorf("%w: unsupported arrow type for conversion - %s", iceberg.ErrInvalidSchema, dt)) - } - return c.Primitive(dt.ValueType) - case *arrow.RunEndEncodedType: - if _, ok := dt.Encoded().(arrow.NestedType); ok { - panic(fmt.Errorf("%w: unsupported arrow type for conversion - %s", iceberg.ErrInvalidSchema, dt)) - } - return c.Primitive(dt.Encoded()) - case *arrow.BooleanType: - result.Type = iceberg.PrimitiveTypes.Bool - case *arrow.Uint8Type, *arrow.Uint16Type, *arrow.Uint32Type, - *arrow.Int8Type, *arrow.Int16Type, *arrow.Int32Type: - result.Type = iceberg.PrimitiveTypes.Int32 - case *arrow.Uint64Type, *arrow.Int64Type: - result.Type = iceberg.PrimitiveTypes.Int64 - case *arrow.Float16Type, *arrow.Float32Type: - result.Type = iceberg.PrimitiveTypes.Float32 - case *arrow.Float64Type: - result.Type = iceberg.PrimitiveTypes.Float64 - case *arrow.Decimal32Type, *arrow.Decimal64Type, *arrow.Decimal128Type: - dec := dt.(arrow.DecimalType) - result.Type = iceberg.DecimalTypeOf(int(dec.GetPrecision()), int(dec.GetScale())) - case *arrow.StringType, *arrow.LargeStringType: - result.Type = iceberg.PrimitiveTypes.String - case *arrow.BinaryType, *arrow.LargeBinaryType: - result.Type = iceberg.PrimitiveTypes.Binary - case *arrow.Date32Type: - result.Type = iceberg.PrimitiveTypes.Date - case *arrow.Time64Type: - if dt.Unit == arrow.Microsecond { - result.Type = iceberg.PrimitiveTypes.Time - } else { - panic(fmt.Errorf("%w: unsupported arrow type for conversion - %s", iceberg.ErrInvalidSchema, dt)) - } - case *arrow.TimestampType: - if dt.Unit == arrow.Nanosecond { - if !c.downcastTimestamp { - panic(fmt.Errorf("%w: 'ns' timestamp precision not supported", iceberg.ErrType)) - } - // TODO: log something - } - - if slices.Contains(utcAliases, dt.TimeZone) { - result.Type = iceberg.PrimitiveTypes.TimestampTz - } else if dt.TimeZone == "" { - result.Type = iceberg.PrimitiveTypes.Timestamp - } else { - panic(fmt.Errorf("%w: unsupported arrow type for conversion - %s", iceberg.ErrInvalidSchema, dt)) - } - case *arrow.FixedSizeBinaryType: - result.Type = iceberg.FixedTypeOf(dt.ByteWidth) - case arrow.ExtensionType: - if dt.ExtensionName() == "arrow.uuid" { - result.Type = iceberg.PrimitiveTypes.UUID - } else { - panic(fmt.Errorf("%w: unsupported arrow type for conversion - %s", iceberg.ErrInvalidSchema, dt)) - } - default: - panic(fmt.Errorf("%w: unsupported arrow type for conversion - %s", iceberg.ErrInvalidSchema, dt)) - } - - return -} - -func ArrowTypeToIceberg(dt arrow.DataType, downcastNsTimestamp bool) (iceberg.Type, error) { - sc := arrow.NewSchema([]arrow.Field{{Type: dt, - Metadata: arrow.NewMetadata([]string{ArrowParquetFieldIDKey}, []string{"1"})}}, nil) - - out, err := VisitArrowSchema(sc, convertToIceberg{ - downcastTimestamp: downcastNsTimestamp, - fieldID: func(field arrow.Field) int { - if id := getFieldID(field); id != nil { - return *id - } - - panic(fmt.Errorf("%w: cannot convert %s to Iceberg field, missing field_id", - iceberg.ErrInvalidSchema, field)) - }, - }) - if err != nil { - return nil, err - } - - return out.Type.(*iceberg.StructType).FieldList[0].Type, nil -} - -func ArrowSchemaToIceberg(sc *arrow.Schema, downcastNsTimestamp bool, nameMapping NameMapping) (*iceberg.Schema, error) { - hasIDs, _ := VisitArrowSchema(sc, hasIDs{}) - - switch { - case hasIDs: - out, err := VisitArrowSchema(sc, convertToIceberg{ - downcastTimestamp: downcastNsTimestamp, - fieldID: func(field arrow.Field) int { - if id := getFieldID(field); id != nil { - return *id - } - - panic(fmt.Errorf("%w: cannot convert %s to Iceberg field, missing field_id", - iceberg.ErrInvalidSchema, field)) - }, - }) - if err != nil { - return nil, err - } - - return iceberg.NewSchema(0, out.Type.(*iceberg.StructType).FieldList...), nil - case nameMapping != nil: - withoutIDs, err := VisitArrowSchema(sc, convertToIceberg{ - downcastTimestamp: downcastNsTimestamp, - fieldID: func(_ arrow.Field) int { return -1 }, - }) - if err != nil { - return nil, err - } - - schemaWithoutIDs := iceberg.NewSchema(0, withoutIDs.Type.(*iceberg.StructType).FieldList...) - return ApplyNameMapping(schemaWithoutIDs, nameMapping) - default: - return nil, fmt.Errorf("%w: arrow schema does not have field-ids and no name mapping provided", - iceberg.ErrInvalidSchema) - } -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package table + +import ( + "fmt" + "slices" + "strconv" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/iceberg-go" +) + +// constants to look for as Keys in Arrow field metadata +const ( + ArrowFieldDocKey = "doc" + // Arrow schemas that are generated from the Parquet library will utilize + // this key to identify the field id of the source Parquet field. + // We use this when converting to Iceberg to provide field IDs + ArrowParquetFieldIDKey = "PARQUET:field_id" +) + +// ArrowSchemaVisitor is an interface that can be implemented and used to +// call VisitArrowSchema for iterating +type ArrowSchemaVisitor[T any] interface { + Schema(*arrow.Schema, T) T + Struct(*arrow.StructType, []T) T + Field(arrow.Field, T) T + List(arrow.ListLikeType, T) T + Map(mt *arrow.MapType, keyResult T, valueResult T) T + Primitive(arrow.DataType) T +} + +func recoverError(err *error) { + if r := recover(); r != nil { + switch e := r.(type) { + case string: + *err = fmt.Errorf("error encountered during arrow schema visitor: %s", e) + case error: + *err = fmt.Errorf("error encountered during arrow schema visitor: %w", e) + } + } +} + +func VisitArrowSchema[T any](sc *arrow.Schema, visitor ArrowSchemaVisitor[T]) (res T, err error) { + if sc == nil { + err = fmt.Errorf("%w: cannot visit nil arrow schema", iceberg.ErrInvalidArgument) + return + } + + defer recoverError(&err) + + return visitor.Schema(sc, visitArrowStruct(arrow.StructOf(sc.Fields()...), visitor)), err +} + +func visitArrowField[T any](f arrow.Field, visitor ArrowSchemaVisitor[T]) T { + switch typ := f.Type.(type) { + case *arrow.StructType: + return visitArrowStruct(typ, visitor) + case *arrow.MapType: + return visitArrowMap(typ, visitor) + case arrow.ListLikeType: + return visitArrowList(typ, visitor) + default: + return visitor.Primitive(typ) + } +} + +func visitArrowStruct[T any](dt *arrow.StructType, visitor ArrowSchemaVisitor[T]) T { + type ( + beforeField interface { + BeforeField(arrow.Field) + } + afterField interface { + AfterField(arrow.Field) + } + ) + + results := make([]T, dt.NumFields()) + bf, _ := visitor.(beforeField) + af, _ := visitor.(afterField) + + for i, f := range dt.Fields() { + if bf != nil { + bf.BeforeField(f) + } + + res := visitArrowField(f, visitor) + + if af != nil { + af.AfterField(f) + } + + results[i] = visitor.Field(f, res) + } + + return visitor.Struct(dt, results) +} + +func visitArrowMap[T any](dt *arrow.MapType, visitor ArrowSchemaVisitor[T]) T { + type ( + beforeMapKey interface { + BeforeMapKey(arrow.Field) + } + beforeMapValue interface { + BeforeMapValue(arrow.Field) + } + afterMapKey interface { + AfterMapKey(arrow.Field) + } + afterMapValue interface { + AfterMapValue(arrow.Field) + } + ) + + key, val := dt.KeyField(), dt.ItemField() + + if bmk, ok := visitor.(beforeMapKey); ok { + bmk.BeforeMapKey(key) + } + + keyResult := visitArrowField(key, visitor) + + if amk, ok := visitor.(afterMapKey); ok { + amk.AfterMapKey(key) + } + + if bmv, ok := visitor.(beforeMapValue); ok { + bmv.BeforeMapValue(val) + } + + valueResult := visitArrowField(val, visitor) + + if amv, ok := visitor.(afterMapValue); ok { + amv.AfterMapValue(val) + } + + return visitor.Map(dt, keyResult, valueResult) +} + +func visitArrowList[T any](dt arrow.ListLikeType, visitor ArrowSchemaVisitor[T]) T { + type ( + beforeListElem interface { + BeforeListElement(arrow.Field) + } + afterListElem interface { + AfterListElement(arrow.Field) + } + ) + + elemField := dt.ElemField() + + if bl, ok := visitor.(beforeListElem); ok { + bl.BeforeListElement(elemField) + } + + res := visitArrowField(elemField, visitor) + + if al, ok := visitor.(afterListElem); ok { + al.AfterListElement(elemField) + } + + return visitor.List(dt, res) +} + +func getFieldID(f arrow.Field) *int { + if !f.HasMetadata() { + return nil + } + + fieldIDStr, ok := f.Metadata.GetValue(ArrowParquetFieldIDKey) + if !ok { + return nil + } + + id, err := strconv.Atoi(fieldIDStr) + if err != nil { + return nil + } + + return &id +} + +type hasIDs struct{} + +func (hasIDs) Schema(sc *arrow.Schema, result bool) bool { + return result +} + +func (hasIDs) Struct(st *arrow.StructType, results []bool) bool { + return !slices.Contains(results, false) +} + +func (hasIDs) Field(f arrow.Field, result bool) bool { + return getFieldID(f) != nil +} + +func (hasIDs) List(dt arrow.ListLikeType, elem bool) bool { + elemField := dt.ElemField() + return elem && getFieldID(elemField) != nil +} + +func (hasIDs) Map(m *arrow.MapType, key, val bool) bool { + return key && val && + getFieldID(m.KeyField()) != nil && getFieldID(m.ItemField()) != nil +} + +func (hasIDs) Primitive(arrow.DataType) bool { return true } + +type convertToIceberg struct { + downcastTimestamp bool + + fieldID func(arrow.Field) int +} + +func (convertToIceberg) Schema(_ *arrow.Schema, result iceberg.NestedField) iceberg.NestedField { + return result +} + +func (convertToIceberg) Struct(_ *arrow.StructType, results []iceberg.NestedField) iceberg.NestedField { + return iceberg.NestedField{ + Type: &iceberg.StructType{FieldList: results}, + } +} + +func (c convertToIceberg) Field(field arrow.Field, result iceberg.NestedField) iceberg.NestedField { + result.ID = c.fieldID(field) + if field.HasMetadata() { + if doc, ok := field.Metadata.GetValue(ArrowFieldDocKey); ok { + result.Doc = doc + } + } + + result.Required = !field.Nullable + result.Name = field.Name + return result +} + +func (c convertToIceberg) List(dt arrow.ListLikeType, elemResult iceberg.NestedField) iceberg.NestedField { + elemField := dt.ElemField() + elemID := c.fieldID(elemField) + + return iceberg.NestedField{ + Type: &iceberg.ListType{ + ElementID: elemID, + Element: elemResult.Type, + ElementRequired: !elemField.Nullable, + }, + } +} + +func (c convertToIceberg) Map(m *arrow.MapType, keyResult, valueResult iceberg.NestedField) iceberg.NestedField { + keyField, valField := m.KeyField(), m.ItemField() + keyID, valID := c.fieldID(keyField), c.fieldID(valField) + + return iceberg.NestedField{ + Type: &iceberg.MapType{ + KeyID: keyID, + KeyType: keyResult.Type, + ValueID: valID, + ValueType: valueResult.Type, + ValueRequired: !valField.Nullable, + }, + } +} + +var ( + utcAliases = []string{"UTC", "+00:00", "Etc/UTC", "Z"} +) + +func (c convertToIceberg) Primitive(dt arrow.DataType) (result iceberg.NestedField) { + switch dt := dt.(type) { + case *arrow.DictionaryType: + if _, ok := dt.ValueType.(arrow.NestedType); ok { + panic(fmt.Errorf("%w: unsupported arrow type for conversion - %s", iceberg.ErrInvalidSchema, dt)) + } + return c.Primitive(dt.ValueType) + case *arrow.RunEndEncodedType: + if _, ok := dt.Encoded().(arrow.NestedType); ok { + panic(fmt.Errorf("%w: unsupported arrow type for conversion - %s", iceberg.ErrInvalidSchema, dt)) + } + return c.Primitive(dt.Encoded()) + case *arrow.BooleanType: + result.Type = iceberg.PrimitiveTypes.Bool + case *arrow.Uint8Type, *arrow.Uint16Type, *arrow.Uint32Type, + *arrow.Int8Type, *arrow.Int16Type, *arrow.Int32Type: + result.Type = iceberg.PrimitiveTypes.Int32 + case *arrow.Uint64Type, *arrow.Int64Type: + result.Type = iceberg.PrimitiveTypes.Int64 + case *arrow.Float16Type, *arrow.Float32Type: + result.Type = iceberg.PrimitiveTypes.Float32 + case *arrow.Float64Type: + result.Type = iceberg.PrimitiveTypes.Float64 + case *arrow.Decimal32Type, *arrow.Decimal64Type, *arrow.Decimal128Type: + dec := dt.(arrow.DecimalType) + result.Type = iceberg.DecimalTypeOf(int(dec.GetPrecision()), int(dec.GetScale())) + case *arrow.StringType, *arrow.LargeStringType: + result.Type = iceberg.PrimitiveTypes.String + case *arrow.BinaryType, *arrow.LargeBinaryType: + result.Type = iceberg.PrimitiveTypes.Binary + case *arrow.Date32Type: + result.Type = iceberg.PrimitiveTypes.Date + case *arrow.Time64Type: + if dt.Unit == arrow.Microsecond { + result.Type = iceberg.PrimitiveTypes.Time + } else { + panic(fmt.Errorf("%w: unsupported arrow type for conversion - %s", iceberg.ErrInvalidSchema, dt)) + } + case *arrow.TimestampType: + if dt.Unit == arrow.Nanosecond { + if !c.downcastTimestamp { + panic(fmt.Errorf("%w: 'ns' timestamp precision not supported", iceberg.ErrType)) + } + // TODO: log something + } + + if slices.Contains(utcAliases, dt.TimeZone) { + result.Type = iceberg.PrimitiveTypes.TimestampTz + } else if dt.TimeZone == "" { + result.Type = iceberg.PrimitiveTypes.Timestamp + } else { + panic(fmt.Errorf("%w: unsupported arrow type for conversion - %s", iceberg.ErrInvalidSchema, dt)) + } + case *arrow.FixedSizeBinaryType: + result.Type = iceberg.FixedTypeOf(dt.ByteWidth) + case arrow.ExtensionType: + if dt.ExtensionName() == "arrow.uuid" { + result.Type = iceberg.PrimitiveTypes.UUID + } else { + panic(fmt.Errorf("%w: unsupported arrow type for conversion - %s", iceberg.ErrInvalidSchema, dt)) + } + default: + panic(fmt.Errorf("%w: unsupported arrow type for conversion - %s", iceberg.ErrInvalidSchema, dt)) + } + + return +} + +func ArrowTypeToIceberg(dt arrow.DataType, downcastNsTimestamp bool) (iceberg.Type, error) { + sc := arrow.NewSchema([]arrow.Field{{Type: dt, + Metadata: arrow.NewMetadata([]string{ArrowParquetFieldIDKey}, []string{"1"})}}, nil) + + out, err := VisitArrowSchema(sc, convertToIceberg{ + downcastTimestamp: downcastNsTimestamp, + fieldID: func(field arrow.Field) int { + if id := getFieldID(field); id != nil { + return *id + } + + panic(fmt.Errorf("%w: cannot convert %s to Iceberg field, missing field_id", + iceberg.ErrInvalidSchema, field)) + }, + }) + if err != nil { + return nil, err + } + + return out.Type.(*iceberg.StructType).FieldList[0].Type, nil +} + +func ArrowSchemaToIceberg(sc *arrow.Schema, downcastNsTimestamp bool, nameMapping NameMapping) (*iceberg.Schema, error) { + hasIDs, _ := VisitArrowSchema(sc, hasIDs{}) + + switch { + case hasIDs: + out, err := VisitArrowSchema(sc, convertToIceberg{ + downcastTimestamp: downcastNsTimestamp, + fieldID: func(field arrow.Field) int { + if id := getFieldID(field); id != nil { + return *id + } + + panic(fmt.Errorf("%w: cannot convert %s to Iceberg field, missing field_id", + iceberg.ErrInvalidSchema, field)) + }, + }) + if err != nil { + return nil, err + } + + return iceberg.NewSchema(0, out.Type.(*iceberg.StructType).FieldList...), nil + case nameMapping != nil: + withoutIDs, err := VisitArrowSchema(sc, convertToIceberg{ + downcastTimestamp: downcastNsTimestamp, + fieldID: func(_ arrow.Field) int { return -1 }, + }) + if err != nil { + return nil, err + } + + schemaWithoutIDs := iceberg.NewSchema(0, withoutIDs.Type.(*iceberg.StructType).FieldList...) + return ApplyNameMapping(schemaWithoutIDs, nameMapping) + default: + return nil, fmt.Errorf("%w: arrow schema does not have field-ids and no name mapping provided", + iceberg.ErrInvalidSchema) + } +} diff --git a/table/arrow_utils_test.go b/table/arrow_utils_test.go index 1d8173e..5ccd236 100644 --- a/table/arrow_utils_test.go +++ b/table/arrow_utils_test.go @@ -1,371 +1,371 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package table_test - -import ( - "testing" - - "github.com/apache/arrow-go/v18/arrow" - "github.com/apache/arrow-go/v18/arrow/extensions" - "github.com/apache/iceberg-go" - "github.com/apache/iceberg-go/table" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func fieldIDMeta(id string) arrow.Metadata { - return arrow.MetadataFrom(map[string]string{table.ArrowParquetFieldIDKey: id}) -} - -func TestArrowToIceberg(t *testing.T) { - tests := []struct { - dt arrow.DataType - ice iceberg.Type - err string - }{ - {&arrow.FixedSizeBinaryType{ByteWidth: 23}, iceberg.FixedTypeOf(23), ""}, - {&arrow.Decimal32Type{Precision: 8, Scale: 9}, iceberg.DecimalTypeOf(8, 9), ""}, - {&arrow.Decimal64Type{Precision: 15, Scale: 14}, iceberg.DecimalTypeOf(15, 14), ""}, - {&arrow.Decimal128Type{Precision: 26, Scale: 20}, iceberg.DecimalTypeOf(26, 20), ""}, - {&arrow.Decimal256Type{Precision: 8, Scale: 9}, nil, "unsupported arrow type for conversion - decimal256(8, 9)"}, - {arrow.FixedWidthTypes.Boolean, iceberg.PrimitiveTypes.Bool, ""}, - {arrow.PrimitiveTypes.Int8, iceberg.PrimitiveTypes.Int32, ""}, - {arrow.PrimitiveTypes.Uint8, iceberg.PrimitiveTypes.Int32, ""}, - {arrow.PrimitiveTypes.Int16, iceberg.PrimitiveTypes.Int32, ""}, - {arrow.PrimitiveTypes.Uint16, iceberg.PrimitiveTypes.Int32, ""}, - {arrow.PrimitiveTypes.Int32, iceberg.PrimitiveTypes.Int32, ""}, - {arrow.PrimitiveTypes.Uint32, iceberg.PrimitiveTypes.Int32, ""}, - {arrow.PrimitiveTypes.Int64, iceberg.PrimitiveTypes.Int64, ""}, - {arrow.PrimitiveTypes.Uint64, iceberg.PrimitiveTypes.Int64, ""}, - {arrow.FixedWidthTypes.Float16, iceberg.PrimitiveTypes.Float32, ""}, - {arrow.PrimitiveTypes.Float32, iceberg.PrimitiveTypes.Float32, ""}, - {arrow.PrimitiveTypes.Float64, iceberg.PrimitiveTypes.Float64, ""}, - {arrow.FixedWidthTypes.Date32, iceberg.PrimitiveTypes.Date, ""}, - {arrow.FixedWidthTypes.Date64, nil, "unsupported arrow type for conversion - date64"}, - {arrow.FixedWidthTypes.Time32s, nil, "unsupported arrow type for conversion - time32[s]"}, - {arrow.FixedWidthTypes.Time32ms, nil, "unsupported arrow type for conversion - time32[ms]"}, - {arrow.FixedWidthTypes.Time64us, iceberg.PrimitiveTypes.Time, ""}, - {arrow.FixedWidthTypes.Time64ns, nil, "unsupported arrow type for conversion - time64[ns]"}, - {arrow.FixedWidthTypes.Timestamp_s, iceberg.PrimitiveTypes.TimestampTz, ""}, - {arrow.FixedWidthTypes.Timestamp_ms, iceberg.PrimitiveTypes.TimestampTz, ""}, - {arrow.FixedWidthTypes.Timestamp_us, iceberg.PrimitiveTypes.TimestampTz, ""}, - {arrow.FixedWidthTypes.Timestamp_ns, nil, "'ns' timestamp precision not supported"}, - {&arrow.TimestampType{Unit: arrow.Second}, iceberg.PrimitiveTypes.Timestamp, ""}, - {&arrow.TimestampType{Unit: arrow.Millisecond}, iceberg.PrimitiveTypes.Timestamp, ""}, - {&arrow.TimestampType{Unit: arrow.Microsecond}, iceberg.PrimitiveTypes.Timestamp, ""}, - {&arrow.TimestampType{Unit: arrow.Nanosecond}, nil, "'ns' timestamp precision not supported"}, - {&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: "US/Pacific"}, nil, "unsupported arrow type for conversion - timestamp[us, tz=US/Pacific]"}, - {arrow.BinaryTypes.String, iceberg.PrimitiveTypes.String, ""}, - {arrow.BinaryTypes.LargeString, iceberg.PrimitiveTypes.String, ""}, - {arrow.BinaryTypes.StringView, nil, "unsupported arrow type for conversion - string_view"}, - {arrow.BinaryTypes.Binary, iceberg.PrimitiveTypes.Binary, ""}, - {arrow.BinaryTypes.LargeBinary, iceberg.PrimitiveTypes.Binary, ""}, - {arrow.BinaryTypes.BinaryView, nil, "unsupported arrow type for conversion - binary_view"}, - {extensions.NewUUIDType(), iceberg.PrimitiveTypes.UUID, ""}, - {arrow.StructOf(arrow.Field{ - Name: "foo", - Type: arrow.BinaryTypes.LargeString, - Nullable: true, - Metadata: arrow.MetadataFrom(map[string]string{ - table.ArrowParquetFieldIDKey: "1", table.ArrowFieldDocKey: "foo doc", - }), - }, arrow.Field{ - Name: "bar", - Type: arrow.PrimitiveTypes.Int32, - Metadata: fieldIDMeta("2"), - }, arrow.Field{ - Name: "baz", - Type: arrow.FixedWidthTypes.Boolean, - Nullable: true, - Metadata: fieldIDMeta("3"), - }), &iceberg.StructType{ - FieldList: []iceberg.NestedField{ - {ID: 1, Name: "foo", Type: iceberg.PrimitiveTypes.String, Required: false, Doc: "foo doc"}, - {ID: 2, Name: "bar", Type: iceberg.PrimitiveTypes.Int32, Required: true}, - {ID: 3, Name: "baz", Type: iceberg.PrimitiveTypes.Bool, Required: false}, - }}, ""}, - {arrow.ListOfField(arrow.Field{ - Name: "element", - Type: arrow.PrimitiveTypes.Int32, - Nullable: false, - Metadata: fieldIDMeta("1"), - }), &iceberg.ListType{ - ElementID: 1, - Element: iceberg.PrimitiveTypes.Int32, - ElementRequired: true, - }, ""}, - {arrow.LargeListOfField(arrow.Field{ - Name: "element", - Type: arrow.PrimitiveTypes.Int32, - Nullable: false, - Metadata: fieldIDMeta("1"), - }), &iceberg.ListType{ - ElementID: 1, - Element: iceberg.PrimitiveTypes.Int32, - ElementRequired: true, - }, ""}, - {arrow.FixedSizeListOfField(1, arrow.Field{ - Name: "element", - Type: arrow.PrimitiveTypes.Int32, - Nullable: false, - Metadata: fieldIDMeta("1"), - }), &iceberg.ListType{ - ElementID: 1, - Element: iceberg.PrimitiveTypes.Int32, - ElementRequired: true, - }, ""}, - {arrow.MapOfWithMetadata(arrow.PrimitiveTypes.Int32, - fieldIDMeta("1"), - arrow.BinaryTypes.String, fieldIDMeta("2")), - &iceberg.MapType{ - KeyID: 1, KeyType: iceberg.PrimitiveTypes.Int32, - ValueID: 2, ValueType: iceberg.PrimitiveTypes.String, ValueRequired: false, - }, ""}, - {&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int32, - ValueType: arrow.BinaryTypes.String}, iceberg.PrimitiveTypes.String, ""}, - {&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int32, - ValueType: arrow.PrimitiveTypes.Int32}, iceberg.PrimitiveTypes.Int32, ""}, - {&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int64, - ValueType: arrow.PrimitiveTypes.Float64}, iceberg.PrimitiveTypes.Float64, ""}, - {arrow.RunEndEncodedOf(arrow.PrimitiveTypes.Int32, arrow.BinaryTypes.String), iceberg.PrimitiveTypes.String, ""}, - {arrow.RunEndEncodedOf(arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Float64), iceberg.PrimitiveTypes.Float64, ""}, - {arrow.RunEndEncodedOf(arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Int16), iceberg.PrimitiveTypes.Int32, ""}, - } - - for _, tt := range tests { - t.Run(tt.dt.String(), func(t *testing.T) { - out, err := table.ArrowTypeToIceberg(tt.dt, false) - if tt.err == "" { - require.NoError(t, err) - assert.True(t, out.Equals(tt.ice), out.String(), tt.ice.String()) - } else { - assert.ErrorContains(t, err, tt.err) - } - }) - } -} - -func TestArrowSchemaToIceb(t *testing.T) { - tests := []struct { - name string - sc *arrow.Schema - expected string - err string - }{ - {"simple", arrow.NewSchema([]arrow.Field{ - {Name: "foo", Nullable: true, Type: arrow.BinaryTypes.String, - Metadata: fieldIDMeta("1")}, - {Name: "bar", Nullable: false, Type: arrow.PrimitiveTypes.Int32, - Metadata: fieldIDMeta("2")}, - {Name: "baz", Nullable: true, Type: arrow.FixedWidthTypes.Boolean, - Metadata: fieldIDMeta("3")}, - }, nil), `table { - 1: foo: optional string - 2: bar: required int - 3: baz: optional boolean -}`, ""}, - {"nested", arrow.NewSchema([]arrow.Field{ - {Name: "qux", Nullable: false, Metadata: fieldIDMeta("4"), - Type: arrow.ListOfField(arrow.Field{ - Name: "element", - Type: arrow.BinaryTypes.String, - Metadata: fieldIDMeta("5"), - })}, - {Name: "quux", Nullable: false, Metadata: fieldIDMeta("6"), - Type: arrow.MapOfWithMetadata(arrow.BinaryTypes.String, fieldIDMeta("7"), - arrow.MapOfWithMetadata(arrow.BinaryTypes.String, fieldIDMeta("9"), - arrow.PrimitiveTypes.Int32, fieldIDMeta("10")), fieldIDMeta("8"))}, - {Name: "location", Nullable: false, Metadata: fieldIDMeta("11"), - Type: arrow.ListOfField( - arrow.Field{ - Name: "element", Metadata: fieldIDMeta("12"), - Type: arrow.StructOf( - arrow.Field{Name: "latitude", Nullable: true, - Type: arrow.PrimitiveTypes.Float32, Metadata: fieldIDMeta("13")}, - arrow.Field{Name: "longitude", Nullable: true, - Type: arrow.PrimitiveTypes.Float32, Metadata: fieldIDMeta("14")}, - )})}, - {Name: "person", Nullable: true, Metadata: fieldIDMeta("15"), - Type: arrow.StructOf( - arrow.Field{Name: "name", Type: arrow.BinaryTypes.String, Nullable: true, Metadata: fieldIDMeta("16")}, - arrow.Field{Name: "age", Type: arrow.PrimitiveTypes.Int32, Metadata: fieldIDMeta("17")}, - )}, - }, nil), `table { - 4: qux: required list - 6: quux: required map> - 11: location: required list> - 15: person: optional struct<16: name: optional string, 17: age: required int> -}`, ""}, - {"missing ids", arrow.NewSchema([]arrow.Field{ - {Name: "foo", Type: arrow.BinaryTypes.String, Nullable: false}, - }, nil), "", "arrow schema does not have field-ids and no name mapping provided"}, - {"missing ids partial", arrow.NewSchema([]arrow.Field{ - {Name: "foo", Type: arrow.BinaryTypes.String, Metadata: fieldIDMeta("1")}, - {Name: "bar", Type: arrow.PrimitiveTypes.Int32, Nullable: false}, - }, nil), "", "arrow schema does not have field-ids and no name mapping provided"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - out, err := table.ArrowSchemaToIceberg(tt.sc, true, nil) - if tt.err == "" { - require.NoError(t, err) - assert.Equal(t, tt.expected, out.String()) - } else { - assert.ErrorContains(t, err, tt.err) - } - }) - } -} - -func makeID(v int) *int { return &v } - -var ( - icebergSchemaNested = iceberg.NewSchema(0, - iceberg.NestedField{ - ID: 1, Name: "foo", Type: iceberg.PrimitiveTypes.String, Required: true}, - iceberg.NestedField{ - ID: 2, Name: "bar", Type: iceberg.PrimitiveTypes.Int32, Required: true}, - iceberg.NestedField{ - ID: 3, Name: "baz", Type: iceberg.PrimitiveTypes.Bool, Required: false}, - iceberg.NestedField{ - ID: 4, Name: "qux", Required: true, Type: &iceberg.ListType{ - ElementID: 5, Element: iceberg.PrimitiveTypes.String, ElementRequired: false}}, - iceberg.NestedField{ - ID: 6, Name: "quux", - Type: &iceberg.MapType{ - KeyID: 7, - KeyType: iceberg.PrimitiveTypes.String, - ValueID: 8, - ValueType: &iceberg.MapType{ - KeyID: 9, - KeyType: iceberg.PrimitiveTypes.String, - ValueID: 10, - ValueType: iceberg.PrimitiveTypes.Int32, - ValueRequired: false, - }, - ValueRequired: false, - }, - Required: true}, - iceberg.NestedField{ - ID: 11, Name: "location", Type: &iceberg.ListType{ - ElementID: 12, Element: &iceberg.StructType{ - FieldList: []iceberg.NestedField{ - {ID: 13, Name: "latitude", Type: iceberg.PrimitiveTypes.Float32, Required: true}, - {ID: 14, Name: "longitude", Type: iceberg.PrimitiveTypes.Float32, Required: true}, - }, - }, - ElementRequired: false}, - Required: true}, - iceberg.NestedField{ - ID: 15, - Name: "person", - Type: &iceberg.StructType{ - FieldList: []iceberg.NestedField{ - {ID: 16, Name: "name", Type: iceberg.PrimitiveTypes.String, Required: false}, - {ID: 17, Name: "age", Type: iceberg.PrimitiveTypes.Int32, Required: true}, - }, - }, - Required: false, - }, - ) - - icebergSchemaSimple = iceberg.NewSchema(0, - iceberg.NestedField{ID: 1, Name: "foo", Type: iceberg.PrimitiveTypes.String}, - iceberg.NestedField{ID: 2, Name: "bar", Type: iceberg.PrimitiveTypes.Int32, Required: true}, - iceberg.NestedField{ID: 3, Name: "baz", Type: iceberg.PrimitiveTypes.Bool}, - ) -) - -func TestArrowSchemaWithNameMapping(t *testing.T) { - schemaWithoutIDs := arrow.NewSchema([]arrow.Field{ - {Name: "foo", Type: arrow.BinaryTypes.String, Nullable: true}, - {Name: "bar", Type: arrow.PrimitiveTypes.Int32, Nullable: false}, - {Name: "baz", Type: arrow.FixedWidthTypes.Boolean, Nullable: true}, - }, nil) - - schemaNestedWithoutIDs := arrow.NewSchema([]arrow.Field{ - {Name: "foo", Type: arrow.BinaryTypes.String, Nullable: false}, - {Name: "bar", Type: arrow.PrimitiveTypes.Int32, Nullable: false}, - {Name: "baz", Type: arrow.FixedWidthTypes.Boolean, Nullable: true}, - {Name: "qux", Type: arrow.ListOf(arrow.BinaryTypes.String), Nullable: false}, - {Name: "quux", Type: arrow.MapOf(arrow.BinaryTypes.String, - arrow.MapOf(arrow.BinaryTypes.String, arrow.PrimitiveTypes.Int32)), Nullable: false}, - {Name: "location", Type: arrow.ListOf(arrow.StructOf( - arrow.Field{Name: "latitude", Type: arrow.PrimitiveTypes.Float32, Nullable: false}, - arrow.Field{Name: "longitude", Type: arrow.PrimitiveTypes.Float32, Nullable: false}, - )), Nullable: false}, - {Name: "person", Type: arrow.StructOf( - arrow.Field{Name: "name", Type: arrow.BinaryTypes.String, Nullable: true}, - arrow.Field{Name: "age", Type: arrow.PrimitiveTypes.Int32, Nullable: false}, - ), Nullable: true}, - }, nil) - - tests := []struct { - name string - schema *arrow.Schema - mapping table.NameMapping - expected *iceberg.Schema - err string - }{ - {"simple", schemaWithoutIDs, table.NameMapping{ - {FieldID: makeID(1), Names: []string{"foo"}}, - {FieldID: makeID(2), Names: []string{"bar"}}, - {FieldID: makeID(3), Names: []string{"baz"}}, - }, icebergSchemaSimple, ""}, - {"field missing", schemaWithoutIDs, table.NameMapping{ - {FieldID: makeID(1), Names: []string{"foo"}}, - }, nil, "field missing from name mapping: bar"}, - {"nested schema", schemaNestedWithoutIDs, table.NameMapping{ - {FieldID: makeID(1), Names: []string{"foo"}}, - {FieldID: makeID(2), Names: []string{"bar"}}, - {FieldID: makeID(3), Names: []string{"baz"}}, - {FieldID: makeID(4), Names: []string{"qux"}, - Fields: []table.MappedField{{FieldID: makeID(5), Names: []string{"element"}}}}, - {FieldID: makeID(6), Names: []string{"quux"}, Fields: []table.MappedField{ - {FieldID: makeID(7), Names: []string{"key"}}, - {FieldID: makeID(8), Names: []string{"value"}, Fields: []table.MappedField{ - {FieldID: makeID(9), Names: []string{"key"}}, - {FieldID: makeID(10), Names: []string{"value"}}, - }}, - }}, - {FieldID: makeID(11), Names: []string{"location"}, Fields: []table.MappedField{ - {FieldID: makeID(12), Names: []string{"element"}, Fields: []table.MappedField{ - {FieldID: makeID(13), Names: []string{"latitude"}}, - {FieldID: makeID(14), Names: []string{"longitude"}}, - }}, - }}, - {FieldID: makeID(15), Names: []string{"person"}, Fields: []table.MappedField{ - {FieldID: makeID(16), Names: []string{"name"}}, - {FieldID: makeID(17), Names: []string{"age"}}, - }}, - }, icebergSchemaNested, ""}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - out, err := table.ArrowSchemaToIceberg(tt.schema, false, tt.mapping) - if tt.err != "" { - assert.ErrorContains(t, err, tt.err) - } else { - require.NoError(t, err) - assert.True(t, tt.expected.Equals(out), out.String(), tt.expected.String()) - } - }) - } -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package table_test + +import ( + "testing" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/extensions" + "github.com/apache/iceberg-go" + "github.com/apache/iceberg-go/table" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func fieldIDMeta(id string) arrow.Metadata { + return arrow.MetadataFrom(map[string]string{table.ArrowParquetFieldIDKey: id}) +} + +func TestArrowToIceberg(t *testing.T) { + tests := []struct { + dt arrow.DataType + ice iceberg.Type + err string + }{ + {&arrow.FixedSizeBinaryType{ByteWidth: 23}, iceberg.FixedTypeOf(23), ""}, + {&arrow.Decimal32Type{Precision: 8, Scale: 9}, iceberg.DecimalTypeOf(8, 9), ""}, + {&arrow.Decimal64Type{Precision: 15, Scale: 14}, iceberg.DecimalTypeOf(15, 14), ""}, + {&arrow.Decimal128Type{Precision: 26, Scale: 20}, iceberg.DecimalTypeOf(26, 20), ""}, + {&arrow.Decimal256Type{Precision: 8, Scale: 9}, nil, "unsupported arrow type for conversion - decimal256(8, 9)"}, + {arrow.FixedWidthTypes.Boolean, iceberg.PrimitiveTypes.Bool, ""}, + {arrow.PrimitiveTypes.Int8, iceberg.PrimitiveTypes.Int32, ""}, + {arrow.PrimitiveTypes.Uint8, iceberg.PrimitiveTypes.Int32, ""}, + {arrow.PrimitiveTypes.Int16, iceberg.PrimitiveTypes.Int32, ""}, + {arrow.PrimitiveTypes.Uint16, iceberg.PrimitiveTypes.Int32, ""}, + {arrow.PrimitiveTypes.Int32, iceberg.PrimitiveTypes.Int32, ""}, + {arrow.PrimitiveTypes.Uint32, iceberg.PrimitiveTypes.Int32, ""}, + {arrow.PrimitiveTypes.Int64, iceberg.PrimitiveTypes.Int64, ""}, + {arrow.PrimitiveTypes.Uint64, iceberg.PrimitiveTypes.Int64, ""}, + {arrow.FixedWidthTypes.Float16, iceberg.PrimitiveTypes.Float32, ""}, + {arrow.PrimitiveTypes.Float32, iceberg.PrimitiveTypes.Float32, ""}, + {arrow.PrimitiveTypes.Float64, iceberg.PrimitiveTypes.Float64, ""}, + {arrow.FixedWidthTypes.Date32, iceberg.PrimitiveTypes.Date, ""}, + {arrow.FixedWidthTypes.Date64, nil, "unsupported arrow type for conversion - date64"}, + {arrow.FixedWidthTypes.Time32s, nil, "unsupported arrow type for conversion - time32[s]"}, + {arrow.FixedWidthTypes.Time32ms, nil, "unsupported arrow type for conversion - time32[ms]"}, + {arrow.FixedWidthTypes.Time64us, iceberg.PrimitiveTypes.Time, ""}, + {arrow.FixedWidthTypes.Time64ns, nil, "unsupported arrow type for conversion - time64[ns]"}, + {arrow.FixedWidthTypes.Timestamp_s, iceberg.PrimitiveTypes.TimestampTz, ""}, + {arrow.FixedWidthTypes.Timestamp_ms, iceberg.PrimitiveTypes.TimestampTz, ""}, + {arrow.FixedWidthTypes.Timestamp_us, iceberg.PrimitiveTypes.TimestampTz, ""}, + {arrow.FixedWidthTypes.Timestamp_ns, nil, "'ns' timestamp precision not supported"}, + {&arrow.TimestampType{Unit: arrow.Second}, iceberg.PrimitiveTypes.Timestamp, ""}, + {&arrow.TimestampType{Unit: arrow.Millisecond}, iceberg.PrimitiveTypes.Timestamp, ""}, + {&arrow.TimestampType{Unit: arrow.Microsecond}, iceberg.PrimitiveTypes.Timestamp, ""}, + {&arrow.TimestampType{Unit: arrow.Nanosecond}, nil, "'ns' timestamp precision not supported"}, + {&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: "US/Pacific"}, nil, "unsupported arrow type for conversion - timestamp[us, tz=US/Pacific]"}, + {arrow.BinaryTypes.String, iceberg.PrimitiveTypes.String, ""}, + {arrow.BinaryTypes.LargeString, iceberg.PrimitiveTypes.String, ""}, + {arrow.BinaryTypes.StringView, nil, "unsupported arrow type for conversion - string_view"}, + {arrow.BinaryTypes.Binary, iceberg.PrimitiveTypes.Binary, ""}, + {arrow.BinaryTypes.LargeBinary, iceberg.PrimitiveTypes.Binary, ""}, + {arrow.BinaryTypes.BinaryView, nil, "unsupported arrow type for conversion - binary_view"}, + {extensions.NewUUIDType(), iceberg.PrimitiveTypes.UUID, ""}, + {arrow.StructOf(arrow.Field{ + Name: "foo", + Type: arrow.BinaryTypes.LargeString, + Nullable: true, + Metadata: arrow.MetadataFrom(map[string]string{ + table.ArrowParquetFieldIDKey: "1", table.ArrowFieldDocKey: "foo doc", + }), + }, arrow.Field{ + Name: "bar", + Type: arrow.PrimitiveTypes.Int32, + Metadata: fieldIDMeta("2"), + }, arrow.Field{ + Name: "baz", + Type: arrow.FixedWidthTypes.Boolean, + Nullable: true, + Metadata: fieldIDMeta("3"), + }), &iceberg.StructType{ + FieldList: []iceberg.NestedField{ + {ID: 1, Name: "foo", Type: iceberg.PrimitiveTypes.String, Required: false, Doc: "foo doc"}, + {ID: 2, Name: "bar", Type: iceberg.PrimitiveTypes.Int32, Required: true}, + {ID: 3, Name: "baz", Type: iceberg.PrimitiveTypes.Bool, Required: false}, + }}, ""}, + {arrow.ListOfField(arrow.Field{ + Name: "element", + Type: arrow.PrimitiveTypes.Int32, + Nullable: false, + Metadata: fieldIDMeta("1"), + }), &iceberg.ListType{ + ElementID: 1, + Element: iceberg.PrimitiveTypes.Int32, + ElementRequired: true, + }, ""}, + {arrow.LargeListOfField(arrow.Field{ + Name: "element", + Type: arrow.PrimitiveTypes.Int32, + Nullable: false, + Metadata: fieldIDMeta("1"), + }), &iceberg.ListType{ + ElementID: 1, + Element: iceberg.PrimitiveTypes.Int32, + ElementRequired: true, + }, ""}, + {arrow.FixedSizeListOfField(1, arrow.Field{ + Name: "element", + Type: arrow.PrimitiveTypes.Int32, + Nullable: false, + Metadata: fieldIDMeta("1"), + }), &iceberg.ListType{ + ElementID: 1, + Element: iceberg.PrimitiveTypes.Int32, + ElementRequired: true, + }, ""}, + {arrow.MapOfWithMetadata(arrow.PrimitiveTypes.Int32, + fieldIDMeta("1"), + arrow.BinaryTypes.String, fieldIDMeta("2")), + &iceberg.MapType{ + KeyID: 1, KeyType: iceberg.PrimitiveTypes.Int32, + ValueID: 2, ValueType: iceberg.PrimitiveTypes.String, ValueRequired: false, + }, ""}, + {&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int32, + ValueType: arrow.BinaryTypes.String}, iceberg.PrimitiveTypes.String, ""}, + {&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int32, + ValueType: arrow.PrimitiveTypes.Int32}, iceberg.PrimitiveTypes.Int32, ""}, + {&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int64, + ValueType: arrow.PrimitiveTypes.Float64}, iceberg.PrimitiveTypes.Float64, ""}, + {arrow.RunEndEncodedOf(arrow.PrimitiveTypes.Int32, arrow.BinaryTypes.String), iceberg.PrimitiveTypes.String, ""}, + {arrow.RunEndEncodedOf(arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Float64), iceberg.PrimitiveTypes.Float64, ""}, + {arrow.RunEndEncodedOf(arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Int16), iceberg.PrimitiveTypes.Int32, ""}, + } + + for _, tt := range tests { + t.Run(tt.dt.String(), func(t *testing.T) { + out, err := table.ArrowTypeToIceberg(tt.dt, false) + if tt.err == "" { + require.NoError(t, err) + assert.True(t, out.Equals(tt.ice), out.String(), tt.ice.String()) + } else { + assert.ErrorContains(t, err, tt.err) + } + }) + } +} + +func TestArrowSchemaToIceb(t *testing.T) { + tests := []struct { + name string + sc *arrow.Schema + expected string + err string + }{ + {"simple", arrow.NewSchema([]arrow.Field{ + {Name: "foo", Nullable: true, Type: arrow.BinaryTypes.String, + Metadata: fieldIDMeta("1")}, + {Name: "bar", Nullable: false, Type: arrow.PrimitiveTypes.Int32, + Metadata: fieldIDMeta("2")}, + {Name: "baz", Nullable: true, Type: arrow.FixedWidthTypes.Boolean, + Metadata: fieldIDMeta("3")}, + }, nil), `table { + 1: foo: optional string + 2: bar: required int + 3: baz: optional boolean +}`, ""}, + {"nested", arrow.NewSchema([]arrow.Field{ + {Name: "qux", Nullable: false, Metadata: fieldIDMeta("4"), + Type: arrow.ListOfField(arrow.Field{ + Name: "element", + Type: arrow.BinaryTypes.String, + Metadata: fieldIDMeta("5"), + })}, + {Name: "quux", Nullable: false, Metadata: fieldIDMeta("6"), + Type: arrow.MapOfWithMetadata(arrow.BinaryTypes.String, fieldIDMeta("7"), + arrow.MapOfWithMetadata(arrow.BinaryTypes.String, fieldIDMeta("9"), + arrow.PrimitiveTypes.Int32, fieldIDMeta("10")), fieldIDMeta("8"))}, + {Name: "location", Nullable: false, Metadata: fieldIDMeta("11"), + Type: arrow.ListOfField( + arrow.Field{ + Name: "element", Metadata: fieldIDMeta("12"), + Type: arrow.StructOf( + arrow.Field{Name: "latitude", Nullable: true, + Type: arrow.PrimitiveTypes.Float32, Metadata: fieldIDMeta("13")}, + arrow.Field{Name: "longitude", Nullable: true, + Type: arrow.PrimitiveTypes.Float32, Metadata: fieldIDMeta("14")}, + )})}, + {Name: "person", Nullable: true, Metadata: fieldIDMeta("15"), + Type: arrow.StructOf( + arrow.Field{Name: "name", Type: arrow.BinaryTypes.String, Nullable: true, Metadata: fieldIDMeta("16")}, + arrow.Field{Name: "age", Type: arrow.PrimitiveTypes.Int32, Metadata: fieldIDMeta("17")}, + )}, + }, nil), `table { + 4: qux: required list + 6: quux: required map> + 11: location: required list> + 15: person: optional struct<16: name: optional string, 17: age: required int> +}`, ""}, + {"missing ids", arrow.NewSchema([]arrow.Field{ + {Name: "foo", Type: arrow.BinaryTypes.String, Nullable: false}, + }, nil), "", "arrow schema does not have field-ids and no name mapping provided"}, + {"missing ids partial", arrow.NewSchema([]arrow.Field{ + {Name: "foo", Type: arrow.BinaryTypes.String, Metadata: fieldIDMeta("1")}, + {Name: "bar", Type: arrow.PrimitiveTypes.Int32, Nullable: false}, + }, nil), "", "arrow schema does not have field-ids and no name mapping provided"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out, err := table.ArrowSchemaToIceberg(tt.sc, true, nil) + if tt.err == "" { + require.NoError(t, err) + assert.Equal(t, tt.expected, out.String()) + } else { + assert.ErrorContains(t, err, tt.err) + } + }) + } +} + +func makeID(v int) *int { return &v } + +var ( + icebergSchemaNested = iceberg.NewSchema(0, + iceberg.NestedField{ + ID: 1, Name: "foo", Type: iceberg.PrimitiveTypes.String, Required: true}, + iceberg.NestedField{ + ID: 2, Name: "bar", Type: iceberg.PrimitiveTypes.Int32, Required: true}, + iceberg.NestedField{ + ID: 3, Name: "baz", Type: iceberg.PrimitiveTypes.Bool, Required: false}, + iceberg.NestedField{ + ID: 4, Name: "qux", Required: true, Type: &iceberg.ListType{ + ElementID: 5, Element: iceberg.PrimitiveTypes.String, ElementRequired: false}}, + iceberg.NestedField{ + ID: 6, Name: "quux", + Type: &iceberg.MapType{ + KeyID: 7, + KeyType: iceberg.PrimitiveTypes.String, + ValueID: 8, + ValueType: &iceberg.MapType{ + KeyID: 9, + KeyType: iceberg.PrimitiveTypes.String, + ValueID: 10, + ValueType: iceberg.PrimitiveTypes.Int32, + ValueRequired: false, + }, + ValueRequired: false, + }, + Required: true}, + iceberg.NestedField{ + ID: 11, Name: "location", Type: &iceberg.ListType{ + ElementID: 12, Element: &iceberg.StructType{ + FieldList: []iceberg.NestedField{ + {ID: 13, Name: "latitude", Type: iceberg.PrimitiveTypes.Float32, Required: true}, + {ID: 14, Name: "longitude", Type: iceberg.PrimitiveTypes.Float32, Required: true}, + }, + }, + ElementRequired: false}, + Required: true}, + iceberg.NestedField{ + ID: 15, + Name: "person", + Type: &iceberg.StructType{ + FieldList: []iceberg.NestedField{ + {ID: 16, Name: "name", Type: iceberg.PrimitiveTypes.String, Required: false}, + {ID: 17, Name: "age", Type: iceberg.PrimitiveTypes.Int32, Required: true}, + }, + }, + Required: false, + }, + ) + + icebergSchemaSimple = iceberg.NewSchema(0, + iceberg.NestedField{ID: 1, Name: "foo", Type: iceberg.PrimitiveTypes.String}, + iceberg.NestedField{ID: 2, Name: "bar", Type: iceberg.PrimitiveTypes.Int32, Required: true}, + iceberg.NestedField{ID: 3, Name: "baz", Type: iceberg.PrimitiveTypes.Bool}, + ) +) + +func TestArrowSchemaWithNameMapping(t *testing.T) { + schemaWithoutIDs := arrow.NewSchema([]arrow.Field{ + {Name: "foo", Type: arrow.BinaryTypes.String, Nullable: true}, + {Name: "bar", Type: arrow.PrimitiveTypes.Int32, Nullable: false}, + {Name: "baz", Type: arrow.FixedWidthTypes.Boolean, Nullable: true}, + }, nil) + + schemaNestedWithoutIDs := arrow.NewSchema([]arrow.Field{ + {Name: "foo", Type: arrow.BinaryTypes.String, Nullable: false}, + {Name: "bar", Type: arrow.PrimitiveTypes.Int32, Nullable: false}, + {Name: "baz", Type: arrow.FixedWidthTypes.Boolean, Nullable: true}, + {Name: "qux", Type: arrow.ListOf(arrow.BinaryTypes.String), Nullable: false}, + {Name: "quux", Type: arrow.MapOf(arrow.BinaryTypes.String, + arrow.MapOf(arrow.BinaryTypes.String, arrow.PrimitiveTypes.Int32)), Nullable: false}, + {Name: "location", Type: arrow.ListOf(arrow.StructOf( + arrow.Field{Name: "latitude", Type: arrow.PrimitiveTypes.Float32, Nullable: false}, + arrow.Field{Name: "longitude", Type: arrow.PrimitiveTypes.Float32, Nullable: false}, + )), Nullable: false}, + {Name: "person", Type: arrow.StructOf( + arrow.Field{Name: "name", Type: arrow.BinaryTypes.String, Nullable: true}, + arrow.Field{Name: "age", Type: arrow.PrimitiveTypes.Int32, Nullable: false}, + ), Nullable: true}, + }, nil) + + tests := []struct { + name string + schema *arrow.Schema + mapping table.NameMapping + expected *iceberg.Schema + err string + }{ + {"simple", schemaWithoutIDs, table.NameMapping{ + {FieldID: makeID(1), Names: []string{"foo"}}, + {FieldID: makeID(2), Names: []string{"bar"}}, + {FieldID: makeID(3), Names: []string{"baz"}}, + }, icebergSchemaSimple, ""}, + {"field missing", schemaWithoutIDs, table.NameMapping{ + {FieldID: makeID(1), Names: []string{"foo"}}, + }, nil, "field missing from name mapping: bar"}, + {"nested schema", schemaNestedWithoutIDs, table.NameMapping{ + {FieldID: makeID(1), Names: []string{"foo"}}, + {FieldID: makeID(2), Names: []string{"bar"}}, + {FieldID: makeID(3), Names: []string{"baz"}}, + {FieldID: makeID(4), Names: []string{"qux"}, + Fields: []table.MappedField{{FieldID: makeID(5), Names: []string{"element"}}}}, + {FieldID: makeID(6), Names: []string{"quux"}, Fields: []table.MappedField{ + {FieldID: makeID(7), Names: []string{"key"}}, + {FieldID: makeID(8), Names: []string{"value"}, Fields: []table.MappedField{ + {FieldID: makeID(9), Names: []string{"key"}}, + {FieldID: makeID(10), Names: []string{"value"}}, + }}, + }}, + {FieldID: makeID(11), Names: []string{"location"}, Fields: []table.MappedField{ + {FieldID: makeID(12), Names: []string{"element"}, Fields: []table.MappedField{ + {FieldID: makeID(13), Names: []string{"latitude"}}, + {FieldID: makeID(14), Names: []string{"longitude"}}, + }}, + }}, + {FieldID: makeID(15), Names: []string{"person"}, Fields: []table.MappedField{ + {FieldID: makeID(16), Names: []string{"name"}}, + {FieldID: makeID(17), Names: []string{"age"}}, + }}, + }, icebergSchemaNested, ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out, err := table.ArrowSchemaToIceberg(tt.schema, false, tt.mapping) + if tt.err != "" { + assert.ErrorContains(t, err, tt.err) + } else { + require.NoError(t, err) + assert.True(t, tt.expected.Equals(out), out.String(), tt.expected.String()) + } + }) + } +} diff --git a/table/evaluators.go b/table/evaluators.go index 1458c98..626e2d0 100644 --- a/table/evaluators.go +++ b/table/evaluators.go @@ -1,1125 +1,1125 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package table - -import ( - "fmt" - "math" - "slices" - - "github.com/apache/iceberg-go" - "github.com/google/uuid" -) - -const ( - rowsMightMatch, rowsMustMatch = true, true - rowsCannotMatch, rowsMightNotMatch = false, false - inPredicateLimit = 200 -) - -// newManifestEvaluator returns a function that can be used to evaluate whether a particular -// manifest file has rows that might or might not match a given partition filter by using -// the stats provided in the partitions (UpperBound/LowerBound/ContainsNull/ContainsNaN). -func newManifestEvaluator(spec iceberg.PartitionSpec, schema *iceberg.Schema, partitionFilter iceberg.BooleanExpression, caseSensitive bool) (func(iceberg.ManifestFile) (bool, error), error) { - partType := spec.PartitionType(schema) - partSchema := iceberg.NewSchema(0, partType.FieldList...) - filter, err := iceberg.RewriteNotExpr(partitionFilter) - if err != nil { - return nil, err - } - - boundFilter, err := iceberg.BindExpr(partSchema, filter, caseSensitive) - if err != nil { - return nil, err - } - - return (&manifestEvalVisitor{partitionFilter: boundFilter}).Eval, nil -} - -type manifestEvalVisitor struct { - partitionFields []iceberg.FieldSummary - partitionFilter iceberg.BooleanExpression -} - -func (m *manifestEvalVisitor) Eval(manifest iceberg.ManifestFile) (bool, error) { - if parts := manifest.Partitions(); len(parts) > 0 { - m.partitionFields = parts - return iceberg.VisitExpr(m.partitionFilter, m) - } - - return rowsMightMatch, nil -} - -func removeBoundCmp[T iceberg.LiteralType](bound iceberg.Literal, vals []iceberg.Literal, cmpToDelete int) []iceberg.Literal { - val := bound.(iceberg.TypedLiteral[T]) - cmp := val.Comparator() - - return slices.DeleteFunc(vals, func(v iceberg.Literal) bool { - return cmp(val.Value(), v.(iceberg.TypedLiteral[T]).Value()) == cmpToDelete - }) -} - -func removeBoundCheck(bound iceberg.Literal, vals []iceberg.Literal, toDelete int) []iceberg.Literal { - switch bound.Type().(type) { - case iceberg.BooleanType: - return removeBoundCmp[bool](bound, vals, toDelete) - case iceberg.Int32Type: - return removeBoundCmp[int32](bound, vals, toDelete) - case iceberg.Int64Type: - return removeBoundCmp[int64](bound, vals, toDelete) - case iceberg.Float32Type: - return removeBoundCmp[float32](bound, vals, toDelete) - case iceberg.Float64Type: - return removeBoundCmp[float64](bound, vals, toDelete) - case iceberg.DateType: - return removeBoundCmp[iceberg.Date](bound, vals, toDelete) - case iceberg.TimeType: - return removeBoundCmp[iceberg.Time](bound, vals, toDelete) - case iceberg.TimestampType, iceberg.TimestampTzType: - return removeBoundCmp[iceberg.Timestamp](bound, vals, toDelete) - case iceberg.BinaryType, iceberg.FixedType: - return removeBoundCmp[[]byte](bound, vals, toDelete) - case iceberg.StringType: - return removeBoundCmp[string](bound, vals, toDelete) - case iceberg.UUIDType: - return removeBoundCmp[uuid.UUID](bound, vals, toDelete) - case iceberg.DecimalType: - return removeBoundCmp[iceberg.Decimal](bound, vals, toDelete) - } - panic("unrecognized type") -} - -func allBoundCmp[T iceberg.LiteralType](bound iceberg.Literal, set iceberg.Set[iceberg.Literal], want int) bool { - val := bound.(iceberg.TypedLiteral[T]) - cmp := val.Comparator() - - return set.All(func(e iceberg.Literal) bool { - return cmp(val.Value(), e.(iceberg.TypedLiteral[T]).Value()) == want - }) -} - -func allBoundCheck(bound iceberg.Literal, set iceberg.Set[iceberg.Literal], want int) bool { - switch bound.Type().(type) { - case iceberg.BooleanType: - return allBoundCmp[bool](bound, set, want) - case iceberg.Int32Type: - return allBoundCmp[int32](bound, set, want) - case iceberg.Int64Type: - return allBoundCmp[int64](bound, set, want) - case iceberg.Float32Type: - return allBoundCmp[float32](bound, set, want) - case iceberg.Float64Type: - return allBoundCmp[float64](bound, set, want) - case iceberg.DateType: - return allBoundCmp[iceberg.Date](bound, set, want) - case iceberg.TimeType: - return allBoundCmp[iceberg.Time](bound, set, want) - case iceberg.TimestampType, iceberg.TimestampTzType: - return allBoundCmp[iceberg.Timestamp](bound, set, want) - case iceberg.BinaryType, iceberg.FixedType: - return allBoundCmp[[]byte](bound, set, want) - case iceberg.StringType: - return allBoundCmp[string](bound, set, want) - case iceberg.UUIDType: - return allBoundCmp[uuid.UUID](bound, set, want) - case iceberg.DecimalType: - return allBoundCmp[iceberg.Decimal](bound, set, want) - } - panic(iceberg.ErrType) -} - -func (m *manifestEvalVisitor) VisitIn(term iceberg.BoundTerm, literals iceberg.Set[iceberg.Literal]) bool { - pos := term.Ref().Pos() - field := m.partitionFields[pos] - - if field.LowerBound == nil { - return rowsCannotMatch - } - - if literals.Len() > inPredicateLimit { - return rowsMightMatch - } - - lower, err := iceberg.LiteralFromBytes(term.Type(), *field.LowerBound) - if err != nil { - panic(err) - } - - if allBoundCheck(lower, literals, 1) { - return rowsCannotMatch - } - - if field.UpperBound != nil { - upper, err := iceberg.LiteralFromBytes(term.Type(), *field.UpperBound) - if err != nil { - panic(err) - } - - if allBoundCheck(upper, literals, -1) { - return rowsCannotMatch - } - } - - return rowsMightMatch -} - -func (m *manifestEvalVisitor) VisitNotIn(term iceberg.BoundTerm, literals iceberg.Set[iceberg.Literal]) bool { - // because the bounds are not necessarily a min or max value, this cannot be answered using them - // notIn(col, {X, ...}) with (X, Y) doesn't guarantee that X is a value in col - return rowsMightMatch -} - -func (m *manifestEvalVisitor) VisitIsNan(term iceberg.BoundTerm) bool { - pos := term.Ref().Pos() - field := m.partitionFields[pos] - - if field.ContainsNaN != nil && !*field.ContainsNaN { - return rowsCannotMatch - } - - return rowsMightMatch -} - -func (m *manifestEvalVisitor) VisitNotNan(term iceberg.BoundTerm) bool { - pos := term.Ref().Pos() - field := m.partitionFields[pos] - - if field.ContainsNaN != nil && *field.ContainsNaN && !field.ContainsNull && field.LowerBound == nil { - return rowsCannotMatch - } - - return rowsMightMatch -} - -func (m *manifestEvalVisitor) VisitIsNull(term iceberg.BoundTerm) bool { - pos := term.Ref().Pos() - field := m.partitionFields[pos] - - if !field.ContainsNull { - return rowsCannotMatch - } - - return rowsMightMatch -} - -func (m *manifestEvalVisitor) VisitNotNull(term iceberg.BoundTerm) bool { - pos := term.Ref().Pos() - field := m.partitionFields[pos] - - // ContainsNull encodes whether at least one partition value is null - // lowerBound is null if all partition values are null - allNull := field.ContainsNull && field.LowerBound == nil - if allNull && (term.Ref().Type().Equals(iceberg.PrimitiveTypes.Float32) || term.Ref().Type().Equals(iceberg.PrimitiveTypes.Float64)) { - // floating point types may include NaN values, which we check separately - // in case bounds don't include NaN values, ContainsNaN needsz to be checked - allNull = field.ContainsNaN != nil && !*field.ContainsNaN - } - - if allNull { - return rowsCannotMatch - } - - return rowsMightMatch -} - -func getCmp[T iceberg.LiteralType](b iceberg.TypedLiteral[T]) func(iceberg.Literal, iceberg.Literal) int { - cmp := b.Comparator() - return func(l1, l2 iceberg.Literal) int { - return cmp(l1.(iceberg.TypedLiteral[T]).Value(), l2.(iceberg.TypedLiteral[T]).Value()) - } -} - -func getCmpLiteral(boundary iceberg.Literal) func(iceberg.Literal, iceberg.Literal) int { - switch l := boundary.(type) { - case iceberg.TypedLiteral[bool]: - return getCmp(l) - case iceberg.TypedLiteral[int32]: - return getCmp(l) - case iceberg.TypedLiteral[int64]: - return getCmp(l) - case iceberg.TypedLiteral[float32]: - return getCmp(l) - case iceberg.TypedLiteral[float64]: - return getCmp(l) - case iceberg.TypedLiteral[iceberg.Date]: - return getCmp(l) - case iceberg.TypedLiteral[iceberg.Time]: - return getCmp(l) - case iceberg.TypedLiteral[iceberg.Timestamp]: - return getCmp(l) - case iceberg.TypedLiteral[[]byte]: - return getCmp(l) - case iceberg.TypedLiteral[string]: - return getCmp(l) - case iceberg.TypedLiteral[uuid.UUID]: - return getCmp(l) - case iceberg.TypedLiteral[iceberg.Decimal]: - return getCmp(l) - } - panic(iceberg.ErrType) -} - -func (m *manifestEvalVisitor) VisitEqual(term iceberg.BoundTerm, lit iceberg.Literal) bool { - pos := term.Ref().Pos() - field := m.partitionFields[pos] - - if field.LowerBound == nil || field.UpperBound == nil { - // values are all null and literal cannot contain null - return rowsCannotMatch - } - - lower, err := iceberg.LiteralFromBytes(term.Ref().Type(), *field.LowerBound) - if err != nil { - panic(err) - } - - cmp := getCmpLiteral(lower) - if cmp(lower, lit) == 1 { - return rowsCannotMatch - } - - upper, err := iceberg.LiteralFromBytes(term.Ref().Type(), *field.UpperBound) - if err != nil { - panic(err) - } - - if cmp(lit, upper) == 1 { - return rowsCannotMatch - } - - return rowsMightMatch -} - -func (m *manifestEvalVisitor) VisitNotEqual(term iceberg.BoundTerm, lit iceberg.Literal) bool { - // because bounds are not necessarily a min or max, this cannot be answered - // using them. notEq(col, X) with (X, Y) doesn't guarantee X is a value in col - return rowsMightMatch -} - -func (m *manifestEvalVisitor) VisitGreaterEqual(term iceberg.BoundTerm, lit iceberg.Literal) bool { - pos := term.Ref().Pos() - field := m.partitionFields[pos] - - if field.UpperBound == nil { - return rowsCannotMatch - } - - upper, err := iceberg.LiteralFromBytes(term.Ref().Type(), *field.UpperBound) - if err != nil { - panic(err) - } - - if getCmpLiteral(upper)(lit, upper) == 1 { - return rowsCannotMatch - } - - return rowsMightMatch -} - -func (m *manifestEvalVisitor) VisitGreater(term iceberg.BoundTerm, lit iceberg.Literal) bool { - pos := term.Ref().Pos() - field := m.partitionFields[pos] - - if field.UpperBound == nil { - return rowsCannotMatch - } - - upper, err := iceberg.LiteralFromBytes(term.Ref().Type(), *field.UpperBound) - if err != nil { - panic(err) - } - - if getCmpLiteral(upper)(lit, upper) >= 0 { - return rowsCannotMatch - } - - return rowsMightMatch -} - -func (m *manifestEvalVisitor) VisitLessEqual(term iceberg.BoundTerm, lit iceberg.Literal) bool { - pos := term.Ref().Pos() - field := m.partitionFields[pos] - - if field.LowerBound == nil { - return rowsCannotMatch - } - - lower, err := iceberg.LiteralFromBytes(term.Ref().Type(), *field.LowerBound) - if err != nil { - panic(err) - } - - if getCmpLiteral(lower)(lit, lower) == -1 { - return rowsCannotMatch - } - - return rowsMightMatch -} - -func (m *manifestEvalVisitor) VisitLess(term iceberg.BoundTerm, lit iceberg.Literal) bool { - pos := term.Ref().Pos() - field := m.partitionFields[pos] - - if field.LowerBound == nil { - return rowsCannotMatch - } - - lower, err := iceberg.LiteralFromBytes(term.Ref().Type(), *field.LowerBound) - if err != nil { - panic(err) - } - - if getCmpLiteral(lower)(lit, lower) <= 0 { - return rowsCannotMatch - } - - return rowsMightMatch -} - -func (m *manifestEvalVisitor) VisitStartsWith(term iceberg.BoundTerm, lit iceberg.Literal) bool { - pos := term.Ref().Pos() - field := m.partitionFields[pos] - - var prefix string - if val, ok := lit.(iceberg.TypedLiteral[string]); ok { - prefix = val.Value() - } else { - prefix = string(lit.(iceberg.TypedLiteral[[]byte]).Value()) - } - - lenPrefix := len(prefix) - - if field.LowerBound == nil { - return rowsCannotMatch - } - - lower, err := iceberg.LiteralFromBytes(term.Ref().Type(), *field.LowerBound) - if err != nil { - panic(err) - } - - // truncate lower bound so that it's length is not greater than the length of prefix - var v string - switch l := lower.(type) { - case iceberg.TypedLiteral[string]: - v = l.Value() - if len(v) > lenPrefix { - v = v[:lenPrefix] - } - case iceberg.TypedLiteral[[]byte]: - v = string(l.Value()) - if len(v) > lenPrefix { - v = v[:lenPrefix] - } - } - - if v > prefix { - return rowsCannotMatch - } - - if field.UpperBound == nil { - return rowsCannotMatch - } - - upper, err := iceberg.LiteralFromBytes(term.Ref().Type(), *field.UpperBound) - if err != nil { - panic(err) - } - - switch u := upper.(type) { - case iceberg.TypedLiteral[string]: - v = u.Value() - if len(v) > lenPrefix { - v = v[:lenPrefix] - } - case iceberg.TypedLiteral[[]byte]: - v = string(u.Value()) - if len(v) > lenPrefix { - v = v[:lenPrefix] - } - } - - if v < prefix { - return rowsCannotMatch - } - - return rowsMightMatch -} - -func (m *manifestEvalVisitor) VisitNotStartsWith(term iceberg.BoundTerm, lit iceberg.Literal) bool { - pos := term.Ref().Pos() - field := m.partitionFields[pos] - - if field.ContainsNull || field.LowerBound == nil || field.UpperBound == nil { - return rowsMightMatch - } - - // NotStartsWith will match unless ALL values must start with the prefix. - // this happens when the lower and upper bounds BOTH start with the prefix - lower, err := iceberg.LiteralFromBytes(term.Ref().Type(), *field.LowerBound) - if err != nil { - panic(err) - } - - upper, err := iceberg.LiteralFromBytes(term.Ref().Type(), *field.UpperBound) - if err != nil { - panic(err) - } - - var ( - prefix, lowerBound, upperBound string - ) - if val, ok := lit.(iceberg.TypedLiteral[string]); ok { - prefix = val.Value() - lowerBound, upperBound = lower.(iceberg.TypedLiteral[string]).Value(), upper.(iceberg.TypedLiteral[string]).Value() - } else { - prefix = string(lit.(iceberg.TypedLiteral[[]byte]).Value()) - lowerBound = string(lower.(iceberg.TypedLiteral[[]byte]).Value()) - upperBound = string(upper.(iceberg.TypedLiteral[[]byte]).Value()) - } - - lenPrefix := len(prefix) - if len(lowerBound) < lenPrefix { - return rowsMightMatch - } - - if lowerBound[:lenPrefix] == prefix { - // if upper is shorter then upper can't start with the prefix - if len(upperBound) < lenPrefix { - return rowsMightMatch - } - - if upperBound[:lenPrefix] == prefix { - return rowsCannotMatch - } - } - - return rowsMightMatch -} - -func (m *manifestEvalVisitor) VisitTrue() bool { - return rowsMightMatch -} - -func (m *manifestEvalVisitor) VisitFalse() bool { - return rowsCannotMatch -} - -func (m *manifestEvalVisitor) VisitUnbound(iceberg.UnboundPredicate) bool { - panic("need bound predicate") -} - -func (m *manifestEvalVisitor) VisitBound(pred iceberg.BoundPredicate) bool { - return iceberg.VisitBoundPredicate(pred, m) -} - -func (m *manifestEvalVisitor) VisitNot(child bool) bool { return !child } -func (m *manifestEvalVisitor) VisitAnd(left, right bool) bool { return left && right } -func (m *manifestEvalVisitor) VisitOr(left, right bool) bool { return left || right } - -type projectionEvaluator struct { - spec iceberg.PartitionSpec - schema *iceberg.Schema - caseSensitive bool -} - -func (*projectionEvaluator) VisitTrue() iceberg.BooleanExpression { return iceberg.AlwaysTrue{} } -func (*projectionEvaluator) VisitFalse() iceberg.BooleanExpression { return iceberg.AlwaysFalse{} } -func (*projectionEvaluator) VisitNot(child iceberg.BooleanExpression) iceberg.BooleanExpression { - panic(fmt.Errorf("%w: cannot project 'not' expression, should be rewritten %s", - iceberg.ErrInvalidArgument, child)) -} - -func (*projectionEvaluator) VisitAnd(left, right iceberg.BooleanExpression) iceberg.BooleanExpression { - return iceberg.NewAnd(left, right) -} - -func (*projectionEvaluator) VisitOr(left, right iceberg.BooleanExpression) iceberg.BooleanExpression { - return iceberg.NewOr(left, right) -} - -func (*projectionEvaluator) VisitUnbound(pred iceberg.UnboundPredicate) iceberg.BooleanExpression { - panic(fmt.Errorf("%w: cannot project unbound predicate: %s", iceberg.ErrInvalidArgument, pred)) -} - -type inclusiveProjection struct{ projectionEvaluator } - -func (p *inclusiveProjection) Project(expr iceberg.BooleanExpression) (iceberg.BooleanExpression, error) { - expr, err := iceberg.RewriteNotExpr(expr) - if err != nil { - return nil, err - } - - bound, err := iceberg.BindExpr(p.schema, expr, p.caseSensitive) - if err != nil { - return nil, err - } - - return iceberg.VisitExpr(bound, p) -} - -func (p *inclusiveProjection) VisitBound(pred iceberg.BoundPredicate) iceberg.BooleanExpression { - parts := p.spec.FieldsBySourceID(pred.Term().Ref().Field().ID) - - var result iceberg.BooleanExpression = iceberg.AlwaysTrue{} - for _, part := range parts { - // consider (d = 2019-01-01) with bucket(7, d) and bucket(5, d) - // projections: b1 = bucket(7, '2019-01-01') = 5, b2 = bucket(5, '2019-01-01') = 0 - // any value where b1 != 5 or any value where b2 != 0 cannot be the '2019-01-01' - // - // similarly, if partitioning by day(ts) and hour(ts), the more restrictive - // projection should be used. ts = 2019-01-01T01:00:00 produces day=2019-01-01 and - // hour=2019-01-01-01. the value will be in 2019-01-01-01 and not in 2019-01-01-02. - inclProjection, err := part.Transform.Project(part.Name, pred) - if err != nil { - panic(err) - } - if inclProjection != nil { - result = iceberg.NewAnd(result, inclProjection) - } - } - - return result -} - -func newInclusiveProjection(s *iceberg.Schema, spec iceberg.PartitionSpec, caseSensitive bool) func(iceberg.BooleanExpression) (iceberg.BooleanExpression, error) { - return (&inclusiveProjection{ - projectionEvaluator: projectionEvaluator{ - schema: s, - spec: spec, - caseSensitive: caseSensitive, - }, - }).Project -} - -type metricsEvaluator struct { - valueCounts map[int]int64 - nullCounts map[int]int64 - nanCounts map[int]int64 - lowerBounds map[int][]byte - upperBounds map[int][]byte -} - -func (m *metricsEvaluator) VisitTrue() bool { return rowsMightMatch } -func (m *metricsEvaluator) VisitFalse() bool { return rowsCannotMatch } -func (m *metricsEvaluator) VisitNot(child bool) bool { - panic(fmt.Errorf("%w: NOT should be rewritten %v", iceberg.ErrInvalidArgument, child)) -} -func (m *metricsEvaluator) VisitAnd(left, right bool) bool { return left && right } -func (m *metricsEvaluator) VisitOr(left, right bool) bool { return left || right } - -func (m *metricsEvaluator) containsNullsOnly(id int) bool { - valCount, ok := m.valueCounts[id] - if !ok { - return false - } - - nullCount, ok := m.nullCounts[id] - if !ok { - return false - } - - return valCount == nullCount -} - -func (m *metricsEvaluator) containsNansOnly(id int) bool { - nanCount, ok := m.nanCounts[id] - if !ok { - return false - } - - valCount, ok := m.valueCounts[id] - if !ok { - return false - } - - return nanCount == valCount -} - -func (m *metricsEvaluator) isNan(v iceberg.Literal) bool { - switch v := v.(type) { - case iceberg.Float32Literal: - return math.IsNaN(float64(v)) - case iceberg.Float64Literal: - return math.IsNaN(float64(v)) - default: - return false - } -} - -func newInclusiveMetricsEvaluator(s *iceberg.Schema, expr iceberg.BooleanExpression, - caseSensitive bool, includeEmptyFiles bool) (func(iceberg.DataFile) (bool, error), error) { - - rewritten, err := iceberg.RewriteNotExpr(expr) - if err != nil { - return nil, err - } - - bound, err := iceberg.BindExpr(s, rewritten, caseSensitive) - if err != nil { - return nil, err - } - - return (&inclusiveMetricsEval{ - st: s.AsStruct(), - includeEmptyFiles: includeEmptyFiles, - expr: bound, - }).Eval, nil -} - -type inclusiveMetricsEval struct { - metricsEvaluator - - st iceberg.StructType - expr iceberg.BooleanExpression - includeEmptyFiles bool -} - -func (m *inclusiveMetricsEval) Eval(file iceberg.DataFile) (bool, error) { - if !m.includeEmptyFiles && file.Count() == 0 { - return rowsCannotMatch, nil - } - - m.valueCounts, m.nullCounts = file.ValueCounts(), file.NullValueCounts() - m.nanCounts = file.NaNValueCounts() - m.lowerBounds, m.upperBounds = file.LowerBoundValues(), file.UpperBoundValues() - - return iceberg.VisitExpr(m.expr, m) -} - -func (m *inclusiveMetricsEval) mayContainNull(fieldID int) bool { - if m.nullCounts == nil { - return true - } - - _, ok := m.nullCounts[fieldID] - return ok -} - -func (m *inclusiveMetricsEval) VisitUnbound(iceberg.UnboundPredicate) bool { - panic("need bound predicate") -} - -func (m *inclusiveMetricsEval) VisitBound(pred iceberg.BoundPredicate) bool { - return iceberg.VisitBoundPredicate(pred, m) -} - -func (m *inclusiveMetricsEval) VisitIsNull(t iceberg.BoundTerm) bool { - fieldID := t.Ref().Field().ID - if cnt, exists := m.nullCounts[fieldID]; exists && cnt == 0 { - return rowsCannotMatch - } - return rowsMightMatch -} - -func (m *inclusiveMetricsEval) VisitNotNull(t iceberg.BoundTerm) bool { - // no need to check whether the field is required because binding evaluates - // that case if the column has no non-null values, the expression cannot match - fieldID := t.Ref().Field().ID - if m.containsNullsOnly(fieldID) { - return rowsCannotMatch - } - return rowsMightMatch -} - -func (m *inclusiveMetricsEval) VisitIsNan(t iceberg.BoundTerm) bool { - fieldID := t.Ref().Field().ID - if cnt, exists := m.nanCounts[fieldID]; exists && cnt == 0 { - return rowsCannotMatch - } - // when there's no nancounts information but we already know the column - // contains null it's guaranteed that there's no nan value - if m.containsNullsOnly(fieldID) { - return rowsCannotMatch - } - return rowsMightMatch -} - -func (m *inclusiveMetricsEval) VisitNotNan(t iceberg.BoundTerm) bool { - fieldID := t.Ref().Field().ID - - if m.containsNansOnly(fieldID) { - return rowsCannotMatch - } - return rowsMightMatch -} - -func (m *inclusiveMetricsEval) VisitLess(t iceberg.BoundTerm, lit iceberg.Literal) bool { - field := t.Ref().Field() - fieldID := field.ID - - if m.containsNullsOnly(fieldID) || m.containsNansOnly(fieldID) { - return rowsCannotMatch - } - - if _, ok := field.Type.(iceberg.PrimitiveType); !ok { - panic(fmt.Errorf("%w: expected iceberg.PrimitiveType, got %s", - iceberg.ErrInvalidTypeString, field.Type)) - } - - if lowerBoundBytes := m.lowerBounds[fieldID]; lowerBoundBytes != nil { - lowerBound, err := iceberg.LiteralFromBytes(field.Type, lowerBoundBytes) - if err != nil { - panic(err) - } - - if m.isNan(lowerBound) { - // nan indicates unreliable bounds - return rowsMightMatch - } - - if getCmpLiteral(lowerBound)(lowerBound, lit) >= 0 { - return rowsCannotMatch - } - } - - return rowsMightMatch -} - -func (m *inclusiveMetricsEval) VisitLessEqual(t iceberg.BoundTerm, lit iceberg.Literal) bool { - field := t.Ref().Field() - fieldID := field.ID - - if m.containsNullsOnly(fieldID) || m.containsNansOnly(fieldID) { - return rowsCannotMatch - } - - if _, ok := field.Type.(iceberg.PrimitiveType); !ok { - panic(fmt.Errorf("%w: expected iceberg.PrimitiveType, got %s", - iceberg.ErrInvalidTypeString, field.Type)) - } - - if lowerBoundBytes := m.lowerBounds[fieldID]; lowerBoundBytes != nil { - lowerBound, err := iceberg.LiteralFromBytes(field.Type, lowerBoundBytes) - if err != nil { - panic(err) - } - - if m.isNan(lowerBound) { - // nan indicates unreliable bounds - return rowsMightMatch - } - - if getCmpLiteral(lowerBound)(lowerBound, lit) > 0 { - return rowsCannotMatch - } - } - - return rowsMightMatch -} - -func (m *inclusiveMetricsEval) VisitGreater(t iceberg.BoundTerm, lit iceberg.Literal) bool { - field := t.Ref().Field() - fieldID := field.ID - - if m.containsNullsOnly(fieldID) || m.containsNansOnly(fieldID) { - return rowsCannotMatch - } - - if _, ok := field.Type.(iceberg.PrimitiveType); !ok { - panic(fmt.Errorf("%w: expected iceberg.PrimitiveType, got %s", - iceberg.ErrInvalidTypeString, field.Type)) - } - - if upperBoundBytes := m.upperBounds[fieldID]; upperBoundBytes != nil { - upperBound, err := iceberg.LiteralFromBytes(field.Type, upperBoundBytes) - if err != nil { - panic(err) - } - - if getCmpLiteral(upperBound)(upperBound, lit) <= 0 { - if m.isNan(upperBound) { - return rowsMightMatch - } - - return rowsCannotMatch - } - } - - return rowsMightMatch -} - -func (m *inclusiveMetricsEval) VisitGreaterEqual(t iceberg.BoundTerm, lit iceberg.Literal) bool { - field := t.Ref().Field() - fieldID := field.ID - - if m.containsNullsOnly(fieldID) || m.containsNansOnly(fieldID) { - return rowsCannotMatch - } - - if _, ok := field.Type.(iceberg.PrimitiveType); !ok { - panic(fmt.Errorf("%w: expected iceberg.PrimitiveType, got %s", - iceberg.ErrInvalidTypeString, field.Type)) - } - - if upperBoundBytes := m.upperBounds[fieldID]; upperBoundBytes != nil { - upperBound, err := iceberg.LiteralFromBytes(field.Type, upperBoundBytes) - if err != nil { - panic(err) - } - - if getCmpLiteral(upperBound)(upperBound, lit) < 0 { - if m.isNan(upperBound) { - return rowsMightMatch - } - - return rowsCannotMatch - } - } - - return rowsMightMatch -} - -func (m *inclusiveMetricsEval) VisitEqual(t iceberg.BoundTerm, lit iceberg.Literal) bool { - field := t.Ref().Field() - fieldID := field.ID - - if m.containsNullsOnly(fieldID) || m.containsNansOnly(fieldID) { - return rowsCannotMatch - } - - if _, ok := field.Type.(iceberg.PrimitiveType); !ok { - panic(fmt.Errorf("%w: expected iceberg.PrimitiveType, got %s", - iceberg.ErrInvalidTypeString, field.Type)) - } - - var cmp func(iceberg.Literal, iceberg.Literal) int - if lowerBoundBytes := m.lowerBounds[fieldID]; lowerBoundBytes != nil { - lowerBound, err := iceberg.LiteralFromBytes(field.Type, lowerBoundBytes) - if err != nil { - panic(err) - } - - if m.isNan(lowerBound) { - return rowsMightMatch - } - - cmp = getCmpLiteral(lowerBound) - if cmp(lowerBound, lit) == 1 { - return rowsCannotMatch - } - } - - if upperBoundBytes := m.upperBounds[fieldID]; upperBoundBytes != nil { - upperBound, err := iceberg.LiteralFromBytes(field.Type, upperBoundBytes) - if err != nil { - panic(err) - } - - if m.isNan(upperBound) { - return rowsMightMatch - } - - if cmp(upperBound, lit) == -1 { - return rowsCannotMatch - } - } - - return rowsMightMatch -} - -func (m *inclusiveMetricsEval) VisitNotEqual(iceberg.BoundTerm, iceberg.Literal) bool { - return rowsMightMatch -} - -func (m *inclusiveMetricsEval) VisitIn(t iceberg.BoundTerm, s iceberg.Set[iceberg.Literal]) bool { - field := t.Ref().Field() - fieldID := field.ID - - if m.containsNullsOnly(fieldID) || m.containsNansOnly(fieldID) { - return rowsCannotMatch - } - - if s.Len() > inPredicateLimit { - // skip evaluating the predicate if the number of values is too big - return rowsMightMatch - } - - if _, ok := field.Type.(iceberg.PrimitiveType); !ok { - panic(fmt.Errorf("%w: expected iceberg.PrimitiveType, got %s", - iceberg.ErrInvalidTypeString, field.Type)) - } - - values := s.Members() - if lowerBoundBytes := m.lowerBounds[fieldID]; lowerBoundBytes != nil { - lowerBound, err := iceberg.LiteralFromBytes(field.Type, lowerBoundBytes) - if err != nil { - panic(lowerBound) - } - - if m.isNan(lowerBound) { - return rowsMightMatch - } - - values = removeBoundCheck(lowerBound, values, 1) - if len(values) == 0 { - return rowsCannotMatch - } - } - - if upperBoundBytes := m.upperBounds[fieldID]; upperBoundBytes != nil { - upperBound, err := iceberg.LiteralFromBytes(field.Type, upperBoundBytes) - if err != nil { - panic(err) - } - - if m.isNan(upperBound) { - return rowsMightMatch - } - - values = removeBoundCheck(upperBound, values, -1) - if len(values) == 0 { - return rowsCannotMatch - } - } - - return rowsMightMatch -} - -func (m *inclusiveMetricsEval) VisitNotIn(iceberg.BoundTerm, iceberg.Set[iceberg.Literal]) bool { - // because the bounds are not necessarily a min or max value, this cannot be - // answered using them. notIn(col, {X, ...}) with (XX, Y) doesn't guarantee that - // X is a value in col - return rowsMightMatch -} - -func (m *inclusiveMetricsEval) VisitStartsWith(t iceberg.BoundTerm, lit iceberg.Literal) bool { - field := t.Ref().Field() - fieldID := field.ID - - if m.containsNullsOnly(fieldID) { - return rowsCannotMatch - } - - if _, ok := field.Type.(iceberg.PrimitiveType); !ok { - panic(fmt.Errorf("%w: expected iceberg.PrimitiveType, got %s", - iceberg.ErrInvalidTypeString, field.Type)) - } - - var prefix string - if val, ok := lit.(iceberg.TypedLiteral[string]); ok { - prefix = val.Value() - } else { - prefix = string(lit.(iceberg.TypedLiteral[[]byte]).Value()) - } - - lenPrefix := len(prefix) - - if lowerBoundBytes := m.lowerBounds[fieldID]; lowerBoundBytes != nil { - lowerBound, err := iceberg.LiteralFromBytes(field.Type, lowerBoundBytes) - if err != nil { - panic(err) - } - - var v string - switch l := lowerBound.(type) { - case iceberg.TypedLiteral[string]: - v = l.Value() - case iceberg.TypedLiteral[[]byte]: - v = string(l.Value()) - } - - if len(v) > lenPrefix { - v = v[:lenPrefix] - } - - if len(v) > 0 && v > prefix { - return rowsCannotMatch - } - } - - if upperBoundBytes := m.upperBounds[fieldID]; upperBoundBytes != nil { - upperBound, err := iceberg.LiteralFromBytes(field.Type, upperBoundBytes) - if err != nil { - panic(err) - } - - var v string - switch u := upperBound.(type) { - case iceberg.TypedLiteral[string]: - v = u.Value() - case iceberg.TypedLiteral[[]byte]: - v = string(u.Value()) - } - - if len(v) > lenPrefix { - v = v[:lenPrefix] - } - - if len(v) > 0 && v < prefix { - return rowsCannotMatch - } - } - - return rowsMightMatch -} - -func (m *inclusiveMetricsEval) VisitNotStartsWith(t iceberg.BoundTerm, lit iceberg.Literal) bool { - field := t.Ref().Field() - fieldID := field.ID - - if m.mayContainNull(fieldID) { - return rowsMightMatch - } - - if _, ok := field.Type.(iceberg.PrimitiveType); !ok { - panic(fmt.Errorf("%w: expected iceberg.PrimitiveType, got %s", - iceberg.ErrInvalidTypeString, field.Type)) - } - - // not_starts_with will match unless all values must start with the prefix. - // this happens when the lower and upper bounds both start with the prefix - lowerBoundBytes, upperBoundBytes := m.lowerBounds[fieldID], m.upperBounds[fieldID] - if lowerBoundBytes != nil && upperBoundBytes != nil { - lowerBound, err := iceberg.LiteralFromBytes(field.Type, lowerBoundBytes) - if err != nil { - panic(err) - } - - upperBound, err := iceberg.LiteralFromBytes(field.Type, upperBoundBytes) - if err != nil { - panic(err) - } - - var prefix, lower, upper string - if val, ok := lit.(iceberg.TypedLiteral[string]); ok { - prefix = val.Value() - lower, upper = lowerBound.(iceberg.TypedLiteral[string]).Value(), upperBound.(iceberg.TypedLiteral[string]).Value() - } else { - prefix = string(lit.(iceberg.TypedLiteral[[]byte]).Value()) - lower, upper = string(lowerBound.(iceberg.TypedLiteral[[]byte]).Value()), string(upperBound.(iceberg.TypedLiteral[[]byte]).Value()) - } - - lenPrefix := len(prefix) - if len(lower) < lenPrefix { - return rowsMightMatch - } - - if lower[:lenPrefix] == prefix { - if len(upper) < lenPrefix { - return rowsMightMatch - } - - if upper[:lenPrefix] == prefix { - return rowsCannotMatch - } - } - } - - return rowsMightMatch -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package table + +import ( + "fmt" + "math" + "slices" + + "github.com/apache/iceberg-go" + "github.com/google/uuid" +) + +const ( + rowsMightMatch, rowsMustMatch = true, true + rowsCannotMatch, rowsMightNotMatch = false, false + inPredicateLimit = 200 +) + +// newManifestEvaluator returns a function that can be used to evaluate whether a particular +// manifest file has rows that might or might not match a given partition filter by using +// the stats provided in the partitions (UpperBound/LowerBound/ContainsNull/ContainsNaN). +func newManifestEvaluator(spec iceberg.PartitionSpec, schema *iceberg.Schema, partitionFilter iceberg.BooleanExpression, caseSensitive bool) (func(iceberg.ManifestFile) (bool, error), error) { + partType := spec.PartitionType(schema) + partSchema := iceberg.NewSchema(0, partType.FieldList...) + filter, err := iceberg.RewriteNotExpr(partitionFilter) + if err != nil { + return nil, err + } + + boundFilter, err := iceberg.BindExpr(partSchema, filter, caseSensitive) + if err != nil { + return nil, err + } + + return (&manifestEvalVisitor{partitionFilter: boundFilter}).Eval, nil +} + +type manifestEvalVisitor struct { + partitionFields []iceberg.FieldSummary + partitionFilter iceberg.BooleanExpression +} + +func (m *manifestEvalVisitor) Eval(manifest iceberg.ManifestFile) (bool, error) { + if parts := manifest.Partitions(); len(parts) > 0 { + m.partitionFields = parts + return iceberg.VisitExpr(m.partitionFilter, m) + } + + return rowsMightMatch, nil +} + +func removeBoundCmp[T iceberg.LiteralType](bound iceberg.Literal, vals []iceberg.Literal, cmpToDelete int) []iceberg.Literal { + val := bound.(iceberg.TypedLiteral[T]) + cmp := val.Comparator() + + return slices.DeleteFunc(vals, func(v iceberg.Literal) bool { + return cmp(val.Value(), v.(iceberg.TypedLiteral[T]).Value()) == cmpToDelete + }) +} + +func removeBoundCheck(bound iceberg.Literal, vals []iceberg.Literal, toDelete int) []iceberg.Literal { + switch bound.Type().(type) { + case iceberg.BooleanType: + return removeBoundCmp[bool](bound, vals, toDelete) + case iceberg.Int32Type: + return removeBoundCmp[int32](bound, vals, toDelete) + case iceberg.Int64Type: + return removeBoundCmp[int64](bound, vals, toDelete) + case iceberg.Float32Type: + return removeBoundCmp[float32](bound, vals, toDelete) + case iceberg.Float64Type: + return removeBoundCmp[float64](bound, vals, toDelete) + case iceberg.DateType: + return removeBoundCmp[iceberg.Date](bound, vals, toDelete) + case iceberg.TimeType: + return removeBoundCmp[iceberg.Time](bound, vals, toDelete) + case iceberg.TimestampType, iceberg.TimestampTzType: + return removeBoundCmp[iceberg.Timestamp](bound, vals, toDelete) + case iceberg.BinaryType, iceberg.FixedType: + return removeBoundCmp[[]byte](bound, vals, toDelete) + case iceberg.StringType: + return removeBoundCmp[string](bound, vals, toDelete) + case iceberg.UUIDType: + return removeBoundCmp[uuid.UUID](bound, vals, toDelete) + case iceberg.DecimalType: + return removeBoundCmp[iceberg.Decimal](bound, vals, toDelete) + } + panic("unrecognized type") +} + +func allBoundCmp[T iceberg.LiteralType](bound iceberg.Literal, set iceberg.Set[iceberg.Literal], want int) bool { + val := bound.(iceberg.TypedLiteral[T]) + cmp := val.Comparator() + + return set.All(func(e iceberg.Literal) bool { + return cmp(val.Value(), e.(iceberg.TypedLiteral[T]).Value()) == want + }) +} + +func allBoundCheck(bound iceberg.Literal, set iceberg.Set[iceberg.Literal], want int) bool { + switch bound.Type().(type) { + case iceberg.BooleanType: + return allBoundCmp[bool](bound, set, want) + case iceberg.Int32Type: + return allBoundCmp[int32](bound, set, want) + case iceberg.Int64Type: + return allBoundCmp[int64](bound, set, want) + case iceberg.Float32Type: + return allBoundCmp[float32](bound, set, want) + case iceberg.Float64Type: + return allBoundCmp[float64](bound, set, want) + case iceberg.DateType: + return allBoundCmp[iceberg.Date](bound, set, want) + case iceberg.TimeType: + return allBoundCmp[iceberg.Time](bound, set, want) + case iceberg.TimestampType, iceberg.TimestampTzType: + return allBoundCmp[iceberg.Timestamp](bound, set, want) + case iceberg.BinaryType, iceberg.FixedType: + return allBoundCmp[[]byte](bound, set, want) + case iceberg.StringType: + return allBoundCmp[string](bound, set, want) + case iceberg.UUIDType: + return allBoundCmp[uuid.UUID](bound, set, want) + case iceberg.DecimalType: + return allBoundCmp[iceberg.Decimal](bound, set, want) + } + panic(iceberg.ErrType) +} + +func (m *manifestEvalVisitor) VisitIn(term iceberg.BoundTerm, literals iceberg.Set[iceberg.Literal]) bool { + pos := term.Ref().Pos() + field := m.partitionFields[pos] + + if field.LowerBound == nil { + return rowsCannotMatch + } + + if literals.Len() > inPredicateLimit { + return rowsMightMatch + } + + lower, err := iceberg.LiteralFromBytes(term.Type(), *field.LowerBound) + if err != nil { + panic(err) + } + + if allBoundCheck(lower, literals, 1) { + return rowsCannotMatch + } + + if field.UpperBound != nil { + upper, err := iceberg.LiteralFromBytes(term.Type(), *field.UpperBound) + if err != nil { + panic(err) + } + + if allBoundCheck(upper, literals, -1) { + return rowsCannotMatch + } + } + + return rowsMightMatch +} + +func (m *manifestEvalVisitor) VisitNotIn(term iceberg.BoundTerm, literals iceberg.Set[iceberg.Literal]) bool { + // because the bounds are not necessarily a min or max value, this cannot be answered using them + // notIn(col, {X, ...}) with (X, Y) doesn't guarantee that X is a value in col + return rowsMightMatch +} + +func (m *manifestEvalVisitor) VisitIsNan(term iceberg.BoundTerm) bool { + pos := term.Ref().Pos() + field := m.partitionFields[pos] + + if field.ContainsNaN != nil && !*field.ContainsNaN { + return rowsCannotMatch + } + + return rowsMightMatch +} + +func (m *manifestEvalVisitor) VisitNotNan(term iceberg.BoundTerm) bool { + pos := term.Ref().Pos() + field := m.partitionFields[pos] + + if field.ContainsNaN != nil && *field.ContainsNaN && !field.ContainsNull && field.LowerBound == nil { + return rowsCannotMatch + } + + return rowsMightMatch +} + +func (m *manifestEvalVisitor) VisitIsNull(term iceberg.BoundTerm) bool { + pos := term.Ref().Pos() + field := m.partitionFields[pos] + + if !field.ContainsNull { + return rowsCannotMatch + } + + return rowsMightMatch +} + +func (m *manifestEvalVisitor) VisitNotNull(term iceberg.BoundTerm) bool { + pos := term.Ref().Pos() + field := m.partitionFields[pos] + + // ContainsNull encodes whether at least one partition value is null + // lowerBound is null if all partition values are null + allNull := field.ContainsNull && field.LowerBound == nil + if allNull && (term.Ref().Type().Equals(iceberg.PrimitiveTypes.Float32) || term.Ref().Type().Equals(iceberg.PrimitiveTypes.Float64)) { + // floating point types may include NaN values, which we check separately + // in case bounds don't include NaN values, ContainsNaN needsz to be checked + allNull = field.ContainsNaN != nil && !*field.ContainsNaN + } + + if allNull { + return rowsCannotMatch + } + + return rowsMightMatch +} + +func getCmp[T iceberg.LiteralType](b iceberg.TypedLiteral[T]) func(iceberg.Literal, iceberg.Literal) int { + cmp := b.Comparator() + return func(l1, l2 iceberg.Literal) int { + return cmp(l1.(iceberg.TypedLiteral[T]).Value(), l2.(iceberg.TypedLiteral[T]).Value()) + } +} + +func getCmpLiteral(boundary iceberg.Literal) func(iceberg.Literal, iceberg.Literal) int { + switch l := boundary.(type) { + case iceberg.TypedLiteral[bool]: + return getCmp(l) + case iceberg.TypedLiteral[int32]: + return getCmp(l) + case iceberg.TypedLiteral[int64]: + return getCmp(l) + case iceberg.TypedLiteral[float32]: + return getCmp(l) + case iceberg.TypedLiteral[float64]: + return getCmp(l) + case iceberg.TypedLiteral[iceberg.Date]: + return getCmp(l) + case iceberg.TypedLiteral[iceberg.Time]: + return getCmp(l) + case iceberg.TypedLiteral[iceberg.Timestamp]: + return getCmp(l) + case iceberg.TypedLiteral[[]byte]: + return getCmp(l) + case iceberg.TypedLiteral[string]: + return getCmp(l) + case iceberg.TypedLiteral[uuid.UUID]: + return getCmp(l) + case iceberg.TypedLiteral[iceberg.Decimal]: + return getCmp(l) + } + panic(iceberg.ErrType) +} + +func (m *manifestEvalVisitor) VisitEqual(term iceberg.BoundTerm, lit iceberg.Literal) bool { + pos := term.Ref().Pos() + field := m.partitionFields[pos] + + if field.LowerBound == nil || field.UpperBound == nil { + // values are all null and literal cannot contain null + return rowsCannotMatch + } + + lower, err := iceberg.LiteralFromBytes(term.Ref().Type(), *field.LowerBound) + if err != nil { + panic(err) + } + + cmp := getCmpLiteral(lower) + if cmp(lower, lit) == 1 { + return rowsCannotMatch + } + + upper, err := iceberg.LiteralFromBytes(term.Ref().Type(), *field.UpperBound) + if err != nil { + panic(err) + } + + if cmp(lit, upper) == 1 { + return rowsCannotMatch + } + + return rowsMightMatch +} + +func (m *manifestEvalVisitor) VisitNotEqual(term iceberg.BoundTerm, lit iceberg.Literal) bool { + // because bounds are not necessarily a min or max, this cannot be answered + // using them. notEq(col, X) with (X, Y) doesn't guarantee X is a value in col + return rowsMightMatch +} + +func (m *manifestEvalVisitor) VisitGreaterEqual(term iceberg.BoundTerm, lit iceberg.Literal) bool { + pos := term.Ref().Pos() + field := m.partitionFields[pos] + + if field.UpperBound == nil { + return rowsCannotMatch + } + + upper, err := iceberg.LiteralFromBytes(term.Ref().Type(), *field.UpperBound) + if err != nil { + panic(err) + } + + if getCmpLiteral(upper)(lit, upper) == 1 { + return rowsCannotMatch + } + + return rowsMightMatch +} + +func (m *manifestEvalVisitor) VisitGreater(term iceberg.BoundTerm, lit iceberg.Literal) bool { + pos := term.Ref().Pos() + field := m.partitionFields[pos] + + if field.UpperBound == nil { + return rowsCannotMatch + } + + upper, err := iceberg.LiteralFromBytes(term.Ref().Type(), *field.UpperBound) + if err != nil { + panic(err) + } + + if getCmpLiteral(upper)(lit, upper) >= 0 { + return rowsCannotMatch + } + + return rowsMightMatch +} + +func (m *manifestEvalVisitor) VisitLessEqual(term iceberg.BoundTerm, lit iceberg.Literal) bool { + pos := term.Ref().Pos() + field := m.partitionFields[pos] + + if field.LowerBound == nil { + return rowsCannotMatch + } + + lower, err := iceberg.LiteralFromBytes(term.Ref().Type(), *field.LowerBound) + if err != nil { + panic(err) + } + + if getCmpLiteral(lower)(lit, lower) == -1 { + return rowsCannotMatch + } + + return rowsMightMatch +} + +func (m *manifestEvalVisitor) VisitLess(term iceberg.BoundTerm, lit iceberg.Literal) bool { + pos := term.Ref().Pos() + field := m.partitionFields[pos] + + if field.LowerBound == nil { + return rowsCannotMatch + } + + lower, err := iceberg.LiteralFromBytes(term.Ref().Type(), *field.LowerBound) + if err != nil { + panic(err) + } + + if getCmpLiteral(lower)(lit, lower) <= 0 { + return rowsCannotMatch + } + + return rowsMightMatch +} + +func (m *manifestEvalVisitor) VisitStartsWith(term iceberg.BoundTerm, lit iceberg.Literal) bool { + pos := term.Ref().Pos() + field := m.partitionFields[pos] + + var prefix string + if val, ok := lit.(iceberg.TypedLiteral[string]); ok { + prefix = val.Value() + } else { + prefix = string(lit.(iceberg.TypedLiteral[[]byte]).Value()) + } + + lenPrefix := len(prefix) + + if field.LowerBound == nil { + return rowsCannotMatch + } + + lower, err := iceberg.LiteralFromBytes(term.Ref().Type(), *field.LowerBound) + if err != nil { + panic(err) + } + + // truncate lower bound so that it's length is not greater than the length of prefix + var v string + switch l := lower.(type) { + case iceberg.TypedLiteral[string]: + v = l.Value() + if len(v) > lenPrefix { + v = v[:lenPrefix] + } + case iceberg.TypedLiteral[[]byte]: + v = string(l.Value()) + if len(v) > lenPrefix { + v = v[:lenPrefix] + } + } + + if v > prefix { + return rowsCannotMatch + } + + if field.UpperBound == nil { + return rowsCannotMatch + } + + upper, err := iceberg.LiteralFromBytes(term.Ref().Type(), *field.UpperBound) + if err != nil { + panic(err) + } + + switch u := upper.(type) { + case iceberg.TypedLiteral[string]: + v = u.Value() + if len(v) > lenPrefix { + v = v[:lenPrefix] + } + case iceberg.TypedLiteral[[]byte]: + v = string(u.Value()) + if len(v) > lenPrefix { + v = v[:lenPrefix] + } + } + + if v < prefix { + return rowsCannotMatch + } + + return rowsMightMatch +} + +func (m *manifestEvalVisitor) VisitNotStartsWith(term iceberg.BoundTerm, lit iceberg.Literal) bool { + pos := term.Ref().Pos() + field := m.partitionFields[pos] + + if field.ContainsNull || field.LowerBound == nil || field.UpperBound == nil { + return rowsMightMatch + } + + // NotStartsWith will match unless ALL values must start with the prefix. + // this happens when the lower and upper bounds BOTH start with the prefix + lower, err := iceberg.LiteralFromBytes(term.Ref().Type(), *field.LowerBound) + if err != nil { + panic(err) + } + + upper, err := iceberg.LiteralFromBytes(term.Ref().Type(), *field.UpperBound) + if err != nil { + panic(err) + } + + var ( + prefix, lowerBound, upperBound string + ) + if val, ok := lit.(iceberg.TypedLiteral[string]); ok { + prefix = val.Value() + lowerBound, upperBound = lower.(iceberg.TypedLiteral[string]).Value(), upper.(iceberg.TypedLiteral[string]).Value() + } else { + prefix = string(lit.(iceberg.TypedLiteral[[]byte]).Value()) + lowerBound = string(lower.(iceberg.TypedLiteral[[]byte]).Value()) + upperBound = string(upper.(iceberg.TypedLiteral[[]byte]).Value()) + } + + lenPrefix := len(prefix) + if len(lowerBound) < lenPrefix { + return rowsMightMatch + } + + if lowerBound[:lenPrefix] == prefix { + // if upper is shorter then upper can't start with the prefix + if len(upperBound) < lenPrefix { + return rowsMightMatch + } + + if upperBound[:lenPrefix] == prefix { + return rowsCannotMatch + } + } + + return rowsMightMatch +} + +func (m *manifestEvalVisitor) VisitTrue() bool { + return rowsMightMatch +} + +func (m *manifestEvalVisitor) VisitFalse() bool { + return rowsCannotMatch +} + +func (m *manifestEvalVisitor) VisitUnbound(iceberg.UnboundPredicate) bool { + panic("need bound predicate") +} + +func (m *manifestEvalVisitor) VisitBound(pred iceberg.BoundPredicate) bool { + return iceberg.VisitBoundPredicate(pred, m) +} + +func (m *manifestEvalVisitor) VisitNot(child bool) bool { return !child } +func (m *manifestEvalVisitor) VisitAnd(left, right bool) bool { return left && right } +func (m *manifestEvalVisitor) VisitOr(left, right bool) bool { return left || right } + +type projectionEvaluator struct { + spec iceberg.PartitionSpec + schema *iceberg.Schema + caseSensitive bool +} + +func (*projectionEvaluator) VisitTrue() iceberg.BooleanExpression { return iceberg.AlwaysTrue{} } +func (*projectionEvaluator) VisitFalse() iceberg.BooleanExpression { return iceberg.AlwaysFalse{} } +func (*projectionEvaluator) VisitNot(child iceberg.BooleanExpression) iceberg.BooleanExpression { + panic(fmt.Errorf("%w: cannot project 'not' expression, should be rewritten %s", + iceberg.ErrInvalidArgument, child)) +} + +func (*projectionEvaluator) VisitAnd(left, right iceberg.BooleanExpression) iceberg.BooleanExpression { + return iceberg.NewAnd(left, right) +} + +func (*projectionEvaluator) VisitOr(left, right iceberg.BooleanExpression) iceberg.BooleanExpression { + return iceberg.NewOr(left, right) +} + +func (*projectionEvaluator) VisitUnbound(pred iceberg.UnboundPredicate) iceberg.BooleanExpression { + panic(fmt.Errorf("%w: cannot project unbound predicate: %s", iceberg.ErrInvalidArgument, pred)) +} + +type inclusiveProjection struct{ projectionEvaluator } + +func (p *inclusiveProjection) Project(expr iceberg.BooleanExpression) (iceberg.BooleanExpression, error) { + expr, err := iceberg.RewriteNotExpr(expr) + if err != nil { + return nil, err + } + + bound, err := iceberg.BindExpr(p.schema, expr, p.caseSensitive) + if err != nil { + return nil, err + } + + return iceberg.VisitExpr(bound, p) +} + +func (p *inclusiveProjection) VisitBound(pred iceberg.BoundPredicate) iceberg.BooleanExpression { + parts := p.spec.FieldsBySourceID(pred.Term().Ref().Field().ID) + + var result iceberg.BooleanExpression = iceberg.AlwaysTrue{} + for _, part := range parts { + // consider (d = 2019-01-01) with bucket(7, d) and bucket(5, d) + // projections: b1 = bucket(7, '2019-01-01') = 5, b2 = bucket(5, '2019-01-01') = 0 + // any value where b1 != 5 or any value where b2 != 0 cannot be the '2019-01-01' + // + // similarly, if partitioning by day(ts) and hour(ts), the more restrictive + // projection should be used. ts = 2019-01-01T01:00:00 produces day=2019-01-01 and + // hour=2019-01-01-01. the value will be in 2019-01-01-01 and not in 2019-01-01-02. + inclProjection, err := part.Transform.Project(part.Name, pred) + if err != nil { + panic(err) + } + if inclProjection != nil { + result = iceberg.NewAnd(result, inclProjection) + } + } + + return result +} + +func newInclusiveProjection(s *iceberg.Schema, spec iceberg.PartitionSpec, caseSensitive bool) func(iceberg.BooleanExpression) (iceberg.BooleanExpression, error) { + return (&inclusiveProjection{ + projectionEvaluator: projectionEvaluator{ + schema: s, + spec: spec, + caseSensitive: caseSensitive, + }, + }).Project +} + +type metricsEvaluator struct { + valueCounts map[int]int64 + nullCounts map[int]int64 + nanCounts map[int]int64 + lowerBounds map[int][]byte + upperBounds map[int][]byte +} + +func (m *metricsEvaluator) VisitTrue() bool { return rowsMightMatch } +func (m *metricsEvaluator) VisitFalse() bool { return rowsCannotMatch } +func (m *metricsEvaluator) VisitNot(child bool) bool { + panic(fmt.Errorf("%w: NOT should be rewritten %v", iceberg.ErrInvalidArgument, child)) +} +func (m *metricsEvaluator) VisitAnd(left, right bool) bool { return left && right } +func (m *metricsEvaluator) VisitOr(left, right bool) bool { return left || right } + +func (m *metricsEvaluator) containsNullsOnly(id int) bool { + valCount, ok := m.valueCounts[id] + if !ok { + return false + } + + nullCount, ok := m.nullCounts[id] + if !ok { + return false + } + + return valCount == nullCount +} + +func (m *metricsEvaluator) containsNansOnly(id int) bool { + nanCount, ok := m.nanCounts[id] + if !ok { + return false + } + + valCount, ok := m.valueCounts[id] + if !ok { + return false + } + + return nanCount == valCount +} + +func (m *metricsEvaluator) isNan(v iceberg.Literal) bool { + switch v := v.(type) { + case iceberg.Float32Literal: + return math.IsNaN(float64(v)) + case iceberg.Float64Literal: + return math.IsNaN(float64(v)) + default: + return false + } +} + +func newInclusiveMetricsEvaluator(s *iceberg.Schema, expr iceberg.BooleanExpression, + caseSensitive bool, includeEmptyFiles bool) (func(iceberg.DataFile) (bool, error), error) { + + rewritten, err := iceberg.RewriteNotExpr(expr) + if err != nil { + return nil, err + } + + bound, err := iceberg.BindExpr(s, rewritten, caseSensitive) + if err != nil { + return nil, err + } + + return (&inclusiveMetricsEval{ + st: s.AsStruct(), + includeEmptyFiles: includeEmptyFiles, + expr: bound, + }).Eval, nil +} + +type inclusiveMetricsEval struct { + metricsEvaluator + + st iceberg.StructType + expr iceberg.BooleanExpression + includeEmptyFiles bool +} + +func (m *inclusiveMetricsEval) Eval(file iceberg.DataFile) (bool, error) { + if !m.includeEmptyFiles && file.Count() == 0 { + return rowsCannotMatch, nil + } + + m.valueCounts, m.nullCounts = file.ValueCounts(), file.NullValueCounts() + m.nanCounts = file.NaNValueCounts() + m.lowerBounds, m.upperBounds = file.LowerBoundValues(), file.UpperBoundValues() + + return iceberg.VisitExpr(m.expr, m) +} + +func (m *inclusiveMetricsEval) mayContainNull(fieldID int) bool { + if m.nullCounts == nil { + return true + } + + _, ok := m.nullCounts[fieldID] + return ok +} + +func (m *inclusiveMetricsEval) VisitUnbound(iceberg.UnboundPredicate) bool { + panic("need bound predicate") +} + +func (m *inclusiveMetricsEval) VisitBound(pred iceberg.BoundPredicate) bool { + return iceberg.VisitBoundPredicate(pred, m) +} + +func (m *inclusiveMetricsEval) VisitIsNull(t iceberg.BoundTerm) bool { + fieldID := t.Ref().Field().ID + if cnt, exists := m.nullCounts[fieldID]; exists && cnt == 0 { + return rowsCannotMatch + } + return rowsMightMatch +} + +func (m *inclusiveMetricsEval) VisitNotNull(t iceberg.BoundTerm) bool { + // no need to check whether the field is required because binding evaluates + // that case if the column has no non-null values, the expression cannot match + fieldID := t.Ref().Field().ID + if m.containsNullsOnly(fieldID) { + return rowsCannotMatch + } + return rowsMightMatch +} + +func (m *inclusiveMetricsEval) VisitIsNan(t iceberg.BoundTerm) bool { + fieldID := t.Ref().Field().ID + if cnt, exists := m.nanCounts[fieldID]; exists && cnt == 0 { + return rowsCannotMatch + } + // when there's no nancounts information but we already know the column + // contains null it's guaranteed that there's no nan value + if m.containsNullsOnly(fieldID) { + return rowsCannotMatch + } + return rowsMightMatch +} + +func (m *inclusiveMetricsEval) VisitNotNan(t iceberg.BoundTerm) bool { + fieldID := t.Ref().Field().ID + + if m.containsNansOnly(fieldID) { + return rowsCannotMatch + } + return rowsMightMatch +} + +func (m *inclusiveMetricsEval) VisitLess(t iceberg.BoundTerm, lit iceberg.Literal) bool { + field := t.Ref().Field() + fieldID := field.ID + + if m.containsNullsOnly(fieldID) || m.containsNansOnly(fieldID) { + return rowsCannotMatch + } + + if _, ok := field.Type.(iceberg.PrimitiveType); !ok { + panic(fmt.Errorf("%w: expected iceberg.PrimitiveType, got %s", + iceberg.ErrInvalidTypeString, field.Type)) + } + + if lowerBoundBytes := m.lowerBounds[fieldID]; lowerBoundBytes != nil { + lowerBound, err := iceberg.LiteralFromBytes(field.Type, lowerBoundBytes) + if err != nil { + panic(err) + } + + if m.isNan(lowerBound) { + // nan indicates unreliable bounds + return rowsMightMatch + } + + if getCmpLiteral(lowerBound)(lowerBound, lit) >= 0 { + return rowsCannotMatch + } + } + + return rowsMightMatch +} + +func (m *inclusiveMetricsEval) VisitLessEqual(t iceberg.BoundTerm, lit iceberg.Literal) bool { + field := t.Ref().Field() + fieldID := field.ID + + if m.containsNullsOnly(fieldID) || m.containsNansOnly(fieldID) { + return rowsCannotMatch + } + + if _, ok := field.Type.(iceberg.PrimitiveType); !ok { + panic(fmt.Errorf("%w: expected iceberg.PrimitiveType, got %s", + iceberg.ErrInvalidTypeString, field.Type)) + } + + if lowerBoundBytes := m.lowerBounds[fieldID]; lowerBoundBytes != nil { + lowerBound, err := iceberg.LiteralFromBytes(field.Type, lowerBoundBytes) + if err != nil { + panic(err) + } + + if m.isNan(lowerBound) { + // nan indicates unreliable bounds + return rowsMightMatch + } + + if getCmpLiteral(lowerBound)(lowerBound, lit) > 0 { + return rowsCannotMatch + } + } + + return rowsMightMatch +} + +func (m *inclusiveMetricsEval) VisitGreater(t iceberg.BoundTerm, lit iceberg.Literal) bool { + field := t.Ref().Field() + fieldID := field.ID + + if m.containsNullsOnly(fieldID) || m.containsNansOnly(fieldID) { + return rowsCannotMatch + } + + if _, ok := field.Type.(iceberg.PrimitiveType); !ok { + panic(fmt.Errorf("%w: expected iceberg.PrimitiveType, got %s", + iceberg.ErrInvalidTypeString, field.Type)) + } + + if upperBoundBytes := m.upperBounds[fieldID]; upperBoundBytes != nil { + upperBound, err := iceberg.LiteralFromBytes(field.Type, upperBoundBytes) + if err != nil { + panic(err) + } + + if getCmpLiteral(upperBound)(upperBound, lit) <= 0 { + if m.isNan(upperBound) { + return rowsMightMatch + } + + return rowsCannotMatch + } + } + + return rowsMightMatch +} + +func (m *inclusiveMetricsEval) VisitGreaterEqual(t iceberg.BoundTerm, lit iceberg.Literal) bool { + field := t.Ref().Field() + fieldID := field.ID + + if m.containsNullsOnly(fieldID) || m.containsNansOnly(fieldID) { + return rowsCannotMatch + } + + if _, ok := field.Type.(iceberg.PrimitiveType); !ok { + panic(fmt.Errorf("%w: expected iceberg.PrimitiveType, got %s", + iceberg.ErrInvalidTypeString, field.Type)) + } + + if upperBoundBytes := m.upperBounds[fieldID]; upperBoundBytes != nil { + upperBound, err := iceberg.LiteralFromBytes(field.Type, upperBoundBytes) + if err != nil { + panic(err) + } + + if getCmpLiteral(upperBound)(upperBound, lit) < 0 { + if m.isNan(upperBound) { + return rowsMightMatch + } + + return rowsCannotMatch + } + } + + return rowsMightMatch +} + +func (m *inclusiveMetricsEval) VisitEqual(t iceberg.BoundTerm, lit iceberg.Literal) bool { + field := t.Ref().Field() + fieldID := field.ID + + if m.containsNullsOnly(fieldID) || m.containsNansOnly(fieldID) { + return rowsCannotMatch + } + + if _, ok := field.Type.(iceberg.PrimitiveType); !ok { + panic(fmt.Errorf("%w: expected iceberg.PrimitiveType, got %s", + iceberg.ErrInvalidTypeString, field.Type)) + } + + var cmp func(iceberg.Literal, iceberg.Literal) int + if lowerBoundBytes := m.lowerBounds[fieldID]; lowerBoundBytes != nil { + lowerBound, err := iceberg.LiteralFromBytes(field.Type, lowerBoundBytes) + if err != nil { + panic(err) + } + + if m.isNan(lowerBound) { + return rowsMightMatch + } + + cmp = getCmpLiteral(lowerBound) + if cmp(lowerBound, lit) == 1 { + return rowsCannotMatch + } + } + + if upperBoundBytes := m.upperBounds[fieldID]; upperBoundBytes != nil { + upperBound, err := iceberg.LiteralFromBytes(field.Type, upperBoundBytes) + if err != nil { + panic(err) + } + + if m.isNan(upperBound) { + return rowsMightMatch + } + + if cmp(upperBound, lit) == -1 { + return rowsCannotMatch + } + } + + return rowsMightMatch +} + +func (m *inclusiveMetricsEval) VisitNotEqual(iceberg.BoundTerm, iceberg.Literal) bool { + return rowsMightMatch +} + +func (m *inclusiveMetricsEval) VisitIn(t iceberg.BoundTerm, s iceberg.Set[iceberg.Literal]) bool { + field := t.Ref().Field() + fieldID := field.ID + + if m.containsNullsOnly(fieldID) || m.containsNansOnly(fieldID) { + return rowsCannotMatch + } + + if s.Len() > inPredicateLimit { + // skip evaluating the predicate if the number of values is too big + return rowsMightMatch + } + + if _, ok := field.Type.(iceberg.PrimitiveType); !ok { + panic(fmt.Errorf("%w: expected iceberg.PrimitiveType, got %s", + iceberg.ErrInvalidTypeString, field.Type)) + } + + values := s.Members() + if lowerBoundBytes := m.lowerBounds[fieldID]; lowerBoundBytes != nil { + lowerBound, err := iceberg.LiteralFromBytes(field.Type, lowerBoundBytes) + if err != nil { + panic(lowerBound) + } + + if m.isNan(lowerBound) { + return rowsMightMatch + } + + values = removeBoundCheck(lowerBound, values, 1) + if len(values) == 0 { + return rowsCannotMatch + } + } + + if upperBoundBytes := m.upperBounds[fieldID]; upperBoundBytes != nil { + upperBound, err := iceberg.LiteralFromBytes(field.Type, upperBoundBytes) + if err != nil { + panic(err) + } + + if m.isNan(upperBound) { + return rowsMightMatch + } + + values = removeBoundCheck(upperBound, values, -1) + if len(values) == 0 { + return rowsCannotMatch + } + } + + return rowsMightMatch +} + +func (m *inclusiveMetricsEval) VisitNotIn(iceberg.BoundTerm, iceberg.Set[iceberg.Literal]) bool { + // because the bounds are not necessarily a min or max value, this cannot be + // answered using them. notIn(col, {X, ...}) with (XX, Y) doesn't guarantee that + // X is a value in col + return rowsMightMatch +} + +func (m *inclusiveMetricsEval) VisitStartsWith(t iceberg.BoundTerm, lit iceberg.Literal) bool { + field := t.Ref().Field() + fieldID := field.ID + + if m.containsNullsOnly(fieldID) { + return rowsCannotMatch + } + + if _, ok := field.Type.(iceberg.PrimitiveType); !ok { + panic(fmt.Errorf("%w: expected iceberg.PrimitiveType, got %s", + iceberg.ErrInvalidTypeString, field.Type)) + } + + var prefix string + if val, ok := lit.(iceberg.TypedLiteral[string]); ok { + prefix = val.Value() + } else { + prefix = string(lit.(iceberg.TypedLiteral[[]byte]).Value()) + } + + lenPrefix := len(prefix) + + if lowerBoundBytes := m.lowerBounds[fieldID]; lowerBoundBytes != nil { + lowerBound, err := iceberg.LiteralFromBytes(field.Type, lowerBoundBytes) + if err != nil { + panic(err) + } + + var v string + switch l := lowerBound.(type) { + case iceberg.TypedLiteral[string]: + v = l.Value() + case iceberg.TypedLiteral[[]byte]: + v = string(l.Value()) + } + + if len(v) > lenPrefix { + v = v[:lenPrefix] + } + + if len(v) > 0 && v > prefix { + return rowsCannotMatch + } + } + + if upperBoundBytes := m.upperBounds[fieldID]; upperBoundBytes != nil { + upperBound, err := iceberg.LiteralFromBytes(field.Type, upperBoundBytes) + if err != nil { + panic(err) + } + + var v string + switch u := upperBound.(type) { + case iceberg.TypedLiteral[string]: + v = u.Value() + case iceberg.TypedLiteral[[]byte]: + v = string(u.Value()) + } + + if len(v) > lenPrefix { + v = v[:lenPrefix] + } + + if len(v) > 0 && v < prefix { + return rowsCannotMatch + } + } + + return rowsMightMatch +} + +func (m *inclusiveMetricsEval) VisitNotStartsWith(t iceberg.BoundTerm, lit iceberg.Literal) bool { + field := t.Ref().Field() + fieldID := field.ID + + if m.mayContainNull(fieldID) { + return rowsMightMatch + } + + if _, ok := field.Type.(iceberg.PrimitiveType); !ok { + panic(fmt.Errorf("%w: expected iceberg.PrimitiveType, got %s", + iceberg.ErrInvalidTypeString, field.Type)) + } + + // not_starts_with will match unless all values must start with the prefix. + // this happens when the lower and upper bounds both start with the prefix + lowerBoundBytes, upperBoundBytes := m.lowerBounds[fieldID], m.upperBounds[fieldID] + if lowerBoundBytes != nil && upperBoundBytes != nil { + lowerBound, err := iceberg.LiteralFromBytes(field.Type, lowerBoundBytes) + if err != nil { + panic(err) + } + + upperBound, err := iceberg.LiteralFromBytes(field.Type, upperBoundBytes) + if err != nil { + panic(err) + } + + var prefix, lower, upper string + if val, ok := lit.(iceberg.TypedLiteral[string]); ok { + prefix = val.Value() + lower, upper = lowerBound.(iceberg.TypedLiteral[string]).Value(), upperBound.(iceberg.TypedLiteral[string]).Value() + } else { + prefix = string(lit.(iceberg.TypedLiteral[[]byte]).Value()) + lower, upper = string(lowerBound.(iceberg.TypedLiteral[[]byte]).Value()), string(upperBound.(iceberg.TypedLiteral[[]byte]).Value()) + } + + lenPrefix := len(prefix) + if len(lower) < lenPrefix { + return rowsMightMatch + } + + if lower[:lenPrefix] == prefix { + if len(upper) < lenPrefix { + return rowsMightMatch + } + + if upper[:lenPrefix] == prefix { + return rowsCannotMatch + } + } + } + + return rowsMightMatch +} diff --git a/table/evaluators_test.go b/table/evaluators_test.go index a543e93..67237c0 100644 --- a/table/evaluators_test.go +++ b/table/evaluators_test.go @@ -1,1799 +1,1799 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package table - -import ( - "math" - "testing" - - "github.com/apache/iceberg-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" -) - -const ( - IntMinValue, IntMaxValue int32 = 30, 79 -) - -func TestManifestEvaluator(t *testing.T) { - - var ( - IntMin, IntMax = []byte{byte(IntMinValue), 0x00, 0x00, 0x00}, []byte{byte(IntMaxValue), 0x00, 0x00, 0x00} - StringMin, StringMax = []byte("a"), []byte("z") - FloatMin, _ = iceberg.Float32Literal(0).MarshalBinary() - FloatMax, _ = iceberg.Float32Literal(20).MarshalBinary() - DblMin, _ = iceberg.Float64Literal(0).MarshalBinary() - DblMax, _ = iceberg.Float64Literal(20).MarshalBinary() - NanTrue, NanFalse = true, false - - testSchema = iceberg.NewSchema(1, - iceberg.NestedField{ID: 1, Name: "id", - Type: iceberg.PrimitiveTypes.Int32, Required: true}, - iceberg.NestedField{ID: 2, Name: "all_nulls_missing_nan", - Type: iceberg.PrimitiveTypes.String, Required: false}, - iceberg.NestedField{ID: 3, Name: "some_nulls", - Type: iceberg.PrimitiveTypes.String, Required: false}, - iceberg.NestedField{ID: 4, Name: "no_nulls", - Type: iceberg.PrimitiveTypes.String, Required: false}, - iceberg.NestedField{ID: 5, Name: "float", - Type: iceberg.PrimitiveTypes.Float32, Required: false}, - iceberg.NestedField{ID: 6, Name: "all_nulls_double", - Type: iceberg.PrimitiveTypes.Float64, Required: false}, - iceberg.NestedField{ID: 7, Name: "all_nulls_no_nans", - Type: iceberg.PrimitiveTypes.Float32, Required: false}, - iceberg.NestedField{ID: 8, Name: "all_nans", - Type: iceberg.PrimitiveTypes.Float64, Required: false}, - iceberg.NestedField{ID: 9, Name: "both_nan_and_null", - Type: iceberg.PrimitiveTypes.Float32, Required: false}, - iceberg.NestedField{ID: 10, Name: "no_nan_or_null", - Type: iceberg.PrimitiveTypes.Float64, Required: false}, - iceberg.NestedField{ID: 11, Name: "all_nulls_missing_nan_float", - Type: iceberg.PrimitiveTypes.Float32, Required: false}, - iceberg.NestedField{ID: 12, Name: "all_same_value_or_null", - Type: iceberg.PrimitiveTypes.String, Required: false}, - iceberg.NestedField{ID: 13, Name: "no_nulls_same_value_a", - Type: iceberg.PrimitiveTypes.Binary, Required: false}, - ) - ) - - partFields := make([]iceberg.PartitionField, 0, testSchema.NumFields()) - for _, f := range testSchema.Fields() { - partFields = append(partFields, iceberg.PartitionField{ - Name: f.Name, - SourceID: f.ID, - FieldID: f.ID, - Transform: iceberg.IdentityTransform{}, - }) - } - - spec := iceberg.NewPartitionSpec(partFields...) - manifestNoStats := iceberg.NewManifestV1Builder("", 0, 0, 0).Build() - manifest := iceberg.NewManifestV1Builder("", 0, 0, 0).Partitions( - []iceberg.FieldSummary{ - { // id - ContainsNull: false, - ContainsNaN: nil, - LowerBound: &IntMin, - UpperBound: &IntMax, - }, - { // all_nulls_missing_nan - ContainsNull: true, - ContainsNaN: nil, - LowerBound: nil, - UpperBound: nil, - }, - { // some_nulls - ContainsNull: true, - ContainsNaN: nil, - LowerBound: &StringMin, - UpperBound: &StringMax, - }, - { // no_nulls - ContainsNull: false, - ContainsNaN: nil, - LowerBound: &StringMin, - UpperBound: &StringMax, - }, - { // float - ContainsNull: true, - ContainsNaN: nil, - LowerBound: &FloatMin, - UpperBound: &FloatMax, - }, - { // all_nulls_double - ContainsNull: true, - ContainsNaN: nil, - LowerBound: nil, - UpperBound: nil, - }, - { // all_nulls_no_nans - ContainsNull: true, - ContainsNaN: &NanFalse, - LowerBound: nil, - UpperBound: nil, - }, - { // all_nans - ContainsNull: false, - ContainsNaN: &NanTrue, - LowerBound: nil, - UpperBound: nil, - }, - { // both_nan_and_null - ContainsNull: true, - ContainsNaN: &NanTrue, - LowerBound: nil, - UpperBound: nil, - }, - { // no_nan_or_null - ContainsNull: false, - ContainsNaN: &NanFalse, - LowerBound: &DblMin, - UpperBound: &DblMax, - }, - { // all_nulls_missing_nan_float - ContainsNull: true, - ContainsNaN: nil, - LowerBound: nil, - UpperBound: nil, - }, - { // all_same_value_or_null - ContainsNull: true, - ContainsNaN: nil, - LowerBound: &StringMin, - UpperBound: &StringMin, - }, - { // no_nulls_same_value_a - ContainsNull: false, - ContainsNaN: nil, - LowerBound: &StringMin, - UpperBound: &StringMin, - }, - }).Build() - - t.Run("all nulls", func(t *testing.T) { - tests := []struct { - field string - expected bool - msg string - }{ - {"all_nulls_missing_nan", false, "should skip: all nulls column with non-floating type contains all null"}, - {"all_nulls_missing_nan_float", true, "should read: no NaN information may indicate presence of NaN value"}, - {"some_nulls", true, "should read: column with some nulls contains a non-null value"}, - {"no_nulls", true, "should read: non-null column contains a non-null value"}, - } - - for _, tt := range tests { - eval, err := newManifestEvaluator(spec, testSchema, - iceberg.NotNull(iceberg.Reference(tt.field)), true) - require.NoError(t, err) - - result, err := eval(manifest) - require.NoError(t, err) - assert.Equal(t, tt.expected, result, tt.msg) - } - }) - - t.Run("no nulls", func(t *testing.T) { - tests := []struct { - field string - expected bool - msg string - }{ - {"all_nulls_missing_nan", true, "should read: at least one null value in all null column"}, - {"some_nulls", true, "should read: column with some nulls contains a null value"}, - {"no_nulls", false, "should skip: non-null column contains no null values"}, - {"both_nan_and_null", true, "should read: both_nan_and_null column contains no null values"}, - } - - for _, tt := range tests { - eval, err := newManifestEvaluator(spec, testSchema, - iceberg.IsNull(iceberg.Reference(tt.field)), true) - require.NoError(t, err) - - result, err := eval(manifest) - require.NoError(t, err) - assert.Equal(t, tt.expected, result, tt.msg) - } - }) - - t.Run("is nan", func(t *testing.T) { - tests := []struct { - field string - expected bool - msg string - }{ - {"float", true, "should read: no information on if there are nan values in float column"}, - {"all_nulls_double", true, "should read: no NaN information may indicate presence of NaN value"}, - {"all_nulls_missing_nan_float", true, "should read: no NaN information may indicate presence of NaN value"}, - {"all_nulls_no_nans", false, "should skip: no nan column doesn't contain nan value"}, - {"all_nans", true, "should read: all_nans column contains nan value"}, - {"both_nan_and_null", true, "should read: both_nan_and_null column contains nan value"}, - {"no_nan_or_null", false, "should skip: no_nan_or_null column doesn't contain nan value"}, - } - - for _, tt := range tests { - eval, err := newManifestEvaluator(spec, testSchema, - iceberg.IsNaN(iceberg.Reference(tt.field)), true) - require.NoError(t, err) - - result, err := eval(manifest) - require.NoError(t, err) - assert.Equal(t, tt.expected, result, tt.msg) - } - }) - - t.Run("not nan", func(t *testing.T) { - tests := []struct { - field string - expected bool - msg string - }{ - {"float", true, "should read: no information on if there are nan values in float column"}, - {"all_nulls_double", true, "should read: all null column contains non nan value"}, - {"all_nulls_no_nans", true, "should read: no_nans column contains non nan value"}, - {"all_nans", false, "should skip: all nans columndoesn't contain non nan value"}, - {"both_nan_and_null", true, "should read: both_nan_and_null nans column contains non nan value"}, - {"no_nan_or_null", true, "should read: no_nan_or_null column contains non nan value"}, - } - - for _, tt := range tests { - eval, err := newManifestEvaluator(spec, testSchema, - iceberg.NotNaN(iceberg.Reference(tt.field)), true) - require.NoError(t, err) - - result, err := eval(manifest) - require.NoError(t, err) - assert.Equal(t, tt.expected, result, tt.msg) - } - }) - - t.Run("test missing stats", func(t *testing.T) { - exprs := []iceberg.BooleanExpression{ - iceberg.LessThan(iceberg.Reference("id"), int32(5)), - iceberg.LessThanEqual(iceberg.Reference("id"), int32(30)), - iceberg.EqualTo(iceberg.Reference("id"), int32(70)), - iceberg.GreaterThan(iceberg.Reference("id"), int32(78)), - iceberg.GreaterThanEqual(iceberg.Reference("id"), int32(90)), - iceberg.NotEqualTo(iceberg.Reference("id"), int32(101)), - iceberg.IsNull(iceberg.Reference("id")), - iceberg.NotNull(iceberg.Reference("id")), - iceberg.IsNaN(iceberg.Reference("float")), - iceberg.NotNaN(iceberg.Reference("float")), - } - - for _, tt := range exprs { - eval, err := newManifestEvaluator(spec, testSchema, tt, true) - require.NoError(t, err) - - result, err := eval(manifestNoStats) - require.NoError(t, err) - assert.Truef(t, result, "should read when missing stats for expr: %s", tt) - } - }) - - t.Run("test exprs", func(t *testing.T) { - tests := []struct { - expr iceberg.BooleanExpression - expect bool - msg string - }{ - {iceberg.NewNot(iceberg.LessThan(iceberg.Reference("id"), int32(IntMinValue-25))), - true, "should read: not(false)"}, - {iceberg.NewNot(iceberg.GreaterThan(iceberg.Reference("id"), int32(IntMinValue-25))), - false, "should skip: not(true)"}, - {iceberg.NewAnd( - iceberg.LessThan(iceberg.Reference("id"), int32(IntMinValue-25)), - iceberg.GreaterThanEqual(iceberg.Reference("id"), int32(IntMinValue-30))), - false, "should skip: and(false, true)"}, - {iceberg.NewAnd( - iceberg.LessThan(iceberg.Reference("id"), int32(IntMinValue-25)), - iceberg.GreaterThanEqual(iceberg.Reference("id"), int32(IntMaxValue+1))), - false, "should skip: and(false, false)"}, - {iceberg.NewAnd( - iceberg.GreaterThan(iceberg.Reference("id"), int32(IntMinValue-25)), - iceberg.LessThanEqual(iceberg.Reference("id"), int32(IntMinValue))), - true, "should read: and(true, true)"}, - {iceberg.NewOr( - iceberg.LessThan(iceberg.Reference("id"), int32(IntMinValue-25)), - iceberg.GreaterThanEqual(iceberg.Reference("id"), int32(IntMaxValue+1))), - false, "should skip: or(false, false)"}, - {iceberg.NewOr( - iceberg.LessThan(iceberg.Reference("id"), int32(IntMinValue-25)), - iceberg.GreaterThanEqual(iceberg.Reference("id"), int32(IntMaxValue-19))), - true, "should read: or(false, true)"}, - {iceberg.LessThan(iceberg.Reference("some_nulls"), "1"), false, - "should not read: id range below lower bound"}, - {iceberg.LessThan(iceberg.Reference("some_nulls"), "b"), true, - "should read: lower bound in range"}, - {iceberg.LessThan(iceberg.Reference("float"), 15.50), true, - "should read: lower bound in range"}, - {iceberg.LessThan(iceberg.Reference("no_nan_or_null"), 15.50), true, - "should read: lower bound in range"}, - {iceberg.LessThanEqual(iceberg.Reference("no_nulls_same_value_a"), "a"), true, - "should read: lower bound in range"}, - {iceberg.LessThan(iceberg.Reference("id"), int32(IntMinValue-25)), false, - "should not read: id range below lower bound (5 < 30)"}, - {iceberg.LessThan(iceberg.Reference("id"), int32(IntMinValue)), false, - "should not read: id range below lower bound (30 is not < 30)"}, - {iceberg.LessThan(iceberg.Reference("id"), int32(IntMinValue+1)), true, - "should read: one possible id"}, - {iceberg.LessThan(iceberg.Reference("id"), int32(IntMaxValue)), true, - "should read: many possible ids"}, - {iceberg.LessThanEqual(iceberg.Reference("id"), int32(IntMinValue-25)), false, - "should not read: id range below lower bound (5 < 30)"}, - {iceberg.LessThanEqual(iceberg.Reference("id"), int32(IntMinValue-1)), false, - "should not read: id range below lower bound 29 < 30"}, - {iceberg.LessThanEqual(iceberg.Reference("id"), int32(IntMinValue)), true, - "should read: one possible id"}, - {iceberg.LessThanEqual(iceberg.Reference("id"), int32(IntMaxValue)), true, - "should read: many possible ids"}, - {iceberg.GreaterThan(iceberg.Reference("id"), int32(IntMaxValue+6)), false, - "should not read: id range above upper bound (85 < 79)"}, - {iceberg.GreaterThan(iceberg.Reference("id"), int32(IntMaxValue)), false, - "should not read: id range above upper bound (79 is not > 79)"}, - {iceberg.GreaterThan(iceberg.Reference("id"), int32(IntMaxValue-1)), true, - "should read: one possible id"}, - {iceberg.GreaterThan(iceberg.Reference("id"), int32(IntMaxValue-4)), true, - "should read: many possible ids"}, - {iceberg.GreaterThanEqual(iceberg.Reference("id"), int32(IntMaxValue+6)), false, - "should not read: id range is above upper bound (85 < 79)"}, - {iceberg.GreaterThanEqual(iceberg.Reference("id"), int32(IntMaxValue+1)), false, - "should not read: id range above upper bound (80 > 79)"}, - {iceberg.GreaterThanEqual(iceberg.Reference("id"), int32(IntMaxValue)), true, - "should read: one possible id"}, - {iceberg.GreaterThanEqual(iceberg.Reference("id"), int32(IntMaxValue)), true, - "should read: many possible ids"}, - {iceberg.EqualTo(iceberg.Reference("id"), int32(IntMinValue-25)), false, - "should not read: id below lower bound"}, - {iceberg.EqualTo(iceberg.Reference("id"), int32(IntMinValue-1)), false, - "should not read: id below lower bound"}, - {iceberg.EqualTo(iceberg.Reference("id"), int32(IntMinValue)), true, - "should read: id equal to lower bound"}, - {iceberg.EqualTo(iceberg.Reference("id"), int32(IntMaxValue-4)), true, - "should read: id between lower and upper bounds"}, - {iceberg.EqualTo(iceberg.Reference("id"), int32(IntMaxValue)), true, - "should read: id equal to upper bound"}, - {iceberg.EqualTo(iceberg.Reference("id"), int32(IntMaxValue+1)), false, - "should not read: id above upper bound"}, - {iceberg.EqualTo(iceberg.Reference("id"), int32(IntMaxValue+6)), false, - "should not read: id above upper bound"}, - {iceberg.NotEqualTo(iceberg.Reference("id"), int32(IntMinValue-25)), true, - "should read: id below lower bound"}, - {iceberg.NotEqualTo(iceberg.Reference("id"), int32(IntMinValue-1)), true, - "should read: id below lower bound"}, - {iceberg.NotEqualTo(iceberg.Reference("id"), int32(IntMinValue)), true, - "should read: id equal to lower bound"}, - {iceberg.NotEqualTo(iceberg.Reference("id"), int32(IntMaxValue-4)), true, - "should read: id between lower and upper bounds"}, - {iceberg.NotEqualTo(iceberg.Reference("id"), int32(IntMaxValue)), true, - "should read: id equal to upper bound"}, - {iceberg.NotEqualTo(iceberg.Reference("id"), int32(IntMaxValue+1)), true, - "should read: id above upper bound"}, - {iceberg.NotEqualTo(iceberg.Reference("id"), int32(IntMaxValue+6)), true, - "should read: id above upper bound"}, - {iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("id"), int32(IntMinValue-25))), true, - "should read: id below lower bound"}, - {iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("id"), int32(IntMinValue-1))), true, - "should read: id below lower bound"}, - {iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("id"), int32(IntMinValue))), true, - "should read: id equal to lower bound"}, - {iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("id"), int32(IntMaxValue-4))), true, - "should read: id between lower and upper bounds"}, - {iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("id"), int32(IntMaxValue))), true, - "should read: id equal to upper bound"}, - {iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("id"), int32(IntMaxValue+1))), true, - "should read: id above upper bound"}, - {iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("id"), int32(IntMaxValue+6))), true, - "should read: id above upper bound"}, - {iceberg.IsIn(iceberg.Reference("id"), int32(IntMinValue-25), IntMinValue-24), false, - "should not read: id below lower bound (5 < 30, 6 < 30)"}, - {iceberg.IsIn(iceberg.Reference("id"), int32(IntMinValue-2), IntMinValue-1), false, - "should not read: id below lower bound (28 < 30, 29 < 30)"}, - {iceberg.IsIn(iceberg.Reference("id"), int32(IntMinValue-1), IntMinValue), true, - "should read: id equal to lower bound (30 == 30)"}, - {iceberg.IsIn(iceberg.Reference("id"), int32(IntMaxValue-4), IntMaxValue-3), true, - "should read: id between lower and upper bounds (30 < 75 < 79, 30 < 76 < 79)"}, - {iceberg.IsIn(iceberg.Reference("id"), int32(IntMaxValue), IntMaxValue+1), true, - "should read: id equal to upper bound (79 == 79)"}, - {iceberg.IsIn(iceberg.Reference("id"), int32(IntMaxValue+1), IntMaxValue+2), false, - "should not read: id above upper bound (80 > 79, 81 > 79)"}, - {iceberg.IsIn(iceberg.Reference("id"), int32(IntMaxValue+6), IntMaxValue+7), false, - "should not read: id above upper bound (85 > 79, 86 > 79)"}, - {iceberg.IsIn(iceberg.Reference("all_nulls_missing_nan"), "abc", "def"), false, - "should skip: in on all nulls column"}, - {iceberg.IsIn(iceberg.Reference("some_nulls"), "abc", "def"), true, - "should read: in on some nulls column"}, - {iceberg.IsIn(iceberg.Reference("no_nulls"), "abc", "def"), true, - "should read: in on no nulls column"}, - {iceberg.IsIn(iceberg.Reference("no_nulls_same_value_a"), "a", "b"), true, - "should read: in on no nulls column"}, - {iceberg.IsIn(iceberg.Reference("float"), 0, -5.5), true, - "should read: float equal to lower bound"}, - {iceberg.IsIn(iceberg.Reference("no_nan_or_null"), 0, -5.5), true, - "should read: float equal to lower bound"}, - {iceberg.NotIn(iceberg.Reference("id"), int32(IntMinValue-25), IntMinValue-24), true, - "should read: id below lower bound (5 < 30, 6 < 30)"}, - {iceberg.NotIn(iceberg.Reference("id"), int32(IntMinValue-2), IntMinValue-1), true, - "should read: id below lower bound (28 < 30, 29 < 30)"}, - {iceberg.NotIn(iceberg.Reference("id"), int32(IntMinValue-1), IntMinValue), true, - "should read: id equal to lower bound (30 == 30)"}, - {iceberg.NotIn(iceberg.Reference("id"), int32(IntMaxValue-4), IntMaxValue-3), true, - "should read: id between lower and upper bounds (30 < 75 < 79, 30 < 76 < 79)"}, - {iceberg.NotIn(iceberg.Reference("id"), int32(IntMaxValue), IntMaxValue+1), true, - "should read: id equal to upper bound (79 == 79)"}, - {iceberg.NotIn(iceberg.Reference("id"), int32(IntMaxValue+1), IntMaxValue+2), true, - "should read: id above upper bound (80 > 79, 81 > 79)"}, - {iceberg.NotIn(iceberg.Reference("id"), int32(IntMaxValue+6), IntMaxValue+7), true, - "should read: id above upper bound (85 > 79, 86 > 79)"}, - {iceberg.NotIn(iceberg.Reference("all_nulls_missing_nan"), "abc", "def"), true, - "should read: notIn on all nulls column"}, - {iceberg.NotIn(iceberg.Reference("some_nulls"), "abc", "def"), true, - "should read: notIn on some nulls column"}, - {iceberg.NotIn(iceberg.Reference("no_nulls"), "abc", "def"), true, - "should read: notIn on no nulls column"}, - {iceberg.StartsWith(iceberg.Reference("some_nulls"), "a"), true, - "should read: range matches"}, - {iceberg.StartsWith(iceberg.Reference("some_nulls"), "aa"), true, - "should read: range matches"}, - {iceberg.StartsWith(iceberg.Reference("some_nulls"), "dddd"), true, - "should read: range matches"}, - {iceberg.StartsWith(iceberg.Reference("some_nulls"), "z"), true, - "should read: range matches"}, - {iceberg.StartsWith(iceberg.Reference("no_nulls"), "a"), true, - "should read: range matches"}, - {iceberg.StartsWith(iceberg.Reference("some_nulls"), "zzzz"), false, - "should skip: range doesn't match"}, - {iceberg.StartsWith(iceberg.Reference("some_nulls"), "1"), false, - "should skip: range doesn't match"}, - {iceberg.StartsWith(iceberg.Reference("no_nulls_same_value_a"), "a"), true, - "should read: all values start with the prefix"}, - {iceberg.NotStartsWith(iceberg.Reference("some_nulls"), "a"), true, - "should read: range matches"}, - {iceberg.NotStartsWith(iceberg.Reference("some_nulls"), "aa"), true, - "should read: range matches"}, - {iceberg.NotStartsWith(iceberg.Reference("some_nulls"), "dddd"), true, - "should read: range matches"}, - {iceberg.NotStartsWith(iceberg.Reference("some_nulls"), "z"), true, - "should read: range matches"}, - {iceberg.NotStartsWith(iceberg.Reference("no_nulls"), "a"), true, - "should read: range matches"}, - {iceberg.NotStartsWith(iceberg.Reference("some_nulls"), "zzzz"), true, - "should read: range matches"}, - {iceberg.NotStartsWith(iceberg.Reference("some_nulls"), "1"), true, - "should read: range matches"}, - {iceberg.NotStartsWith(iceberg.Reference("all_same_value_or_null"), "a"), true, - "should read: range matches"}, - {iceberg.NotStartsWith(iceberg.Reference("all_same_value_or_null"), "aa"), true, - "should read: range matches"}, - {iceberg.NotStartsWith(iceberg.Reference("all_same_value_or_null"), "A"), true, - "should read: range matches"}, - // Iceberg does not implement SQL 3-way boolean logic, so the choice of an - // all null column matching is by definition in order to surface more values - // to the query engine to allow it to make its own decision - {iceberg.NotStartsWith(iceberg.Reference("all_nulls_missing_nan"), "A"), true, - "should read: range matches"}, - {iceberg.NotStartsWith(iceberg.Reference("no_nulls_same_value_a"), "a"), false, - "should not read: all values start with the prefix"}, - } - - for _, tt := range tests { - t.Run(tt.expr.String(), func(t *testing.T) { - eval, err := newManifestEvaluator(spec, testSchema, - tt.expr, true) - require.NoError(t, err) - - result, err := eval(manifest) - require.NoError(t, err) - assert.Equal(t, tt.expect, result, tt.msg) - }) - } - }) -} - -type ProjectionTestSuite struct { - suite.Suite -} - -func (*ProjectionTestSuite) schema() *iceberg.Schema { - return iceberg.NewSchema(0, - iceberg.NestedField{ID: 1, Name: "id", Type: iceberg.PrimitiveTypes.Int64}, - iceberg.NestedField{ID: 2, Name: "data", Type: iceberg.PrimitiveTypes.String}, - iceberg.NestedField{ID: 3, Name: "event_date", Type: iceberg.PrimitiveTypes.Date}, - iceberg.NestedField{ID: 4, Name: "event_ts", Type: iceberg.PrimitiveTypes.Timestamp}, - ) -} - -func (*ProjectionTestSuite) emptySpec() iceberg.PartitionSpec { - return iceberg.NewPartitionSpec() -} - -func (*ProjectionTestSuite) idSpec() iceberg.PartitionSpec { - return iceberg.NewPartitionSpec( - iceberg.PartitionField{SourceID: 1, FieldID: 1000, - Transform: iceberg.IdentityTransform{}, Name: "id_part"}, - ) -} - -func (*ProjectionTestSuite) bucketSpec() iceberg.PartitionSpec { - return iceberg.NewPartitionSpec( - iceberg.PartitionField{SourceID: 2, FieldID: 1000, - Transform: iceberg.BucketTransform{NumBuckets: 16}, Name: "data_bucket"}, - ) -} - -func (*ProjectionTestSuite) daySpec() iceberg.PartitionSpec { - return iceberg.NewPartitionSpec( - iceberg.PartitionField{SourceID: 4, FieldID: 1000, - Transform: iceberg.DayTransform{}, Name: "date"}, - iceberg.PartitionField{SourceID: 3, FieldID: 1001, - Transform: iceberg.DayTransform{}, Name: "ddate"}, - ) -} - -func (*ProjectionTestSuite) hourSpec() iceberg.PartitionSpec { - return iceberg.NewPartitionSpec( - iceberg.PartitionField{SourceID: 4, FieldID: 1000, - Transform: iceberg.HourTransform{}, Name: "hour"}, - ) -} - -func (*ProjectionTestSuite) truncateStrSpec() iceberg.PartitionSpec { - return iceberg.NewPartitionSpec( - iceberg.PartitionField{SourceID: 2, FieldID: 1000, - Transform: iceberg.TruncateTransform{Width: 2}, Name: "data_trunc"}, - ) -} - -func (*ProjectionTestSuite) truncateIntSpec() iceberg.PartitionSpec { - return iceberg.NewPartitionSpec( - iceberg.PartitionField{SourceID: 1, FieldID: 1000, - Transform: iceberg.TruncateTransform{Width: 10}, Name: "id_trunc"}, - ) -} - -func (*ProjectionTestSuite) idAndBucketSpec() iceberg.PartitionSpec { - return iceberg.NewPartitionSpec( - iceberg.PartitionField{SourceID: 1, FieldID: 1000, - Transform: iceberg.IdentityTransform{}, Name: "id_part"}, - iceberg.PartitionField{SourceID: 2, FieldID: 1001, - Transform: iceberg.BucketTransform{NumBuckets: 16}, Name: "data_bucket"}, - ) -} - -func (p *ProjectionTestSuite) TestIdentityProjection() { - schema, spec := p.schema(), p.idSpec() - - idRef, idPartRef := iceberg.Reference("id"), iceberg.Reference("id_part") - tests := []struct { - pred iceberg.BooleanExpression - expected iceberg.BooleanExpression - }{ - {iceberg.NotNull(idRef), iceberg.NotNull(idPartRef)}, - {iceberg.IsNull(idRef), iceberg.IsNull(idPartRef)}, - {iceberg.LessThan(idRef, int64(100)), iceberg.LessThan(idPartRef, int64(100))}, - {iceberg.LessThanEqual(idRef, int64(101)), iceberg.LessThanEqual(idPartRef, int64(101))}, - {iceberg.GreaterThan(idRef, int64(102)), iceberg.GreaterThan(idPartRef, int64(102))}, - {iceberg.GreaterThanEqual(idRef, int64(103)), iceberg.GreaterThanEqual(idPartRef, int64(103))}, - {iceberg.EqualTo(idRef, int64(104)), iceberg.EqualTo(idPartRef, int64(104))}, - {iceberg.NotEqualTo(idRef, int64(105)), iceberg.NotEqualTo(idPartRef, int64(105))}, - {iceberg.IsIn(idRef, int64(3), 4, 5), iceberg.IsIn(idPartRef, int64(3), 4, 5)}, - {iceberg.NotIn(idRef, int64(3), 4, 5), iceberg.NotIn(idPartRef, int64(3), 4, 5)}, - } - - project := newInclusiveProjection(schema, spec, true) - for _, tt := range tests { - p.Run(tt.pred.String(), func() { - expr, err := project(tt.pred) - p.Require().NoError(err) - p.Truef(tt.expected.Equals(expr), "expected: %s\ngot: %s", tt.expected, expr) - }) - } -} - -func (p *ProjectionTestSuite) TestBucketProjection() { - schema, spec := p.schema(), p.bucketSpec() - - dataRef, dataBkt := iceberg.Reference("data"), iceberg.Reference("data_bucket") - tests := []struct { - pred, expected iceberg.BooleanExpression - }{ - {iceberg.NotNull(dataRef), iceberg.NotNull(dataBkt)}, - {iceberg.IsNull(dataRef), iceberg.IsNull(dataBkt)}, - {iceberg.LessThan(dataRef, "val"), iceberg.AlwaysTrue{}}, - {iceberg.LessThanEqual(dataRef, "val"), iceberg.AlwaysTrue{}}, - {iceberg.GreaterThan(dataRef, "val"), iceberg.AlwaysTrue{}}, - {iceberg.GreaterThanEqual(dataRef, "val"), iceberg.AlwaysTrue{}}, - {iceberg.EqualTo(dataRef, "val"), iceberg.EqualTo(dataBkt, int32(14))}, - {iceberg.NotEqualTo(dataRef, "val"), iceberg.AlwaysTrue{}}, - {iceberg.IsIn(dataRef, "v1", "v2", "v3"), iceberg.IsIn(dataBkt, int32(1), 3, 13)}, - {iceberg.NotIn(dataRef, "v1", "v2", "v3"), iceberg.AlwaysTrue{}}, - } - - project := newInclusiveProjection(schema, spec, true) - for _, tt := range tests { - p.Run(tt.pred.String(), func() { - expr, err := project(tt.pred) - p.Require().NoError(err) - p.Truef(tt.expected.Equals(expr), "expected: %s\ngot: %s", tt.expected, expr) - }) - } -} - -func (p *ProjectionTestSuite) TestHourProjection() { - schema, spec := p.schema(), p.hourSpec() - - ref, hour := iceberg.Reference("event_ts"), iceberg.Reference("hour") - tests := []struct { - pred, expected iceberg.BooleanExpression - }{ - {iceberg.NotNull(ref), iceberg.NotNull(hour)}, - {iceberg.IsNull(ref), iceberg.IsNull(hour)}, - {iceberg.LessThan(ref, "2022-11-27T10:00:00"), iceberg.LessThanEqual(hour, int32(463761))}, - {iceberg.LessThanEqual(ref, "2022-11-27T10:00:00"), iceberg.LessThanEqual(hour, int32(463762))}, - {iceberg.GreaterThan(ref, "2022-11-27T09:59:59.999999"), iceberg.GreaterThanEqual(hour, int32(463762))}, - {iceberg.GreaterThanEqual(ref, "2022-11-27T09:59:59.999999"), iceberg.GreaterThanEqual(hour, int32(463761))}, - {iceberg.EqualTo(ref, "2022-11-27T10:00:00"), iceberg.EqualTo(hour, int32(463762))}, - {iceberg.NotEqualTo(ref, "2022-11-27T10:00:00"), iceberg.AlwaysTrue{}}, - {iceberg.IsIn(ref, "2022-11-27T10:00:00", "2022-11-27T09:59:59.999999"), iceberg.IsIn(hour, int32(463761), 463762)}, - {iceberg.NotIn(ref, "2022-11-27T10:00:00", "2022-11-27T09:59:59.999999"), iceberg.AlwaysTrue{}}, - } - - project := newInclusiveProjection(schema, spec, true) - for _, tt := range tests { - p.Run(tt.pred.String(), func() { - expr, err := project(tt.pred) - p.Require().NoError(err) - p.Truef(tt.expected.Equals(expr), "expected: %s\ngot: %s", tt.expected, expr) - }) - } -} - -func (p *ProjectionTestSuite) TestDayProjection() { - schema, spec := p.schema(), p.daySpec() - - ref, date := iceberg.Reference("event_ts"), iceberg.Reference("date") - tests := []struct { - pred, expected iceberg.BooleanExpression - }{ - {iceberg.NotNull(ref), iceberg.NotNull(date)}, - {iceberg.IsNull(ref), iceberg.IsNull(date)}, - {iceberg.LessThan(ref, "2022-11-27T00:00:00"), iceberg.LessThanEqual(date, int32(19322))}, - {iceberg.LessThanEqual(ref, "2022-11-27T00:00:00"), iceberg.LessThanEqual(date, int32(19323))}, - {iceberg.GreaterThan(ref, "2022-11-26T23:59:59.999999"), iceberg.GreaterThanEqual(date, int32(19323))}, - {iceberg.GreaterThanEqual(ref, "2022-11-26T23:59:59.999999"), iceberg.GreaterThanEqual(date, int32(19322))}, - {iceberg.EqualTo(ref, "2022-11-27T10:00:00"), iceberg.EqualTo(date, int32(19323))}, - {iceberg.NotEqualTo(ref, "2022-11-27T10:00:00"), iceberg.AlwaysTrue{}}, - {iceberg.IsIn(ref, "2022-11-27T00:00:00", "2022-11-26T23:59:59.999999"), iceberg.IsIn(date, int32(19322), 19323)}, - {iceberg.NotIn(ref, "2022-11-27T00:00:00", "2022-11-26T23:59:59.999999"), iceberg.AlwaysTrue{}}, - } - - project := newInclusiveProjection(schema, spec, true) - for _, tt := range tests { - p.Run(tt.pred.String(), func() { - expr, err := project(tt.pred) - p.Require().NoError(err) - p.Truef(tt.expected.Equals(expr), "expected: %s\ngot: %s", tt.expected, expr) - }) - } -} - -func (p *ProjectionTestSuite) TestDateDayProjection() { - schema, spec := p.schema(), p.daySpec() - - ref, date := iceberg.Reference("event_date"), iceberg.Reference("ddate") - tests := []struct { - pred, expected iceberg.BooleanExpression - }{ - {iceberg.NotNull(ref), iceberg.NotNull(date)}, - {iceberg.IsNull(ref), iceberg.IsNull(date)}, - {iceberg.LessThan(ref, "2022-11-27"), iceberg.LessThanEqual(date, int32(19322))}, - {iceberg.LessThanEqual(ref, "2022-11-27"), iceberg.LessThanEqual(date, int32(19323))}, - {iceberg.GreaterThan(ref, "2022-11-26"), iceberg.GreaterThanEqual(date, int32(19323))}, - {iceberg.GreaterThanEqual(ref, "2022-11-26"), iceberg.GreaterThanEqual(date, int32(19322))}, - {iceberg.EqualTo(ref, "2022-11-27"), iceberg.EqualTo(date, int32(19323))}, - {iceberg.NotEqualTo(ref, "2022-11-27"), iceberg.AlwaysTrue{}}, - {iceberg.IsIn(ref, "2022-11-27", "2022-11-26"), iceberg.IsIn(date, int32(19322), 19323)}, - {iceberg.NotIn(ref, "2022-11-27", "2022-11-26"), iceberg.AlwaysTrue{}}, - } - - project := newInclusiveProjection(schema, spec, true) - for _, tt := range tests { - p.Run(tt.pred.String(), func() { - expr, err := project(tt.pred) - p.Require().NoError(err) - p.Truef(tt.expected.Equals(expr), "expected: %s\ngot: %s", tt.expected, expr) - }) - } -} - -func (p *ProjectionTestSuite) TestStringTruncateProjection() { - schema, spec := p.schema(), p.truncateStrSpec() - - ref, truncStr := iceberg.Reference("data"), iceberg.Reference("data_trunc") - tests := []struct { - pred, expected iceberg.BooleanExpression - }{ - {iceberg.NotNull(ref), iceberg.NotNull(truncStr)}, - {iceberg.IsNull(ref), iceberg.IsNull(truncStr)}, - {iceberg.LessThan(ref, "aaa"), iceberg.LessThanEqual(truncStr, "aa")}, - {iceberg.LessThanEqual(ref, "aaa"), iceberg.LessThanEqual(truncStr, "aa")}, - {iceberg.GreaterThan(ref, "aaa"), iceberg.GreaterThanEqual(truncStr, "aa")}, - {iceberg.GreaterThanEqual(ref, "aaa"), iceberg.GreaterThanEqual(truncStr, "aa")}, - {iceberg.EqualTo(ref, "aaa"), iceberg.EqualTo(truncStr, "aa")}, - {iceberg.NotEqualTo(ref, "aaa"), iceberg.AlwaysTrue{}}, - {iceberg.IsIn(ref, "aaa", "aab"), iceberg.EqualTo(truncStr, "aa")}, - {iceberg.NotIn(ref, "aaa", "aab"), iceberg.AlwaysTrue{}}, - } - - project := newInclusiveProjection(schema, spec, true) - for _, tt := range tests { - p.Run(tt.pred.String(), func() { - expr, err := project(tt.pred) - p.Require().NoError(err) - p.Truef(tt.expected.Equals(expr), "expected: %s\ngot: %s", tt.expected, expr) - }) - } -} - -func (p *ProjectionTestSuite) TestIntTruncateProjection() { - schema, spec := p.schema(), p.truncateIntSpec() - - ref, idTrunc := iceberg.Reference("id"), iceberg.Reference("id_trunc") - tests := []struct { - pred, expected iceberg.BooleanExpression - }{ - {iceberg.NotNull(ref), iceberg.NotNull(idTrunc)}, - {iceberg.IsNull(ref), iceberg.IsNull(idTrunc)}, - {iceberg.LessThan(ref, int32(10)), iceberg.LessThanEqual(idTrunc, int64(0))}, - {iceberg.LessThanEqual(ref, int32(10)), iceberg.LessThanEqual(idTrunc, int64(10))}, - {iceberg.GreaterThan(ref, int32(9)), iceberg.GreaterThanEqual(idTrunc, int64(10))}, - {iceberg.GreaterThanEqual(ref, int32(10)), iceberg.GreaterThanEqual(idTrunc, int64(10))}, - {iceberg.EqualTo(ref, int32(15)), iceberg.EqualTo(idTrunc, int64(10))}, - {iceberg.NotEqualTo(ref, int32(15)), iceberg.AlwaysTrue{}}, - {iceberg.IsIn(ref, int32(15), 16), iceberg.EqualTo(idTrunc, int64(10))}, - {iceberg.NotIn(ref, int32(15), 16), iceberg.AlwaysTrue{}}, - } - - project := newInclusiveProjection(schema, spec, true) - for _, tt := range tests { - p.Run(tt.pred.String(), func() { - expr, err := project(tt.pred) - p.Require().NoError(err) - p.Truef(tt.expected.Equals(expr), "expected: %s\ngot: %s", tt.expected, expr) - }) - } -} - -func (p *ProjectionTestSuite) TestProjectionCaseSensitive() { - schema, spec := p.schema(), p.idSpec() - project := newInclusiveProjection(schema, spec, true) - _, err := project(iceberg.NotNull(iceberg.Reference("ID"))) - p.ErrorIs(err, iceberg.ErrInvalidSchema) - p.ErrorContains(err, "could not bind reference 'ID', caseSensitive=true") -} - -func (p *ProjectionTestSuite) TestProjectionCaseInsensitive() { - schema, spec := p.schema(), p.idSpec() - project := newInclusiveProjection(schema, spec, false) - expr, err := project(iceberg.NotNull(iceberg.Reference("ID"))) - p.Require().NoError(err) - p.True(expr.Equals(iceberg.NotNull(iceberg.Reference("id_part")))) -} - -func (p *ProjectionTestSuite) TestProjectEmptySpec() { - project := newInclusiveProjection(p.schema(), p.emptySpec(), true) - expr, err := project(iceberg.NewAnd(iceberg.LessThan(iceberg.Reference("id"), int32(5)), - iceberg.NotNull(iceberg.Reference("data")))) - p.Require().NoError(err) - p.Equal(iceberg.AlwaysTrue{}, expr) -} - -func (p *ProjectionTestSuite) TestAndProjectionMultipleFields() { - project := newInclusiveProjection(p.schema(), p.idAndBucketSpec(), true) - expr, err := project(iceberg.NewAnd(iceberg.LessThan(iceberg.Reference("id"), - int32(5)), iceberg.IsIn(iceberg.Reference("data"), "a", "b", "c"))) - p.Require().NoError(err) - - p.True(expr.Equals(iceberg.NewAnd(iceberg.LessThan(iceberg.Reference("id_part"), int64(5)), - iceberg.IsIn(iceberg.Reference("data_bucket"), int32(2), 3, 15)))) -} - -func (p *ProjectionTestSuite) TestOrProjectionMultipleFields() { - project := newInclusiveProjection(p.schema(), p.idAndBucketSpec(), true) - expr, err := project(iceberg.NewOr(iceberg.LessThan(iceberg.Reference("id"), int32(5)), - iceberg.IsIn(iceberg.Reference("data"), "a", "b", "c"))) - p.Require().NoError(err) - - p.True(expr.Equals(iceberg.NewOr(iceberg.LessThan(iceberg.Reference("id_part"), int64(5)), - iceberg.IsIn(iceberg.Reference("data_bucket"), int32(2), 3, 15)))) -} - -func (p *ProjectionTestSuite) TestNotProjectionMultipleFields() { - project := newInclusiveProjection(p.schema(), p.idAndBucketSpec(), true) - // not causes In to be rewritten to NotIn, which cannot be projected - expr, err := project(iceberg.NewNot(iceberg.NewOr(iceberg.LessThan(iceberg.Reference("id"), int64(5)), - iceberg.IsIn(iceberg.Reference("data"), "a", "b", "c")))) - p.Require().NoError(err) - - p.True(expr.Equals(iceberg.GreaterThanEqual(iceberg.Reference("id_part"), int64(5)))) -} - -func (p *ProjectionTestSuite) TestPartialProjectedFields() { - project := newInclusiveProjection(p.schema(), p.idSpec(), true) - expr, err := project(iceberg.NewAnd(iceberg.LessThan(iceberg.Reference("id"), int32(5)), - iceberg.IsIn(iceberg.Reference("data"), "a", "b", "c"))) - p.Require().NoError(err) - p.True(expr.Equals(iceberg.LessThan(iceberg.Reference("id_part"), int64(5)))) -} - -type mockDataFile struct { - path string - format iceberg.FileFormat - partition map[string]any - count int64 - columnSizes map[int]int64 - filesize int64 - valueCounts map[int]int64 - nullCounts map[int]int64 - nanCounts map[int]int64 - lowerBounds map[int][]byte - upperBounds map[int][]byte -} - -func (*mockDataFile) ContentType() iceberg.ManifestEntryContent { return iceberg.EntryContentData } -func (m *mockDataFile) FilePath() string { return m.path } -func (m *mockDataFile) FileFormat() iceberg.FileFormat { return m.format } -func (m *mockDataFile) Partition() map[string]any { return m.partition } -func (m *mockDataFile) Count() int64 { return m.count } -func (m *mockDataFile) FileSizeBytes() int64 { return m.filesize } -func (m *mockDataFile) ColumnSizes() map[int]int64 { return m.columnSizes } -func (m *mockDataFile) ValueCounts() map[int]int64 { return m.valueCounts } -func (m *mockDataFile) NullValueCounts() map[int]int64 { return m.nullCounts } -func (m *mockDataFile) NaNValueCounts() map[int]int64 { return m.nanCounts } -func (*mockDataFile) DistinctValueCounts() map[int]int64 { return nil } -func (m *mockDataFile) LowerBoundValues() map[int][]byte { return m.lowerBounds } -func (m *mockDataFile) UpperBoundValues() map[int][]byte { return m.upperBounds } -func (*mockDataFile) KeyMetadata() []byte { return nil } -func (*mockDataFile) SplitOffsets() []int64 { return nil } -func (*mockDataFile) EqualityFieldIDs() []int { return nil } -func (*mockDataFile) SortOrderID() *int { return nil } - -type InclusiveMetricsTestSuite struct { - suite.Suite - - schemaDataFile *iceberg.Schema - dataFiles [4]iceberg.DataFile - - schemaDataFileNan *iceberg.Schema - dataFileNan iceberg.DataFile -} - -func (suite *InclusiveMetricsTestSuite) SetupSuite() { - suite.schemaDataFile = iceberg.NewSchema(0, - iceberg.NestedField{ID: 1, Name: "id", Type: iceberg.PrimitiveTypes.Int32, Required: true}, - iceberg.NestedField{ID: 2, Name: "no_stats", Type: iceberg.PrimitiveTypes.Int32, Required: false}, - iceberg.NestedField{ID: 3, Name: "required", Type: iceberg.PrimitiveTypes.String, Required: true}, - iceberg.NestedField{ID: 4, Name: "all_nulls", Type: iceberg.PrimitiveTypes.String}, - iceberg.NestedField{ID: 5, Name: "some_nulls", Type: iceberg.PrimitiveTypes.String}, - iceberg.NestedField{ID: 6, Name: "no_nulls", Type: iceberg.PrimitiveTypes.String}, - iceberg.NestedField{ID: 7, Name: "all_nans", Type: iceberg.PrimitiveTypes.Float64}, - iceberg.NestedField{ID: 8, Name: "some_nans", Type: iceberg.PrimitiveTypes.Float32}, - iceberg.NestedField{ID: 9, Name: "no_nans", Type: iceberg.PrimitiveTypes.Float32}, - iceberg.NestedField{ID: 10, Name: "all_nulls_double", Type: iceberg.PrimitiveTypes.Float64}, - iceberg.NestedField{ID: 11, Name: "all_nans_v1_stats", Type: iceberg.PrimitiveTypes.Float32}, - iceberg.NestedField{ID: 12, Name: "nan_and_null_only", Type: iceberg.PrimitiveTypes.Float64}, - iceberg.NestedField{ID: 13, Name: "no_nan_stats", Type: iceberg.PrimitiveTypes.Float64}, - iceberg.NestedField{ID: 14, Name: "some_empty", Type: iceberg.PrimitiveTypes.String}, - ) - - var ( - IntMin, _ = iceberg.Int32Literal(IntMinValue).MarshalBinary() - IntMax, _ = iceberg.Int32Literal(IntMaxValue).MarshalBinary() - FltNan, _ = iceberg.Float32Literal(float32(math.NaN())).MarshalBinary() - DblNan, _ = iceberg.Float64Literal(math.NaN()).MarshalBinary() - FltSeven, _ = iceberg.Float32Literal(7).MarshalBinary() - DblSeven, _ = iceberg.Float64Literal(7).MarshalBinary() - FltMax, _ = iceberg.Float32Literal(22).MarshalBinary() - ) - - suite.dataFiles = [4]iceberg.DataFile{ - &mockDataFile{ - path: "file_1.parquet", - format: iceberg.ParquetFile, - count: 50, - filesize: 3, - valueCounts: map[int]int64{ - 4: 50, 5: 50, 6: 50, 7: 50, 8: 50, 9: 50, - 10: 50, 11: 50, 12: 50, 13: 50, 14: 50, - }, - nullCounts: map[int]int64{4: 50, 5: 10, 6: 0, 10: 50, 11: 0, 12: 1, 14: 8}, - nanCounts: map[int]int64{7: 50, 8: 10, 9: 0}, - lowerBounds: map[int][]byte{ - 1: IntMin, - 11: FltNan, - 12: DblNan, - 14: {}, - }, - upperBounds: map[int][]byte{ - 1: IntMax, - 11: FltNan, - 12: DblNan, - 14: []byte("房东整租霍营小区二层两居室"), - }, - }, - &mockDataFile{ - path: "file_2.parquet", - format: iceberg.ParquetFile, - count: 50, - filesize: 3, - valueCounts: map[int]int64{3: 20}, - nullCounts: map[int]int64{3: 2}, - nanCounts: nil, - lowerBounds: map[int][]byte{3: {'a', 'a'}}, - upperBounds: map[int][]byte{3: {'d', 'C'}}, - }, - &mockDataFile{ - path: "file_3.parquet", - format: iceberg.ParquetFile, - count: 50, - filesize: 3, - valueCounts: map[int]int64{3: 20}, - nullCounts: map[int]int64{3: 2}, - nanCounts: nil, - lowerBounds: map[int][]byte{3: []byte("1str1")}, - upperBounds: map[int][]byte{3: []byte("3str3")}, - }, - &mockDataFile{ - path: "file_4.parquet", - format: iceberg.ParquetFile, - count: 50, - filesize: 3, - valueCounts: map[int]int64{3: 20}, - nullCounts: map[int]int64{3: 2}, - nanCounts: nil, - lowerBounds: map[int][]byte{3: []byte("abc")}, - upperBounds: map[int][]byte{3: []byte("イロハニホヘト")}, - }, - } - - suite.schemaDataFileNan = iceberg.NewSchema(0, - iceberg.NestedField{ID: 1, Name: "all_nan", Type: iceberg.PrimitiveTypes.Float64, Required: true}, - iceberg.NestedField{ID: 2, Name: "max_nan", Type: iceberg.PrimitiveTypes.Float64, Required: true}, - iceberg.NestedField{ID: 3, Name: "min_max_nan", Type: iceberg.PrimitiveTypes.Float32}, - iceberg.NestedField{ID: 4, Name: "all_nan_null_bounds", Type: iceberg.PrimitiveTypes.Float64, Required: true}, - iceberg.NestedField{ID: 5, Name: "some_nan_correct_bounds", Type: iceberg.PrimitiveTypes.Float32}, - ) - - suite.dataFileNan = &mockDataFile{ - path: "file.avro", - format: iceberg.AvroFile, - count: 50, - filesize: 3, - columnSizes: map[int]int64{1: 10, 2: 10, 3: 10, 4: 10, 5: 10}, - valueCounts: map[int]int64{1: 10, 2: 10, 3: 10, 4: 10, 5: 10}, - nullCounts: map[int]int64{1: 0, 2: 0, 3: 0, 4: 0, 5: 0}, - nanCounts: map[int]int64{1: 10, 4: 10, 5: 5}, - lowerBounds: map[int][]byte{ - 1: DblNan, - 2: DblSeven, - 3: FltNan, - 5: FltSeven, - }, - upperBounds: map[int][]byte{ - 1: DblNan, - 2: DblNan, - 3: FltNan, - 5: FltMax, - }, - } -} - -func (suite *InclusiveMetricsTestSuite) TestAllNull() { - allNull, someNull, noNull := iceberg.Reference("all_nulls"), iceberg.Reference("some_nulls"), iceberg.Reference("no_nulls") - - tests := []struct { - expr iceberg.BooleanExpression - expected bool - msg string - }{ - {iceberg.NotNull(allNull), false, "should skip: no non-null value in all null column"}, - {iceberg.LessThan(allNull, "a"), false, "should skip: lessThan on all null column"}, - {iceberg.LessThanEqual(allNull, "a"), false, "should skip: lessThanEqual on all null column"}, - {iceberg.GreaterThan(allNull, "a"), false, "should skip: greaterThan on all null column"}, - {iceberg.GreaterThanEqual(allNull, "a"), false, "should skip: greaterThanEqual on all null column"}, - {iceberg.EqualTo(allNull, "a"), false, "should skip: equal on all null column"}, - {iceberg.NotNull(someNull), true, "should read: column with some nulls contains a non-null value"}, - {iceberg.NotNull(noNull), true, "should read: non-null column contains a non-null value"}, - {iceberg.StartsWith(allNull, "asad"), false, "should skip: starts with on all null column"}, - {iceberg.NotStartsWith(allNull, "asad"), true, "should read: notStartsWith on all null column"}, - } - - for _, tt := range tests { - suite.Run(tt.expr.String(), func() { - eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) - suite.Require().NoError(err) - shouldRead, err := eval(suite.dataFiles[0]) - suite.Require().NoError(err) - suite.Equal(tt.expected, shouldRead, tt.msg) - }) - } -} - -func (suite *InclusiveMetricsTestSuite) TestNoNulls() { - allNull, someNull, noNull := iceberg.Reference("all_nulls"), iceberg.Reference("some_nulls"), iceberg.Reference("no_nulls") - - tests := []struct { - expr iceberg.BooleanExpression - expected bool - msg string - }{ - {iceberg.IsNull(allNull), true, "should read: at least one null value in all null column"}, - {iceberg.IsNull(someNull), true, "should read: column with some nulls contains a null value"}, - {iceberg.IsNull(noNull), false, "should skip: non-null column contains no null values"}, - } - - for _, tt := range tests { - suite.Run(tt.expr.String(), func() { - eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) - suite.Require().NoError(err) - shouldRead, err := eval(suite.dataFiles[0]) - suite.Require().NoError(err) - suite.Equal(tt.expected, shouldRead, tt.msg) - }) - } -} - -func (suite *InclusiveMetricsTestSuite) TestIsNan() { - allNan, someNan, noNan := iceberg.Reference("all_nans"), iceberg.Reference("some_nans"), iceberg.Reference("no_nans") - allNullsDbl, noNanStats := iceberg.Reference("all_nulls_double"), iceberg.Reference("no_nan_stats") - allNansV1, nanNullOnly := iceberg.Reference("all_nans_v1_stats"), iceberg.Reference("nan_and_null_only") - - tests := []struct { - expr iceberg.BooleanExpression - expected bool - msg string - }{ - {iceberg.IsNaN(allNan), true, "should read: at least one nan value in all nan column"}, - {iceberg.IsNaN(someNan), true, "should read: at least one nan value in some nan column"}, - {iceberg.IsNaN(noNan), false, "should skip: no-nans column has no nans"}, - {iceberg.IsNaN(allNullsDbl), false, "should skip: all-null column doesn't contain nan values"}, - {iceberg.IsNaN(noNanStats), true, "should read: no guarantee if contains nan without stats"}, - {iceberg.IsNaN(allNansV1), true, "should read: at least one nan value in all nan column"}, - {iceberg.IsNaN(nanNullOnly), true, "should read: at least one nan value in nan and nulls only column"}, - } - - for _, tt := range tests { - suite.Run(tt.expr.String(), func() { - eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) - suite.Require().NoError(err) - shouldRead, err := eval(suite.dataFiles[0]) - suite.Require().NoError(err) - suite.Equal(tt.expected, shouldRead, tt.msg) - }) - } -} - -func (suite *InclusiveMetricsTestSuite) TestNotNaN() { - allNan, someNan, noNan := iceberg.Reference("all_nans"), iceberg.Reference("some_nans"), iceberg.Reference("no_nans") - allNullsDbl, noNanStats := iceberg.Reference("all_nulls_double"), iceberg.Reference("no_nan_stats") - allNansV1, nanNullOnly := iceberg.Reference("all_nans_v1_stats"), iceberg.Reference("nan_and_null_only") - - tests := []struct { - expr iceberg.BooleanExpression - expected bool - msg string - }{ - {iceberg.NotNaN(allNan), false, "should skip: column with all nans will not contain non-nan"}, - {iceberg.NotNaN(someNan), true, "should read: at least one non-nan value in some nan column"}, - {iceberg.NotNaN(noNan), true, "should read: at least one non-nan value in no nan column"}, - {iceberg.NotNaN(allNullsDbl), true, "should read: at least one non-nan value in all null column"}, - {iceberg.NotNaN(noNanStats), true, "should read: no guarantee if contains nan without stats"}, - {iceberg.NotNaN(allNansV1), true, "should read: no guarantee"}, - {iceberg.NotNaN(nanNullOnly), true, "should read: at least one null value in nan and nulls only column"}, - } - - for _, tt := range tests { - suite.Run(tt.expr.String(), func() { - eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) - suite.Require().NoError(err) - shouldRead, err := eval(suite.dataFiles[0]) - suite.Require().NoError(err) - suite.Equal(tt.expected, shouldRead, tt.msg) - }) - } -} - -func (suite *InclusiveMetricsTestSuite) TestRequiredColumn() { - tests := []struct { - expr iceberg.BooleanExpression - expected bool - msg string - }{ - {iceberg.NotNull(iceberg.Reference("required")), true, "should read: required columns are always non-null"}, - {iceberg.IsNull(iceberg.Reference("required")), false, "should skip: required columns are always non-null"}, - } - - for _, tt := range tests { - suite.Run(tt.expr.String(), func() { - eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) - suite.Require().NoError(err) - shouldRead, err := eval(suite.dataFiles[0]) - suite.Require().NoError(err) - suite.Equal(tt.expected, shouldRead, tt.msg) - }) - } -} - -func (suite *InclusiveMetricsTestSuite) TestMissingColumn() { - _, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, iceberg.LessThan(iceberg.Reference("missing"), int32(22)), true, true) - suite.ErrorIs(err, iceberg.ErrInvalidSchema) -} - -func (suite *InclusiveMetricsTestSuite) TestMissingStats() { - noStatsSchema := iceberg.NewSchema(0, - iceberg.NestedField{ID: 2, Name: "no_stats", Type: iceberg.PrimitiveTypes.Float64}) - - noStatsFile := &mockDataFile{ - path: "file_1.parquet", - format: iceberg.ParquetFile, - count: 50, - } - - ref := iceberg.Reference("no_stats") - tests := []iceberg.BooleanExpression{ - iceberg.LessThan(ref, int32(5)), - iceberg.LessThanEqual(ref, int32(30)), - iceberg.EqualTo(ref, int32(70)), - iceberg.GreaterThan(ref, int32(78)), - iceberg.GreaterThanEqual(ref, int32(90)), - iceberg.NotEqualTo(ref, int32(101)), - iceberg.IsNull(ref), - iceberg.NotNull(ref), - iceberg.IsNaN(ref), - iceberg.NotNaN(ref), - } - - for _, tt := range tests { - suite.Run(tt.String(), func() { - eval, err := newInclusiveMetricsEvaluator(noStatsSchema, tt, true, true) - suite.Require().NoError(err) - shouldRead, err := eval(noStatsFile) - suite.Require().NoError(err) - suite.True(shouldRead, "should read when stats are missing") - }) - } -} - -func (suite *InclusiveMetricsTestSuite) TestZeroRecordFileStats() { - zeroRecordFile := &mockDataFile{ - path: "file_1.parquet", - format: iceberg.ParquetFile, - count: 0, - } - - ref := iceberg.Reference("no_stats") - tests := []iceberg.BooleanExpression{ - iceberg.LessThan(ref, int32(5)), - iceberg.LessThanEqual(ref, int32(30)), - iceberg.EqualTo(ref, int32(70)), - iceberg.GreaterThan(ref, int32(78)), - iceberg.GreaterThanEqual(ref, int32(90)), - iceberg.NotEqualTo(ref, int32(101)), - iceberg.IsNull(ref), - iceberg.NotNull(ref), - iceberg.IsNaN(ref), - iceberg.NotNaN(ref), - } - - for _, tt := range tests { - suite.Run(tt.String(), func() { - eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt, true, false) - suite.Require().NoError(err) - shouldRead, err := eval(zeroRecordFile) - suite.Require().NoError(err) - suite.False(shouldRead, "should skip datafile without records") - }) - } -} - -func (suite *InclusiveMetricsTestSuite) TestNot() { - tests := []struct { - expr iceberg.BooleanExpression - expected bool - msg string - }{ - {iceberg.NewNot(iceberg.LessThan(iceberg.Reference("id"), IntMinValue-25)), true, "should read: not(false)"}, - {iceberg.NewNot(iceberg.GreaterThan(iceberg.Reference("id"), IntMinValue-25)), false, "should skip: not(true)"}, - } - - for _, tt := range tests { - suite.Run(tt.expr.String(), func() { - eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) - suite.Require().NoError(err) - shouldRead, err := eval(suite.dataFiles[0]) - suite.Require().NoError(err) - suite.Equal(tt.expected, shouldRead, tt.msg) - }) - } -} - -func (suite *InclusiveMetricsTestSuite) TestAnd() { - ref := iceberg.Reference("id") - tests := []struct { - expr iceberg.BooleanExpression - expected bool - msg string - }{ - {iceberg.NewAnd( - iceberg.LessThan(ref, IntMinValue-25), - iceberg.GreaterThanEqual(ref, IntMinValue-30)), false, "should skip: and(false, true)"}, - {iceberg.NewAnd( - iceberg.LessThan(ref, IntMinValue-25), - iceberg.GreaterThanEqual(ref, IntMinValue+1)), false, "should skip: and(false, false)"}, - {iceberg.NewAnd( - iceberg.GreaterThan(ref, IntMinValue-25), - iceberg.LessThanEqual(ref, IntMinValue)), true, "should read: and(true, true)"}, - } - - for _, tt := range tests { - suite.Run(tt.expr.String(), func() { - eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) - suite.Require().NoError(err) - shouldRead, err := eval(suite.dataFiles[0]) - suite.Require().NoError(err) - suite.Equal(tt.expected, shouldRead, tt.msg) - }) - } -} - -func (suite *InclusiveMetricsTestSuite) TestOr() { - ref := iceberg.Reference("id") - tests := []struct { - expr iceberg.BooleanExpression - expected bool - msg string - }{ - {iceberg.NewOr( - iceberg.LessThan(ref, IntMinValue-25), - iceberg.GreaterThanEqual(ref, IntMaxValue+1)), false, "should skip: or(false, false)"}, - {iceberg.NewOr( - iceberg.LessThan(ref, IntMinValue-25), - iceberg.GreaterThanEqual(ref, IntMaxValue-19)), true, "should read: or(false, true)"}, - } - - for _, tt := range tests { - suite.Run(tt.expr.String(), func() { - eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) - suite.Require().NoError(err) - shouldRead, err := eval(suite.dataFiles[0]) - suite.Require().NoError(err) - suite.Equal(tt.expected, shouldRead, tt.msg) - }) - } -} - -func (suite *InclusiveMetricsTestSuite) TestIntLt() { - ref := iceberg.Reference("id") - tests := []struct { - expr iceberg.BooleanExpression - expected bool - msg string - }{ - {iceberg.LessThan(ref, IntMinValue-25), false, "should skip: id range below lower bound (5 < 30)"}, - {iceberg.LessThan(ref, IntMinValue), false, "should skip: id range below lower bound (30 is not < 30)"}, - {iceberg.LessThan(ref, IntMinValue+1), true, "should read: one possible id"}, - {iceberg.LessThan(ref, IntMaxValue), true, "should read: many possible ids"}, - } - - for _, tt := range tests { - suite.Run(tt.expr.String(), func() { - eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) - suite.Require().NoError(err) - shouldRead, err := eval(suite.dataFiles[0]) - suite.Require().NoError(err) - suite.Equal(tt.expected, shouldRead, tt.msg) - }) - } -} - -func (suite *InclusiveMetricsTestSuite) TestIntLtEq() { - ref := iceberg.Reference("id") - tests := []struct { - expr iceberg.BooleanExpression - expected bool - msg string - }{ - {iceberg.LessThanEqual(ref, IntMinValue-25), false, "should skip: id range below lower bound (5 < 30)"}, - {iceberg.LessThanEqual(ref, IntMinValue-1), false, "should skip: id range below lower bound (29 is not <= 30)"}, - {iceberg.LessThanEqual(ref, IntMinValue), true, "should read: one possible id"}, - {iceberg.LessThanEqual(ref, IntMaxValue), true, "should read: many possible ids"}, - } - - for _, tt := range tests { - suite.Run(tt.expr.String(), func() { - eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) - suite.Require().NoError(err) - shouldRead, err := eval(suite.dataFiles[0]) - suite.Require().NoError(err) - suite.Equal(tt.expected, shouldRead, tt.msg) - }) - } -} - -func (suite *InclusiveMetricsTestSuite) TestIntGt() { - ref := iceberg.Reference("id") - tests := []struct { - expr iceberg.BooleanExpression - expected bool - msg string - }{ - {iceberg.GreaterThan(ref, IntMaxValue+6), false, "should skip: id range above upper bound (85 > 79)"}, - {iceberg.GreaterThan(ref, IntMaxValue), false, "should skip: id range above upper bound (79 is not > 79)"}, - {iceberg.GreaterThan(ref, IntMinValue-1), true, "should read: one possible id"}, - {iceberg.GreaterThan(ref, IntMaxValue-4), true, "should read: many possible ids"}, - } - - for _, tt := range tests { - suite.Run(tt.expr.String(), func() { - eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) - suite.Require().NoError(err) - shouldRead, err := eval(suite.dataFiles[0]) - suite.Require().NoError(err) - suite.Equal(tt.expected, shouldRead, tt.msg) - }) - } -} - -func (suite *InclusiveMetricsTestSuite) TestIntGtEq() { - ref := iceberg.Reference("id") - tests := []struct { - expr iceberg.BooleanExpression - expected bool - msg string - }{ - {iceberg.GreaterThanEqual(ref, IntMaxValue+6), false, "should skip: id range above upper bound (85 < 79)"}, - {iceberg.GreaterThanEqual(ref, IntMaxValue+1), false, "should skip: id range above upper bound (80 > 79)"}, - {iceberg.GreaterThanEqual(ref, IntMaxValue), true, "should read: one possible id"}, - {iceberg.GreaterThanEqual(ref, IntMaxValue-4), true, "should read: many possible ids"}, - } - - for _, tt := range tests { - suite.Run(tt.expr.String(), func() { - eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) - suite.Require().NoError(err) - shouldRead, err := eval(suite.dataFiles[0]) - suite.Require().NoError(err) - suite.Equal(tt.expected, shouldRead, tt.msg) - }) - } -} - -func (suite *InclusiveMetricsTestSuite) TestIntEq() { - ref := iceberg.Reference("id") - tests := []struct { - expr iceberg.BooleanExpression - expected bool - msg string - }{ - {iceberg.EqualTo(ref, IntMinValue-25), false, "should skip: id range below lower bound"}, - {iceberg.EqualTo(ref, IntMinValue-1), false, "should skip: id range below lower bound"}, - {iceberg.EqualTo(ref, IntMinValue), true, "should read: id equal to lower bound"}, - {iceberg.EqualTo(ref, IntMaxValue-4), true, "should read: id between lower and upper bounds"}, - {iceberg.EqualTo(ref, IntMaxValue), true, "should read: id equal to upper bound"}, - {iceberg.EqualTo(ref, IntMaxValue+1), false, "should skip: id above upper bound"}, - {iceberg.EqualTo(ref, IntMaxValue+6), false, "should skip: id above upper bound"}, - } - - for _, tt := range tests { - suite.Run(tt.expr.String(), func() { - eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) - suite.Require().NoError(err) - shouldRead, err := eval(suite.dataFiles[0]) - suite.Require().NoError(err) - suite.Equal(tt.expected, shouldRead, tt.msg) - }) - } -} - -func (suite *InclusiveMetricsTestSuite) TestIntNeq() { - ref := iceberg.Reference("id") - tests := []struct { - expr iceberg.BooleanExpression - expected bool - msg string - }{ - {iceberg.NotEqualTo(ref, IntMinValue-25), true, "should read: id range below lower bound"}, - {iceberg.NotEqualTo(ref, IntMinValue-1), true, "should read: id range below lower bound"}, - {iceberg.NotEqualTo(ref, IntMinValue), true, "should read: id equal to lower bound"}, - {iceberg.NotEqualTo(ref, IntMaxValue-4), true, "should read: id between lower and upper bounds"}, - {iceberg.NotEqualTo(ref, IntMaxValue), true, "should read: id equal to upper bound"}, - {iceberg.NotEqualTo(ref, IntMaxValue+1), true, "should read: id above upper bound"}, - {iceberg.NotEqualTo(ref, IntMaxValue+6), true, "should read: id above upper bound"}, - } - - for _, tt := range tests { - suite.Run(tt.expr.String(), func() { - eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) - suite.Require().NoError(err) - shouldRead, err := eval(suite.dataFiles[0]) - suite.Require().NoError(err) - suite.Equal(tt.expected, shouldRead, tt.msg) - }) - } -} - -func (suite *InclusiveMetricsTestSuite) TestIntNeqRewritten() { - ref := iceberg.Reference("id") - tests := []struct { - expr iceberg.BooleanExpression - expected bool - msg string - }{ - {iceberg.EqualTo(ref, IntMinValue-25), true, "should read: id range below lower bound"}, - {iceberg.EqualTo(ref, IntMinValue-1), true, "should read: id range below lower bound"}, - {iceberg.EqualTo(ref, IntMinValue), true, "should read: id equal to lower bound"}, - {iceberg.EqualTo(ref, IntMaxValue-4), true, "should read: id between lower and upper bounds"}, - {iceberg.EqualTo(ref, IntMaxValue), true, "should read: id equal to upper bound"}, - {iceberg.EqualTo(ref, IntMaxValue+1), true, "should read: id above upper bound"}, - {iceberg.EqualTo(ref, IntMaxValue+6), true, "should read: id above upper bound"}, - } - - for _, tt := range tests { - suite.Run(tt.expr.String(), func() { - eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, iceberg.NewNot(tt.expr), true, true) - suite.Require().NoError(err) - shouldRead, err := eval(suite.dataFiles[0]) - suite.Require().NoError(err) - suite.Equal(tt.expected, shouldRead, tt.msg) - }) - } -} - -func (suite *InclusiveMetricsTestSuite) TestIntNeqRewrittenCaseInsensitive() { - ref := iceberg.Reference("ID") - tests := []struct { - expr iceberg.BooleanExpression - expected bool - msg string - }{ - {iceberg.EqualTo(ref, IntMinValue-25), true, "should read: id range below lower bound"}, - {iceberg.EqualTo(ref, IntMinValue-1), true, "should read: id range below lower bound"}, - {iceberg.EqualTo(ref, IntMinValue), true, "should read: id equal to lower bound"}, - {iceberg.EqualTo(ref, IntMaxValue-4), true, "should read: id between lower and upper bounds"}, - {iceberg.EqualTo(ref, IntMaxValue), true, "should read: id equal to upper bound"}, - {iceberg.EqualTo(ref, IntMaxValue+1), true, "should read: id above upper bound"}, - {iceberg.EqualTo(ref, IntMaxValue+6), true, "should read: id above upper bound"}, - } - - for _, tt := range tests { - suite.Run(tt.expr.String(), func() { - eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, iceberg.NewNot(tt.expr), false, true) - suite.Require().NoError(err) - shouldRead, err := eval(suite.dataFiles[0]) - suite.Require().NoError(err) - suite.Equal(tt.expected, shouldRead, tt.msg) - }) - } -} - -func (suite *InclusiveMetricsTestSuite) TestInMetrics() { - ref := iceberg.Reference("id") - - ids := make([]int32, 400) - for i := range ids { - ids[i] = int32(i) - } - - tests := []struct { - expr iceberg.BooleanExpression - expected bool - msg string - }{ - {iceberg.IsIn(ref, IntMinValue-25, IntMinValue-24), false, "should skip: id below lower bound"}, - {iceberg.IsIn(ref, IntMinValue-2, IntMinValue-1), false, "should skip: id below lower bound"}, - {iceberg.IsIn(ref, IntMinValue-1, IntMinValue), true, "should read: id equal to lower bound"}, - {iceberg.IsIn(ref, IntMaxValue-4, IntMaxValue-3), true, "should read: id between upper and lower bounds"}, - {iceberg.IsIn(ref, IntMaxValue, IntMaxValue+1), true, "should read: id equal to upper bound"}, - {iceberg.IsIn(ref, IntMaxValue+1, IntMaxValue+2), false, "should skip: id above upper bound"}, - {iceberg.IsIn(ref, IntMaxValue+6, IntMaxValue+7), false, "should skip: id above upper bound"}, - {iceberg.IsIn(iceberg.Reference("all_nulls"), "abc", "def"), false, "should skip: in on all nulls column"}, - {iceberg.IsIn(iceberg.Reference("some_nulls"), "abc", "def"), true, "should read: in on some nulls column"}, - {iceberg.IsIn(iceberg.Reference("no_nulls"), "abc", "def"), true, "should read: in on no nulls column"}, - {iceberg.IsIn(ref, ids...), true, "should read: large in expression"}, - } - - for _, tt := range tests { - suite.Run(tt.expr.String(), func() { - eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) - suite.Require().NoError(err) - shouldRead, err := eval(suite.dataFiles[0]) - suite.Require().NoError(err) - suite.Equal(tt.expected, shouldRead, tt.msg) - }) - } -} - -func (suite *InclusiveMetricsTestSuite) TestNotInMetrics() { - ref := iceberg.Reference("id") - - tests := []struct { - expr iceberg.BooleanExpression - expected bool - msg string - }{ - {iceberg.NotIn(ref, IntMinValue-25, IntMinValue-24), true, "should read: id below lower bound"}, - {iceberg.NotIn(ref, IntMinValue-2, IntMinValue-1), true, "should read: id below lower bound"}, - {iceberg.NotIn(ref, IntMinValue-1, IntMinValue), true, "should read: id equal to lower bound"}, - {iceberg.NotIn(ref, IntMaxValue-4, IntMaxValue-3), true, "should read: id between upper and lower bounds"}, - {iceberg.NotIn(ref, IntMaxValue, IntMaxValue+1), true, "should read: id equal to upper bound"}, - {iceberg.NotIn(ref, IntMaxValue+1, IntMaxValue+2), true, "should read: id above upper bound"}, - {iceberg.NotIn(ref, IntMaxValue+6, IntMaxValue+7), true, "should read: id above upper bound"}, - {iceberg.NotIn(iceberg.Reference("all_nulls"), "abc", "def"), true, "should read: in on all nulls column"}, - {iceberg.NotIn(iceberg.Reference("some_nulls"), "abc", "def"), true, "should read: in on some nulls column"}, - {iceberg.NotIn(iceberg.Reference("no_nulls"), "abc", "def"), true, "should read: in on no nulls column"}, - } - - for _, tt := range tests { - suite.Run(tt.expr.String(), func() { - eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) - suite.Require().NoError(err) - shouldRead, err := eval(suite.dataFiles[0]) - suite.Require().NoError(err) - suite.Equal(tt.expected, shouldRead, tt.msg) - }) - } -} - -func (suite *InclusiveMetricsTestSuite) TestLessAndLessEqualNans() { - type Op func(iceberg.UnboundTerm, int32) iceberg.UnboundPredicate - for _, operator := range []Op{iceberg.LessThan[int32], iceberg.LessThanEqual[int32]} { - tests := []struct { - expr iceberg.BooleanExpression - expected bool - msg string - }{ - {operator(iceberg.Reference("all_nan"), int32(1)), false, "should skip: all nan column doesn't contain number"}, - {operator(iceberg.Reference("max_nan"), int32(1)), false, "should skip: 1 is smaller than lower bound"}, - {operator(iceberg.Reference("max_nan"), int32(10)), true, "should read: 10 is larger than lower bound"}, - {operator(iceberg.Reference("min_max_nan"), int32(1)), true, "should read: no visibility"}, - {operator(iceberg.Reference("all_nan_null_bounds"), int32(1)), false, "should skip: all nan column doesn't contain number"}, - {operator(iceberg.Reference("some_nan_correct_bounds"), int32(1)), false, "should skip: 1 is smaller than lower bound"}, - {operator(iceberg.Reference("some_nan_correct_bounds"), int32(10)), true, "should read: 10 is larger than lower bound"}, - } - - for _, tt := range tests { - suite.Run(tt.expr.String(), func() { - eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFileNan, tt.expr, true, true) - suite.Require().NoError(err) - shouldRead, err := eval(suite.dataFileNan) - suite.Require().NoError(err) - suite.Equal(tt.expected, shouldRead, tt.msg) - }) - } - } -} - -func (suite *InclusiveMetricsTestSuite) TestGreaterAndGreaterEqualNans() { - type Op func(iceberg.UnboundTerm, int32) iceberg.UnboundPredicate - for _, operator := range []Op{iceberg.GreaterThan[int32], iceberg.GreaterThanEqual[int32]} { - tests := []struct { - expr iceberg.BooleanExpression - expected bool - msg string - }{ - {operator(iceberg.Reference("all_nan"), int32(1)), false, "should skip: all nan column doesn't contain number"}, - {operator(iceberg.Reference("max_nan"), int32(1)), true, "should read: upper bound is larger than 1"}, - {operator(iceberg.Reference("max_nan"), int32(10)), true, "should read: 10 is smaller than upper bound"}, - {operator(iceberg.Reference("min_max_nan"), int32(1)), true, "should read: no visibility"}, - {operator(iceberg.Reference("all_nan_null_bounds"), int32(1)), false, "should skip: all nan column doesn't contain number"}, - {operator(iceberg.Reference("some_nan_correct_bounds"), int32(1)), true, "should read: 1 is smaller than upper bound"}, - {operator(iceberg.Reference("some_nan_correct_bounds"), int32(10)), true, "should read: 10 is smaller than upper bound"}, - {operator(iceberg.Reference("all_nan"), int32(30)), false, "should skip: 30 is larger than upper bound"}, - } - - for _, tt := range tests { - suite.Run(tt.expr.String(), func() { - eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFileNan, tt.expr, true, true) - suite.Require().NoError(err) - shouldRead, err := eval(suite.dataFileNan) - suite.Require().NoError(err) - suite.Equal(tt.expected, shouldRead, tt.msg) - }) - } - } -} - -func (suite *InclusiveMetricsTestSuite) TestEqualsNans() { - tests := []struct { - expr iceberg.BooleanExpression - expected bool - msg string - }{ - {iceberg.EqualTo(iceberg.Reference("all_nan"), int32(1)), false, "should skip: all nan column doesn't contain number"}, - {iceberg.EqualTo(iceberg.Reference("max_nan"), int32(1)), false, "should skip: 1 is smaller than lower bound"}, - {iceberg.EqualTo(iceberg.Reference("max_nan"), int32(10)), true, "should read: 10 is within bounds"}, - {iceberg.EqualTo(iceberg.Reference("min_max_nan"), int32(1)), true, "should read: no visibility"}, - {iceberg.EqualTo(iceberg.Reference("all_nan_null_bounds"), int32(1)), false, "should skip: all nan column doesn't contain number"}, - {iceberg.EqualTo(iceberg.Reference("some_nan_correct_bounds"), int32(1)), false, "should skip: 1 is smaller than lower bound"}, - {iceberg.EqualTo(iceberg.Reference("some_nan_correct_bounds"), int32(10)), true, "should read: 10 within bounds"}, - {iceberg.EqualTo(iceberg.Reference("all_nan"), int32(30)), false, "should skip: 30 is larger than upper bound"}, - } - - for _, tt := range tests { - suite.Run(tt.expr.String(), func() { - eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFileNan, tt.expr, true, true) - suite.Require().NoError(err) - shouldRead, err := eval(suite.dataFileNan) - suite.Require().NoError(err) - suite.Equal(tt.expected, shouldRead, tt.msg) - }) - } -} - -func (suite *InclusiveMetricsTestSuite) TestNotEqualsNans() { - tests := []struct { - expr iceberg.BooleanExpression - expected bool - msg string - }{ - {iceberg.NotEqualTo(iceberg.Reference("all_nan"), int32(1)), true, "should read: no visibility"}, - {iceberg.NotEqualTo(iceberg.Reference("max_nan"), int32(1)), true, "should read: no visibility"}, - {iceberg.NotEqualTo(iceberg.Reference("max_nan"), int32(10)), true, "should read: no visibility"}, - {iceberg.NotEqualTo(iceberg.Reference("min_max_nan"), int32(1)), true, "should read: no visibility"}, - {iceberg.NotEqualTo(iceberg.Reference("all_nan_null_bounds"), int32(1)), true, "should read: no visibility"}, - {iceberg.NotEqualTo(iceberg.Reference("some_nan_correct_bounds"), int32(1)), true, "should read: no visibility"}, - {iceberg.NotEqualTo(iceberg.Reference("some_nan_correct_bounds"), int32(10)), true, "should read: no visibility"}, - {iceberg.NotEqualTo(iceberg.Reference("all_nan"), int32(30)), true, "should read: no visibility"}, - } - - for _, tt := range tests { - suite.Run(tt.expr.String(), func() { - eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFileNan, tt.expr, true, true) - suite.Require().NoError(err) - shouldRead, err := eval(suite.dataFileNan) - suite.Require().NoError(err) - suite.Equal(tt.expected, shouldRead, tt.msg) - }) - } -} - -func (suite *InclusiveMetricsTestSuite) TestInWithNans() { - tests := []struct { - expr iceberg.BooleanExpression - expected bool - msg string - }{ - {iceberg.IsIn(iceberg.Reference("all_nan"), int32(1), 10, 30), false, "should skip: all nan column doesn't contain number"}, - {iceberg.IsIn(iceberg.Reference("max_nan"), int32(1), 10, 30), true, "should read: 10 and 30 are greater than lower bound"}, - {iceberg.IsIn(iceberg.Reference("min_max_nan"), int32(1), 10, 30), true, "should read: no visibility"}, - {iceberg.IsIn(iceberg.Reference("all_nan_null_bounds"), int32(1), 10, 30), false, "should skip: all nan column doesn't contain number"}, - {iceberg.IsIn(iceberg.Reference("some_nan_correct_bounds"), int32(1), 10, 30), true, "should read: 10 within bounds"}, - {iceberg.IsIn(iceberg.Reference("some_nan_correct_bounds"), int32(1), 30), false, "should skip: 1 and 30 not within bounds"}, - {iceberg.IsIn(iceberg.Reference("some_nan_correct_bounds"), int32(5), 7), true, "should read: overlap with lower bound"}, - {iceberg.IsIn(iceberg.Reference("some_nan_correct_bounds"), int32(22), 25), true, "should read: overlap with upper bound"}, - } - - for _, tt := range tests { - suite.Run(tt.expr.String(), func() { - eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFileNan, tt.expr, true, true) - suite.Require().NoError(err) - shouldRead, err := eval(suite.dataFileNan) - suite.Require().NoError(err) - suite.Equal(tt.expected, shouldRead, tt.msg) - }) - } -} - -func (suite *InclusiveMetricsTestSuite) TestNotInWithNans() { - tests := []struct { - expr iceberg.BooleanExpression - expected bool - msg string - }{ - {iceberg.NotIn(iceberg.Reference("all_nan"), int32(1), 10, 30), true, "should read: no visibility"}, - {iceberg.NotIn(iceberg.Reference("max_nan"), int32(1), 10, 30), true, "should read: no visibility"}, - {iceberg.NotIn(iceberg.Reference("min_max_nan"), int32(1), 10, 30), true, "should read: no visibility"}, - {iceberg.NotIn(iceberg.Reference("all_nan_null_bounds"), int32(1), 10, 30), true, "should read: no visibility"}, - {iceberg.NotIn(iceberg.Reference("some_nan_correct_bounds"), int32(1), 10, 30), true, "should read: no visibility"}, - } - - for _, tt := range tests { - suite.Run(tt.expr.String(), func() { - eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFileNan, tt.expr, true, true) - suite.Require().NoError(err) - shouldRead, err := eval(suite.dataFileNan) - suite.Require().NoError(err) - suite.Equal(tt.expected, shouldRead, tt.msg) - }) - } -} - -func (suite *InclusiveMetricsTestSuite) TestStartsWith() { - ref, refEmpty := iceberg.Reference("required"), iceberg.Reference("some_empty") - - tests := []struct { - expr iceberg.BooleanExpression - expected bool - dataFile iceberg.DataFile - msg string - }{ - {iceberg.StartsWith(ref, "a"), true, suite.dataFiles[0], "should read: no stats"}, - {iceberg.StartsWith(ref, "a"), true, suite.dataFiles[1], "should read: range matches"}, - {iceberg.StartsWith(ref, "aa"), true, suite.dataFiles[1], "should read: range matches"}, - {iceberg.StartsWith(ref, "aaa"), true, suite.dataFiles[1], "should read: range matches"}, - {iceberg.StartsWith(ref, "1s"), true, suite.dataFiles[2], "should read: range matches"}, - {iceberg.StartsWith(ref, "1str1x"), true, suite.dataFiles[2], "should read: range matches"}, - {iceberg.StartsWith(ref, "ff"), true, suite.dataFiles[3], "should read: range matches"}, - {iceberg.StartsWith(ref, "aB"), false, suite.dataFiles[1], "should skip: range doesn't match"}, - {iceberg.StartsWith(ref, "dWx"), false, suite.dataFiles[1], "should skip: range doesn't match"}, - {iceberg.StartsWith(ref, "5"), false, suite.dataFiles[2], "should skip: range doesn't match"}, - {iceberg.StartsWith(ref, "3str3x"), false, suite.dataFiles[2], "should skip: range doesn't match"}, - {iceberg.StartsWith(refEmpty, "房东整租霍"), true, suite.dataFiles[0], "should read: range matches"}, - {iceberg.StartsWith(iceberg.Reference("all_nulls"), ""), false, suite.dataFiles[0], "should skip: range doesn't match"}, - } - - for _, tt := range tests { - suite.Run(tt.expr.String(), func() { - eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) - suite.Require().NoError(err) - shouldRead, err := eval(tt.dataFile) - suite.Require().NoError(err) - suite.Equal(tt.expected, shouldRead, tt.msg) - }) - } -} - -func (suite *InclusiveMetricsTestSuite) TestNotStartsWith() { - ref, refEmpty := iceberg.Reference("required"), iceberg.Reference("some_empty") - - tests := []struct { - expr iceberg.BooleanExpression - expected bool - dataFile iceberg.DataFile - msg string - }{ - {iceberg.NotStartsWith(ref, "a"), true, suite.dataFiles[0], "should read: no stats"}, - {iceberg.NotStartsWith(ref, "a"), true, suite.dataFiles[1], "should read: range matches"}, - {iceberg.NotStartsWith(ref, "aa"), true, suite.dataFiles[1], "should read: range matches"}, - {iceberg.NotStartsWith(ref, "aaa"), true, suite.dataFiles[1], "should read: range matches"}, - {iceberg.NotStartsWith(ref, "1s"), true, suite.dataFiles[2], "should read: range matches"}, - {iceberg.NotStartsWith(ref, "1str1x"), true, suite.dataFiles[2], "should read: range matches"}, - {iceberg.NotStartsWith(ref, "ff"), true, suite.dataFiles[3], "should read: range matches"}, - {iceberg.NotStartsWith(ref, "aB"), true, suite.dataFiles[1], "should read: range doesn't match"}, - {iceberg.NotStartsWith(ref, "dWx"), true, suite.dataFiles[1], "should read: range doesn't match"}, - {iceberg.NotStartsWith(ref, "5"), true, suite.dataFiles[2], "should read: range doesn't match"}, - {iceberg.NotStartsWith(ref, "3str3x"), true, suite.dataFiles[2], "should read: range doesn't match"}, - {iceberg.NotStartsWith(refEmpty, "房东整租霍"), true, suite.dataFiles[0], "should read: range matches"}, - } - - for _, tt := range tests { - suite.Run(tt.expr.String(), func() { - eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) - suite.Require().NoError(err) - shouldRead, err := eval(tt.dataFile) - suite.Require().NoError(err) - suite.Equal(tt.expected, shouldRead, tt.msg) - }) - } -} - -func TestEvaluators(t *testing.T) { - suite.Run(t, &ProjectionTestSuite{}) - suite.Run(t, &InclusiveMetricsTestSuite{}) -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package table + +import ( + "math" + "testing" + + "github.com/apache/iceberg-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +const ( + IntMinValue, IntMaxValue int32 = 30, 79 +) + +func TestManifestEvaluator(t *testing.T) { + + var ( + IntMin, IntMax = []byte{byte(IntMinValue), 0x00, 0x00, 0x00}, []byte{byte(IntMaxValue), 0x00, 0x00, 0x00} + StringMin, StringMax = []byte("a"), []byte("z") + FloatMin, _ = iceberg.Float32Literal(0).MarshalBinary() + FloatMax, _ = iceberg.Float32Literal(20).MarshalBinary() + DblMin, _ = iceberg.Float64Literal(0).MarshalBinary() + DblMax, _ = iceberg.Float64Literal(20).MarshalBinary() + NanTrue, NanFalse = true, false + + testSchema = iceberg.NewSchema(1, + iceberg.NestedField{ID: 1, Name: "id", + Type: iceberg.PrimitiveTypes.Int32, Required: true}, + iceberg.NestedField{ID: 2, Name: "all_nulls_missing_nan", + Type: iceberg.PrimitiveTypes.String, Required: false}, + iceberg.NestedField{ID: 3, Name: "some_nulls", + Type: iceberg.PrimitiveTypes.String, Required: false}, + iceberg.NestedField{ID: 4, Name: "no_nulls", + Type: iceberg.PrimitiveTypes.String, Required: false}, + iceberg.NestedField{ID: 5, Name: "float", + Type: iceberg.PrimitiveTypes.Float32, Required: false}, + iceberg.NestedField{ID: 6, Name: "all_nulls_double", + Type: iceberg.PrimitiveTypes.Float64, Required: false}, + iceberg.NestedField{ID: 7, Name: "all_nulls_no_nans", + Type: iceberg.PrimitiveTypes.Float32, Required: false}, + iceberg.NestedField{ID: 8, Name: "all_nans", + Type: iceberg.PrimitiveTypes.Float64, Required: false}, + iceberg.NestedField{ID: 9, Name: "both_nan_and_null", + Type: iceberg.PrimitiveTypes.Float32, Required: false}, + iceberg.NestedField{ID: 10, Name: "no_nan_or_null", + Type: iceberg.PrimitiveTypes.Float64, Required: false}, + iceberg.NestedField{ID: 11, Name: "all_nulls_missing_nan_float", + Type: iceberg.PrimitiveTypes.Float32, Required: false}, + iceberg.NestedField{ID: 12, Name: "all_same_value_or_null", + Type: iceberg.PrimitiveTypes.String, Required: false}, + iceberg.NestedField{ID: 13, Name: "no_nulls_same_value_a", + Type: iceberg.PrimitiveTypes.Binary, Required: false}, + ) + ) + + partFields := make([]iceberg.PartitionField, 0, testSchema.NumFields()) + for _, f := range testSchema.Fields() { + partFields = append(partFields, iceberg.PartitionField{ + Name: f.Name, + SourceID: f.ID, + FieldID: f.ID, + Transform: iceberg.IdentityTransform{}, + }) + } + + spec := iceberg.NewPartitionSpec(partFields...) + manifestNoStats := iceberg.NewManifestV1Builder("", 0, 0, 0).Build() + manifest := iceberg.NewManifestV1Builder("", 0, 0, 0).Partitions( + []iceberg.FieldSummary{ + { // id + ContainsNull: false, + ContainsNaN: nil, + LowerBound: &IntMin, + UpperBound: &IntMax, + }, + { // all_nulls_missing_nan + ContainsNull: true, + ContainsNaN: nil, + LowerBound: nil, + UpperBound: nil, + }, + { // some_nulls + ContainsNull: true, + ContainsNaN: nil, + LowerBound: &StringMin, + UpperBound: &StringMax, + }, + { // no_nulls + ContainsNull: false, + ContainsNaN: nil, + LowerBound: &StringMin, + UpperBound: &StringMax, + }, + { // float + ContainsNull: true, + ContainsNaN: nil, + LowerBound: &FloatMin, + UpperBound: &FloatMax, + }, + { // all_nulls_double + ContainsNull: true, + ContainsNaN: nil, + LowerBound: nil, + UpperBound: nil, + }, + { // all_nulls_no_nans + ContainsNull: true, + ContainsNaN: &NanFalse, + LowerBound: nil, + UpperBound: nil, + }, + { // all_nans + ContainsNull: false, + ContainsNaN: &NanTrue, + LowerBound: nil, + UpperBound: nil, + }, + { // both_nan_and_null + ContainsNull: true, + ContainsNaN: &NanTrue, + LowerBound: nil, + UpperBound: nil, + }, + { // no_nan_or_null + ContainsNull: false, + ContainsNaN: &NanFalse, + LowerBound: &DblMin, + UpperBound: &DblMax, + }, + { // all_nulls_missing_nan_float + ContainsNull: true, + ContainsNaN: nil, + LowerBound: nil, + UpperBound: nil, + }, + { // all_same_value_or_null + ContainsNull: true, + ContainsNaN: nil, + LowerBound: &StringMin, + UpperBound: &StringMin, + }, + { // no_nulls_same_value_a + ContainsNull: false, + ContainsNaN: nil, + LowerBound: &StringMin, + UpperBound: &StringMin, + }, + }).Build() + + t.Run("all nulls", func(t *testing.T) { + tests := []struct { + field string + expected bool + msg string + }{ + {"all_nulls_missing_nan", false, "should skip: all nulls column with non-floating type contains all null"}, + {"all_nulls_missing_nan_float", true, "should read: no NaN information may indicate presence of NaN value"}, + {"some_nulls", true, "should read: column with some nulls contains a non-null value"}, + {"no_nulls", true, "should read: non-null column contains a non-null value"}, + } + + for _, tt := range tests { + eval, err := newManifestEvaluator(spec, testSchema, + iceberg.NotNull(iceberg.Reference(tt.field)), true) + require.NoError(t, err) + + result, err := eval(manifest) + require.NoError(t, err) + assert.Equal(t, tt.expected, result, tt.msg) + } + }) + + t.Run("no nulls", func(t *testing.T) { + tests := []struct { + field string + expected bool + msg string + }{ + {"all_nulls_missing_nan", true, "should read: at least one null value in all null column"}, + {"some_nulls", true, "should read: column with some nulls contains a null value"}, + {"no_nulls", false, "should skip: non-null column contains no null values"}, + {"both_nan_and_null", true, "should read: both_nan_and_null column contains no null values"}, + } + + for _, tt := range tests { + eval, err := newManifestEvaluator(spec, testSchema, + iceberg.IsNull(iceberg.Reference(tt.field)), true) + require.NoError(t, err) + + result, err := eval(manifest) + require.NoError(t, err) + assert.Equal(t, tt.expected, result, tt.msg) + } + }) + + t.Run("is nan", func(t *testing.T) { + tests := []struct { + field string + expected bool + msg string + }{ + {"float", true, "should read: no information on if there are nan values in float column"}, + {"all_nulls_double", true, "should read: no NaN information may indicate presence of NaN value"}, + {"all_nulls_missing_nan_float", true, "should read: no NaN information may indicate presence of NaN value"}, + {"all_nulls_no_nans", false, "should skip: no nan column doesn't contain nan value"}, + {"all_nans", true, "should read: all_nans column contains nan value"}, + {"both_nan_and_null", true, "should read: both_nan_and_null column contains nan value"}, + {"no_nan_or_null", false, "should skip: no_nan_or_null column doesn't contain nan value"}, + } + + for _, tt := range tests { + eval, err := newManifestEvaluator(spec, testSchema, + iceberg.IsNaN(iceberg.Reference(tt.field)), true) + require.NoError(t, err) + + result, err := eval(manifest) + require.NoError(t, err) + assert.Equal(t, tt.expected, result, tt.msg) + } + }) + + t.Run("not nan", func(t *testing.T) { + tests := []struct { + field string + expected bool + msg string + }{ + {"float", true, "should read: no information on if there are nan values in float column"}, + {"all_nulls_double", true, "should read: all null column contains non nan value"}, + {"all_nulls_no_nans", true, "should read: no_nans column contains non nan value"}, + {"all_nans", false, "should skip: all nans columndoesn't contain non nan value"}, + {"both_nan_and_null", true, "should read: both_nan_and_null nans column contains non nan value"}, + {"no_nan_or_null", true, "should read: no_nan_or_null column contains non nan value"}, + } + + for _, tt := range tests { + eval, err := newManifestEvaluator(spec, testSchema, + iceberg.NotNaN(iceberg.Reference(tt.field)), true) + require.NoError(t, err) + + result, err := eval(manifest) + require.NoError(t, err) + assert.Equal(t, tt.expected, result, tt.msg) + } + }) + + t.Run("test missing stats", func(t *testing.T) { + exprs := []iceberg.BooleanExpression{ + iceberg.LessThan(iceberg.Reference("id"), int32(5)), + iceberg.LessThanEqual(iceberg.Reference("id"), int32(30)), + iceberg.EqualTo(iceberg.Reference("id"), int32(70)), + iceberg.GreaterThan(iceberg.Reference("id"), int32(78)), + iceberg.GreaterThanEqual(iceberg.Reference("id"), int32(90)), + iceberg.NotEqualTo(iceberg.Reference("id"), int32(101)), + iceberg.IsNull(iceberg.Reference("id")), + iceberg.NotNull(iceberg.Reference("id")), + iceberg.IsNaN(iceberg.Reference("float")), + iceberg.NotNaN(iceberg.Reference("float")), + } + + for _, tt := range exprs { + eval, err := newManifestEvaluator(spec, testSchema, tt, true) + require.NoError(t, err) + + result, err := eval(manifestNoStats) + require.NoError(t, err) + assert.Truef(t, result, "should read when missing stats for expr: %s", tt) + } + }) + + t.Run("test exprs", func(t *testing.T) { + tests := []struct { + expr iceberg.BooleanExpression + expect bool + msg string + }{ + {iceberg.NewNot(iceberg.LessThan(iceberg.Reference("id"), int32(IntMinValue-25))), + true, "should read: not(false)"}, + {iceberg.NewNot(iceberg.GreaterThan(iceberg.Reference("id"), int32(IntMinValue-25))), + false, "should skip: not(true)"}, + {iceberg.NewAnd( + iceberg.LessThan(iceberg.Reference("id"), int32(IntMinValue-25)), + iceberg.GreaterThanEqual(iceberg.Reference("id"), int32(IntMinValue-30))), + false, "should skip: and(false, true)"}, + {iceberg.NewAnd( + iceberg.LessThan(iceberg.Reference("id"), int32(IntMinValue-25)), + iceberg.GreaterThanEqual(iceberg.Reference("id"), int32(IntMaxValue+1))), + false, "should skip: and(false, false)"}, + {iceberg.NewAnd( + iceberg.GreaterThan(iceberg.Reference("id"), int32(IntMinValue-25)), + iceberg.LessThanEqual(iceberg.Reference("id"), int32(IntMinValue))), + true, "should read: and(true, true)"}, + {iceberg.NewOr( + iceberg.LessThan(iceberg.Reference("id"), int32(IntMinValue-25)), + iceberg.GreaterThanEqual(iceberg.Reference("id"), int32(IntMaxValue+1))), + false, "should skip: or(false, false)"}, + {iceberg.NewOr( + iceberg.LessThan(iceberg.Reference("id"), int32(IntMinValue-25)), + iceberg.GreaterThanEqual(iceberg.Reference("id"), int32(IntMaxValue-19))), + true, "should read: or(false, true)"}, + {iceberg.LessThan(iceberg.Reference("some_nulls"), "1"), false, + "should not read: id range below lower bound"}, + {iceberg.LessThan(iceberg.Reference("some_nulls"), "b"), true, + "should read: lower bound in range"}, + {iceberg.LessThan(iceberg.Reference("float"), 15.50), true, + "should read: lower bound in range"}, + {iceberg.LessThan(iceberg.Reference("no_nan_or_null"), 15.50), true, + "should read: lower bound in range"}, + {iceberg.LessThanEqual(iceberg.Reference("no_nulls_same_value_a"), "a"), true, + "should read: lower bound in range"}, + {iceberg.LessThan(iceberg.Reference("id"), int32(IntMinValue-25)), false, + "should not read: id range below lower bound (5 < 30)"}, + {iceberg.LessThan(iceberg.Reference("id"), int32(IntMinValue)), false, + "should not read: id range below lower bound (30 is not < 30)"}, + {iceberg.LessThan(iceberg.Reference("id"), int32(IntMinValue+1)), true, + "should read: one possible id"}, + {iceberg.LessThan(iceberg.Reference("id"), int32(IntMaxValue)), true, + "should read: many possible ids"}, + {iceberg.LessThanEqual(iceberg.Reference("id"), int32(IntMinValue-25)), false, + "should not read: id range below lower bound (5 < 30)"}, + {iceberg.LessThanEqual(iceberg.Reference("id"), int32(IntMinValue-1)), false, + "should not read: id range below lower bound 29 < 30"}, + {iceberg.LessThanEqual(iceberg.Reference("id"), int32(IntMinValue)), true, + "should read: one possible id"}, + {iceberg.LessThanEqual(iceberg.Reference("id"), int32(IntMaxValue)), true, + "should read: many possible ids"}, + {iceberg.GreaterThan(iceberg.Reference("id"), int32(IntMaxValue+6)), false, + "should not read: id range above upper bound (85 < 79)"}, + {iceberg.GreaterThan(iceberg.Reference("id"), int32(IntMaxValue)), false, + "should not read: id range above upper bound (79 is not > 79)"}, + {iceberg.GreaterThan(iceberg.Reference("id"), int32(IntMaxValue-1)), true, + "should read: one possible id"}, + {iceberg.GreaterThan(iceberg.Reference("id"), int32(IntMaxValue-4)), true, + "should read: many possible ids"}, + {iceberg.GreaterThanEqual(iceberg.Reference("id"), int32(IntMaxValue+6)), false, + "should not read: id range is above upper bound (85 < 79)"}, + {iceberg.GreaterThanEqual(iceberg.Reference("id"), int32(IntMaxValue+1)), false, + "should not read: id range above upper bound (80 > 79)"}, + {iceberg.GreaterThanEqual(iceberg.Reference("id"), int32(IntMaxValue)), true, + "should read: one possible id"}, + {iceberg.GreaterThanEqual(iceberg.Reference("id"), int32(IntMaxValue)), true, + "should read: many possible ids"}, + {iceberg.EqualTo(iceberg.Reference("id"), int32(IntMinValue-25)), false, + "should not read: id below lower bound"}, + {iceberg.EqualTo(iceberg.Reference("id"), int32(IntMinValue-1)), false, + "should not read: id below lower bound"}, + {iceberg.EqualTo(iceberg.Reference("id"), int32(IntMinValue)), true, + "should read: id equal to lower bound"}, + {iceberg.EqualTo(iceberg.Reference("id"), int32(IntMaxValue-4)), true, + "should read: id between lower and upper bounds"}, + {iceberg.EqualTo(iceberg.Reference("id"), int32(IntMaxValue)), true, + "should read: id equal to upper bound"}, + {iceberg.EqualTo(iceberg.Reference("id"), int32(IntMaxValue+1)), false, + "should not read: id above upper bound"}, + {iceberg.EqualTo(iceberg.Reference("id"), int32(IntMaxValue+6)), false, + "should not read: id above upper bound"}, + {iceberg.NotEqualTo(iceberg.Reference("id"), int32(IntMinValue-25)), true, + "should read: id below lower bound"}, + {iceberg.NotEqualTo(iceberg.Reference("id"), int32(IntMinValue-1)), true, + "should read: id below lower bound"}, + {iceberg.NotEqualTo(iceberg.Reference("id"), int32(IntMinValue)), true, + "should read: id equal to lower bound"}, + {iceberg.NotEqualTo(iceberg.Reference("id"), int32(IntMaxValue-4)), true, + "should read: id between lower and upper bounds"}, + {iceberg.NotEqualTo(iceberg.Reference("id"), int32(IntMaxValue)), true, + "should read: id equal to upper bound"}, + {iceberg.NotEqualTo(iceberg.Reference("id"), int32(IntMaxValue+1)), true, + "should read: id above upper bound"}, + {iceberg.NotEqualTo(iceberg.Reference("id"), int32(IntMaxValue+6)), true, + "should read: id above upper bound"}, + {iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("id"), int32(IntMinValue-25))), true, + "should read: id below lower bound"}, + {iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("id"), int32(IntMinValue-1))), true, + "should read: id below lower bound"}, + {iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("id"), int32(IntMinValue))), true, + "should read: id equal to lower bound"}, + {iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("id"), int32(IntMaxValue-4))), true, + "should read: id between lower and upper bounds"}, + {iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("id"), int32(IntMaxValue))), true, + "should read: id equal to upper bound"}, + {iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("id"), int32(IntMaxValue+1))), true, + "should read: id above upper bound"}, + {iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("id"), int32(IntMaxValue+6))), true, + "should read: id above upper bound"}, + {iceberg.IsIn(iceberg.Reference("id"), int32(IntMinValue-25), IntMinValue-24), false, + "should not read: id below lower bound (5 < 30, 6 < 30)"}, + {iceberg.IsIn(iceberg.Reference("id"), int32(IntMinValue-2), IntMinValue-1), false, + "should not read: id below lower bound (28 < 30, 29 < 30)"}, + {iceberg.IsIn(iceberg.Reference("id"), int32(IntMinValue-1), IntMinValue), true, + "should read: id equal to lower bound (30 == 30)"}, + {iceberg.IsIn(iceberg.Reference("id"), int32(IntMaxValue-4), IntMaxValue-3), true, + "should read: id between lower and upper bounds (30 < 75 < 79, 30 < 76 < 79)"}, + {iceberg.IsIn(iceberg.Reference("id"), int32(IntMaxValue), IntMaxValue+1), true, + "should read: id equal to upper bound (79 == 79)"}, + {iceberg.IsIn(iceberg.Reference("id"), int32(IntMaxValue+1), IntMaxValue+2), false, + "should not read: id above upper bound (80 > 79, 81 > 79)"}, + {iceberg.IsIn(iceberg.Reference("id"), int32(IntMaxValue+6), IntMaxValue+7), false, + "should not read: id above upper bound (85 > 79, 86 > 79)"}, + {iceberg.IsIn(iceberg.Reference("all_nulls_missing_nan"), "abc", "def"), false, + "should skip: in on all nulls column"}, + {iceberg.IsIn(iceberg.Reference("some_nulls"), "abc", "def"), true, + "should read: in on some nulls column"}, + {iceberg.IsIn(iceberg.Reference("no_nulls"), "abc", "def"), true, + "should read: in on no nulls column"}, + {iceberg.IsIn(iceberg.Reference("no_nulls_same_value_a"), "a", "b"), true, + "should read: in on no nulls column"}, + {iceberg.IsIn(iceberg.Reference("float"), 0, -5.5), true, + "should read: float equal to lower bound"}, + {iceberg.IsIn(iceberg.Reference("no_nan_or_null"), 0, -5.5), true, + "should read: float equal to lower bound"}, + {iceberg.NotIn(iceberg.Reference("id"), int32(IntMinValue-25), IntMinValue-24), true, + "should read: id below lower bound (5 < 30, 6 < 30)"}, + {iceberg.NotIn(iceberg.Reference("id"), int32(IntMinValue-2), IntMinValue-1), true, + "should read: id below lower bound (28 < 30, 29 < 30)"}, + {iceberg.NotIn(iceberg.Reference("id"), int32(IntMinValue-1), IntMinValue), true, + "should read: id equal to lower bound (30 == 30)"}, + {iceberg.NotIn(iceberg.Reference("id"), int32(IntMaxValue-4), IntMaxValue-3), true, + "should read: id between lower and upper bounds (30 < 75 < 79, 30 < 76 < 79)"}, + {iceberg.NotIn(iceberg.Reference("id"), int32(IntMaxValue), IntMaxValue+1), true, + "should read: id equal to upper bound (79 == 79)"}, + {iceberg.NotIn(iceberg.Reference("id"), int32(IntMaxValue+1), IntMaxValue+2), true, + "should read: id above upper bound (80 > 79, 81 > 79)"}, + {iceberg.NotIn(iceberg.Reference("id"), int32(IntMaxValue+6), IntMaxValue+7), true, + "should read: id above upper bound (85 > 79, 86 > 79)"}, + {iceberg.NotIn(iceberg.Reference("all_nulls_missing_nan"), "abc", "def"), true, + "should read: notIn on all nulls column"}, + {iceberg.NotIn(iceberg.Reference("some_nulls"), "abc", "def"), true, + "should read: notIn on some nulls column"}, + {iceberg.NotIn(iceberg.Reference("no_nulls"), "abc", "def"), true, + "should read: notIn on no nulls column"}, + {iceberg.StartsWith(iceberg.Reference("some_nulls"), "a"), true, + "should read: range matches"}, + {iceberg.StartsWith(iceberg.Reference("some_nulls"), "aa"), true, + "should read: range matches"}, + {iceberg.StartsWith(iceberg.Reference("some_nulls"), "dddd"), true, + "should read: range matches"}, + {iceberg.StartsWith(iceberg.Reference("some_nulls"), "z"), true, + "should read: range matches"}, + {iceberg.StartsWith(iceberg.Reference("no_nulls"), "a"), true, + "should read: range matches"}, + {iceberg.StartsWith(iceberg.Reference("some_nulls"), "zzzz"), false, + "should skip: range doesn't match"}, + {iceberg.StartsWith(iceberg.Reference("some_nulls"), "1"), false, + "should skip: range doesn't match"}, + {iceberg.StartsWith(iceberg.Reference("no_nulls_same_value_a"), "a"), true, + "should read: all values start with the prefix"}, + {iceberg.NotStartsWith(iceberg.Reference("some_nulls"), "a"), true, + "should read: range matches"}, + {iceberg.NotStartsWith(iceberg.Reference("some_nulls"), "aa"), true, + "should read: range matches"}, + {iceberg.NotStartsWith(iceberg.Reference("some_nulls"), "dddd"), true, + "should read: range matches"}, + {iceberg.NotStartsWith(iceberg.Reference("some_nulls"), "z"), true, + "should read: range matches"}, + {iceberg.NotStartsWith(iceberg.Reference("no_nulls"), "a"), true, + "should read: range matches"}, + {iceberg.NotStartsWith(iceberg.Reference("some_nulls"), "zzzz"), true, + "should read: range matches"}, + {iceberg.NotStartsWith(iceberg.Reference("some_nulls"), "1"), true, + "should read: range matches"}, + {iceberg.NotStartsWith(iceberg.Reference("all_same_value_or_null"), "a"), true, + "should read: range matches"}, + {iceberg.NotStartsWith(iceberg.Reference("all_same_value_or_null"), "aa"), true, + "should read: range matches"}, + {iceberg.NotStartsWith(iceberg.Reference("all_same_value_or_null"), "A"), true, + "should read: range matches"}, + // Iceberg does not implement SQL 3-way boolean logic, so the choice of an + // all null column matching is by definition in order to surface more values + // to the query engine to allow it to make its own decision + {iceberg.NotStartsWith(iceberg.Reference("all_nulls_missing_nan"), "A"), true, + "should read: range matches"}, + {iceberg.NotStartsWith(iceberg.Reference("no_nulls_same_value_a"), "a"), false, + "should not read: all values start with the prefix"}, + } + + for _, tt := range tests { + t.Run(tt.expr.String(), func(t *testing.T) { + eval, err := newManifestEvaluator(spec, testSchema, + tt.expr, true) + require.NoError(t, err) + + result, err := eval(manifest) + require.NoError(t, err) + assert.Equal(t, tt.expect, result, tt.msg) + }) + } + }) +} + +type ProjectionTestSuite struct { + suite.Suite +} + +func (*ProjectionTestSuite) schema() *iceberg.Schema { + return iceberg.NewSchema(0, + iceberg.NestedField{ID: 1, Name: "id", Type: iceberg.PrimitiveTypes.Int64}, + iceberg.NestedField{ID: 2, Name: "data", Type: iceberg.PrimitiveTypes.String}, + iceberg.NestedField{ID: 3, Name: "event_date", Type: iceberg.PrimitiveTypes.Date}, + iceberg.NestedField{ID: 4, Name: "event_ts", Type: iceberg.PrimitiveTypes.Timestamp}, + ) +} + +func (*ProjectionTestSuite) emptySpec() iceberg.PartitionSpec { + return iceberg.NewPartitionSpec() +} + +func (*ProjectionTestSuite) idSpec() iceberg.PartitionSpec { + return iceberg.NewPartitionSpec( + iceberg.PartitionField{SourceID: 1, FieldID: 1000, + Transform: iceberg.IdentityTransform{}, Name: "id_part"}, + ) +} + +func (*ProjectionTestSuite) bucketSpec() iceberg.PartitionSpec { + return iceberg.NewPartitionSpec( + iceberg.PartitionField{SourceID: 2, FieldID: 1000, + Transform: iceberg.BucketTransform{NumBuckets: 16}, Name: "data_bucket"}, + ) +} + +func (*ProjectionTestSuite) daySpec() iceberg.PartitionSpec { + return iceberg.NewPartitionSpec( + iceberg.PartitionField{SourceID: 4, FieldID: 1000, + Transform: iceberg.DayTransform{}, Name: "date"}, + iceberg.PartitionField{SourceID: 3, FieldID: 1001, + Transform: iceberg.DayTransform{}, Name: "ddate"}, + ) +} + +func (*ProjectionTestSuite) hourSpec() iceberg.PartitionSpec { + return iceberg.NewPartitionSpec( + iceberg.PartitionField{SourceID: 4, FieldID: 1000, + Transform: iceberg.HourTransform{}, Name: "hour"}, + ) +} + +func (*ProjectionTestSuite) truncateStrSpec() iceberg.PartitionSpec { + return iceberg.NewPartitionSpec( + iceberg.PartitionField{SourceID: 2, FieldID: 1000, + Transform: iceberg.TruncateTransform{Width: 2}, Name: "data_trunc"}, + ) +} + +func (*ProjectionTestSuite) truncateIntSpec() iceberg.PartitionSpec { + return iceberg.NewPartitionSpec( + iceberg.PartitionField{SourceID: 1, FieldID: 1000, + Transform: iceberg.TruncateTransform{Width: 10}, Name: "id_trunc"}, + ) +} + +func (*ProjectionTestSuite) idAndBucketSpec() iceberg.PartitionSpec { + return iceberg.NewPartitionSpec( + iceberg.PartitionField{SourceID: 1, FieldID: 1000, + Transform: iceberg.IdentityTransform{}, Name: "id_part"}, + iceberg.PartitionField{SourceID: 2, FieldID: 1001, + Transform: iceberg.BucketTransform{NumBuckets: 16}, Name: "data_bucket"}, + ) +} + +func (p *ProjectionTestSuite) TestIdentityProjection() { + schema, spec := p.schema(), p.idSpec() + + idRef, idPartRef := iceberg.Reference("id"), iceberg.Reference("id_part") + tests := []struct { + pred iceberg.BooleanExpression + expected iceberg.BooleanExpression + }{ + {iceberg.NotNull(idRef), iceberg.NotNull(idPartRef)}, + {iceberg.IsNull(idRef), iceberg.IsNull(idPartRef)}, + {iceberg.LessThan(idRef, int64(100)), iceberg.LessThan(idPartRef, int64(100))}, + {iceberg.LessThanEqual(idRef, int64(101)), iceberg.LessThanEqual(idPartRef, int64(101))}, + {iceberg.GreaterThan(idRef, int64(102)), iceberg.GreaterThan(idPartRef, int64(102))}, + {iceberg.GreaterThanEqual(idRef, int64(103)), iceberg.GreaterThanEqual(idPartRef, int64(103))}, + {iceberg.EqualTo(idRef, int64(104)), iceberg.EqualTo(idPartRef, int64(104))}, + {iceberg.NotEqualTo(idRef, int64(105)), iceberg.NotEqualTo(idPartRef, int64(105))}, + {iceberg.IsIn(idRef, int64(3), 4, 5), iceberg.IsIn(idPartRef, int64(3), 4, 5)}, + {iceberg.NotIn(idRef, int64(3), 4, 5), iceberg.NotIn(idPartRef, int64(3), 4, 5)}, + } + + project := newInclusiveProjection(schema, spec, true) + for _, tt := range tests { + p.Run(tt.pred.String(), func() { + expr, err := project(tt.pred) + p.Require().NoError(err) + p.Truef(tt.expected.Equals(expr), "expected: %s\ngot: %s", tt.expected, expr) + }) + } +} + +func (p *ProjectionTestSuite) TestBucketProjection() { + schema, spec := p.schema(), p.bucketSpec() + + dataRef, dataBkt := iceberg.Reference("data"), iceberg.Reference("data_bucket") + tests := []struct { + pred, expected iceberg.BooleanExpression + }{ + {iceberg.NotNull(dataRef), iceberg.NotNull(dataBkt)}, + {iceberg.IsNull(dataRef), iceberg.IsNull(dataBkt)}, + {iceberg.LessThan(dataRef, "val"), iceberg.AlwaysTrue{}}, + {iceberg.LessThanEqual(dataRef, "val"), iceberg.AlwaysTrue{}}, + {iceberg.GreaterThan(dataRef, "val"), iceberg.AlwaysTrue{}}, + {iceberg.GreaterThanEqual(dataRef, "val"), iceberg.AlwaysTrue{}}, + {iceberg.EqualTo(dataRef, "val"), iceberg.EqualTo(dataBkt, int32(14))}, + {iceberg.NotEqualTo(dataRef, "val"), iceberg.AlwaysTrue{}}, + {iceberg.IsIn(dataRef, "v1", "v2", "v3"), iceberg.IsIn(dataBkt, int32(1), 3, 13)}, + {iceberg.NotIn(dataRef, "v1", "v2", "v3"), iceberg.AlwaysTrue{}}, + } + + project := newInclusiveProjection(schema, spec, true) + for _, tt := range tests { + p.Run(tt.pred.String(), func() { + expr, err := project(tt.pred) + p.Require().NoError(err) + p.Truef(tt.expected.Equals(expr), "expected: %s\ngot: %s", tt.expected, expr) + }) + } +} + +func (p *ProjectionTestSuite) TestHourProjection() { + schema, spec := p.schema(), p.hourSpec() + + ref, hour := iceberg.Reference("event_ts"), iceberg.Reference("hour") + tests := []struct { + pred, expected iceberg.BooleanExpression + }{ + {iceberg.NotNull(ref), iceberg.NotNull(hour)}, + {iceberg.IsNull(ref), iceberg.IsNull(hour)}, + {iceberg.LessThan(ref, "2022-11-27T10:00:00"), iceberg.LessThanEqual(hour, int32(463761))}, + {iceberg.LessThanEqual(ref, "2022-11-27T10:00:00"), iceberg.LessThanEqual(hour, int32(463762))}, + {iceberg.GreaterThan(ref, "2022-11-27T09:59:59.999999"), iceberg.GreaterThanEqual(hour, int32(463762))}, + {iceberg.GreaterThanEqual(ref, "2022-11-27T09:59:59.999999"), iceberg.GreaterThanEqual(hour, int32(463761))}, + {iceberg.EqualTo(ref, "2022-11-27T10:00:00"), iceberg.EqualTo(hour, int32(463762))}, + {iceberg.NotEqualTo(ref, "2022-11-27T10:00:00"), iceberg.AlwaysTrue{}}, + {iceberg.IsIn(ref, "2022-11-27T10:00:00", "2022-11-27T09:59:59.999999"), iceberg.IsIn(hour, int32(463761), 463762)}, + {iceberg.NotIn(ref, "2022-11-27T10:00:00", "2022-11-27T09:59:59.999999"), iceberg.AlwaysTrue{}}, + } + + project := newInclusiveProjection(schema, spec, true) + for _, tt := range tests { + p.Run(tt.pred.String(), func() { + expr, err := project(tt.pred) + p.Require().NoError(err) + p.Truef(tt.expected.Equals(expr), "expected: %s\ngot: %s", tt.expected, expr) + }) + } +} + +func (p *ProjectionTestSuite) TestDayProjection() { + schema, spec := p.schema(), p.daySpec() + + ref, date := iceberg.Reference("event_ts"), iceberg.Reference("date") + tests := []struct { + pred, expected iceberg.BooleanExpression + }{ + {iceberg.NotNull(ref), iceberg.NotNull(date)}, + {iceberg.IsNull(ref), iceberg.IsNull(date)}, + {iceberg.LessThan(ref, "2022-11-27T00:00:00"), iceberg.LessThanEqual(date, int32(19322))}, + {iceberg.LessThanEqual(ref, "2022-11-27T00:00:00"), iceberg.LessThanEqual(date, int32(19323))}, + {iceberg.GreaterThan(ref, "2022-11-26T23:59:59.999999"), iceberg.GreaterThanEqual(date, int32(19323))}, + {iceberg.GreaterThanEqual(ref, "2022-11-26T23:59:59.999999"), iceberg.GreaterThanEqual(date, int32(19322))}, + {iceberg.EqualTo(ref, "2022-11-27T10:00:00"), iceberg.EqualTo(date, int32(19323))}, + {iceberg.NotEqualTo(ref, "2022-11-27T10:00:00"), iceberg.AlwaysTrue{}}, + {iceberg.IsIn(ref, "2022-11-27T00:00:00", "2022-11-26T23:59:59.999999"), iceberg.IsIn(date, int32(19322), 19323)}, + {iceberg.NotIn(ref, "2022-11-27T00:00:00", "2022-11-26T23:59:59.999999"), iceberg.AlwaysTrue{}}, + } + + project := newInclusiveProjection(schema, spec, true) + for _, tt := range tests { + p.Run(tt.pred.String(), func() { + expr, err := project(tt.pred) + p.Require().NoError(err) + p.Truef(tt.expected.Equals(expr), "expected: %s\ngot: %s", tt.expected, expr) + }) + } +} + +func (p *ProjectionTestSuite) TestDateDayProjection() { + schema, spec := p.schema(), p.daySpec() + + ref, date := iceberg.Reference("event_date"), iceberg.Reference("ddate") + tests := []struct { + pred, expected iceberg.BooleanExpression + }{ + {iceberg.NotNull(ref), iceberg.NotNull(date)}, + {iceberg.IsNull(ref), iceberg.IsNull(date)}, + {iceberg.LessThan(ref, "2022-11-27"), iceberg.LessThanEqual(date, int32(19322))}, + {iceberg.LessThanEqual(ref, "2022-11-27"), iceberg.LessThanEqual(date, int32(19323))}, + {iceberg.GreaterThan(ref, "2022-11-26"), iceberg.GreaterThanEqual(date, int32(19323))}, + {iceberg.GreaterThanEqual(ref, "2022-11-26"), iceberg.GreaterThanEqual(date, int32(19322))}, + {iceberg.EqualTo(ref, "2022-11-27"), iceberg.EqualTo(date, int32(19323))}, + {iceberg.NotEqualTo(ref, "2022-11-27"), iceberg.AlwaysTrue{}}, + {iceberg.IsIn(ref, "2022-11-27", "2022-11-26"), iceberg.IsIn(date, int32(19322), 19323)}, + {iceberg.NotIn(ref, "2022-11-27", "2022-11-26"), iceberg.AlwaysTrue{}}, + } + + project := newInclusiveProjection(schema, spec, true) + for _, tt := range tests { + p.Run(tt.pred.String(), func() { + expr, err := project(tt.pred) + p.Require().NoError(err) + p.Truef(tt.expected.Equals(expr), "expected: %s\ngot: %s", tt.expected, expr) + }) + } +} + +func (p *ProjectionTestSuite) TestStringTruncateProjection() { + schema, spec := p.schema(), p.truncateStrSpec() + + ref, truncStr := iceberg.Reference("data"), iceberg.Reference("data_trunc") + tests := []struct { + pred, expected iceberg.BooleanExpression + }{ + {iceberg.NotNull(ref), iceberg.NotNull(truncStr)}, + {iceberg.IsNull(ref), iceberg.IsNull(truncStr)}, + {iceberg.LessThan(ref, "aaa"), iceberg.LessThanEqual(truncStr, "aa")}, + {iceberg.LessThanEqual(ref, "aaa"), iceberg.LessThanEqual(truncStr, "aa")}, + {iceberg.GreaterThan(ref, "aaa"), iceberg.GreaterThanEqual(truncStr, "aa")}, + {iceberg.GreaterThanEqual(ref, "aaa"), iceberg.GreaterThanEqual(truncStr, "aa")}, + {iceberg.EqualTo(ref, "aaa"), iceberg.EqualTo(truncStr, "aa")}, + {iceberg.NotEqualTo(ref, "aaa"), iceberg.AlwaysTrue{}}, + {iceberg.IsIn(ref, "aaa", "aab"), iceberg.EqualTo(truncStr, "aa")}, + {iceberg.NotIn(ref, "aaa", "aab"), iceberg.AlwaysTrue{}}, + } + + project := newInclusiveProjection(schema, spec, true) + for _, tt := range tests { + p.Run(tt.pred.String(), func() { + expr, err := project(tt.pred) + p.Require().NoError(err) + p.Truef(tt.expected.Equals(expr), "expected: %s\ngot: %s", tt.expected, expr) + }) + } +} + +func (p *ProjectionTestSuite) TestIntTruncateProjection() { + schema, spec := p.schema(), p.truncateIntSpec() + + ref, idTrunc := iceberg.Reference("id"), iceberg.Reference("id_trunc") + tests := []struct { + pred, expected iceberg.BooleanExpression + }{ + {iceberg.NotNull(ref), iceberg.NotNull(idTrunc)}, + {iceberg.IsNull(ref), iceberg.IsNull(idTrunc)}, + {iceberg.LessThan(ref, int32(10)), iceberg.LessThanEqual(idTrunc, int64(0))}, + {iceberg.LessThanEqual(ref, int32(10)), iceberg.LessThanEqual(idTrunc, int64(10))}, + {iceberg.GreaterThan(ref, int32(9)), iceberg.GreaterThanEqual(idTrunc, int64(10))}, + {iceberg.GreaterThanEqual(ref, int32(10)), iceberg.GreaterThanEqual(idTrunc, int64(10))}, + {iceberg.EqualTo(ref, int32(15)), iceberg.EqualTo(idTrunc, int64(10))}, + {iceberg.NotEqualTo(ref, int32(15)), iceberg.AlwaysTrue{}}, + {iceberg.IsIn(ref, int32(15), 16), iceberg.EqualTo(idTrunc, int64(10))}, + {iceberg.NotIn(ref, int32(15), 16), iceberg.AlwaysTrue{}}, + } + + project := newInclusiveProjection(schema, spec, true) + for _, tt := range tests { + p.Run(tt.pred.String(), func() { + expr, err := project(tt.pred) + p.Require().NoError(err) + p.Truef(tt.expected.Equals(expr), "expected: %s\ngot: %s", tt.expected, expr) + }) + } +} + +func (p *ProjectionTestSuite) TestProjectionCaseSensitive() { + schema, spec := p.schema(), p.idSpec() + project := newInclusiveProjection(schema, spec, true) + _, err := project(iceberg.NotNull(iceberg.Reference("ID"))) + p.ErrorIs(err, iceberg.ErrInvalidSchema) + p.ErrorContains(err, "could not bind reference 'ID', caseSensitive=true") +} + +func (p *ProjectionTestSuite) TestProjectionCaseInsensitive() { + schema, spec := p.schema(), p.idSpec() + project := newInclusiveProjection(schema, spec, false) + expr, err := project(iceberg.NotNull(iceberg.Reference("ID"))) + p.Require().NoError(err) + p.True(expr.Equals(iceberg.NotNull(iceberg.Reference("id_part")))) +} + +func (p *ProjectionTestSuite) TestProjectEmptySpec() { + project := newInclusiveProjection(p.schema(), p.emptySpec(), true) + expr, err := project(iceberg.NewAnd(iceberg.LessThan(iceberg.Reference("id"), int32(5)), + iceberg.NotNull(iceberg.Reference("data")))) + p.Require().NoError(err) + p.Equal(iceberg.AlwaysTrue{}, expr) +} + +func (p *ProjectionTestSuite) TestAndProjectionMultipleFields() { + project := newInclusiveProjection(p.schema(), p.idAndBucketSpec(), true) + expr, err := project(iceberg.NewAnd(iceberg.LessThan(iceberg.Reference("id"), + int32(5)), iceberg.IsIn(iceberg.Reference("data"), "a", "b", "c"))) + p.Require().NoError(err) + + p.True(expr.Equals(iceberg.NewAnd(iceberg.LessThan(iceberg.Reference("id_part"), int64(5)), + iceberg.IsIn(iceberg.Reference("data_bucket"), int32(2), 3, 15)))) +} + +func (p *ProjectionTestSuite) TestOrProjectionMultipleFields() { + project := newInclusiveProjection(p.schema(), p.idAndBucketSpec(), true) + expr, err := project(iceberg.NewOr(iceberg.LessThan(iceberg.Reference("id"), int32(5)), + iceberg.IsIn(iceberg.Reference("data"), "a", "b", "c"))) + p.Require().NoError(err) + + p.True(expr.Equals(iceberg.NewOr(iceberg.LessThan(iceberg.Reference("id_part"), int64(5)), + iceberg.IsIn(iceberg.Reference("data_bucket"), int32(2), 3, 15)))) +} + +func (p *ProjectionTestSuite) TestNotProjectionMultipleFields() { + project := newInclusiveProjection(p.schema(), p.idAndBucketSpec(), true) + // not causes In to be rewritten to NotIn, which cannot be projected + expr, err := project(iceberg.NewNot(iceberg.NewOr(iceberg.LessThan(iceberg.Reference("id"), int64(5)), + iceberg.IsIn(iceberg.Reference("data"), "a", "b", "c")))) + p.Require().NoError(err) + + p.True(expr.Equals(iceberg.GreaterThanEqual(iceberg.Reference("id_part"), int64(5)))) +} + +func (p *ProjectionTestSuite) TestPartialProjectedFields() { + project := newInclusiveProjection(p.schema(), p.idSpec(), true) + expr, err := project(iceberg.NewAnd(iceberg.LessThan(iceberg.Reference("id"), int32(5)), + iceberg.IsIn(iceberg.Reference("data"), "a", "b", "c"))) + p.Require().NoError(err) + p.True(expr.Equals(iceberg.LessThan(iceberg.Reference("id_part"), int64(5)))) +} + +type mockDataFile struct { + path string + format iceberg.FileFormat + partition map[string]any + count int64 + columnSizes map[int]int64 + filesize int64 + valueCounts map[int]int64 + nullCounts map[int]int64 + nanCounts map[int]int64 + lowerBounds map[int][]byte + upperBounds map[int][]byte +} + +func (*mockDataFile) ContentType() iceberg.ManifestEntryContent { return iceberg.EntryContentData } +func (m *mockDataFile) FilePath() string { return m.path } +func (m *mockDataFile) FileFormat() iceberg.FileFormat { return m.format } +func (m *mockDataFile) Partition() map[string]any { return m.partition } +func (m *mockDataFile) Count() int64 { return m.count } +func (m *mockDataFile) FileSizeBytes() int64 { return m.filesize } +func (m *mockDataFile) ColumnSizes() map[int]int64 { return m.columnSizes } +func (m *mockDataFile) ValueCounts() map[int]int64 { return m.valueCounts } +func (m *mockDataFile) NullValueCounts() map[int]int64 { return m.nullCounts } +func (m *mockDataFile) NaNValueCounts() map[int]int64 { return m.nanCounts } +func (*mockDataFile) DistinctValueCounts() map[int]int64 { return nil } +func (m *mockDataFile) LowerBoundValues() map[int][]byte { return m.lowerBounds } +func (m *mockDataFile) UpperBoundValues() map[int][]byte { return m.upperBounds } +func (*mockDataFile) KeyMetadata() []byte { return nil } +func (*mockDataFile) SplitOffsets() []int64 { return nil } +func (*mockDataFile) EqualityFieldIDs() []int { return nil } +func (*mockDataFile) SortOrderID() *int { return nil } + +type InclusiveMetricsTestSuite struct { + suite.Suite + + schemaDataFile *iceberg.Schema + dataFiles [4]iceberg.DataFile + + schemaDataFileNan *iceberg.Schema + dataFileNan iceberg.DataFile +} + +func (suite *InclusiveMetricsTestSuite) SetupSuite() { + suite.schemaDataFile = iceberg.NewSchema(0, + iceberg.NestedField{ID: 1, Name: "id", Type: iceberg.PrimitiveTypes.Int32, Required: true}, + iceberg.NestedField{ID: 2, Name: "no_stats", Type: iceberg.PrimitiveTypes.Int32, Required: false}, + iceberg.NestedField{ID: 3, Name: "required", Type: iceberg.PrimitiveTypes.String, Required: true}, + iceberg.NestedField{ID: 4, Name: "all_nulls", Type: iceberg.PrimitiveTypes.String}, + iceberg.NestedField{ID: 5, Name: "some_nulls", Type: iceberg.PrimitiveTypes.String}, + iceberg.NestedField{ID: 6, Name: "no_nulls", Type: iceberg.PrimitiveTypes.String}, + iceberg.NestedField{ID: 7, Name: "all_nans", Type: iceberg.PrimitiveTypes.Float64}, + iceberg.NestedField{ID: 8, Name: "some_nans", Type: iceberg.PrimitiveTypes.Float32}, + iceberg.NestedField{ID: 9, Name: "no_nans", Type: iceberg.PrimitiveTypes.Float32}, + iceberg.NestedField{ID: 10, Name: "all_nulls_double", Type: iceberg.PrimitiveTypes.Float64}, + iceberg.NestedField{ID: 11, Name: "all_nans_v1_stats", Type: iceberg.PrimitiveTypes.Float32}, + iceberg.NestedField{ID: 12, Name: "nan_and_null_only", Type: iceberg.PrimitiveTypes.Float64}, + iceberg.NestedField{ID: 13, Name: "no_nan_stats", Type: iceberg.PrimitiveTypes.Float64}, + iceberg.NestedField{ID: 14, Name: "some_empty", Type: iceberg.PrimitiveTypes.String}, + ) + + var ( + IntMin, _ = iceberg.Int32Literal(IntMinValue).MarshalBinary() + IntMax, _ = iceberg.Int32Literal(IntMaxValue).MarshalBinary() + FltNan, _ = iceberg.Float32Literal(float32(math.NaN())).MarshalBinary() + DblNan, _ = iceberg.Float64Literal(math.NaN()).MarshalBinary() + FltSeven, _ = iceberg.Float32Literal(7).MarshalBinary() + DblSeven, _ = iceberg.Float64Literal(7).MarshalBinary() + FltMax, _ = iceberg.Float32Literal(22).MarshalBinary() + ) + + suite.dataFiles = [4]iceberg.DataFile{ + &mockDataFile{ + path: "file_1.parquet", + format: iceberg.ParquetFile, + count: 50, + filesize: 3, + valueCounts: map[int]int64{ + 4: 50, 5: 50, 6: 50, 7: 50, 8: 50, 9: 50, + 10: 50, 11: 50, 12: 50, 13: 50, 14: 50, + }, + nullCounts: map[int]int64{4: 50, 5: 10, 6: 0, 10: 50, 11: 0, 12: 1, 14: 8}, + nanCounts: map[int]int64{7: 50, 8: 10, 9: 0}, + lowerBounds: map[int][]byte{ + 1: IntMin, + 11: FltNan, + 12: DblNan, + 14: {}, + }, + upperBounds: map[int][]byte{ + 1: IntMax, + 11: FltNan, + 12: DblNan, + 14: []byte("房东整租霍营小区二层两居室"), + }, + }, + &mockDataFile{ + path: "file_2.parquet", + format: iceberg.ParquetFile, + count: 50, + filesize: 3, + valueCounts: map[int]int64{3: 20}, + nullCounts: map[int]int64{3: 2}, + nanCounts: nil, + lowerBounds: map[int][]byte{3: {'a', 'a'}}, + upperBounds: map[int][]byte{3: {'d', 'C'}}, + }, + &mockDataFile{ + path: "file_3.parquet", + format: iceberg.ParquetFile, + count: 50, + filesize: 3, + valueCounts: map[int]int64{3: 20}, + nullCounts: map[int]int64{3: 2}, + nanCounts: nil, + lowerBounds: map[int][]byte{3: []byte("1str1")}, + upperBounds: map[int][]byte{3: []byte("3str3")}, + }, + &mockDataFile{ + path: "file_4.parquet", + format: iceberg.ParquetFile, + count: 50, + filesize: 3, + valueCounts: map[int]int64{3: 20}, + nullCounts: map[int]int64{3: 2}, + nanCounts: nil, + lowerBounds: map[int][]byte{3: []byte("abc")}, + upperBounds: map[int][]byte{3: []byte("イロハニホヘト")}, + }, + } + + suite.schemaDataFileNan = iceberg.NewSchema(0, + iceberg.NestedField{ID: 1, Name: "all_nan", Type: iceberg.PrimitiveTypes.Float64, Required: true}, + iceberg.NestedField{ID: 2, Name: "max_nan", Type: iceberg.PrimitiveTypes.Float64, Required: true}, + iceberg.NestedField{ID: 3, Name: "min_max_nan", Type: iceberg.PrimitiveTypes.Float32}, + iceberg.NestedField{ID: 4, Name: "all_nan_null_bounds", Type: iceberg.PrimitiveTypes.Float64, Required: true}, + iceberg.NestedField{ID: 5, Name: "some_nan_correct_bounds", Type: iceberg.PrimitiveTypes.Float32}, + ) + + suite.dataFileNan = &mockDataFile{ + path: "file.avro", + format: iceberg.AvroFile, + count: 50, + filesize: 3, + columnSizes: map[int]int64{1: 10, 2: 10, 3: 10, 4: 10, 5: 10}, + valueCounts: map[int]int64{1: 10, 2: 10, 3: 10, 4: 10, 5: 10}, + nullCounts: map[int]int64{1: 0, 2: 0, 3: 0, 4: 0, 5: 0}, + nanCounts: map[int]int64{1: 10, 4: 10, 5: 5}, + lowerBounds: map[int][]byte{ + 1: DblNan, + 2: DblSeven, + 3: FltNan, + 5: FltSeven, + }, + upperBounds: map[int][]byte{ + 1: DblNan, + 2: DblNan, + 3: FltNan, + 5: FltMax, + }, + } +} + +func (suite *InclusiveMetricsTestSuite) TestAllNull() { + allNull, someNull, noNull := iceberg.Reference("all_nulls"), iceberg.Reference("some_nulls"), iceberg.Reference("no_nulls") + + tests := []struct { + expr iceberg.BooleanExpression + expected bool + msg string + }{ + {iceberg.NotNull(allNull), false, "should skip: no non-null value in all null column"}, + {iceberg.LessThan(allNull, "a"), false, "should skip: lessThan on all null column"}, + {iceberg.LessThanEqual(allNull, "a"), false, "should skip: lessThanEqual on all null column"}, + {iceberg.GreaterThan(allNull, "a"), false, "should skip: greaterThan on all null column"}, + {iceberg.GreaterThanEqual(allNull, "a"), false, "should skip: greaterThanEqual on all null column"}, + {iceberg.EqualTo(allNull, "a"), false, "should skip: equal on all null column"}, + {iceberg.NotNull(someNull), true, "should read: column with some nulls contains a non-null value"}, + {iceberg.NotNull(noNull), true, "should read: non-null column contains a non-null value"}, + {iceberg.StartsWith(allNull, "asad"), false, "should skip: starts with on all null column"}, + {iceberg.NotStartsWith(allNull, "asad"), true, "should read: notStartsWith on all null column"}, + } + + for _, tt := range tests { + suite.Run(tt.expr.String(), func() { + eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) + suite.Require().NoError(err) + shouldRead, err := eval(suite.dataFiles[0]) + suite.Require().NoError(err) + suite.Equal(tt.expected, shouldRead, tt.msg) + }) + } +} + +func (suite *InclusiveMetricsTestSuite) TestNoNulls() { + allNull, someNull, noNull := iceberg.Reference("all_nulls"), iceberg.Reference("some_nulls"), iceberg.Reference("no_nulls") + + tests := []struct { + expr iceberg.BooleanExpression + expected bool + msg string + }{ + {iceberg.IsNull(allNull), true, "should read: at least one null value in all null column"}, + {iceberg.IsNull(someNull), true, "should read: column with some nulls contains a null value"}, + {iceberg.IsNull(noNull), false, "should skip: non-null column contains no null values"}, + } + + for _, tt := range tests { + suite.Run(tt.expr.String(), func() { + eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) + suite.Require().NoError(err) + shouldRead, err := eval(suite.dataFiles[0]) + suite.Require().NoError(err) + suite.Equal(tt.expected, shouldRead, tt.msg) + }) + } +} + +func (suite *InclusiveMetricsTestSuite) TestIsNan() { + allNan, someNan, noNan := iceberg.Reference("all_nans"), iceberg.Reference("some_nans"), iceberg.Reference("no_nans") + allNullsDbl, noNanStats := iceberg.Reference("all_nulls_double"), iceberg.Reference("no_nan_stats") + allNansV1, nanNullOnly := iceberg.Reference("all_nans_v1_stats"), iceberg.Reference("nan_and_null_only") + + tests := []struct { + expr iceberg.BooleanExpression + expected bool + msg string + }{ + {iceberg.IsNaN(allNan), true, "should read: at least one nan value in all nan column"}, + {iceberg.IsNaN(someNan), true, "should read: at least one nan value in some nan column"}, + {iceberg.IsNaN(noNan), false, "should skip: no-nans column has no nans"}, + {iceberg.IsNaN(allNullsDbl), false, "should skip: all-null column doesn't contain nan values"}, + {iceberg.IsNaN(noNanStats), true, "should read: no guarantee if contains nan without stats"}, + {iceberg.IsNaN(allNansV1), true, "should read: at least one nan value in all nan column"}, + {iceberg.IsNaN(nanNullOnly), true, "should read: at least one nan value in nan and nulls only column"}, + } + + for _, tt := range tests { + suite.Run(tt.expr.String(), func() { + eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) + suite.Require().NoError(err) + shouldRead, err := eval(suite.dataFiles[0]) + suite.Require().NoError(err) + suite.Equal(tt.expected, shouldRead, tt.msg) + }) + } +} + +func (suite *InclusiveMetricsTestSuite) TestNotNaN() { + allNan, someNan, noNan := iceberg.Reference("all_nans"), iceberg.Reference("some_nans"), iceberg.Reference("no_nans") + allNullsDbl, noNanStats := iceberg.Reference("all_nulls_double"), iceberg.Reference("no_nan_stats") + allNansV1, nanNullOnly := iceberg.Reference("all_nans_v1_stats"), iceberg.Reference("nan_and_null_only") + + tests := []struct { + expr iceberg.BooleanExpression + expected bool + msg string + }{ + {iceberg.NotNaN(allNan), false, "should skip: column with all nans will not contain non-nan"}, + {iceberg.NotNaN(someNan), true, "should read: at least one non-nan value in some nan column"}, + {iceberg.NotNaN(noNan), true, "should read: at least one non-nan value in no nan column"}, + {iceberg.NotNaN(allNullsDbl), true, "should read: at least one non-nan value in all null column"}, + {iceberg.NotNaN(noNanStats), true, "should read: no guarantee if contains nan without stats"}, + {iceberg.NotNaN(allNansV1), true, "should read: no guarantee"}, + {iceberg.NotNaN(nanNullOnly), true, "should read: at least one null value in nan and nulls only column"}, + } + + for _, tt := range tests { + suite.Run(tt.expr.String(), func() { + eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) + suite.Require().NoError(err) + shouldRead, err := eval(suite.dataFiles[0]) + suite.Require().NoError(err) + suite.Equal(tt.expected, shouldRead, tt.msg) + }) + } +} + +func (suite *InclusiveMetricsTestSuite) TestRequiredColumn() { + tests := []struct { + expr iceberg.BooleanExpression + expected bool + msg string + }{ + {iceberg.NotNull(iceberg.Reference("required")), true, "should read: required columns are always non-null"}, + {iceberg.IsNull(iceberg.Reference("required")), false, "should skip: required columns are always non-null"}, + } + + for _, tt := range tests { + suite.Run(tt.expr.String(), func() { + eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) + suite.Require().NoError(err) + shouldRead, err := eval(suite.dataFiles[0]) + suite.Require().NoError(err) + suite.Equal(tt.expected, shouldRead, tt.msg) + }) + } +} + +func (suite *InclusiveMetricsTestSuite) TestMissingColumn() { + _, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, iceberg.LessThan(iceberg.Reference("missing"), int32(22)), true, true) + suite.ErrorIs(err, iceberg.ErrInvalidSchema) +} + +func (suite *InclusiveMetricsTestSuite) TestMissingStats() { + noStatsSchema := iceberg.NewSchema(0, + iceberg.NestedField{ID: 2, Name: "no_stats", Type: iceberg.PrimitiveTypes.Float64}) + + noStatsFile := &mockDataFile{ + path: "file_1.parquet", + format: iceberg.ParquetFile, + count: 50, + } + + ref := iceberg.Reference("no_stats") + tests := []iceberg.BooleanExpression{ + iceberg.LessThan(ref, int32(5)), + iceberg.LessThanEqual(ref, int32(30)), + iceberg.EqualTo(ref, int32(70)), + iceberg.GreaterThan(ref, int32(78)), + iceberg.GreaterThanEqual(ref, int32(90)), + iceberg.NotEqualTo(ref, int32(101)), + iceberg.IsNull(ref), + iceberg.NotNull(ref), + iceberg.IsNaN(ref), + iceberg.NotNaN(ref), + } + + for _, tt := range tests { + suite.Run(tt.String(), func() { + eval, err := newInclusiveMetricsEvaluator(noStatsSchema, tt, true, true) + suite.Require().NoError(err) + shouldRead, err := eval(noStatsFile) + suite.Require().NoError(err) + suite.True(shouldRead, "should read when stats are missing") + }) + } +} + +func (suite *InclusiveMetricsTestSuite) TestZeroRecordFileStats() { + zeroRecordFile := &mockDataFile{ + path: "file_1.parquet", + format: iceberg.ParquetFile, + count: 0, + } + + ref := iceberg.Reference("no_stats") + tests := []iceberg.BooleanExpression{ + iceberg.LessThan(ref, int32(5)), + iceberg.LessThanEqual(ref, int32(30)), + iceberg.EqualTo(ref, int32(70)), + iceberg.GreaterThan(ref, int32(78)), + iceberg.GreaterThanEqual(ref, int32(90)), + iceberg.NotEqualTo(ref, int32(101)), + iceberg.IsNull(ref), + iceberg.NotNull(ref), + iceberg.IsNaN(ref), + iceberg.NotNaN(ref), + } + + for _, tt := range tests { + suite.Run(tt.String(), func() { + eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt, true, false) + suite.Require().NoError(err) + shouldRead, err := eval(zeroRecordFile) + suite.Require().NoError(err) + suite.False(shouldRead, "should skip datafile without records") + }) + } +} + +func (suite *InclusiveMetricsTestSuite) TestNot() { + tests := []struct { + expr iceberg.BooleanExpression + expected bool + msg string + }{ + {iceberg.NewNot(iceberg.LessThan(iceberg.Reference("id"), IntMinValue-25)), true, "should read: not(false)"}, + {iceberg.NewNot(iceberg.GreaterThan(iceberg.Reference("id"), IntMinValue-25)), false, "should skip: not(true)"}, + } + + for _, tt := range tests { + suite.Run(tt.expr.String(), func() { + eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) + suite.Require().NoError(err) + shouldRead, err := eval(suite.dataFiles[0]) + suite.Require().NoError(err) + suite.Equal(tt.expected, shouldRead, tt.msg) + }) + } +} + +func (suite *InclusiveMetricsTestSuite) TestAnd() { + ref := iceberg.Reference("id") + tests := []struct { + expr iceberg.BooleanExpression + expected bool + msg string + }{ + {iceberg.NewAnd( + iceberg.LessThan(ref, IntMinValue-25), + iceberg.GreaterThanEqual(ref, IntMinValue-30)), false, "should skip: and(false, true)"}, + {iceberg.NewAnd( + iceberg.LessThan(ref, IntMinValue-25), + iceberg.GreaterThanEqual(ref, IntMinValue+1)), false, "should skip: and(false, false)"}, + {iceberg.NewAnd( + iceberg.GreaterThan(ref, IntMinValue-25), + iceberg.LessThanEqual(ref, IntMinValue)), true, "should read: and(true, true)"}, + } + + for _, tt := range tests { + suite.Run(tt.expr.String(), func() { + eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) + suite.Require().NoError(err) + shouldRead, err := eval(suite.dataFiles[0]) + suite.Require().NoError(err) + suite.Equal(tt.expected, shouldRead, tt.msg) + }) + } +} + +func (suite *InclusiveMetricsTestSuite) TestOr() { + ref := iceberg.Reference("id") + tests := []struct { + expr iceberg.BooleanExpression + expected bool + msg string + }{ + {iceberg.NewOr( + iceberg.LessThan(ref, IntMinValue-25), + iceberg.GreaterThanEqual(ref, IntMaxValue+1)), false, "should skip: or(false, false)"}, + {iceberg.NewOr( + iceberg.LessThan(ref, IntMinValue-25), + iceberg.GreaterThanEqual(ref, IntMaxValue-19)), true, "should read: or(false, true)"}, + } + + for _, tt := range tests { + suite.Run(tt.expr.String(), func() { + eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) + suite.Require().NoError(err) + shouldRead, err := eval(suite.dataFiles[0]) + suite.Require().NoError(err) + suite.Equal(tt.expected, shouldRead, tt.msg) + }) + } +} + +func (suite *InclusiveMetricsTestSuite) TestIntLt() { + ref := iceberg.Reference("id") + tests := []struct { + expr iceberg.BooleanExpression + expected bool + msg string + }{ + {iceberg.LessThan(ref, IntMinValue-25), false, "should skip: id range below lower bound (5 < 30)"}, + {iceberg.LessThan(ref, IntMinValue), false, "should skip: id range below lower bound (30 is not < 30)"}, + {iceberg.LessThan(ref, IntMinValue+1), true, "should read: one possible id"}, + {iceberg.LessThan(ref, IntMaxValue), true, "should read: many possible ids"}, + } + + for _, tt := range tests { + suite.Run(tt.expr.String(), func() { + eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) + suite.Require().NoError(err) + shouldRead, err := eval(suite.dataFiles[0]) + suite.Require().NoError(err) + suite.Equal(tt.expected, shouldRead, tt.msg) + }) + } +} + +func (suite *InclusiveMetricsTestSuite) TestIntLtEq() { + ref := iceberg.Reference("id") + tests := []struct { + expr iceberg.BooleanExpression + expected bool + msg string + }{ + {iceberg.LessThanEqual(ref, IntMinValue-25), false, "should skip: id range below lower bound (5 < 30)"}, + {iceberg.LessThanEqual(ref, IntMinValue-1), false, "should skip: id range below lower bound (29 is not <= 30)"}, + {iceberg.LessThanEqual(ref, IntMinValue), true, "should read: one possible id"}, + {iceberg.LessThanEqual(ref, IntMaxValue), true, "should read: many possible ids"}, + } + + for _, tt := range tests { + suite.Run(tt.expr.String(), func() { + eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) + suite.Require().NoError(err) + shouldRead, err := eval(suite.dataFiles[0]) + suite.Require().NoError(err) + suite.Equal(tt.expected, shouldRead, tt.msg) + }) + } +} + +func (suite *InclusiveMetricsTestSuite) TestIntGt() { + ref := iceberg.Reference("id") + tests := []struct { + expr iceberg.BooleanExpression + expected bool + msg string + }{ + {iceberg.GreaterThan(ref, IntMaxValue+6), false, "should skip: id range above upper bound (85 > 79)"}, + {iceberg.GreaterThan(ref, IntMaxValue), false, "should skip: id range above upper bound (79 is not > 79)"}, + {iceberg.GreaterThan(ref, IntMinValue-1), true, "should read: one possible id"}, + {iceberg.GreaterThan(ref, IntMaxValue-4), true, "should read: many possible ids"}, + } + + for _, tt := range tests { + suite.Run(tt.expr.String(), func() { + eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) + suite.Require().NoError(err) + shouldRead, err := eval(suite.dataFiles[0]) + suite.Require().NoError(err) + suite.Equal(tt.expected, shouldRead, tt.msg) + }) + } +} + +func (suite *InclusiveMetricsTestSuite) TestIntGtEq() { + ref := iceberg.Reference("id") + tests := []struct { + expr iceberg.BooleanExpression + expected bool + msg string + }{ + {iceberg.GreaterThanEqual(ref, IntMaxValue+6), false, "should skip: id range above upper bound (85 < 79)"}, + {iceberg.GreaterThanEqual(ref, IntMaxValue+1), false, "should skip: id range above upper bound (80 > 79)"}, + {iceberg.GreaterThanEqual(ref, IntMaxValue), true, "should read: one possible id"}, + {iceberg.GreaterThanEqual(ref, IntMaxValue-4), true, "should read: many possible ids"}, + } + + for _, tt := range tests { + suite.Run(tt.expr.String(), func() { + eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) + suite.Require().NoError(err) + shouldRead, err := eval(suite.dataFiles[0]) + suite.Require().NoError(err) + suite.Equal(tt.expected, shouldRead, tt.msg) + }) + } +} + +func (suite *InclusiveMetricsTestSuite) TestIntEq() { + ref := iceberg.Reference("id") + tests := []struct { + expr iceberg.BooleanExpression + expected bool + msg string + }{ + {iceberg.EqualTo(ref, IntMinValue-25), false, "should skip: id range below lower bound"}, + {iceberg.EqualTo(ref, IntMinValue-1), false, "should skip: id range below lower bound"}, + {iceberg.EqualTo(ref, IntMinValue), true, "should read: id equal to lower bound"}, + {iceberg.EqualTo(ref, IntMaxValue-4), true, "should read: id between lower and upper bounds"}, + {iceberg.EqualTo(ref, IntMaxValue), true, "should read: id equal to upper bound"}, + {iceberg.EqualTo(ref, IntMaxValue+1), false, "should skip: id above upper bound"}, + {iceberg.EqualTo(ref, IntMaxValue+6), false, "should skip: id above upper bound"}, + } + + for _, tt := range tests { + suite.Run(tt.expr.String(), func() { + eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) + suite.Require().NoError(err) + shouldRead, err := eval(suite.dataFiles[0]) + suite.Require().NoError(err) + suite.Equal(tt.expected, shouldRead, tt.msg) + }) + } +} + +func (suite *InclusiveMetricsTestSuite) TestIntNeq() { + ref := iceberg.Reference("id") + tests := []struct { + expr iceberg.BooleanExpression + expected bool + msg string + }{ + {iceberg.NotEqualTo(ref, IntMinValue-25), true, "should read: id range below lower bound"}, + {iceberg.NotEqualTo(ref, IntMinValue-1), true, "should read: id range below lower bound"}, + {iceberg.NotEqualTo(ref, IntMinValue), true, "should read: id equal to lower bound"}, + {iceberg.NotEqualTo(ref, IntMaxValue-4), true, "should read: id between lower and upper bounds"}, + {iceberg.NotEqualTo(ref, IntMaxValue), true, "should read: id equal to upper bound"}, + {iceberg.NotEqualTo(ref, IntMaxValue+1), true, "should read: id above upper bound"}, + {iceberg.NotEqualTo(ref, IntMaxValue+6), true, "should read: id above upper bound"}, + } + + for _, tt := range tests { + suite.Run(tt.expr.String(), func() { + eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) + suite.Require().NoError(err) + shouldRead, err := eval(suite.dataFiles[0]) + suite.Require().NoError(err) + suite.Equal(tt.expected, shouldRead, tt.msg) + }) + } +} + +func (suite *InclusiveMetricsTestSuite) TestIntNeqRewritten() { + ref := iceberg.Reference("id") + tests := []struct { + expr iceberg.BooleanExpression + expected bool + msg string + }{ + {iceberg.EqualTo(ref, IntMinValue-25), true, "should read: id range below lower bound"}, + {iceberg.EqualTo(ref, IntMinValue-1), true, "should read: id range below lower bound"}, + {iceberg.EqualTo(ref, IntMinValue), true, "should read: id equal to lower bound"}, + {iceberg.EqualTo(ref, IntMaxValue-4), true, "should read: id between lower and upper bounds"}, + {iceberg.EqualTo(ref, IntMaxValue), true, "should read: id equal to upper bound"}, + {iceberg.EqualTo(ref, IntMaxValue+1), true, "should read: id above upper bound"}, + {iceberg.EqualTo(ref, IntMaxValue+6), true, "should read: id above upper bound"}, + } + + for _, tt := range tests { + suite.Run(tt.expr.String(), func() { + eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, iceberg.NewNot(tt.expr), true, true) + suite.Require().NoError(err) + shouldRead, err := eval(suite.dataFiles[0]) + suite.Require().NoError(err) + suite.Equal(tt.expected, shouldRead, tt.msg) + }) + } +} + +func (suite *InclusiveMetricsTestSuite) TestIntNeqRewrittenCaseInsensitive() { + ref := iceberg.Reference("ID") + tests := []struct { + expr iceberg.BooleanExpression + expected bool + msg string + }{ + {iceberg.EqualTo(ref, IntMinValue-25), true, "should read: id range below lower bound"}, + {iceberg.EqualTo(ref, IntMinValue-1), true, "should read: id range below lower bound"}, + {iceberg.EqualTo(ref, IntMinValue), true, "should read: id equal to lower bound"}, + {iceberg.EqualTo(ref, IntMaxValue-4), true, "should read: id between lower and upper bounds"}, + {iceberg.EqualTo(ref, IntMaxValue), true, "should read: id equal to upper bound"}, + {iceberg.EqualTo(ref, IntMaxValue+1), true, "should read: id above upper bound"}, + {iceberg.EqualTo(ref, IntMaxValue+6), true, "should read: id above upper bound"}, + } + + for _, tt := range tests { + suite.Run(tt.expr.String(), func() { + eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, iceberg.NewNot(tt.expr), false, true) + suite.Require().NoError(err) + shouldRead, err := eval(suite.dataFiles[0]) + suite.Require().NoError(err) + suite.Equal(tt.expected, shouldRead, tt.msg) + }) + } +} + +func (suite *InclusiveMetricsTestSuite) TestInMetrics() { + ref := iceberg.Reference("id") + + ids := make([]int32, 400) + for i := range ids { + ids[i] = int32(i) + } + + tests := []struct { + expr iceberg.BooleanExpression + expected bool + msg string + }{ + {iceberg.IsIn(ref, IntMinValue-25, IntMinValue-24), false, "should skip: id below lower bound"}, + {iceberg.IsIn(ref, IntMinValue-2, IntMinValue-1), false, "should skip: id below lower bound"}, + {iceberg.IsIn(ref, IntMinValue-1, IntMinValue), true, "should read: id equal to lower bound"}, + {iceberg.IsIn(ref, IntMaxValue-4, IntMaxValue-3), true, "should read: id between upper and lower bounds"}, + {iceberg.IsIn(ref, IntMaxValue, IntMaxValue+1), true, "should read: id equal to upper bound"}, + {iceberg.IsIn(ref, IntMaxValue+1, IntMaxValue+2), false, "should skip: id above upper bound"}, + {iceberg.IsIn(ref, IntMaxValue+6, IntMaxValue+7), false, "should skip: id above upper bound"}, + {iceberg.IsIn(iceberg.Reference("all_nulls"), "abc", "def"), false, "should skip: in on all nulls column"}, + {iceberg.IsIn(iceberg.Reference("some_nulls"), "abc", "def"), true, "should read: in on some nulls column"}, + {iceberg.IsIn(iceberg.Reference("no_nulls"), "abc", "def"), true, "should read: in on no nulls column"}, + {iceberg.IsIn(ref, ids...), true, "should read: large in expression"}, + } + + for _, tt := range tests { + suite.Run(tt.expr.String(), func() { + eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) + suite.Require().NoError(err) + shouldRead, err := eval(suite.dataFiles[0]) + suite.Require().NoError(err) + suite.Equal(tt.expected, shouldRead, tt.msg) + }) + } +} + +func (suite *InclusiveMetricsTestSuite) TestNotInMetrics() { + ref := iceberg.Reference("id") + + tests := []struct { + expr iceberg.BooleanExpression + expected bool + msg string + }{ + {iceberg.NotIn(ref, IntMinValue-25, IntMinValue-24), true, "should read: id below lower bound"}, + {iceberg.NotIn(ref, IntMinValue-2, IntMinValue-1), true, "should read: id below lower bound"}, + {iceberg.NotIn(ref, IntMinValue-1, IntMinValue), true, "should read: id equal to lower bound"}, + {iceberg.NotIn(ref, IntMaxValue-4, IntMaxValue-3), true, "should read: id between upper and lower bounds"}, + {iceberg.NotIn(ref, IntMaxValue, IntMaxValue+1), true, "should read: id equal to upper bound"}, + {iceberg.NotIn(ref, IntMaxValue+1, IntMaxValue+2), true, "should read: id above upper bound"}, + {iceberg.NotIn(ref, IntMaxValue+6, IntMaxValue+7), true, "should read: id above upper bound"}, + {iceberg.NotIn(iceberg.Reference("all_nulls"), "abc", "def"), true, "should read: in on all nulls column"}, + {iceberg.NotIn(iceberg.Reference("some_nulls"), "abc", "def"), true, "should read: in on some nulls column"}, + {iceberg.NotIn(iceberg.Reference("no_nulls"), "abc", "def"), true, "should read: in on no nulls column"}, + } + + for _, tt := range tests { + suite.Run(tt.expr.String(), func() { + eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) + suite.Require().NoError(err) + shouldRead, err := eval(suite.dataFiles[0]) + suite.Require().NoError(err) + suite.Equal(tt.expected, shouldRead, tt.msg) + }) + } +} + +func (suite *InclusiveMetricsTestSuite) TestLessAndLessEqualNans() { + type Op func(iceberg.UnboundTerm, int32) iceberg.UnboundPredicate + for _, operator := range []Op{iceberg.LessThan[int32], iceberg.LessThanEqual[int32]} { + tests := []struct { + expr iceberg.BooleanExpression + expected bool + msg string + }{ + {operator(iceberg.Reference("all_nan"), int32(1)), false, "should skip: all nan column doesn't contain number"}, + {operator(iceberg.Reference("max_nan"), int32(1)), false, "should skip: 1 is smaller than lower bound"}, + {operator(iceberg.Reference("max_nan"), int32(10)), true, "should read: 10 is larger than lower bound"}, + {operator(iceberg.Reference("min_max_nan"), int32(1)), true, "should read: no visibility"}, + {operator(iceberg.Reference("all_nan_null_bounds"), int32(1)), false, "should skip: all nan column doesn't contain number"}, + {operator(iceberg.Reference("some_nan_correct_bounds"), int32(1)), false, "should skip: 1 is smaller than lower bound"}, + {operator(iceberg.Reference("some_nan_correct_bounds"), int32(10)), true, "should read: 10 is larger than lower bound"}, + } + + for _, tt := range tests { + suite.Run(tt.expr.String(), func() { + eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFileNan, tt.expr, true, true) + suite.Require().NoError(err) + shouldRead, err := eval(suite.dataFileNan) + suite.Require().NoError(err) + suite.Equal(tt.expected, shouldRead, tt.msg) + }) + } + } +} + +func (suite *InclusiveMetricsTestSuite) TestGreaterAndGreaterEqualNans() { + type Op func(iceberg.UnboundTerm, int32) iceberg.UnboundPredicate + for _, operator := range []Op{iceberg.GreaterThan[int32], iceberg.GreaterThanEqual[int32]} { + tests := []struct { + expr iceberg.BooleanExpression + expected bool + msg string + }{ + {operator(iceberg.Reference("all_nan"), int32(1)), false, "should skip: all nan column doesn't contain number"}, + {operator(iceberg.Reference("max_nan"), int32(1)), true, "should read: upper bound is larger than 1"}, + {operator(iceberg.Reference("max_nan"), int32(10)), true, "should read: 10 is smaller than upper bound"}, + {operator(iceberg.Reference("min_max_nan"), int32(1)), true, "should read: no visibility"}, + {operator(iceberg.Reference("all_nan_null_bounds"), int32(1)), false, "should skip: all nan column doesn't contain number"}, + {operator(iceberg.Reference("some_nan_correct_bounds"), int32(1)), true, "should read: 1 is smaller than upper bound"}, + {operator(iceberg.Reference("some_nan_correct_bounds"), int32(10)), true, "should read: 10 is smaller than upper bound"}, + {operator(iceberg.Reference("all_nan"), int32(30)), false, "should skip: 30 is larger than upper bound"}, + } + + for _, tt := range tests { + suite.Run(tt.expr.String(), func() { + eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFileNan, tt.expr, true, true) + suite.Require().NoError(err) + shouldRead, err := eval(suite.dataFileNan) + suite.Require().NoError(err) + suite.Equal(tt.expected, shouldRead, tt.msg) + }) + } + } +} + +func (suite *InclusiveMetricsTestSuite) TestEqualsNans() { + tests := []struct { + expr iceberg.BooleanExpression + expected bool + msg string + }{ + {iceberg.EqualTo(iceberg.Reference("all_nan"), int32(1)), false, "should skip: all nan column doesn't contain number"}, + {iceberg.EqualTo(iceberg.Reference("max_nan"), int32(1)), false, "should skip: 1 is smaller than lower bound"}, + {iceberg.EqualTo(iceberg.Reference("max_nan"), int32(10)), true, "should read: 10 is within bounds"}, + {iceberg.EqualTo(iceberg.Reference("min_max_nan"), int32(1)), true, "should read: no visibility"}, + {iceberg.EqualTo(iceberg.Reference("all_nan_null_bounds"), int32(1)), false, "should skip: all nan column doesn't contain number"}, + {iceberg.EqualTo(iceberg.Reference("some_nan_correct_bounds"), int32(1)), false, "should skip: 1 is smaller than lower bound"}, + {iceberg.EqualTo(iceberg.Reference("some_nan_correct_bounds"), int32(10)), true, "should read: 10 within bounds"}, + {iceberg.EqualTo(iceberg.Reference("all_nan"), int32(30)), false, "should skip: 30 is larger than upper bound"}, + } + + for _, tt := range tests { + suite.Run(tt.expr.String(), func() { + eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFileNan, tt.expr, true, true) + suite.Require().NoError(err) + shouldRead, err := eval(suite.dataFileNan) + suite.Require().NoError(err) + suite.Equal(tt.expected, shouldRead, tt.msg) + }) + } +} + +func (suite *InclusiveMetricsTestSuite) TestNotEqualsNans() { + tests := []struct { + expr iceberg.BooleanExpression + expected bool + msg string + }{ + {iceberg.NotEqualTo(iceberg.Reference("all_nan"), int32(1)), true, "should read: no visibility"}, + {iceberg.NotEqualTo(iceberg.Reference("max_nan"), int32(1)), true, "should read: no visibility"}, + {iceberg.NotEqualTo(iceberg.Reference("max_nan"), int32(10)), true, "should read: no visibility"}, + {iceberg.NotEqualTo(iceberg.Reference("min_max_nan"), int32(1)), true, "should read: no visibility"}, + {iceberg.NotEqualTo(iceberg.Reference("all_nan_null_bounds"), int32(1)), true, "should read: no visibility"}, + {iceberg.NotEqualTo(iceberg.Reference("some_nan_correct_bounds"), int32(1)), true, "should read: no visibility"}, + {iceberg.NotEqualTo(iceberg.Reference("some_nan_correct_bounds"), int32(10)), true, "should read: no visibility"}, + {iceberg.NotEqualTo(iceberg.Reference("all_nan"), int32(30)), true, "should read: no visibility"}, + } + + for _, tt := range tests { + suite.Run(tt.expr.String(), func() { + eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFileNan, tt.expr, true, true) + suite.Require().NoError(err) + shouldRead, err := eval(suite.dataFileNan) + suite.Require().NoError(err) + suite.Equal(tt.expected, shouldRead, tt.msg) + }) + } +} + +func (suite *InclusiveMetricsTestSuite) TestInWithNans() { + tests := []struct { + expr iceberg.BooleanExpression + expected bool + msg string + }{ + {iceberg.IsIn(iceberg.Reference("all_nan"), int32(1), 10, 30), false, "should skip: all nan column doesn't contain number"}, + {iceberg.IsIn(iceberg.Reference("max_nan"), int32(1), 10, 30), true, "should read: 10 and 30 are greater than lower bound"}, + {iceberg.IsIn(iceberg.Reference("min_max_nan"), int32(1), 10, 30), true, "should read: no visibility"}, + {iceberg.IsIn(iceberg.Reference("all_nan_null_bounds"), int32(1), 10, 30), false, "should skip: all nan column doesn't contain number"}, + {iceberg.IsIn(iceberg.Reference("some_nan_correct_bounds"), int32(1), 10, 30), true, "should read: 10 within bounds"}, + {iceberg.IsIn(iceberg.Reference("some_nan_correct_bounds"), int32(1), 30), false, "should skip: 1 and 30 not within bounds"}, + {iceberg.IsIn(iceberg.Reference("some_nan_correct_bounds"), int32(5), 7), true, "should read: overlap with lower bound"}, + {iceberg.IsIn(iceberg.Reference("some_nan_correct_bounds"), int32(22), 25), true, "should read: overlap with upper bound"}, + } + + for _, tt := range tests { + suite.Run(tt.expr.String(), func() { + eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFileNan, tt.expr, true, true) + suite.Require().NoError(err) + shouldRead, err := eval(suite.dataFileNan) + suite.Require().NoError(err) + suite.Equal(tt.expected, shouldRead, tt.msg) + }) + } +} + +func (suite *InclusiveMetricsTestSuite) TestNotInWithNans() { + tests := []struct { + expr iceberg.BooleanExpression + expected bool + msg string + }{ + {iceberg.NotIn(iceberg.Reference("all_nan"), int32(1), 10, 30), true, "should read: no visibility"}, + {iceberg.NotIn(iceberg.Reference("max_nan"), int32(1), 10, 30), true, "should read: no visibility"}, + {iceberg.NotIn(iceberg.Reference("min_max_nan"), int32(1), 10, 30), true, "should read: no visibility"}, + {iceberg.NotIn(iceberg.Reference("all_nan_null_bounds"), int32(1), 10, 30), true, "should read: no visibility"}, + {iceberg.NotIn(iceberg.Reference("some_nan_correct_bounds"), int32(1), 10, 30), true, "should read: no visibility"}, + } + + for _, tt := range tests { + suite.Run(tt.expr.String(), func() { + eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFileNan, tt.expr, true, true) + suite.Require().NoError(err) + shouldRead, err := eval(suite.dataFileNan) + suite.Require().NoError(err) + suite.Equal(tt.expected, shouldRead, tt.msg) + }) + } +} + +func (suite *InclusiveMetricsTestSuite) TestStartsWith() { + ref, refEmpty := iceberg.Reference("required"), iceberg.Reference("some_empty") + + tests := []struct { + expr iceberg.BooleanExpression + expected bool + dataFile iceberg.DataFile + msg string + }{ + {iceberg.StartsWith(ref, "a"), true, suite.dataFiles[0], "should read: no stats"}, + {iceberg.StartsWith(ref, "a"), true, suite.dataFiles[1], "should read: range matches"}, + {iceberg.StartsWith(ref, "aa"), true, suite.dataFiles[1], "should read: range matches"}, + {iceberg.StartsWith(ref, "aaa"), true, suite.dataFiles[1], "should read: range matches"}, + {iceberg.StartsWith(ref, "1s"), true, suite.dataFiles[2], "should read: range matches"}, + {iceberg.StartsWith(ref, "1str1x"), true, suite.dataFiles[2], "should read: range matches"}, + {iceberg.StartsWith(ref, "ff"), true, suite.dataFiles[3], "should read: range matches"}, + {iceberg.StartsWith(ref, "aB"), false, suite.dataFiles[1], "should skip: range doesn't match"}, + {iceberg.StartsWith(ref, "dWx"), false, suite.dataFiles[1], "should skip: range doesn't match"}, + {iceberg.StartsWith(ref, "5"), false, suite.dataFiles[2], "should skip: range doesn't match"}, + {iceberg.StartsWith(ref, "3str3x"), false, suite.dataFiles[2], "should skip: range doesn't match"}, + {iceberg.StartsWith(refEmpty, "房东整租霍"), true, suite.dataFiles[0], "should read: range matches"}, + {iceberg.StartsWith(iceberg.Reference("all_nulls"), ""), false, suite.dataFiles[0], "should skip: range doesn't match"}, + } + + for _, tt := range tests { + suite.Run(tt.expr.String(), func() { + eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) + suite.Require().NoError(err) + shouldRead, err := eval(tt.dataFile) + suite.Require().NoError(err) + suite.Equal(tt.expected, shouldRead, tt.msg) + }) + } +} + +func (suite *InclusiveMetricsTestSuite) TestNotStartsWith() { + ref, refEmpty := iceberg.Reference("required"), iceberg.Reference("some_empty") + + tests := []struct { + expr iceberg.BooleanExpression + expected bool + dataFile iceberg.DataFile + msg string + }{ + {iceberg.NotStartsWith(ref, "a"), true, suite.dataFiles[0], "should read: no stats"}, + {iceberg.NotStartsWith(ref, "a"), true, suite.dataFiles[1], "should read: range matches"}, + {iceberg.NotStartsWith(ref, "aa"), true, suite.dataFiles[1], "should read: range matches"}, + {iceberg.NotStartsWith(ref, "aaa"), true, suite.dataFiles[1], "should read: range matches"}, + {iceberg.NotStartsWith(ref, "1s"), true, suite.dataFiles[2], "should read: range matches"}, + {iceberg.NotStartsWith(ref, "1str1x"), true, suite.dataFiles[2], "should read: range matches"}, + {iceberg.NotStartsWith(ref, "ff"), true, suite.dataFiles[3], "should read: range matches"}, + {iceberg.NotStartsWith(ref, "aB"), true, suite.dataFiles[1], "should read: range doesn't match"}, + {iceberg.NotStartsWith(ref, "dWx"), true, suite.dataFiles[1], "should read: range doesn't match"}, + {iceberg.NotStartsWith(ref, "5"), true, suite.dataFiles[2], "should read: range doesn't match"}, + {iceberg.NotStartsWith(ref, "3str3x"), true, suite.dataFiles[2], "should read: range doesn't match"}, + {iceberg.NotStartsWith(refEmpty, "房东整租霍"), true, suite.dataFiles[0], "should read: range matches"}, + } + + for _, tt := range tests { + suite.Run(tt.expr.String(), func() { + eval, err := newInclusiveMetricsEvaluator(suite.schemaDataFile, tt.expr, true, true) + suite.Require().NoError(err) + shouldRead, err := eval(tt.dataFile) + suite.Require().NoError(err) + suite.Equal(tt.expected, shouldRead, tt.msg) + }) + } +} + +func TestEvaluators(t *testing.T) { + suite.Run(t, &ProjectionTestSuite{}) + suite.Run(t, &InclusiveMetricsTestSuite{}) +} diff --git a/table/metadata.go b/table/metadata.go index 47b3ffe..ebc3d76 100644 --- a/table/metadata.go +++ b/table/metadata.go @@ -1,467 +1,467 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package table - -import ( - "encoding/json" - "errors" - "fmt" - "io" - "maps" - "slices" - - "github.com/apache/iceberg-go" - - "github.com/google/uuid" -) - -// Metadata for an iceberg table as specified in the Iceberg spec -// -// https://iceberg.apache.org/spec/#iceberg-table-spec -type Metadata interface { - // Version indicates the version of this metadata, 1 for V1, 2 for V2, etc. - Version() int - // TableUUID returns a UUID that identifies the table, generated when the - // table is created. Implementations must throw an exception if a table's - // UUID does not match the expected UUID after refreshing metadata. - TableUUID() uuid.UUID - // Location is the table's base location. This is used by writers to determine - // where to store data files, manifest files, and table metadata files. - Location() string - // LastUpdatedMillis is the timestamp in milliseconds from the unix epoch when - // the table was last updated. Each table metadata file should update this - // field just before writing. - LastUpdatedMillis() int64 - // LastColumnID returns the highest assigned column ID for the table. - // This is used to ensure fields are always assigned an unused ID when - // evolving schemas. - LastColumnID() int - // Schemas returns the list of schemas, stored as objects with their - // schema-id. - Schemas() []*iceberg.Schema - // CurrentSchema returns the table's current schema. - CurrentSchema() *iceberg.Schema - // PartitionSpecs returns the list of all partition specs in the table. - PartitionSpecs() []iceberg.PartitionSpec - // PartitionSpec returns the current partition spec that the table is using. - PartitionSpec() iceberg.PartitionSpec - // DefaultPartitionSpec is the ID of the current spec that writers should - // use by default. - DefaultPartitionSpec() int - // LastPartitionSpecID is the highest assigned partition field ID across - // all partition specs for the table. This is used to ensure partition - // fields are always assigned an unused ID when evolving specs. - LastPartitionSpecID() *int - // Snapshots returns the list of valid snapshots. Valid snapshots are - // snapshots for which all data files exist in the file system. A data - // file must not be deleted from the file system until the last snapshot - // in which it was listed is garbage collected. - Snapshots() []Snapshot - // SnapshotByID find and return a specific snapshot by its ID. Returns - // nil if the ID is not found in the list of snapshots. - SnapshotByID(int64) *Snapshot - // SnapshotByName searches the list of snapshots for a snapshot with a given - // ref name. Returns nil if there's no ref with this name for a snapshot. - SnapshotByName(name string) *Snapshot - // CurrentSnapshot returns the table's current snapshot. - CurrentSnapshot() *Snapshot - // SortOrder returns the table's current sort order, ie: the one with the - // ID that matches the default-sort-order-id. - SortOrder() SortOrder - // SortOrders returns the list of sort orders in the table. - SortOrders() []SortOrder - // Properties is a string to string map of table properties. This is used - // to control settings that affect reading and writing and is not intended - // to be used for arbitrary metadata. For example, commit.retry.num-retries - // is used to control the number of commit retries. - Properties() iceberg.Properties - - Equals(Metadata) bool -} - -var ( - ErrInvalidMetadataFormatVersion = errors.New("invalid or missing format-version in table metadata") - ErrInvalidMetadata = errors.New("invalid metadata") -) - -// ParseMetadata parses json metadata provided by the passed in reader, -// returning an error if one is encountered. -func ParseMetadata(r io.Reader) (Metadata, error) { - data, err := io.ReadAll(r) - if err != nil { - return nil, err - } - - return ParseMetadataBytes(data) -} - -// ParseMetadataString is like [ParseMetadata], but for a string rather than -// an io.Reader. -func ParseMetadataString(s string) (Metadata, error) { - return ParseMetadataBytes([]byte(s)) -} - -// ParseMetadataBytes is like [ParseMetadataString] but for a byte slice. -func ParseMetadataBytes(b []byte) (Metadata, error) { - ver := struct { - FormatVersion int `json:"format-version"` - }{} - if err := json.Unmarshal(b, &ver); err != nil { - return nil, err - } - - var ret Metadata - switch ver.FormatVersion { - case 1: - ret = &MetadataV1{} - case 2: - ret = &MetadataV2{} - default: - return nil, ErrInvalidMetadataFormatVersion - } - - return ret, json.Unmarshal(b, ret) -} - -func sliceEqualHelper[T interface{ Equals(T) bool }](s1, s2 []T) bool { - return slices.EqualFunc(s1, s2, func(t1, t2 T) bool { - return t1.Equals(t2) - }) -} - -// https://iceberg.apache.org/spec/#iceberg-table-spec -type commonMetadata struct { - FormatVersion int `json:"format-version"` - UUID uuid.UUID `json:"table-uuid"` - Loc string `json:"location"` - LastUpdatedMS int64 `json:"last-updated-ms"` - LastColumnId int `json:"last-column-id"` - SchemaList []*iceberg.Schema `json:"schemas"` - CurrentSchemaID int `json:"current-schema-id"` - Specs []iceberg.PartitionSpec `json:"partition-specs"` - DefaultSpecID int `json:"default-spec-id"` - LastPartitionID *int `json:"last-partition-id,omitempty"` - Props iceberg.Properties `json:"properties"` - SnapshotList []Snapshot `json:"snapshots,omitempty"` - CurrentSnapshotID *int64 `json:"current-snapshot-id,omitempty"` - SnapshotLog []SnapshotLogEntry `json:"snapshot-log"` - MetadataLog []MetadataLogEntry `json:"metadata-log"` - SortOrderList []SortOrder `json:"sort-orders"` - DefaultSortOrderID int `json:"default-sort-order-id"` - Refs map[string]SnapshotRef `json:"refs"` -} - -func (c *commonMetadata) Equals(other *commonMetadata) bool { - switch { - case c.LastPartitionID == nil && other.LastPartitionID != nil: - fallthrough - case c.LastPartitionID != nil && other.LastPartitionID == nil: - fallthrough - case c.CurrentSnapshotID == nil && other.CurrentSnapshotID != nil: - fallthrough - case c.CurrentSnapshotID != nil && other.CurrentSnapshotID == nil: - return false - } - - switch { - case !sliceEqualHelper(c.SchemaList, other.SchemaList): - fallthrough - case !sliceEqualHelper(c.SnapshotList, other.SnapshotList): - fallthrough - case !sliceEqualHelper(c.Specs, other.Specs): - fallthrough - case !maps.Equal(c.Props, other.Props): - fallthrough - case !maps.EqualFunc(c.Refs, other.Refs, func(sr1, sr2 SnapshotRef) bool { return sr1.Equals(sr2) }): - return false - } - - return c.FormatVersion == other.FormatVersion && c.UUID == other.UUID && - ((c.LastPartitionID == other.LastPartitionID) || (*c.LastPartitionID == *other.LastPartitionID)) && - ((c.CurrentSnapshotID == other.CurrentSnapshotID) || (*c.CurrentSnapshotID == *other.CurrentSnapshotID)) && - c.Loc == other.Loc && c.LastUpdatedMS == other.LastUpdatedMS && - c.LastColumnId == other.LastColumnId && c.CurrentSchemaID == other.CurrentSchemaID && - c.DefaultSpecID == other.DefaultSpecID && c.DefaultSortOrderID == other.DefaultSortOrderID && - slices.Equal(c.SnapshotLog, other.SnapshotLog) && slices.Equal(c.MetadataLog, other.MetadataLog) && - sliceEqualHelper(c.SortOrderList, other.SortOrderList) - -} - -func (c *commonMetadata) TableUUID() uuid.UUID { return c.UUID } -func (c *commonMetadata) Location() string { return c.Loc } -func (c *commonMetadata) LastUpdatedMillis() int64 { return c.LastUpdatedMS } -func (c *commonMetadata) LastColumnID() int { return c.LastColumnId } -func (c *commonMetadata) Schemas() []*iceberg.Schema { return c.SchemaList } -func (c *commonMetadata) CurrentSchema() *iceberg.Schema { - for _, s := range c.SchemaList { - if s.ID == c.CurrentSchemaID { - return s - } - } - panic("should never get here") -} - -func (c *commonMetadata) PartitionSpecs() []iceberg.PartitionSpec { - return c.Specs -} - -func (c *commonMetadata) DefaultPartitionSpec() int { - return c.DefaultSpecID -} - -func (c *commonMetadata) PartitionSpec() iceberg.PartitionSpec { - for _, s := range c.Specs { - if s.ID() == c.DefaultSpecID { - return s - } - } - return *iceberg.UnpartitionedSpec -} - -func (c *commonMetadata) LastPartitionSpecID() *int { return c.LastPartitionID } -func (c *commonMetadata) Snapshots() []Snapshot { return c.SnapshotList } -func (c *commonMetadata) SnapshotByID(id int64) *Snapshot { - for i := range c.SnapshotList { - if c.SnapshotList[i].SnapshotID == id { - return &c.SnapshotList[i] - } - } - return nil -} - -func (c *commonMetadata) SnapshotByName(name string) *Snapshot { - if ref, ok := c.Refs[name]; ok { - return c.SnapshotByID(ref.SnapshotID) - } - return nil -} - -func (c *commonMetadata) CurrentSnapshot() *Snapshot { - if c.CurrentSnapshotID == nil { - return nil - } - return c.SnapshotByID(*c.CurrentSnapshotID) -} - -func (c *commonMetadata) SortOrders() []SortOrder { return c.SortOrderList } -func (c *commonMetadata) SortOrder() SortOrder { - for _, s := range c.SortOrderList { - if s.OrderID == c.DefaultSortOrderID { - return s - } - } - return UnsortedSortOrder -} - -func (c *commonMetadata) Properties() iceberg.Properties { - return c.Props -} - -// preValidate updates values in the metadata struct with defaults based on -// combinations of struct members. Such as initializing slices as empty slices -// if they were null in the metadata, or normalizing inconsistencies between -// metadata versions. -func (c *commonMetadata) preValidate() { - if c.CurrentSnapshotID != nil && *c.CurrentSnapshotID == -1 { - // treat -1 as the same as nil, clean this up in pre-validation - // to make the validation logic simplified later - c.CurrentSnapshotID = nil - } - - if c.CurrentSnapshotID != nil { - if _, ok := c.Refs[MainBranch]; !ok { - c.Refs[MainBranch] = SnapshotRef{ - SnapshotID: *c.CurrentSnapshotID, - SnapshotRefType: BranchRef, - } - } - } - - if c.MetadataLog == nil { - c.MetadataLog = []MetadataLogEntry{} - } - - if c.Refs == nil { - c.Refs = make(map[string]SnapshotRef) - } - - if c.SnapshotLog == nil { - c.SnapshotLog = []SnapshotLogEntry{} - } -} - -func (c *commonMetadata) checkSchemas() error { - // check that current-schema-id is present in schemas - for _, s := range c.SchemaList { - if s.ID == c.CurrentSchemaID { - return nil - } - } - - return fmt.Errorf("%w: current-schema-id %d can't be found in any schema", - ErrInvalidMetadata, c.CurrentSchemaID) -} - -func (c *commonMetadata) checkPartitionSpecs() error { - for _, spec := range c.Specs { - if spec.ID() == c.DefaultSpecID { - return nil - } - } - - return fmt.Errorf("%w: default-spec-id %d can't be found", - ErrInvalidMetadata, c.DefaultSpecID) -} - -func (c *commonMetadata) checkSortOrders() error { - if c.DefaultSortOrderID == UnsortedSortOrderID { - return nil - } - - for _, o := range c.SortOrderList { - if o.OrderID == c.DefaultSortOrderID { - return nil - } - } - - return fmt.Errorf("%w: default-sort-order-id %d can't be found in %+v", - ErrInvalidMetadata, c.DefaultSortOrderID, c.SortOrderList) -} - -func (c *commonMetadata) validate() error { - if err := c.checkSchemas(); err != nil { - return err - } - - if err := c.checkPartitionSpecs(); err != nil { - return err - } - - if err := c.checkSortOrders(); err != nil { - return err - } - - switch { - case c.LastUpdatedMS == 0: - // last-updated-ms is required - return fmt.Errorf("%w: missing last-updated-ms", ErrInvalidMetadata) - case c.LastColumnId == 0: - // last-column-id is required - return fmt.Errorf("%w: missing last-column-id", ErrInvalidMetadata) - } - - return nil -} - -func (c *commonMetadata) Version() int { return c.FormatVersion } - -type MetadataV1 struct { - Schema iceberg.Schema `json:"schema"` - Partition []iceberg.PartitionField `json:"partition-spec"` - - commonMetadata -} - -func (m *MetadataV1) Equals(other Metadata) bool { - rhs, ok := other.(*MetadataV1) - if !ok { - return false - } - - return m.Schema.Equals(&rhs.Schema) && slices.Equal(m.Partition, rhs.Partition) && - m.commonMetadata.Equals(&rhs.commonMetadata) -} - -func (m *MetadataV1) preValidate() { - if len(m.SchemaList) == 0 { - m.SchemaList = []*iceberg.Schema{&m.Schema} - } - - if len(m.Specs) == 0 { - m.Specs = []iceberg.PartitionSpec{ - iceberg.NewPartitionSpec(m.Partition...)} - m.DefaultSpecID = m.Specs[0].ID() - } - - if m.LastPartitionID == nil { - id := m.Specs[0].LastAssignedFieldID() - for _, spec := range m.Specs[1:] { - last := spec.LastAssignedFieldID() - if last > id { - id = last - } - } - m.LastPartitionID = &id - } - - if len(m.SortOrderList) == 0 { - m.SortOrderList = []SortOrder{UnsortedSortOrder} - } - - m.commonMetadata.preValidate() -} - -func (m *MetadataV1) UnmarshalJSON(b []byte) error { - type Alias MetadataV1 - aux := (*Alias)(m) - - if err := json.Unmarshal(b, aux); err != nil { - return err - } - - m.preValidate() - return m.validate() -} - -func (m *MetadataV1) ToV2() MetadataV2 { - commonOut := m.commonMetadata - commonOut.FormatVersion = 2 - if commonOut.UUID.String() == "" { - commonOut.UUID = uuid.New() - } - - return MetadataV2{commonMetadata: commonOut} -} - -type MetadataV2 struct { - LastSequenceNumber int `json:"last-sequence-number"` - - commonMetadata -} - -func (m *MetadataV2) Equals(other Metadata) bool { - rhs, ok := other.(*MetadataV2) - if !ok { - return false - } - - return m.LastSequenceNumber == rhs.LastSequenceNumber && - m.commonMetadata.Equals(&rhs.commonMetadata) -} - -func (m *MetadataV2) UnmarshalJSON(b []byte) error { - type Alias MetadataV2 - aux := (*Alias)(m) - - if err := json.Unmarshal(b, aux); err != nil { - return err - } - - m.preValidate() - return m.validate() -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package table + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "maps" + "slices" + + "github.com/apache/iceberg-go" + + "github.com/google/uuid" +) + +// Metadata for an iceberg table as specified in the Iceberg spec +// +// https://iceberg.apache.org/spec/#iceberg-table-spec +type Metadata interface { + // Version indicates the version of this metadata, 1 for V1, 2 for V2, etc. + Version() int + // TableUUID returns a UUID that identifies the table, generated when the + // table is created. Implementations must throw an exception if a table's + // UUID does not match the expected UUID after refreshing metadata. + TableUUID() uuid.UUID + // Location is the table's base location. This is used by writers to determine + // where to store data files, manifest files, and table metadata files. + Location() string + // LastUpdatedMillis is the timestamp in milliseconds from the unix epoch when + // the table was last updated. Each table metadata file should update this + // field just before writing. + LastUpdatedMillis() int64 + // LastColumnID returns the highest assigned column ID for the table. + // This is used to ensure fields are always assigned an unused ID when + // evolving schemas. + LastColumnID() int + // Schemas returns the list of schemas, stored as objects with their + // schema-id. + Schemas() []*iceberg.Schema + // CurrentSchema returns the table's current schema. + CurrentSchema() *iceberg.Schema + // PartitionSpecs returns the list of all partition specs in the table. + PartitionSpecs() []iceberg.PartitionSpec + // PartitionSpec returns the current partition spec that the table is using. + PartitionSpec() iceberg.PartitionSpec + // DefaultPartitionSpec is the ID of the current spec that writers should + // use by default. + DefaultPartitionSpec() int + // LastPartitionSpecID is the highest assigned partition field ID across + // all partition specs for the table. This is used to ensure partition + // fields are always assigned an unused ID when evolving specs. + LastPartitionSpecID() *int + // Snapshots returns the list of valid snapshots. Valid snapshots are + // snapshots for which all data files exist in the file system. A data + // file must not be deleted from the file system until the last snapshot + // in which it was listed is garbage collected. + Snapshots() []Snapshot + // SnapshotByID find and return a specific snapshot by its ID. Returns + // nil if the ID is not found in the list of snapshots. + SnapshotByID(int64) *Snapshot + // SnapshotByName searches the list of snapshots for a snapshot with a given + // ref name. Returns nil if there's no ref with this name for a snapshot. + SnapshotByName(name string) *Snapshot + // CurrentSnapshot returns the table's current snapshot. + CurrentSnapshot() *Snapshot + // SortOrder returns the table's current sort order, ie: the one with the + // ID that matches the default-sort-order-id. + SortOrder() SortOrder + // SortOrders returns the list of sort orders in the table. + SortOrders() []SortOrder + // Properties is a string to string map of table properties. This is used + // to control settings that affect reading and writing and is not intended + // to be used for arbitrary metadata. For example, commit.retry.num-retries + // is used to control the number of commit retries. + Properties() iceberg.Properties + + Equals(Metadata) bool +} + +var ( + ErrInvalidMetadataFormatVersion = errors.New("invalid or missing format-version in table metadata") + ErrInvalidMetadata = errors.New("invalid metadata") +) + +// ParseMetadata parses json metadata provided by the passed in reader, +// returning an error if one is encountered. +func ParseMetadata(r io.Reader) (Metadata, error) { + data, err := io.ReadAll(r) + if err != nil { + return nil, err + } + + return ParseMetadataBytes(data) +} + +// ParseMetadataString is like [ParseMetadata], but for a string rather than +// an io.Reader. +func ParseMetadataString(s string) (Metadata, error) { + return ParseMetadataBytes([]byte(s)) +} + +// ParseMetadataBytes is like [ParseMetadataString] but for a byte slice. +func ParseMetadataBytes(b []byte) (Metadata, error) { + ver := struct { + FormatVersion int `json:"format-version"` + }{} + if err := json.Unmarshal(b, &ver); err != nil { + return nil, err + } + + var ret Metadata + switch ver.FormatVersion { + case 1: + ret = &MetadataV1{} + case 2: + ret = &MetadataV2{} + default: + return nil, ErrInvalidMetadataFormatVersion + } + + return ret, json.Unmarshal(b, ret) +} + +func sliceEqualHelper[T interface{ Equals(T) bool }](s1, s2 []T) bool { + return slices.EqualFunc(s1, s2, func(t1, t2 T) bool { + return t1.Equals(t2) + }) +} + +// https://iceberg.apache.org/spec/#iceberg-table-spec +type commonMetadata struct { + FormatVersion int `json:"format-version"` + UUID uuid.UUID `json:"table-uuid"` + Loc string `json:"location"` + LastUpdatedMS int64 `json:"last-updated-ms"` + LastColumnId int `json:"last-column-id"` + SchemaList []*iceberg.Schema `json:"schemas"` + CurrentSchemaID int `json:"current-schema-id"` + Specs []iceberg.PartitionSpec `json:"partition-specs"` + DefaultSpecID int `json:"default-spec-id"` + LastPartitionID *int `json:"last-partition-id,omitempty"` + Props iceberg.Properties `json:"properties"` + SnapshotList []Snapshot `json:"snapshots,omitempty"` + CurrentSnapshotID *int64 `json:"current-snapshot-id,omitempty"` + SnapshotLog []SnapshotLogEntry `json:"snapshot-log"` + MetadataLog []MetadataLogEntry `json:"metadata-log"` + SortOrderList []SortOrder `json:"sort-orders"` + DefaultSortOrderID int `json:"default-sort-order-id"` + Refs map[string]SnapshotRef `json:"refs"` +} + +func (c *commonMetadata) Equals(other *commonMetadata) bool { + switch { + case c.LastPartitionID == nil && other.LastPartitionID != nil: + fallthrough + case c.LastPartitionID != nil && other.LastPartitionID == nil: + fallthrough + case c.CurrentSnapshotID == nil && other.CurrentSnapshotID != nil: + fallthrough + case c.CurrentSnapshotID != nil && other.CurrentSnapshotID == nil: + return false + } + + switch { + case !sliceEqualHelper(c.SchemaList, other.SchemaList): + fallthrough + case !sliceEqualHelper(c.SnapshotList, other.SnapshotList): + fallthrough + case !sliceEqualHelper(c.Specs, other.Specs): + fallthrough + case !maps.Equal(c.Props, other.Props): + fallthrough + case !maps.EqualFunc(c.Refs, other.Refs, func(sr1, sr2 SnapshotRef) bool { return sr1.Equals(sr2) }): + return false + } + + return c.FormatVersion == other.FormatVersion && c.UUID == other.UUID && + ((c.LastPartitionID == other.LastPartitionID) || (*c.LastPartitionID == *other.LastPartitionID)) && + ((c.CurrentSnapshotID == other.CurrentSnapshotID) || (*c.CurrentSnapshotID == *other.CurrentSnapshotID)) && + c.Loc == other.Loc && c.LastUpdatedMS == other.LastUpdatedMS && + c.LastColumnId == other.LastColumnId && c.CurrentSchemaID == other.CurrentSchemaID && + c.DefaultSpecID == other.DefaultSpecID && c.DefaultSortOrderID == other.DefaultSortOrderID && + slices.Equal(c.SnapshotLog, other.SnapshotLog) && slices.Equal(c.MetadataLog, other.MetadataLog) && + sliceEqualHelper(c.SortOrderList, other.SortOrderList) + +} + +func (c *commonMetadata) TableUUID() uuid.UUID { return c.UUID } +func (c *commonMetadata) Location() string { return c.Loc } +func (c *commonMetadata) LastUpdatedMillis() int64 { return c.LastUpdatedMS } +func (c *commonMetadata) LastColumnID() int { return c.LastColumnId } +func (c *commonMetadata) Schemas() []*iceberg.Schema { return c.SchemaList } +func (c *commonMetadata) CurrentSchema() *iceberg.Schema { + for _, s := range c.SchemaList { + if s.ID == c.CurrentSchemaID { + return s + } + } + panic("should never get here") +} + +func (c *commonMetadata) PartitionSpecs() []iceberg.PartitionSpec { + return c.Specs +} + +func (c *commonMetadata) DefaultPartitionSpec() int { + return c.DefaultSpecID +} + +func (c *commonMetadata) PartitionSpec() iceberg.PartitionSpec { + for _, s := range c.Specs { + if s.ID() == c.DefaultSpecID { + return s + } + } + return *iceberg.UnpartitionedSpec +} + +func (c *commonMetadata) LastPartitionSpecID() *int { return c.LastPartitionID } +func (c *commonMetadata) Snapshots() []Snapshot { return c.SnapshotList } +func (c *commonMetadata) SnapshotByID(id int64) *Snapshot { + for i := range c.SnapshotList { + if c.SnapshotList[i].SnapshotID == id { + return &c.SnapshotList[i] + } + } + return nil +} + +func (c *commonMetadata) SnapshotByName(name string) *Snapshot { + if ref, ok := c.Refs[name]; ok { + return c.SnapshotByID(ref.SnapshotID) + } + return nil +} + +func (c *commonMetadata) CurrentSnapshot() *Snapshot { + if c.CurrentSnapshotID == nil { + return nil + } + return c.SnapshotByID(*c.CurrentSnapshotID) +} + +func (c *commonMetadata) SortOrders() []SortOrder { return c.SortOrderList } +func (c *commonMetadata) SortOrder() SortOrder { + for _, s := range c.SortOrderList { + if s.OrderID == c.DefaultSortOrderID { + return s + } + } + return UnsortedSortOrder +} + +func (c *commonMetadata) Properties() iceberg.Properties { + return c.Props +} + +// preValidate updates values in the metadata struct with defaults based on +// combinations of struct members. Such as initializing slices as empty slices +// if they were null in the metadata, or normalizing inconsistencies between +// metadata versions. +func (c *commonMetadata) preValidate() { + if c.CurrentSnapshotID != nil && *c.CurrentSnapshotID == -1 { + // treat -1 as the same as nil, clean this up in pre-validation + // to make the validation logic simplified later + c.CurrentSnapshotID = nil + } + + if c.CurrentSnapshotID != nil { + if _, ok := c.Refs[MainBranch]; !ok { + c.Refs[MainBranch] = SnapshotRef{ + SnapshotID: *c.CurrentSnapshotID, + SnapshotRefType: BranchRef, + } + } + } + + if c.MetadataLog == nil { + c.MetadataLog = []MetadataLogEntry{} + } + + if c.Refs == nil { + c.Refs = make(map[string]SnapshotRef) + } + + if c.SnapshotLog == nil { + c.SnapshotLog = []SnapshotLogEntry{} + } +} + +func (c *commonMetadata) checkSchemas() error { + // check that current-schema-id is present in schemas + for _, s := range c.SchemaList { + if s.ID == c.CurrentSchemaID { + return nil + } + } + + return fmt.Errorf("%w: current-schema-id %d can't be found in any schema", + ErrInvalidMetadata, c.CurrentSchemaID) +} + +func (c *commonMetadata) checkPartitionSpecs() error { + for _, spec := range c.Specs { + if spec.ID() == c.DefaultSpecID { + return nil + } + } + + return fmt.Errorf("%w: default-spec-id %d can't be found", + ErrInvalidMetadata, c.DefaultSpecID) +} + +func (c *commonMetadata) checkSortOrders() error { + if c.DefaultSortOrderID == UnsortedSortOrderID { + return nil + } + + for _, o := range c.SortOrderList { + if o.OrderID == c.DefaultSortOrderID { + return nil + } + } + + return fmt.Errorf("%w: default-sort-order-id %d can't be found in %+v", + ErrInvalidMetadata, c.DefaultSortOrderID, c.SortOrderList) +} + +func (c *commonMetadata) validate() error { + if err := c.checkSchemas(); err != nil { + return err + } + + if err := c.checkPartitionSpecs(); err != nil { + return err + } + + if err := c.checkSortOrders(); err != nil { + return err + } + + switch { + case c.LastUpdatedMS == 0: + // last-updated-ms is required + return fmt.Errorf("%w: missing last-updated-ms", ErrInvalidMetadata) + case c.LastColumnId == 0: + // last-column-id is required + return fmt.Errorf("%w: missing last-column-id", ErrInvalidMetadata) + } + + return nil +} + +func (c *commonMetadata) Version() int { return c.FormatVersion } + +type MetadataV1 struct { + Schema iceberg.Schema `json:"schema"` + Partition []iceberg.PartitionField `json:"partition-spec"` + + commonMetadata +} + +func (m *MetadataV1) Equals(other Metadata) bool { + rhs, ok := other.(*MetadataV1) + if !ok { + return false + } + + return m.Schema.Equals(&rhs.Schema) && slices.Equal(m.Partition, rhs.Partition) && + m.commonMetadata.Equals(&rhs.commonMetadata) +} + +func (m *MetadataV1) preValidate() { + if len(m.SchemaList) == 0 { + m.SchemaList = []*iceberg.Schema{&m.Schema} + } + + if len(m.Specs) == 0 { + m.Specs = []iceberg.PartitionSpec{ + iceberg.NewPartitionSpec(m.Partition...)} + m.DefaultSpecID = m.Specs[0].ID() + } + + if m.LastPartitionID == nil { + id := m.Specs[0].LastAssignedFieldID() + for _, spec := range m.Specs[1:] { + last := spec.LastAssignedFieldID() + if last > id { + id = last + } + } + m.LastPartitionID = &id + } + + if len(m.SortOrderList) == 0 { + m.SortOrderList = []SortOrder{UnsortedSortOrder} + } + + m.commonMetadata.preValidate() +} + +func (m *MetadataV1) UnmarshalJSON(b []byte) error { + type Alias MetadataV1 + aux := (*Alias)(m) + + if err := json.Unmarshal(b, aux); err != nil { + return err + } + + m.preValidate() + return m.validate() +} + +func (m *MetadataV1) ToV2() MetadataV2 { + commonOut := m.commonMetadata + commonOut.FormatVersion = 2 + if commonOut.UUID.String() == "" { + commonOut.UUID = uuid.New() + } + + return MetadataV2{commonMetadata: commonOut} +} + +type MetadataV2 struct { + LastSequenceNumber int `json:"last-sequence-number"` + + commonMetadata +} + +func (m *MetadataV2) Equals(other Metadata) bool { + rhs, ok := other.(*MetadataV2) + if !ok { + return false + } + + return m.LastSequenceNumber == rhs.LastSequenceNumber && + m.commonMetadata.Equals(&rhs.commonMetadata) +} + +func (m *MetadataV2) UnmarshalJSON(b []byte) error { + type Alias MetadataV2 + aux := (*Alias)(m) + + if err := json.Unmarshal(b, aux); err != nil { + return err + } + + m.preValidate() + return m.validate() +} diff --git a/table/metadata_test.go b/table/metadata_test.go index e268d88..1597cb9 100644 --- a/table/metadata_test.go +++ b/table/metadata_test.go @@ -1,494 +1,494 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package table_test - -import ( - "encoding/json" - "slices" - "testing" - - "github.com/apache/iceberg-go" - "github.com/apache/iceberg-go/table" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -const ExampleTableMetadataV2 = `{ - "format-version": 2, - "table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1", - "location": "s3://bucket/test/location", - "last-sequence-number": 34, - "last-updated-ms": 1602638573590, - "last-column-id": 3, - "current-schema-id": 1, - "schemas": [ - {"type": "struct", "schema-id": 0, "fields": [{"id": 1, "name": "x", "required": true, "type": "long"}]}, - { - "type": "struct", - "schema-id": 1, - "identifier-field-ids": [1, 2], - "fields": [ - {"id": 1, "name": "x", "required": true, "type": "long"}, - {"id": 2, "name": "y", "required": true, "type": "long", "doc": "comment"}, - {"id": 3, "name": "z", "required": true, "type": "long"} - ] - } - ], - "default-spec-id": 0, - "partition-specs": [{"spec-id": 0, "fields": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]}], - "last-partition-id": 1000, - "default-sort-order-id": 3, - "sort-orders": [ - { - "order-id": 3, - "fields": [ - {"transform": "identity", "source-id": 2, "direction": "asc", "null-order": "nulls-first"}, - {"transform": "bucket[4]", "source-id": 3, "direction": "desc", "null-order": "nulls-last"} - ] - } - ], - "properties": {"read.split.target.size": "134217728"}, - "current-snapshot-id": 3055729675574597004, - "snapshots": [ - { - "snapshot-id": 3051729675574597004, - "timestamp-ms": 1515100955770, - "sequence-number": 0, - "summary": {"operation": "append"}, - "manifest-list": "s3://a/b/1.avro" - }, - { - "snapshot-id": 3055729675574597004, - "parent-snapshot-id": 3051729675574597004, - "timestamp-ms": 1555100955770, - "sequence-number": 1, - "summary": {"operation": "append"}, - "manifest-list": "s3://a/b/2.avro", - "schema-id": 1 - } - ], - "snapshot-log": [ - {"snapshot-id": 3051729675574597004, "timestamp-ms": 1515100955770}, - {"snapshot-id": 3055729675574597004, "timestamp-ms": 1555100955770} - ], - "metadata-log": [{"metadata-file": "s3://bucket/.../v1.json", "timestamp-ms": 1515100}], - "refs": {"test": {"snapshot-id": 3051729675574597004, "type": "tag", "max-ref-age-ms": 10000000}} -}` - -const ExampleTableMetadataV1 = `{ - "format-version": 1, - "table-uuid": "d20125c8-7284-442c-9aea-15fee620737c", - "location": "s3://bucket/test/location", - "last-updated-ms": 1602638573874, - "last-column-id": 3, - "schema": { - "type": "struct", - "fields": [ - {"id": 1, "name": "x", "required": true, "type": "long"}, - {"id": 2, "name": "y", "required": true, "type": "long", "doc": "comment"}, - {"id": 3, "name": "z", "required": true, "type": "long"} - ] - }, - "partition-spec": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}], - "properties": {}, - "current-snapshot-id": -1, - "snapshots": [{"snapshot-id": 1925, "timestamp-ms": 1602638573822}] -}` - -func TestMetadataV1Parsing(t *testing.T) { - meta, err := table.ParseMetadataBytes([]byte(ExampleTableMetadataV1)) - require.NoError(t, err) - require.NotNil(t, meta) - - assert.IsType(t, (*table.MetadataV1)(nil), meta) - assert.Equal(t, 1, meta.Version()) - - data := meta.(*table.MetadataV1) - assert.Equal(t, uuid.MustParse("d20125c8-7284-442c-9aea-15fee620737c"), meta.TableUUID()) - assert.Equal(t, "s3://bucket/test/location", meta.Location()) - assert.Equal(t, int64(1602638573874), meta.LastUpdatedMillis()) - assert.Equal(t, 3, meta.LastColumnID()) - - expected := iceberg.NewSchema( - 0, - iceberg.NestedField{ID: 1, Name: "x", Type: iceberg.PrimitiveTypes.Int64, Required: true}, - iceberg.NestedField{ID: 2, Name: "y", Type: iceberg.PrimitiveTypes.Int64, Required: true, Doc: "comment"}, - iceberg.NestedField{ID: 3, Name: "z", Type: iceberg.PrimitiveTypes.Int64, Required: true}, - ) - - assert.True(t, slices.EqualFunc([]*iceberg.Schema{expected}, meta.Schemas(), func(s1, s2 *iceberg.Schema) bool { - return s1.Equals(s2) - })) - assert.Zero(t, data.SchemaList[0].ID) - assert.True(t, meta.CurrentSchema().Equals(expected)) - assert.Equal(t, []iceberg.PartitionSpec{ - iceberg.NewPartitionSpec(iceberg.PartitionField{ - SourceID: 1, FieldID: 1000, Transform: iceberg.IdentityTransform{}, Name: "x", - }), - }, meta.PartitionSpecs()) - - assert.Equal(t, iceberg.NewPartitionSpec(iceberg.PartitionField{ - SourceID: 1, FieldID: 1000, Transform: iceberg.IdentityTransform{}, Name: "x", - }), meta.PartitionSpec()) - - assert.Equal(t, 0, meta.DefaultPartitionSpec()) - assert.Equal(t, 1000, *meta.LastPartitionSpecID()) - assert.Nil(t, data.CurrentSnapshotID) - assert.Nil(t, meta.CurrentSnapshot()) - assert.Len(t, meta.Snapshots(), 1) - assert.NotNil(t, meta.SnapshotByID(1925)) - assert.Nil(t, meta.SnapshotByID(0)) - assert.Nil(t, meta.SnapshotByName("foo")) - assert.Zero(t, data.DefaultSortOrderID) - assert.Equal(t, table.UnsortedSortOrder, meta.SortOrder()) -} - -func TestMetadataV2Parsing(t *testing.T) { - meta, err := table.ParseMetadataBytes([]byte(ExampleTableMetadataV2)) - require.NoError(t, err) - require.NotNil(t, meta) - - assert.IsType(t, (*table.MetadataV2)(nil), meta) - assert.Equal(t, 2, meta.Version()) - - data := meta.(*table.MetadataV2) - assert.Equal(t, uuid.MustParse("9c12d441-03fe-4693-9a96-a0705ddf69c1"), data.UUID) - assert.Equal(t, "s3://bucket/test/location", data.Location()) - assert.Equal(t, 34, data.LastSequenceNumber) - assert.Equal(t, int64(1602638573590), data.LastUpdatedMS) - assert.Equal(t, 3, data.LastColumnId) - assert.Equal(t, 0, data.SchemaList[0].ID) - assert.Equal(t, 1, data.CurrentSchemaID) - assert.Equal(t, 0, data.Specs[0].ID()) - assert.Equal(t, 0, data.DefaultSpecID) - assert.Equal(t, 1000, *data.LastPartitionID) - assert.EqualValues(t, "134217728", data.Props["read.split.target.size"]) - assert.EqualValues(t, 3055729675574597004, *data.CurrentSnapshotID) - assert.EqualValues(t, 3051729675574597004, data.SnapshotList[0].SnapshotID) - assert.Equal(t, int64(1515100955770), data.SnapshotLog[0].TimestampMs) - assert.Equal(t, 3, data.SortOrderList[0].OrderID) - assert.Equal(t, 3, data.DefaultSortOrderID) - - assert.Len(t, meta.Snapshots(), 2) - assert.Equal(t, data.SnapshotList[1], *meta.CurrentSnapshot()) - assert.Equal(t, data.SnapshotList[0], *meta.SnapshotByName("test")) - assert.EqualValues(t, "134217728", meta.Properties()["read.split.target.size"]) -} - -func TestParsingCorrectTypes(t *testing.T) { - var meta table.MetadataV2 - require.NoError(t, json.Unmarshal([]byte(ExampleTableMetadataV2), &meta)) - - assert.IsType(t, &iceberg.Schema{}, meta.SchemaList[0]) - assert.IsType(t, iceberg.NestedField{}, meta.SchemaList[0].Field(0)) - assert.IsType(t, iceberg.PrimitiveTypes.Int64, meta.SchemaList[0].Field(0).Type) -} - -func TestSerializeMetadataV1(t *testing.T) { - var meta table.MetadataV1 - require.NoError(t, json.Unmarshal([]byte(ExampleTableMetadataV1), &meta)) - - data, err := json.Marshal(&meta) - require.NoError(t, err) - - assert.JSONEq(t, `{"location": "s3://bucket/test/location", "table-uuid": "d20125c8-7284-442c-9aea-15fee620737c", "last-updated-ms": 1602638573874, "last-column-id": 3, "schemas": [{"type": "struct", "fields": [{"id": 1, "name": "x", "type": "long", "required": true}, {"id": 2, "name": "y", "type": "long", "required": true, "doc": "comment"}, {"id": 3, "name": "z", "type": "long", "required": true}], "schema-id": 0, "identifier-field-ids": []}], "current-schema-id": 0, "partition-specs": [{"spec-id": 0, "fields": [{"source-id": 1, "field-id": 1000, "transform": "identity", "name": "x"}]}], "default-spec-id": 0, "last-partition-id": 1000, "properties": {}, "snapshots": [{"snapshot-id": 1925, "sequence-number": 0, "timestamp-ms": 1602638573822}], "snapshot-log": [], "metadata-log": [], "sort-orders": [{"order-id": 0, "fields": []}], "default-sort-order-id": 0, "refs": {}, "format-version": 1, "schema": {"type": "struct", "fields": [{"id": 1, "name": "x", "type": "long", "required": true}, {"id": 2, "name": "y", "type": "long", "required": true, "doc": "comment"}, {"id": 3, "name": "z", "type": "long", "required": true}], "schema-id": 0, "identifier-field-ids": []}, "partition-spec": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]}`, - string(data)) -} - -func TestSerializeMetadataV2(t *testing.T) { - var meta table.MetadataV2 - require.NoError(t, json.Unmarshal([]byte(ExampleTableMetadataV2), &meta)) - - data, err := json.Marshal(&meta) - require.NoError(t, err) - - assert.JSONEq(t, `{"location": "s3://bucket/test/location", "table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1", "last-updated-ms": 1602638573590, "last-column-id": 3, "schemas": [{"type": "struct", "fields": [{"id": 1, "name": "x", "type": "long", "required": true}], "schema-id": 0, "identifier-field-ids": []}, {"type": "struct", "fields": [{"id": 1, "name": "x", "type": "long", "required": true}, {"id": 2, "name": "y", "type": "long", "required": true, "doc": "comment"}, {"id": 3, "name": "z", "type": "long", "required": true}], "schema-id": 1, "identifier-field-ids": [1, 2]}], "current-schema-id": 1, "partition-specs": [{"spec-id": 0, "fields": [{"source-id": 1, "field-id": 1000, "transform": "identity", "name": "x"}]}], "default-spec-id": 0, "last-partition-id": 1000, "properties": {"read.split.target.size": "134217728"}, "current-snapshot-id": 3055729675574597004, "snapshots": [{"snapshot-id": 3051729675574597004, "sequence-number": 0, "timestamp-ms": 1515100955770, "manifest-list": "s3://a/b/1.avro", "summary": {"operation": "append"}}, {"snapshot-id": 3055729675574597004, "parent-snapshot-id": 3051729675574597004, "sequence-number": 1, "timestamp-ms": 1555100955770, "manifest-list": "s3://a/b/2.avro", "summary": {"operation": "append"}, "schema-id": 1}], "snapshot-log": [{"snapshot-id": 3051729675574597004, "timestamp-ms": 1515100955770}, {"snapshot-id": 3055729675574597004, "timestamp-ms": 1555100955770}], "metadata-log": [{"metadata-file": "s3://bucket/.../v1.json", "timestamp-ms": 1515100}], "sort-orders": [{"order-id": 3, "fields": [{"source-id": 2, "transform": "identity", "direction": "asc", "null-order": "nulls-first"}, {"source-id": 3, "transform": "bucket[4]", "direction": "desc", "null-order": "nulls-last"}]}], "default-sort-order-id": 3, "refs": {"test": {"snapshot-id": 3051729675574597004, "type": "tag", "max-ref-age-ms": 10000000}, "main": {"snapshot-id": 3055729675574597004, "type": "branch"}}, "format-version": 2, "last-sequence-number": 34}`, - string(data)) -} - -func TestInvalidFormatVersion(t *testing.T) { - metadataInvalidFormat := `{ - "format-version": -1, - "table-uuid": "d20125c8-7284-442c-9aea-15fee620737c", - "location": "s3://bucket/test/location", - "last-updated-ms": 1602638573874, - "last-column-id": 3, - "schema": { - "type": "struct", - "fields": [ - {"id": 1, "name": "x", "required": true, "type": "long"}, - {"id": 2, "name": "y", "required": true, "type": "long", "doc": "comment"}, - {"id": 3, "name": "z", "required": true, "type": "long"} - ] - }, - "partition-spec": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}], - "properties": {}, - "current-snapshot-id": -1, - "snapshots": [] - }` - - _, err := table.ParseMetadataBytes([]byte(metadataInvalidFormat)) - assert.Error(t, err) - assert.ErrorIs(t, err, table.ErrInvalidMetadataFormatVersion) -} - -func TestCurrentSchemaNotFound(t *testing.T) { - schemaNotFound := `{ - "format-version": 2, - "table-uuid": "d20125c8-7284-442c-9aea-15fee620737c", - "location": "s3://bucket/test/location", - "last-updated-ms": 1602638573874, - "last-column-id": 3, - "schemas": [ - {"type": "struct", "schema-id": 0, "fields": [{"id": 1, "name": "x", "required": true, "type": "long"}]}, - { - "type": "struct", - "schema-id": 1, - "identifier-field-ids": [1, 2], - "fields": [ - {"id": 1, "name": "x", "required": true, "type": "long"}, - {"id": 2, "name": "y", "required": true, "type": "long", "doc": "comment"}, - {"id": 3, "name": "z", "required": true, "type": "long"} - ] - } - ], - "current-schema-id": 2, - "default-spec-id": 0, - "partition-specs": [{"spec-id": 0, "fields": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]}], - "last-partition-id": 1000, - "default-sort-order-id": 0, - "properties": {}, - "current-snapshot-id": -1, - "snapshots": [] - }` - - _, err := table.ParseMetadataBytes([]byte(schemaNotFound)) - assert.Error(t, err) - assert.ErrorIs(t, err, table.ErrInvalidMetadata) - assert.ErrorContains(t, err, "current-schema-id 2 can't be found in any schema") -} - -func TestSortOrderNotFound(t *testing.T) { - metadataSortOrderNotFound := `{ - "format-version": 2, - "table-uuid": "d20125c8-7284-442c-9aea-15fee620737c", - "location": "s3://bucket/test/location", - "last-updated-ms": 1602638573874, - "last-column-id": 3, - "schemas": [ - { - "type": "struct", - "schema-id": 0, - "identifier-field-ids": [1, 2], - "fields": [ - {"id": 1, "name": "x", "required": true, "type": "long"}, - {"id": 2, "name": "y", "required": true, "type": "long", "doc": "comment"}, - {"id": 3, "name": "z", "required": true, "type": "long"} - ] - } - ], - "default-sort-order-id": 4, - "sort-orders": [ - { - "order-id": 3, - "fields": [ - {"transform": "identity", "source-id": 2, "direction": "asc", "null-order": "nulls-first"}, - {"transform": "bucket[4]", "source-id": 3, "direction": "desc", "null-order": "nulls-last"} - ] - } - ], - "current-schema-id": 0, - "default-spec-id": 0, - "partition-specs": [{"spec-id": 0, "fields": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]}], - "last-partition-id": 1000, - "properties": {}, - "current-snapshot-id": -1, - "snapshots": [] - }` - - _, err := table.ParseMetadataBytes([]byte(metadataSortOrderNotFound)) - assert.Error(t, err) - assert.ErrorIs(t, err, table.ErrInvalidMetadata) - assert.ErrorContains(t, err, "default-sort-order-id 4 can't be found in [3: [\n2 asc nulls-first\nbucket[4](3) desc nulls-last\n]]") -} - -func TestSortOrderUnsorted(t *testing.T) { - sortOrderUnsorted := `{ - "format-version": 2, - "table-uuid": "d20125c8-7284-442c-9aea-15fee620737c", - "location": "s3://bucket/test/location", - "last-updated-ms": 1602638573874, - "last-column-id": 3, - "schemas": [ - { - "type": "struct", - "schema-id": 0, - "identifier-field-ids": [1, 2], - "fields": [ - {"id": 1, "name": "x", "required": true, "type": "long"}, - {"id": 2, "name": "y", "required": true, "type": "long", "doc": "comment"}, - {"id": 3, "name": "z", "required": true, "type": "long"} - ] - } - ], - "default-sort-order-id": 0, - "sort-orders": [], - "current-schema-id": 0, - "default-spec-id": 0, - "partition-specs": [{"spec-id": 0, "fields": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]}], - "last-partition-id": 1000, - "properties": {}, - "current-snapshot-id": -1, - "snapshots": [] - }` - - var meta table.MetadataV2 - require.NoError(t, json.Unmarshal([]byte(sortOrderUnsorted), &meta)) - - assert.Equal(t, table.UnsortedSortOrderID, meta.DefaultSortOrderID) - assert.Len(t, meta.SortOrderList, 0) -} - -func TestInvalidPartitionSpecID(t *testing.T) { - invalidSpecID := `{ - "format-version": 2, - "table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1", - "location": "s3://bucket/test/location", - "last-sequence-number": 34, - "last-updated-ms": 1602638573590, - "last-column-id": 3, - "current-schema-id": 1, - "schemas": [ - {"type": "struct", "schema-id": 0, "fields": [{"id": 1, "name": "x", "required": true, "type": "long"}]}, - { - "type": "struct", - "schema-id": 1, - "identifier-field-ids": [1, 2], - "fields": [ - {"id": 1, "name": "x", "required": true, "type": "long"}, - {"id": 2, "name": "y", "required": true, "type": "long", "doc": "comment"}, - {"id": 3, "name": "z", "required": true, "type": "long"} - ] - } - ], - "sort-orders": [], - "default-sort-order-id": 0, - "default-spec-id": 1, - "partition-specs": [{"spec-id": 0, "fields": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]}], - "last-partition-id": 1000 - }` - - var meta table.MetadataV2 - err := json.Unmarshal([]byte(invalidSpecID), &meta) - assert.ErrorIs(t, err, table.ErrInvalidMetadata) - assert.ErrorContains(t, err, "default-spec-id 1 can't be found") -} - -func TestV2RefCreation(t *testing.T) { - var meta table.MetadataV2 - require.NoError(t, json.Unmarshal([]byte(ExampleTableMetadataV2), &meta)) - - maxRefAge := int64(10000000) - assert.Equal(t, map[string]table.SnapshotRef{ - "main": { - SnapshotID: 3055729675574597004, - SnapshotRefType: table.BranchRef, - }, - "test": { - SnapshotID: 3051729675574597004, - SnapshotRefType: table.TagRef, - MaxRefAgeMs: &maxRefAge, - }, - }, meta.Refs) -} - -func TestV1WriteMetadataToV2(t *testing.T) { - // https://iceberg.apache.org/spec/#version-2 - // - // Table metadata JSON: - // - last-sequence-number was added and is required; default to 0 when reading v1 metadata - // - table-uuid is now required - // - current-schema-id is now required - // - schemas is now required - // - partition-specs is now required - // - default-spec-id is now required - // - last-partition-id is now required - // - sort-orders is now required - // - default-sort-order-id is now required - // - schema is no longer required and should be omitted; use schemas and current-schema-id instead - // - partition-spec is no longer required and should be omitted; use partition-specs and default-spec-id instead - - minimalV1Example := `{ - "format-version": 1, - "location": "s3://bucket/test/location", - "last-updated-ms": 1062638573874, - "last-column-id": 3, - "schema": { - "type": "struct", - "fields": [ - {"id": 1, "name": "x", "required": true, "type": "long"}, - {"id": 2, "name": "y", "required": true, "type": "long", "doc": "comment"}, - {"id": 3, "name": "z", "required": true, "type": "long"} - ] - }, - "partition-spec": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}], - "properties": {}, - "current-snapshot-id": -1, - "snapshots": [{"snapshot-id": 1925, "timestamp-ms": 1602638573822}] - }` - - meta, err := table.ParseMetadataString(minimalV1Example) - require.NoError(t, err) - assert.IsType(t, (*table.MetadataV1)(nil), meta) - - metaV2 := meta.(*table.MetadataV1).ToV2() - metaV2Json, err := json.Marshal(metaV2) - require.NoError(t, err) - - rawData := make(map[string]any) - require.NoError(t, json.Unmarshal(metaV2Json, &rawData)) - - assert.EqualValues(t, 0, rawData["last-sequence-number"]) - assert.NotEmpty(t, rawData["table-uuid"]) - assert.EqualValues(t, 0, rawData["current-schema-id"]) - assert.Equal(t, []any{map[string]any{ - "fields": []any{ - map[string]any{"id": float64(1), "name": "x", "required": true, "type": "long"}, - map[string]any{"id": float64(2), "name": "y", "required": true, "type": "long", "doc": "comment"}, - map[string]any{"id": float64(3), "name": "z", "required": true, "type": "long"}, - }, - "identifier-field-ids": []any{}, - "schema-id": float64(0), - "type": "struct", - }}, rawData["schemas"]) - assert.Equal(t, []any{map[string]any{ - "spec-id": float64(0), - "fields": []any{map[string]any{ - "name": "x", "transform": "identity", - "source-id": float64(1), "field-id": float64(1000), - }}, - }}, rawData["partition-specs"]) - - assert.Zero(t, rawData["default-spec-id"]) - assert.EqualValues(t, 1000, rawData["last-partition-id"]) - assert.Zero(t, rawData["default-sort-order-id"]) - assert.Equal(t, []any{map[string]any{"order-id": float64(0), "fields": []any{}}}, rawData["sort-orders"]) - assert.NotContains(t, rawData, "schema") - assert.NotContains(t, rawData, "partition-spec") -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package table_test + +import ( + "encoding/json" + "slices" + "testing" + + "github.com/apache/iceberg-go" + "github.com/apache/iceberg-go/table" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ExampleTableMetadataV2 = `{ + "format-version": 2, + "table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1", + "location": "s3://bucket/test/location", + "last-sequence-number": 34, + "last-updated-ms": 1602638573590, + "last-column-id": 3, + "current-schema-id": 1, + "schemas": [ + {"type": "struct", "schema-id": 0, "fields": [{"id": 1, "name": "x", "required": true, "type": "long"}]}, + { + "type": "struct", + "schema-id": 1, + "identifier-field-ids": [1, 2], + "fields": [ + {"id": 1, "name": "x", "required": true, "type": "long"}, + {"id": 2, "name": "y", "required": true, "type": "long", "doc": "comment"}, + {"id": 3, "name": "z", "required": true, "type": "long"} + ] + } + ], + "default-spec-id": 0, + "partition-specs": [{"spec-id": 0, "fields": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]}], + "last-partition-id": 1000, + "default-sort-order-id": 3, + "sort-orders": [ + { + "order-id": 3, + "fields": [ + {"transform": "identity", "source-id": 2, "direction": "asc", "null-order": "nulls-first"}, + {"transform": "bucket[4]", "source-id": 3, "direction": "desc", "null-order": "nulls-last"} + ] + } + ], + "properties": {"read.split.target.size": "134217728"}, + "current-snapshot-id": 3055729675574597004, + "snapshots": [ + { + "snapshot-id": 3051729675574597004, + "timestamp-ms": 1515100955770, + "sequence-number": 0, + "summary": {"operation": "append"}, + "manifest-list": "s3://a/b/1.avro" + }, + { + "snapshot-id": 3055729675574597004, + "parent-snapshot-id": 3051729675574597004, + "timestamp-ms": 1555100955770, + "sequence-number": 1, + "summary": {"operation": "append"}, + "manifest-list": "s3://a/b/2.avro", + "schema-id": 1 + } + ], + "snapshot-log": [ + {"snapshot-id": 3051729675574597004, "timestamp-ms": 1515100955770}, + {"snapshot-id": 3055729675574597004, "timestamp-ms": 1555100955770} + ], + "metadata-log": [{"metadata-file": "s3://bucket/.../v1.json", "timestamp-ms": 1515100}], + "refs": {"test": {"snapshot-id": 3051729675574597004, "type": "tag", "max-ref-age-ms": 10000000}} +}` + +const ExampleTableMetadataV1 = `{ + "format-version": 1, + "table-uuid": "d20125c8-7284-442c-9aea-15fee620737c", + "location": "s3://bucket/test/location", + "last-updated-ms": 1602638573874, + "last-column-id": 3, + "schema": { + "type": "struct", + "fields": [ + {"id": 1, "name": "x", "required": true, "type": "long"}, + {"id": 2, "name": "y", "required": true, "type": "long", "doc": "comment"}, + {"id": 3, "name": "z", "required": true, "type": "long"} + ] + }, + "partition-spec": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}], + "properties": {}, + "current-snapshot-id": -1, + "snapshots": [{"snapshot-id": 1925, "timestamp-ms": 1602638573822}] +}` + +func TestMetadataV1Parsing(t *testing.T) { + meta, err := table.ParseMetadataBytes([]byte(ExampleTableMetadataV1)) + require.NoError(t, err) + require.NotNil(t, meta) + + assert.IsType(t, (*table.MetadataV1)(nil), meta) + assert.Equal(t, 1, meta.Version()) + + data := meta.(*table.MetadataV1) + assert.Equal(t, uuid.MustParse("d20125c8-7284-442c-9aea-15fee620737c"), meta.TableUUID()) + assert.Equal(t, "s3://bucket/test/location", meta.Location()) + assert.Equal(t, int64(1602638573874), meta.LastUpdatedMillis()) + assert.Equal(t, 3, meta.LastColumnID()) + + expected := iceberg.NewSchema( + 0, + iceberg.NestedField{ID: 1, Name: "x", Type: iceberg.PrimitiveTypes.Int64, Required: true}, + iceberg.NestedField{ID: 2, Name: "y", Type: iceberg.PrimitiveTypes.Int64, Required: true, Doc: "comment"}, + iceberg.NestedField{ID: 3, Name: "z", Type: iceberg.PrimitiveTypes.Int64, Required: true}, + ) + + assert.True(t, slices.EqualFunc([]*iceberg.Schema{expected}, meta.Schemas(), func(s1, s2 *iceberg.Schema) bool { + return s1.Equals(s2) + })) + assert.Zero(t, data.SchemaList[0].ID) + assert.True(t, meta.CurrentSchema().Equals(expected)) + assert.Equal(t, []iceberg.PartitionSpec{ + iceberg.NewPartitionSpec(iceberg.PartitionField{ + SourceID: 1, FieldID: 1000, Transform: iceberg.IdentityTransform{}, Name: "x", + }), + }, meta.PartitionSpecs()) + + assert.Equal(t, iceberg.NewPartitionSpec(iceberg.PartitionField{ + SourceID: 1, FieldID: 1000, Transform: iceberg.IdentityTransform{}, Name: "x", + }), meta.PartitionSpec()) + + assert.Equal(t, 0, meta.DefaultPartitionSpec()) + assert.Equal(t, 1000, *meta.LastPartitionSpecID()) + assert.Nil(t, data.CurrentSnapshotID) + assert.Nil(t, meta.CurrentSnapshot()) + assert.Len(t, meta.Snapshots(), 1) + assert.NotNil(t, meta.SnapshotByID(1925)) + assert.Nil(t, meta.SnapshotByID(0)) + assert.Nil(t, meta.SnapshotByName("foo")) + assert.Zero(t, data.DefaultSortOrderID) + assert.Equal(t, table.UnsortedSortOrder, meta.SortOrder()) +} + +func TestMetadataV2Parsing(t *testing.T) { + meta, err := table.ParseMetadataBytes([]byte(ExampleTableMetadataV2)) + require.NoError(t, err) + require.NotNil(t, meta) + + assert.IsType(t, (*table.MetadataV2)(nil), meta) + assert.Equal(t, 2, meta.Version()) + + data := meta.(*table.MetadataV2) + assert.Equal(t, uuid.MustParse("9c12d441-03fe-4693-9a96-a0705ddf69c1"), data.UUID) + assert.Equal(t, "s3://bucket/test/location", data.Location()) + assert.Equal(t, 34, data.LastSequenceNumber) + assert.Equal(t, int64(1602638573590), data.LastUpdatedMS) + assert.Equal(t, 3, data.LastColumnId) + assert.Equal(t, 0, data.SchemaList[0].ID) + assert.Equal(t, 1, data.CurrentSchemaID) + assert.Equal(t, 0, data.Specs[0].ID()) + assert.Equal(t, 0, data.DefaultSpecID) + assert.Equal(t, 1000, *data.LastPartitionID) + assert.EqualValues(t, "134217728", data.Props["read.split.target.size"]) + assert.EqualValues(t, 3055729675574597004, *data.CurrentSnapshotID) + assert.EqualValues(t, 3051729675574597004, data.SnapshotList[0].SnapshotID) + assert.Equal(t, int64(1515100955770), data.SnapshotLog[0].TimestampMs) + assert.Equal(t, 3, data.SortOrderList[0].OrderID) + assert.Equal(t, 3, data.DefaultSortOrderID) + + assert.Len(t, meta.Snapshots(), 2) + assert.Equal(t, data.SnapshotList[1], *meta.CurrentSnapshot()) + assert.Equal(t, data.SnapshotList[0], *meta.SnapshotByName("test")) + assert.EqualValues(t, "134217728", meta.Properties()["read.split.target.size"]) +} + +func TestParsingCorrectTypes(t *testing.T) { + var meta table.MetadataV2 + require.NoError(t, json.Unmarshal([]byte(ExampleTableMetadataV2), &meta)) + + assert.IsType(t, &iceberg.Schema{}, meta.SchemaList[0]) + assert.IsType(t, iceberg.NestedField{}, meta.SchemaList[0].Field(0)) + assert.IsType(t, iceberg.PrimitiveTypes.Int64, meta.SchemaList[0].Field(0).Type) +} + +func TestSerializeMetadataV1(t *testing.T) { + var meta table.MetadataV1 + require.NoError(t, json.Unmarshal([]byte(ExampleTableMetadataV1), &meta)) + + data, err := json.Marshal(&meta) + require.NoError(t, err) + + assert.JSONEq(t, `{"location": "s3://bucket/test/location", "table-uuid": "d20125c8-7284-442c-9aea-15fee620737c", "last-updated-ms": 1602638573874, "last-column-id": 3, "schemas": [{"type": "struct", "fields": [{"id": 1, "name": "x", "type": "long", "required": true}, {"id": 2, "name": "y", "type": "long", "required": true, "doc": "comment"}, {"id": 3, "name": "z", "type": "long", "required": true}], "schema-id": 0, "identifier-field-ids": []}], "current-schema-id": 0, "partition-specs": [{"spec-id": 0, "fields": [{"source-id": 1, "field-id": 1000, "transform": "identity", "name": "x"}]}], "default-spec-id": 0, "last-partition-id": 1000, "properties": {}, "snapshots": [{"snapshot-id": 1925, "sequence-number": 0, "timestamp-ms": 1602638573822}], "snapshot-log": [], "metadata-log": [], "sort-orders": [{"order-id": 0, "fields": []}], "default-sort-order-id": 0, "refs": {}, "format-version": 1, "schema": {"type": "struct", "fields": [{"id": 1, "name": "x", "type": "long", "required": true}, {"id": 2, "name": "y", "type": "long", "required": true, "doc": "comment"}, {"id": 3, "name": "z", "type": "long", "required": true}], "schema-id": 0, "identifier-field-ids": []}, "partition-spec": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]}`, + string(data)) +} + +func TestSerializeMetadataV2(t *testing.T) { + var meta table.MetadataV2 + require.NoError(t, json.Unmarshal([]byte(ExampleTableMetadataV2), &meta)) + + data, err := json.Marshal(&meta) + require.NoError(t, err) + + assert.JSONEq(t, `{"location": "s3://bucket/test/location", "table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1", "last-updated-ms": 1602638573590, "last-column-id": 3, "schemas": [{"type": "struct", "fields": [{"id": 1, "name": "x", "type": "long", "required": true}], "schema-id": 0, "identifier-field-ids": []}, {"type": "struct", "fields": [{"id": 1, "name": "x", "type": "long", "required": true}, {"id": 2, "name": "y", "type": "long", "required": true, "doc": "comment"}, {"id": 3, "name": "z", "type": "long", "required": true}], "schema-id": 1, "identifier-field-ids": [1, 2]}], "current-schema-id": 1, "partition-specs": [{"spec-id": 0, "fields": [{"source-id": 1, "field-id": 1000, "transform": "identity", "name": "x"}]}], "default-spec-id": 0, "last-partition-id": 1000, "properties": {"read.split.target.size": "134217728"}, "current-snapshot-id": 3055729675574597004, "snapshots": [{"snapshot-id": 3051729675574597004, "sequence-number": 0, "timestamp-ms": 1515100955770, "manifest-list": "s3://a/b/1.avro", "summary": {"operation": "append"}}, {"snapshot-id": 3055729675574597004, "parent-snapshot-id": 3051729675574597004, "sequence-number": 1, "timestamp-ms": 1555100955770, "manifest-list": "s3://a/b/2.avro", "summary": {"operation": "append"}, "schema-id": 1}], "snapshot-log": [{"snapshot-id": 3051729675574597004, "timestamp-ms": 1515100955770}, {"snapshot-id": 3055729675574597004, "timestamp-ms": 1555100955770}], "metadata-log": [{"metadata-file": "s3://bucket/.../v1.json", "timestamp-ms": 1515100}], "sort-orders": [{"order-id": 3, "fields": [{"source-id": 2, "transform": "identity", "direction": "asc", "null-order": "nulls-first"}, {"source-id": 3, "transform": "bucket[4]", "direction": "desc", "null-order": "nulls-last"}]}], "default-sort-order-id": 3, "refs": {"test": {"snapshot-id": 3051729675574597004, "type": "tag", "max-ref-age-ms": 10000000}, "main": {"snapshot-id": 3055729675574597004, "type": "branch"}}, "format-version": 2, "last-sequence-number": 34}`, + string(data)) +} + +func TestInvalidFormatVersion(t *testing.T) { + metadataInvalidFormat := `{ + "format-version": -1, + "table-uuid": "d20125c8-7284-442c-9aea-15fee620737c", + "location": "s3://bucket/test/location", + "last-updated-ms": 1602638573874, + "last-column-id": 3, + "schema": { + "type": "struct", + "fields": [ + {"id": 1, "name": "x", "required": true, "type": "long"}, + {"id": 2, "name": "y", "required": true, "type": "long", "doc": "comment"}, + {"id": 3, "name": "z", "required": true, "type": "long"} + ] + }, + "partition-spec": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}], + "properties": {}, + "current-snapshot-id": -1, + "snapshots": [] + }` + + _, err := table.ParseMetadataBytes([]byte(metadataInvalidFormat)) + assert.Error(t, err) + assert.ErrorIs(t, err, table.ErrInvalidMetadataFormatVersion) +} + +func TestCurrentSchemaNotFound(t *testing.T) { + schemaNotFound := `{ + "format-version": 2, + "table-uuid": "d20125c8-7284-442c-9aea-15fee620737c", + "location": "s3://bucket/test/location", + "last-updated-ms": 1602638573874, + "last-column-id": 3, + "schemas": [ + {"type": "struct", "schema-id": 0, "fields": [{"id": 1, "name": "x", "required": true, "type": "long"}]}, + { + "type": "struct", + "schema-id": 1, + "identifier-field-ids": [1, 2], + "fields": [ + {"id": 1, "name": "x", "required": true, "type": "long"}, + {"id": 2, "name": "y", "required": true, "type": "long", "doc": "comment"}, + {"id": 3, "name": "z", "required": true, "type": "long"} + ] + } + ], + "current-schema-id": 2, + "default-spec-id": 0, + "partition-specs": [{"spec-id": 0, "fields": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]}], + "last-partition-id": 1000, + "default-sort-order-id": 0, + "properties": {}, + "current-snapshot-id": -1, + "snapshots": [] + }` + + _, err := table.ParseMetadataBytes([]byte(schemaNotFound)) + assert.Error(t, err) + assert.ErrorIs(t, err, table.ErrInvalidMetadata) + assert.ErrorContains(t, err, "current-schema-id 2 can't be found in any schema") +} + +func TestSortOrderNotFound(t *testing.T) { + metadataSortOrderNotFound := `{ + "format-version": 2, + "table-uuid": "d20125c8-7284-442c-9aea-15fee620737c", + "location": "s3://bucket/test/location", + "last-updated-ms": 1602638573874, + "last-column-id": 3, + "schemas": [ + { + "type": "struct", + "schema-id": 0, + "identifier-field-ids": [1, 2], + "fields": [ + {"id": 1, "name": "x", "required": true, "type": "long"}, + {"id": 2, "name": "y", "required": true, "type": "long", "doc": "comment"}, + {"id": 3, "name": "z", "required": true, "type": "long"} + ] + } + ], + "default-sort-order-id": 4, + "sort-orders": [ + { + "order-id": 3, + "fields": [ + {"transform": "identity", "source-id": 2, "direction": "asc", "null-order": "nulls-first"}, + {"transform": "bucket[4]", "source-id": 3, "direction": "desc", "null-order": "nulls-last"} + ] + } + ], + "current-schema-id": 0, + "default-spec-id": 0, + "partition-specs": [{"spec-id": 0, "fields": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]}], + "last-partition-id": 1000, + "properties": {}, + "current-snapshot-id": -1, + "snapshots": [] + }` + + _, err := table.ParseMetadataBytes([]byte(metadataSortOrderNotFound)) + assert.Error(t, err) + assert.ErrorIs(t, err, table.ErrInvalidMetadata) + assert.ErrorContains(t, err, "default-sort-order-id 4 can't be found in [3: [\n2 asc nulls-first\nbucket[4](3) desc nulls-last\n]]") +} + +func TestSortOrderUnsorted(t *testing.T) { + sortOrderUnsorted := `{ + "format-version": 2, + "table-uuid": "d20125c8-7284-442c-9aea-15fee620737c", + "location": "s3://bucket/test/location", + "last-updated-ms": 1602638573874, + "last-column-id": 3, + "schemas": [ + { + "type": "struct", + "schema-id": 0, + "identifier-field-ids": [1, 2], + "fields": [ + {"id": 1, "name": "x", "required": true, "type": "long"}, + {"id": 2, "name": "y", "required": true, "type": "long", "doc": "comment"}, + {"id": 3, "name": "z", "required": true, "type": "long"} + ] + } + ], + "default-sort-order-id": 0, + "sort-orders": [], + "current-schema-id": 0, + "default-spec-id": 0, + "partition-specs": [{"spec-id": 0, "fields": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]}], + "last-partition-id": 1000, + "properties": {}, + "current-snapshot-id": -1, + "snapshots": [] + }` + + var meta table.MetadataV2 + require.NoError(t, json.Unmarshal([]byte(sortOrderUnsorted), &meta)) + + assert.Equal(t, table.UnsortedSortOrderID, meta.DefaultSortOrderID) + assert.Len(t, meta.SortOrderList, 0) +} + +func TestInvalidPartitionSpecID(t *testing.T) { + invalidSpecID := `{ + "format-version": 2, + "table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1", + "location": "s3://bucket/test/location", + "last-sequence-number": 34, + "last-updated-ms": 1602638573590, + "last-column-id": 3, + "current-schema-id": 1, + "schemas": [ + {"type": "struct", "schema-id": 0, "fields": [{"id": 1, "name": "x", "required": true, "type": "long"}]}, + { + "type": "struct", + "schema-id": 1, + "identifier-field-ids": [1, 2], + "fields": [ + {"id": 1, "name": "x", "required": true, "type": "long"}, + {"id": 2, "name": "y", "required": true, "type": "long", "doc": "comment"}, + {"id": 3, "name": "z", "required": true, "type": "long"} + ] + } + ], + "sort-orders": [], + "default-sort-order-id": 0, + "default-spec-id": 1, + "partition-specs": [{"spec-id": 0, "fields": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]}], + "last-partition-id": 1000 + }` + + var meta table.MetadataV2 + err := json.Unmarshal([]byte(invalidSpecID), &meta) + assert.ErrorIs(t, err, table.ErrInvalidMetadata) + assert.ErrorContains(t, err, "default-spec-id 1 can't be found") +} + +func TestV2RefCreation(t *testing.T) { + var meta table.MetadataV2 + require.NoError(t, json.Unmarshal([]byte(ExampleTableMetadataV2), &meta)) + + maxRefAge := int64(10000000) + assert.Equal(t, map[string]table.SnapshotRef{ + "main": { + SnapshotID: 3055729675574597004, + SnapshotRefType: table.BranchRef, + }, + "test": { + SnapshotID: 3051729675574597004, + SnapshotRefType: table.TagRef, + MaxRefAgeMs: &maxRefAge, + }, + }, meta.Refs) +} + +func TestV1WriteMetadataToV2(t *testing.T) { + // https://iceberg.apache.org/spec/#version-2 + // + // Table metadata JSON: + // - last-sequence-number was added and is required; default to 0 when reading v1 metadata + // - table-uuid is now required + // - current-schema-id is now required + // - schemas is now required + // - partition-specs is now required + // - default-spec-id is now required + // - last-partition-id is now required + // - sort-orders is now required + // - default-sort-order-id is now required + // - schema is no longer required and should be omitted; use schemas and current-schema-id instead + // - partition-spec is no longer required and should be omitted; use partition-specs and default-spec-id instead + + minimalV1Example := `{ + "format-version": 1, + "location": "s3://bucket/test/location", + "last-updated-ms": 1062638573874, + "last-column-id": 3, + "schema": { + "type": "struct", + "fields": [ + {"id": 1, "name": "x", "required": true, "type": "long"}, + {"id": 2, "name": "y", "required": true, "type": "long", "doc": "comment"}, + {"id": 3, "name": "z", "required": true, "type": "long"} + ] + }, + "partition-spec": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}], + "properties": {}, + "current-snapshot-id": -1, + "snapshots": [{"snapshot-id": 1925, "timestamp-ms": 1602638573822}] + }` + + meta, err := table.ParseMetadataString(minimalV1Example) + require.NoError(t, err) + assert.IsType(t, (*table.MetadataV1)(nil), meta) + + metaV2 := meta.(*table.MetadataV1).ToV2() + metaV2Json, err := json.Marshal(metaV2) + require.NoError(t, err) + + rawData := make(map[string]any) + require.NoError(t, json.Unmarshal(metaV2Json, &rawData)) + + assert.EqualValues(t, 0, rawData["last-sequence-number"]) + assert.NotEmpty(t, rawData["table-uuid"]) + assert.EqualValues(t, 0, rawData["current-schema-id"]) + assert.Equal(t, []any{map[string]any{ + "fields": []any{ + map[string]any{"id": float64(1), "name": "x", "required": true, "type": "long"}, + map[string]any{"id": float64(2), "name": "y", "required": true, "type": "long", "doc": "comment"}, + map[string]any{"id": float64(3), "name": "z", "required": true, "type": "long"}, + }, + "identifier-field-ids": []any{}, + "schema-id": float64(0), + "type": "struct", + }}, rawData["schemas"]) + assert.Equal(t, []any{map[string]any{ + "spec-id": float64(0), + "fields": []any{map[string]any{ + "name": "x", "transform": "identity", + "source-id": float64(1), "field-id": float64(1000), + }}, + }}, rawData["partition-specs"]) + + assert.Zero(t, rawData["default-spec-id"]) + assert.EqualValues(t, 1000, rawData["last-partition-id"]) + assert.Zero(t, rawData["default-sort-order-id"]) + assert.Equal(t, []any{map[string]any{"order-id": float64(0), "fields": []any{}}}, rawData["sort-orders"]) + assert.NotContains(t, rawData, "schema") + assert.NotContains(t, rawData, "partition-spec") +} diff --git a/table/name_mapping.go b/table/name_mapping.go index b71b7d3..9ce80aa 100644 --- a/table/name_mapping.go +++ b/table/name_mapping.go @@ -1,296 +1,296 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package table - -import ( - "fmt" - "slices" - "strconv" - "strings" - - "github.com/apache/iceberg-go" -) - -type MappedField struct { - Names []string `json:"names"` - // iceberg spec says this is optional, but I don't see any examples - // of this being left empty. Does pyiceberg need to be updated or should - // the spec not say field-id is optional? - FieldID *int `json:"field-id,omitempty"` - Fields []MappedField `json:"fields,omitempty"` -} - -func (m *MappedField) Len() int { return len(m.Fields) } - -func (m *MappedField) String() string { - var bldr strings.Builder - bldr.WriteString("([") - bldr.WriteString(strings.Join(m.Names, ", ")) - bldr.WriteString("] -> ") - - if m.FieldID != nil { - bldr.WriteString(strconv.Itoa(*m.FieldID)) - } else { - bldr.WriteByte('?') - } - - if len(m.Fields) > 0 { - bldr.WriteByte(' ') - for i, f := range m.Fields { - if i != 0 { - bldr.WriteString(", ") - } - bldr.WriteString(f.String()) - } - } - - bldr.WriteByte(')') - return bldr.String() -} - -type NameMapping []MappedField - -func (nm NameMapping) String() string { - var bldr strings.Builder - bldr.WriteString("[\n") - for _, f := range nm { - bldr.WriteByte('\t') - bldr.WriteString(f.String()) - bldr.WriteByte('\n') - } - bldr.WriteByte(']') - return bldr.String() -} - -type NameMappingVisitor[S, T any] interface { - Mapping(nm NameMapping, fieldResults S) S - Fields(st []MappedField, fieldResults []T) S - Field(field MappedField, fieldResult S) T -} - -func VisitNameMapping[S, T any](obj NameMapping, visitor NameMappingVisitor[S, T]) (res S, err error) { - if obj == nil { - err = fmt.Errorf("%w: cannot visit nil NameMapping", iceberg.ErrInvalidArgument) - return - } - - defer recoverError(&err) - - return visitor.Mapping(obj, visitMappedFields([]MappedField(obj), visitor)), err -} - -func VisitMappedFields[S, T any](fields []MappedField, visitor NameMappingVisitor[S, T]) (res S, err error) { - defer recoverError(&err) - - return visitMappedFields(fields, visitor), err -} - -func visitMappedFields[S, T any](fields []MappedField, visitor NameMappingVisitor[S, T]) S { - results := make([]T, len(fields)) - for i, f := range fields { - results[i] = visitor.Field(f, visitMappedFields(f.Fields, visitor)) - } - - return visitor.Fields(fields, results) -} - -type NameMappingAccessor struct{} - -func (NameMappingAccessor) SchemaPartner(partner *MappedField) *MappedField { - return partner -} - -func (NameMappingAccessor) getField(p *MappedField, field string) *MappedField { - for _, f := range p.Fields { - if slices.Contains(f.Names, field) { - return &f - } - } - - return nil -} - -func (n NameMappingAccessor) FieldPartner(partnerStruct *MappedField, _ int, fieldName string) *MappedField { - if partnerStruct == nil { - return nil - } - - return n.getField(partnerStruct, fieldName) -} - -func (n NameMappingAccessor) ListElementPartner(partnerList *MappedField) *MappedField { - if partnerList == nil { - return nil - } - - return n.getField(partnerList, "element") -} - -func (n NameMappingAccessor) MapKeyPartner(partnerMap *MappedField) *MappedField { - if partnerMap == nil { - return nil - } - - return n.getField(partnerMap, "key") -} - -func (n NameMappingAccessor) MapValuePartner(partnerMap *MappedField) *MappedField { - if partnerMap == nil { - return nil - } - - return n.getField(partnerMap, "value") -} - -type nameMapProjectVisitor struct { - currentPath []string -} - -func (n *nameMapProjectVisitor) popPath() { - n.currentPath = n.currentPath[:len(n.currentPath)-1] -} - -func (n *nameMapProjectVisitor) BeforeField(f iceberg.NestedField, _ *MappedField) { - n.currentPath = append(n.currentPath, f.Name) -} - -func (n *nameMapProjectVisitor) AfterField(iceberg.NestedField, *MappedField) { - n.popPath() -} - -func (n *nameMapProjectVisitor) BeforeListElement(iceberg.NestedField, *MappedField) { - n.currentPath = append(n.currentPath, "element") -} - -func (n *nameMapProjectVisitor) AfterListElement(iceberg.NestedField, *MappedField) { - n.popPath() -} - -func (n *nameMapProjectVisitor) BeforeMapKey(iceberg.NestedField, *MappedField) { - n.currentPath = append(n.currentPath, "key") -} - -func (n *nameMapProjectVisitor) AfterMapKey(iceberg.NestedField, *MappedField) { - n.popPath() -} - -func (n *nameMapProjectVisitor) BeforeMapValue(iceberg.NestedField, *MappedField) { - n.currentPath = append(n.currentPath, "value") -} - -func (n *nameMapProjectVisitor) AfterMapValue(iceberg.NestedField, *MappedField) { - n.popPath() -} - -func (n *nameMapProjectVisitor) Schema(_ *iceberg.Schema, _ *MappedField, structResult iceberg.NestedField) iceberg.NestedField { - return structResult -} - -func (n *nameMapProjectVisitor) Struct(_ iceberg.StructType, _ *MappedField, fieldResults []iceberg.NestedField) iceberg.NestedField { - return iceberg.NestedField{ - Type: &iceberg.StructType{FieldList: fieldResults}, - } -} - -func (n *nameMapProjectVisitor) Field(field iceberg.NestedField, fieldPartner *MappedField, fieldResult iceberg.NestedField) iceberg.NestedField { - if fieldPartner == nil { - panic(fmt.Errorf("%w: field missing from name mapping: %s", - iceberg.ErrInvalidArgument, strings.Join(n.currentPath, "."))) - } - - return iceberg.NestedField{ - ID: *fieldPartner.FieldID, - Name: field.Name, - Type: fieldResult.Type, - Required: field.Required, - Doc: field.Doc, - InitialDefault: field.InitialDefault, - WriteDefault: field.WriteDefault, - } -} - -func (nameMapProjectVisitor) mappedFieldID(mapped *MappedField, name string) int { - for _, f := range mapped.Fields { - if slices.Contains(f.Names, name) { - if f.FieldID != nil { - return *f.FieldID - } - return -1 - } - } - - return -1 -} - -func (n *nameMapProjectVisitor) List(lt iceberg.ListType, listPartner *MappedField, elemResult iceberg.NestedField) iceberg.NestedField { - if listPartner == nil { - panic(fmt.Errorf("%w: field missing from name mapping: %s", - iceberg.ErrInvalidArgument, strings.Join(n.currentPath, "."))) - } - - elementID := n.mappedFieldID(listPartner, "element") - - return iceberg.NestedField{ - Type: &iceberg.ListType{ - ElementID: elementID, - Element: elemResult.Type, - ElementRequired: lt.ElementRequired, - }, - } -} - -func (n *nameMapProjectVisitor) Map(m iceberg.MapType, mapPartner *MappedField, keyResult, valResult iceberg.NestedField) iceberg.NestedField { - if mapPartner == nil { - panic(fmt.Errorf("%w: field missing from name mapping: %s", - iceberg.ErrInvalidArgument, strings.Join(n.currentPath, "."))) - } - - keyID := n.mappedFieldID(mapPartner, "key") - valID := n.mappedFieldID(mapPartner, "value") - return iceberg.NestedField{ - Type: &iceberg.MapType{ - KeyID: keyID, - KeyType: keyResult.Type, - ValueID: valID, - ValueType: valResult.Type, - ValueRequired: m.ValueRequired, - }, - } -} - -func (n *nameMapProjectVisitor) Primitive(p iceberg.PrimitiveType, primitivePartner *MappedField) iceberg.NestedField { - if primitivePartner == nil { - panic(fmt.Errorf("%w: field missing from name mapping: %s", - iceberg.ErrInvalidArgument, strings.Join(n.currentPath, "."))) - } - - return iceberg.NestedField{Type: p} -} - -func ApplyNameMapping(schemaWithoutIDs *iceberg.Schema, nameMapping NameMapping) (*iceberg.Schema, error) { - top, err := iceberg.VisitSchemaWithPartner[iceberg.NestedField, *MappedField](schemaWithoutIDs, - &MappedField{Fields: nameMapping}, - &nameMapProjectVisitor{currentPath: make([]string, 0, 1)}, - NameMappingAccessor{}) - if err != nil { - return nil, err - } - - return iceberg.NewSchema(schemaWithoutIDs.ID, - top.Type.(*iceberg.StructType).FieldList...), nil -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package table + +import ( + "fmt" + "slices" + "strconv" + "strings" + + "github.com/apache/iceberg-go" +) + +type MappedField struct { + Names []string `json:"names"` + // iceberg spec says this is optional, but I don't see any examples + // of this being left empty. Does pyiceberg need to be updated or should + // the spec not say field-id is optional? + FieldID *int `json:"field-id,omitempty"` + Fields []MappedField `json:"fields,omitempty"` +} + +func (m *MappedField) Len() int { return len(m.Fields) } + +func (m *MappedField) String() string { + var bldr strings.Builder + bldr.WriteString("([") + bldr.WriteString(strings.Join(m.Names, ", ")) + bldr.WriteString("] -> ") + + if m.FieldID != nil { + bldr.WriteString(strconv.Itoa(*m.FieldID)) + } else { + bldr.WriteByte('?') + } + + if len(m.Fields) > 0 { + bldr.WriteByte(' ') + for i, f := range m.Fields { + if i != 0 { + bldr.WriteString(", ") + } + bldr.WriteString(f.String()) + } + } + + bldr.WriteByte(')') + return bldr.String() +} + +type NameMapping []MappedField + +func (nm NameMapping) String() string { + var bldr strings.Builder + bldr.WriteString("[\n") + for _, f := range nm { + bldr.WriteByte('\t') + bldr.WriteString(f.String()) + bldr.WriteByte('\n') + } + bldr.WriteByte(']') + return bldr.String() +} + +type NameMappingVisitor[S, T any] interface { + Mapping(nm NameMapping, fieldResults S) S + Fields(st []MappedField, fieldResults []T) S + Field(field MappedField, fieldResult S) T +} + +func VisitNameMapping[S, T any](obj NameMapping, visitor NameMappingVisitor[S, T]) (res S, err error) { + if obj == nil { + err = fmt.Errorf("%w: cannot visit nil NameMapping", iceberg.ErrInvalidArgument) + return + } + + defer recoverError(&err) + + return visitor.Mapping(obj, visitMappedFields([]MappedField(obj), visitor)), err +} + +func VisitMappedFields[S, T any](fields []MappedField, visitor NameMappingVisitor[S, T]) (res S, err error) { + defer recoverError(&err) + + return visitMappedFields(fields, visitor), err +} + +func visitMappedFields[S, T any](fields []MappedField, visitor NameMappingVisitor[S, T]) S { + results := make([]T, len(fields)) + for i, f := range fields { + results[i] = visitor.Field(f, visitMappedFields(f.Fields, visitor)) + } + + return visitor.Fields(fields, results) +} + +type NameMappingAccessor struct{} + +func (NameMappingAccessor) SchemaPartner(partner *MappedField) *MappedField { + return partner +} + +func (NameMappingAccessor) getField(p *MappedField, field string) *MappedField { + for _, f := range p.Fields { + if slices.Contains(f.Names, field) { + return &f + } + } + + return nil +} + +func (n NameMappingAccessor) FieldPartner(partnerStruct *MappedField, _ int, fieldName string) *MappedField { + if partnerStruct == nil { + return nil + } + + return n.getField(partnerStruct, fieldName) +} + +func (n NameMappingAccessor) ListElementPartner(partnerList *MappedField) *MappedField { + if partnerList == nil { + return nil + } + + return n.getField(partnerList, "element") +} + +func (n NameMappingAccessor) MapKeyPartner(partnerMap *MappedField) *MappedField { + if partnerMap == nil { + return nil + } + + return n.getField(partnerMap, "key") +} + +func (n NameMappingAccessor) MapValuePartner(partnerMap *MappedField) *MappedField { + if partnerMap == nil { + return nil + } + + return n.getField(partnerMap, "value") +} + +type nameMapProjectVisitor struct { + currentPath []string +} + +func (n *nameMapProjectVisitor) popPath() { + n.currentPath = n.currentPath[:len(n.currentPath)-1] +} + +func (n *nameMapProjectVisitor) BeforeField(f iceberg.NestedField, _ *MappedField) { + n.currentPath = append(n.currentPath, f.Name) +} + +func (n *nameMapProjectVisitor) AfterField(iceberg.NestedField, *MappedField) { + n.popPath() +} + +func (n *nameMapProjectVisitor) BeforeListElement(iceberg.NestedField, *MappedField) { + n.currentPath = append(n.currentPath, "element") +} + +func (n *nameMapProjectVisitor) AfterListElement(iceberg.NestedField, *MappedField) { + n.popPath() +} + +func (n *nameMapProjectVisitor) BeforeMapKey(iceberg.NestedField, *MappedField) { + n.currentPath = append(n.currentPath, "key") +} + +func (n *nameMapProjectVisitor) AfterMapKey(iceberg.NestedField, *MappedField) { + n.popPath() +} + +func (n *nameMapProjectVisitor) BeforeMapValue(iceberg.NestedField, *MappedField) { + n.currentPath = append(n.currentPath, "value") +} + +func (n *nameMapProjectVisitor) AfterMapValue(iceberg.NestedField, *MappedField) { + n.popPath() +} + +func (n *nameMapProjectVisitor) Schema(_ *iceberg.Schema, _ *MappedField, structResult iceberg.NestedField) iceberg.NestedField { + return structResult +} + +func (n *nameMapProjectVisitor) Struct(_ iceberg.StructType, _ *MappedField, fieldResults []iceberg.NestedField) iceberg.NestedField { + return iceberg.NestedField{ + Type: &iceberg.StructType{FieldList: fieldResults}, + } +} + +func (n *nameMapProjectVisitor) Field(field iceberg.NestedField, fieldPartner *MappedField, fieldResult iceberg.NestedField) iceberg.NestedField { + if fieldPartner == nil { + panic(fmt.Errorf("%w: field missing from name mapping: %s", + iceberg.ErrInvalidArgument, strings.Join(n.currentPath, "."))) + } + + return iceberg.NestedField{ + ID: *fieldPartner.FieldID, + Name: field.Name, + Type: fieldResult.Type, + Required: field.Required, + Doc: field.Doc, + InitialDefault: field.InitialDefault, + WriteDefault: field.WriteDefault, + } +} + +func (nameMapProjectVisitor) mappedFieldID(mapped *MappedField, name string) int { + for _, f := range mapped.Fields { + if slices.Contains(f.Names, name) { + if f.FieldID != nil { + return *f.FieldID + } + return -1 + } + } + + return -1 +} + +func (n *nameMapProjectVisitor) List(lt iceberg.ListType, listPartner *MappedField, elemResult iceberg.NestedField) iceberg.NestedField { + if listPartner == nil { + panic(fmt.Errorf("%w: field missing from name mapping: %s", + iceberg.ErrInvalidArgument, strings.Join(n.currentPath, "."))) + } + + elementID := n.mappedFieldID(listPartner, "element") + + return iceberg.NestedField{ + Type: &iceberg.ListType{ + ElementID: elementID, + Element: elemResult.Type, + ElementRequired: lt.ElementRequired, + }, + } +} + +func (n *nameMapProjectVisitor) Map(m iceberg.MapType, mapPartner *MappedField, keyResult, valResult iceberg.NestedField) iceberg.NestedField { + if mapPartner == nil { + panic(fmt.Errorf("%w: field missing from name mapping: %s", + iceberg.ErrInvalidArgument, strings.Join(n.currentPath, "."))) + } + + keyID := n.mappedFieldID(mapPartner, "key") + valID := n.mappedFieldID(mapPartner, "value") + return iceberg.NestedField{ + Type: &iceberg.MapType{ + KeyID: keyID, + KeyType: keyResult.Type, + ValueID: valID, + ValueType: valResult.Type, + ValueRequired: m.ValueRequired, + }, + } +} + +func (n *nameMapProjectVisitor) Primitive(p iceberg.PrimitiveType, primitivePartner *MappedField) iceberg.NestedField { + if primitivePartner == nil { + panic(fmt.Errorf("%w: field missing from name mapping: %s", + iceberg.ErrInvalidArgument, strings.Join(n.currentPath, "."))) + } + + return iceberg.NestedField{Type: p} +} + +func ApplyNameMapping(schemaWithoutIDs *iceberg.Schema, nameMapping NameMapping) (*iceberg.Schema, error) { + top, err := iceberg.VisitSchemaWithPartner[iceberg.NestedField, *MappedField](schemaWithoutIDs, + &MappedField{Fields: nameMapping}, + &nameMapProjectVisitor{currentPath: make([]string, 0, 1)}, + NameMappingAccessor{}) + if err != nil { + return nil, err + } + + return iceberg.NewSchema(schemaWithoutIDs.ID, + top.Type.(*iceberg.StructType).FieldList...), nil +} diff --git a/table/name_mapping_test.go b/table/name_mapping_test.go index bbef128..ecc6d4c 100644 --- a/table/name_mapping_test.go +++ b/table/name_mapping_test.go @@ -1,145 +1,145 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package table_test - -import ( - "encoding/json" - "testing" - - "github.com/apache/iceberg-go/table" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -var ( - tableNameMappingNested = table.NameMapping{ - {FieldID: makeID(1), Names: []string{"foo"}}, - {FieldID: makeID(2), Names: []string{"bar"}}, - {FieldID: makeID(3), Names: []string{"baz"}}, - {FieldID: makeID(4), Names: []string{"qux"}, - Fields: []table.MappedField{{FieldID: makeID(5), Names: []string{"element"}}}}, - {FieldID: makeID(6), Names: []string{"quux"}, Fields: []table.MappedField{ - {FieldID: makeID(7), Names: []string{"key"}}, - {FieldID: makeID(8), Names: []string{"value"}, Fields: []table.MappedField{ - {FieldID: makeID(9), Names: []string{"key"}}, - {FieldID: makeID(10), Names: []string{"value"}}, - }}, - }}, - {FieldID: makeID(11), Names: []string{"location"}, Fields: []table.MappedField{ - {FieldID: makeID(12), Names: []string{"element"}, Fields: []table.MappedField{ - {FieldID: makeID(13), Names: []string{"latitude"}}, - {FieldID: makeID(14), Names: []string{"longitude"}}, - }}, - }}, - {FieldID: makeID(15), Names: []string{"person"}, Fields: []table.MappedField{ - {FieldID: makeID(16), Names: []string{"name"}}, - {FieldID: makeID(17), Names: []string{"age"}}, - }}, - } -) - -func TestJsonMappedField(t *testing.T) { - tests := []struct { - name string - str string - exp table.MappedField - }{ - {"simple", `{"field-id": 1, "names": ["id", "record_id"]}`, - table.MappedField{FieldID: makeID(1), Names: []string{"id", "record_id"}}}, - {"with null fields", `{"field-id": 1, "names": ["id", "record_id"], "fields": null}`, - table.MappedField{FieldID: makeID(1), Names: []string{"id", "record_id"}}}, - {"no names", `{"field-id": 1, "names": []}`, table.MappedField{FieldID: makeID(1), Names: []string{}}}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var n table.MappedField - require.NoError(t, json.Unmarshal([]byte(tt.str), &n)) - assert.Equal(t, tt.exp, n) - }) - } -} - -func TestNameMappingFromJson(t *testing.T) { - mapping := `[ - {"names": ["foo", "bar"]}, - {"field-id": 1, "names": ["id", "record_id"]}, - {"field-id": 2, "names": ["data"]}, - {"field-id": 3, "names": ["location"], "fields": [ - {"field-id": 4, "names": ["latitude", "lat"]}, - {"field-id": 5, "names": ["longitude", "long"]} - ]} - ]` - - var nm table.NameMapping - require.NoError(t, json.Unmarshal([]byte(mapping), &nm)) - - assert.Equal(t, nm, table.NameMapping{ - {FieldID: nil, Names: []string{"foo", "bar"}}, - {FieldID: makeID(1), Names: []string{"id", "record_id"}}, - {FieldID: makeID(2), Names: []string{"data"}}, - {FieldID: makeID(3), Names: []string{"location"}, Fields: []table.MappedField{ - {FieldID: makeID(4), Names: []string{"latitude", "lat"}}, - {FieldID: makeID(5), Names: []string{"longitude", "long"}}, - }}, - }) -} - -func TestNameMappingToJson(t *testing.T) { - result, err := json.Marshal(tableNameMappingNested) - require.NoError(t, err) - assert.JSONEq(t, `[ - {"field-id": 1, "names": ["foo"]}, - {"field-id": 2, "names": ["bar"]}, - {"field-id": 3, "names": ["baz"]}, - {"field-id": 4, "names": ["qux"], "fields": [{"field-id": 5, "names": ["element"]}]}, - {"field-id": 6, "names": ["quux"], "fields": [ - {"field-id": 7, "names": ["key"]}, - {"field-id": 8, "names": ["value"], "fields": [ - {"field-id": 9, "names": ["key"]}, - {"field-id": 10, "names": ["value"]} - ]} - ]}, - {"field-id": 11, "names": ["location"], "fields": [ - {"field-id": 12, "names": ["element"], "fields": [ - {"field-id": 13, "names": ["latitude"]}, - {"field-id": 14, "names": ["longitude"]} - ]} - ]}, - {"field-id": 15, "names": ["person"], "fields": [ - {"field-id": 16, "names": ["name"]}, - {"field-id": 17, "names": ["age"]} - ]} -]`, string(result)) -} - -func TestNameMappingToString(t *testing.T) { - assert.Equal(t, `[ - ([foo] -> ?) - ([id, record_id] -> 1) - ([data] -> 2) - ([location] -> 3 ([lat, latitude] -> 4), ([long, longitude] -> 5)) -]`, table.NameMapping{ - {Names: []string{"foo"}}, - {FieldID: makeID(1), Names: []string{"id", "record_id"}}, - {FieldID: makeID(2), Names: []string{"data"}}, - {FieldID: makeID(3), Names: []string{"location"}, Fields: []table.MappedField{ - {FieldID: makeID(4), Names: []string{"lat", "latitude"}}, - {FieldID: makeID(5), Names: []string{"long", "longitude"}}, - }}}.String()) -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package table_test + +import ( + "encoding/json" + "testing" + + "github.com/apache/iceberg-go/table" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + tableNameMappingNested = table.NameMapping{ + {FieldID: makeID(1), Names: []string{"foo"}}, + {FieldID: makeID(2), Names: []string{"bar"}}, + {FieldID: makeID(3), Names: []string{"baz"}}, + {FieldID: makeID(4), Names: []string{"qux"}, + Fields: []table.MappedField{{FieldID: makeID(5), Names: []string{"element"}}}}, + {FieldID: makeID(6), Names: []string{"quux"}, Fields: []table.MappedField{ + {FieldID: makeID(7), Names: []string{"key"}}, + {FieldID: makeID(8), Names: []string{"value"}, Fields: []table.MappedField{ + {FieldID: makeID(9), Names: []string{"key"}}, + {FieldID: makeID(10), Names: []string{"value"}}, + }}, + }}, + {FieldID: makeID(11), Names: []string{"location"}, Fields: []table.MappedField{ + {FieldID: makeID(12), Names: []string{"element"}, Fields: []table.MappedField{ + {FieldID: makeID(13), Names: []string{"latitude"}}, + {FieldID: makeID(14), Names: []string{"longitude"}}, + }}, + }}, + {FieldID: makeID(15), Names: []string{"person"}, Fields: []table.MappedField{ + {FieldID: makeID(16), Names: []string{"name"}}, + {FieldID: makeID(17), Names: []string{"age"}}, + }}, + } +) + +func TestJsonMappedField(t *testing.T) { + tests := []struct { + name string + str string + exp table.MappedField + }{ + {"simple", `{"field-id": 1, "names": ["id", "record_id"]}`, + table.MappedField{FieldID: makeID(1), Names: []string{"id", "record_id"}}}, + {"with null fields", `{"field-id": 1, "names": ["id", "record_id"], "fields": null}`, + table.MappedField{FieldID: makeID(1), Names: []string{"id", "record_id"}}}, + {"no names", `{"field-id": 1, "names": []}`, table.MappedField{FieldID: makeID(1), Names: []string{}}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var n table.MappedField + require.NoError(t, json.Unmarshal([]byte(tt.str), &n)) + assert.Equal(t, tt.exp, n) + }) + } +} + +func TestNameMappingFromJson(t *testing.T) { + mapping := `[ + {"names": ["foo", "bar"]}, + {"field-id": 1, "names": ["id", "record_id"]}, + {"field-id": 2, "names": ["data"]}, + {"field-id": 3, "names": ["location"], "fields": [ + {"field-id": 4, "names": ["latitude", "lat"]}, + {"field-id": 5, "names": ["longitude", "long"]} + ]} + ]` + + var nm table.NameMapping + require.NoError(t, json.Unmarshal([]byte(mapping), &nm)) + + assert.Equal(t, nm, table.NameMapping{ + {FieldID: nil, Names: []string{"foo", "bar"}}, + {FieldID: makeID(1), Names: []string{"id", "record_id"}}, + {FieldID: makeID(2), Names: []string{"data"}}, + {FieldID: makeID(3), Names: []string{"location"}, Fields: []table.MappedField{ + {FieldID: makeID(4), Names: []string{"latitude", "lat"}}, + {FieldID: makeID(5), Names: []string{"longitude", "long"}}, + }}, + }) +} + +func TestNameMappingToJson(t *testing.T) { + result, err := json.Marshal(tableNameMappingNested) + require.NoError(t, err) + assert.JSONEq(t, `[ + {"field-id": 1, "names": ["foo"]}, + {"field-id": 2, "names": ["bar"]}, + {"field-id": 3, "names": ["baz"]}, + {"field-id": 4, "names": ["qux"], "fields": [{"field-id": 5, "names": ["element"]}]}, + {"field-id": 6, "names": ["quux"], "fields": [ + {"field-id": 7, "names": ["key"]}, + {"field-id": 8, "names": ["value"], "fields": [ + {"field-id": 9, "names": ["key"]}, + {"field-id": 10, "names": ["value"]} + ]} + ]}, + {"field-id": 11, "names": ["location"], "fields": [ + {"field-id": 12, "names": ["element"], "fields": [ + {"field-id": 13, "names": ["latitude"]}, + {"field-id": 14, "names": ["longitude"]} + ]} + ]}, + {"field-id": 15, "names": ["person"], "fields": [ + {"field-id": 16, "names": ["name"]}, + {"field-id": 17, "names": ["age"]} + ]} +]`, string(result)) +} + +func TestNameMappingToString(t *testing.T) { + assert.Equal(t, `[ + ([foo] -> ?) + ([id, record_id] -> 1) + ([data] -> 2) + ([location] -> 3 ([lat, latitude] -> 4), ([long, longitude] -> 5)) +]`, table.NameMapping{ + {Names: []string{"foo"}}, + {FieldID: makeID(1), Names: []string{"id", "record_id"}}, + {FieldID: makeID(2), Names: []string{"data"}}, + {FieldID: makeID(3), Names: []string{"location"}, Fields: []table.MappedField{ + {FieldID: makeID(4), Names: []string{"lat", "latitude"}}, + {FieldID: makeID(5), Names: []string{"long", "longitude"}}, + }}}.String()) +} diff --git a/table/refs.go b/table/refs.go index f0eb697..81286e7 100644 --- a/table/refs.go +++ b/table/refs.go @@ -1,68 +1,68 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package table - -import ( - "encoding/json" - "errors" - "reflect" -) - -const MainBranch = "main" - -// RefType will be either a BranchRef or a TagRef -type RefType string - -const ( - BranchRef RefType = "branch" - TagRef RefType = "tag" -) - -var ( - ErrInvalidRefType = errors.New("invalid snapshot ref type, should be 'branch' or 'tag'") -) - -// SnapshotRef represents the reference information for a specific snapshot -type SnapshotRef struct { - SnapshotID int64 `json:"snapshot-id"` - SnapshotRefType RefType `json:"type"` - MinSnapshotsToKeep *int `json:"min-snapshots-to-keep,omitempty"` - MaxSnapshotAgeMs *int64 `json:"max-snapshot-age-ms,omitempty"` - MaxRefAgeMs *int64 `json:"max-ref-age-ms,omitempty"` -} - -func (s *SnapshotRef) Equals(rhs SnapshotRef) bool { - return reflect.DeepEqual(s, &rhs) -} - -func (s *SnapshotRef) UnmarshalJSON(b []byte) error { - type Alias SnapshotRef - aux := (*Alias)(s) - - if err := json.Unmarshal(b, aux); err != nil { - return nil - } - - switch s.SnapshotRefType { - case BranchRef, TagRef: - default: - return ErrInvalidRefType - } - - return nil -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package table + +import ( + "encoding/json" + "errors" + "reflect" +) + +const MainBranch = "main" + +// RefType will be either a BranchRef or a TagRef +type RefType string + +const ( + BranchRef RefType = "branch" + TagRef RefType = "tag" +) + +var ( + ErrInvalidRefType = errors.New("invalid snapshot ref type, should be 'branch' or 'tag'") +) + +// SnapshotRef represents the reference information for a specific snapshot +type SnapshotRef struct { + SnapshotID int64 `json:"snapshot-id"` + SnapshotRefType RefType `json:"type"` + MinSnapshotsToKeep *int `json:"min-snapshots-to-keep,omitempty"` + MaxSnapshotAgeMs *int64 `json:"max-snapshot-age-ms,omitempty"` + MaxRefAgeMs *int64 `json:"max-ref-age-ms,omitempty"` +} + +func (s *SnapshotRef) Equals(rhs SnapshotRef) bool { + return reflect.DeepEqual(s, &rhs) +} + +func (s *SnapshotRef) UnmarshalJSON(b []byte) error { + type Alias SnapshotRef + aux := (*Alias)(s) + + if err := json.Unmarshal(b, aux); err != nil { + return nil + } + + switch s.SnapshotRefType { + case BranchRef, TagRef: + default: + return ErrInvalidRefType + } + + return nil +} diff --git a/table/refs_test.go b/table/refs_test.go index d8b54e4..ca616fd 100644 --- a/table/refs_test.go +++ b/table/refs_test.go @@ -1,72 +1,72 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package table_test - -import ( - "encoding/json" - "testing" - - "github.com/apache/iceberg-go/table" - "github.com/stretchr/testify/assert" -) - -func TestInvalidSnapshotRefType(t *testing.T) { - ref := `{ - "snapshot-id": 3051729675574597004, - "type": "foobar" - }` - - var snapRef table.SnapshotRef - err := json.Unmarshal([]byte(ref), &snapRef) - assert.ErrorIs(t, err, table.ErrInvalidRefType) -} - -func TestSnapshotBranchRef(t *testing.T) { - ref := `{ - "snapshot-id": 3051729675574597004, - "type": "branch" - }` - - var snapRef table.SnapshotRef - err := json.Unmarshal([]byte(ref), &snapRef) - assert.NoError(t, err) - - assert.Equal(t, table.BranchRef, snapRef.SnapshotRefType) - assert.Equal(t, int64(3051729675574597004), snapRef.SnapshotID) - assert.Nil(t, snapRef.MinSnapshotsToKeep) - assert.Nil(t, snapRef.MaxRefAgeMs) - assert.Nil(t, snapRef.MaxSnapshotAgeMs) -} - -func TestSnapshotTagRef(t *testing.T) { - ref := `{ - "snapshot-id": 3051729675574597004, - "type": "tag", - "min-snapshots-to-keep": 10 - }` - - var snapRef table.SnapshotRef - err := json.Unmarshal([]byte(ref), &snapRef) - assert.NoError(t, err) - - assert.Equal(t, table.TagRef, snapRef.SnapshotRefType) - assert.Equal(t, int64(3051729675574597004), snapRef.SnapshotID) - assert.Equal(t, 10, *snapRef.MinSnapshotsToKeep) - assert.Nil(t, snapRef.MaxRefAgeMs) - assert.Nil(t, snapRef.MaxSnapshotAgeMs) -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package table_test + +import ( + "encoding/json" + "testing" + + "github.com/apache/iceberg-go/table" + "github.com/stretchr/testify/assert" +) + +func TestInvalidSnapshotRefType(t *testing.T) { + ref := `{ + "snapshot-id": 3051729675574597004, + "type": "foobar" + }` + + var snapRef table.SnapshotRef + err := json.Unmarshal([]byte(ref), &snapRef) + assert.ErrorIs(t, err, table.ErrInvalidRefType) +} + +func TestSnapshotBranchRef(t *testing.T) { + ref := `{ + "snapshot-id": 3051729675574597004, + "type": "branch" + }` + + var snapRef table.SnapshotRef + err := json.Unmarshal([]byte(ref), &snapRef) + assert.NoError(t, err) + + assert.Equal(t, table.BranchRef, snapRef.SnapshotRefType) + assert.Equal(t, int64(3051729675574597004), snapRef.SnapshotID) + assert.Nil(t, snapRef.MinSnapshotsToKeep) + assert.Nil(t, snapRef.MaxRefAgeMs) + assert.Nil(t, snapRef.MaxSnapshotAgeMs) +} + +func TestSnapshotTagRef(t *testing.T) { + ref := `{ + "snapshot-id": 3051729675574597004, + "type": "tag", + "min-snapshots-to-keep": 10 + }` + + var snapRef table.SnapshotRef + err := json.Unmarshal([]byte(ref), &snapRef) + assert.NoError(t, err) + + assert.Equal(t, table.TagRef, snapRef.SnapshotRefType) + assert.Equal(t, int64(3051729675574597004), snapRef.SnapshotID) + assert.Equal(t, 10, *snapRef.MinSnapshotsToKeep) + assert.Nil(t, snapRef.MaxRefAgeMs) + assert.Nil(t, snapRef.MaxSnapshotAgeMs) +} diff --git a/table/scanner.go b/table/scanner.go index ea33372..84e2b74 100644 --- a/table/scanner.go +++ b/table/scanner.go @@ -1,395 +1,395 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package table - -import ( - "cmp" - "context" - "fmt" - "runtime" - "slices" - "sync" - - "github.com/apache/iceberg-go" - "github.com/apache/iceberg-go/io" -) - -type keyDefaultMap[K comparable, V any] struct { - defaultFactory func(K) V - data map[K]V - - mx sync.RWMutex -} - -func (k *keyDefaultMap[K, V]) Get(key K) V { - k.mx.RLock() - if v, ok := k.data[key]; ok { - k.mx.RUnlock() - return v - } - - k.mx.RUnlock() - k.mx.Lock() - defer k.mx.Unlock() - - v := k.defaultFactory(key) - k.data[key] = v - return v -} - -func newKeyDefaultMap[K comparable, V any](factory func(K) V) keyDefaultMap[K, V] { - return keyDefaultMap[K, V]{ - data: make(map[K]V), - defaultFactory: factory, - } -} - -func newKeyDefaultMapWrapErr[K comparable, V any](factory func(K) (V, error)) keyDefaultMap[K, V] { - return keyDefaultMap[K, V]{ - data: make(map[K]V), - defaultFactory: func(k K) V { - v, err := factory(k) - if err != nil { - panic(err) - } - return v - }, - } -} - -type partitionRecord []any - -func (p partitionRecord) Size() int { return len(p) } -func (p partitionRecord) Get(pos int) any { return p[pos] } -func (p partitionRecord) Set(pos int, val any) { p[pos] = val } - -func getPartitionRecord(dataFile iceberg.DataFile, partitionType *iceberg.StructType) partitionRecord { - partitionData := dataFile.Partition() - - out := make(partitionRecord, len(partitionType.FieldList)) - for i, f := range partitionType.FieldList { - out[i] = partitionData[f.Name] - } - return out -} - -func openManifest(io io.IO, manifest iceberg.ManifestFile, - partitionFilter, metricsEval func(iceberg.DataFile) (bool, error)) ([]iceberg.ManifestEntry, error) { - - entries, err := manifest.FetchEntries(io, true) - if err != nil { - return nil, err - } - - out := make([]iceberg.ManifestEntry, 0, len(entries)) - for _, entry := range entries { - p, err := partitionFilter(entry.DataFile()) - if err != nil { - return nil, err - } - - m, err := metricsEval(entry.DataFile()) - if err != nil { - return nil, err - } - - if p && m { - out = append(out, entry) - } - } - - return out, nil -} - -type Scan struct { - metadata Metadata - io io.IO - rowFilter iceberg.BooleanExpression - selectedFields []string - caseSensitive bool - snapshotID *int64 - options iceberg.Properties - - partitionFilters keyDefaultMap[int, iceberg.BooleanExpression] -} - -func (s *Scan) UseRef(name string) (*Scan, error) { - if s.snapshotID != nil { - return nil, fmt.Errorf("%w: cannot override ref, already set snapshot id %d", - iceberg.ErrInvalidArgument, *s.snapshotID) - } - - if snap := s.metadata.SnapshotByName(name); snap != nil { - out := &Scan{ - metadata: s.metadata, - io: s.io, - rowFilter: s.rowFilter, - selectedFields: s.selectedFields, - caseSensitive: s.caseSensitive, - snapshotID: &snap.SnapshotID, - options: s.options, - } - out.partitionFilters = newKeyDefaultMapWrapErr(out.buildPartitionProjection) - - return out, nil - } - - return nil, fmt.Errorf("%w: cannot scan unknown ref=%s", iceberg.ErrInvalidArgument, name) -} - -func (s *Scan) Snapshot() *Snapshot { - if s.snapshotID != nil { - return s.metadata.SnapshotByID(*s.snapshotID) - } - return s.metadata.CurrentSnapshot() -} - -func (s *Scan) Projection() (*iceberg.Schema, error) { - curSchema := s.metadata.CurrentSchema() - if s.snapshotID != nil { - snap := s.metadata.SnapshotByID(*s.snapshotID) - if snap == nil { - return nil, fmt.Errorf("%w: snapshot not found: %d", ErrInvalidOperation, *s.snapshotID) - } - - if snap.SchemaID != nil { - for _, schema := range s.metadata.Schemas() { - if schema.ID == *snap.SchemaID { - curSchema = schema - break - } - } - } - } - - if slices.Contains(s.selectedFields, "*") { - return curSchema, nil - } - - return curSchema.Select(s.caseSensitive, s.selectedFields...) -} - -func (s *Scan) buildPartitionProjection(specID int) (iceberg.BooleanExpression, error) { - project := newInclusiveProjection(s.metadata.CurrentSchema(), - s.metadata.PartitionSpecs()[specID], true) - return project(s.rowFilter) -} - -func (s *Scan) buildManifestEvaluator(specID int) (func(iceberg.ManifestFile) (bool, error), error) { - spec := s.metadata.PartitionSpecs()[specID] - return newManifestEvaluator(spec, s.metadata.CurrentSchema(), - s.partitionFilters.Get(specID), s.caseSensitive) -} - -func (s *Scan) buildPartitionEvaluator(specID int) func(iceberg.DataFile) (bool, error) { - spec := s.metadata.PartitionSpecs()[specID] - partType := spec.PartitionType(s.metadata.CurrentSchema()) - partSchema := iceberg.NewSchema(0, partType.FieldList...) - partExpr := s.partitionFilters.Get(specID) - - return func(d iceberg.DataFile) (bool, error) { - fn, err := iceberg.ExpressionEvaluator(partSchema, partExpr, s.caseSensitive) - if err != nil { - return false, err - } - - return fn(getPartitionRecord(d, partType)) - } -} - -func (s *Scan) checkSequenceNumber(minSeqNum int64, manifest iceberg.ManifestFile) bool { - return manifest.ManifestContent() == iceberg.ManifestContentData || - (manifest.ManifestContent() == iceberg.ManifestContentDeletes && - manifest.SequenceNum() >= minSeqNum) -} - -func minSequenceNum(manifests []iceberg.ManifestFile) int64 { - n := int64(0) - for _, m := range manifests { - if m.ManifestContent() == iceberg.ManifestContentData { - n = min(n, m.MinSequenceNum()) - } - } - return n -} - -func matchDeletesToData(entry iceberg.ManifestEntry, positionalDeletes []iceberg.ManifestEntry) ([]iceberg.DataFile, error) { - idx, _ := slices.BinarySearchFunc(positionalDeletes, entry, func(me1, me2 iceberg.ManifestEntry) int { - return cmp.Compare(me1.SequenceNum(), me2.SequenceNum()) - }) - - evaluator, err := newInclusiveMetricsEvaluator(iceberg.PositionalDeleteSchema, - iceberg.EqualTo(iceberg.Reference("file_path"), entry.DataFile().FilePath()), true, false) - if err != nil { - return nil, err - } - - out := make([]iceberg.DataFile, 0) - for _, relevant := range positionalDeletes[idx:] { - df := relevant.DataFile() - ok, err := evaluator(df) - if err != nil { - return nil, err - } - if ok { - out = append(out, df) - } - } - - return out, nil -} - -func (s *Scan) PlanFiles(ctx context.Context) ([]FileScanTask, error) { - snap := s.Snapshot() - if snap == nil { - return nil, nil - } - - // step 1: filter manifests using partition summaries - // the filter depends on the partition spec used to write the manifest file - // so create a cache of filters for each spec id - manifestEvaluators := newKeyDefaultMapWrapErr(s.buildManifestEvaluator) - manifestList, err := snap.Manifests(s.io) - if err != nil { - return nil, err - } - - // remove any manifests that we don't need to use - manifestList = slices.DeleteFunc(manifestList, func(mf iceberg.ManifestFile) bool { - eval := manifestEvaluators.Get(int(mf.PartitionSpecID())) - use, err := eval(mf) - return !use || err != nil - }) - - // step 2: filter the data files in each manifest - // this filter depends on the partition spec used to write the manifest file - partitionEvaluators := newKeyDefaultMap(s.buildPartitionEvaluator) - metricsEval, err := newInclusiveMetricsEvaluator( - s.metadata.CurrentSchema(), s.rowFilter, s.caseSensitive, s.options["include_empty_files"] == "true") - if err != nil { - return nil, err - } - - minSeqNum := minSequenceNum(manifestList) - dataEntries := make([]iceberg.ManifestEntry, 0) - positionalDeleteEntries := make([]iceberg.ManifestEntry, 0) - - nworkers := runtime.NumCPU() - var wg sync.WaitGroup - - manifestChan := make(chan iceberg.ManifestFile, len(manifestList)) - entryChan := make(chan []iceberg.ManifestEntry, 20) - - ctx, cancel := context.WithCancelCause(ctx) - for i := 0; i < nworkers; i++ { - wg.Add(1) - - go func() { - defer wg.Done() - - for { - select { - case m, ok := <-manifestChan: - if !ok { - return - } - - if !s.checkSequenceNumber(minSeqNum, m) { - continue - } - - entries, err := openManifest(s.io, m, - partitionEvaluators.Get(int(m.PartitionSpecID())), metricsEval) - if err != nil { - cancel(err) - break - } - - entryChan <- entries - case <-ctx.Done(): - return - } - } - }() - } - - go func() { - wg.Wait() - close(entryChan) - }() - - for _, m := range manifestList { - manifestChan <- m - } - close(manifestChan) - -Loop: - for { - select { - case <-ctx.Done(): - return nil, context.Cause(ctx) - case entries, ok := <-entryChan: - if !ok { - // closed! - break Loop - } - - for _, e := range entries { - df := e.DataFile() - switch df.ContentType() { - case iceberg.EntryContentData: - dataEntries = append(dataEntries, e) - case iceberg.EntryContentPosDeletes: - positionalDeleteEntries = append(positionalDeleteEntries, e) - case iceberg.EntryContentEqDeletes: - return nil, fmt.Errorf("iceberg-go does not yet support equality deletes") - default: - return nil, fmt.Errorf("%w: unknown DataFileContent type (%s): %s", - ErrInvalidMetadata, df.ContentType(), e) - } - } - } - } - - slices.SortFunc(positionalDeleteEntries, func(a, b iceberg.ManifestEntry) int { - return cmp.Compare(a.SequenceNum(), b.SequenceNum()) - }) - - results := make([]FileScanTask, 0) - for _, e := range dataEntries { - deleteFiles, err := matchDeletesToData(e, positionalDeleteEntries) - if err != nil { - return nil, err - } - - results = append(results, FileScanTask{ - File: e.DataFile(), - DeleteFiles: deleteFiles, - Start: 0, - Length: e.DataFile().FileSizeBytes(), - }) - } - - return results, nil -} - -type FileScanTask struct { - File iceberg.DataFile - DeleteFiles []iceberg.DataFile - Start, Length int64 -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package table + +import ( + "cmp" + "context" + "fmt" + "runtime" + "slices" + "sync" + + "github.com/apache/iceberg-go" + "github.com/apache/iceberg-go/io" +) + +type keyDefaultMap[K comparable, V any] struct { + defaultFactory func(K) V + data map[K]V + + mx sync.RWMutex +} + +func (k *keyDefaultMap[K, V]) Get(key K) V { + k.mx.RLock() + if v, ok := k.data[key]; ok { + k.mx.RUnlock() + return v + } + + k.mx.RUnlock() + k.mx.Lock() + defer k.mx.Unlock() + + v := k.defaultFactory(key) + k.data[key] = v + return v +} + +func newKeyDefaultMap[K comparable, V any](factory func(K) V) keyDefaultMap[K, V] { + return keyDefaultMap[K, V]{ + data: make(map[K]V), + defaultFactory: factory, + } +} + +func newKeyDefaultMapWrapErr[K comparable, V any](factory func(K) (V, error)) keyDefaultMap[K, V] { + return keyDefaultMap[K, V]{ + data: make(map[K]V), + defaultFactory: func(k K) V { + v, err := factory(k) + if err != nil { + panic(err) + } + return v + }, + } +} + +type partitionRecord []any + +func (p partitionRecord) Size() int { return len(p) } +func (p partitionRecord) Get(pos int) any { return p[pos] } +func (p partitionRecord) Set(pos int, val any) { p[pos] = val } + +func getPartitionRecord(dataFile iceberg.DataFile, partitionType *iceberg.StructType) partitionRecord { + partitionData := dataFile.Partition() + + out := make(partitionRecord, len(partitionType.FieldList)) + for i, f := range partitionType.FieldList { + out[i] = partitionData[f.Name] + } + return out +} + +func openManifest(io io.IO, manifest iceberg.ManifestFile, + partitionFilter, metricsEval func(iceberg.DataFile) (bool, error)) ([]iceberg.ManifestEntry, error) { + + entries, err := manifest.FetchEntries(io, true) + if err != nil { + return nil, err + } + + out := make([]iceberg.ManifestEntry, 0, len(entries)) + for _, entry := range entries { + p, err := partitionFilter(entry.DataFile()) + if err != nil { + return nil, err + } + + m, err := metricsEval(entry.DataFile()) + if err != nil { + return nil, err + } + + if p && m { + out = append(out, entry) + } + } + + return out, nil +} + +type Scan struct { + metadata Metadata + io io.IO + rowFilter iceberg.BooleanExpression + selectedFields []string + caseSensitive bool + snapshotID *int64 + options iceberg.Properties + + partitionFilters keyDefaultMap[int, iceberg.BooleanExpression] +} + +func (s *Scan) UseRef(name string) (*Scan, error) { + if s.snapshotID != nil { + return nil, fmt.Errorf("%w: cannot override ref, already set snapshot id %d", + iceberg.ErrInvalidArgument, *s.snapshotID) + } + + if snap := s.metadata.SnapshotByName(name); snap != nil { + out := &Scan{ + metadata: s.metadata, + io: s.io, + rowFilter: s.rowFilter, + selectedFields: s.selectedFields, + caseSensitive: s.caseSensitive, + snapshotID: &snap.SnapshotID, + options: s.options, + } + out.partitionFilters = newKeyDefaultMapWrapErr(out.buildPartitionProjection) + + return out, nil + } + + return nil, fmt.Errorf("%w: cannot scan unknown ref=%s", iceberg.ErrInvalidArgument, name) +} + +func (s *Scan) Snapshot() *Snapshot { + if s.snapshotID != nil { + return s.metadata.SnapshotByID(*s.snapshotID) + } + return s.metadata.CurrentSnapshot() +} + +func (s *Scan) Projection() (*iceberg.Schema, error) { + curSchema := s.metadata.CurrentSchema() + if s.snapshotID != nil { + snap := s.metadata.SnapshotByID(*s.snapshotID) + if snap == nil { + return nil, fmt.Errorf("%w: snapshot not found: %d", ErrInvalidOperation, *s.snapshotID) + } + + if snap.SchemaID != nil { + for _, schema := range s.metadata.Schemas() { + if schema.ID == *snap.SchemaID { + curSchema = schema + break + } + } + } + } + + if slices.Contains(s.selectedFields, "*") { + return curSchema, nil + } + + return curSchema.Select(s.caseSensitive, s.selectedFields...) +} + +func (s *Scan) buildPartitionProjection(specID int) (iceberg.BooleanExpression, error) { + project := newInclusiveProjection(s.metadata.CurrentSchema(), + s.metadata.PartitionSpecs()[specID], true) + return project(s.rowFilter) +} + +func (s *Scan) buildManifestEvaluator(specID int) (func(iceberg.ManifestFile) (bool, error), error) { + spec := s.metadata.PartitionSpecs()[specID] + return newManifestEvaluator(spec, s.metadata.CurrentSchema(), + s.partitionFilters.Get(specID), s.caseSensitive) +} + +func (s *Scan) buildPartitionEvaluator(specID int) func(iceberg.DataFile) (bool, error) { + spec := s.metadata.PartitionSpecs()[specID] + partType := spec.PartitionType(s.metadata.CurrentSchema()) + partSchema := iceberg.NewSchema(0, partType.FieldList...) + partExpr := s.partitionFilters.Get(specID) + + return func(d iceberg.DataFile) (bool, error) { + fn, err := iceberg.ExpressionEvaluator(partSchema, partExpr, s.caseSensitive) + if err != nil { + return false, err + } + + return fn(getPartitionRecord(d, partType)) + } +} + +func (s *Scan) checkSequenceNumber(minSeqNum int64, manifest iceberg.ManifestFile) bool { + return manifest.ManifestContent() == iceberg.ManifestContentData || + (manifest.ManifestContent() == iceberg.ManifestContentDeletes && + manifest.SequenceNum() >= minSeqNum) +} + +func minSequenceNum(manifests []iceberg.ManifestFile) int64 { + n := int64(0) + for _, m := range manifests { + if m.ManifestContent() == iceberg.ManifestContentData { + n = min(n, m.MinSequenceNum()) + } + } + return n +} + +func matchDeletesToData(entry iceberg.ManifestEntry, positionalDeletes []iceberg.ManifestEntry) ([]iceberg.DataFile, error) { + idx, _ := slices.BinarySearchFunc(positionalDeletes, entry, func(me1, me2 iceberg.ManifestEntry) int { + return cmp.Compare(me1.SequenceNum(), me2.SequenceNum()) + }) + + evaluator, err := newInclusiveMetricsEvaluator(iceberg.PositionalDeleteSchema, + iceberg.EqualTo(iceberg.Reference("file_path"), entry.DataFile().FilePath()), true, false) + if err != nil { + return nil, err + } + + out := make([]iceberg.DataFile, 0) + for _, relevant := range positionalDeletes[idx:] { + df := relevant.DataFile() + ok, err := evaluator(df) + if err != nil { + return nil, err + } + if ok { + out = append(out, df) + } + } + + return out, nil +} + +func (s *Scan) PlanFiles(ctx context.Context) ([]FileScanTask, error) { + snap := s.Snapshot() + if snap == nil { + return nil, nil + } + + // step 1: filter manifests using partition summaries + // the filter depends on the partition spec used to write the manifest file + // so create a cache of filters for each spec id + manifestEvaluators := newKeyDefaultMapWrapErr(s.buildManifestEvaluator) + manifestList, err := snap.Manifests(s.io) + if err != nil { + return nil, err + } + + // remove any manifests that we don't need to use + manifestList = slices.DeleteFunc(manifestList, func(mf iceberg.ManifestFile) bool { + eval := manifestEvaluators.Get(int(mf.PartitionSpecID())) + use, err := eval(mf) + return !use || err != nil + }) + + // step 2: filter the data files in each manifest + // this filter depends on the partition spec used to write the manifest file + partitionEvaluators := newKeyDefaultMap(s.buildPartitionEvaluator) + metricsEval, err := newInclusiveMetricsEvaluator( + s.metadata.CurrentSchema(), s.rowFilter, s.caseSensitive, s.options["include_empty_files"] == "true") + if err != nil { + return nil, err + } + + minSeqNum := minSequenceNum(manifestList) + dataEntries := make([]iceberg.ManifestEntry, 0) + positionalDeleteEntries := make([]iceberg.ManifestEntry, 0) + + nworkers := runtime.NumCPU() + var wg sync.WaitGroup + + manifestChan := make(chan iceberg.ManifestFile, len(manifestList)) + entryChan := make(chan []iceberg.ManifestEntry, 20) + + ctx, cancel := context.WithCancelCause(ctx) + for i := 0; i < nworkers; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + + for { + select { + case m, ok := <-manifestChan: + if !ok { + return + } + + if !s.checkSequenceNumber(minSeqNum, m) { + continue + } + + entries, err := openManifest(s.io, m, + partitionEvaluators.Get(int(m.PartitionSpecID())), metricsEval) + if err != nil { + cancel(err) + break + } + + entryChan <- entries + case <-ctx.Done(): + return + } + } + }() + } + + go func() { + wg.Wait() + close(entryChan) + }() + + for _, m := range manifestList { + manifestChan <- m + } + close(manifestChan) + +Loop: + for { + select { + case <-ctx.Done(): + return nil, context.Cause(ctx) + case entries, ok := <-entryChan: + if !ok { + // closed! + break Loop + } + + for _, e := range entries { + df := e.DataFile() + switch df.ContentType() { + case iceberg.EntryContentData: + dataEntries = append(dataEntries, e) + case iceberg.EntryContentPosDeletes: + positionalDeleteEntries = append(positionalDeleteEntries, e) + case iceberg.EntryContentEqDeletes: + return nil, fmt.Errorf("iceberg-go does not yet support equality deletes") + default: + return nil, fmt.Errorf("%w: unknown DataFileContent type (%s): %s", + ErrInvalidMetadata, df.ContentType(), e) + } + } + } + } + + slices.SortFunc(positionalDeleteEntries, func(a, b iceberg.ManifestEntry) int { + return cmp.Compare(a.SequenceNum(), b.SequenceNum()) + }) + + results := make([]FileScanTask, 0) + for _, e := range dataEntries { + deleteFiles, err := matchDeletesToData(e, positionalDeleteEntries) + if err != nil { + return nil, err + } + + results = append(results, FileScanTask{ + File: e.DataFile(), + DeleteFiles: deleteFiles, + Start: 0, + Length: e.DataFile().FileSizeBytes(), + }) + } + + return results, nil +} + +type FileScanTask struct { + File iceberg.DataFile + DeleteFiles []iceberg.DataFile + Start, Length int64 +} diff --git a/table/scanner_test.go b/table/scanner_test.go index af4b8f6..cee2bae 100644 --- a/table/scanner_test.go +++ b/table/scanner_test.go @@ -1,121 +1,121 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//go:build integration - -package table_test - -import ( - "context" - "testing" - - "github.com/apache/iceberg-go" - "github.com/apache/iceberg-go/catalog" - "github.com/apache/iceberg-go/io" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestScanner(t *testing.T) { - cat, err := catalog.NewRestCatalog("rest", "http://localhost:8181") - require.NoError(t, err) - - props := iceberg.Properties{ - io.S3Region: "us-east-1", - io.S3AccessKeyID: "admin", io.S3SecretAccessKey: "password"} - - tests := []struct { - table string - expr iceberg.BooleanExpression - expectedNumTasks int - }{ - {"test_all_types", iceberg.AlwaysTrue{}, 5}, - {"test_all_types", iceberg.LessThan(iceberg.Reference("intCol"), int32(3)), 3}, - {"test_all_types", iceberg.GreaterThanEqual(iceberg.Reference("intCol"), int32(3)), 2}, - {"test_partitioned_by_identity", - iceberg.GreaterThanEqual(iceberg.Reference("ts"), "2023-03-05T00:00:00+00:00"), 8}, - {"test_partitioned_by_identity", - iceberg.LessThan(iceberg.Reference("ts"), "2023-03-05T00:00:00+00:00"), 4}, - {"test_partitioned_by_years", iceberg.AlwaysTrue{}, 2}, - {"test_partitioned_by_years", iceberg.LessThan(iceberg.Reference("dt"), "2023-03-05"), 1}, - {"test_partitioned_by_years", iceberg.GreaterThanEqual(iceberg.Reference("dt"), "2023-03-05"), 1}, - {"test_partitioned_by_months", iceberg.GreaterThanEqual(iceberg.Reference("dt"), "2023-03-05"), 1}, - {"test_partitioned_by_days", iceberg.GreaterThanEqual(iceberg.Reference("ts"), "2023-03-05T00:00:00+00:00"), 4}, - {"test_partitioned_by_hours", iceberg.GreaterThanEqual(iceberg.Reference("ts"), "2023-03-05T00:00:00+00:00"), 8}, - {"test_partitioned_by_truncate", iceberg.GreaterThanEqual(iceberg.Reference("letter"), "e"), 8}, - {"test_partitioned_by_bucket", iceberg.GreaterThanEqual(iceberg.Reference("number"), int32(5)), 6}, - {"test_uuid_and_fixed_unpartitioned", iceberg.AlwaysTrue{}, 4}, - {"test_uuid_and_fixed_unpartitioned", iceberg.EqualTo(iceberg.Reference("uuid_col"), "102cb62f-e6f8-4eb0-9973-d9b012ff0967"), 1}, - } - - for _, tt := range tests { - t.Run(tt.table+" "+tt.expr.String(), func(t *testing.T) { - ident := catalog.ToRestIdentifier("default", tt.table) - - tbl, err := cat.LoadTable(context.Background(), ident, props) - require.NoError(t, err) - - scan := tbl.Scan(tt.expr, 0, true, "*") - tasks, err := scan.PlanFiles(context.Background()) - require.NoError(t, err) - - assert.Len(t, tasks, tt.expectedNumTasks) - }) - } -} - -func TestScannerWithDeletes(t *testing.T) { - cat, err := catalog.NewRestCatalog("rest", "http://localhost:8181") - require.NoError(t, err) - - props := iceberg.Properties{ - io.S3Region: "us-east-1", - io.S3AccessKeyID: "admin", io.S3SecretAccessKey: "password"} - - ident := catalog.ToRestIdentifier("default", "test_positional_mor_deletes") - - tbl, err := cat.LoadTable(context.Background(), ident, props) - require.NoError(t, err) - - scan := tbl.Scan(iceberg.AlwaysTrue{}, 0, true, "*") - tasks, err := scan.PlanFiles(context.Background()) - require.NoError(t, err) - - assert.Len(t, tasks, 1) - assert.Len(t, tasks[0].DeleteFiles, 1) - - tagScan, err := scan.UseRef("tag_12") - require.NoError(t, err) - - tasks, err = tagScan.PlanFiles(context.Background()) - require.NoError(t, err) - - assert.Len(t, tasks, 1) - assert.Len(t, tasks[0].DeleteFiles, 0) - - _, err = tagScan.UseRef("without_5") - assert.ErrorIs(t, err, iceberg.ErrInvalidArgument) - - tagScan, err = scan.UseRef("without_5") - require.NoError(t, err) - - tasks, err = tagScan.PlanFiles(context.Background()) - require.NoError(t, err) - - assert.Len(t, tasks, 1) - assert.Len(t, tasks[0].DeleteFiles, 1) -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//go:build integration + +package table_test + +import ( + "context" + "testing" + + "github.com/apache/iceberg-go" + "github.com/apache/iceberg-go/catalog" + "github.com/apache/iceberg-go/io" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestScanner(t *testing.T) { + cat, err := catalog.NewRestCatalog("rest", "http://localhost:8181") + require.NoError(t, err) + + props := iceberg.Properties{ + io.S3Region: "us-east-1", + io.S3AccessKeyID: "admin", io.S3SecretAccessKey: "password"} + + tests := []struct { + table string + expr iceberg.BooleanExpression + expectedNumTasks int + }{ + {"test_all_types", iceberg.AlwaysTrue{}, 5}, + {"test_all_types", iceberg.LessThan(iceberg.Reference("intCol"), int32(3)), 3}, + {"test_all_types", iceberg.GreaterThanEqual(iceberg.Reference("intCol"), int32(3)), 2}, + {"test_partitioned_by_identity", + iceberg.GreaterThanEqual(iceberg.Reference("ts"), "2023-03-05T00:00:00+00:00"), 8}, + {"test_partitioned_by_identity", + iceberg.LessThan(iceberg.Reference("ts"), "2023-03-05T00:00:00+00:00"), 4}, + {"test_partitioned_by_years", iceberg.AlwaysTrue{}, 2}, + {"test_partitioned_by_years", iceberg.LessThan(iceberg.Reference("dt"), "2023-03-05"), 1}, + {"test_partitioned_by_years", iceberg.GreaterThanEqual(iceberg.Reference("dt"), "2023-03-05"), 1}, + {"test_partitioned_by_months", iceberg.GreaterThanEqual(iceberg.Reference("dt"), "2023-03-05"), 1}, + {"test_partitioned_by_days", iceberg.GreaterThanEqual(iceberg.Reference("ts"), "2023-03-05T00:00:00+00:00"), 4}, + {"test_partitioned_by_hours", iceberg.GreaterThanEqual(iceberg.Reference("ts"), "2023-03-05T00:00:00+00:00"), 8}, + {"test_partitioned_by_truncate", iceberg.GreaterThanEqual(iceberg.Reference("letter"), "e"), 8}, + {"test_partitioned_by_bucket", iceberg.GreaterThanEqual(iceberg.Reference("number"), int32(5)), 6}, + {"test_uuid_and_fixed_unpartitioned", iceberg.AlwaysTrue{}, 4}, + {"test_uuid_and_fixed_unpartitioned", iceberg.EqualTo(iceberg.Reference("uuid_col"), "102cb62f-e6f8-4eb0-9973-d9b012ff0967"), 1}, + } + + for _, tt := range tests { + t.Run(tt.table+" "+tt.expr.String(), func(t *testing.T) { + ident := catalog.ToRestIdentifier("default", tt.table) + + tbl, err := cat.LoadTable(context.Background(), ident, props) + require.NoError(t, err) + + scan := tbl.Scan(tt.expr, 0, true, "*") + tasks, err := scan.PlanFiles(context.Background()) + require.NoError(t, err) + + assert.Len(t, tasks, tt.expectedNumTasks) + }) + } +} + +func TestScannerWithDeletes(t *testing.T) { + cat, err := catalog.NewRestCatalog("rest", "http://localhost:8181") + require.NoError(t, err) + + props := iceberg.Properties{ + io.S3Region: "us-east-1", + io.S3AccessKeyID: "admin", io.S3SecretAccessKey: "password"} + + ident := catalog.ToRestIdentifier("default", "test_positional_mor_deletes") + + tbl, err := cat.LoadTable(context.Background(), ident, props) + require.NoError(t, err) + + scan := tbl.Scan(iceberg.AlwaysTrue{}, 0, true, "*") + tasks, err := scan.PlanFiles(context.Background()) + require.NoError(t, err) + + assert.Len(t, tasks, 1) + assert.Len(t, tasks[0].DeleteFiles, 1) + + tagScan, err := scan.UseRef("tag_12") + require.NoError(t, err) + + tasks, err = tagScan.PlanFiles(context.Background()) + require.NoError(t, err) + + assert.Len(t, tasks, 1) + assert.Len(t, tasks[0].DeleteFiles, 0) + + _, err = tagScan.UseRef("without_5") + assert.ErrorIs(t, err, iceberg.ErrInvalidArgument) + + tagScan, err = scan.UseRef("without_5") + require.NoError(t, err) + + tasks, err = tagScan.PlanFiles(context.Background()) + require.NoError(t, err) + + assert.Len(t, tasks, 1) + assert.Len(t, tasks[0].DeleteFiles, 1) +} diff --git a/table/snapshots.go b/table/snapshots.go index c880d7d..b569ced 100644 --- a/table/snapshots.go +++ b/table/snapshots.go @@ -1,196 +1,196 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package table - -import ( - "encoding/json" - "errors" - "fmt" - "maps" - "strconv" - - "github.com/apache/iceberg-go" - "github.com/apache/iceberg-go/io" -) - -type Operation string - -const ( - OpAppend Operation = "append" - OpReplace Operation = "replace" - OpOverwrite Operation = "overwrite" - OpDelete Operation = "delete" -) - -var ( - ErrInvalidOperation = errors.New("invalid operation value") - ErrMissingOperation = errors.New("missing operation key") -) - -// ValidOperation ensures that a given string is one of the valid operation -// types: append,replace,overwrite,delete -func ValidOperation(s string) (Operation, error) { - switch s { - case "append", "replace", "overwrite", "delete": - return Operation(s), nil - } - return "", fmt.Errorf("%w: found '%s'", ErrInvalidOperation, s) -} - -const operationKey = "operation" - -// Summary stores the summary information for a snapshot indicating -// the operation that created the snapshot, and various properties -// which might exist in the summary. -type Summary struct { - Operation Operation - Properties map[string]string -} - -func (s *Summary) String() string { - out := string(s.Operation) - if s.Properties != nil { - data, _ := json.Marshal(s.Properties) - out += ", " + string(data) - } - return out -} - -func (s *Summary) Equals(other *Summary) bool { - if s == other { - return true - } - - if s != nil && other == nil { - return false - } - - if s.Operation != other.Operation { - return false - } - - if len(s.Properties) == 0 && len(other.Properties) == 0 { - return true - } - - return maps.Equal(s.Properties, other.Properties) -} - -func (s *Summary) UnmarshalJSON(b []byte) (err error) { - alias := map[string]string{} - if err = json.Unmarshal(b, &alias); err != nil { - return - } - - op, ok := alias[operationKey] - if !ok { - return ErrMissingOperation - } - - if s.Operation, err = ValidOperation(op); err != nil { - return - } - - delete(alias, operationKey) - s.Properties = alias - return nil -} - -func (s *Summary) MarshalJSON() ([]byte, error) { - props := maps.Clone(s.Properties) - if s.Operation != "" { - if props == nil { - props = make(map[string]string) - } - props[operationKey] = string(s.Operation) - } - - return json.Marshal(props) -} - -type Snapshot struct { - SnapshotID int64 `json:"snapshot-id"` - ParentSnapshotID *int64 `json:"parent-snapshot-id,omitempty"` - SequenceNumber int64 `json:"sequence-number"` - TimestampMs int64 `json:"timestamp-ms"` - ManifestList string `json:"manifest-list,omitempty"` - Summary *Summary `json:"summary,omitempty"` - SchemaID *int `json:"schema-id,omitempty"` -} - -func (s Snapshot) String() string { - var ( - op, parent, schema string - ) - - if s.Summary != nil { - op = s.Summary.String() + ": " - } - if s.ParentSnapshotID != nil { - parent = ", parent_id=" + strconv.FormatInt(*s.ParentSnapshotID, 10) - } - if s.SchemaID != nil { - schema = ", schema_id=" + strconv.Itoa(*s.SchemaID) - } - return fmt.Sprintf("%sid=%d%s%s, sequence_number=%d, timestamp_ms=%d, manifest_list=%s", - op, s.SnapshotID, parent, schema, s.SequenceNumber, s.TimestampMs, s.ManifestList) -} - -func (s Snapshot) Equals(other Snapshot) bool { - switch { - case s.ParentSnapshotID == nil && other.ParentSnapshotID != nil: - fallthrough - case s.ParentSnapshotID != nil && other.ParentSnapshotID == nil: - fallthrough - case s.SchemaID == nil && other.SchemaID != nil: - fallthrough - case s.SchemaID != nil && other.SchemaID == nil: - return false - } - - return s.SnapshotID == other.SnapshotID && - ((s.ParentSnapshotID == other.ParentSnapshotID) || (*s.ParentSnapshotID == *other.ParentSnapshotID)) && - ((s.SchemaID == other.SchemaID) || (*s.SchemaID == *other.SchemaID)) && - s.SequenceNumber == other.SequenceNumber && - s.TimestampMs == other.TimestampMs && - s.ManifestList == other.ManifestList && - s.Summary.Equals(other.Summary) -} - -func (s Snapshot) Manifests(fio io.IO) ([]iceberg.ManifestFile, error) { - if s.ManifestList != "" { - f, err := fio.Open(s.ManifestList) - if err != nil { - return nil, fmt.Errorf("could not open manifest file: %w", err) - } - defer f.Close() - return iceberg.ReadManifestList(f) - } - - return nil, nil -} - -type MetadataLogEntry struct { - MetadataFile string `json:"metadata-file"` - TimestampMs int64 `json:"timestamp-ms"` -} - -type SnapshotLogEntry struct { - SnapshotID int64 `json:"snapshot-id"` - TimestampMs int64 `json:"timestamp-ms"` -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package table + +import ( + "encoding/json" + "errors" + "fmt" + "maps" + "strconv" + + "github.com/apache/iceberg-go" + "github.com/apache/iceberg-go/io" +) + +type Operation string + +const ( + OpAppend Operation = "append" + OpReplace Operation = "replace" + OpOverwrite Operation = "overwrite" + OpDelete Operation = "delete" +) + +var ( + ErrInvalidOperation = errors.New("invalid operation value") + ErrMissingOperation = errors.New("missing operation key") +) + +// ValidOperation ensures that a given string is one of the valid operation +// types: append,replace,overwrite,delete +func ValidOperation(s string) (Operation, error) { + switch s { + case "append", "replace", "overwrite", "delete": + return Operation(s), nil + } + return "", fmt.Errorf("%w: found '%s'", ErrInvalidOperation, s) +} + +const operationKey = "operation" + +// Summary stores the summary information for a snapshot indicating +// the operation that created the snapshot, and various properties +// which might exist in the summary. +type Summary struct { + Operation Operation + Properties map[string]string +} + +func (s *Summary) String() string { + out := string(s.Operation) + if s.Properties != nil { + data, _ := json.Marshal(s.Properties) + out += ", " + string(data) + } + return out +} + +func (s *Summary) Equals(other *Summary) bool { + if s == other { + return true + } + + if s != nil && other == nil { + return false + } + + if s.Operation != other.Operation { + return false + } + + if len(s.Properties) == 0 && len(other.Properties) == 0 { + return true + } + + return maps.Equal(s.Properties, other.Properties) +} + +func (s *Summary) UnmarshalJSON(b []byte) (err error) { + alias := map[string]string{} + if err = json.Unmarshal(b, &alias); err != nil { + return + } + + op, ok := alias[operationKey] + if !ok { + return ErrMissingOperation + } + + if s.Operation, err = ValidOperation(op); err != nil { + return + } + + delete(alias, operationKey) + s.Properties = alias + return nil +} + +func (s *Summary) MarshalJSON() ([]byte, error) { + props := maps.Clone(s.Properties) + if s.Operation != "" { + if props == nil { + props = make(map[string]string) + } + props[operationKey] = string(s.Operation) + } + + return json.Marshal(props) +} + +type Snapshot struct { + SnapshotID int64 `json:"snapshot-id"` + ParentSnapshotID *int64 `json:"parent-snapshot-id,omitempty"` + SequenceNumber int64 `json:"sequence-number"` + TimestampMs int64 `json:"timestamp-ms"` + ManifestList string `json:"manifest-list,omitempty"` + Summary *Summary `json:"summary,omitempty"` + SchemaID *int `json:"schema-id,omitempty"` +} + +func (s Snapshot) String() string { + var ( + op, parent, schema string + ) + + if s.Summary != nil { + op = s.Summary.String() + ": " + } + if s.ParentSnapshotID != nil { + parent = ", parent_id=" + strconv.FormatInt(*s.ParentSnapshotID, 10) + } + if s.SchemaID != nil { + schema = ", schema_id=" + strconv.Itoa(*s.SchemaID) + } + return fmt.Sprintf("%sid=%d%s%s, sequence_number=%d, timestamp_ms=%d, manifest_list=%s", + op, s.SnapshotID, parent, schema, s.SequenceNumber, s.TimestampMs, s.ManifestList) +} + +func (s Snapshot) Equals(other Snapshot) bool { + switch { + case s.ParentSnapshotID == nil && other.ParentSnapshotID != nil: + fallthrough + case s.ParentSnapshotID != nil && other.ParentSnapshotID == nil: + fallthrough + case s.SchemaID == nil && other.SchemaID != nil: + fallthrough + case s.SchemaID != nil && other.SchemaID == nil: + return false + } + + return s.SnapshotID == other.SnapshotID && + ((s.ParentSnapshotID == other.ParentSnapshotID) || (*s.ParentSnapshotID == *other.ParentSnapshotID)) && + ((s.SchemaID == other.SchemaID) || (*s.SchemaID == *other.SchemaID)) && + s.SequenceNumber == other.SequenceNumber && + s.TimestampMs == other.TimestampMs && + s.ManifestList == other.ManifestList && + s.Summary.Equals(other.Summary) +} + +func (s Snapshot) Manifests(fio io.IO) ([]iceberg.ManifestFile, error) { + if s.ManifestList != "" { + f, err := fio.Open(s.ManifestList) + if err != nil { + return nil, fmt.Errorf("could not open manifest file: %w", err) + } + defer f.Close() + return iceberg.ReadManifestList(f) + } + + return nil, nil +} + +type MetadataLogEntry struct { + MetadataFile string `json:"metadata-file"` + TimestampMs int64 `json:"timestamp-ms"` +} + +type SnapshotLogEntry struct { + SnapshotID int64 `json:"snapshot-id"` + TimestampMs int64 `json:"timestamp-ms"` +} diff --git a/table/snapshots_test.go b/table/snapshots_test.go index 6a101a5..e647eff 100644 --- a/table/snapshots_test.go +++ b/table/snapshots_test.go @@ -1,115 +1,115 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package table_test - -import ( - "encoding/json" - "testing" - - "github.com/apache/iceberg-go/table" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func Snapshot() table.Snapshot { - parentID := int64(19) - manifest, schemaid := "s3:/a/b/c.avro", 3 - return table.Snapshot{ - SnapshotID: 25, - ParentSnapshotID: &parentID, - SequenceNumber: 200, - TimestampMs: 1602638573590, - ManifestList: manifest, - SchemaID: &schemaid, - Summary: &table.Summary{ - Operation: table.OpAppend, - }, - } -} - -func SnapshotWithProperties() table.Snapshot { - parentID := int64(19) - manifest, schemaid := "s3:/a/b/c.avro", 3 - return table.Snapshot{ - SnapshotID: 25, - ParentSnapshotID: &parentID, - SequenceNumber: 200, - TimestampMs: 1602638573590, - ManifestList: manifest, - SchemaID: &schemaid, - Summary: &table.Summary{ - Operation: table.OpAppend, - Properties: map[string]string{"foo": "bar"}, - }, - } -} - -func TestSerializeSnapshot(t *testing.T) { - snapshot := Snapshot() - data, err := json.Marshal(snapshot) - require.NoError(t, err) - - assert.JSONEq(t, `{ - "snapshot-id": 25, - "parent-snapshot-id": 19, - "sequence-number": 200, - "timestamp-ms": 1602638573590, - "manifest-list": "s3:/a/b/c.avro", - "summary": {"operation": "append"}, - "schema-id": 3 - }`, string(data)) -} - -func TestSerializeSnapshotWithProps(t *testing.T) { - snapshot := SnapshotWithProperties() - data, err := json.Marshal(snapshot) - require.NoError(t, err) - - assert.JSONEq(t, `{ - "snapshot-id": 25, - "parent-snapshot-id": 19, - "sequence-number": 200, - "timestamp-ms": 1602638573590, - "manifest-list": "s3:/a/b/c.avro", - "summary": {"operation": "append", "foo": "bar"}, - "schema-id": 3 - }`, string(data)) -} - -func TestMissingOperation(t *testing.T) { - var summary table.Summary - err := json.Unmarshal([]byte(`{"foo": "bar"}`), &summary) - assert.ErrorIs(t, err, table.ErrMissingOperation) -} - -func TestInvalidOperation(t *testing.T) { - var summary table.Summary - err := json.Unmarshal([]byte(`{"operation": "foobar"}`), &summary) - assert.ErrorIs(t, err, table.ErrInvalidOperation) - assert.ErrorContains(t, err, "found 'foobar'") -} - -func TestSnapshotString(t *testing.T) { - snapshot := Snapshot() - assert.Equal(t, `append: id=25, parent_id=19, schema_id=3, sequence_number=200, timestamp_ms=1602638573590, manifest_list=s3:/a/b/c.avro`, - snapshot.String()) - - snapshot = SnapshotWithProperties() - assert.Equal(t, `append, {"foo":"bar"}: id=25, parent_id=19, schema_id=3, sequence_number=200, timestamp_ms=1602638573590, manifest_list=s3:/a/b/c.avro`, - snapshot.String()) -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package table_test + +import ( + "encoding/json" + "testing" + + "github.com/apache/iceberg-go/table" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Snapshot() table.Snapshot { + parentID := int64(19) + manifest, schemaid := "s3:/a/b/c.avro", 3 + return table.Snapshot{ + SnapshotID: 25, + ParentSnapshotID: &parentID, + SequenceNumber: 200, + TimestampMs: 1602638573590, + ManifestList: manifest, + SchemaID: &schemaid, + Summary: &table.Summary{ + Operation: table.OpAppend, + }, + } +} + +func SnapshotWithProperties() table.Snapshot { + parentID := int64(19) + manifest, schemaid := "s3:/a/b/c.avro", 3 + return table.Snapshot{ + SnapshotID: 25, + ParentSnapshotID: &parentID, + SequenceNumber: 200, + TimestampMs: 1602638573590, + ManifestList: manifest, + SchemaID: &schemaid, + Summary: &table.Summary{ + Operation: table.OpAppend, + Properties: map[string]string{"foo": "bar"}, + }, + } +} + +func TestSerializeSnapshot(t *testing.T) { + snapshot := Snapshot() + data, err := json.Marshal(snapshot) + require.NoError(t, err) + + assert.JSONEq(t, `{ + "snapshot-id": 25, + "parent-snapshot-id": 19, + "sequence-number": 200, + "timestamp-ms": 1602638573590, + "manifest-list": "s3:/a/b/c.avro", + "summary": {"operation": "append"}, + "schema-id": 3 + }`, string(data)) +} + +func TestSerializeSnapshotWithProps(t *testing.T) { + snapshot := SnapshotWithProperties() + data, err := json.Marshal(snapshot) + require.NoError(t, err) + + assert.JSONEq(t, `{ + "snapshot-id": 25, + "parent-snapshot-id": 19, + "sequence-number": 200, + "timestamp-ms": 1602638573590, + "manifest-list": "s3:/a/b/c.avro", + "summary": {"operation": "append", "foo": "bar"}, + "schema-id": 3 + }`, string(data)) +} + +func TestMissingOperation(t *testing.T) { + var summary table.Summary + err := json.Unmarshal([]byte(`{"foo": "bar"}`), &summary) + assert.ErrorIs(t, err, table.ErrMissingOperation) +} + +func TestInvalidOperation(t *testing.T) { + var summary table.Summary + err := json.Unmarshal([]byte(`{"operation": "foobar"}`), &summary) + assert.ErrorIs(t, err, table.ErrInvalidOperation) + assert.ErrorContains(t, err, "found 'foobar'") +} + +func TestSnapshotString(t *testing.T) { + snapshot := Snapshot() + assert.Equal(t, `append: id=25, parent_id=19, schema_id=3, sequence_number=200, timestamp_ms=1602638573590, manifest_list=s3:/a/b/c.avro`, + snapshot.String()) + + snapshot = SnapshotWithProperties() + assert.Equal(t, `append, {"foo":"bar"}: id=25, parent_id=19, schema_id=3, sequence_number=200, timestamp_ms=1602638573590, manifest_list=s3:/a/b/c.avro`, + snapshot.String()) +} diff --git a/table/sorting.go b/table/sorting.go index 425a92e..0878dc4 100644 --- a/table/sorting.go +++ b/table/sorting.go @@ -1,177 +1,177 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package table - -import ( - "encoding/json" - "errors" - "fmt" - "slices" - "strings" - - "github.com/apache/iceberg-go" -) - -type SortDirection string - -const ( - SortASC SortDirection = "asc" - SortDESC SortDirection = "desc" -) - -type NullOrder string - -const ( - NullsFirst NullOrder = "nulls-first" - NullsLast NullOrder = "nulls-last" -) - -var ( - ErrInvalidSortDirection = errors.New("invalid sort direction, must be 'asc' or 'desc'") - ErrInvalidNullOrder = errors.New("invalid null order, must be 'nulls-first' or 'nulls-last'") -) - -// SortField describes a field used in a sort order definition. -type SortField struct { - // SourceID is the source column id from the table's schema - SourceID int `json:"source-id"` - // Transform is the tranformation used to produce values to be - // sorted on from the source column. - Transform iceberg.Transform `json:"transform"` - // Direction is an enum indicating ascending or descending direction. - Direction SortDirection `json:"direction"` - // NullOrder describes the order of null values when sorting - // should be only either nulls-first or nulls-last enum values. - NullOrder NullOrder `json:"null-order"` -} - -func (s *SortField) String() string { - if _, ok := s.Transform.(iceberg.IdentityTransform); ok { - return fmt.Sprintf("%d %s %s", s.SourceID, s.Direction, s.NullOrder) - } - return fmt.Sprintf("%s(%d) %s %s", s.Transform, s.SourceID, s.Direction, s.NullOrder) -} - -func (s *SortField) MarshalJSON() ([]byte, error) { - if s.Direction == "" { - s.Direction = SortASC - } - - if s.NullOrder == "" { - if s.Direction == SortASC { - s.NullOrder = NullsFirst - } else { - s.NullOrder = NullsLast - } - } - - type Alias SortField - return json.Marshal((*Alias)(s)) -} - -func (s *SortField) UnmarshalJSON(b []byte) error { - type Alias SortField - var aux = struct { - TransformString string `json:"transform"` - *Alias - }{ - Alias: (*Alias)(s), - } - - err := json.Unmarshal(b, &aux) - if err != nil { - return err - } - - if s.Transform, err = iceberg.ParseTransform(aux.TransformString); err != nil { - return err - } - - switch s.Direction { - case SortASC, SortDESC: - default: - return ErrInvalidSortDirection - } - - switch s.NullOrder { - case NullsFirst, NullsLast: - default: - return ErrInvalidNullOrder - } - - return nil -} - -const ( - InitialSortOrderID = 1 - UnsortedSortOrderID = 0 -) - -// A default Sort Order indicating no sort order at all -var UnsortedSortOrder = SortOrder{OrderID: UnsortedSortOrderID, Fields: []SortField{}} - -// SortOrder describes how the data is sorted within the table. -// -// Data can be sorted within partitions by columns to gain performance. The -// order of the sort fields within the list defines the order in which the -// sort is applied to the data. -type SortOrder struct { - OrderID int `json:"order-id"` - Fields []SortField `json:"fields"` -} - -func (s SortOrder) Equals(rhs SortOrder) bool { - return s.OrderID == rhs.OrderID && - slices.Equal(s.Fields, rhs.Fields) -} - -func (s SortOrder) String() string { - var b strings.Builder - fmt.Fprintf(&b, "%d: ", s.OrderID) - b.WriteByte('[') - for i, f := range s.Fields { - if i == 0 { - b.WriteByte('\n') - } - b.WriteString(f.String()) - b.WriteByte('\n') - } - b.WriteByte(']') - return b.String() -} - -func (s *SortOrder) UnmarshalJSON(b []byte) error { - type Alias SortOrder - aux := (*Alias)(s) - - if err := json.Unmarshal(b, aux); err != nil { - return err - } - - if len(s.Fields) == 0 { - s.Fields = []SortField{} - s.OrderID = 0 - return nil - } - - if s.OrderID == 0 { - s.OrderID = InitialSortOrderID // initialize default sort order id - } - - return nil -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package table + +import ( + "encoding/json" + "errors" + "fmt" + "slices" + "strings" + + "github.com/apache/iceberg-go" +) + +type SortDirection string + +const ( + SortASC SortDirection = "asc" + SortDESC SortDirection = "desc" +) + +type NullOrder string + +const ( + NullsFirst NullOrder = "nulls-first" + NullsLast NullOrder = "nulls-last" +) + +var ( + ErrInvalidSortDirection = errors.New("invalid sort direction, must be 'asc' or 'desc'") + ErrInvalidNullOrder = errors.New("invalid null order, must be 'nulls-first' or 'nulls-last'") +) + +// SortField describes a field used in a sort order definition. +type SortField struct { + // SourceID is the source column id from the table's schema + SourceID int `json:"source-id"` + // Transform is the tranformation used to produce values to be + // sorted on from the source column. + Transform iceberg.Transform `json:"transform"` + // Direction is an enum indicating ascending or descending direction. + Direction SortDirection `json:"direction"` + // NullOrder describes the order of null values when sorting + // should be only either nulls-first or nulls-last enum values. + NullOrder NullOrder `json:"null-order"` +} + +func (s *SortField) String() string { + if _, ok := s.Transform.(iceberg.IdentityTransform); ok { + return fmt.Sprintf("%d %s %s", s.SourceID, s.Direction, s.NullOrder) + } + return fmt.Sprintf("%s(%d) %s %s", s.Transform, s.SourceID, s.Direction, s.NullOrder) +} + +func (s *SortField) MarshalJSON() ([]byte, error) { + if s.Direction == "" { + s.Direction = SortASC + } + + if s.NullOrder == "" { + if s.Direction == SortASC { + s.NullOrder = NullsFirst + } else { + s.NullOrder = NullsLast + } + } + + type Alias SortField + return json.Marshal((*Alias)(s)) +} + +func (s *SortField) UnmarshalJSON(b []byte) error { + type Alias SortField + var aux = struct { + TransformString string `json:"transform"` + *Alias + }{ + Alias: (*Alias)(s), + } + + err := json.Unmarshal(b, &aux) + if err != nil { + return err + } + + if s.Transform, err = iceberg.ParseTransform(aux.TransformString); err != nil { + return err + } + + switch s.Direction { + case SortASC, SortDESC: + default: + return ErrInvalidSortDirection + } + + switch s.NullOrder { + case NullsFirst, NullsLast: + default: + return ErrInvalidNullOrder + } + + return nil +} + +const ( + InitialSortOrderID = 1 + UnsortedSortOrderID = 0 +) + +// A default Sort Order indicating no sort order at all +var UnsortedSortOrder = SortOrder{OrderID: UnsortedSortOrderID, Fields: []SortField{}} + +// SortOrder describes how the data is sorted within the table. +// +// Data can be sorted within partitions by columns to gain performance. The +// order of the sort fields within the list defines the order in which the +// sort is applied to the data. +type SortOrder struct { + OrderID int `json:"order-id"` + Fields []SortField `json:"fields"` +} + +func (s SortOrder) Equals(rhs SortOrder) bool { + return s.OrderID == rhs.OrderID && + slices.Equal(s.Fields, rhs.Fields) +} + +func (s SortOrder) String() string { + var b strings.Builder + fmt.Fprintf(&b, "%d: ", s.OrderID) + b.WriteByte('[') + for i, f := range s.Fields { + if i == 0 { + b.WriteByte('\n') + } + b.WriteString(f.String()) + b.WriteByte('\n') + } + b.WriteByte(']') + return b.String() +} + +func (s *SortOrder) UnmarshalJSON(b []byte) error { + type Alias SortOrder + aux := (*Alias)(s) + + if err := json.Unmarshal(b, aux); err != nil { + return err + } + + if len(s.Fields) == 0 { + s.Fields = []SortField{} + s.OrderID = 0 + return nil + } + + if s.OrderID == 0 { + s.OrderID = InitialSortOrderID // initialize default sort order id + } + + return nil +} diff --git a/table/sorting_test.go b/table/sorting_test.go index c12c8ff..d6c04fa 100644 --- a/table/sorting_test.go +++ b/table/sorting_test.go @@ -1,110 +1,110 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package table_test - -import ( - "encoding/json" - "testing" - - "github.com/apache/iceberg-go" - "github.com/apache/iceberg-go/table" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -var sortOrder = table.SortOrder{ - OrderID: 22, - Fields: []table.SortField{ - {SourceID: 19, Transform: iceberg.IdentityTransform{}, NullOrder: table.NullsFirst}, - {SourceID: 25, Transform: iceberg.BucketTransform{NumBuckets: 4}, Direction: table.SortDESC}, - {SourceID: 22, Transform: iceberg.VoidTransform{}, Direction: table.SortASC}, - }, -} - -func TestSerializeUnsortedSortOrder(t *testing.T) { - data, err := json.Marshal(table.UnsortedSortOrder) - require.NoError(t, err) - assert.JSONEq(t, `{"order-id": 0, "fields": []}`, string(data)) -} - -func TestSerializeSortOrder(t *testing.T) { - data, err := json.Marshal(sortOrder) - require.NoError(t, err) - assert.JSONEq(t, `{ - "order-id": 22, - "fields": [ - {"source-id": 19, "transform": "identity", "direction": "asc", "null-order": "nulls-first"}, - {"source-id": 25, "transform": "bucket[4]", "direction": "desc", "null-order": "nulls-last"}, - {"source-id": 22, "transform": "void", "direction": "asc", "null-order": "nulls-first"} - ] - }`, string(data)) -} - -func TestUnmarshalSortOrderDefaults(t *testing.T) { - var order table.SortOrder - require.NoError(t, json.Unmarshal([]byte(`{"fields": []}`), &order)) - assert.Equal(t, table.UnsortedSortOrder, order) - - require.NoError(t, json.Unmarshal([]byte(`{"fields": [{"source-id": 19, "transform": "identity", "direction": "asc", "null-order": "nulls-first"}]}`), &order)) - assert.Equal(t, table.InitialSortOrderID, order.OrderID) -} - -func TestUnmarshalInvalidSortDirection(t *testing.T) { - badJson := `{ - "order-id": 22, - "fields": [ - {"source-id": 19, "transform": "identity", "direction": "foobar", "null-order": "nulls-first"}, - {"source-id": 25, "transform": "bucket[4]", "direction": "desc", "null-order": "nulls-last"}, - {"source-id": 22, "transform": "void", "direction": "asc", "null-order": "nulls-first"} - ] - }` - - var order table.SortOrder - err := json.Unmarshal([]byte(badJson), &order) - assert.ErrorIs(t, err, table.ErrInvalidSortDirection) -} - -func TestUnmarshalInvalidSortNullOrder(t *testing.T) { - badJson := `{ - "order-id": 22, - "fields": [ - {"source-id": 19, "transform": "identity", "direction": "asc", "null-order": "foobar"}, - {"source-id": 25, "transform": "bucket[4]", "direction": "desc", "null-order": "nulls-last"}, - {"source-id": 22, "transform": "void", "direction": "asc", "null-order": "nulls-first"} - ] - }` - - var order table.SortOrder - err := json.Unmarshal([]byte(badJson), &order) - assert.ErrorIs(t, err, table.ErrInvalidNullOrder) -} - -func TestUnmarshalInvalidSortTransform(t *testing.T) { - badJson := `{ - "order-id": 22, - "fields": [ - {"source-id": 19, "transform": "foobar", "direction": "asc", "null-order": "nulls-first"}, - {"source-id": 25, "transform": "bucket[4]", "direction": "desc", "null-order": "nulls-last"}, - {"source-id": 22, "transform": "void", "direction": "asc", "null-order": "nulls-first"} - ] - }` - - var order table.SortOrder - err := json.Unmarshal([]byte(badJson), &order) - assert.ErrorIs(t, err, iceberg.ErrInvalidTransform) -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package table_test + +import ( + "encoding/json" + "testing" + + "github.com/apache/iceberg-go" + "github.com/apache/iceberg-go/table" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var sortOrder = table.SortOrder{ + OrderID: 22, + Fields: []table.SortField{ + {SourceID: 19, Transform: iceberg.IdentityTransform{}, NullOrder: table.NullsFirst}, + {SourceID: 25, Transform: iceberg.BucketTransform{NumBuckets: 4}, Direction: table.SortDESC}, + {SourceID: 22, Transform: iceberg.VoidTransform{}, Direction: table.SortASC}, + }, +} + +func TestSerializeUnsortedSortOrder(t *testing.T) { + data, err := json.Marshal(table.UnsortedSortOrder) + require.NoError(t, err) + assert.JSONEq(t, `{"order-id": 0, "fields": []}`, string(data)) +} + +func TestSerializeSortOrder(t *testing.T) { + data, err := json.Marshal(sortOrder) + require.NoError(t, err) + assert.JSONEq(t, `{ + "order-id": 22, + "fields": [ + {"source-id": 19, "transform": "identity", "direction": "asc", "null-order": "nulls-first"}, + {"source-id": 25, "transform": "bucket[4]", "direction": "desc", "null-order": "nulls-last"}, + {"source-id": 22, "transform": "void", "direction": "asc", "null-order": "nulls-first"} + ] + }`, string(data)) +} + +func TestUnmarshalSortOrderDefaults(t *testing.T) { + var order table.SortOrder + require.NoError(t, json.Unmarshal([]byte(`{"fields": []}`), &order)) + assert.Equal(t, table.UnsortedSortOrder, order) + + require.NoError(t, json.Unmarshal([]byte(`{"fields": [{"source-id": 19, "transform": "identity", "direction": "asc", "null-order": "nulls-first"}]}`), &order)) + assert.Equal(t, table.InitialSortOrderID, order.OrderID) +} + +func TestUnmarshalInvalidSortDirection(t *testing.T) { + badJson := `{ + "order-id": 22, + "fields": [ + {"source-id": 19, "transform": "identity", "direction": "foobar", "null-order": "nulls-first"}, + {"source-id": 25, "transform": "bucket[4]", "direction": "desc", "null-order": "nulls-last"}, + {"source-id": 22, "transform": "void", "direction": "asc", "null-order": "nulls-first"} + ] + }` + + var order table.SortOrder + err := json.Unmarshal([]byte(badJson), &order) + assert.ErrorIs(t, err, table.ErrInvalidSortDirection) +} + +func TestUnmarshalInvalidSortNullOrder(t *testing.T) { + badJson := `{ + "order-id": 22, + "fields": [ + {"source-id": 19, "transform": "identity", "direction": "asc", "null-order": "foobar"}, + {"source-id": 25, "transform": "bucket[4]", "direction": "desc", "null-order": "nulls-last"}, + {"source-id": 22, "transform": "void", "direction": "asc", "null-order": "nulls-first"} + ] + }` + + var order table.SortOrder + err := json.Unmarshal([]byte(badJson), &order) + assert.ErrorIs(t, err, table.ErrInvalidNullOrder) +} + +func TestUnmarshalInvalidSortTransform(t *testing.T) { + badJson := `{ + "order-id": 22, + "fields": [ + {"source-id": 19, "transform": "foobar", "direction": "asc", "null-order": "nulls-first"}, + {"source-id": 25, "transform": "bucket[4]", "direction": "desc", "null-order": "nulls-last"}, + {"source-id": 22, "transform": "void", "direction": "asc", "null-order": "nulls-first"} + ] + }` + + var order table.SortOrder + err := json.Unmarshal([]byte(badJson), &order) + assert.ErrorIs(t, err, iceberg.ErrInvalidTransform) +} diff --git a/table/table.go b/table/table.go index b4cf92c..4749dd7 100644 --- a/table/table.go +++ b/table/table.go @@ -1,112 +1,112 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package table - -import ( - "github.com/apache/iceberg-go" - "github.com/apache/iceberg-go/io" - "golang.org/x/exp/slices" -) - -type Identifier = []string - -type Table struct { - identifier Identifier - metadata Metadata - metadataLocation string - fs io.IO -} - -func (t Table) Equals(other Table) bool { - return slices.Equal(t.identifier, other.identifier) && - t.metadataLocation == other.metadataLocation && - t.metadata.Equals(other.metadata) -} - -func (t Table) Identifier() Identifier { return t.identifier } -func (t Table) Metadata() Metadata { return t.metadata } -func (t Table) MetadataLocation() string { return t.metadataLocation } -func (t Table) FS() io.IO { return t.fs } - -func (t Table) Schema() *iceberg.Schema { return t.metadata.CurrentSchema() } -func (t Table) Spec() iceberg.PartitionSpec { return t.metadata.PartitionSpec() } -func (t Table) SortOrder() SortOrder { return t.metadata.SortOrder() } -func (t Table) Properties() iceberg.Properties { return t.metadata.Properties() } -func (t Table) Location() string { return t.metadata.Location() } -func (t Table) CurrentSnapshot() *Snapshot { return t.metadata.CurrentSnapshot() } -func (t Table) SnapshotByID(id int64) *Snapshot { return t.metadata.SnapshotByID(id) } -func (t Table) SnapshotByName(name string) *Snapshot { return t.metadata.SnapshotByName(name) } -func (t Table) Schemas() map[int]*iceberg.Schema { - m := make(map[int]*iceberg.Schema) - for _, s := range t.metadata.Schemas() { - m[s.ID] = s - } - return m -} - -func (t Table) Scan(rowFilter iceberg.BooleanExpression, snapshotID int64, caseSensitive bool, fields ...string) *Scan { - s := &Scan{ - metadata: t.metadata, - io: t.fs, - rowFilter: rowFilter, - selectedFields: fields, - caseSensitive: caseSensitive, - } - - if snapshotID != 0 { - s.snapshotID = &snapshotID - } - - s.partitionFilters = newKeyDefaultMapWrapErr(s.buildPartitionProjection) - return s -} - -func New(ident Identifier, meta Metadata, location string, fs io.IO) *Table { - return &Table{ - identifier: ident, - metadata: meta, - metadataLocation: location, - fs: fs, - } -} - -func NewFromLocation(ident Identifier, metalocation string, fsys io.IO) (*Table, error) { - var meta Metadata - - if rf, ok := fsys.(io.ReadFileIO); ok { - data, err := rf.ReadFile(metalocation) - if err != nil { - return nil, err - } - - if meta, err = ParseMetadataBytes(data); err != nil { - return nil, err - } - } else { - f, err := fsys.Open(metalocation) - if err != nil { - return nil, err - } - defer f.Close() - - if meta, err = ParseMetadata(f); err != nil { - return nil, err - } - } - return New(ident, meta, metalocation, fsys), nil -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package table + +import ( + "github.com/apache/iceberg-go" + "github.com/apache/iceberg-go/io" + "golang.org/x/exp/slices" +) + +type Identifier = []string + +type Table struct { + identifier Identifier + metadata Metadata + metadataLocation string + fs io.IO +} + +func (t Table) Equals(other Table) bool { + return slices.Equal(t.identifier, other.identifier) && + t.metadataLocation == other.metadataLocation && + t.metadata.Equals(other.metadata) +} + +func (t Table) Identifier() Identifier { return t.identifier } +func (t Table) Metadata() Metadata { return t.metadata } +func (t Table) MetadataLocation() string { return t.metadataLocation } +func (t Table) FS() io.IO { return t.fs } + +func (t Table) Schema() *iceberg.Schema { return t.metadata.CurrentSchema() } +func (t Table) Spec() iceberg.PartitionSpec { return t.metadata.PartitionSpec() } +func (t Table) SortOrder() SortOrder { return t.metadata.SortOrder() } +func (t Table) Properties() iceberg.Properties { return t.metadata.Properties() } +func (t Table) Location() string { return t.metadata.Location() } +func (t Table) CurrentSnapshot() *Snapshot { return t.metadata.CurrentSnapshot() } +func (t Table) SnapshotByID(id int64) *Snapshot { return t.metadata.SnapshotByID(id) } +func (t Table) SnapshotByName(name string) *Snapshot { return t.metadata.SnapshotByName(name) } +func (t Table) Schemas() map[int]*iceberg.Schema { + m := make(map[int]*iceberg.Schema) + for _, s := range t.metadata.Schemas() { + m[s.ID] = s + } + return m +} + +func (t Table) Scan(rowFilter iceberg.BooleanExpression, snapshotID int64, caseSensitive bool, fields ...string) *Scan { + s := &Scan{ + metadata: t.metadata, + io: t.fs, + rowFilter: rowFilter, + selectedFields: fields, + caseSensitive: caseSensitive, + } + + if snapshotID != 0 { + s.snapshotID = &snapshotID + } + + s.partitionFilters = newKeyDefaultMapWrapErr(s.buildPartitionProjection) + return s +} + +func New(ident Identifier, meta Metadata, location string, fs io.IO) *Table { + return &Table{ + identifier: ident, + metadata: meta, + metadataLocation: location, + fs: fs, + } +} + +func NewFromLocation(ident Identifier, metalocation string, fsys io.IO) (*Table, error) { + var meta Metadata + + if rf, ok := fsys.(io.ReadFileIO); ok { + data, err := rf.ReadFile(metalocation) + if err != nil { + return nil, err + } + + if meta, err = ParseMetadataBytes(data); err != nil { + return nil, err + } + } else { + f, err := fsys.Open(metalocation) + if err != nil { + return nil, err + } + defer f.Close() + + if meta, err = ParseMetadata(f); err != nil { + return nil, err + } + } + return New(ident, meta, metalocation, fsys), nil +} diff --git a/table/table_test.go b/table/table_test.go index cde94ab..ec4236d 100644 --- a/table/table_test.go +++ b/table/table_test.go @@ -1,130 +1,130 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package table_test - -import ( - "bytes" - "testing" - - "github.com/apache/iceberg-go" - "github.com/apache/iceberg-go/internal" - "github.com/apache/iceberg-go/table" - "github.com/stretchr/testify/suite" -) - -type TableTestSuite struct { - suite.Suite - - tbl *table.Table -} - -func TestTable(t *testing.T) { - suite.Run(t, new(TableTestSuite)) -} - -func (t *TableTestSuite) SetupSuite() { - var mockfs internal.MockFS - mockfs.Test(t.T()) - mockfs.On("Open", "s3://bucket/test/location/uuid.metadata.json"). - Return(&internal.MockFile{Contents: bytes.NewReader([]byte(ExampleTableMetadataV2))}, nil) - defer mockfs.AssertExpectations(t.T()) - - tbl, err := table.NewFromLocation([]string{"foo"}, "s3://bucket/test/location/uuid.metadata.json", &mockfs) - t.Require().NoError(err) - t.Require().NotNil(tbl) - - t.Equal([]string{"foo"}, tbl.Identifier()) - t.Equal("s3://bucket/test/location/uuid.metadata.json", tbl.MetadataLocation()) - t.Equal(&mockfs, tbl.FS()) - - t.tbl = tbl -} - -func (t *TableTestSuite) TestNewTableFromReadFile() { - var mockfsReadFile internal.MockFSReadFile - mockfsReadFile.Test(t.T()) - mockfsReadFile.On("ReadFile", "s3://bucket/test/location/uuid.metadata.json"). - Return([]byte(ExampleTableMetadataV2), nil) - defer mockfsReadFile.AssertExpectations(t.T()) - - tbl2, err := table.NewFromLocation([]string{"foo"}, "s3://bucket/test/location/uuid.metadata.json", &mockfsReadFile) - t.Require().NoError(err) - t.Require().NotNil(tbl2) - - t.True(t.tbl.Equals(*tbl2)) -} - -func (t *TableTestSuite) TestSchema() { - t.True(t.tbl.Schema().Equals(iceberg.NewSchemaWithIdentifiers(1, []int{1, 2}, - iceberg.NestedField{ID: 1, Name: "x", Type: iceberg.PrimitiveTypes.Int64, Required: true}, - iceberg.NestedField{ID: 2, Name: "y", Type: iceberg.PrimitiveTypes.Int64, Required: true, Doc: "comment"}, - iceberg.NestedField{ID: 3, Name: "z", Type: iceberg.PrimitiveTypes.Int64, Required: true}, - ))) -} - -func (t *TableTestSuite) TestPartitionSpec() { - t.Equal(iceberg.NewPartitionSpec( - iceberg.PartitionField{SourceID: 1, FieldID: 1000, Transform: iceberg.IdentityTransform{}, Name: "x"}, - ), t.tbl.Spec()) -} - -func (t *TableTestSuite) TestSortOrder() { - t.Equal(table.SortOrder{ - OrderID: 3, - Fields: []table.SortField{ - {SourceID: 2, Transform: iceberg.IdentityTransform{}, Direction: table.SortASC, NullOrder: table.NullsFirst}, - {SourceID: 3, Transform: iceberg.BucketTransform{NumBuckets: 4}, Direction: table.SortDESC, NullOrder: table.NullsLast}, - }, - }, t.tbl.SortOrder()) -} - -func (t *TableTestSuite) TestLocation() { - t.Equal("s3://bucket/test/location", t.tbl.Location()) -} - -func (t *TableTestSuite) TestSnapshot() { - var ( - parentSnapshotID int64 = 3051729675574597004 - one = 1 - manifestList = "s3://a/b/2.avro" - ) - - testSnapshot := table.Snapshot{ - SnapshotID: 3055729675574597004, - ParentSnapshotID: &parentSnapshotID, - SequenceNumber: 1, - TimestampMs: 1555100955770, - ManifestList: manifestList, - Summary: &table.Summary{Operation: table.OpAppend, Properties: map[string]string{}}, - SchemaID: &one, - } - t.True(testSnapshot.Equals(*t.tbl.CurrentSnapshot())) - - t.True(testSnapshot.Equals(*t.tbl.SnapshotByID(3055729675574597004))) -} - -func (t *TableTestSuite) TestSnapshotByName() { - testSnapshot := table.Snapshot{ - SnapshotID: 3051729675574597004, - TimestampMs: 1515100955770, - ManifestList: "s3://a/b/1.avro", - Summary: &table.Summary{Operation: table.OpAppend}, - } - - t.True(testSnapshot.Equals(*t.tbl.SnapshotByName("test"))) -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package table_test + +import ( + "bytes" + "testing" + + "github.com/apache/iceberg-go" + "github.com/apache/iceberg-go/internal" + "github.com/apache/iceberg-go/table" + "github.com/stretchr/testify/suite" +) + +type TableTestSuite struct { + suite.Suite + + tbl *table.Table +} + +func TestTable(t *testing.T) { + suite.Run(t, new(TableTestSuite)) +} + +func (t *TableTestSuite) SetupSuite() { + var mockfs internal.MockFS + mockfs.Test(t.T()) + mockfs.On("Open", "s3://bucket/test/location/uuid.metadata.json"). + Return(&internal.MockFile{Contents: bytes.NewReader([]byte(ExampleTableMetadataV2))}, nil) + defer mockfs.AssertExpectations(t.T()) + + tbl, err := table.NewFromLocation([]string{"foo"}, "s3://bucket/test/location/uuid.metadata.json", &mockfs) + t.Require().NoError(err) + t.Require().NotNil(tbl) + + t.Equal([]string{"foo"}, tbl.Identifier()) + t.Equal("s3://bucket/test/location/uuid.metadata.json", tbl.MetadataLocation()) + t.Equal(&mockfs, tbl.FS()) + + t.tbl = tbl +} + +func (t *TableTestSuite) TestNewTableFromReadFile() { + var mockfsReadFile internal.MockFSReadFile + mockfsReadFile.Test(t.T()) + mockfsReadFile.On("ReadFile", "s3://bucket/test/location/uuid.metadata.json"). + Return([]byte(ExampleTableMetadataV2), nil) + defer mockfsReadFile.AssertExpectations(t.T()) + + tbl2, err := table.NewFromLocation([]string{"foo"}, "s3://bucket/test/location/uuid.metadata.json", &mockfsReadFile) + t.Require().NoError(err) + t.Require().NotNil(tbl2) + + t.True(t.tbl.Equals(*tbl2)) +} + +func (t *TableTestSuite) TestSchema() { + t.True(t.tbl.Schema().Equals(iceberg.NewSchemaWithIdentifiers(1, []int{1, 2}, + iceberg.NestedField{ID: 1, Name: "x", Type: iceberg.PrimitiveTypes.Int64, Required: true}, + iceberg.NestedField{ID: 2, Name: "y", Type: iceberg.PrimitiveTypes.Int64, Required: true, Doc: "comment"}, + iceberg.NestedField{ID: 3, Name: "z", Type: iceberg.PrimitiveTypes.Int64, Required: true}, + ))) +} + +func (t *TableTestSuite) TestPartitionSpec() { + t.Equal(iceberg.NewPartitionSpec( + iceberg.PartitionField{SourceID: 1, FieldID: 1000, Transform: iceberg.IdentityTransform{}, Name: "x"}, + ), t.tbl.Spec()) +} + +func (t *TableTestSuite) TestSortOrder() { + t.Equal(table.SortOrder{ + OrderID: 3, + Fields: []table.SortField{ + {SourceID: 2, Transform: iceberg.IdentityTransform{}, Direction: table.SortASC, NullOrder: table.NullsFirst}, + {SourceID: 3, Transform: iceberg.BucketTransform{NumBuckets: 4}, Direction: table.SortDESC, NullOrder: table.NullsLast}, + }, + }, t.tbl.SortOrder()) +} + +func (t *TableTestSuite) TestLocation() { + t.Equal("s3://bucket/test/location", t.tbl.Location()) +} + +func (t *TableTestSuite) TestSnapshot() { + var ( + parentSnapshotID int64 = 3051729675574597004 + one = 1 + manifestList = "s3://a/b/2.avro" + ) + + testSnapshot := table.Snapshot{ + SnapshotID: 3055729675574597004, + ParentSnapshotID: &parentSnapshotID, + SequenceNumber: 1, + TimestampMs: 1555100955770, + ManifestList: manifestList, + Summary: &table.Summary{Operation: table.OpAppend, Properties: map[string]string{}}, + SchemaID: &one, + } + t.True(testSnapshot.Equals(*t.tbl.CurrentSnapshot())) + + t.True(testSnapshot.Equals(*t.tbl.SnapshotByID(3055729675574597004))) +} + +func (t *TableTestSuite) TestSnapshotByName() { + testSnapshot := table.Snapshot{ + SnapshotID: 3051729675574597004, + TimestampMs: 1515100955770, + ManifestList: "s3://a/b/1.avro", + Summary: &table.Summary{Operation: table.OpAppend}, + } + + t.True(testSnapshot.Equals(*t.tbl.SnapshotByName("test"))) +} diff --git a/transforms.go b/transforms.go index 477ef18..cef5b8b 100644 --- a/transforms.go +++ b/transforms.go @@ -1,871 +1,871 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package iceberg - -import ( - "encoding" - "encoding/binary" - "fmt" - "math" - "math/big" - "strconv" - "strings" - "time" - "unsafe" - - "github.com/apache/arrow-go/v18/arrow/decimal128" - "github.com/google/uuid" - "github.com/twmb/murmur3" -) - -// ParseTransform takes the string representation of a transform as -// defined in the iceberg spec, and produces the appropriate Transform -// object or an error if the string is not a valid transform string. -func ParseTransform(s string) (Transform, error) { - s = strings.ToLower(s) - switch { - case strings.HasPrefix(s, "bucket"): - matches := regexFromBrackets.FindStringSubmatch(s) - if len(matches) != 2 { - break - } - - n, _ := strconv.Atoi(matches[1]) - return BucketTransform{NumBuckets: n}, nil - case strings.HasPrefix(s, "truncate"): - matches := regexFromBrackets.FindStringSubmatch(s) - if len(matches) != 2 { - break - } - - n, _ := strconv.Atoi(matches[1]) - return TruncateTransform{Width: n}, nil - default: - switch s { - case "identity": - return IdentityTransform{}, nil - case "void": - return VoidTransform{}, nil - case "year": - return YearTransform{}, nil - case "month": - return MonthTransform{}, nil - case "day": - return DayTransform{}, nil - case "hour": - return HourTransform{}, nil - } - } - - return nil, fmt.Errorf("%w: %s", ErrInvalidTransform, s) -} - -// Transform is an interface for the various Transformation types -// in partition specs. Currently, they do not yet provide actual -// transformation functions or implementation. That will come later as -// data reading gets implemented. -type Transform interface { - fmt.Stringer - encoding.TextMarshaler - ResultType(t Type) Type - Equals(Transform) bool - Apply(Optional[Literal]) Optional[Literal] - Project(name string, pred BoundPredicate) (UnboundPredicate, error) -} - -// IdentityTransform uses the identity function, performing no transformation -// but instead partitioning on the value itself. -type IdentityTransform struct{} - -func (t IdentityTransform) MarshalText() ([]byte, error) { - return []byte(t.String()), nil -} - -func (IdentityTransform) String() string { return "identity" } - -func (IdentityTransform) ResultType(t Type) Type { return t } - -func (IdentityTransform) Equals(other Transform) bool { - _, ok := other.(IdentityTransform) - return ok -} - -func (IdentityTransform) Apply(value Optional[Literal]) Optional[Literal] { - return value -} - -func (t IdentityTransform) Project(name string, pred BoundPredicate) (UnboundPredicate, error) { - if _, ok := pred.Term().(*BoundTransform); ok { - return projectTransformPredicate(t, name, pred) - } - - switch p := pred.(type) { - case BoundUnaryPredicate: - return p.AsUnbound(Reference(name)), nil - case BoundLiteralPredicate: - return p.AsUnbound(Reference(name), p.Literal()), nil - case BoundSetPredicate: - return p.AsUnbound(Reference(name), p.Literals().Members()), nil - } - - return nil, nil -} - -// VoidTransform is a transformation that always returns nil. -type VoidTransform struct{} - -func (t VoidTransform) MarshalText() ([]byte, error) { - return []byte(t.String()), nil -} - -func (VoidTransform) String() string { return "void" } - -func (VoidTransform) ResultType(t Type) Type { return t } - -func (VoidTransform) Equals(other Transform) bool { - _, ok := other.(VoidTransform) - return ok -} - -func (VoidTransform) Apply(value Optional[Literal]) Optional[Literal] { - return Optional[Literal]{} -} - -func (VoidTransform) Project(string, BoundPredicate) (UnboundPredicate, error) { - return nil, nil -} - -// BucketTransform transforms values into a bucket partition value. It is -// parameterized by a number of buckets. Bucket partition transforms use -// a 32-bit hash of the source value to produce a positive value by mod -// the bucket number. -type BucketTransform struct { - NumBuckets int -} - -func (t BucketTransform) MarshalText() ([]byte, error) { - return []byte(t.String()), nil -} - -func (t BucketTransform) String() string { return fmt.Sprintf("bucket[%d]", t.NumBuckets) } - -func (BucketTransform) ResultType(Type) Type { return PrimitiveTypes.Int32 } - -func hashHelperInt[T ~int32 | ~int64](v any) uint32 { - var ( - val = uint64(v.(T)) - buf [8]byte - b = buf[:] - ) - - binary.LittleEndian.PutUint64(b, val) - return murmur3.Sum32(b) -} - -func (t BucketTransform) Equals(other Transform) bool { - rhs, ok := other.(BucketTransform) - if !ok { - return false - } - - return t.NumBuckets == rhs.NumBuckets -} - -func (t BucketTransform) Apply(value Optional[Literal]) Optional[Literal] { - if !value.Valid { - return Optional[Literal]{} - } - - var hash uint32 - switch v := value.Val.(type) { - case TypedLiteral[[]byte]: - hash = murmur3.Sum32(v.Value()) - case StringLiteral: - hash = murmur3.Sum32(unsafe.Slice(unsafe.StringData(string(v)), len(v))) - case UUIDLiteral: - hash = murmur3.Sum32(v[:]) - case DecimalLiteral: - b, _ := v.MarshalBinary() - hash = murmur3.Sum32(b) - case Int32Literal: - hash = hashHelperInt[int64](int64(v)) - case Int64Literal: - hash = hashHelperInt[int64](int64(v)) - case DateLiteral: - hash = hashHelperInt[int64](int64(v)) - case TimeLiteral: - hash = hashHelperInt[int64](int64(v)) - case TimestampLiteral: - hash = hashHelperInt[int64](int64(v)) - default: - return Optional[Literal]{} - } - - return Optional[Literal]{ - Valid: true, - Val: Int32Literal((int32(hash) & math.MaxInt32) % int32(t.NumBuckets))} -} - -func (t BucketTransform) Transformer(src Type) func(any) Optional[int32] { - var h func(any) uint32 - - switch src.(type) { - case Int32Type: - h = hashHelperInt[int32] - case DateType: - h = hashHelperInt[Date] - case Int64Type: - h = hashHelperInt[int64] - case TimeType: - h = hashHelperInt[Time] - case TimestampType: - h = hashHelperInt[Timestamp] - case TimestampTzType: - h = hashHelperInt[Timestamp] - case DecimalType: - h = func(v any) uint32 { - b, _ := DecimalLiteral(v.(Decimal)).MarshalBinary() - return murmur3.Sum32(b) - } - case StringType, FixedType, BinaryType: - h = func(v any) uint32 { - if v, ok := v.([]byte); ok { - return murmur3.Sum32(v) - } - - str := v.(string) - return murmur3.Sum32(unsafe.Slice(unsafe.StringData(str), len(str))) - } - case UUIDType: - h = func(v any) uint32 { - if v, ok := v.([]byte); ok { - return murmur3.Sum32(v) - } - - u := v.(uuid.UUID) - return murmur3.Sum32(u[:]) - } - } - - return func(v any) Optional[int32] { - if v == nil { - return Optional[int32]{} - } - - return Optional[int32]{ - Valid: true, - Val: int32((int32(h(v)) & math.MaxInt32) % int32(t.NumBuckets))} - } -} - -func (t BucketTransform) Project(name string, pred BoundPredicate) (UnboundPredicate, error) { - if _, ok := pred.Term().(*BoundTransform); ok { - return projectTransformPredicate(t, name, pred) - } - - transformer := t.Transformer(pred.Term().Type()) - switch p := pred.(type) { - case BoundUnaryPredicate: - return p.AsUnbound(Reference(name)), nil - case BoundLiteralPredicate: - if p.Op() != OpEQ { - break - } - return p.AsUnbound(Reference(name), transformLiteral(transformer, p.Literal())), nil - case BoundSetPredicate: - if p.Op() != OpIn { - break - } - - return setApplyTransform(name, p, transformer), nil - } - - return nil, nil -} - -// TruncateTransform is a transformation for truncating a value to a specified width. -type TruncateTransform struct { - Width int -} - -func (t TruncateTransform) MarshalText() ([]byte, error) { - return []byte(t.String()), nil -} - -func (t TruncateTransform) String() string { return fmt.Sprintf("truncate[%d]", t.Width) } - -func (TruncateTransform) ResultType(t Type) Type { return t } - -func (t TruncateTransform) Equals(other Transform) bool { - rhs, ok := other.(TruncateTransform) - if !ok { - return false - } - - return t.Width == rhs.Width -} - -func (t TruncateTransform) Transformer(src Type) (func(any) any, error) { - switch src.(type) { - case Int32Type: - return func(v any) any { - if v == nil { - return nil - } - - val := v.(int32) - return val - (val % int32(t.Width)) - }, nil - case Int64Type: - return func(v any) any { - if v == nil { - return nil - } - - val := v.(int64) - return val - (val % int64(t.Width)) - }, nil - case StringType, BinaryType: - return func(v any) any { - switch v := v.(type) { - case string: - return v[:min(len(v), t.Width)] - case []byte: - return v[:min(len(v), t.Width)] - default: - return nil - } - }, nil - case DecimalType: - bigWidth := big.NewInt(int64(t.Width)) - return func(v any) any { - if v == nil { - return nil - } - - val := v.(Decimal) - unscaled := val.Val.BigInt() - // unscaled - (((unscaled % width) + width) % width) - applied := (&big.Int{}).Mod(unscaled, bigWidth) - applied.Add(applied, bigWidth).Mod(applied, bigWidth) - val.Val = decimal128.FromBigInt(unscaled.Sub(unscaled, applied)) - return val - }, nil - } - - return nil, fmt.Errorf("%w: cannot truncate for type %s", - ErrInvalidArgument, src) -} - -func (t TruncateTransform) Apply(value Optional[Literal]) (out Optional[Literal]) { - if !value.Valid { - return - } - - fn, err := t.Transformer(value.Val.Type()) - if err != nil { - return - } - - out.Valid = true - switch v := value.Val.(type) { - case Int32Literal: - out.Val = Int32Literal(fn(int32(v)).(int32)) - case Int64Literal: - out.Val = Int64Literal(fn(int64(v)).(int64)) - case DecimalLiteral: - out.Val = DecimalLiteral(fn(Decimal(v)).(Decimal)) - case StringLiteral: - out.Val = StringLiteral(fn(string(v)).(string)) - case BinaryLiteral: - out.Val = BinaryLiteral(fn([]byte(v)).([]byte)) - } - - return -} - -func (t TruncateTransform) Project(name string, pred BoundPredicate) (UnboundPredicate, error) { - if _, ok := pred.Term().(*BoundTransform); ok { - return projectTransformPredicate(t, name, pred) - } - - fieldType := pred.Term().Ref().Field().Type - - transformer, err := t.Transformer(fieldType) - if err != nil { - return nil, err - } - - switch p := pred.(type) { - case BoundUnaryPredicate: - return p.AsUnbound(Reference(name)), nil - case BoundSetPredicate: - if p.Op() != OpIn { - break - } - - switch fieldType.(type) { - case Int32Type: - return setApplyTransform(name, p, wrapTransformFn[int32](transformer)), nil - case Int64Type: - return setApplyTransform(name, p, wrapTransformFn[int64](transformer)), nil - case DecimalType: - return setApplyTransform(name, p, wrapTransformFn[Decimal](transformer)), nil - case StringType: - return setApplyTransform(name, p, wrapTransformFn[string](transformer)), nil - case BinaryType: - return setApplyTransform(name, p, wrapTransformFn[[]byte](transformer)), nil - } - case BoundLiteralPredicate: - switch fieldType.(type) { - case Int32Type: - return truncateNumber(name, p, wrapTransformFn[int32](transformer)) - case Int64Type: - return truncateNumber(name, p, wrapTransformFn[int64](transformer)) - case DecimalType: - return truncateNumber(name, p, wrapTransformFn[Decimal](transformer)) - case StringType: - return truncateArray(name, p, wrapTransformFn[string](transformer)) - case BinaryType: - return truncateArray(name, p, wrapTransformFn[[]byte](transformer)) - } - } - - return nil, nil -} - -var epochTM = time.Unix(0, 0).UTC() - -type timeTransform interface { - Transform - Transformer(Type) (func(any) Optional[int32], error) -} - -func projectTimeTransform(t timeTransform, name string, pred BoundPredicate) (UnboundPredicate, error) { - if _, ok := pred.Term().(*BoundTransform); ok { - return projectTransformPredicate(t, name, pred) - } - - transformer, err := t.Transformer(pred.Term().Ref().Type()) - if err != nil { - return nil, err - } - - switch p := pred.(type) { - case BoundUnaryPredicate: - return p.AsUnbound(Reference(name)), nil - case BoundLiteralPredicate: - return truncateNumber(name, p, transformer) - case BoundSetPredicate: - if p.Op() != OpIn { - break - } - - return setApplyTransform(name, p, transformer), nil - } - - return nil, nil -} - -// YearTransform transforms a datetime value into a year value. -type YearTransform struct{} - -func (t YearTransform) MarshalText() ([]byte, error) { - return []byte(t.String()), nil -} - -func (YearTransform) String() string { return "year" } - -func (YearTransform) ResultType(Type) Type { return PrimitiveTypes.Int32 } - -func (YearTransform) Equals(other Transform) bool { - _, ok := other.(YearTransform) - return ok -} - -func (YearTransform) Transformer(src Type) (func(any) Optional[int32], error) { - switch src.(type) { - case DateType: - return func(v any) Optional[int32] { - if v == nil { - return Optional[int32]{} - } - - return Optional[int32]{ - Valid: true, - Val: int32(v.(Date).ToTime().Year() - epochTM.Year()), - } - }, nil - case TimestampType, TimestampTzType: - return func(v any) Optional[int32] { - if v == nil { - return Optional[int32]{} - } - - return Optional[int32]{ - Valid: true, - Val: int32(v.(Timestamp).ToTime().Year() - epochTM.Year()), - } - }, nil - } - - return nil, fmt.Errorf("%w: cannot apply year transform for type %s", - ErrInvalidArgument, src) -} - -func (YearTransform) Apply(value Optional[Literal]) (out Optional[Literal]) { - if !value.Valid { - return - } - - switch v := value.Val.(type) { - case DateLiteral: - out.Valid = true - out.Val = Int32Literal(Date(v).ToTime().Year() - epochTM.Year()) - case TimestampLiteral: - out.Valid = true - out.Val = Int32Literal(Timestamp(v).ToTime().Year() - epochTM.Year()) - } - - return -} - -func (t YearTransform) Project(name string, pred BoundPredicate) (UnboundPredicate, error) { - return projectTimeTransform(t, name, pred) -} - -// MonthTransform transforms a datetime value into a month value. -type MonthTransform struct{} - -func (t MonthTransform) MarshalText() ([]byte, error) { - return []byte(t.String()), nil -} - -func (MonthTransform) String() string { return "month" } - -func (MonthTransform) ResultType(Type) Type { return PrimitiveTypes.Int32 } - -func (MonthTransform) Equals(other Transform) bool { - _, ok := other.(MonthTransform) - return ok -} - -func (MonthTransform) Transformer(src Type) (func(any) Optional[int32], error) { - switch src.(type) { - case DateType: - return func(v any) Optional[int32] { - if v == nil { - return Optional[int32]{} - } - - d := v.(Date).ToTime() - return Optional[int32]{ - Valid: true, - Val: int32((d.Year()-epochTM.Year())*12 + (int(d.Month()) - int(epochTM.Month()))), - } - - }, nil - case TimestampType, TimestampTzType: - return func(v any) Optional[int32] { - if v == nil { - return Optional[int32]{} - } - - d := v.(Timestamp).ToTime() - return Optional[int32]{ - Valid: true, - Val: int32((d.Year()-epochTM.Year())*12 + (int(d.Month()) - int(epochTM.Month()))), - } - - }, nil - - } - - return nil, fmt.Errorf("%w: cannot apply month transform for type %s", - ErrInvalidArgument, src) -} - -func (MonthTransform) Apply(value Optional[Literal]) (out Optional[Literal]) { - if !value.Valid { - return - } - - var tm time.Time - switch v := value.Val.(type) { - case DateLiteral: - tm = Date(v).ToTime() - case TimestampLiteral: - tm = Timestamp(v).ToTime() - default: - return - } - - out.Valid = true - out.Val = Int32Literal(int32((tm.Year()-epochTM.Year())*12 + (int(tm.Month()) - int(epochTM.Month())))) - return -} - -func (t MonthTransform) Project(name string, pred BoundPredicate) (UnboundPredicate, error) { - return projectTimeTransform(t, name, pred) -} - -// DayTransform transforms a datetime value into a date value. -type DayTransform struct{} - -func (t DayTransform) MarshalText() ([]byte, error) { - return []byte(t.String()), nil -} - -func (DayTransform) String() string { return "day" } - -func (DayTransform) ResultType(Type) Type { return PrimitiveTypes.Date } - -func (DayTransform) Equals(other Transform) bool { - _, ok := other.(DayTransform) - return ok -} - -func (DayTransform) Transformer(src Type) (func(any) Optional[int32], error) { - switch src.(type) { - case DateType: - return func(v any) Optional[int32] { - if v == nil { - return Optional[int32]{} - } - - return Optional[int32]{ - Valid: true, - Val: int32(v.(Date)), - } - }, nil - case TimestampType, TimestampTzType: - return func(v any) Optional[int32] { - if v == nil { - return Optional[int32]{} - } - - return Optional[int32]{ - Valid: true, - Val: int32(v.(Timestamp).ToDate()), - } - }, nil - } - - return nil, fmt.Errorf("%w: cannot apply day transform for type %s", - ErrInvalidArgument, src) -} - -func (DayTransform) Apply(value Optional[Literal]) (out Optional[Literal]) { - if !value.Valid { - return - } - - switch v := value.Val.(type) { - case DateLiteral: - out.Valid, out.Val = true, Int32Literal(v) - case TimestampLiteral: - out.Valid, out.Val = true, Int32Literal(Timestamp(v).ToDate()) - } - return -} - -func (t DayTransform) Project(name string, pred BoundPredicate) (UnboundPredicate, error) { - return projectTimeTransform(t, name, pred) -} - -// HourTransform transforms a datetime value into an hour value. -type HourTransform struct{} - -func (t HourTransform) MarshalText() ([]byte, error) { - return []byte(t.String()), nil -} - -func (HourTransform) String() string { return "hour" } - -func (HourTransform) ResultType(Type) Type { return PrimitiveTypes.Int32 } - -func (HourTransform) Equals(other Transform) bool { - _, ok := other.(HourTransform) - return ok -} - -func (HourTransform) Transformer(src Type) (func(any) Optional[int32], error) { - switch src.(type) { - case TimestampType, TimestampTzType: - const factor = int64(time.Hour / time.Microsecond) - return func(v any) Optional[int32] { - if v == nil { - return Optional[int32]{} - } - - return Optional[int32]{ - Valid: true, - Val: int32(int64(v.(Timestamp)) / factor), - } - }, nil - } - - return nil, fmt.Errorf("%w: cannot apply hour transform for type %s", - ErrInvalidArgument, src) -} - -func (HourTransform) Apply(value Optional[Literal]) (out Optional[Literal]) { - if !value.Valid { - return - } - - switch v := value.Val.(type) { - case TimestampLiteral: - const factor = int64(time.Hour / time.Microsecond) - out.Valid, out.Val = true, Int32Literal(int32(int64(v)/factor)) - } - - return -} - -func (t HourTransform) Project(name string, pred BoundPredicate) (UnboundPredicate, error) { - return projectTimeTransform(t, name, pred) -} - -func removeTransform(partName string, pred BoundPredicate) (UnboundPredicate, error) { - switch p := pred.(type) { - case BoundUnaryPredicate: - return p.AsUnbound(Reference(partName)), nil - case BoundLiteralPredicate: - return p.AsUnbound(Reference(partName), p.Literal()), nil - case BoundSetPredicate: - return p.AsUnbound(Reference(partName), p.Literals().Members()), nil - } - - return nil, fmt.Errorf("%w: cannot replace transform in unknown predicate: %s", - ErrInvalidArgument, pred) -} - -func projectTransformPredicate(t Transform, partitionName string, pred BoundPredicate) (UnboundPredicate, error) { - term := pred.Term() - bt, ok := term.(*BoundTransform) - if !ok || !t.Equals(bt.transform) { - return nil, nil - } - - return removeTransform(partitionName, pred) -} - -func transformLiteral[T LiteralType](fn func(any) Optional[T], lit Literal) Literal { - switch l := lit.(type) { - case BoolLiteral: - return NewLiteral(fn(bool(l)).Val) - case Int32Literal: - return NewLiteral(fn(int32(l)).Val) - case Int64Literal: - return NewLiteral(fn(int64(l)).Val) - case Float32Literal: - return NewLiteral(fn(float32(l)).Val) - case Float64Literal: - return NewLiteral(fn(float64(l)).Val) - case DateLiteral: - return NewLiteral(fn(Date(l)).Val) - case TimeLiteral: - return NewLiteral(fn(Time(l)).Val) - case TimestampLiteral: - return NewLiteral(fn(Timestamp(l)).Val) - case StringLiteral: - return NewLiteral(fn(string(l)).Val) - case FixedLiteral: - return NewLiteral(fn([]byte(l)).Val) - case BinaryLiteral: - return NewLiteral(fn([]byte(l)).Val) - case UUIDLiteral: - return NewLiteral(fn(uuid.UUID(l)).Val) - case DecimalLiteral: - return NewLiteral(fn(Decimal(l)).Val) - } - - panic("invalid literal type") -} - -func wrapTransformFn[T LiteralType](fn func(any) any) func(any) Optional[T] { - return func(v any) Optional[T] { - out := fn(v) - if out == nil { - return Optional[T]{} - } - return Optional[T]{Valid: true, Val: out.(T)} - } -} - -func truncateNumber[T LiteralType](name string, pred BoundLiteralPredicate, fn func(any) Optional[T]) (UnboundPredicate, error) { - boundary, ok := pred.Literal().(NumericLiteral) - if !ok { - return nil, fmt.Errorf("%w: expected numeric literal, got %s", - ErrInvalidArgument, boundary.Type()) - } - - switch pred.Op() { - case OpLT: - return LiteralPredicate(OpLTEQ, Reference(name), - transformLiteral(fn, boundary.Decrement())), nil - case OpLTEQ: - return LiteralPredicate(OpLTEQ, Reference(name), - transformLiteral(fn, boundary)), nil - case OpGT: - return LiteralPredicate(OpGTEQ, Reference(name), - transformLiteral(fn, boundary.Increment())), nil - case OpGTEQ: - return LiteralPredicate(OpGTEQ, Reference(name), - transformLiteral(fn, boundary)), nil - case OpEQ: - return LiteralPredicate(OpEQ, Reference(name), - transformLiteral(fn, boundary)), nil - } - - return nil, nil -} - -func truncateArray[T LiteralType](name string, pred BoundLiteralPredicate, fn func(any) Optional[T]) (UnboundPredicate, error) { - boundary := pred.Literal() - - switch pred.Op() { - case OpLT, OpLTEQ: - return LiteralPredicate(OpLTEQ, Reference(name), - transformLiteral(fn, boundary)), nil - case OpGT, OpGTEQ: - return LiteralPredicate(OpGTEQ, Reference(name), - transformLiteral(fn, boundary)), nil - case OpEQ: - return LiteralPredicate(OpEQ, Reference(name), - transformLiteral(fn, boundary)), nil - case OpStartsWith: - return LiteralPredicate(OpStartsWith, Reference(name), - transformLiteral(fn, boundary)), nil - case OpNotStartsWith: - return LiteralPredicate(OpNotStartsWith, Reference(name), - transformLiteral(fn, boundary)), nil - } - - return nil, nil -} - -func setApplyTransform[T LiteralType](name string, pred BoundSetPredicate, fn func(any) Optional[T]) UnboundPredicate { - lits := pred.Literals().Members() - for i, l := range lits { - lits[i] = transformLiteral(fn, l) - } - - return pred.AsUnbound(Reference(name), lits) -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package iceberg + +import ( + "encoding" + "encoding/binary" + "fmt" + "math" + "math/big" + "strconv" + "strings" + "time" + "unsafe" + + "github.com/apache/arrow-go/v18/arrow/decimal128" + "github.com/google/uuid" + "github.com/twmb/murmur3" +) + +// ParseTransform takes the string representation of a transform as +// defined in the iceberg spec, and produces the appropriate Transform +// object or an error if the string is not a valid transform string. +func ParseTransform(s string) (Transform, error) { + s = strings.ToLower(s) + switch { + case strings.HasPrefix(s, "bucket"): + matches := regexFromBrackets.FindStringSubmatch(s) + if len(matches) != 2 { + break + } + + n, _ := strconv.Atoi(matches[1]) + return BucketTransform{NumBuckets: n}, nil + case strings.HasPrefix(s, "truncate"): + matches := regexFromBrackets.FindStringSubmatch(s) + if len(matches) != 2 { + break + } + + n, _ := strconv.Atoi(matches[1]) + return TruncateTransform{Width: n}, nil + default: + switch s { + case "identity": + return IdentityTransform{}, nil + case "void": + return VoidTransform{}, nil + case "year": + return YearTransform{}, nil + case "month": + return MonthTransform{}, nil + case "day": + return DayTransform{}, nil + case "hour": + return HourTransform{}, nil + } + } + + return nil, fmt.Errorf("%w: %s", ErrInvalidTransform, s) +} + +// Transform is an interface for the various Transformation types +// in partition specs. Currently, they do not yet provide actual +// transformation functions or implementation. That will come later as +// data reading gets implemented. +type Transform interface { + fmt.Stringer + encoding.TextMarshaler + ResultType(t Type) Type + Equals(Transform) bool + Apply(Optional[Literal]) Optional[Literal] + Project(name string, pred BoundPredicate) (UnboundPredicate, error) +} + +// IdentityTransform uses the identity function, performing no transformation +// but instead partitioning on the value itself. +type IdentityTransform struct{} + +func (t IdentityTransform) MarshalText() ([]byte, error) { + return []byte(t.String()), nil +} + +func (IdentityTransform) String() string { return "identity" } + +func (IdentityTransform) ResultType(t Type) Type { return t } + +func (IdentityTransform) Equals(other Transform) bool { + _, ok := other.(IdentityTransform) + return ok +} + +func (IdentityTransform) Apply(value Optional[Literal]) Optional[Literal] { + return value +} + +func (t IdentityTransform) Project(name string, pred BoundPredicate) (UnboundPredicate, error) { + if _, ok := pred.Term().(*BoundTransform); ok { + return projectTransformPredicate(t, name, pred) + } + + switch p := pred.(type) { + case BoundUnaryPredicate: + return p.AsUnbound(Reference(name)), nil + case BoundLiteralPredicate: + return p.AsUnbound(Reference(name), p.Literal()), nil + case BoundSetPredicate: + return p.AsUnbound(Reference(name), p.Literals().Members()), nil + } + + return nil, nil +} + +// VoidTransform is a transformation that always returns nil. +type VoidTransform struct{} + +func (t VoidTransform) MarshalText() ([]byte, error) { + return []byte(t.String()), nil +} + +func (VoidTransform) String() string { return "void" } + +func (VoidTransform) ResultType(t Type) Type { return t } + +func (VoidTransform) Equals(other Transform) bool { + _, ok := other.(VoidTransform) + return ok +} + +func (VoidTransform) Apply(value Optional[Literal]) Optional[Literal] { + return Optional[Literal]{} +} + +func (VoidTransform) Project(string, BoundPredicate) (UnboundPredicate, error) { + return nil, nil +} + +// BucketTransform transforms values into a bucket partition value. It is +// parameterized by a number of buckets. Bucket partition transforms use +// a 32-bit hash of the source value to produce a positive value by mod +// the bucket number. +type BucketTransform struct { + NumBuckets int +} + +func (t BucketTransform) MarshalText() ([]byte, error) { + return []byte(t.String()), nil +} + +func (t BucketTransform) String() string { return fmt.Sprintf("bucket[%d]", t.NumBuckets) } + +func (BucketTransform) ResultType(Type) Type { return PrimitiveTypes.Int32 } + +func hashHelperInt[T ~int32 | ~int64](v any) uint32 { + var ( + val = uint64(v.(T)) + buf [8]byte + b = buf[:] + ) + + binary.LittleEndian.PutUint64(b, val) + return murmur3.Sum32(b) +} + +func (t BucketTransform) Equals(other Transform) bool { + rhs, ok := other.(BucketTransform) + if !ok { + return false + } + + return t.NumBuckets == rhs.NumBuckets +} + +func (t BucketTransform) Apply(value Optional[Literal]) Optional[Literal] { + if !value.Valid { + return Optional[Literal]{} + } + + var hash uint32 + switch v := value.Val.(type) { + case TypedLiteral[[]byte]: + hash = murmur3.Sum32(v.Value()) + case StringLiteral: + hash = murmur3.Sum32(unsafe.Slice(unsafe.StringData(string(v)), len(v))) + case UUIDLiteral: + hash = murmur3.Sum32(v[:]) + case DecimalLiteral: + b, _ := v.MarshalBinary() + hash = murmur3.Sum32(b) + case Int32Literal: + hash = hashHelperInt[int64](int64(v)) + case Int64Literal: + hash = hashHelperInt[int64](int64(v)) + case DateLiteral: + hash = hashHelperInt[int64](int64(v)) + case TimeLiteral: + hash = hashHelperInt[int64](int64(v)) + case TimestampLiteral: + hash = hashHelperInt[int64](int64(v)) + default: + return Optional[Literal]{} + } + + return Optional[Literal]{ + Valid: true, + Val: Int32Literal((int32(hash) & math.MaxInt32) % int32(t.NumBuckets))} +} + +func (t BucketTransform) Transformer(src Type) func(any) Optional[int32] { + var h func(any) uint32 + + switch src.(type) { + case Int32Type: + h = hashHelperInt[int32] + case DateType: + h = hashHelperInt[Date] + case Int64Type: + h = hashHelperInt[int64] + case TimeType: + h = hashHelperInt[Time] + case TimestampType: + h = hashHelperInt[Timestamp] + case TimestampTzType: + h = hashHelperInt[Timestamp] + case DecimalType: + h = func(v any) uint32 { + b, _ := DecimalLiteral(v.(Decimal)).MarshalBinary() + return murmur3.Sum32(b) + } + case StringType, FixedType, BinaryType: + h = func(v any) uint32 { + if v, ok := v.([]byte); ok { + return murmur3.Sum32(v) + } + + str := v.(string) + return murmur3.Sum32(unsafe.Slice(unsafe.StringData(str), len(str))) + } + case UUIDType: + h = func(v any) uint32 { + if v, ok := v.([]byte); ok { + return murmur3.Sum32(v) + } + + u := v.(uuid.UUID) + return murmur3.Sum32(u[:]) + } + } + + return func(v any) Optional[int32] { + if v == nil { + return Optional[int32]{} + } + + return Optional[int32]{ + Valid: true, + Val: int32((int32(h(v)) & math.MaxInt32) % int32(t.NumBuckets))} + } +} + +func (t BucketTransform) Project(name string, pred BoundPredicate) (UnboundPredicate, error) { + if _, ok := pred.Term().(*BoundTransform); ok { + return projectTransformPredicate(t, name, pred) + } + + transformer := t.Transformer(pred.Term().Type()) + switch p := pred.(type) { + case BoundUnaryPredicate: + return p.AsUnbound(Reference(name)), nil + case BoundLiteralPredicate: + if p.Op() != OpEQ { + break + } + return p.AsUnbound(Reference(name), transformLiteral(transformer, p.Literal())), nil + case BoundSetPredicate: + if p.Op() != OpIn { + break + } + + return setApplyTransform(name, p, transformer), nil + } + + return nil, nil +} + +// TruncateTransform is a transformation for truncating a value to a specified width. +type TruncateTransform struct { + Width int +} + +func (t TruncateTransform) MarshalText() ([]byte, error) { + return []byte(t.String()), nil +} + +func (t TruncateTransform) String() string { return fmt.Sprintf("truncate[%d]", t.Width) } + +func (TruncateTransform) ResultType(t Type) Type { return t } + +func (t TruncateTransform) Equals(other Transform) bool { + rhs, ok := other.(TruncateTransform) + if !ok { + return false + } + + return t.Width == rhs.Width +} + +func (t TruncateTransform) Transformer(src Type) (func(any) any, error) { + switch src.(type) { + case Int32Type: + return func(v any) any { + if v == nil { + return nil + } + + val := v.(int32) + return val - (val % int32(t.Width)) + }, nil + case Int64Type: + return func(v any) any { + if v == nil { + return nil + } + + val := v.(int64) + return val - (val % int64(t.Width)) + }, nil + case StringType, BinaryType: + return func(v any) any { + switch v := v.(type) { + case string: + return v[:min(len(v), t.Width)] + case []byte: + return v[:min(len(v), t.Width)] + default: + return nil + } + }, nil + case DecimalType: + bigWidth := big.NewInt(int64(t.Width)) + return func(v any) any { + if v == nil { + return nil + } + + val := v.(Decimal) + unscaled := val.Val.BigInt() + // unscaled - (((unscaled % width) + width) % width) + applied := (&big.Int{}).Mod(unscaled, bigWidth) + applied.Add(applied, bigWidth).Mod(applied, bigWidth) + val.Val = decimal128.FromBigInt(unscaled.Sub(unscaled, applied)) + return val + }, nil + } + + return nil, fmt.Errorf("%w: cannot truncate for type %s", + ErrInvalidArgument, src) +} + +func (t TruncateTransform) Apply(value Optional[Literal]) (out Optional[Literal]) { + if !value.Valid { + return + } + + fn, err := t.Transformer(value.Val.Type()) + if err != nil { + return + } + + out.Valid = true + switch v := value.Val.(type) { + case Int32Literal: + out.Val = Int32Literal(fn(int32(v)).(int32)) + case Int64Literal: + out.Val = Int64Literal(fn(int64(v)).(int64)) + case DecimalLiteral: + out.Val = DecimalLiteral(fn(Decimal(v)).(Decimal)) + case StringLiteral: + out.Val = StringLiteral(fn(string(v)).(string)) + case BinaryLiteral: + out.Val = BinaryLiteral(fn([]byte(v)).([]byte)) + } + + return +} + +func (t TruncateTransform) Project(name string, pred BoundPredicate) (UnboundPredicate, error) { + if _, ok := pred.Term().(*BoundTransform); ok { + return projectTransformPredicate(t, name, pred) + } + + fieldType := pred.Term().Ref().Field().Type + + transformer, err := t.Transformer(fieldType) + if err != nil { + return nil, err + } + + switch p := pred.(type) { + case BoundUnaryPredicate: + return p.AsUnbound(Reference(name)), nil + case BoundSetPredicate: + if p.Op() != OpIn { + break + } + + switch fieldType.(type) { + case Int32Type: + return setApplyTransform(name, p, wrapTransformFn[int32](transformer)), nil + case Int64Type: + return setApplyTransform(name, p, wrapTransformFn[int64](transformer)), nil + case DecimalType: + return setApplyTransform(name, p, wrapTransformFn[Decimal](transformer)), nil + case StringType: + return setApplyTransform(name, p, wrapTransformFn[string](transformer)), nil + case BinaryType: + return setApplyTransform(name, p, wrapTransformFn[[]byte](transformer)), nil + } + case BoundLiteralPredicate: + switch fieldType.(type) { + case Int32Type: + return truncateNumber(name, p, wrapTransformFn[int32](transformer)) + case Int64Type: + return truncateNumber(name, p, wrapTransformFn[int64](transformer)) + case DecimalType: + return truncateNumber(name, p, wrapTransformFn[Decimal](transformer)) + case StringType: + return truncateArray(name, p, wrapTransformFn[string](transformer)) + case BinaryType: + return truncateArray(name, p, wrapTransformFn[[]byte](transformer)) + } + } + + return nil, nil +} + +var epochTM = time.Unix(0, 0).UTC() + +type timeTransform interface { + Transform + Transformer(Type) (func(any) Optional[int32], error) +} + +func projectTimeTransform(t timeTransform, name string, pred BoundPredicate) (UnboundPredicate, error) { + if _, ok := pred.Term().(*BoundTransform); ok { + return projectTransformPredicate(t, name, pred) + } + + transformer, err := t.Transformer(pred.Term().Ref().Type()) + if err != nil { + return nil, err + } + + switch p := pred.(type) { + case BoundUnaryPredicate: + return p.AsUnbound(Reference(name)), nil + case BoundLiteralPredicate: + return truncateNumber(name, p, transformer) + case BoundSetPredicate: + if p.Op() != OpIn { + break + } + + return setApplyTransform(name, p, transformer), nil + } + + return nil, nil +} + +// YearTransform transforms a datetime value into a year value. +type YearTransform struct{} + +func (t YearTransform) MarshalText() ([]byte, error) { + return []byte(t.String()), nil +} + +func (YearTransform) String() string { return "year" } + +func (YearTransform) ResultType(Type) Type { return PrimitiveTypes.Int32 } + +func (YearTransform) Equals(other Transform) bool { + _, ok := other.(YearTransform) + return ok +} + +func (YearTransform) Transformer(src Type) (func(any) Optional[int32], error) { + switch src.(type) { + case DateType: + return func(v any) Optional[int32] { + if v == nil { + return Optional[int32]{} + } + + return Optional[int32]{ + Valid: true, + Val: int32(v.(Date).ToTime().Year() - epochTM.Year()), + } + }, nil + case TimestampType, TimestampTzType: + return func(v any) Optional[int32] { + if v == nil { + return Optional[int32]{} + } + + return Optional[int32]{ + Valid: true, + Val: int32(v.(Timestamp).ToTime().Year() - epochTM.Year()), + } + }, nil + } + + return nil, fmt.Errorf("%w: cannot apply year transform for type %s", + ErrInvalidArgument, src) +} + +func (YearTransform) Apply(value Optional[Literal]) (out Optional[Literal]) { + if !value.Valid { + return + } + + switch v := value.Val.(type) { + case DateLiteral: + out.Valid = true + out.Val = Int32Literal(Date(v).ToTime().Year() - epochTM.Year()) + case TimestampLiteral: + out.Valid = true + out.Val = Int32Literal(Timestamp(v).ToTime().Year() - epochTM.Year()) + } + + return +} + +func (t YearTransform) Project(name string, pred BoundPredicate) (UnboundPredicate, error) { + return projectTimeTransform(t, name, pred) +} + +// MonthTransform transforms a datetime value into a month value. +type MonthTransform struct{} + +func (t MonthTransform) MarshalText() ([]byte, error) { + return []byte(t.String()), nil +} + +func (MonthTransform) String() string { return "month" } + +func (MonthTransform) ResultType(Type) Type { return PrimitiveTypes.Int32 } + +func (MonthTransform) Equals(other Transform) bool { + _, ok := other.(MonthTransform) + return ok +} + +func (MonthTransform) Transformer(src Type) (func(any) Optional[int32], error) { + switch src.(type) { + case DateType: + return func(v any) Optional[int32] { + if v == nil { + return Optional[int32]{} + } + + d := v.(Date).ToTime() + return Optional[int32]{ + Valid: true, + Val: int32((d.Year()-epochTM.Year())*12 + (int(d.Month()) - int(epochTM.Month()))), + } + + }, nil + case TimestampType, TimestampTzType: + return func(v any) Optional[int32] { + if v == nil { + return Optional[int32]{} + } + + d := v.(Timestamp).ToTime() + return Optional[int32]{ + Valid: true, + Val: int32((d.Year()-epochTM.Year())*12 + (int(d.Month()) - int(epochTM.Month()))), + } + + }, nil + + } + + return nil, fmt.Errorf("%w: cannot apply month transform for type %s", + ErrInvalidArgument, src) +} + +func (MonthTransform) Apply(value Optional[Literal]) (out Optional[Literal]) { + if !value.Valid { + return + } + + var tm time.Time + switch v := value.Val.(type) { + case DateLiteral: + tm = Date(v).ToTime() + case TimestampLiteral: + tm = Timestamp(v).ToTime() + default: + return + } + + out.Valid = true + out.Val = Int32Literal(int32((tm.Year()-epochTM.Year())*12 + (int(tm.Month()) - int(epochTM.Month())))) + return +} + +func (t MonthTransform) Project(name string, pred BoundPredicate) (UnboundPredicate, error) { + return projectTimeTransform(t, name, pred) +} + +// DayTransform transforms a datetime value into a date value. +type DayTransform struct{} + +func (t DayTransform) MarshalText() ([]byte, error) { + return []byte(t.String()), nil +} + +func (DayTransform) String() string { return "day" } + +func (DayTransform) ResultType(Type) Type { return PrimitiveTypes.Date } + +func (DayTransform) Equals(other Transform) bool { + _, ok := other.(DayTransform) + return ok +} + +func (DayTransform) Transformer(src Type) (func(any) Optional[int32], error) { + switch src.(type) { + case DateType: + return func(v any) Optional[int32] { + if v == nil { + return Optional[int32]{} + } + + return Optional[int32]{ + Valid: true, + Val: int32(v.(Date)), + } + }, nil + case TimestampType, TimestampTzType: + return func(v any) Optional[int32] { + if v == nil { + return Optional[int32]{} + } + + return Optional[int32]{ + Valid: true, + Val: int32(v.(Timestamp).ToDate()), + } + }, nil + } + + return nil, fmt.Errorf("%w: cannot apply day transform for type %s", + ErrInvalidArgument, src) +} + +func (DayTransform) Apply(value Optional[Literal]) (out Optional[Literal]) { + if !value.Valid { + return + } + + switch v := value.Val.(type) { + case DateLiteral: + out.Valid, out.Val = true, Int32Literal(v) + case TimestampLiteral: + out.Valid, out.Val = true, Int32Literal(Timestamp(v).ToDate()) + } + return +} + +func (t DayTransform) Project(name string, pred BoundPredicate) (UnboundPredicate, error) { + return projectTimeTransform(t, name, pred) +} + +// HourTransform transforms a datetime value into an hour value. +type HourTransform struct{} + +func (t HourTransform) MarshalText() ([]byte, error) { + return []byte(t.String()), nil +} + +func (HourTransform) String() string { return "hour" } + +func (HourTransform) ResultType(Type) Type { return PrimitiveTypes.Int32 } + +func (HourTransform) Equals(other Transform) bool { + _, ok := other.(HourTransform) + return ok +} + +func (HourTransform) Transformer(src Type) (func(any) Optional[int32], error) { + switch src.(type) { + case TimestampType, TimestampTzType: + const factor = int64(time.Hour / time.Microsecond) + return func(v any) Optional[int32] { + if v == nil { + return Optional[int32]{} + } + + return Optional[int32]{ + Valid: true, + Val: int32(int64(v.(Timestamp)) / factor), + } + }, nil + } + + return nil, fmt.Errorf("%w: cannot apply hour transform for type %s", + ErrInvalidArgument, src) +} + +func (HourTransform) Apply(value Optional[Literal]) (out Optional[Literal]) { + if !value.Valid { + return + } + + switch v := value.Val.(type) { + case TimestampLiteral: + const factor = int64(time.Hour / time.Microsecond) + out.Valid, out.Val = true, Int32Literal(int32(int64(v)/factor)) + } + + return +} + +func (t HourTransform) Project(name string, pred BoundPredicate) (UnboundPredicate, error) { + return projectTimeTransform(t, name, pred) +} + +func removeTransform(partName string, pred BoundPredicate) (UnboundPredicate, error) { + switch p := pred.(type) { + case BoundUnaryPredicate: + return p.AsUnbound(Reference(partName)), nil + case BoundLiteralPredicate: + return p.AsUnbound(Reference(partName), p.Literal()), nil + case BoundSetPredicate: + return p.AsUnbound(Reference(partName), p.Literals().Members()), nil + } + + return nil, fmt.Errorf("%w: cannot replace transform in unknown predicate: %s", + ErrInvalidArgument, pred) +} + +func projectTransformPredicate(t Transform, partitionName string, pred BoundPredicate) (UnboundPredicate, error) { + term := pred.Term() + bt, ok := term.(*BoundTransform) + if !ok || !t.Equals(bt.transform) { + return nil, nil + } + + return removeTransform(partitionName, pred) +} + +func transformLiteral[T LiteralType](fn func(any) Optional[T], lit Literal) Literal { + switch l := lit.(type) { + case BoolLiteral: + return NewLiteral(fn(bool(l)).Val) + case Int32Literal: + return NewLiteral(fn(int32(l)).Val) + case Int64Literal: + return NewLiteral(fn(int64(l)).Val) + case Float32Literal: + return NewLiteral(fn(float32(l)).Val) + case Float64Literal: + return NewLiteral(fn(float64(l)).Val) + case DateLiteral: + return NewLiteral(fn(Date(l)).Val) + case TimeLiteral: + return NewLiteral(fn(Time(l)).Val) + case TimestampLiteral: + return NewLiteral(fn(Timestamp(l)).Val) + case StringLiteral: + return NewLiteral(fn(string(l)).Val) + case FixedLiteral: + return NewLiteral(fn([]byte(l)).Val) + case BinaryLiteral: + return NewLiteral(fn([]byte(l)).Val) + case UUIDLiteral: + return NewLiteral(fn(uuid.UUID(l)).Val) + case DecimalLiteral: + return NewLiteral(fn(Decimal(l)).Val) + } + + panic("invalid literal type") +} + +func wrapTransformFn[T LiteralType](fn func(any) any) func(any) Optional[T] { + return func(v any) Optional[T] { + out := fn(v) + if out == nil { + return Optional[T]{} + } + return Optional[T]{Valid: true, Val: out.(T)} + } +} + +func truncateNumber[T LiteralType](name string, pred BoundLiteralPredicate, fn func(any) Optional[T]) (UnboundPredicate, error) { + boundary, ok := pred.Literal().(NumericLiteral) + if !ok { + return nil, fmt.Errorf("%w: expected numeric literal, got %s", + ErrInvalidArgument, boundary.Type()) + } + + switch pred.Op() { + case OpLT: + return LiteralPredicate(OpLTEQ, Reference(name), + transformLiteral(fn, boundary.Decrement())), nil + case OpLTEQ: + return LiteralPredicate(OpLTEQ, Reference(name), + transformLiteral(fn, boundary)), nil + case OpGT: + return LiteralPredicate(OpGTEQ, Reference(name), + transformLiteral(fn, boundary.Increment())), nil + case OpGTEQ: + return LiteralPredicate(OpGTEQ, Reference(name), + transformLiteral(fn, boundary)), nil + case OpEQ: + return LiteralPredicate(OpEQ, Reference(name), + transformLiteral(fn, boundary)), nil + } + + return nil, nil +} + +func truncateArray[T LiteralType](name string, pred BoundLiteralPredicate, fn func(any) Optional[T]) (UnboundPredicate, error) { + boundary := pred.Literal() + + switch pred.Op() { + case OpLT, OpLTEQ: + return LiteralPredicate(OpLTEQ, Reference(name), + transformLiteral(fn, boundary)), nil + case OpGT, OpGTEQ: + return LiteralPredicate(OpGTEQ, Reference(name), + transformLiteral(fn, boundary)), nil + case OpEQ: + return LiteralPredicate(OpEQ, Reference(name), + transformLiteral(fn, boundary)), nil + case OpStartsWith: + return LiteralPredicate(OpStartsWith, Reference(name), + transformLiteral(fn, boundary)), nil + case OpNotStartsWith: + return LiteralPredicate(OpNotStartsWith, Reference(name), + transformLiteral(fn, boundary)), nil + } + + return nil, nil +} + +func setApplyTransform[T LiteralType](name string, pred BoundSetPredicate, fn func(any) Optional[T]) UnboundPredicate { + lits := pred.Literals().Members() + for i, l := range lits { + lits[i] = transformLiteral(fn, l) + } + + return pred.AsUnbound(Reference(name), lits) +} diff --git a/transforms_test.go b/transforms_test.go index a455ede..6656a7a 100644 --- a/transforms_test.go +++ b/transforms_test.go @@ -1,89 +1,89 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package iceberg_test - -import ( - "strings" - "testing" - - "github.com/apache/iceberg-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestParseTransform(t *testing.T) { - tests := []struct { - toparse string - expected iceberg.Transform - }{ - {"identity", iceberg.IdentityTransform{}}, - {"IdEnTiTy", iceberg.IdentityTransform{}}, - {"void", iceberg.VoidTransform{}}, - {"VOId", iceberg.VoidTransform{}}, - {"year", iceberg.YearTransform{}}, - {"yEAr", iceberg.YearTransform{}}, - {"month", iceberg.MonthTransform{}}, - {"MONtH", iceberg.MonthTransform{}}, - {"day", iceberg.DayTransform{}}, - {"DaY", iceberg.DayTransform{}}, - {"hour", iceberg.HourTransform{}}, - {"hOuR", iceberg.HourTransform{}}, - {"bucket[5]", iceberg.BucketTransform{NumBuckets: 5}}, - {"bucket[100]", iceberg.BucketTransform{NumBuckets: 100}}, - {"BUCKET[5]", iceberg.BucketTransform{NumBuckets: 5}}, - {"bUCKeT[100]", iceberg.BucketTransform{NumBuckets: 100}}, - {"truncate[10]", iceberg.TruncateTransform{Width: 10}}, - {"truncate[255]", iceberg.TruncateTransform{Width: 255}}, - {"TRUNCATE[10]", iceberg.TruncateTransform{Width: 10}}, - {"tRuNCATe[255]", iceberg.TruncateTransform{Width: 255}}, - } - - for _, tt := range tests { - t.Run(tt.toparse, func(t *testing.T) { - transform, err := iceberg.ParseTransform(tt.toparse) - require.NoError(t, err) - assert.Equal(t, tt.expected, transform) - - txt, err := transform.MarshalText() - assert.NoError(t, err) - assert.Equal(t, strings.ToLower(tt.toparse), string(txt)) - }) - } - - errorTests := []struct { - name string - toparse string - }{ - {"foobar", "foobar"}, - {"bucket no brackets", "bucket"}, - {"truncate no brackets", "truncate"}, - {"bucket no val", "bucket[]"}, - {"truncate no val", "truncate[]"}, - {"bucket neg", "bucket[-1]"}, - {"truncate neg", "truncate[-1]"}, - } - - for _, tt := range errorTests { - t.Run(tt.name, func(t *testing.T) { - tr, err := iceberg.ParseTransform(tt.toparse) - assert.Nil(t, tr) - assert.ErrorIs(t, err, iceberg.ErrInvalidTransform) - assert.ErrorContains(t, err, tt.toparse) - }) - } -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package iceberg_test + +import ( + "strings" + "testing" + + "github.com/apache/iceberg-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseTransform(t *testing.T) { + tests := []struct { + toparse string + expected iceberg.Transform + }{ + {"identity", iceberg.IdentityTransform{}}, + {"IdEnTiTy", iceberg.IdentityTransform{}}, + {"void", iceberg.VoidTransform{}}, + {"VOId", iceberg.VoidTransform{}}, + {"year", iceberg.YearTransform{}}, + {"yEAr", iceberg.YearTransform{}}, + {"month", iceberg.MonthTransform{}}, + {"MONtH", iceberg.MonthTransform{}}, + {"day", iceberg.DayTransform{}}, + {"DaY", iceberg.DayTransform{}}, + {"hour", iceberg.HourTransform{}}, + {"hOuR", iceberg.HourTransform{}}, + {"bucket[5]", iceberg.BucketTransform{NumBuckets: 5}}, + {"bucket[100]", iceberg.BucketTransform{NumBuckets: 100}}, + {"BUCKET[5]", iceberg.BucketTransform{NumBuckets: 5}}, + {"bUCKeT[100]", iceberg.BucketTransform{NumBuckets: 100}}, + {"truncate[10]", iceberg.TruncateTransform{Width: 10}}, + {"truncate[255]", iceberg.TruncateTransform{Width: 255}}, + {"TRUNCATE[10]", iceberg.TruncateTransform{Width: 10}}, + {"tRuNCATe[255]", iceberg.TruncateTransform{Width: 255}}, + } + + for _, tt := range tests { + t.Run(tt.toparse, func(t *testing.T) { + transform, err := iceberg.ParseTransform(tt.toparse) + require.NoError(t, err) + assert.Equal(t, tt.expected, transform) + + txt, err := transform.MarshalText() + assert.NoError(t, err) + assert.Equal(t, strings.ToLower(tt.toparse), string(txt)) + }) + } + + errorTests := []struct { + name string + toparse string + }{ + {"foobar", "foobar"}, + {"bucket no brackets", "bucket"}, + {"truncate no brackets", "truncate"}, + {"bucket no val", "bucket[]"}, + {"truncate no val", "truncate[]"}, + {"bucket neg", "bucket[-1]"}, + {"truncate neg", "truncate[-1]"}, + } + + for _, tt := range errorTests { + t.Run(tt.name, func(t *testing.T) { + tr, err := iceberg.ParseTransform(tt.toparse) + assert.Nil(t, tr) + assert.ErrorIs(t, err, iceberg.ErrInvalidTransform) + assert.ErrorContains(t, err, tt.toparse) + }) + } +} diff --git a/types.go b/types.go index 6729964..8aab26c 100644 --- a/types.go +++ b/types.go @@ -1,639 +1,639 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package iceberg - -import ( - "encoding/json" - "fmt" - "regexp" - "strconv" - "strings" - "time" - - "github.com/apache/arrow-go/v18/arrow/decimal128" - "golang.org/x/exp/slices" -) - -var ( - regexFromBrackets = regexp.MustCompile(`^\w+\[(\d+)\]$`) - decimalRegex = regexp.MustCompile(`decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)`) -) - -type Properties map[string]string - -// Type is an interface representing any of the available iceberg types, -// such as primitives (int32/int64/etc.) or nested types (list/struct/map). -type Type interface { - fmt.Stringer - Type() string - Equals(Type) bool -} - -// NestedType is an interface that allows access to the child fields of -// a nested type such as a list/struct/map type. -type NestedType interface { - Type - Fields() []NestedField -} - -type typeIFace struct { - Type -} - -func (t *typeIFace) MarshalJSON() ([]byte, error) { - if nested, ok := t.Type.(NestedType); ok { - return json.Marshal(nested) - } - return []byte(`"` + t.Type.Type() + `"`), nil -} - -func (t *typeIFace) UnmarshalJSON(b []byte) error { - var typename string - err := json.Unmarshal(b, &typename) - if err == nil { - switch typename { - case "boolean": - t.Type = BooleanType{} - case "int": - t.Type = Int32Type{} - case "long": - t.Type = Int64Type{} - case "float": - t.Type = Float32Type{} - case "double": - t.Type = Float64Type{} - case "date": - t.Type = DateType{} - case "time": - t.Type = TimeType{} - case "timestamp": - t.Type = TimestampType{} - case "timestamptz": - t.Type = TimestampTzType{} - case "string": - t.Type = StringType{} - case "uuid": - t.Type = UUIDType{} - case "binary": - t.Type = BinaryType{} - default: - switch { - case strings.HasPrefix(typename, "fixed"): - matches := regexFromBrackets.FindStringSubmatch(typename) - if len(matches) != 2 { - return fmt.Errorf("%w: %s", ErrInvalidTypeString, typename) - } - - n, _ := strconv.Atoi(matches[1]) - t.Type = FixedType{len: n} - case strings.HasPrefix(typename, "decimal"): - matches := decimalRegex.FindStringSubmatch(typename) - if len(matches) != 3 { - return fmt.Errorf("%w: %s", ErrInvalidTypeString, typename) - } - - prec, _ := strconv.Atoi(matches[1]) - scale, _ := strconv.Atoi(matches[2]) - t.Type = DecimalType{precision: prec, scale: scale} - default: - return fmt.Errorf("%w: unrecognized field type", ErrInvalidSchema) - } - } - return nil - } - - aux := struct { - TypeName string `json:"type"` - }{} - if err = json.Unmarshal(b, &aux); err != nil { - return err - } - - switch aux.TypeName { - case "list": - t.Type = &ListType{} - case "map": - t.Type = &MapType{} - case "struct": - t.Type = &StructType{} - default: - return fmt.Errorf("%w: %s", ErrInvalidTypeString, aux.TypeName) - } - - return json.Unmarshal(b, t.Type) -} - -type NestedField struct { - Type `json:"-"` - - ID int `json:"id"` - Name string `json:"name"` - Required bool `json:"required"` - Doc string `json:"doc,omitempty"` - InitialDefault any `json:"initial-default,omitempty"` - WriteDefault any `json:"write-default,omitempty"` -} - -func optOrReq(required bool) string { - if required { - return "required" - } - return "optional" -} - -func (n NestedField) String() string { - doc := n.Doc - if doc != "" { - doc = " (" + doc + ")" - } - - return fmt.Sprintf("%d: %s: %s %s%s", - n.ID, n.Name, optOrReq(n.Required), n.Type, doc) -} - -func (n *NestedField) Equals(other NestedField) bool { - return n.ID == other.ID && - n.Name == other.Name && - n.Required == other.Required && - n.Doc == other.Doc && - n.InitialDefault == other.InitialDefault && - n.WriteDefault == other.WriteDefault && - n.Type.Equals(other.Type) -} - -func (n NestedField) MarshalJSON() ([]byte, error) { - type Alias NestedField - return json.Marshal(struct { - Type *typeIFace `json:"type"` - *Alias - }{Type: &typeIFace{n.Type}, Alias: (*Alias)(&n)}) -} - -func (n *NestedField) UnmarshalJSON(b []byte) error { - type Alias NestedField - aux := struct { - Type typeIFace `json:"type"` - *Alias - }{ - Alias: (*Alias)(n), - } - - if err := json.Unmarshal(b, &aux); err != nil { - return err - } - - n.Type = aux.Type.Type - - return nil -} - -type StructType struct { - FieldList []NestedField `json:"fields"` -} - -func (s *StructType) Equals(other Type) bool { - st, ok := other.(*StructType) - if !ok { - return false - } - - return slices.EqualFunc(s.FieldList, st.FieldList, func(a, b NestedField) bool { - return a.Equals(b) - }) -} - -func (s *StructType) Fields() []NestedField { return s.FieldList } - -func (s *StructType) MarshalJSON() ([]byte, error) { - type Alias StructType - return json.Marshal(struct { - Type string `json:"type"` - *Alias - }{Type: s.Type(), Alias: (*Alias)(s)}) -} - -func (*StructType) Type() string { return "struct" } -func (s *StructType) String() string { - var b strings.Builder - b.WriteString("struct<") - for i, f := range s.FieldList { - if i != 0 { - b.WriteString(", ") - } - fmt.Fprintf(&b, "%d: %s: ", - f.ID, f.Name) - if f.Required { - b.WriteString("required ") - } else { - b.WriteString("optional ") - } - b.WriteString(f.Type.String()) - if f.Doc != "" { - b.WriteString(" (") - b.WriteString(f.Doc) - b.WriteByte(')') - } - } - b.WriteString(">") - - return b.String() -} - -type ListType struct { - ElementID int `json:"element-id"` - Element Type `json:"-"` - ElementRequired bool `json:"element-required"` -} - -func (l *ListType) MarshalJSON() ([]byte, error) { - type Alias ListType - return json.Marshal(struct { - Type string `json:"type"` - *Alias - Element *typeIFace `json:"element"` - }{Type: l.Type(), Alias: (*Alias)(l), Element: &typeIFace{l.Element}}) -} - -func (l *ListType) Equals(other Type) bool { - rhs, ok := other.(*ListType) - if !ok { - return false - } - - return l.ElementID == rhs.ElementID && - l.Element.Equals(rhs.Element) && - l.ElementRequired == rhs.ElementRequired -} - -func (l *ListType) Fields() []NestedField { - return []NestedField{l.ElementField()} -} - -func (l *ListType) ElementField() NestedField { - return NestedField{ - ID: l.ElementID, - Name: "element", - Type: l.Element, - Required: l.ElementRequired, - } -} - -func (*ListType) Type() string { return "list" } -func (l *ListType) String() string { return fmt.Sprintf("list<%s>", l.Element) } - -func (l *ListType) UnmarshalJSON(b []byte) error { - aux := struct { - ID int `json:"element-id"` - Elem typeIFace `json:"element"` - Req bool `json:"element-required"` - }{} - if err := json.Unmarshal(b, &aux); err != nil { - return err - } - - l.ElementID = aux.ID - l.Element = aux.Elem.Type - l.ElementRequired = aux.Req - return nil -} - -type MapType struct { - KeyID int `json:"key-id"` - KeyType Type `json:"-"` - ValueID int `json:"value-id"` - ValueType Type `json:"-"` - ValueRequired bool `json:"value-required"` -} - -func (m *MapType) MarshalJSON() ([]byte, error) { - type Alias MapType - return json.Marshal(struct { - Type string `json:"type"` - *Alias - KeyType *typeIFace `json:"key"` - ValueType *typeIFace `json:"value"` - }{Type: m.Type(), Alias: (*Alias)(m), - KeyType: &typeIFace{m.KeyType}, - ValueType: &typeIFace{m.ValueType}}) -} - -func (m *MapType) Equals(other Type) bool { - rhs, ok := other.(*MapType) - if !ok { - return false - } - - return m.KeyID == rhs.KeyID && - m.KeyType.Equals(rhs.KeyType) && - m.ValueID == rhs.ValueID && - m.ValueType.Equals(rhs.ValueType) && - m.ValueRequired == rhs.ValueRequired -} - -func (m *MapType) Fields() []NestedField { - return []NestedField{m.KeyField(), m.ValueField()} -} - -func (m *MapType) KeyField() NestedField { - return NestedField{ - Name: "key", - ID: m.KeyID, - Type: m.KeyType, - Required: true, - } -} - -func (m *MapType) ValueField() NestedField { - return NestedField{ - Name: "value", - ID: m.ValueID, - Type: m.ValueType, - Required: m.ValueRequired, - } -} - -func (*MapType) Type() string { return "map" } -func (m *MapType) String() string { - return fmt.Sprintf("map<%s, %s>", m.KeyType, m.ValueType) -} - -func (m *MapType) UnmarshalJSON(b []byte) error { - aux := struct { - KeyID int `json:"key-id"` - Key typeIFace `json:"key"` - ValueID int `json:"value-id"` - Value typeIFace `json:"value"` - ValueReq *bool `json:"value-required"` - }{} - if err := json.Unmarshal(b, &aux); err != nil { - return err - } - - m.KeyID, m.KeyType = aux.KeyID, aux.Key.Type - m.ValueID, m.ValueType = aux.ValueID, aux.Value.Type - if aux.ValueReq == nil { - m.ValueRequired = true - } else { - m.ValueRequired = *aux.ValueReq - } - return nil -} - -func FixedTypeOf(n int) FixedType { return FixedType{len: n} } - -type FixedType struct { - len int -} - -func (f FixedType) Equals(other Type) bool { - rhs, ok := other.(FixedType) - if !ok { - return false - } - - return f.len == rhs.len -} -func (f FixedType) Len() int { return f.len } -func (f FixedType) Type() string { return fmt.Sprintf("fixed[%d]", f.len) } -func (f FixedType) String() string { return fmt.Sprintf("fixed[%d]", f.len) } -func (f FixedType) primitive() {} - -func DecimalTypeOf(prec, scale int) DecimalType { - return DecimalType{precision: prec, scale: scale} -} - -type DecimalType struct { - precision, scale int -} - -func (d DecimalType) Equals(other Type) bool { - rhs, ok := other.(DecimalType) - if !ok { - return false - } - - return d.precision == rhs.precision && - d.scale == rhs.scale -} - -func (d DecimalType) Type() string { return fmt.Sprintf("decimal(%d, %d)", d.precision, d.scale) } -func (d DecimalType) String() string { return fmt.Sprintf("decimal(%d, %d)", d.precision, d.scale) } -func (d DecimalType) Precision() int { return d.precision } -func (d DecimalType) Scale() int { return d.scale } -func (DecimalType) primitive() {} - -type Decimal struct { - Val decimal128.Num - Scale int -} - -type PrimitiveType interface { - Type - primitive() -} - -type BooleanType struct{} - -func (BooleanType) Equals(other Type) bool { - _, ok := other.(BooleanType) - return ok -} - -func (BooleanType) primitive() {} -func (BooleanType) Type() string { return "boolean" } -func (BooleanType) String() string { return "boolean" } - -// Int32Type is the "int"/"integer" type of the iceberg spec. -type Int32Type struct{} - -func (Int32Type) Equals(other Type) bool { - _, ok := other.(Int32Type) - return ok -} - -func (Int32Type) primitive() {} -func (Int32Type) Type() string { return "int" } -func (Int32Type) String() string { return "int" } - -// Int64Type is the "long" type of the iceberg spec. -type Int64Type struct{} - -func (Int64Type) Equals(other Type) bool { - _, ok := other.(Int64Type) - return ok -} - -func (Int64Type) primitive() {} -func (Int64Type) Type() string { return "long" } -func (Int64Type) String() string { return "long" } - -// Float32Type is the "float" type in the iceberg spec. -type Float32Type struct{} - -func (Float32Type) Equals(other Type) bool { - _, ok := other.(Float32Type) - return ok -} - -func (Float32Type) primitive() {} -func (Float32Type) Type() string { return "float" } -func (Float32Type) String() string { return "float" } - -// Float64Type represents the "double" type of the iceberg spec. -type Float64Type struct{} - -func (Float64Type) Equals(other Type) bool { - _, ok := other.(Float64Type) - return ok -} - -func (Float64Type) primitive() {} -func (Float64Type) Type() string { return "double" } -func (Float64Type) String() string { return "double" } - -type Date int32 - -func (d Date) ToTime() time.Time { - return epochTM.AddDate(0, 0, int(d)) -} - -// DateType represents a calendar date without a timezone or time, -// represented as a 32-bit integer denoting the number of days since -// the unix epoch. -type DateType struct{} - -func (DateType) Equals(other Type) bool { - _, ok := other.(DateType) - return ok -} - -func (DateType) primitive() {} -func (DateType) Type() string { return "date" } -func (DateType) String() string { return "date" } - -type Time int64 - -// TimeType represents a number of microseconds since midnight. -type TimeType struct{} - -func (TimeType) Equals(other Type) bool { - _, ok := other.(TimeType) - return ok -} - -func (TimeType) primitive() {} -func (TimeType) Type() string { return "time" } -func (TimeType) String() string { return "time" } - -type Timestamp int64 - -func (t Timestamp) ToTime() time.Time { - return time.UnixMicro(int64(t)).UTC() -} - -func (t Timestamp) ToDate() Date { - tm := time.UnixMicro(int64(t)).UTC() - return Date(tm.Truncate(24*time.Hour).Unix() / int64((time.Hour * 24).Seconds())) -} - -// TimestampType represents a number of microseconds since the unix epoch -// without regard for timezone. -type TimestampType struct{} - -func (TimestampType) Equals(other Type) bool { - _, ok := other.(TimestampType) - return ok -} - -func (TimestampType) primitive() {} -func (TimestampType) Type() string { return "timestamp" } -func (TimestampType) String() string { return "timestamp" } - -// TimestampTzType represents a timestamp stored as UTC representing the -// number of microseconds since the unix epoch. -type TimestampTzType struct{} - -func (TimestampTzType) Equals(other Type) bool { - _, ok := other.(TimestampTzType) - return ok -} - -func (TimestampTzType) primitive() {} -func (TimestampTzType) Type() string { return "timestamptz" } -func (TimestampTzType) String() string { return "timestamptz" } - -type StringType struct{} - -func (StringType) Equals(other Type) bool { - _, ok := other.(StringType) - return ok -} - -func (StringType) primitive() {} -func (StringType) Type() string { return "string" } -func (StringType) String() string { return "string" } - -type UUIDType struct{} - -func (UUIDType) Equals(other Type) bool { - _, ok := other.(UUIDType) - return ok -} - -func (UUIDType) primitive() {} -func (UUIDType) Type() string { return "uuid" } -func (UUIDType) String() string { return "uuid" } - -type BinaryType struct{} - -func (BinaryType) Equals(other Type) bool { - _, ok := other.(BinaryType) - return ok -} - -func (BinaryType) primitive() {} -func (BinaryType) Type() string { return "binary" } -func (BinaryType) String() string { return "binary" } - -var PrimitiveTypes = struct { - Bool PrimitiveType - Int32 PrimitiveType - Int64 PrimitiveType - Float32 PrimitiveType - Float64 PrimitiveType - Date PrimitiveType - Time PrimitiveType - Timestamp PrimitiveType - TimestampTz PrimitiveType - String PrimitiveType - Binary PrimitiveType - UUID PrimitiveType -}{ - Bool: BooleanType{}, - Int32: Int32Type{}, - Int64: Int64Type{}, - Float32: Float32Type{}, - Float64: Float64Type{}, - Date: DateType{}, - Time: TimeType{}, - Timestamp: TimestampType{}, - TimestampTz: TimestampTzType{}, - String: StringType{}, - Binary: BinaryType{}, - UUID: UUIDType{}, -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package iceberg + +import ( + "encoding/json" + "fmt" + "regexp" + "strconv" + "strings" + "time" + + "github.com/apache/arrow-go/v18/arrow/decimal128" + "golang.org/x/exp/slices" +) + +var ( + regexFromBrackets = regexp.MustCompile(`^\w+\[(\d+)\]$`) + decimalRegex = regexp.MustCompile(`decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)`) +) + +type Properties map[string]string + +// Type is an interface representing any of the available iceberg types, +// such as primitives (int32/int64/etc.) or nested types (list/struct/map). +type Type interface { + fmt.Stringer + Type() string + Equals(Type) bool +} + +// NestedType is an interface that allows access to the child fields of +// a nested type such as a list/struct/map type. +type NestedType interface { + Type + Fields() []NestedField +} + +type typeIFace struct { + Type +} + +func (t *typeIFace) MarshalJSON() ([]byte, error) { + if nested, ok := t.Type.(NestedType); ok { + return json.Marshal(nested) + } + return []byte(`"` + t.Type.Type() + `"`), nil +} + +func (t *typeIFace) UnmarshalJSON(b []byte) error { + var typename string + err := json.Unmarshal(b, &typename) + if err == nil { + switch typename { + case "boolean": + t.Type = BooleanType{} + case "int": + t.Type = Int32Type{} + case "long": + t.Type = Int64Type{} + case "float": + t.Type = Float32Type{} + case "double": + t.Type = Float64Type{} + case "date": + t.Type = DateType{} + case "time": + t.Type = TimeType{} + case "timestamp": + t.Type = TimestampType{} + case "timestamptz": + t.Type = TimestampTzType{} + case "string": + t.Type = StringType{} + case "uuid": + t.Type = UUIDType{} + case "binary": + t.Type = BinaryType{} + default: + switch { + case strings.HasPrefix(typename, "fixed"): + matches := regexFromBrackets.FindStringSubmatch(typename) + if len(matches) != 2 { + return fmt.Errorf("%w: %s", ErrInvalidTypeString, typename) + } + + n, _ := strconv.Atoi(matches[1]) + t.Type = FixedType{len: n} + case strings.HasPrefix(typename, "decimal"): + matches := decimalRegex.FindStringSubmatch(typename) + if len(matches) != 3 { + return fmt.Errorf("%w: %s", ErrInvalidTypeString, typename) + } + + prec, _ := strconv.Atoi(matches[1]) + scale, _ := strconv.Atoi(matches[2]) + t.Type = DecimalType{precision: prec, scale: scale} + default: + return fmt.Errorf("%w: unrecognized field type", ErrInvalidSchema) + } + } + return nil + } + + aux := struct { + TypeName string `json:"type"` + }{} + if err = json.Unmarshal(b, &aux); err != nil { + return err + } + + switch aux.TypeName { + case "list": + t.Type = &ListType{} + case "map": + t.Type = &MapType{} + case "struct": + t.Type = &StructType{} + default: + return fmt.Errorf("%w: %s", ErrInvalidTypeString, aux.TypeName) + } + + return json.Unmarshal(b, t.Type) +} + +type NestedField struct { + Type `json:"-"` + + ID int `json:"id"` + Name string `json:"name"` + Required bool `json:"required"` + Doc string `json:"doc,omitempty"` + InitialDefault any `json:"initial-default,omitempty"` + WriteDefault any `json:"write-default,omitempty"` +} + +func optOrReq(required bool) string { + if required { + return "required" + } + return "optional" +} + +func (n NestedField) String() string { + doc := n.Doc + if doc != "" { + doc = " (" + doc + ")" + } + + return fmt.Sprintf("%d: %s: %s %s%s", + n.ID, n.Name, optOrReq(n.Required), n.Type, doc) +} + +func (n *NestedField) Equals(other NestedField) bool { + return n.ID == other.ID && + n.Name == other.Name && + n.Required == other.Required && + n.Doc == other.Doc && + n.InitialDefault == other.InitialDefault && + n.WriteDefault == other.WriteDefault && + n.Type.Equals(other.Type) +} + +func (n NestedField) MarshalJSON() ([]byte, error) { + type Alias NestedField + return json.Marshal(struct { + Type *typeIFace `json:"type"` + *Alias + }{Type: &typeIFace{n.Type}, Alias: (*Alias)(&n)}) +} + +func (n *NestedField) UnmarshalJSON(b []byte) error { + type Alias NestedField + aux := struct { + Type typeIFace `json:"type"` + *Alias + }{ + Alias: (*Alias)(n), + } + + if err := json.Unmarshal(b, &aux); err != nil { + return err + } + + n.Type = aux.Type.Type + + return nil +} + +type StructType struct { + FieldList []NestedField `json:"fields"` +} + +func (s *StructType) Equals(other Type) bool { + st, ok := other.(*StructType) + if !ok { + return false + } + + return slices.EqualFunc(s.FieldList, st.FieldList, func(a, b NestedField) bool { + return a.Equals(b) + }) +} + +func (s *StructType) Fields() []NestedField { return s.FieldList } + +func (s *StructType) MarshalJSON() ([]byte, error) { + type Alias StructType + return json.Marshal(struct { + Type string `json:"type"` + *Alias + }{Type: s.Type(), Alias: (*Alias)(s)}) +} + +func (*StructType) Type() string { return "struct" } +func (s *StructType) String() string { + var b strings.Builder + b.WriteString("struct<") + for i, f := range s.FieldList { + if i != 0 { + b.WriteString(", ") + } + fmt.Fprintf(&b, "%d: %s: ", + f.ID, f.Name) + if f.Required { + b.WriteString("required ") + } else { + b.WriteString("optional ") + } + b.WriteString(f.Type.String()) + if f.Doc != "" { + b.WriteString(" (") + b.WriteString(f.Doc) + b.WriteByte(')') + } + } + b.WriteString(">") + + return b.String() +} + +type ListType struct { + ElementID int `json:"element-id"` + Element Type `json:"-"` + ElementRequired bool `json:"element-required"` +} + +func (l *ListType) MarshalJSON() ([]byte, error) { + type Alias ListType + return json.Marshal(struct { + Type string `json:"type"` + *Alias + Element *typeIFace `json:"element"` + }{Type: l.Type(), Alias: (*Alias)(l), Element: &typeIFace{l.Element}}) +} + +func (l *ListType) Equals(other Type) bool { + rhs, ok := other.(*ListType) + if !ok { + return false + } + + return l.ElementID == rhs.ElementID && + l.Element.Equals(rhs.Element) && + l.ElementRequired == rhs.ElementRequired +} + +func (l *ListType) Fields() []NestedField { + return []NestedField{l.ElementField()} +} + +func (l *ListType) ElementField() NestedField { + return NestedField{ + ID: l.ElementID, + Name: "element", + Type: l.Element, + Required: l.ElementRequired, + } +} + +func (*ListType) Type() string { return "list" } +func (l *ListType) String() string { return fmt.Sprintf("list<%s>", l.Element) } + +func (l *ListType) UnmarshalJSON(b []byte) error { + aux := struct { + ID int `json:"element-id"` + Elem typeIFace `json:"element"` + Req bool `json:"element-required"` + }{} + if err := json.Unmarshal(b, &aux); err != nil { + return err + } + + l.ElementID = aux.ID + l.Element = aux.Elem.Type + l.ElementRequired = aux.Req + return nil +} + +type MapType struct { + KeyID int `json:"key-id"` + KeyType Type `json:"-"` + ValueID int `json:"value-id"` + ValueType Type `json:"-"` + ValueRequired bool `json:"value-required"` +} + +func (m *MapType) MarshalJSON() ([]byte, error) { + type Alias MapType + return json.Marshal(struct { + Type string `json:"type"` + *Alias + KeyType *typeIFace `json:"key"` + ValueType *typeIFace `json:"value"` + }{Type: m.Type(), Alias: (*Alias)(m), + KeyType: &typeIFace{m.KeyType}, + ValueType: &typeIFace{m.ValueType}}) +} + +func (m *MapType) Equals(other Type) bool { + rhs, ok := other.(*MapType) + if !ok { + return false + } + + return m.KeyID == rhs.KeyID && + m.KeyType.Equals(rhs.KeyType) && + m.ValueID == rhs.ValueID && + m.ValueType.Equals(rhs.ValueType) && + m.ValueRequired == rhs.ValueRequired +} + +func (m *MapType) Fields() []NestedField { + return []NestedField{m.KeyField(), m.ValueField()} +} + +func (m *MapType) KeyField() NestedField { + return NestedField{ + Name: "key", + ID: m.KeyID, + Type: m.KeyType, + Required: true, + } +} + +func (m *MapType) ValueField() NestedField { + return NestedField{ + Name: "value", + ID: m.ValueID, + Type: m.ValueType, + Required: m.ValueRequired, + } +} + +func (*MapType) Type() string { return "map" } +func (m *MapType) String() string { + return fmt.Sprintf("map<%s, %s>", m.KeyType, m.ValueType) +} + +func (m *MapType) UnmarshalJSON(b []byte) error { + aux := struct { + KeyID int `json:"key-id"` + Key typeIFace `json:"key"` + ValueID int `json:"value-id"` + Value typeIFace `json:"value"` + ValueReq *bool `json:"value-required"` + }{} + if err := json.Unmarshal(b, &aux); err != nil { + return err + } + + m.KeyID, m.KeyType = aux.KeyID, aux.Key.Type + m.ValueID, m.ValueType = aux.ValueID, aux.Value.Type + if aux.ValueReq == nil { + m.ValueRequired = true + } else { + m.ValueRequired = *aux.ValueReq + } + return nil +} + +func FixedTypeOf(n int) FixedType { return FixedType{len: n} } + +type FixedType struct { + len int +} + +func (f FixedType) Equals(other Type) bool { + rhs, ok := other.(FixedType) + if !ok { + return false + } + + return f.len == rhs.len +} +func (f FixedType) Len() int { return f.len } +func (f FixedType) Type() string { return fmt.Sprintf("fixed[%d]", f.len) } +func (f FixedType) String() string { return fmt.Sprintf("fixed[%d]", f.len) } +func (f FixedType) primitive() {} + +func DecimalTypeOf(prec, scale int) DecimalType { + return DecimalType{precision: prec, scale: scale} +} + +type DecimalType struct { + precision, scale int +} + +func (d DecimalType) Equals(other Type) bool { + rhs, ok := other.(DecimalType) + if !ok { + return false + } + + return d.precision == rhs.precision && + d.scale == rhs.scale +} + +func (d DecimalType) Type() string { return fmt.Sprintf("decimal(%d, %d)", d.precision, d.scale) } +func (d DecimalType) String() string { return fmt.Sprintf("decimal(%d, %d)", d.precision, d.scale) } +func (d DecimalType) Precision() int { return d.precision } +func (d DecimalType) Scale() int { return d.scale } +func (DecimalType) primitive() {} + +type Decimal struct { + Val decimal128.Num + Scale int +} + +type PrimitiveType interface { + Type + primitive() +} + +type BooleanType struct{} + +func (BooleanType) Equals(other Type) bool { + _, ok := other.(BooleanType) + return ok +} + +func (BooleanType) primitive() {} +func (BooleanType) Type() string { return "boolean" } +func (BooleanType) String() string { return "boolean" } + +// Int32Type is the "int"/"integer" type of the iceberg spec. +type Int32Type struct{} + +func (Int32Type) Equals(other Type) bool { + _, ok := other.(Int32Type) + return ok +} + +func (Int32Type) primitive() {} +func (Int32Type) Type() string { return "int" } +func (Int32Type) String() string { return "int" } + +// Int64Type is the "long" type of the iceberg spec. +type Int64Type struct{} + +func (Int64Type) Equals(other Type) bool { + _, ok := other.(Int64Type) + return ok +} + +func (Int64Type) primitive() {} +func (Int64Type) Type() string { return "long" } +func (Int64Type) String() string { return "long" } + +// Float32Type is the "float" type in the iceberg spec. +type Float32Type struct{} + +func (Float32Type) Equals(other Type) bool { + _, ok := other.(Float32Type) + return ok +} + +func (Float32Type) primitive() {} +func (Float32Type) Type() string { return "float" } +func (Float32Type) String() string { return "float" } + +// Float64Type represents the "double" type of the iceberg spec. +type Float64Type struct{} + +func (Float64Type) Equals(other Type) bool { + _, ok := other.(Float64Type) + return ok +} + +func (Float64Type) primitive() {} +func (Float64Type) Type() string { return "double" } +func (Float64Type) String() string { return "double" } + +type Date int32 + +func (d Date) ToTime() time.Time { + return epochTM.AddDate(0, 0, int(d)) +} + +// DateType represents a calendar date without a timezone or time, +// represented as a 32-bit integer denoting the number of days since +// the unix epoch. +type DateType struct{} + +func (DateType) Equals(other Type) bool { + _, ok := other.(DateType) + return ok +} + +func (DateType) primitive() {} +func (DateType) Type() string { return "date" } +func (DateType) String() string { return "date" } + +type Time int64 + +// TimeType represents a number of microseconds since midnight. +type TimeType struct{} + +func (TimeType) Equals(other Type) bool { + _, ok := other.(TimeType) + return ok +} + +func (TimeType) primitive() {} +func (TimeType) Type() string { return "time" } +func (TimeType) String() string { return "time" } + +type Timestamp int64 + +func (t Timestamp) ToTime() time.Time { + return time.UnixMicro(int64(t)).UTC() +} + +func (t Timestamp) ToDate() Date { + tm := time.UnixMicro(int64(t)).UTC() + return Date(tm.Truncate(24*time.Hour).Unix() / int64((time.Hour * 24).Seconds())) +} + +// TimestampType represents a number of microseconds since the unix epoch +// without regard for timezone. +type TimestampType struct{} + +func (TimestampType) Equals(other Type) bool { + _, ok := other.(TimestampType) + return ok +} + +func (TimestampType) primitive() {} +func (TimestampType) Type() string { return "timestamp" } +func (TimestampType) String() string { return "timestamp" } + +// TimestampTzType represents a timestamp stored as UTC representing the +// number of microseconds since the unix epoch. +type TimestampTzType struct{} + +func (TimestampTzType) Equals(other Type) bool { + _, ok := other.(TimestampTzType) + return ok +} + +func (TimestampTzType) primitive() {} +func (TimestampTzType) Type() string { return "timestamptz" } +func (TimestampTzType) String() string { return "timestamptz" } + +type StringType struct{} + +func (StringType) Equals(other Type) bool { + _, ok := other.(StringType) + return ok +} + +func (StringType) primitive() {} +func (StringType) Type() string { return "string" } +func (StringType) String() string { return "string" } + +type UUIDType struct{} + +func (UUIDType) Equals(other Type) bool { + _, ok := other.(UUIDType) + return ok +} + +func (UUIDType) primitive() {} +func (UUIDType) Type() string { return "uuid" } +func (UUIDType) String() string { return "uuid" } + +type BinaryType struct{} + +func (BinaryType) Equals(other Type) bool { + _, ok := other.(BinaryType) + return ok +} + +func (BinaryType) primitive() {} +func (BinaryType) Type() string { return "binary" } +func (BinaryType) String() string { return "binary" } + +var PrimitiveTypes = struct { + Bool PrimitiveType + Int32 PrimitiveType + Int64 PrimitiveType + Float32 PrimitiveType + Float64 PrimitiveType + Date PrimitiveType + Time PrimitiveType + Timestamp PrimitiveType + TimestampTz PrimitiveType + String PrimitiveType + Binary PrimitiveType + UUID PrimitiveType +}{ + Bool: BooleanType{}, + Int32: Int32Type{}, + Int64: Int64Type{}, + Float32: Float32Type{}, + Float64: Float64Type{}, + Date: DateType{}, + Time: TimeType{}, + Timestamp: TimestampType{}, + TimestampTz: TimestampTzType{}, + String: StringType{}, + Binary: BinaryType{}, + UUID: UUIDType{}, +} diff --git a/types_test.go b/types_test.go index abe233f..1f2e37c 100644 --- a/types_test.go +++ b/types_test.go @@ -1,236 +1,236 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package iceberg_test - -import ( - "encoding/json" - "testing" - - "github.com/apache/iceberg-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestTypesBasic(t *testing.T) { - tests := []struct { - expected string - typ iceberg.Type - }{ - {"boolean", iceberg.PrimitiveTypes.Bool}, - {"int", iceberg.PrimitiveTypes.Int32}, - {"long", iceberg.PrimitiveTypes.Int64}, - {"float", iceberg.PrimitiveTypes.Float32}, - {"double", iceberg.PrimitiveTypes.Float64}, - {"date", iceberg.PrimitiveTypes.Date}, - {"time", iceberg.PrimitiveTypes.Time}, - {"timestamp", iceberg.PrimitiveTypes.Timestamp}, - {"timestamptz", iceberg.PrimitiveTypes.TimestampTz}, - {"uuid", iceberg.PrimitiveTypes.UUID}, - {"binary", iceberg.PrimitiveTypes.Binary}, - {"fixed[5]", iceberg.FixedTypeOf(5)}, - {"decimal(9, 4)", iceberg.DecimalTypeOf(9, 4)}, - } - - for _, tt := range tests { - t.Run(tt.expected, func(t *testing.T) { - var data = `{ - "id": 1, - "name": "test", - "type": "` + tt.expected + `", - "required": false - }` - - var n iceberg.NestedField - require.NoError(t, json.Unmarshal([]byte(data), &n)) - assert.Truef(t, n.Type.Equals(tt.typ), "expected: %s\ngot: %s", tt.typ, n.Type) - - out, err := json.Marshal(n) - require.NoError(t, err) - assert.JSONEq(t, data, string(out)) - }) - } -} - -func TestFixedType(t *testing.T) { - typ := iceberg.FixedTypeOf(5) - assert.Equal(t, 5, typ.Len()) - assert.Equal(t, "fixed[5]", typ.String()) - assert.True(t, typ.Equals(iceberg.FixedTypeOf(5))) - assert.False(t, typ.Equals(iceberg.FixedTypeOf(6))) -} - -func TestDecimalType(t *testing.T) { - typ := iceberg.DecimalTypeOf(9, 2) - assert.Equal(t, 9, typ.Precision()) - assert.Equal(t, 2, typ.Scale()) - assert.Equal(t, "decimal(9, 2)", typ.String()) - assert.True(t, typ.Equals(iceberg.DecimalTypeOf(9, 2))) - assert.False(t, typ.Equals(iceberg.DecimalTypeOf(9, 3))) -} - -func TestStructType(t *testing.T) { - typ := &iceberg.StructType{ - FieldList: []iceberg.NestedField{ - {ID: 1, Name: "required_field", Type: iceberg.PrimitiveTypes.Int32, Required: true}, - {ID: 2, Name: "optional_field", Type: iceberg.FixedTypeOf(5), Required: false}, - {ID: 3, Name: "required_field", Type: &iceberg.StructType{ - FieldList: []iceberg.NestedField{ - {ID: 4, Name: "optional_field", Type: iceberg.DecimalTypeOf(8, 2), Required: false}, - {ID: 5, Name: "required_field", Type: iceberg.PrimitiveTypes.Int64, Required: false}, - }, - }, Required: false}, - }, - } - - assert.Len(t, typ.FieldList, 3) - assert.False(t, typ.Equals(&iceberg.StructType{FieldList: []iceberg.NestedField{{ID: 1, Name: "optional_field", Type: iceberg.PrimitiveTypes.Int32, Required: true}}})) - out, err := json.Marshal(typ) - require.NoError(t, err) - - var actual iceberg.StructType - require.NoError(t, json.Unmarshal(out, &actual)) - assert.True(t, typ.Equals(&actual)) -} - -func TestListType(t *testing.T) { - typ := &iceberg.ListType{ - ElementID: 1, - ElementRequired: false, - Element: &iceberg.StructType{ - FieldList: []iceberg.NestedField{ - {ID: 2, Name: "required_field", Type: iceberg.DecimalTypeOf(8, 2), Required: true}, - {ID: 3, Name: "optional_field", Type: iceberg.PrimitiveTypes.Int64, Required: false}, - }, - }, - } - - assert.IsType(t, (*iceberg.StructType)(nil), typ.ElementField().Type) - assert.Len(t, typ.ElementField().Type.(iceberg.NestedType).Fields(), 2) - assert.Equal(t, 1, typ.ElementField().ID) - assert.False(t, typ.Equals(&iceberg.ListType{ - ElementID: 1, - ElementRequired: true, - Element: &iceberg.StructType{ - FieldList: []iceberg.NestedField{ - {ID: 2, Name: "required_field", Type: iceberg.DecimalTypeOf(8, 2), Required: true}, - }, - }, - })) - - out, err := json.Marshal(typ) - require.NoError(t, err) - - var actual iceberg.ListType - require.NoError(t, json.Unmarshal(out, &actual)) - assert.True(t, typ.Equals(&actual)) -} - -func TestMapType(t *testing.T) { - typ := &iceberg.MapType{ - KeyID: 1, - KeyType: iceberg.PrimitiveTypes.Float64, - ValueID: 2, - ValueType: iceberg.PrimitiveTypes.UUID, - ValueRequired: false, - } - - assert.IsType(t, iceberg.PrimitiveTypes.Float64, typ.KeyField().Type) - assert.Equal(t, 1, typ.KeyField().ID) - assert.IsType(t, iceberg.PrimitiveTypes.UUID, typ.ValueField().Type) - assert.Equal(t, 2, typ.ValueField().ID) - assert.False(t, typ.Equals(&iceberg.MapType{ - KeyID: 1, KeyType: iceberg.PrimitiveTypes.Int64, - ValueID: 2, ValueType: iceberg.PrimitiveTypes.UUID, ValueRequired: false, - })) - assert.False(t, typ.Equals(&iceberg.MapType{ - KeyID: 1, KeyType: iceberg.PrimitiveTypes.Float64, - ValueID: 2, ValueType: iceberg.PrimitiveTypes.String, ValueRequired: true, - })) - - out, err := json.Marshal(typ) - require.NoError(t, err) - - var actual iceberg.MapType - require.NoError(t, json.Unmarshal(out, &actual)) - assert.True(t, typ.Equals(&actual)) -} - -var ( - NonParameterizedTypes = []iceberg.Type{ - iceberg.PrimitiveTypes.Bool, - iceberg.PrimitiveTypes.Int32, - iceberg.PrimitiveTypes.Int64, - iceberg.PrimitiveTypes.Float32, - iceberg.PrimitiveTypes.Float64, - iceberg.PrimitiveTypes.Date, - iceberg.PrimitiveTypes.Time, - iceberg.PrimitiveTypes.Timestamp, - iceberg.PrimitiveTypes.TimestampTz, - iceberg.PrimitiveTypes.String, - iceberg.PrimitiveTypes.Binary, - iceberg.PrimitiveTypes.UUID, - } -) - -func TestNonParameterizedTypeEquality(t *testing.T) { - for i, in := range NonParameterizedTypes { - for j, check := range NonParameterizedTypes { - if i == j { - assert.Truef(t, in.Equals(check), "expected %s == %s", in, check) - } else { - assert.Falsef(t, in.Equals(check), "expected %s != %s", in, check) - } - } - } -} - -func TestTypeStrings(t *testing.T) { - tests := []struct { - typ iceberg.Type - str string - }{ - {iceberg.PrimitiveTypes.Bool, "boolean"}, - {iceberg.PrimitiveTypes.Int32, "int"}, - {iceberg.PrimitiveTypes.Int64, "long"}, - {iceberg.PrimitiveTypes.Float32, "float"}, - {iceberg.PrimitiveTypes.Float64, "double"}, - {iceberg.PrimitiveTypes.Date, "date"}, - {iceberg.PrimitiveTypes.Time, "time"}, - {iceberg.PrimitiveTypes.Timestamp, "timestamp"}, - {iceberg.PrimitiveTypes.TimestampTz, "timestamptz"}, - {iceberg.PrimitiveTypes.String, "string"}, - {iceberg.PrimitiveTypes.UUID, "uuid"}, - {iceberg.PrimitiveTypes.Binary, "binary"}, - {iceberg.FixedTypeOf(22), "fixed[22]"}, - {iceberg.DecimalTypeOf(19, 25), "decimal(19, 25)"}, - {&iceberg.StructType{ - FieldList: []iceberg.NestedField{ - {ID: 1, Name: "required_field", Type: iceberg.PrimitiveTypes.String, Required: true, Doc: "this is a doc"}, - {ID: 2, Name: "optional_field", Type: iceberg.PrimitiveTypes.Int32, Required: true}, - }, - }, "struct<1: required_field: required string (this is a doc), 2: optional_field: required int>"}, - {&iceberg.ListType{ - ElementID: 22, Element: iceberg.PrimitiveTypes.String}, "list"}, - {&iceberg.MapType{KeyID: 19, KeyType: iceberg.PrimitiveTypes.String, ValueID: 25, ValueType: iceberg.PrimitiveTypes.Float64}, - "map"}, - } - - for _, tt := range tests { - assert.Equal(t, tt.str, tt.typ.String()) - } -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package iceberg_test + +import ( + "encoding/json" + "testing" + + "github.com/apache/iceberg-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTypesBasic(t *testing.T) { + tests := []struct { + expected string + typ iceberg.Type + }{ + {"boolean", iceberg.PrimitiveTypes.Bool}, + {"int", iceberg.PrimitiveTypes.Int32}, + {"long", iceberg.PrimitiveTypes.Int64}, + {"float", iceberg.PrimitiveTypes.Float32}, + {"double", iceberg.PrimitiveTypes.Float64}, + {"date", iceberg.PrimitiveTypes.Date}, + {"time", iceberg.PrimitiveTypes.Time}, + {"timestamp", iceberg.PrimitiveTypes.Timestamp}, + {"timestamptz", iceberg.PrimitiveTypes.TimestampTz}, + {"uuid", iceberg.PrimitiveTypes.UUID}, + {"binary", iceberg.PrimitiveTypes.Binary}, + {"fixed[5]", iceberg.FixedTypeOf(5)}, + {"decimal(9, 4)", iceberg.DecimalTypeOf(9, 4)}, + } + + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + var data = `{ + "id": 1, + "name": "test", + "type": "` + tt.expected + `", + "required": false + }` + + var n iceberg.NestedField + require.NoError(t, json.Unmarshal([]byte(data), &n)) + assert.Truef(t, n.Type.Equals(tt.typ), "expected: %s\ngot: %s", tt.typ, n.Type) + + out, err := json.Marshal(n) + require.NoError(t, err) + assert.JSONEq(t, data, string(out)) + }) + } +} + +func TestFixedType(t *testing.T) { + typ := iceberg.FixedTypeOf(5) + assert.Equal(t, 5, typ.Len()) + assert.Equal(t, "fixed[5]", typ.String()) + assert.True(t, typ.Equals(iceberg.FixedTypeOf(5))) + assert.False(t, typ.Equals(iceberg.FixedTypeOf(6))) +} + +func TestDecimalType(t *testing.T) { + typ := iceberg.DecimalTypeOf(9, 2) + assert.Equal(t, 9, typ.Precision()) + assert.Equal(t, 2, typ.Scale()) + assert.Equal(t, "decimal(9, 2)", typ.String()) + assert.True(t, typ.Equals(iceberg.DecimalTypeOf(9, 2))) + assert.False(t, typ.Equals(iceberg.DecimalTypeOf(9, 3))) +} + +func TestStructType(t *testing.T) { + typ := &iceberg.StructType{ + FieldList: []iceberg.NestedField{ + {ID: 1, Name: "required_field", Type: iceberg.PrimitiveTypes.Int32, Required: true}, + {ID: 2, Name: "optional_field", Type: iceberg.FixedTypeOf(5), Required: false}, + {ID: 3, Name: "required_field", Type: &iceberg.StructType{ + FieldList: []iceberg.NestedField{ + {ID: 4, Name: "optional_field", Type: iceberg.DecimalTypeOf(8, 2), Required: false}, + {ID: 5, Name: "required_field", Type: iceberg.PrimitiveTypes.Int64, Required: false}, + }, + }, Required: false}, + }, + } + + assert.Len(t, typ.FieldList, 3) + assert.False(t, typ.Equals(&iceberg.StructType{FieldList: []iceberg.NestedField{{ID: 1, Name: "optional_field", Type: iceberg.PrimitiveTypes.Int32, Required: true}}})) + out, err := json.Marshal(typ) + require.NoError(t, err) + + var actual iceberg.StructType + require.NoError(t, json.Unmarshal(out, &actual)) + assert.True(t, typ.Equals(&actual)) +} + +func TestListType(t *testing.T) { + typ := &iceberg.ListType{ + ElementID: 1, + ElementRequired: false, + Element: &iceberg.StructType{ + FieldList: []iceberg.NestedField{ + {ID: 2, Name: "required_field", Type: iceberg.DecimalTypeOf(8, 2), Required: true}, + {ID: 3, Name: "optional_field", Type: iceberg.PrimitiveTypes.Int64, Required: false}, + }, + }, + } + + assert.IsType(t, (*iceberg.StructType)(nil), typ.ElementField().Type) + assert.Len(t, typ.ElementField().Type.(iceberg.NestedType).Fields(), 2) + assert.Equal(t, 1, typ.ElementField().ID) + assert.False(t, typ.Equals(&iceberg.ListType{ + ElementID: 1, + ElementRequired: true, + Element: &iceberg.StructType{ + FieldList: []iceberg.NestedField{ + {ID: 2, Name: "required_field", Type: iceberg.DecimalTypeOf(8, 2), Required: true}, + }, + }, + })) + + out, err := json.Marshal(typ) + require.NoError(t, err) + + var actual iceberg.ListType + require.NoError(t, json.Unmarshal(out, &actual)) + assert.True(t, typ.Equals(&actual)) +} + +func TestMapType(t *testing.T) { + typ := &iceberg.MapType{ + KeyID: 1, + KeyType: iceberg.PrimitiveTypes.Float64, + ValueID: 2, + ValueType: iceberg.PrimitiveTypes.UUID, + ValueRequired: false, + } + + assert.IsType(t, iceberg.PrimitiveTypes.Float64, typ.KeyField().Type) + assert.Equal(t, 1, typ.KeyField().ID) + assert.IsType(t, iceberg.PrimitiveTypes.UUID, typ.ValueField().Type) + assert.Equal(t, 2, typ.ValueField().ID) + assert.False(t, typ.Equals(&iceberg.MapType{ + KeyID: 1, KeyType: iceberg.PrimitiveTypes.Int64, + ValueID: 2, ValueType: iceberg.PrimitiveTypes.UUID, ValueRequired: false, + })) + assert.False(t, typ.Equals(&iceberg.MapType{ + KeyID: 1, KeyType: iceberg.PrimitiveTypes.Float64, + ValueID: 2, ValueType: iceberg.PrimitiveTypes.String, ValueRequired: true, + })) + + out, err := json.Marshal(typ) + require.NoError(t, err) + + var actual iceberg.MapType + require.NoError(t, json.Unmarshal(out, &actual)) + assert.True(t, typ.Equals(&actual)) +} + +var ( + NonParameterizedTypes = []iceberg.Type{ + iceberg.PrimitiveTypes.Bool, + iceberg.PrimitiveTypes.Int32, + iceberg.PrimitiveTypes.Int64, + iceberg.PrimitiveTypes.Float32, + iceberg.PrimitiveTypes.Float64, + iceberg.PrimitiveTypes.Date, + iceberg.PrimitiveTypes.Time, + iceberg.PrimitiveTypes.Timestamp, + iceberg.PrimitiveTypes.TimestampTz, + iceberg.PrimitiveTypes.String, + iceberg.PrimitiveTypes.Binary, + iceberg.PrimitiveTypes.UUID, + } +) + +func TestNonParameterizedTypeEquality(t *testing.T) { + for i, in := range NonParameterizedTypes { + for j, check := range NonParameterizedTypes { + if i == j { + assert.Truef(t, in.Equals(check), "expected %s == %s", in, check) + } else { + assert.Falsef(t, in.Equals(check), "expected %s != %s", in, check) + } + } + } +} + +func TestTypeStrings(t *testing.T) { + tests := []struct { + typ iceberg.Type + str string + }{ + {iceberg.PrimitiveTypes.Bool, "boolean"}, + {iceberg.PrimitiveTypes.Int32, "int"}, + {iceberg.PrimitiveTypes.Int64, "long"}, + {iceberg.PrimitiveTypes.Float32, "float"}, + {iceberg.PrimitiveTypes.Float64, "double"}, + {iceberg.PrimitiveTypes.Date, "date"}, + {iceberg.PrimitiveTypes.Time, "time"}, + {iceberg.PrimitiveTypes.Timestamp, "timestamp"}, + {iceberg.PrimitiveTypes.TimestampTz, "timestamptz"}, + {iceberg.PrimitiveTypes.String, "string"}, + {iceberg.PrimitiveTypes.UUID, "uuid"}, + {iceberg.PrimitiveTypes.Binary, "binary"}, + {iceberg.FixedTypeOf(22), "fixed[22]"}, + {iceberg.DecimalTypeOf(19, 25), "decimal(19, 25)"}, + {&iceberg.StructType{ + FieldList: []iceberg.NestedField{ + {ID: 1, Name: "required_field", Type: iceberg.PrimitiveTypes.String, Required: true, Doc: "this is a doc"}, + {ID: 2, Name: "optional_field", Type: iceberg.PrimitiveTypes.Int32, Required: true}, + }, + }, "struct<1: required_field: required string (this is a doc), 2: optional_field: required int>"}, + {&iceberg.ListType{ + ElementID: 22, Element: iceberg.PrimitiveTypes.String}, "list"}, + {&iceberg.MapType{KeyID: 19, KeyType: iceberg.PrimitiveTypes.String, ValueID: 25, ValueType: iceberg.PrimitiveTypes.Float64}, + "map"}, + } + + for _, tt := range tests { + assert.Equal(t, tt.str, tt.typ.String()) + } +} diff --git a/utils.go b/utils.go index c0a00fe..cf7c1ab 100644 --- a/utils.go +++ b/utils.go @@ -1,198 +1,198 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package iceberg - -import ( - "cmp" - "fmt" - "hash/maphash" - "maps" - "runtime/debug" - "strings" -) - -var version string - -func init() { - version = "(unknown version)" - if info, ok := debug.ReadBuildInfo(); ok { - for _, dep := range info.Deps { - if strings.HasPrefix(dep.Path, "github.com/apache/iceberg-go") { - version = dep.Version - break - } - } - } -} - -func Version() string { return version } - -func max[T cmp.Ordered](vals ...T) T { - if len(vals) == 0 { - panic("can't call max with no arguments") - } - - out := vals[0] - for _, v := range vals[1:] { - if v > out { - out = v - } - } - return out -} - -// Optional represents a typed value that could be null -type Optional[T any] struct { - Val T - Valid bool -} - -// represents a single row in a record -type structLike interface { - // Size returns the number of columns in this row - Size() int - // Get returns the value in the requested column, - // will panic if pos is out of bounds. - Get(pos int) any - // Set changes the value in the column indicated, - // will panic if pos is out of bounds. - Set(pos int, val any) -} - -type accessor struct { - pos int - inner *accessor -} - -func (a *accessor) String() string { - return fmt.Sprintf("Accessor(position=%d, inner=%s)", a.pos, a.inner) -} - -func (a *accessor) Get(s structLike) any { - val, inner := s.Get(a.pos), a - for val != nil && inner.inner != nil { - inner = inner.inner - val = val.(structLike).Get(inner.pos) - } - return val -} - -type Set[E any] interface { - Add(...E) - Contains(E) bool - Members() []E - Equals(Set[E]) bool - Len() int - All(func(E) bool) bool -} - -var lzseed = maphash.MakeSeed() - -type literalSet map[any]struct{ orig Literal } - -func newLiteralSet(vals ...Literal) Set[Literal] { - s := literalSet{} - for _, v := range vals { - s.addliteral(v) - } - return s -} - -func (l literalSet) addliteral(v Literal) { - switch v := v.(type) { - case FixedLiteral: - l[maphash.Bytes(lzseed, []byte(v))] = struct{ orig Literal }{v} - case BinaryLiteral: - l[maphash.Bytes(lzseed, []byte(v))] = struct{ orig Literal }{v} - default: - l[v] = struct{ orig Literal }{} - } -} - -func (l literalSet) Add(lits ...Literal) { - for _, v := range lits { - l.addliteral(v) - } -} - -func (l literalSet) Contains(lit Literal) bool { - switch lit := lit.(type) { - case BinaryLiteral: - v, ok := l[maphash.Bytes(lzseed, []byte(lit))] - if !ok { - return false - } - return lit.Equals(v.orig) - case FixedLiteral: - v, ok := l[maphash.Bytes(lzseed, []byte(lit))] - if !ok { - return false - } - return lit.Equals(v.orig) - default: - _, ok := l[lit] - return ok - } -} - -func (l literalSet) Members() []Literal { - result := make([]Literal, 0, len(l)) - for k, v := range l { - if k, ok := k.(Literal); ok { - result = append(result, k) - } else { - result = append(result, v.orig) - } - } - return result -} - -func (l literalSet) Equals(other Set[Literal]) bool { - rhs, ok := other.(literalSet) - if !ok { - return false - } - return maps.EqualFunc(l, rhs, func(v1, v2 struct{ orig Literal }) bool { - switch { - case v1.orig == nil: - return v2.orig == nil - case v2.orig == nil: - return v1.orig == nil - default: - return v1.orig.Equals(v2.orig) - } - }) -} - -func (l literalSet) Len() int { return len(l) } - -func (l literalSet) All(fn func(Literal) bool) bool { - for k, v := range l { - var e Literal - if k, ok := k.(Literal); ok { - e = k - } else { - e = v.orig - } - - if !fn(e) { - return false - } - } - return true -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package iceberg + +import ( + "cmp" + "fmt" + "hash/maphash" + "maps" + "runtime/debug" + "strings" +) + +var version string + +func init() { + version = "(unknown version)" + if info, ok := debug.ReadBuildInfo(); ok { + for _, dep := range info.Deps { + if strings.HasPrefix(dep.Path, "github.com/apache/iceberg-go") { + version = dep.Version + break + } + } + } +} + +func Version() string { return version } + +func max[T cmp.Ordered](vals ...T) T { + if len(vals) == 0 { + panic("can't call max with no arguments") + } + + out := vals[0] + for _, v := range vals[1:] { + if v > out { + out = v + } + } + return out +} + +// Optional represents a typed value that could be null +type Optional[T any] struct { + Val T + Valid bool +} + +// represents a single row in a record +type structLike interface { + // Size returns the number of columns in this row + Size() int + // Get returns the value in the requested column, + // will panic if pos is out of bounds. + Get(pos int) any + // Set changes the value in the column indicated, + // will panic if pos is out of bounds. + Set(pos int, val any) +} + +type accessor struct { + pos int + inner *accessor +} + +func (a *accessor) String() string { + return fmt.Sprintf("Accessor(position=%d, inner=%s)", a.pos, a.inner) +} + +func (a *accessor) Get(s structLike) any { + val, inner := s.Get(a.pos), a + for val != nil && inner.inner != nil { + inner = inner.inner + val = val.(structLike).Get(inner.pos) + } + return val +} + +type Set[E any] interface { + Add(...E) + Contains(E) bool + Members() []E + Equals(Set[E]) bool + Len() int + All(func(E) bool) bool +} + +var lzseed = maphash.MakeSeed() + +type literalSet map[any]struct{ orig Literal } + +func newLiteralSet(vals ...Literal) Set[Literal] { + s := literalSet{} + for _, v := range vals { + s.addliteral(v) + } + return s +} + +func (l literalSet) addliteral(v Literal) { + switch v := v.(type) { + case FixedLiteral: + l[maphash.Bytes(lzseed, []byte(v))] = struct{ orig Literal }{v} + case BinaryLiteral: + l[maphash.Bytes(lzseed, []byte(v))] = struct{ orig Literal }{v} + default: + l[v] = struct{ orig Literal }{} + } +} + +func (l literalSet) Add(lits ...Literal) { + for _, v := range lits { + l.addliteral(v) + } +} + +func (l literalSet) Contains(lit Literal) bool { + switch lit := lit.(type) { + case BinaryLiteral: + v, ok := l[maphash.Bytes(lzseed, []byte(lit))] + if !ok { + return false + } + return lit.Equals(v.orig) + case FixedLiteral: + v, ok := l[maphash.Bytes(lzseed, []byte(lit))] + if !ok { + return false + } + return lit.Equals(v.orig) + default: + _, ok := l[lit] + return ok + } +} + +func (l literalSet) Members() []Literal { + result := make([]Literal, 0, len(l)) + for k, v := range l { + if k, ok := k.(Literal); ok { + result = append(result, k) + } else { + result = append(result, v.orig) + } + } + return result +} + +func (l literalSet) Equals(other Set[Literal]) bool { + rhs, ok := other.(literalSet) + if !ok { + return false + } + return maps.EqualFunc(l, rhs, func(v1, v2 struct{ orig Literal }) bool { + switch { + case v1.orig == nil: + return v2.orig == nil + case v2.orig == nil: + return v1.orig == nil + default: + return v1.orig.Equals(v2.orig) + } + }) +} + +func (l literalSet) Len() int { return len(l) } + +func (l literalSet) All(fn func(Literal) bool) bool { + for k, v := range l { + var e Literal + if k, ok := k.(Literal); ok { + e = k + } else { + e = v.orig + } + + if !fn(e) { + return false + } + } + return true +} diff --git a/visitors.go b/visitors.go index 7525026..c16b1a3 100644 --- a/visitors.go +++ b/visitors.go @@ -1,397 +1,397 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package iceberg - -import ( - "fmt" - "math" - "strings" - - "github.com/google/uuid" -) - -// BooleanExprVisitor is an interface for recursively visiting the nodes of a -// boolean expression -type BooleanExprVisitor[T any] interface { - VisitTrue() T - VisitFalse() T - VisitNot(childResult T) T - VisitAnd(left, right T) T - VisitOr(left, right T) T - VisitUnbound(UnboundPredicate) T - VisitBound(BoundPredicate) T -} - -// BoundBooleanExprVisitor builds on BooleanExprVisitor by adding interface -// methods for visiting bound expressions, because we do casting of literals -// during binding you can assume that the BoundTerm and the Literal passed -// to a method have the same type. -type BoundBooleanExprVisitor[T any] interface { - BooleanExprVisitor[T] - - VisitIn(BoundTerm, Set[Literal]) T - VisitNotIn(BoundTerm, Set[Literal]) T - VisitIsNan(BoundTerm) T - VisitNotNan(BoundTerm) T - VisitIsNull(BoundTerm) T - VisitNotNull(BoundTerm) T - VisitEqual(BoundTerm, Literal) T - VisitNotEqual(BoundTerm, Literal) T - VisitGreaterEqual(BoundTerm, Literal) T - VisitGreater(BoundTerm, Literal) T - VisitLessEqual(BoundTerm, Literal) T - VisitLess(BoundTerm, Literal) T - VisitStartsWith(BoundTerm, Literal) T - VisitNotStartsWith(BoundTerm, Literal) T -} - -// VisitExpr is a convenience function to use a given visitor to visit all parts of -// a boolean expression in-order. Values returned from the methods are passed to the -// subsequent methods, effectively "bubbling up" the results. -func VisitExpr[T any](expr BooleanExpression, visitor BooleanExprVisitor[T]) (res T, err error) { - defer func() { - if r := recover(); r != nil { - switch e := r.(type) { - case string: - err = fmt.Errorf("error encountered during visitExpr: %s", e) - case error: - err = e - } - } - }() - - return visitBoolExpr(expr, visitor), err -} - -func visitBoolExpr[T any](e BooleanExpression, visitor BooleanExprVisitor[T]) T { - switch e := e.(type) { - case AlwaysFalse: - return visitor.VisitFalse() - case AlwaysTrue: - return visitor.VisitTrue() - case AndExpr: - left, right := visitBoolExpr(e.left, visitor), visitBoolExpr(e.right, visitor) - return visitor.VisitAnd(left, right) - case OrExpr: - left, right := visitBoolExpr(e.left, visitor), visitBoolExpr(e.right, visitor) - return visitor.VisitOr(left, right) - case NotExpr: - child := visitBoolExpr(e.child, visitor) - return visitor.VisitNot(child) - case UnboundPredicate: - return visitor.VisitUnbound(e) - case BoundPredicate: - return visitor.VisitBound(e) - } - panic(fmt.Errorf("%w: VisitBooleanExpression type %s", ErrNotImplemented, e)) -} - -// VisitBoundPredicate uses a BoundBooleanExprVisitor to call the appropriate method -// based on the type of operation in the predicate. This is a convenience function -// for implementing the VisitBound method of a BoundBooleanExprVisitor by simply calling -// iceberg.VisitBoundPredicate(pred, this). -func VisitBoundPredicate[T any](e BoundPredicate, visitor BoundBooleanExprVisitor[T]) T { - switch e.Op() { - case OpIn: - return visitor.VisitIn(e.Term(), e.(BoundSetPredicate).Literals()) - case OpNotIn: - return visitor.VisitNotIn(e.Term(), e.(BoundSetPredicate).Literals()) - case OpIsNan: - return visitor.VisitIsNan(e.Term()) - case OpNotNan: - return visitor.VisitNotNan(e.Term()) - case OpIsNull: - return visitor.VisitIsNull(e.Term()) - case OpNotNull: - return visitor.VisitNotNull(e.Term()) - case OpEQ: - return visitor.VisitEqual(e.Term(), e.(BoundLiteralPredicate).Literal()) - case OpNEQ: - return visitor.VisitNotEqual(e.Term(), e.(BoundLiteralPredicate).Literal()) - case OpGTEQ: - return visitor.VisitGreaterEqual(e.Term(), e.(BoundLiteralPredicate).Literal()) - case OpGT: - return visitor.VisitGreater(e.Term(), e.(BoundLiteralPredicate).Literal()) - case OpLTEQ: - return visitor.VisitLessEqual(e.Term(), e.(BoundLiteralPredicate).Literal()) - case OpLT: - return visitor.VisitLess(e.Term(), e.(BoundLiteralPredicate).Literal()) - case OpStartsWith: - return visitor.VisitStartsWith(e.Term(), e.(BoundLiteralPredicate).Literal()) - case OpNotStartsWith: - return visitor.VisitNotStartsWith(e.Term(), e.(BoundLiteralPredicate).Literal()) - } - panic(fmt.Errorf("%w: unhandled bound predicate type: %s", ErrNotImplemented, e)) -} - -// BindExpr recursively binds each portion of an expression using the provided schema. -// Because the expression can end up being simplified to just AlwaysTrue/AlwaysFalse, -// this returns a BooleanExpression. -func BindExpr(s *Schema, expr BooleanExpression, caseSensitive bool) (BooleanExpression, error) { - return VisitExpr(expr, &bindVisitor{schema: s, caseSensitive: caseSensitive}) -} - -type bindVisitor struct { - schema *Schema - caseSensitive bool -} - -func (*bindVisitor) VisitTrue() BooleanExpression { return AlwaysTrue{} } -func (*bindVisitor) VisitFalse() BooleanExpression { return AlwaysFalse{} } -func (*bindVisitor) VisitNot(child BooleanExpression) BooleanExpression { - return NewNot(child) -} -func (*bindVisitor) VisitAnd(left, right BooleanExpression) BooleanExpression { - return NewAnd(left, right) -} -func (*bindVisitor) VisitOr(left, right BooleanExpression) BooleanExpression { - return NewOr(left, right) -} -func (b *bindVisitor) VisitUnbound(pred UnboundPredicate) BooleanExpression { - expr, err := pred.Bind(b.schema, b.caseSensitive) - if err != nil { - panic(err) - } - return expr -} -func (*bindVisitor) VisitBound(pred BoundPredicate) BooleanExpression { - panic(fmt.Errorf("%w: found already bound predicate: %s", ErrInvalidArgument, pred)) -} - -// ExpressionEvaluator returns a function which can be used to evaluate a given expression -// as long as a structlike value is passed which operates like and matches the passed in -// schema. -func ExpressionEvaluator(s *Schema, unbound BooleanExpression, caseSensitive bool) (func(structLike) (bool, error), error) { - bound, err := BindExpr(s, unbound, caseSensitive) - if err != nil { - return nil, err - } - - return (&exprEvaluator{bound: bound}).Eval, nil -} - -type exprEvaluator struct { - bound BooleanExpression - st structLike -} - -func (e *exprEvaluator) Eval(st structLike) (bool, error) { - e.st = st - return VisitExpr(e.bound, e) -} - -func (e *exprEvaluator) VisitUnbound(UnboundPredicate) bool { - panic("found unbound predicate when evaluating expression") -} - -func (e *exprEvaluator) VisitBound(pred BoundPredicate) bool { - return VisitBoundPredicate(pred, e) -} - -func (*exprEvaluator) VisitTrue() bool { return true } -func (*exprEvaluator) VisitFalse() bool { return false } -func (*exprEvaluator) VisitNot(child bool) bool { return !child } -func (*exprEvaluator) VisitAnd(left, right bool) bool { return left && right } -func (*exprEvaluator) VisitOr(left, right bool) bool { return left || right } - -func (e *exprEvaluator) VisitIn(term BoundTerm, literals Set[Literal]) bool { - v := term.evalToLiteral(e.st) - if !v.Valid { - return false - } - - return literals.Contains(v.Val) -} - -func (e *exprEvaluator) VisitNotIn(term BoundTerm, literals Set[Literal]) bool { - return !e.VisitIn(term, literals) -} - -func (e *exprEvaluator) VisitIsNan(term BoundTerm) bool { - switch term.Type().(type) { - case Float32Type: - v := term.(bound[float32]).eval(e.st) - if !v.Valid { - break - } - return math.IsNaN(float64(v.Val)) - case Float64Type: - v := term.(bound[float64]).eval(e.st) - if !v.Valid { - break - } - return math.IsNaN(v.Val) - } - - return false -} - -func (e *exprEvaluator) VisitNotNan(term BoundTerm) bool { - return !e.VisitIsNan(term) -} - -func (e *exprEvaluator) VisitIsNull(term BoundTerm) bool { - return term.evalIsNull(e.st) -} - -func (e *exprEvaluator) VisitNotNull(term BoundTerm) bool { - return !term.evalIsNull(e.st) -} - -func nullsFirstCmp[T LiteralType](cmp Comparator[T], v1, v2 Optional[T]) int { - if !v1.Valid { - if !v2.Valid { - // both are null - return 0 - } - // v1 is null, v2 is not - return -1 - } - - if !v2.Valid { - return 1 - } - - return cmp(v1.Val, v2.Val) -} - -func typedCmp[T LiteralType](st structLike, term BoundTerm, lit Literal) int { - v := term.(bound[T]).eval(st) - var l Optional[T] - - rhs := lit.(TypedLiteral[T]) - if lit != nil { - l.Valid = true - l.Val = rhs.Value() - } - - return nullsFirstCmp(rhs.Comparator(), v, l) -} - -func doCmp(st structLike, term BoundTerm, lit Literal) int { - // we already properly casted and converted everything during binding - // so we can type assert based on the term type - switch term.Type().(type) { - case BooleanType: - return typedCmp[bool](st, term, lit) - case Int32Type: - return typedCmp[int32](st, term, lit) - case Int64Type: - return typedCmp[int64](st, term, lit) - case Float32Type: - return typedCmp[float32](st, term, lit) - case Float64Type: - return typedCmp[float64](st, term, lit) - case DateType: - return typedCmp[Date](st, term, lit) - case TimeType: - return typedCmp[Time](st, term, lit) - case TimestampType, TimestampTzType: - return typedCmp[Timestamp](st, term, lit) - case BinaryType, FixedType: - return typedCmp[[]byte](st, term, lit) - case StringType: - return typedCmp[string](st, term, lit) - case UUIDType: - return typedCmp[uuid.UUID](st, term, lit) - case DecimalType: - return typedCmp[Decimal](st, term, lit) - } - panic(ErrType) -} - -func (e *exprEvaluator) VisitEqual(term BoundTerm, lit Literal) bool { - return doCmp(e.st, term, lit) == 0 -} - -func (e *exprEvaluator) VisitNotEqual(term BoundTerm, lit Literal) bool { - return doCmp(e.st, term, lit) != 0 -} - -func (e *exprEvaluator) VisitGreater(term BoundTerm, lit Literal) bool { - return doCmp(e.st, term, lit) > 0 -} - -func (e *exprEvaluator) VisitGreaterEqual(term BoundTerm, lit Literal) bool { - return doCmp(e.st, term, lit) >= 0 -} - -func (e *exprEvaluator) VisitLess(term BoundTerm, lit Literal) bool { - return doCmp(e.st, term, lit) < 0 -} - -func (e *exprEvaluator) VisitLessEqual(term BoundTerm, lit Literal) bool { - return doCmp(e.st, term, lit) <= 0 -} - -func (e *exprEvaluator) VisitStartsWith(term BoundTerm, lit Literal) bool { - var value, prefix string - - switch lit.(type) { - case TypedLiteral[string]: - val := term.(bound[string]).eval(e.st) - if !val.Valid { - return false - } - prefix, value = lit.(StringLiteral).Value(), val.Val - case TypedLiteral[[]byte]: - val := term.(bound[[]byte]).eval(e.st) - if !val.Valid { - return false - } - prefix, value = string(lit.(TypedLiteral[[]byte]).Value()), string(val.Val) - } - - return strings.HasPrefix(value, prefix) -} - -func (e *exprEvaluator) VisitNotStartsWith(term BoundTerm, lit Literal) bool { - return !e.VisitStartsWith(term, lit) -} - -// RewriteNotExpr rewrites a boolean expression to remove "Not" nodes from the expression -// tree. This is because Projections assume there are no "not" nodes. -// -// Not nodes will be replaced with simply calling `Negate` on the child in the tree. -func RewriteNotExpr(expr BooleanExpression) (BooleanExpression, error) { - return VisitExpr(expr, rewriteNotVisitor{}) -} - -type rewriteNotVisitor struct{} - -func (rewriteNotVisitor) VisitTrue() BooleanExpression { return AlwaysTrue{} } -func (rewriteNotVisitor) VisitFalse() BooleanExpression { return AlwaysFalse{} } -func (rewriteNotVisitor) VisitNot(child BooleanExpression) BooleanExpression { - return child.Negate() -} - -func (rewriteNotVisitor) VisitAnd(left, right BooleanExpression) BooleanExpression { - return NewAnd(left, right) -} - -func (rewriteNotVisitor) VisitOr(left, right BooleanExpression) BooleanExpression { - return NewOr(left, right) -} - -func (rewriteNotVisitor) VisitUnbound(pred UnboundPredicate) BooleanExpression { - return pred -} - -func (rewriteNotVisitor) VisitBound(pred BoundPredicate) BooleanExpression { - return pred -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package iceberg + +import ( + "fmt" + "math" + "strings" + + "github.com/google/uuid" +) + +// BooleanExprVisitor is an interface for recursively visiting the nodes of a +// boolean expression +type BooleanExprVisitor[T any] interface { + VisitTrue() T + VisitFalse() T + VisitNot(childResult T) T + VisitAnd(left, right T) T + VisitOr(left, right T) T + VisitUnbound(UnboundPredicate) T + VisitBound(BoundPredicate) T +} + +// BoundBooleanExprVisitor builds on BooleanExprVisitor by adding interface +// methods for visiting bound expressions, because we do casting of literals +// during binding you can assume that the BoundTerm and the Literal passed +// to a method have the same type. +type BoundBooleanExprVisitor[T any] interface { + BooleanExprVisitor[T] + + VisitIn(BoundTerm, Set[Literal]) T + VisitNotIn(BoundTerm, Set[Literal]) T + VisitIsNan(BoundTerm) T + VisitNotNan(BoundTerm) T + VisitIsNull(BoundTerm) T + VisitNotNull(BoundTerm) T + VisitEqual(BoundTerm, Literal) T + VisitNotEqual(BoundTerm, Literal) T + VisitGreaterEqual(BoundTerm, Literal) T + VisitGreater(BoundTerm, Literal) T + VisitLessEqual(BoundTerm, Literal) T + VisitLess(BoundTerm, Literal) T + VisitStartsWith(BoundTerm, Literal) T + VisitNotStartsWith(BoundTerm, Literal) T +} + +// VisitExpr is a convenience function to use a given visitor to visit all parts of +// a boolean expression in-order. Values returned from the methods are passed to the +// subsequent methods, effectively "bubbling up" the results. +func VisitExpr[T any](expr BooleanExpression, visitor BooleanExprVisitor[T]) (res T, err error) { + defer func() { + if r := recover(); r != nil { + switch e := r.(type) { + case string: + err = fmt.Errorf("error encountered during visitExpr: %s", e) + case error: + err = e + } + } + }() + + return visitBoolExpr(expr, visitor), err +} + +func visitBoolExpr[T any](e BooleanExpression, visitor BooleanExprVisitor[T]) T { + switch e := e.(type) { + case AlwaysFalse: + return visitor.VisitFalse() + case AlwaysTrue: + return visitor.VisitTrue() + case AndExpr: + left, right := visitBoolExpr(e.left, visitor), visitBoolExpr(e.right, visitor) + return visitor.VisitAnd(left, right) + case OrExpr: + left, right := visitBoolExpr(e.left, visitor), visitBoolExpr(e.right, visitor) + return visitor.VisitOr(left, right) + case NotExpr: + child := visitBoolExpr(e.child, visitor) + return visitor.VisitNot(child) + case UnboundPredicate: + return visitor.VisitUnbound(e) + case BoundPredicate: + return visitor.VisitBound(e) + } + panic(fmt.Errorf("%w: VisitBooleanExpression type %s", ErrNotImplemented, e)) +} + +// VisitBoundPredicate uses a BoundBooleanExprVisitor to call the appropriate method +// based on the type of operation in the predicate. This is a convenience function +// for implementing the VisitBound method of a BoundBooleanExprVisitor by simply calling +// iceberg.VisitBoundPredicate(pred, this). +func VisitBoundPredicate[T any](e BoundPredicate, visitor BoundBooleanExprVisitor[T]) T { + switch e.Op() { + case OpIn: + return visitor.VisitIn(e.Term(), e.(BoundSetPredicate).Literals()) + case OpNotIn: + return visitor.VisitNotIn(e.Term(), e.(BoundSetPredicate).Literals()) + case OpIsNan: + return visitor.VisitIsNan(e.Term()) + case OpNotNan: + return visitor.VisitNotNan(e.Term()) + case OpIsNull: + return visitor.VisitIsNull(e.Term()) + case OpNotNull: + return visitor.VisitNotNull(e.Term()) + case OpEQ: + return visitor.VisitEqual(e.Term(), e.(BoundLiteralPredicate).Literal()) + case OpNEQ: + return visitor.VisitNotEqual(e.Term(), e.(BoundLiteralPredicate).Literal()) + case OpGTEQ: + return visitor.VisitGreaterEqual(e.Term(), e.(BoundLiteralPredicate).Literal()) + case OpGT: + return visitor.VisitGreater(e.Term(), e.(BoundLiteralPredicate).Literal()) + case OpLTEQ: + return visitor.VisitLessEqual(e.Term(), e.(BoundLiteralPredicate).Literal()) + case OpLT: + return visitor.VisitLess(e.Term(), e.(BoundLiteralPredicate).Literal()) + case OpStartsWith: + return visitor.VisitStartsWith(e.Term(), e.(BoundLiteralPredicate).Literal()) + case OpNotStartsWith: + return visitor.VisitNotStartsWith(e.Term(), e.(BoundLiteralPredicate).Literal()) + } + panic(fmt.Errorf("%w: unhandled bound predicate type: %s", ErrNotImplemented, e)) +} + +// BindExpr recursively binds each portion of an expression using the provided schema. +// Because the expression can end up being simplified to just AlwaysTrue/AlwaysFalse, +// this returns a BooleanExpression. +func BindExpr(s *Schema, expr BooleanExpression, caseSensitive bool) (BooleanExpression, error) { + return VisitExpr(expr, &bindVisitor{schema: s, caseSensitive: caseSensitive}) +} + +type bindVisitor struct { + schema *Schema + caseSensitive bool +} + +func (*bindVisitor) VisitTrue() BooleanExpression { return AlwaysTrue{} } +func (*bindVisitor) VisitFalse() BooleanExpression { return AlwaysFalse{} } +func (*bindVisitor) VisitNot(child BooleanExpression) BooleanExpression { + return NewNot(child) +} +func (*bindVisitor) VisitAnd(left, right BooleanExpression) BooleanExpression { + return NewAnd(left, right) +} +func (*bindVisitor) VisitOr(left, right BooleanExpression) BooleanExpression { + return NewOr(left, right) +} +func (b *bindVisitor) VisitUnbound(pred UnboundPredicate) BooleanExpression { + expr, err := pred.Bind(b.schema, b.caseSensitive) + if err != nil { + panic(err) + } + return expr +} +func (*bindVisitor) VisitBound(pred BoundPredicate) BooleanExpression { + panic(fmt.Errorf("%w: found already bound predicate: %s", ErrInvalidArgument, pred)) +} + +// ExpressionEvaluator returns a function which can be used to evaluate a given expression +// as long as a structlike value is passed which operates like and matches the passed in +// schema. +func ExpressionEvaluator(s *Schema, unbound BooleanExpression, caseSensitive bool) (func(structLike) (bool, error), error) { + bound, err := BindExpr(s, unbound, caseSensitive) + if err != nil { + return nil, err + } + + return (&exprEvaluator{bound: bound}).Eval, nil +} + +type exprEvaluator struct { + bound BooleanExpression + st structLike +} + +func (e *exprEvaluator) Eval(st structLike) (bool, error) { + e.st = st + return VisitExpr(e.bound, e) +} + +func (e *exprEvaluator) VisitUnbound(UnboundPredicate) bool { + panic("found unbound predicate when evaluating expression") +} + +func (e *exprEvaluator) VisitBound(pred BoundPredicate) bool { + return VisitBoundPredicate(pred, e) +} + +func (*exprEvaluator) VisitTrue() bool { return true } +func (*exprEvaluator) VisitFalse() bool { return false } +func (*exprEvaluator) VisitNot(child bool) bool { return !child } +func (*exprEvaluator) VisitAnd(left, right bool) bool { return left && right } +func (*exprEvaluator) VisitOr(left, right bool) bool { return left || right } + +func (e *exprEvaluator) VisitIn(term BoundTerm, literals Set[Literal]) bool { + v := term.evalToLiteral(e.st) + if !v.Valid { + return false + } + + return literals.Contains(v.Val) +} + +func (e *exprEvaluator) VisitNotIn(term BoundTerm, literals Set[Literal]) bool { + return !e.VisitIn(term, literals) +} + +func (e *exprEvaluator) VisitIsNan(term BoundTerm) bool { + switch term.Type().(type) { + case Float32Type: + v := term.(bound[float32]).eval(e.st) + if !v.Valid { + break + } + return math.IsNaN(float64(v.Val)) + case Float64Type: + v := term.(bound[float64]).eval(e.st) + if !v.Valid { + break + } + return math.IsNaN(v.Val) + } + + return false +} + +func (e *exprEvaluator) VisitNotNan(term BoundTerm) bool { + return !e.VisitIsNan(term) +} + +func (e *exprEvaluator) VisitIsNull(term BoundTerm) bool { + return term.evalIsNull(e.st) +} + +func (e *exprEvaluator) VisitNotNull(term BoundTerm) bool { + return !term.evalIsNull(e.st) +} + +func nullsFirstCmp[T LiteralType](cmp Comparator[T], v1, v2 Optional[T]) int { + if !v1.Valid { + if !v2.Valid { + // both are null + return 0 + } + // v1 is null, v2 is not + return -1 + } + + if !v2.Valid { + return 1 + } + + return cmp(v1.Val, v2.Val) +} + +func typedCmp[T LiteralType](st structLike, term BoundTerm, lit Literal) int { + v := term.(bound[T]).eval(st) + var l Optional[T] + + rhs := lit.(TypedLiteral[T]) + if lit != nil { + l.Valid = true + l.Val = rhs.Value() + } + + return nullsFirstCmp(rhs.Comparator(), v, l) +} + +func doCmp(st structLike, term BoundTerm, lit Literal) int { + // we already properly casted and converted everything during binding + // so we can type assert based on the term type + switch term.Type().(type) { + case BooleanType: + return typedCmp[bool](st, term, lit) + case Int32Type: + return typedCmp[int32](st, term, lit) + case Int64Type: + return typedCmp[int64](st, term, lit) + case Float32Type: + return typedCmp[float32](st, term, lit) + case Float64Type: + return typedCmp[float64](st, term, lit) + case DateType: + return typedCmp[Date](st, term, lit) + case TimeType: + return typedCmp[Time](st, term, lit) + case TimestampType, TimestampTzType: + return typedCmp[Timestamp](st, term, lit) + case BinaryType, FixedType: + return typedCmp[[]byte](st, term, lit) + case StringType: + return typedCmp[string](st, term, lit) + case UUIDType: + return typedCmp[uuid.UUID](st, term, lit) + case DecimalType: + return typedCmp[Decimal](st, term, lit) + } + panic(ErrType) +} + +func (e *exprEvaluator) VisitEqual(term BoundTerm, lit Literal) bool { + return doCmp(e.st, term, lit) == 0 +} + +func (e *exprEvaluator) VisitNotEqual(term BoundTerm, lit Literal) bool { + return doCmp(e.st, term, lit) != 0 +} + +func (e *exprEvaluator) VisitGreater(term BoundTerm, lit Literal) bool { + return doCmp(e.st, term, lit) > 0 +} + +func (e *exprEvaluator) VisitGreaterEqual(term BoundTerm, lit Literal) bool { + return doCmp(e.st, term, lit) >= 0 +} + +func (e *exprEvaluator) VisitLess(term BoundTerm, lit Literal) bool { + return doCmp(e.st, term, lit) < 0 +} + +func (e *exprEvaluator) VisitLessEqual(term BoundTerm, lit Literal) bool { + return doCmp(e.st, term, lit) <= 0 +} + +func (e *exprEvaluator) VisitStartsWith(term BoundTerm, lit Literal) bool { + var value, prefix string + + switch lit.(type) { + case TypedLiteral[string]: + val := term.(bound[string]).eval(e.st) + if !val.Valid { + return false + } + prefix, value = lit.(StringLiteral).Value(), val.Val + case TypedLiteral[[]byte]: + val := term.(bound[[]byte]).eval(e.st) + if !val.Valid { + return false + } + prefix, value = string(lit.(TypedLiteral[[]byte]).Value()), string(val.Val) + } + + return strings.HasPrefix(value, prefix) +} + +func (e *exprEvaluator) VisitNotStartsWith(term BoundTerm, lit Literal) bool { + return !e.VisitStartsWith(term, lit) +} + +// RewriteNotExpr rewrites a boolean expression to remove "Not" nodes from the expression +// tree. This is because Projections assume there are no "not" nodes. +// +// Not nodes will be replaced with simply calling `Negate` on the child in the tree. +func RewriteNotExpr(expr BooleanExpression) (BooleanExpression, error) { + return VisitExpr(expr, rewriteNotVisitor{}) +} + +type rewriteNotVisitor struct{} + +func (rewriteNotVisitor) VisitTrue() BooleanExpression { return AlwaysTrue{} } +func (rewriteNotVisitor) VisitFalse() BooleanExpression { return AlwaysFalse{} } +func (rewriteNotVisitor) VisitNot(child BooleanExpression) BooleanExpression { + return child.Negate() +} + +func (rewriteNotVisitor) VisitAnd(left, right BooleanExpression) BooleanExpression { + return NewAnd(left, right) +} + +func (rewriteNotVisitor) VisitOr(left, right BooleanExpression) BooleanExpression { + return NewOr(left, right) +} + +func (rewriteNotVisitor) VisitUnbound(pred UnboundPredicate) BooleanExpression { + return pred +} + +func (rewriteNotVisitor) VisitBound(pred BoundPredicate) BooleanExpression { + return pred +} diff --git a/visitors_test.go b/visitors_test.go index cd93a60..b1f0342 100644 --- a/visitors_test.go +++ b/visitors_test.go @@ -1,607 +1,607 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package iceberg_test - -import ( - "math" - "strings" - "testing" - - "github.com/apache/arrow-go/v18/arrow/decimal128" - "github.com/apache/iceberg-go" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -type ExampleVisitor struct { - visitHistory []string -} - -func (e *ExampleVisitor) VisitTrue() []string { - e.visitHistory = append(e.visitHistory, "TRUE") - return e.visitHistory -} - -func (e *ExampleVisitor) VisitFalse() []string { - e.visitHistory = append(e.visitHistory, "FALSE") - return e.visitHistory -} - -func (e *ExampleVisitor) VisitNot([]string) []string { - e.visitHistory = append(e.visitHistory, "NOT") - return e.visitHistory -} - -func (e *ExampleVisitor) VisitAnd(_, _ []string) []string { - e.visitHistory = append(e.visitHistory, "AND") - return e.visitHistory -} - -func (e *ExampleVisitor) VisitOr(_, _ []string) []string { - e.visitHistory = append(e.visitHistory, "OR") - return e.visitHistory -} - -func (e *ExampleVisitor) VisitUnbound(pred iceberg.UnboundPredicate) []string { - e.visitHistory = append(e.visitHistory, strings.ToUpper(pred.Op().String())) - return e.visitHistory -} - -func (e *ExampleVisitor) VisitBound(pred iceberg.BoundPredicate) []string { - e.visitHistory = append(e.visitHistory, strings.ToUpper(pred.Op().String())) - return e.visitHistory -} - -type FooBoundExprVisitor struct { - ExampleVisitor -} - -func (e *FooBoundExprVisitor) VisitBound(pred iceberg.BoundPredicate) []string { - return iceberg.VisitBoundPredicate(pred, e) -} - -func (e *FooBoundExprVisitor) VisitUnbound(pred iceberg.UnboundPredicate) []string { - panic("found unbound predicate when evaluating") -} - -func (e *FooBoundExprVisitor) VisitIn(iceberg.BoundTerm, iceberg.Set[iceberg.Literal]) []string { - e.visitHistory = append(e.visitHistory, "IN") - return e.visitHistory -} - -func (e *FooBoundExprVisitor) VisitNotIn(iceberg.BoundTerm, iceberg.Set[iceberg.Literal]) []string { - e.visitHistory = append(e.visitHistory, "NOT_IN") - return e.visitHistory -} - -func (e *FooBoundExprVisitor) VisitIsNan(iceberg.BoundTerm) []string { - e.visitHistory = append(e.visitHistory, "IS_NAN") - return e.visitHistory -} - -func (e *FooBoundExprVisitor) VisitNotNan(iceberg.BoundTerm) []string { - e.visitHistory = append(e.visitHistory, "NOT_NAN") - return e.visitHistory -} - -func (e *FooBoundExprVisitor) VisitIsNull(iceberg.BoundTerm) []string { - e.visitHistory = append(e.visitHistory, "IS_NULL") - return e.visitHistory -} - -func (e *FooBoundExprVisitor) VisitNotNull(iceberg.BoundTerm) []string { - e.visitHistory = append(e.visitHistory, "NOT_NULL") - return e.visitHistory -} - -func (e *FooBoundExprVisitor) VisitEqual(iceberg.BoundTerm, iceberg.Literal) []string { - e.visitHistory = append(e.visitHistory, "EQUAL") - return e.visitHistory -} - -func (e *FooBoundExprVisitor) VisitNotEqual(iceberg.BoundTerm, iceberg.Literal) []string { - e.visitHistory = append(e.visitHistory, "NOT_EQUAL") - return e.visitHistory -} - -func (e *FooBoundExprVisitor) VisitGreaterEqual(iceberg.BoundTerm, iceberg.Literal) []string { - e.visitHistory = append(e.visitHistory, "GREATER_THAN_OR_EQUAL") - return e.visitHistory -} - -func (e *FooBoundExprVisitor) VisitGreater(iceberg.BoundTerm, iceberg.Literal) []string { - e.visitHistory = append(e.visitHistory, "GREATER_THAN") - return e.visitHistory -} - -func (e *FooBoundExprVisitor) VisitLessEqual(iceberg.BoundTerm, iceberg.Literal) []string { - e.visitHistory = append(e.visitHistory, "LESS_THAN_OR_EQUAL") - return e.visitHistory -} - -func (e *FooBoundExprVisitor) VisitLess(iceberg.BoundTerm, iceberg.Literal) []string { - e.visitHistory = append(e.visitHistory, "LESS_THAN") - return e.visitHistory -} - -func (e *FooBoundExprVisitor) VisitStartsWith(iceberg.BoundTerm, iceberg.Literal) []string { - e.visitHistory = append(e.visitHistory, "STARTS_WITH") - return e.visitHistory -} - -func (e *FooBoundExprVisitor) VisitNotStartsWith(iceberg.BoundTerm, iceberg.Literal) []string { - e.visitHistory = append(e.visitHistory, "NOT_STARTS_WITH") - return e.visitHistory -} - -func TestBooleanExprVisitor(t *testing.T) { - expr := iceberg.NewAnd( - iceberg.NewOr( - iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("a"), int32(1))), - iceberg.NewNot(iceberg.NotEqualTo(iceberg.Reference("b"), int32(0))), - iceberg.EqualTo(iceberg.Reference("a"), int32(1)), - iceberg.NotEqualTo(iceberg.Reference("b"), int32(0)), - ), - iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("a"), int32(1))), - iceberg.NotEqualTo(iceberg.Reference("b"), int32(0))) - - visitor := ExampleVisitor{visitHistory: make([]string, 0)} - result, err := iceberg.VisitExpr(expr, &visitor) - require.NoError(t, err) - assert.Equal(t, []string{ - "EQUAL", - "NOT", - "NOTEQUAL", - "NOT", - "OR", - "EQUAL", - "OR", - "NOTEQUAL", - "OR", - "EQUAL", - "NOT", - "AND", - "NOTEQUAL", - "AND", - }, result) -} - -func TestBindVisitorAlready(t *testing.T) { - bound, err := iceberg.EqualTo(iceberg.Reference("foo"), "hello"). - Bind(tableSchemaSimple, false) - require.NoError(t, err) - - _, err = iceberg.BindExpr(tableSchemaSimple, bound, true) - assert.ErrorIs(t, err, iceberg.ErrInvalidArgument) - assert.ErrorContains(t, err, "found already bound predicate: BoundEqual(term=BoundReference(field=1: foo: optional string, accessor=Accessor(position=0, inner=)), literal=hello)") -} - -func TestAlwaysExprBinding(t *testing.T) { - tests := []struct { - expr iceberg.BooleanExpression - expected iceberg.BooleanExpression - }{ - {iceberg.AlwaysTrue{}, iceberg.AlwaysTrue{}}, - {iceberg.AlwaysFalse{}, iceberg.AlwaysFalse{}}, - {iceberg.NewAnd(iceberg.AlwaysTrue{}, iceberg.AlwaysFalse{}), iceberg.AlwaysFalse{}}, - {iceberg.NewOr(iceberg.AlwaysTrue{}, iceberg.AlwaysFalse{}), iceberg.AlwaysTrue{}}, - } - - for _, tt := range tests { - t.Run(tt.expr.String(), func(t *testing.T) { - bound, err := iceberg.BindExpr(tableSchemaSimple, tt.expr, true) - require.NoError(t, err) - assert.Equal(t, tt.expected, bound) - }) - } -} - -func TestBoundBoolExprVisitor(t *testing.T) { - tests := []struct { - expr iceberg.BooleanExpression - expected []string - }{ - {iceberg.NewAnd(iceberg.IsIn(iceberg.Reference("foo"), "foo", "bar"), - iceberg.IsIn(iceberg.Reference("bar"), int32(1), int32(2))), []string{"IN", "IN", "AND"}}, - {iceberg.NewOr(iceberg.NewNot(iceberg.IsIn(iceberg.Reference("foo"), "foo", "bar")), - iceberg.NewNot(iceberg.IsIn(iceberg.Reference("bar"), int32(1), int32(2)))), - []string{"IN", "NOT", "IN", "NOT", "OR"}}, - {iceberg.EqualTo(iceberg.Reference("bar"), int32(1)), []string{"EQUAL"}}, - {iceberg.NotEqualTo(iceberg.Reference("foo"), "foo"), []string{"NOT_EQUAL"}}, - {iceberg.AlwaysTrue{}, []string{"TRUE"}}, - {iceberg.AlwaysFalse{}, []string{"FALSE"}}, - {iceberg.NotIn(iceberg.Reference("foo"), "bar", "foo"), []string{"NOT_IN"}}, - {iceberg.IsNull(iceberg.Reference("foo")), []string{"IS_NULL"}}, - {iceberg.NotNull(iceberg.Reference("foo")), []string{"NOT_NULL"}}, - {iceberg.GreaterThan(iceberg.Reference("foo"), "foo"), []string{"GREATER_THAN"}}, - {iceberg.GreaterThanEqual(iceberg.Reference("foo"), "foo"), []string{"GREATER_THAN_OR_EQUAL"}}, - {iceberg.LessThan(iceberg.Reference("foo"), "foo"), []string{"LESS_THAN"}}, - {iceberg.LessThanEqual(iceberg.Reference("foo"), "foo"), []string{"LESS_THAN_OR_EQUAL"}}, - {iceberg.StartsWith(iceberg.Reference("foo"), "foo"), []string{"STARTS_WITH"}}, - {iceberg.NotStartsWith(iceberg.Reference("foo"), "foo"), []string{"NOT_STARTS_WITH"}}, - } - - for _, tt := range tests { - t.Run(tt.expr.String(), func(t *testing.T) { - bound, err := iceberg.BindExpr(tableSchemaNested, - tt.expr, - true) - require.NoError(t, err) - - visitor := FooBoundExprVisitor{ExampleVisitor: ExampleVisitor{visitHistory: []string{}}} - result, err := iceberg.VisitExpr(bound, &visitor) - require.NoError(t, err) - assert.Equal(t, tt.expected, result) - }) - } -} - -type rowTester []any - -func (r rowTester) Size() int { return len(r) } -func (r rowTester) Get(pos int) any { return r[pos] } -func (r rowTester) Set(pos int, val any) { - r[pos] = val -} - -func rowOf(vals ...any) rowTester { - return rowTester(vals) -} - -var testSchema = iceberg.NewSchema(1, - iceberg.NestedField{ID: 13, Name: "x", - Type: iceberg.PrimitiveTypes.Int32, Required: true}, - iceberg.NestedField{ID: 14, Name: "y", - Type: iceberg.PrimitiveTypes.Float64, Required: true}, - iceberg.NestedField{ID: 15, Name: "z", - Type: iceberg.PrimitiveTypes.Int32}, - iceberg.NestedField{ID: 16, Name: "s1", - Type: &iceberg.StructType{ - FieldList: []iceberg.NestedField{{ - ID: 17, Name: "s2", Required: true, - Type: &iceberg.StructType{ - FieldList: []iceberg.NestedField{{ - ID: 18, Name: "s3", Required: true, - Type: &iceberg.StructType{ - FieldList: []iceberg.NestedField{{ - ID: 19, Name: "s4", Required: true, - Type: &iceberg.StructType{ - FieldList: []iceberg.NestedField{{ - ID: 20, Name: "i", Required: true, - Type: iceberg.PrimitiveTypes.Int32, - }}, - }, - }}, - }, - }}, - }, - }}, - }}, - iceberg.NestedField{ID: 21, Name: "s5", Type: &iceberg.StructType{ - FieldList: []iceberg.NestedField{{ - ID: 22, Name: "s6", Required: true, Type: &iceberg.StructType{ - FieldList: []iceberg.NestedField{{ - ID: 23, Name: "f", Required: true, Type: iceberg.PrimitiveTypes.Float32, - }}, - }, - }}, - }}, - iceberg.NestedField{ID: 24, Name: "s", Type: iceberg.PrimitiveTypes.String}) - -func TestExprEvaluator(t *testing.T) { - type testCase struct { - str string - row rowTester - result bool - } - - tests := []struct { - exp iceberg.BooleanExpression - cases []testCase - }{ - {iceberg.AlwaysTrue{}, []testCase{{"always true", rowOf(), true}}}, - {iceberg.AlwaysFalse{}, []testCase{{"always false", rowOf(), false}}}, - {iceberg.LessThan(iceberg.Reference("x"), int32(7)), []testCase{ - {"7 < 7 => false", rowOf(7, 8, nil, nil), false}, - {"6 < 7 => true", rowOf(6, 8, nil, nil), true}, - }}, - {iceberg.LessThan(iceberg.Reference("s1.s2.s3.s4.i"), int32(7)), []testCase{ - {"7 < 7 => false", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(7))))), false}, - {"6 < 7 => true", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(6))))), true}, - {"nil < 7 => true", rowOf(7, 8, nil, nil), true}, - }}, - {iceberg.LessThanEqual(iceberg.Reference("x"), int32(7)), []testCase{ - {"7 <= 7 => true", rowOf(7, 8, nil), true}, - {"6 <= 7 => true", rowOf(6, 8, nil), true}, - {"8 <= 7 => false", rowOf(8, 8, nil), false}, - }}, - {iceberg.LessThanEqual(iceberg.Reference("s1.s2.s3.s4.i"), int32(7)), []testCase{ - {"7 <= 7 => true", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(7))))), true}, - {"6 <= 7 => true", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(6))))), true}, - {"8 <= 7 => false", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(8))))), false}, - }}, - {iceberg.GreaterThan(iceberg.Reference("x"), int32(7)), []testCase{ - {"7 > 7 => false", rowOf(7, 8, nil), false}, - {"6 > 7 => false", rowOf(6, 8, nil), false}, - {"8 > 7 => true", rowOf(8, 8, nil), true}, - }}, - {iceberg.GreaterThan(iceberg.Reference("s1.s2.s3.s4.i"), int32(7)), []testCase{ - {"7 > 7 => false", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(7))))), false}, - {"6 > 7 => false", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(6))))), false}, - {"8 > 7 => true", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(8))))), true}, - }}, - {iceberg.GreaterThanEqual(iceberg.Reference("x"), int32(7)), []testCase{ - {"7 >= 7 => true", rowOf(7, 8, nil), true}, - {"6 >= 7 => false", rowOf(6, 8, nil), false}, - {"8 >= 7 => true", rowOf(8, 8, nil), true}, - }}, - {iceberg.GreaterThanEqual(iceberg.Reference("s1.s2.s3.s4.i"), int32(7)), []testCase{ - {"7 >= 7 => true", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(7))))), true}, - {"6 >= 7 => false", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(6))))), false}, - {"8 >= 7 => true", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(8))))), true}, - }}, - {iceberg.EqualTo(iceberg.Reference("x"), int32(7)), []testCase{ - {"7 == 7 => true", rowOf(7, 8, nil), true}, - {"6 == 7 => false", rowOf(6, 8, nil), false}, - }}, - {iceberg.EqualTo(iceberg.Reference("s1.s2.s3.s4.i"), int32(7)), []testCase{ - {"7 == 7 => true", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(7))))), true}, - {"6 == 7 => false", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(6))))), false}, - }}, - {iceberg.NotEqualTo(iceberg.Reference("x"), int32(7)), []testCase{ - {"7 != 7 => false", rowOf(7, 8, nil), false}, - {"6 != 7 => true", rowOf(6, 8, nil), true}, - }}, - {iceberg.NotEqualTo(iceberg.Reference("s1.s2.s3.s4.i"), int32(7)), []testCase{ - {"7 != 7 => false", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(7))))), false}, - {"6 != 7 => true", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(6))))), true}, - }}, - {iceberg.IsNull(iceberg.Reference("z")), []testCase{ - {"nil is null", rowOf(1, 2, nil), true}, - {"3 is not null", rowOf(1, 2, 3), false}, - }}, - {iceberg.IsNull(iceberg.Reference("s1.s2.s3.s4.i")), []testCase{ - {"3 is not null", rowOf(1, 2, 3, rowOf(rowOf(rowOf(rowOf(3))))), false}, - }}, - {iceberg.NotNull(iceberg.Reference("z")), []testCase{ - {"nil is null", rowOf(1, 2, nil), false}, - {"3 is not null", rowOf(1, 2, 3), true}, - }}, - {iceberg.NotNull(iceberg.Reference("s1.s2.s3.s4.i")), []testCase{ - {"3 is not null", rowOf(1, 2, 3, rowOf(rowOf(rowOf(rowOf(3))))), true}, - }}, - {iceberg.IsNaN(iceberg.Reference("y")), []testCase{ - {"NaN is NaN", rowOf(1, math.NaN(), 3), true}, - {"2 is not NaN", rowOf(1, 2.0, 3), false}, - }}, - {iceberg.IsNaN(iceberg.Reference("s5.s6.f")), []testCase{ - {"NaN is NaN", rowOf(1, 2, 3, nil, rowOf(rowOf(math.NaN()))), true}, - {"4 is not NaN", rowOf(1, 2, 3, nil, rowOf(rowOf(4.0))), false}, - {"nil is not NaN", rowOf(1, 2, 3, nil, nil), false}, - }}, - {iceberg.NotNaN(iceberg.Reference("y")), []testCase{ - {"NaN is NaN", rowOf(1, math.NaN(), 3), false}, - {"2 is not NaN", rowOf(1, 2.0, 3), true}, - }}, - {iceberg.NotNaN(iceberg.Reference("s5.s6.f")), []testCase{ - {"NaN is NaN", rowOf(1, 2, 3, nil, rowOf(rowOf(math.NaN()))), false}, - {"4 is not NaN", rowOf(1, 2, 3, nil, rowOf(rowOf(4.0))), true}, - }}, - {iceberg.NewAnd(iceberg.EqualTo(iceberg.Reference("x"), int32(7)), iceberg.NotNull(iceberg.Reference("z"))), []testCase{ - {"7, 3 => true", rowOf(7, 0, 3), true}, - {"8, 3 => false", rowOf(8, 0, 3), false}, - {"7, null => false", rowOf(7, 0, nil), false}, - {"8, null => false", rowOf(8, 0, nil), false}, - }}, - {iceberg.NewAnd(iceberg.EqualTo(iceberg.Reference("s1.s2.s3.s4.i"), int32(7)), - iceberg.NotNull(iceberg.Reference("s1.s2.s3.s4.i"))), []testCase{ - {"7, 7 => true", rowOf(5, 0, 3, rowOf(rowOf(rowOf(rowOf(7))))), true}, - {"8, 8 => false", rowOf(7, 0, 3, rowOf(rowOf(rowOf(rowOf(8))))), false}, - {"7, null => false", rowOf(5, 0, 3, nil), false}, - {"8, notnull => false", rowOf(7, 0, 3, rowOf(rowOf(rowOf(rowOf(8))))), false}, - }}, - {iceberg.NewOr(iceberg.EqualTo(iceberg.Reference("x"), int32(7)), iceberg.NotNull(iceberg.Reference("z"))), []testCase{ - {"7, 3 => true", rowOf(7, 0, 3), true}, - {"8, 3 => true", rowOf(8, 0, 3), true}, - {"7, null => true", rowOf(7, 0, nil), true}, - {"8, null => false", rowOf(8, 0, nil), false}, - }}, - {iceberg.NewOr(iceberg.EqualTo(iceberg.Reference("s1.s2.s3.s4.i"), int32(7)), - iceberg.NotNull(iceberg.Reference("s1.s2.s3.s4.i"))), []testCase{ - {"7, 7 => true", rowOf(5, 0, 3, rowOf(rowOf(rowOf(rowOf(7))))), true}, - {"8, notnull => true", rowOf(7, 0, 3, rowOf(rowOf(rowOf(rowOf(8))))), true}, - {"7, null => false", rowOf(5, 0, 3, nil), false}, - {"8, notnull => true", rowOf(7, 0, 3, rowOf(rowOf(rowOf(rowOf(8))))), true}, - }}, - {iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("x"), int32(7))), []testCase{ - {"not(7 == 7) => false", rowOf(7), false}, - {"not(8 == 7) => true", rowOf(8), true}, - }}, - {iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("s1.s2.s3.s4.i"), int32(7))), []testCase{ - {"not(7 == 7) => false", rowOf(7, nil, nil, rowOf(rowOf(rowOf(rowOf(7))))), false}, - {"not(8 == 7) => true", rowOf(7, nil, nil, rowOf(rowOf(rowOf(rowOf(8))))), true}, - }}, - {iceberg.IsIn(iceberg.Reference("x"), int64(7), 8, math.MaxInt64), []testCase{ - {"7 in [7, 8, Int64Max] => true", rowOf(7, 8, nil), true}, - {"9 in [7, 8, Int64Max] => false", rowOf(9, 8, nil), false}, - {"8 in [7, 8, Int64Max] => true", rowOf(8, 8, nil), true}, - }}, - {iceberg.IsIn(iceberg.Reference("x"), int64(math.MaxInt64), math.MaxInt32, math.MinInt64), []testCase{ - {"Int32Max in [Int64Max, Int32Max, Int64Min] => true", rowOf(math.MaxInt32, 7.0, nil), true}, - {"6 in [Int64Max, Int32Max, Int64Min] => false", rowOf(6, 6.9, nil), false}, - }}, - {iceberg.IsIn(iceberg.Reference("y"), float64(7), 8, 9.1), []testCase{ - {"7.0 in [7, 8, 9.1] => true", rowOf(0, 7.0, nil), true}, - {"9.1 in [7, 8, 9.1] => true", rowOf(7, 9.1, nil), true}, - {"6.8 in [7, 8, 9.1] => false", rowOf(7, 6.8, nil), false}, - }}, - {iceberg.IsIn(iceberg.Reference("s1.s2.s3.s4.i"), int32(7), 8, 9), []testCase{ - {"7 in [7, 8, 9] => true", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(7))))), true}, - {"6 in [7, 8, 9] => true", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(6))))), false}, - {"nil in [7, 8, 9] => false", rowOf(7, 8, nil, nil), false}, - }}, - {iceberg.NotIn(iceberg.Reference("x"), int64(7), 8, math.MaxInt64), []testCase{ - {"7 not in [7, 8, Int64Max] => false", rowOf(7, 8, nil), false}, - {"9 not in [7, 8, Int64Max] => true", rowOf(9, 8, nil), true}, - {"8 not in [7, 8, Int64Max] => false", rowOf(8, 8, nil), false}, - }}, - {iceberg.NotIn(iceberg.Reference("x"), int64(math.MaxInt64), math.MaxInt32, math.MinInt64), []testCase{ - {"Int32Max not in [Int64Max, Int32Max, Int64Min] => false", rowOf(math.MaxInt32, 7.0, nil), false}, - {"6 not in [Int64Max, Int32Max, Int64Min] => true", rowOf(6, 6.9, nil), true}, - }}, - {iceberg.NotIn(iceberg.Reference("y"), float64(7), 8, 9.1), []testCase{ - {"7.0 not in [7, 8, 9.1] => false", rowOf(0, 7.0, nil), false}, - {"9.1 not in [7, 8, 9.1] => false", rowOf(7, 9.1, nil), false}, - {"6.8 not in [7, 8, 9.1] => true", rowOf(7, 6.8, nil), true}, - }}, - {iceberg.NotIn(iceberg.Reference("s1.s2.s3.s4.i"), int32(7), 8, 9), []testCase{ - {"7 not in [7, 8, 9] => false", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(7))))), false}, - {"6 not in [7, 8, 9] => false", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(6))))), true}, - }}, - {iceberg.EqualTo(iceberg.Reference("s"), "abc"), []testCase{ - {"abc == abc => true", rowOf(1, 2, nil, nil, nil, "abc"), true}, - {"abd == abc => false", rowOf(1, 2, nil, nil, nil, "abd"), false}, - }}, - {iceberg.StartsWith(iceberg.Reference("s"), "abc"), []testCase{ - {"abc startsWith abc => true", rowOf(1, 2, nil, nil, nil, "abc"), true}, - {"xabc startsWith abc => false", rowOf(1, 2, nil, nil, nil, "xabc"), false}, - {"Abc startsWith abc => false", rowOf(1, 2, nil, nil, nil, "Abc"), false}, - {"a startsWith abc => false", rowOf(1, 2, nil, nil, nil, "a"), false}, - {"abcd startsWith abc => true", rowOf(1, 2, nil, nil, nil, "abcd"), true}, - {"nil startsWith abc => false", rowOf(1, 2, nil, nil, nil, nil), false}, - }}, - {iceberg.NotStartsWith(iceberg.Reference("s"), "abc"), []testCase{ - {"abc not startsWith abc => false", rowOf(1, 2, nil, nil, nil, "abc"), false}, - {"xabc not startsWith abc => true", rowOf(1, 2, nil, nil, nil, "xabc"), true}, - {"Abc not startsWith abc => true", rowOf(1, 2, nil, nil, nil, "Abc"), true}, - {"a not startsWith abc => true", rowOf(1, 2, nil, nil, nil, "a"), true}, - {"abcd not startsWith abc => false", rowOf(1, 2, nil, nil, nil, "abcd"), false}, - {"nil not startsWith abc => true", rowOf(1, 2, nil, nil, nil, nil), true}, - }}, - } - - for _, tt := range tests { - t.Run(tt.exp.String(), func(t *testing.T) { - ev, err := iceberg.ExpressionEvaluator(testSchema, tt.exp, true) - require.NoError(t, err) - - for _, c := range tt.cases { - res, err := ev(c.row) - require.NoError(t, err) - - assert.Equal(t, c.result, res, c.str) - } - }) - } -} - -func TestEvaluatorCmpTypes(t *testing.T) { - sc := iceberg.NewSchema(1, - iceberg.NestedField{ID: 1, Name: "a", Type: iceberg.PrimitiveTypes.Bool}, - iceberg.NestedField{ID: 2, Name: "b", Type: iceberg.PrimitiveTypes.Int32}, - iceberg.NestedField{ID: 3, Name: "c", Type: iceberg.PrimitiveTypes.Int64}, - iceberg.NestedField{ID: 4, Name: "d", Type: iceberg.PrimitiveTypes.Float32}, - iceberg.NestedField{ID: 5, Name: "e", Type: iceberg.PrimitiveTypes.Float64}, - iceberg.NestedField{ID: 6, Name: "f", Type: iceberg.PrimitiveTypes.Date}, - iceberg.NestedField{ID: 7, Name: "g", Type: iceberg.PrimitiveTypes.Time}, - iceberg.NestedField{ID: 8, Name: "h", Type: iceberg.PrimitiveTypes.Timestamp}, - iceberg.NestedField{ID: 9, Name: "i", Type: iceberg.DecimalTypeOf(9, 2)}, - iceberg.NestedField{ID: 10, Name: "j", Type: iceberg.PrimitiveTypes.String}, - iceberg.NestedField{ID: 11, Name: "k", Type: iceberg.PrimitiveTypes.Binary}, - iceberg.NestedField{ID: 12, Name: "l", Type: iceberg.PrimitiveTypes.UUID}, - iceberg.NestedField{ID: 13, Name: "m", Type: iceberg.FixedTypeOf(5)}) - - rowData := rowOf(true, - 5, 5, float32(5.0), float64(5.0), - 29, 51661919000, 1503066061919234, - iceberg.Decimal{Scale: 2, Val: decimal128.FromI64(3456)}, - "abcdef", []byte{0x01, 0x02, 0x03}, - uuid.New(), []byte{0xDE, 0xAD, 0xBE, 0xEF, 0x0}) - - tests := []struct { - ref iceberg.BooleanExpression - exp bool - }{ - {iceberg.EqualTo(iceberg.Reference("a"), true), true}, - {iceberg.EqualTo(iceberg.Reference("a"), false), false}, - {iceberg.EqualTo(iceberg.Reference("c"), int64(5)), true}, - {iceberg.EqualTo(iceberg.Reference("c"), int64(6)), false}, - {iceberg.EqualTo(iceberg.Reference("d"), int64(5)), true}, - {iceberg.EqualTo(iceberg.Reference("d"), int64(6)), false}, - {iceberg.EqualTo(iceberg.Reference("e"), int64(5)), true}, - {iceberg.EqualTo(iceberg.Reference("e"), int64(6)), false}, - {iceberg.EqualTo(iceberg.Reference("f"), "1970-01-30"), true}, - {iceberg.EqualTo(iceberg.Reference("f"), "1970-01-31"), false}, - {iceberg.EqualTo(iceberg.Reference("g"), "14:21:01.919"), true}, - {iceberg.EqualTo(iceberg.Reference("g"), "14:21:02.919"), false}, - {iceberg.EqualTo(iceberg.Reference("h"), "2017-08-18T14:21:01.919234"), true}, - {iceberg.EqualTo(iceberg.Reference("h"), "2017-08-19T14:21:01.919234"), false}, - {iceberg.LessThan(iceberg.Reference("i"), "32.22"), false}, - {iceberg.GreaterThan(iceberg.Reference("i"), "32.22"), true}, - {iceberg.LessThanEqual(iceberg.Reference("j"), "abcd"), false}, - {iceberg.GreaterThan(iceberg.Reference("j"), "abcde"), true}, - {iceberg.GreaterThan(iceberg.Reference("k"), []byte{0x00}), true}, - {iceberg.LessThan(iceberg.Reference("k"), []byte{0x00}), false}, - {iceberg.EqualTo(iceberg.Reference("l"), uuid.New().String()), false}, - {iceberg.EqualTo(iceberg.Reference("l"), rowData[11].(uuid.UUID)), true}, - {iceberg.EqualTo(iceberg.Reference("m"), []byte{0xDE, 0xAD, 0xBE, 0xEF, 0x1}), false}, - {iceberg.EqualTo(iceberg.Reference("m"), []byte{0xDE, 0xAD, 0xBE, 0xEF, 0x0}), true}, - } - - for _, tt := range tests { - t.Run(tt.ref.String(), func(t *testing.T) { - ev, err := iceberg.ExpressionEvaluator(sc, tt.ref, true) - require.NoError(t, err) - - res, err := ev(rowData) - require.NoError(t, err) - assert.Equal(t, tt.exp, res) - }) - } -} - -func TestRewriteNot(t *testing.T) { - tests := []struct { - expr, expected iceberg.BooleanExpression - }{ - {iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("x"), 34.56)), - iceberg.NotEqualTo(iceberg.Reference("x"), 34.56)}, - {iceberg.NewNot(iceberg.NotEqualTo(iceberg.Reference("x"), 34.56)), - iceberg.EqualTo(iceberg.Reference("x"), 34.56)}, - {iceberg.NewNot(iceberg.IsIn(iceberg.Reference("x"), 34.56, 23.45)), - iceberg.NotIn(iceberg.Reference("x"), 34.56, 23.45)}, - {iceberg.NewNot(iceberg.NewAnd( - iceberg.EqualTo(iceberg.Reference("x"), 34.56), iceberg.EqualTo(iceberg.Reference("y"), 34.56))), - iceberg.NewOr( - iceberg.NotEqualTo(iceberg.Reference("x"), 34.56), iceberg.NotEqualTo(iceberg.Reference("y"), 34.56))}, - {iceberg.NewNot(iceberg.NewOr( - iceberg.EqualTo(iceberg.Reference("x"), 34.56), iceberg.EqualTo(iceberg.Reference("y"), 34.56))), - iceberg.NewAnd(iceberg.NotEqualTo(iceberg.Reference("x"), 34.56), iceberg.NotEqualTo(iceberg.Reference("y"), 34.56))}, - {iceberg.NewNot(iceberg.AlwaysFalse{}), iceberg.AlwaysTrue{}}, - {iceberg.NewNot(iceberg.AlwaysTrue{}), iceberg.AlwaysFalse{}}, - } - - for _, tt := range tests { - t.Run(tt.expr.String(), func(t *testing.T) { - out, err := iceberg.RewriteNotExpr(tt.expr) - require.NoError(t, err) - assert.True(t, out.Equals(tt.expected)) - }) - } -} +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package iceberg_test + +import ( + "math" + "strings" + "testing" + + "github.com/apache/arrow-go/v18/arrow/decimal128" + "github.com/apache/iceberg-go" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type ExampleVisitor struct { + visitHistory []string +} + +func (e *ExampleVisitor) VisitTrue() []string { + e.visitHistory = append(e.visitHistory, "TRUE") + return e.visitHistory +} + +func (e *ExampleVisitor) VisitFalse() []string { + e.visitHistory = append(e.visitHistory, "FALSE") + return e.visitHistory +} + +func (e *ExampleVisitor) VisitNot([]string) []string { + e.visitHistory = append(e.visitHistory, "NOT") + return e.visitHistory +} + +func (e *ExampleVisitor) VisitAnd(_, _ []string) []string { + e.visitHistory = append(e.visitHistory, "AND") + return e.visitHistory +} + +func (e *ExampleVisitor) VisitOr(_, _ []string) []string { + e.visitHistory = append(e.visitHistory, "OR") + return e.visitHistory +} + +func (e *ExampleVisitor) VisitUnbound(pred iceberg.UnboundPredicate) []string { + e.visitHistory = append(e.visitHistory, strings.ToUpper(pred.Op().String())) + return e.visitHistory +} + +func (e *ExampleVisitor) VisitBound(pred iceberg.BoundPredicate) []string { + e.visitHistory = append(e.visitHistory, strings.ToUpper(pred.Op().String())) + return e.visitHistory +} + +type FooBoundExprVisitor struct { + ExampleVisitor +} + +func (e *FooBoundExprVisitor) VisitBound(pred iceberg.BoundPredicate) []string { + return iceberg.VisitBoundPredicate(pred, e) +} + +func (e *FooBoundExprVisitor) VisitUnbound(pred iceberg.UnboundPredicate) []string { + panic("found unbound predicate when evaluating") +} + +func (e *FooBoundExprVisitor) VisitIn(iceberg.BoundTerm, iceberg.Set[iceberg.Literal]) []string { + e.visitHistory = append(e.visitHistory, "IN") + return e.visitHistory +} + +func (e *FooBoundExprVisitor) VisitNotIn(iceberg.BoundTerm, iceberg.Set[iceberg.Literal]) []string { + e.visitHistory = append(e.visitHistory, "NOT_IN") + return e.visitHistory +} + +func (e *FooBoundExprVisitor) VisitIsNan(iceberg.BoundTerm) []string { + e.visitHistory = append(e.visitHistory, "IS_NAN") + return e.visitHistory +} + +func (e *FooBoundExprVisitor) VisitNotNan(iceberg.BoundTerm) []string { + e.visitHistory = append(e.visitHistory, "NOT_NAN") + return e.visitHistory +} + +func (e *FooBoundExprVisitor) VisitIsNull(iceberg.BoundTerm) []string { + e.visitHistory = append(e.visitHistory, "IS_NULL") + return e.visitHistory +} + +func (e *FooBoundExprVisitor) VisitNotNull(iceberg.BoundTerm) []string { + e.visitHistory = append(e.visitHistory, "NOT_NULL") + return e.visitHistory +} + +func (e *FooBoundExprVisitor) VisitEqual(iceberg.BoundTerm, iceberg.Literal) []string { + e.visitHistory = append(e.visitHistory, "EQUAL") + return e.visitHistory +} + +func (e *FooBoundExprVisitor) VisitNotEqual(iceberg.BoundTerm, iceberg.Literal) []string { + e.visitHistory = append(e.visitHistory, "NOT_EQUAL") + return e.visitHistory +} + +func (e *FooBoundExprVisitor) VisitGreaterEqual(iceberg.BoundTerm, iceberg.Literal) []string { + e.visitHistory = append(e.visitHistory, "GREATER_THAN_OR_EQUAL") + return e.visitHistory +} + +func (e *FooBoundExprVisitor) VisitGreater(iceberg.BoundTerm, iceberg.Literal) []string { + e.visitHistory = append(e.visitHistory, "GREATER_THAN") + return e.visitHistory +} + +func (e *FooBoundExprVisitor) VisitLessEqual(iceberg.BoundTerm, iceberg.Literal) []string { + e.visitHistory = append(e.visitHistory, "LESS_THAN_OR_EQUAL") + return e.visitHistory +} + +func (e *FooBoundExprVisitor) VisitLess(iceberg.BoundTerm, iceberg.Literal) []string { + e.visitHistory = append(e.visitHistory, "LESS_THAN") + return e.visitHistory +} + +func (e *FooBoundExprVisitor) VisitStartsWith(iceberg.BoundTerm, iceberg.Literal) []string { + e.visitHistory = append(e.visitHistory, "STARTS_WITH") + return e.visitHistory +} + +func (e *FooBoundExprVisitor) VisitNotStartsWith(iceberg.BoundTerm, iceberg.Literal) []string { + e.visitHistory = append(e.visitHistory, "NOT_STARTS_WITH") + return e.visitHistory +} + +func TestBooleanExprVisitor(t *testing.T) { + expr := iceberg.NewAnd( + iceberg.NewOr( + iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("a"), int32(1))), + iceberg.NewNot(iceberg.NotEqualTo(iceberg.Reference("b"), int32(0))), + iceberg.EqualTo(iceberg.Reference("a"), int32(1)), + iceberg.NotEqualTo(iceberg.Reference("b"), int32(0)), + ), + iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("a"), int32(1))), + iceberg.NotEqualTo(iceberg.Reference("b"), int32(0))) + + visitor := ExampleVisitor{visitHistory: make([]string, 0)} + result, err := iceberg.VisitExpr(expr, &visitor) + require.NoError(t, err) + assert.Equal(t, []string{ + "EQUAL", + "NOT", + "NOTEQUAL", + "NOT", + "OR", + "EQUAL", + "OR", + "NOTEQUAL", + "OR", + "EQUAL", + "NOT", + "AND", + "NOTEQUAL", + "AND", + }, result) +} + +func TestBindVisitorAlready(t *testing.T) { + bound, err := iceberg.EqualTo(iceberg.Reference("foo"), "hello"). + Bind(tableSchemaSimple, false) + require.NoError(t, err) + + _, err = iceberg.BindExpr(tableSchemaSimple, bound, true) + assert.ErrorIs(t, err, iceberg.ErrInvalidArgument) + assert.ErrorContains(t, err, "found already bound predicate: BoundEqual(term=BoundReference(field=1: foo: optional string, accessor=Accessor(position=0, inner=)), literal=hello)") +} + +func TestAlwaysExprBinding(t *testing.T) { + tests := []struct { + expr iceberg.BooleanExpression + expected iceberg.BooleanExpression + }{ + {iceberg.AlwaysTrue{}, iceberg.AlwaysTrue{}}, + {iceberg.AlwaysFalse{}, iceberg.AlwaysFalse{}}, + {iceberg.NewAnd(iceberg.AlwaysTrue{}, iceberg.AlwaysFalse{}), iceberg.AlwaysFalse{}}, + {iceberg.NewOr(iceberg.AlwaysTrue{}, iceberg.AlwaysFalse{}), iceberg.AlwaysTrue{}}, + } + + for _, tt := range tests { + t.Run(tt.expr.String(), func(t *testing.T) { + bound, err := iceberg.BindExpr(tableSchemaSimple, tt.expr, true) + require.NoError(t, err) + assert.Equal(t, tt.expected, bound) + }) + } +} + +func TestBoundBoolExprVisitor(t *testing.T) { + tests := []struct { + expr iceberg.BooleanExpression + expected []string + }{ + {iceberg.NewAnd(iceberg.IsIn(iceberg.Reference("foo"), "foo", "bar"), + iceberg.IsIn(iceberg.Reference("bar"), int32(1), int32(2))), []string{"IN", "IN", "AND"}}, + {iceberg.NewOr(iceberg.NewNot(iceberg.IsIn(iceberg.Reference("foo"), "foo", "bar")), + iceberg.NewNot(iceberg.IsIn(iceberg.Reference("bar"), int32(1), int32(2)))), + []string{"IN", "NOT", "IN", "NOT", "OR"}}, + {iceberg.EqualTo(iceberg.Reference("bar"), int32(1)), []string{"EQUAL"}}, + {iceberg.NotEqualTo(iceberg.Reference("foo"), "foo"), []string{"NOT_EQUAL"}}, + {iceberg.AlwaysTrue{}, []string{"TRUE"}}, + {iceberg.AlwaysFalse{}, []string{"FALSE"}}, + {iceberg.NotIn(iceberg.Reference("foo"), "bar", "foo"), []string{"NOT_IN"}}, + {iceberg.IsNull(iceberg.Reference("foo")), []string{"IS_NULL"}}, + {iceberg.NotNull(iceberg.Reference("foo")), []string{"NOT_NULL"}}, + {iceberg.GreaterThan(iceberg.Reference("foo"), "foo"), []string{"GREATER_THAN"}}, + {iceberg.GreaterThanEqual(iceberg.Reference("foo"), "foo"), []string{"GREATER_THAN_OR_EQUAL"}}, + {iceberg.LessThan(iceberg.Reference("foo"), "foo"), []string{"LESS_THAN"}}, + {iceberg.LessThanEqual(iceberg.Reference("foo"), "foo"), []string{"LESS_THAN_OR_EQUAL"}}, + {iceberg.StartsWith(iceberg.Reference("foo"), "foo"), []string{"STARTS_WITH"}}, + {iceberg.NotStartsWith(iceberg.Reference("foo"), "foo"), []string{"NOT_STARTS_WITH"}}, + } + + for _, tt := range tests { + t.Run(tt.expr.String(), func(t *testing.T) { + bound, err := iceberg.BindExpr(tableSchemaNested, + tt.expr, + true) + require.NoError(t, err) + + visitor := FooBoundExprVisitor{ExampleVisitor: ExampleVisitor{visitHistory: []string{}}} + result, err := iceberg.VisitExpr(bound, &visitor) + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + }) + } +} + +type rowTester []any + +func (r rowTester) Size() int { return len(r) } +func (r rowTester) Get(pos int) any { return r[pos] } +func (r rowTester) Set(pos int, val any) { + r[pos] = val +} + +func rowOf(vals ...any) rowTester { + return rowTester(vals) +} + +var testSchema = iceberg.NewSchema(1, + iceberg.NestedField{ID: 13, Name: "x", + Type: iceberg.PrimitiveTypes.Int32, Required: true}, + iceberg.NestedField{ID: 14, Name: "y", + Type: iceberg.PrimitiveTypes.Float64, Required: true}, + iceberg.NestedField{ID: 15, Name: "z", + Type: iceberg.PrimitiveTypes.Int32}, + iceberg.NestedField{ID: 16, Name: "s1", + Type: &iceberg.StructType{ + FieldList: []iceberg.NestedField{{ + ID: 17, Name: "s2", Required: true, + Type: &iceberg.StructType{ + FieldList: []iceberg.NestedField{{ + ID: 18, Name: "s3", Required: true, + Type: &iceberg.StructType{ + FieldList: []iceberg.NestedField{{ + ID: 19, Name: "s4", Required: true, + Type: &iceberg.StructType{ + FieldList: []iceberg.NestedField{{ + ID: 20, Name: "i", Required: true, + Type: iceberg.PrimitiveTypes.Int32, + }}, + }, + }}, + }, + }}, + }, + }}, + }}, + iceberg.NestedField{ID: 21, Name: "s5", Type: &iceberg.StructType{ + FieldList: []iceberg.NestedField{{ + ID: 22, Name: "s6", Required: true, Type: &iceberg.StructType{ + FieldList: []iceberg.NestedField{{ + ID: 23, Name: "f", Required: true, Type: iceberg.PrimitiveTypes.Float32, + }}, + }, + }}, + }}, + iceberg.NestedField{ID: 24, Name: "s", Type: iceberg.PrimitiveTypes.String}) + +func TestExprEvaluator(t *testing.T) { + type testCase struct { + str string + row rowTester + result bool + } + + tests := []struct { + exp iceberg.BooleanExpression + cases []testCase + }{ + {iceberg.AlwaysTrue{}, []testCase{{"always true", rowOf(), true}}}, + {iceberg.AlwaysFalse{}, []testCase{{"always false", rowOf(), false}}}, + {iceberg.LessThan(iceberg.Reference("x"), int32(7)), []testCase{ + {"7 < 7 => false", rowOf(7, 8, nil, nil), false}, + {"6 < 7 => true", rowOf(6, 8, nil, nil), true}, + }}, + {iceberg.LessThan(iceberg.Reference("s1.s2.s3.s4.i"), int32(7)), []testCase{ + {"7 < 7 => false", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(7))))), false}, + {"6 < 7 => true", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(6))))), true}, + {"nil < 7 => true", rowOf(7, 8, nil, nil), true}, + }}, + {iceberg.LessThanEqual(iceberg.Reference("x"), int32(7)), []testCase{ + {"7 <= 7 => true", rowOf(7, 8, nil), true}, + {"6 <= 7 => true", rowOf(6, 8, nil), true}, + {"8 <= 7 => false", rowOf(8, 8, nil), false}, + }}, + {iceberg.LessThanEqual(iceberg.Reference("s1.s2.s3.s4.i"), int32(7)), []testCase{ + {"7 <= 7 => true", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(7))))), true}, + {"6 <= 7 => true", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(6))))), true}, + {"8 <= 7 => false", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(8))))), false}, + }}, + {iceberg.GreaterThan(iceberg.Reference("x"), int32(7)), []testCase{ + {"7 > 7 => false", rowOf(7, 8, nil), false}, + {"6 > 7 => false", rowOf(6, 8, nil), false}, + {"8 > 7 => true", rowOf(8, 8, nil), true}, + }}, + {iceberg.GreaterThan(iceberg.Reference("s1.s2.s3.s4.i"), int32(7)), []testCase{ + {"7 > 7 => false", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(7))))), false}, + {"6 > 7 => false", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(6))))), false}, + {"8 > 7 => true", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(8))))), true}, + }}, + {iceberg.GreaterThanEqual(iceberg.Reference("x"), int32(7)), []testCase{ + {"7 >= 7 => true", rowOf(7, 8, nil), true}, + {"6 >= 7 => false", rowOf(6, 8, nil), false}, + {"8 >= 7 => true", rowOf(8, 8, nil), true}, + }}, + {iceberg.GreaterThanEqual(iceberg.Reference("s1.s2.s3.s4.i"), int32(7)), []testCase{ + {"7 >= 7 => true", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(7))))), true}, + {"6 >= 7 => false", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(6))))), false}, + {"8 >= 7 => true", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(8))))), true}, + }}, + {iceberg.EqualTo(iceberg.Reference("x"), int32(7)), []testCase{ + {"7 == 7 => true", rowOf(7, 8, nil), true}, + {"6 == 7 => false", rowOf(6, 8, nil), false}, + }}, + {iceberg.EqualTo(iceberg.Reference("s1.s2.s3.s4.i"), int32(7)), []testCase{ + {"7 == 7 => true", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(7))))), true}, + {"6 == 7 => false", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(6))))), false}, + }}, + {iceberg.NotEqualTo(iceberg.Reference("x"), int32(7)), []testCase{ + {"7 != 7 => false", rowOf(7, 8, nil), false}, + {"6 != 7 => true", rowOf(6, 8, nil), true}, + }}, + {iceberg.NotEqualTo(iceberg.Reference("s1.s2.s3.s4.i"), int32(7)), []testCase{ + {"7 != 7 => false", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(7))))), false}, + {"6 != 7 => true", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(6))))), true}, + }}, + {iceberg.IsNull(iceberg.Reference("z")), []testCase{ + {"nil is null", rowOf(1, 2, nil), true}, + {"3 is not null", rowOf(1, 2, 3), false}, + }}, + {iceberg.IsNull(iceberg.Reference("s1.s2.s3.s4.i")), []testCase{ + {"3 is not null", rowOf(1, 2, 3, rowOf(rowOf(rowOf(rowOf(3))))), false}, + }}, + {iceberg.NotNull(iceberg.Reference("z")), []testCase{ + {"nil is null", rowOf(1, 2, nil), false}, + {"3 is not null", rowOf(1, 2, 3), true}, + }}, + {iceberg.NotNull(iceberg.Reference("s1.s2.s3.s4.i")), []testCase{ + {"3 is not null", rowOf(1, 2, 3, rowOf(rowOf(rowOf(rowOf(3))))), true}, + }}, + {iceberg.IsNaN(iceberg.Reference("y")), []testCase{ + {"NaN is NaN", rowOf(1, math.NaN(), 3), true}, + {"2 is not NaN", rowOf(1, 2.0, 3), false}, + }}, + {iceberg.IsNaN(iceberg.Reference("s5.s6.f")), []testCase{ + {"NaN is NaN", rowOf(1, 2, 3, nil, rowOf(rowOf(math.NaN()))), true}, + {"4 is not NaN", rowOf(1, 2, 3, nil, rowOf(rowOf(4.0))), false}, + {"nil is not NaN", rowOf(1, 2, 3, nil, nil), false}, + }}, + {iceberg.NotNaN(iceberg.Reference("y")), []testCase{ + {"NaN is NaN", rowOf(1, math.NaN(), 3), false}, + {"2 is not NaN", rowOf(1, 2.0, 3), true}, + }}, + {iceberg.NotNaN(iceberg.Reference("s5.s6.f")), []testCase{ + {"NaN is NaN", rowOf(1, 2, 3, nil, rowOf(rowOf(math.NaN()))), false}, + {"4 is not NaN", rowOf(1, 2, 3, nil, rowOf(rowOf(4.0))), true}, + }}, + {iceberg.NewAnd(iceberg.EqualTo(iceberg.Reference("x"), int32(7)), iceberg.NotNull(iceberg.Reference("z"))), []testCase{ + {"7, 3 => true", rowOf(7, 0, 3), true}, + {"8, 3 => false", rowOf(8, 0, 3), false}, + {"7, null => false", rowOf(7, 0, nil), false}, + {"8, null => false", rowOf(8, 0, nil), false}, + }}, + {iceberg.NewAnd(iceberg.EqualTo(iceberg.Reference("s1.s2.s3.s4.i"), int32(7)), + iceberg.NotNull(iceberg.Reference("s1.s2.s3.s4.i"))), []testCase{ + {"7, 7 => true", rowOf(5, 0, 3, rowOf(rowOf(rowOf(rowOf(7))))), true}, + {"8, 8 => false", rowOf(7, 0, 3, rowOf(rowOf(rowOf(rowOf(8))))), false}, + {"7, null => false", rowOf(5, 0, 3, nil), false}, + {"8, notnull => false", rowOf(7, 0, 3, rowOf(rowOf(rowOf(rowOf(8))))), false}, + }}, + {iceberg.NewOr(iceberg.EqualTo(iceberg.Reference("x"), int32(7)), iceberg.NotNull(iceberg.Reference("z"))), []testCase{ + {"7, 3 => true", rowOf(7, 0, 3), true}, + {"8, 3 => true", rowOf(8, 0, 3), true}, + {"7, null => true", rowOf(7, 0, nil), true}, + {"8, null => false", rowOf(8, 0, nil), false}, + }}, + {iceberg.NewOr(iceberg.EqualTo(iceberg.Reference("s1.s2.s3.s4.i"), int32(7)), + iceberg.NotNull(iceberg.Reference("s1.s2.s3.s4.i"))), []testCase{ + {"7, 7 => true", rowOf(5, 0, 3, rowOf(rowOf(rowOf(rowOf(7))))), true}, + {"8, notnull => true", rowOf(7, 0, 3, rowOf(rowOf(rowOf(rowOf(8))))), true}, + {"7, null => false", rowOf(5, 0, 3, nil), false}, + {"8, notnull => true", rowOf(7, 0, 3, rowOf(rowOf(rowOf(rowOf(8))))), true}, + }}, + {iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("x"), int32(7))), []testCase{ + {"not(7 == 7) => false", rowOf(7), false}, + {"not(8 == 7) => true", rowOf(8), true}, + }}, + {iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("s1.s2.s3.s4.i"), int32(7))), []testCase{ + {"not(7 == 7) => false", rowOf(7, nil, nil, rowOf(rowOf(rowOf(rowOf(7))))), false}, + {"not(8 == 7) => true", rowOf(7, nil, nil, rowOf(rowOf(rowOf(rowOf(8))))), true}, + }}, + {iceberg.IsIn(iceberg.Reference("x"), int64(7), 8, math.MaxInt64), []testCase{ + {"7 in [7, 8, Int64Max] => true", rowOf(7, 8, nil), true}, + {"9 in [7, 8, Int64Max] => false", rowOf(9, 8, nil), false}, + {"8 in [7, 8, Int64Max] => true", rowOf(8, 8, nil), true}, + }}, + {iceberg.IsIn(iceberg.Reference("x"), int64(math.MaxInt64), math.MaxInt32, math.MinInt64), []testCase{ + {"Int32Max in [Int64Max, Int32Max, Int64Min] => true", rowOf(math.MaxInt32, 7.0, nil), true}, + {"6 in [Int64Max, Int32Max, Int64Min] => false", rowOf(6, 6.9, nil), false}, + }}, + {iceberg.IsIn(iceberg.Reference("y"), float64(7), 8, 9.1), []testCase{ + {"7.0 in [7, 8, 9.1] => true", rowOf(0, 7.0, nil), true}, + {"9.1 in [7, 8, 9.1] => true", rowOf(7, 9.1, nil), true}, + {"6.8 in [7, 8, 9.1] => false", rowOf(7, 6.8, nil), false}, + }}, + {iceberg.IsIn(iceberg.Reference("s1.s2.s3.s4.i"), int32(7), 8, 9), []testCase{ + {"7 in [7, 8, 9] => true", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(7))))), true}, + {"6 in [7, 8, 9] => true", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(6))))), false}, + {"nil in [7, 8, 9] => false", rowOf(7, 8, nil, nil), false}, + }}, + {iceberg.NotIn(iceberg.Reference("x"), int64(7), 8, math.MaxInt64), []testCase{ + {"7 not in [7, 8, Int64Max] => false", rowOf(7, 8, nil), false}, + {"9 not in [7, 8, Int64Max] => true", rowOf(9, 8, nil), true}, + {"8 not in [7, 8, Int64Max] => false", rowOf(8, 8, nil), false}, + }}, + {iceberg.NotIn(iceberg.Reference("x"), int64(math.MaxInt64), math.MaxInt32, math.MinInt64), []testCase{ + {"Int32Max not in [Int64Max, Int32Max, Int64Min] => false", rowOf(math.MaxInt32, 7.0, nil), false}, + {"6 not in [Int64Max, Int32Max, Int64Min] => true", rowOf(6, 6.9, nil), true}, + }}, + {iceberg.NotIn(iceberg.Reference("y"), float64(7), 8, 9.1), []testCase{ + {"7.0 not in [7, 8, 9.1] => false", rowOf(0, 7.0, nil), false}, + {"9.1 not in [7, 8, 9.1] => false", rowOf(7, 9.1, nil), false}, + {"6.8 not in [7, 8, 9.1] => true", rowOf(7, 6.8, nil), true}, + }}, + {iceberg.NotIn(iceberg.Reference("s1.s2.s3.s4.i"), int32(7), 8, 9), []testCase{ + {"7 not in [7, 8, 9] => false", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(7))))), false}, + {"6 not in [7, 8, 9] => false", rowOf(7, 8, nil, rowOf(rowOf(rowOf(rowOf(6))))), true}, + }}, + {iceberg.EqualTo(iceberg.Reference("s"), "abc"), []testCase{ + {"abc == abc => true", rowOf(1, 2, nil, nil, nil, "abc"), true}, + {"abd == abc => false", rowOf(1, 2, nil, nil, nil, "abd"), false}, + }}, + {iceberg.StartsWith(iceberg.Reference("s"), "abc"), []testCase{ + {"abc startsWith abc => true", rowOf(1, 2, nil, nil, nil, "abc"), true}, + {"xabc startsWith abc => false", rowOf(1, 2, nil, nil, nil, "xabc"), false}, + {"Abc startsWith abc => false", rowOf(1, 2, nil, nil, nil, "Abc"), false}, + {"a startsWith abc => false", rowOf(1, 2, nil, nil, nil, "a"), false}, + {"abcd startsWith abc => true", rowOf(1, 2, nil, nil, nil, "abcd"), true}, + {"nil startsWith abc => false", rowOf(1, 2, nil, nil, nil, nil), false}, + }}, + {iceberg.NotStartsWith(iceberg.Reference("s"), "abc"), []testCase{ + {"abc not startsWith abc => false", rowOf(1, 2, nil, nil, nil, "abc"), false}, + {"xabc not startsWith abc => true", rowOf(1, 2, nil, nil, nil, "xabc"), true}, + {"Abc not startsWith abc => true", rowOf(1, 2, nil, nil, nil, "Abc"), true}, + {"a not startsWith abc => true", rowOf(1, 2, nil, nil, nil, "a"), true}, + {"abcd not startsWith abc => false", rowOf(1, 2, nil, nil, nil, "abcd"), false}, + {"nil not startsWith abc => true", rowOf(1, 2, nil, nil, nil, nil), true}, + }}, + } + + for _, tt := range tests { + t.Run(tt.exp.String(), func(t *testing.T) { + ev, err := iceberg.ExpressionEvaluator(testSchema, tt.exp, true) + require.NoError(t, err) + + for _, c := range tt.cases { + res, err := ev(c.row) + require.NoError(t, err) + + assert.Equal(t, c.result, res, c.str) + } + }) + } +} + +func TestEvaluatorCmpTypes(t *testing.T) { + sc := iceberg.NewSchema(1, + iceberg.NestedField{ID: 1, Name: "a", Type: iceberg.PrimitiveTypes.Bool}, + iceberg.NestedField{ID: 2, Name: "b", Type: iceberg.PrimitiveTypes.Int32}, + iceberg.NestedField{ID: 3, Name: "c", Type: iceberg.PrimitiveTypes.Int64}, + iceberg.NestedField{ID: 4, Name: "d", Type: iceberg.PrimitiveTypes.Float32}, + iceberg.NestedField{ID: 5, Name: "e", Type: iceberg.PrimitiveTypes.Float64}, + iceberg.NestedField{ID: 6, Name: "f", Type: iceberg.PrimitiveTypes.Date}, + iceberg.NestedField{ID: 7, Name: "g", Type: iceberg.PrimitiveTypes.Time}, + iceberg.NestedField{ID: 8, Name: "h", Type: iceberg.PrimitiveTypes.Timestamp}, + iceberg.NestedField{ID: 9, Name: "i", Type: iceberg.DecimalTypeOf(9, 2)}, + iceberg.NestedField{ID: 10, Name: "j", Type: iceberg.PrimitiveTypes.String}, + iceberg.NestedField{ID: 11, Name: "k", Type: iceberg.PrimitiveTypes.Binary}, + iceberg.NestedField{ID: 12, Name: "l", Type: iceberg.PrimitiveTypes.UUID}, + iceberg.NestedField{ID: 13, Name: "m", Type: iceberg.FixedTypeOf(5)}) + + rowData := rowOf(true, + 5, 5, float32(5.0), float64(5.0), + 29, 51661919000, 1503066061919234, + iceberg.Decimal{Scale: 2, Val: decimal128.FromI64(3456)}, + "abcdef", []byte{0x01, 0x02, 0x03}, + uuid.New(), []byte{0xDE, 0xAD, 0xBE, 0xEF, 0x0}) + + tests := []struct { + ref iceberg.BooleanExpression + exp bool + }{ + {iceberg.EqualTo(iceberg.Reference("a"), true), true}, + {iceberg.EqualTo(iceberg.Reference("a"), false), false}, + {iceberg.EqualTo(iceberg.Reference("c"), int64(5)), true}, + {iceberg.EqualTo(iceberg.Reference("c"), int64(6)), false}, + {iceberg.EqualTo(iceberg.Reference("d"), int64(5)), true}, + {iceberg.EqualTo(iceberg.Reference("d"), int64(6)), false}, + {iceberg.EqualTo(iceberg.Reference("e"), int64(5)), true}, + {iceberg.EqualTo(iceberg.Reference("e"), int64(6)), false}, + {iceberg.EqualTo(iceberg.Reference("f"), "1970-01-30"), true}, + {iceberg.EqualTo(iceberg.Reference("f"), "1970-01-31"), false}, + {iceberg.EqualTo(iceberg.Reference("g"), "14:21:01.919"), true}, + {iceberg.EqualTo(iceberg.Reference("g"), "14:21:02.919"), false}, + {iceberg.EqualTo(iceberg.Reference("h"), "2017-08-18T14:21:01.919234"), true}, + {iceberg.EqualTo(iceberg.Reference("h"), "2017-08-19T14:21:01.919234"), false}, + {iceberg.LessThan(iceberg.Reference("i"), "32.22"), false}, + {iceberg.GreaterThan(iceberg.Reference("i"), "32.22"), true}, + {iceberg.LessThanEqual(iceberg.Reference("j"), "abcd"), false}, + {iceberg.GreaterThan(iceberg.Reference("j"), "abcde"), true}, + {iceberg.GreaterThan(iceberg.Reference("k"), []byte{0x00}), true}, + {iceberg.LessThan(iceberg.Reference("k"), []byte{0x00}), false}, + {iceberg.EqualTo(iceberg.Reference("l"), uuid.New().String()), false}, + {iceberg.EqualTo(iceberg.Reference("l"), rowData[11].(uuid.UUID)), true}, + {iceberg.EqualTo(iceberg.Reference("m"), []byte{0xDE, 0xAD, 0xBE, 0xEF, 0x1}), false}, + {iceberg.EqualTo(iceberg.Reference("m"), []byte{0xDE, 0xAD, 0xBE, 0xEF, 0x0}), true}, + } + + for _, tt := range tests { + t.Run(tt.ref.String(), func(t *testing.T) { + ev, err := iceberg.ExpressionEvaluator(sc, tt.ref, true) + require.NoError(t, err) + + res, err := ev(rowData) + require.NoError(t, err) + assert.Equal(t, tt.exp, res) + }) + } +} + +func TestRewriteNot(t *testing.T) { + tests := []struct { + expr, expected iceberg.BooleanExpression + }{ + {iceberg.NewNot(iceberg.EqualTo(iceberg.Reference("x"), 34.56)), + iceberg.NotEqualTo(iceberg.Reference("x"), 34.56)}, + {iceberg.NewNot(iceberg.NotEqualTo(iceberg.Reference("x"), 34.56)), + iceberg.EqualTo(iceberg.Reference("x"), 34.56)}, + {iceberg.NewNot(iceberg.IsIn(iceberg.Reference("x"), 34.56, 23.45)), + iceberg.NotIn(iceberg.Reference("x"), 34.56, 23.45)}, + {iceberg.NewNot(iceberg.NewAnd( + iceberg.EqualTo(iceberg.Reference("x"), 34.56), iceberg.EqualTo(iceberg.Reference("y"), 34.56))), + iceberg.NewOr( + iceberg.NotEqualTo(iceberg.Reference("x"), 34.56), iceberg.NotEqualTo(iceberg.Reference("y"), 34.56))}, + {iceberg.NewNot(iceberg.NewOr( + iceberg.EqualTo(iceberg.Reference("x"), 34.56), iceberg.EqualTo(iceberg.Reference("y"), 34.56))), + iceberg.NewAnd(iceberg.NotEqualTo(iceberg.Reference("x"), 34.56), iceberg.NotEqualTo(iceberg.Reference("y"), 34.56))}, + {iceberg.NewNot(iceberg.AlwaysFalse{}), iceberg.AlwaysTrue{}}, + {iceberg.NewNot(iceberg.AlwaysTrue{}), iceberg.AlwaysFalse{}}, + } + + for _, tt := range tests { + t.Run(tt.expr.String(), func(t *testing.T) { + out, err := iceberg.RewriteNotExpr(tt.expr) + require.NoError(t, err) + assert.True(t, out.Equals(tt.expected)) + }) + } +}