diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 191e1fd72e5..8c4aff4faa9 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1 +1,4 @@ * @snowflakedb/snowpark-python-api-reviewers +/src/snowflake/snowpark/modin/ @snowflakedb/snowpandas +/tests/integ/modin/ @snowflakedb/snowpandas +/tests/unit/modin/ @snowflakedb/snowpandas diff --git a/.github/workflows/changedoc_snowpark_pandas.yml b/.github/workflows/changedoc_snowpark_pandas.yml new file mode 100644 index 00000000000..831232860de --- /dev/null +++ b/.github/workflows/changedoc_snowpark_pandas.yml @@ -0,0 +1,22 @@ +name: Snowpark pandas Changedoc Check + +on: + pull_request: + types: [opened, synchronize, labeled, unlabeled] + branches: + - pandas-main + paths: + - 'src/snowflake/snowpark/modin/**' + +jobs: + check_pandas_change_doc: + runs-on: ubuntu-latest + if: ${{!contains(github.event.pull_request.labels.*.name, 'NO-PANDAS-CHANGEDOC-UPDATES')}} + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Ensure Snowpark pandas docs is updated + run: git diff --name-only --diff-filter=ACMRT ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -q "docs/source/modin" diff --git a/.github/workflows/changelog_snowpark_pandas.yml b/.github/workflows/changelog_snowpark_pandas.yml new file mode 100644 index 00000000000..236796774cc --- /dev/null +++ b/.github/workflows/changelog_snowpark_pandas.yml @@ -0,0 +1,22 @@ +name: Snowpark pandas Changelog Check + +on: + pull_request: + types: [opened, synchronize, labeled, unlabeled] + branches: + - pandas-main + paths: + - 'src/snowflake/snowpark/modin/**' + +jobs: + check_pandas_change_log: + runs-on: ubuntu-latest + if: ${{!contains(github.event.pull_request.labels.*.name, 'NO-PANDAS-CHANGELOG-UPDATES')}} + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Ensure PANDAS_CHANGELOG.md is updated + run: git diff --name-only --diff-filter=ACMRT ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | grep -wq "src/snowflake/snowpark/modin/plugin/PANDAS_CHANGELOG.md" diff --git a/.github/workflows/precommit.yml b/.github/workflows/precommit.yml index 6d77c8234ff..6c0b02f8fe2 100644 --- a/.github/workflows/precommit.yml +++ b/.github/workflows/precommit.yml @@ -302,6 +302,115 @@ jobs: .tox/.coverage .tox/coverage.xml + test-snowpark-pandas: + name: Test modin-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + needs: build + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [macos-latest, windows-latest-64-cores, ubuntu-latest-64-cores] + python-version: [ "3.9", "3.10", "3.11" ] + cloud-provider: [aws, azure, gcp] + exclude: + # only run macos with aws py3.9 for doctest + - os: macos-latest + python-version: "3.10" + - os: macos-latest + python-version: "3.11" + - os: macos-latest + python-version: "3.9" + cloud-provider: azure + - os: macos-latest + python-version: "3.9" + cloud-provider: gcp + # only run ubuntu with py3.9 on aws and py3.10 on azure + - os: ubuntu-latest-64-cores + python-version: "3.11" + - os: ubuntu-latest-64-cores + python-version: "3.9" + cloud-provider: azure + - os: ubuntu-latest-64-cores + python-version: "3.9" + cloud-provider: gcp + - os: ubuntu-latest-64-cores + python-version: "3.10" + cloud-provider: aws + - os: ubuntu-latest-64-cores + python-version: "3.10" + cloud-provider: gcp + # only run windows with py3.10 on gcp + - os: windows-latest-64-cores + python-version: "3.9" + - os: windows-latest-64-cores + python-version: "3.10" + - os: windows-latest-64-cores + python-version: "3.11" + cloud-provider: aws + - os: windows-latest-64-cores + python-version: "3.11" + cloud-provider: azure + steps: + - name: Checkout Code + uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Display Python version + run: python -c "import sys; print(sys.version)" + - name: Decrypt parameters.py + shell: bash + run: .github/scripts/decrypt_parameters.sh + env: + PARAMETER_PASSWORD: ${{ secrets.PARAMETER_PASSWORD }} + CLOUD_PROVIDER: ${{ matrix.cloud-provider }} + - name: Download wheel(s) + uses: actions/download-artifact@v4 + with: + name: wheel + path: dist + - name: Show wheels downloaded + run: ls -lh dist + shell: bash + - name: Upgrade setuptools, pip and wheel + run: python -m pip install -U setuptools pip wheel + - name: Install tox + run: python -m pip install tox + # only run doctest for macos on aws + - if: ${{ matrix.os == 'macos-latest' && matrix.cloud-provider == 'aws' }} + name: Run Snowpark pandas API doctests + run: python -m tox -e "py${PYTHON_VERSION}-doctest-snowparkpandasdoctest-modin-ci" + env: + PYTHON_VERSION: ${{ matrix.python-version }} + cloud_provider: ${{ matrix.cloud-provider }} + PYTEST_ADDOPTS: --color=yes --tb=short + TOX_PARALLEL_NO_SPINNER: 1 + # Specify SNOWFLAKE_IS_PYTHON_RUNTIME_TEST: 1 when adding >= python3.11 with no server-side support + # For example, see https://github.com/snowflakedb/snowpark-python/pull/681 + shell: bash + # do not run other tests for macos on aws + - if: ${{ !(matrix.os == 'macos-latest' && matrix.cloud-provider == 'aws') }} + name: Run Snowpark pandas API tests (excluding doctests) + run: python -m tox -e "py${PYTHON_VERSION/\./}-snowparkpandasnotdoctest-modin-ci" + env: + PYTHON_VERSION: ${{ matrix.python-version }} + cloud_provider: ${{ matrix.cloud-provider }} + PYTEST_ADDOPTS: --color=yes --tb=short + TOX_PARALLEL_NO_SPINNER: 1 + shell: bash + - name: Combine coverages + run: python -m tox -e coverage --skip-missing-interpreters false + shell: bash + env: + SNOWFLAKE_IS_PYTHON_RUNTIME_TEST: 1 + - uses: actions/upload-artifact@v4 + with: + name: coverage_${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }}-snowpark-pandas-testing + path: | + .tox/.coverage + .tox/coverage.xml + combine-coverage: if: ${{ success() || failure() }} name: Combine coverage @@ -365,7 +474,7 @@ jobs: - name: Upgrade setuptools and pip run: python -m pip install -U setuptools pip - name: Install Snowpark - run: python -m pip install ".[development, pandas]" + run: python -m pip install ".[modin-development, development, pandas]" - name: Install Sphinx run: python -m pip install sphinx - name: Build document diff --git a/LICENSE.txt b/LICENSE.txt index f024f6cad12..5995cf32a72 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -187,7 +187,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright (c) 2012-2023 Snowflake Computing, Inc. + Copyright (c) 2012-2024 Snowflake Computing, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/ci/check_standalone_function_snowpark_pandas_telemetry_decorator.py b/ci/check_standalone_function_snowpark_pandas_telemetry_decorator.py new file mode 100644 index 00000000000..32b5da09168 --- /dev/null +++ b/ci/check_standalone_function_snowpark_pandas_telemetry_decorator.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# +import ast + + +class DecoratorError(Exception): + pass + + +def check_standalone_function_snowpark_pandas_telemetry_decorator( + target_file: str, + telemetry_decorator_name: str, +) -> None: + """ + Check if all standalone functions in the target file have been decorated by the decorator with + name telemetry_decorator_name. + Raises a DecoratorError if the decorator is missing. + + Args: + target_file (str): Path to the target file. + telemetry_decorator_name: Name of the telemetry decorator that is checked. + """ + # Get the source code of the target file + with open(target_file) as file: + source_code = file.read() + assert source_code.strip(), f"Source code in '{target_file}' is empty." + # Parse the abstract syntax tree + tree = ast.parse(source_code) + + # List of str: function names that need the decorator. + failed_funcs = [] + + # Apply the decorator to the functions with matching return types + # Exclude sub-functions with iter_child_nodes which yields direct child nodes + for node in ast.iter_child_nodes(tree): + if ( + isinstance(node, ast.FunctionDef) # Check if it is function type + and not node.name.startswith( + "_" + ) # the function is not private (does not start with an underscore) + and node.name + ): + has_telemetry_decorator = False + for decorator in node.decorator_list: + if ( + hasattr(decorator, "id") + and decorator.id == telemetry_decorator_name + ): + has_telemetry_decorator = True + break + if not has_telemetry_decorator: + failed_funcs.append(node.name) + if len(failed_funcs) > 0: + raise DecoratorError( + f"functions {failed_funcs} should be decorated with {telemetry_decorator_name}" + ) + + +if __name__ == "__main__": + check_standalone_function_snowpark_pandas_telemetry_decorator( + target_file="src/snowflake/snowpark/modin/pandas/io.py", + telemetry_decorator_name="snowpark_pandas_telemetry_standalone_function_decorator", + ) + check_standalone_function_snowpark_pandas_telemetry_decorator( + target_file="src/snowflake/snowpark/modin/pandas/general.py", + telemetry_decorator_name="snowpark_pandas_telemetry_standalone_function_decorator", + ) + check_standalone_function_snowpark_pandas_telemetry_decorator( + target_file="src/snowflake/snowpark/modin/plugin/extensions/pd_extensions.py", + telemetry_decorator_name="snowpark_pandas_telemetry_standalone_function_decorator", + ) + check_standalone_function_snowpark_pandas_telemetry_decorator( + target_file="src/snowflake/snowpark/modin/plugin/extensions/pd_overrides.py", + telemetry_decorator_name="snowpark_pandas_telemetry_standalone_function_decorator", + ) diff --git a/setup.py b/setup.py index a400f84aec4..454a0ae92be 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,9 @@ THIS_DIR = os.path.dirname(os.path.realpath(__file__)) SRC_DIR = os.path.join(THIS_DIR, "src") SNOWPARK_SRC_DIR = os.path.join(SRC_DIR, "snowflake", "snowpark") +MODIN_DEPENDENCY_VERSION = ( + "==0.28.1" # Snowpark pandas requires modin 0.28.1, which depends on pandas 2.2.1 +) CONNECTOR_DEPENDENCY_VERSION = ">=3.6.0, <4.0.0" INSTALL_REQ_LIST = [ "setuptools>=40.6.0", @@ -65,6 +68,22 @@ "snowflake.snowpark._internal", "snowflake.snowpark._internal.analyzer", "snowflake.snowpark.mock", + "snowflake.snowpark.modin", + "snowflake.snowpark.modin.config", + "snowflake.snowpark.modin.core.dataframe.algebra.default2pandas", + "snowflake.snowpark.modin.core.execution.dispatching", + "snowflake.snowpark.modin.core.execution.dispatching.factories", + "snowflake.snowpark.modin.pandas", + "snowflake.snowpark.modin.pandas.api.extensions", + "snowflake.snowpark.modin.plugin", + "snowflake.snowpark.modin.plugin._internal", + "snowflake.snowpark.modin.plugin.compiler", + "snowflake.snowpark.modin.plugin.docstrings", + "snowflake.snowpark.modin.plugin.default2pandas", + "snowflake.snowpark.modin.plugin.docstrings", + "snowflake.snowpark.modin.plugin.extensions", + "snowflake.snowpark.modin.plugin.io", + "snowflake.snowpark.modin.plugin.utils", ], package_dir={ "": "src", @@ -76,6 +95,9 @@ "pandas": [ f"snowflake-connector-python[pandas]{CONNECTOR_DEPENDENCY_VERSION}", ], + "modin": [ + f"modin{MODIN_DEPENDENCY_VERSION}", + ], "secure-local-storage": [ f"snowflake-connector-python[secure-local-storage]{CONNECTOR_DEPENDENCY_VERSION}", ], @@ -88,6 +110,13 @@ "pytest-timeout", "pre-commit", ], + "modin-development": [ + "pytest-assume", # Snowpark pandas + "decorator", # Snowpark pandas + "scipy", # Snowpark pandas 3rd party library testing + "statsmodels", # Snowpark pandas 3rd party library testing + f"modin{MODIN_DEPENDENCY_VERSION}", + ], "localtest": [ "pandas", "pyarrow", diff --git a/src/snowflake/snowpark/_internal/open_telemetry.py b/src/snowflake/snowpark/_internal/open_telemetry.py index 20b1b81858c..fbe0e8cb03b 100644 --- a/src/snowflake/snowpark/_internal/open_telemetry.py +++ b/src/snowflake/snowpark/_internal/open_telemetry.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # # diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index 8caa7c4d178..5bc44a57ddd 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -927,6 +927,96 @@ def to_df(self, *names: Union[str, Iterable[str]]) -> "DataFrame": new_cols.append(Column(attr).alias(name)) return self.select(new_cols) + @df_collect_api_telemetry + def to_snowpark_pandas( + self, + index_col: Optional[Union[str, List[str]]] = None, + columns: Optional[List[str]] = None, + ) -> "snowflake.snowpark.modin.pandas.DataFrame": + """ + Convert the Snowpark DataFrame to Snowpark pandas DataFrame. + + Args: + index_col: A column name or a list of column names to use as index. + columns: A list of column names for the columns to select from the Snowpark DataFrame. If not specified, select + all columns except ones configured in index_col. + + Returns: + :class:`~snowflake.snowpark.modin.pandas.DataFrame` + A Snowpark pandas DataFrame contains index and data columns based on the snapshot of the current + Snowpark DataFrame, which triggers an eager evaluation. + + If index_col is provided, the specified index_col is selected as the index column(s) for the result dataframe, + otherwise, a default range index from 0 to n - 1 is created as the index column, where n is the number + of rows. Please note that is also used as the start row ordering for the dataframe, but there is no + guarantee that the default row ordering is the same for two Snowpark pandas dataframe created from + the same Snowpark Dataframe. + + If columns are provided, the specified columns are selected as the data column(s) for the result dataframe, + otherwise, all Snowpark DataFrame columns (exclude index_col) are selected as data columns. + + Note: + Transformations performed on the returned Snowpark pandas Dataframe do not affect the Snowpark DataFrame + from which it was created. Call + - :func:`snowflake.snowpark.modin.pandas.to_snowpark ` + to transform a Snowpark pandas DataFrame back to a Snowpark DataFrame. + + The column names used for columns or index_cols must be Normalized Snowflake Identifiers, and the + Normalized Snowflake Identifiers of a Snowpark DataFrame can be displayed by calling df.show(). + For details about Normalized Snowflake Identifiers, please refer to the Note in :func:`~snowflake.snowpark.modin.pandas.read_snowflake` + + `to_snowpark_pandas` works only when the environment is set up correctly for Snowpark pandas. This environment + may require version of Python and pandas different from what Snowpark Python uses If the environment is setup + incorrectly, an error will be raised when `to_snowpark_pandas` is called. + + For Python version support information, please refer to: + - the prerequisites section https://docs.snowflake.com/LIMITEDACCESS/snowpark-pandas#prerequisites + - the installation section https://docs.snowflake.com/LIMITEDACCESS/snowpark-pandas#installing-the-snowpark-pandas-api + + See also: + - :func:`snowflake.snowpark.modin.pandas.to_snowpark ` + - :func:`snowflake.snowpark.modin.pandas.DataFrame.to_snowpark ` + - :func:`snowflake.snowpark.modin.pandas.Series.to_snowpark ` + + Example:: + >>> df = session.create_dataframe([[1, 2, 3]], schema=["a", "b", "c"]) + >>> snowpark_pandas_df = df.to_snowpark_pandas() # doctest: +SKIP + >>> snowpark_pandas_df # doctest: +SKIP +NORMALIZE_WHITESPACE + A B C + 0 1 2 3 + + >>> snowpark_pandas_df = df.to_snowpark_pandas(index_col='A') # doctest: +SKIP + >>> snowpark_pandas_df # doctest: +SKIP +NORMALIZE_WHITESPACE + B C + A + 1 2 3 + >>> snowpark_pandas_df = df.to_snowpark_pandas(index_col='A', columns=['B']) # doctest: +SKIP + >>> snowpark_pandas_df # doctest: +SKIP +NORMALIZE_WHITESPACE + B + A + 1 2 + >>> snowpark_pandas_df = df.to_snowpark_pandas(index_col=['B', 'A'], columns=['A', 'C', 'A']) # doctest: +SKIP + >>> snowpark_pandas_df # doctest: +SKIP +NORMALIZE_WHITESPACE + A C A + B A + 2 1 1 3 1 + """ + import snowflake.snowpark.modin.pandas as pd # pragma: no cover + + # create a temporary table out of the current snowpark dataframe + temporary_table_name = random_name_for_temp_object( + TempObjectType.TABLE + ) # pragma: no cover + self.write.save_as_table( + temporary_table_name, mode="errorifexists", table_type="temporary" + ) # pragma: no cover + + snowpandas_df = pd.read_snowflake( + name_or_query=temporary_table_name, index_col=index_col, columns=columns + ) # pragma: no cover + + return snowpandas_df + def __getitem__(self, item: Union[str, Column, List, Tuple, int]): if isinstance(item, str): return self.col(item) diff --git a/src/snowflake/snowpark/modin/NOTICE b/src/snowflake/snowpark/modin/NOTICE new file mode 100644 index 00000000000..b2a3d24db71 --- /dev/null +++ b/src/snowflake/snowpark/modin/NOTICE @@ -0,0 +1,249 @@ +Use of the Snowflake Snowpark pandas library is governed by your customer agreement with Snowflake +as Client Software. The open source libraries used in the Snowflake Snowpark pandas library include: + +#### [Modin (Apache 2)](https://github.com/modin-project/modin/) + +``` +Modin + +Copyright (c) 2018-2024 Modin Developers. + + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +``` + +#### [pandas (BSD 3-Clause License)](https://github.com/pandas-dev/pandas) + +``` +BSD 3-Clause License + +Copyright (c) 2008-2011, AQR Capital Management, LLC, Lambda Foundry, Inc. and PyData Development Team +All rights reserved. + +Copyright (c) 2011-2024, Open source contributors. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +``` diff --git a/src/snowflake/snowpark/modin/__init__.py b/src/snowflake/snowpark/modin/__init__.py new file mode 100644 index 00000000000..0fbef920926 --- /dev/null +++ b/src/snowflake/snowpark/modin/__init__.py @@ -0,0 +1,3 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# diff --git a/src/snowflake/snowpark/modin/config/__init__.py b/src/snowflake/snowpark/modin/config/__init__.py new file mode 100644 index 00000000000..16275bef184 --- /dev/null +++ b/src/snowflake/snowpark/modin/config/__init__.py @@ -0,0 +1,119 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +"""Module houses config entities which can be used for Modin behavior tuning.""" + +from snowflake.snowpark.modin.config.envvars import ( + AsvDataSizeConfig, + AsvImplementation, + AsyncReadMode, + BenchmarkMode, + CIAWSAccessKeyID, + CIAWSSecretAccessKey, + CpuCount, + DaskThreadsPerWorker, + DocModule, + DoUseCalcite, + Engine, + EnvironmentVariable, + ExperimentalGroupbyImpl, + ExperimentalNumPyAPI, + GithubCI, + GpuCount, + HdkFragmentSize, + HdkLaunchParameters, + IsDebug, + IsExperimental, + IsRayCluster, + LazyExecution, + LogFileSize, + LogMemoryInterval, + LogMode, + Memory, + MinPartitionSize, + ModinNumpy, + NPartitions, + PersistentPickle, + ProgressBar, + RangePartitioning, + RangePartitioningGroupby, + RayRedisAddress, + RayRedisPassword, + ReadSqlEngine, + StorageFormat, + TestDatasetSize, + TestReadFromPostgres, + TestReadFromSqlServer, + TrackFileLeaks, +) +from snowflake.snowpark.modin.config.pubsub import Parameter, ValueSource + +__all__ = [ + "EnvironmentVariable", + "Parameter", + "ValueSource", + # General settings + "IsDebug", + "Engine", + "StorageFormat", + "CpuCount", + "GpuCount", + "Memory", + # Ray specific + "IsRayCluster", + "RayRedisAddress", + "RayRedisPassword", + "LazyExecution", + # Dask specific + "DaskThreadsPerWorker", + # Partitioning + "NPartitions", + "MinPartitionSize", + # HDK specific + "HdkFragmentSize", + "DoUseCalcite", + "HdkLaunchParameters", + # ASV specific + "TestDatasetSize", + "AsvImplementation", + "AsvDataSizeConfig", + # Specific features + "ProgressBar", + "BenchmarkMode", + "PersistentPickle", + "ModinNumpy", + "ExperimentalNumPyAPI", + "RangePartitioningGroupby", + "RangePartitioning", + "ExperimentalGroupbyImpl", + "AsyncReadMode", + "ReadSqlEngine", + "IsExperimental", + # For tests + "TrackFileLeaks", + "TestReadFromSqlServer", + "TestReadFromPostgres", + "GithubCI", + "CIAWSSecretAccessKey", + "CIAWSAccessKeyID", + # Logging + "LogMode", + "LogMemoryInterval", + "LogFileSize", + # Plugin settings + "DocModule", +] diff --git a/src/snowflake/snowpark/modin/config/__main__.py b/src/snowflake/snowpark/modin/config/__main__.py new file mode 100644 index 00000000000..db55f40c548 --- /dev/null +++ b/src/snowflake/snowpark/modin/config/__main__.py @@ -0,0 +1,108 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +""" +Content of this file should be executed if module `modin.config` is called. + +If module is called (using `python -m modin.config`) configs help will be printed. +Using `-export_path` option configs description can be exported to the external CSV file +provided with this flag. +""" # pragma: no cover + +import argparse # pragma: no cover +from textwrap import dedent # pragma: no cover + +import pandas # pragma: no cover + +import snowflake.snowpark.modin.config as cfg # pragma: no cover + + +def print_config_help() -> None: # pragma: no cover + """Print configs help messages.""" + for objname in sorted(cfg.__all__): + obj = getattr(cfg, objname) + if ( + isinstance(obj, type) + and issubclass(obj, cfg.Parameter) + and not obj.is_abstract + ): + print(f"{obj.get_help()}\n\tCurrent value: {obj.get()}") # noqa: T201 + + +def export_config_help(filename: str) -> None: # pragma: no cover + """ + Export all configs help messages to the CSV file. + + Parameters + ---------- + filename : str + Name of the file to export configs data. + """ + configs_data = [] + default_values = dict( + RayRedisPassword="random string", + CpuCount="multiprocessing.cpu_count()", + NPartitions="equals to MODIN_CPUS env", + ) + for objname in sorted(cfg.__all__): + obj = getattr(cfg, objname) + if ( + isinstance(obj, type) + and issubclass(obj, cfg.Parameter) + and not obj.is_abstract + ): + data = { + "Config Name": obj.__name__, + "Env. Variable Name": getattr( + obj, "varname", "not backed by environment" + ), + "Default Value": default_values.get(obj.__name__, obj._get_default()), + # `Notes` `-` underlining can't be correctly parsed inside csv table by sphinx + "Description": dedent(obj.__doc__ or "").replace( + "Notes\n-----", "Notes:\n" + ), + "Options": obj.choices, + } + configs_data.append(data) + + pandas.DataFrame( + configs_data, + columns=[ + "Config Name", + "Env. Variable Name", + "Default Value", + "Description", + "Options", + ], + ).to_csv(filename, index=False) + + +if __name__ == "__main__": # pragma: no cover + parser = argparse.ArgumentParser() + parser.add_argument( + "--export-path", + dest="export_path", + type=str, + required=False, + default=None, + help="File path to export configs data.", + ) + export_path = parser.parse_args().export_path + if export_path: + export_config_help(export_path) + else: + print_config_help() diff --git a/src/snowflake/snowpark/modin/config/envvars.py b/src/snowflake/snowpark/modin/config/envvars.py new file mode 100644 index 00000000000..63bea0b0f28 --- /dev/null +++ b/src/snowflake/snowpark/modin/config/envvars.py @@ -0,0 +1,913 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +"""Module houses Modin configs originated from environment variables.""" + +import os +import secrets +import sys +import warnings +from textwrap import dedent +from typing import Any, Optional + +from packaging import version +from pandas.util._decorators import doc # type: ignore[attr-defined] + +from snowflake.snowpark.modin.config.pubsub import ( + _TYPE_PARAMS, + _UNSET, + DeprecationDescriptor, + ExactStr, + Parameter, + ValueSource, +) + + +class EnvironmentVariable(Parameter, type=str, abstract=True): # pragma: no cover + """Base class for environment variables-based configuration.""" + + varname: Optional[str] = None + + @classmethod + def _get_raw_from_config(cls) -> str: + """ + Read the value from environment variable. + + Returns + ------- + str + Config raw value. + + Raises + ------ + TypeError + If `varname` is None. + KeyError + If value is absent. + """ + if cls.varname is None: + raise TypeError("varname should not be None") + return os.environ[cls.varname] + + @classmethod + def get_help(cls) -> str: + """ + Generate user-presentable help for the config. + + Returns + ------- + str + """ + help = f"{cls.varname}: {dedent(cls.__doc__ or 'Unknown').strip()}\n\tProvide {_TYPE_PARAMS[cls.type].help}" + if cls.choices: + help += f" (valid examples are: {', '.join(str(c) for c in cls.choices)})" + return help + + +class EnvWithSibilings( + EnvironmentVariable, + # 'type' is a mandatory parameter for '__init_subclasses__', so we have to pass something here, + # this doesn't force child classes to have 'str' type though, they actually can be any type + type=str, +): # pragma: no cover + """Ensure values synchronization between sibling parameters.""" + + _update_sibling = True + + @classmethod + def _sibling(cls) -> type["EnvWithSibilings"]: + """Return a sibling parameter.""" + raise NotImplementedError() + + @classmethod + def get(cls) -> Any: + """ + Get parameter's value and ensure that it's equal to the sibling's value. + + Returns + ------- + Any + """ + sibling = cls._sibling() + + if sibling._value is _UNSET and cls._value is _UNSET: + super().get() + with warnings.catch_warnings(): + # filter warnings that can potentially come from the potentially deprecated sibling + warnings.filterwarnings("ignore", category=FutureWarning) + super(EnvWithSibilings, sibling).get() + + if ( + cls._value_source + == sibling._value_source + == ValueSource.GOT_FROM_CFG_SOURCE + ): + raise ValueError( + f"Configuration is ambiguous. You cannot set '{cls.varname}' and '{sibling.varname}' at the same time." + ) + + # further we assume that there are only two valid sources for the variables: 'GOT_FROM_CFG' and 'DEFAULT', + # as otherwise we wouldn't ended-up in this branch at all, because all other ways of setting a value + # changes the '._value' attribute from '_UNSET' to something meaningful + from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage + + if cls._value_source == ValueSource.GOT_FROM_CFG_SOURCE: + ErrorMessage.catch_bugs_and_request_email( + failure_condition=sibling._value_source != ValueSource.DEFAULT + ) + sibling._value = cls._value + sibling._value_source = ValueSource.GOT_FROM_CFG_SOURCE + elif sibling._value_source == ValueSource.GOT_FROM_CFG_SOURCE: + ErrorMessage.catch_bugs_and_request_email( + failure_condition=cls._value_source != ValueSource.DEFAULT + ) + cls._value = sibling._value + cls._value_source = ValueSource.GOT_FROM_CFG_SOURCE + else: + ErrorMessage.catch_bugs_and_request_email( + failure_condition=cls._value_source != ValueSource.DEFAULT + or sibling._value_source != ValueSource.DEFAULT + ) + # propagating 'cls' default value to the sibling + sibling._value = cls._value + return super().get() + + @classmethod + def put(cls, value: Any) -> None: + """ + Set a new value to this parameter as well as to its sibling. + + Parameters + ---------- + value : Any + """ + super().put(value) + # avoid getting into an infinite recursion + if cls._update_sibling: + cls._update_sibling = False + try: + with warnings.catch_warnings(): + # filter potential future warnings of the sibling + warnings.filterwarnings("ignore", category=FutureWarning) + cls._sibling().put(value) + finally: + cls._update_sibling = True + + +class IsDebug(EnvironmentVariable, type=bool): # pragma: no cover + """Force Modin engine to be "Python" unless specified by $MODIN_ENGINE.""" + + varname = "MODIN_DEBUG" + + +class Engine(EnvironmentVariable, type=str): # pragma: no cover + """Distribution engine to run queries by.""" + + varname = "MODIN_ENGINE" + choices = ("Ray", "Dask", "Python", "Native", "Unidist") + + NOINIT_ENGINES = { + "Python", + } # engines that don't require initialization, useful for unit tests + + has_custom_engine = False + + @classmethod + def _get_default(cls) -> str: + """ + Get default value of the config. + + Returns + ------- + str + """ + from snowflake.snowpark.modin.utils import ( + MIN_DASK_VERSION, + MIN_RAY_VERSION, + MIN_UNIDIST_VERSION, + ) + + # If there's a custom engine, we don't need to check for any engine + # dependencies. Return the default "Python" engine. + if IsDebug.get() or cls.has_custom_engine: + return "Python" + try: + import ray + + except ImportError: + pass + else: + if version.parse(ray.__version__) < MIN_RAY_VERSION: + raise ImportError( + 'Please `pip install "modin[ray]"` to install compatible Ray ' + + "version " + + f"(>={MIN_RAY_VERSION})." + ) + return "Ray" + try: + import dask + import distributed + + except ImportError: + pass + else: + if ( + version.parse(dask.__version__) < MIN_DASK_VERSION + or version.parse(distributed.__version__) < MIN_DASK_VERSION + ): + raise ImportError( + f'Please `pip install "modin[dask]"` to install compatible Dask version (>={MIN_DASK_VERSION}).' + ) + return "Dask" + try: + # We import ``DbWorker`` from this module since correct import of ``DbWorker`` itself + # from HDK is located in it with all the necessary options for dlopen. + from modin.experimental.core.execution.native.implementations.hdk_on_native.db_worker import ( # noqa + DbWorker, + ) + except ImportError: + pass + else: + return "Native" + try: + import unidist + + except ImportError: + pass + else: + if version.parse(unidist.__version__) < MIN_UNIDIST_VERSION: + raise ImportError( + 'Please `pip install "unidist[mpi]"` to install compatible unidist on MPI ' + + "version " + + f"(>={MIN_UNIDIST_VERSION})." + ) + return "Unidist" + raise ImportError( + "Please refer to installation documentation page to install an engine" + ) + + @classmethod + @doc(Parameter.add_option.__doc__) + def add_option(cls, choice: Any) -> Any: + choice = super().add_option(choice) + cls.NOINIT_ENGINES.add(choice) + cls.has_custom_engine = True + return choice + + +class StorageFormat(EnvironmentVariable, type=str): # pragma: no cover + """Engine to run on a single node of distribution.""" + + varname = "MODIN_STORAGE_FORMAT" + default = "Pandas" + choices = ("Pandas", "Hdk", "Cudf") + + +class IsExperimental(EnvironmentVariable, type=bool): # pragma: no cover + """Whether to Turn on experimental features.""" + + varname = "MODIN_EXPERIMENTAL" + + +class IsRayCluster(EnvironmentVariable, type=bool): # pragma: no cover + """Whether Modin is running on pre-initialized Ray cluster.""" + + varname = "MODIN_RAY_CLUSTER" + + +class RayRedisAddress(EnvironmentVariable, type=ExactStr): # pragma: no cover + """Redis address to connect to when running in Ray cluster.""" + + varname = "MODIN_REDIS_ADDRESS" + + +class RayRedisPassword(EnvironmentVariable, type=ExactStr): # pragma: no cover + """What password to use for connecting to Redis.""" + + varname = "MODIN_REDIS_PASSWORD" + default = secrets.token_hex(32) + + +class CpuCount(EnvironmentVariable, type=int): # pragma: no cover + """How many CPU cores to use during initialization of the Modin engine.""" + + varname = "MODIN_CPUS" + + @classmethod + def _get_default(cls) -> int: + """ + Get default value of the config. + + Returns + ------- + int + """ + import multiprocessing + + return multiprocessing.cpu_count() + + +class GpuCount(EnvironmentVariable, type=int): # pragma: no cover + """How may GPU devices to utilize across the whole distribution.""" + + varname = "MODIN_GPUS" + + +class Memory(EnvironmentVariable, type=int): # pragma: no cover + """ + How much memory (in bytes) give to an execution engine. + + Notes + ----- + * In Ray case: the amount of memory to start the Plasma object store with. + * In Dask case: the amount of memory that is given to each worker depending on CPUs used. + """ + + varname = "MODIN_MEMORY" + + +class NPartitions(EnvironmentVariable, type=int): # pragma: no cover + """How many partitions to use for a Modin DataFrame (along each axis).""" + + varname = "MODIN_NPARTITIONS" + + @classmethod + def _put(cls, value: int) -> None: + """ + Put specific value if NPartitions wasn't set by a user yet. + + Parameters + ---------- + value : int + Config value to set. + + Notes + ----- + This method is used to set NPartitions from cluster resources internally + and should not be called by a user. + """ + if cls.get_value_source() == ValueSource.DEFAULT: + cls.put(value) + + @classmethod + def _get_default(cls) -> int: + """ + Get default value of the config. + + Returns + ------- + int + """ + if StorageFormat.get() == "Cudf": + return GpuCount.get() + else: + return CpuCount.get() + + +class HdkFragmentSize(EnvironmentVariable, type=int): # pragma: no cover + """How big a fragment in HDK should be when creating a table (in rows).""" + + varname = "MODIN_HDK_FRAGMENT_SIZE" + + +class DoUseCalcite(EnvironmentVariable, type=bool): # pragma: no cover + """Whether to use Calcite for HDK queries execution.""" + + varname = "MODIN_USE_CALCITE" + default = True + + +class TestDatasetSize(EnvironmentVariable, type=str): # pragma: no cover + """Dataset size for running some tests.""" + + varname = "MODIN_TEST_DATASET_SIZE" + choices = ("Small", "Normal", "Big") + + +class TrackFileLeaks(EnvironmentVariable, type=bool): # pragma: no cover + """Whether to track for open file handles leakage during testing.""" + + varname = "MODIN_TEST_TRACK_FILE_LEAKS" + # Turn off tracking on Windows by default because + # psutil's open_files() can be extremely slow on Windows (up to adding a few hours). + # see https://github.com/giampaolo/psutil/pull/597 + default = sys.platform != "win32" + + +class AsvImplementation(EnvironmentVariable, type=ExactStr): # pragma: no cover + """Allows to select a library that we will use for testing performance.""" + + varname = "MODIN_ASV_USE_IMPL" + choices = ("modin", "pandas") + + default = "modin" + + +class AsvDataSizeConfig(EnvironmentVariable, type=ExactStr): # pragma: no cover + """Allows to override default size of data (shapes).""" + + varname = "MODIN_ASV_DATASIZE_CONFIG" + default = None + + +class ProgressBar(EnvironmentVariable, type=bool): # pragma: no cover + """Whether or not to show the progress bar.""" + + varname = "MODIN_PROGRESS_BAR" + default = False + + @classmethod + def enable(cls) -> None: + """Enable ``ProgressBar`` feature.""" + cls.put(True) + + @classmethod + def disable(cls) -> None: + """Disable ``ProgressBar`` feature.""" + cls.put(False) + + @classmethod + def put(cls, value: bool) -> None: + """ + Set ``ProgressBar`` value only if synchronous benchmarking is disabled. + + Parameters + ---------- + value : bool + Config value to set. + """ + if value and BenchmarkMode.get(): + raise ValueError("ProgressBar isn't compatible with BenchmarkMode") + super().put(value) + + +class BenchmarkMode(EnvironmentVariable, type=bool): # pragma: no cover + """Whether or not to perform computations synchronously.""" + + varname = "MODIN_BENCHMARK_MODE" + default = False + + @classmethod + def put(cls, value: bool) -> None: + """ + Set ``BenchmarkMode`` value only if progress bar feature is disabled. + + Parameters + ---------- + value : bool + Config value to set. + """ + if value and ProgressBar.get(): + raise ValueError("BenchmarkMode isn't compatible with ProgressBar") + super().put(value) + + +class LogMode(EnvironmentVariable, type=ExactStr): # pragma: no cover + """Set ``LogMode`` value if users want to opt-in.""" + + varname = "MODIN_LOG_MODE" + choices = ("enable", "disable", "enable_api_only") + default = "disable" + + @classmethod + def enable(cls) -> None: + """Enable all logging levels.""" + cls.put("enable") + + @classmethod + def disable(cls) -> None: + """Disable logging feature.""" + cls.put("disable") + + @classmethod + def enable_api_only(cls) -> None: + """Enable API level logging.""" + cls.put("enable_api_only") + + +class LogMemoryInterval(EnvironmentVariable, type=int): # pragma: no cover + """Interval (in seconds) to profile memory utilization for logging.""" + + varname = "MODIN_LOG_MEMORY_INTERVAL" + default = 5 + + @classmethod + def put(cls, value: int) -> None: + """ + Set ``LogMemoryInterval`` with extra checks. + + Parameters + ---------- + value : int + Config value to set. + """ + if value <= 0: + raise ValueError(f"Log memory Interval should be > 0, passed value {value}") + super().put(value) + + @classmethod + def get(cls) -> int: + """ + Get ``LogMemoryInterval`` with extra checks. + + Returns + ------- + int + """ + log_memory_interval = super().get() + assert log_memory_interval > 0, "`LogMemoryInterval` should be > 0" + return log_memory_interval + + +class LogFileSize(EnvironmentVariable, type=int): # pragma: no cover + """Max size of logs (in MBs) to store per Modin job.""" + + varname = "MODIN_LOG_FILE_SIZE" + default = 10 + + @classmethod + def put(cls, value: int) -> None: + """ + Set ``LogFileSize`` with extra checks. + + Parameters + ---------- + value : int + Config value to set. + """ + if value <= 0: + raise ValueError(f"Log file size should be > 0 MB, passed value {value}") + super().put(value) + + @classmethod + def get(cls) -> int: + """ + Get ``LogFileSize`` with extra checks. + + Returns + ------- + int + """ + log_file_size = super().get() + assert log_file_size > 0, "`LogFileSize` should be > 0" + return log_file_size + + +class PersistentPickle(EnvironmentVariable, type=bool): # pragma: no cover + """Whether serialization should be persistent.""" + + varname = "MODIN_PERSISTENT_PICKLE" + # When set to off, it allows faster serialization which is only + # valid in current run (i.e. useless for saving to disk). + # When set to on, Modin objects could be saved to disk and loaded + # but serialization/deserialization could take more time. + default = False + + +class HdkLaunchParameters(EnvironmentVariable, type=dict): # pragma: no cover + """ + Additional command line options for the HDK engine. + + Please visit OmniSci documentation for the description of available parameters: + https://docs.omnisci.com/installation-and-configuration/config-parameters#configuration-parameters-for-omniscidb + """ + + varname = "MODIN_HDK_LAUNCH_PARAMETERS" + + @classmethod + def get(cls) -> dict: + """ + Get the resulted command-line options. + + Decode and merge specified command-line options with the default one. + + Returns + ------- + dict + Decoded and verified config value. + """ + custom_parameters = super().get() + result = cls._get_default().copy() + result.update( + {key.replace("-", "_"): value for key, value in custom_parameters.items()} + ) + return result + + @classmethod + def _get_default(cls) -> Any: + """ + Get default value of the config. Checks the pyhdk version and omits variables unsupported in prior versions. + + Returns + ------- + dict + Config keys and corresponding values. + """ + if (default := getattr(cls, "default", None)) is None: + cls.default = default = { + "enable_union": 1, + "enable_columnar_output": 1, + "enable_lazy_fetch": 0, + "null_div_by_zero": 1, + "enable_watchdog": 0, + "enable_thrift_logs": 0, + "enable_multifrag_execution_result": 1, + "cpu_only": 1, + } + + try: + import pyhdk + + if version.parse(pyhdk.__version__) >= version.parse("0.6.1"): + default["enable_lazy_dict_materialization"] = 0 + default["log_dir"] = "pyhdk_log" + except ImportError: + # if pyhdk is not available, do not show any additional options + pass + return default + + +class MinPartitionSize(EnvironmentVariable, type=int): # pragma: no cover + """ + Minimum number of rows/columns in a single pandas partition split. + + Once a partition for a pandas dataframe has more than this many elements, + Modin adds another partition. + """ + + varname = "MODIN_MIN_PARTITION_SIZE" + default = 32 + + @classmethod + def put(cls, value: int) -> None: + """ + Set ``MinPartitionSize`` with extra checks. + + Parameters + ---------- + value : int + Config value to set. + """ + if value <= 0: + raise ValueError(f"Min partition size should be > 0, passed value {value}") + super().put(value) + + @classmethod + def get(cls) -> int: + """ + Get ``MinPartitionSize`` with extra checks. + + Returns + ------- + int + """ + min_partition_size = super().get() + assert min_partition_size > 0, "`min_partition_size` should be > 0" + return min_partition_size + + +class TestReadFromSqlServer(EnvironmentVariable, type=bool): # pragma: no cover + """Set to true to test reading from SQL server.""" + + varname = "MODIN_TEST_READ_FROM_SQL_SERVER" + default = False + + +class TestReadFromPostgres(EnvironmentVariable, type=bool): # pragma: no cover + """Set to true to test reading from Postgres.""" + + varname = "MODIN_TEST_READ_FROM_POSTGRES" + default = False + + +class GithubCI(EnvironmentVariable, type=bool): # pragma: no cover + """Set to true when running Modin in GitHub CI.""" + + varname = "MODIN_GITHUB_CI" + default = False + + +class ModinNumpy(EnvWithSibilings, type=bool): # pragma: no cover + """Set to true to use Modin's implementation of NumPy API.""" + + varname = "MODIN_NUMPY" + default = False + + @classmethod + def _sibling(cls) -> type[EnvWithSibilings]: + """Get a parameter sibling.""" + return ExperimentalNumPyAPI + + +class ExperimentalNumPyAPI(EnvWithSibilings, type=bool): # pragma: no cover + """ + Set to true to use Modin's implementation of NumPy API. + + This parameter is deprecated. Use ``ModinNumpy`` instead. + """ + + varname = "MODIN_EXPERIMENTAL_NUMPY_API" + default = False + + @classmethod + def _sibling(cls) -> type[EnvWithSibilings]: + """Get a parameter sibling.""" + return ModinNumpy + + +# Let the parameter's handling logic know that this variable is deprecated and that +# we should raise respective warnings +ExperimentalNumPyAPI._deprecation_descriptor = DeprecationDescriptor( + ExperimentalNumPyAPI, ModinNumpy +) + + +class RangePartitioningGroupby(EnvWithSibilings, type=bool): # pragma: no cover + """ + Set to true to use Modin's range-partitioning group by implementation. + + Experimental groupby is implemented using a range-partitioning technique, + note that it may not always work better than the original Modin's TreeReduce + and FullAxis implementations. For more information visit the according section + of Modin's documentation: TODO: add a link to the section once it's written. + """ + + varname = "MODIN_RANGE_PARTITIONING_GROUPBY" + default = False + + @classmethod + def _sibling(cls) -> type[EnvWithSibilings]: + """Get a parameter sibling.""" + return ExperimentalGroupbyImpl + + +class ExperimentalGroupbyImpl(EnvWithSibilings, type=bool): # pragma: no cover + """ + Set to true to use Modin's range-partitioning group by implementation. + + This parameter is deprecated. Use ``RangePartitioningGroupby`` instead. + """ + + varname = "MODIN_EXPERIMENTAL_GROUPBY" + default = False + + @classmethod + def _sibling(cls) -> type[EnvWithSibilings]: + """Get a parameter sibling.""" + return RangePartitioningGroupby + + +# Let the parameter's handling logic know that this variable is deprecated and that +# we should raise respective warnings +ExperimentalGroupbyImpl._deprecation_descriptor = DeprecationDescriptor( + ExperimentalGroupbyImpl, RangePartitioningGroupby +) + + +class RangePartitioning(EnvironmentVariable, type=bool): # pragma: no cover + """ + Set to true to use Modin's range-partitioning implementation where possible. + + Please refer to documentation for cases where enabling this options would be beneficial: + https://modin.readthedocs.io/en/stable/flow/modin/experimental/range_partitioning_groupby.html + """ + + varname = "MODIN_RANGE_PARTITIONING" + default = False + + +class CIAWSSecretAccessKey(EnvironmentVariable, type=str): # pragma: no cover + """Set to AWS_SECRET_ACCESS_KEY when running mock S3 tests for Modin in GitHub CI.""" + + varname = "AWS_SECRET_ACCESS_KEY" + default = "foobar_secret" + + +class CIAWSAccessKeyID(EnvironmentVariable, type=str): # pragma: no cover + """Set to AWS_ACCESS_KEY_ID when running mock S3 tests for Modin in GitHub CI.""" + + varname = "AWS_ACCESS_KEY_ID" + default = "foobar_key" + + +class AsyncReadMode(EnvironmentVariable, type=bool): # pragma: no cover + """ + It does not wait for the end of reading information from the source. + + It basically means, that the reading function only launches tasks for the dataframe + to be read/created, but not ensures that the construction is finalized by the time + the reading function returns a dataframe. + + This option was brought to improve performance of reading/construction + of Modin DataFrames, however it may also: + + 1. Increase the peak memory consumption. Since the garbage collection of the + temporary objects created during the reading is now also lazy and will only + be performed when the reading/construction is actually finished. + + 2. Can break situations when the source is manually deleted after the reading + function returns a result, for example, when reading inside of a context-block + that deletes the file on ``__exit__()``. + """ + + varname = "MODIN_ASYNC_READ_MODE" + default = False + + +class ReadSqlEngine(EnvironmentVariable, type=str): # pragma: no cover + """Engine to run `read_sql`.""" + + varname = "MODIN_READ_SQL_ENGINE" + default = "Pandas" + choices = ("Pandas", "Connectorx") + + +class LazyExecution(EnvironmentVariable, type=str): # pragma: no cover + """ + Lazy execution mode. + + Supported values: + `Auto` - the execution mode is chosen by the engine for each operation (default value). + `On` - the lazy execution is performed wherever it's possible. + `Off` - the lazy execution is disabled. + """ + + varname = "MODIN_LAZY_EXECUTION" + choices = ("Auto", "On", "Off") + default = "Auto" + + +class DocModule(EnvironmentVariable, type=ExactStr): # pragma: no cover + """ + The module to use that will be used for docstrings. + + The value set here must be a valid, importable module. It should have + a `DataFrame`, `Series`, and/or several APIs directly (e.g. `read_csv`). + """ + + varname = "MODIN_DOC_MODULE" + default = "pandas" + + +class DaskThreadsPerWorker(EnvironmentVariable, type=int): # pragma: no cover + """Number of threads per Dask worker.""" + + varname = "MODIN_DASK_THREADS_PER_WORKER" + default = 1 + + +def _check_vars() -> None: # pragma: no cover + """ + Check validity of environment variables. + + Look out for any environment variables that start with "MODIN_" prefix + that are unknown - they might be a typo, so warn a user. + """ + valid_names = { + obj.varname + for obj in globals().values() + if isinstance(obj, type) + and issubclass(obj, EnvironmentVariable) + and not obj.is_abstract + } + valid_names.add("MODIN_PYTEST_CMD") + found_names = {name for name in os.environ if name.startswith("MODIN_")} + unknown = found_names - valid_names + deprecated: dict[str, DeprecationDescriptor] = { + obj.varname: obj._deprecation_descriptor + for obj in globals().values() + if isinstance(obj, type) + and issubclass(obj, EnvironmentVariable) + and not obj.is_abstract + and obj.varname is not None + and obj._deprecation_descriptor is not None + } + found_deprecated = found_names & deprecated.keys() + if unknown: + warnings.warn( # noqa: B028 + f"Found unknown environment variable{'s' if len(unknown) > 1 else ''}," + + f" please check {'their' if len(unknown) > 1 else 'its'} spelling: " + + ", ".join(sorted(unknown)) + ) + for depr_var in found_deprecated: + warnings.warn( # noqa: B028 + deprecated[depr_var].deprecation_message(use_envvar_names=True), + FutureWarning, + ) + + +_check_vars() diff --git a/src/snowflake/snowpark/modin/config/pubsub.py b/src/snowflake/snowpark/modin/config/pubsub.py new file mode 100644 index 00000000000..f83fe3647a8 --- /dev/null +++ b/src/snowflake/snowpark/modin/config/pubsub.py @@ -0,0 +1,447 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +"""Module houses ``Parameter`` class - base class for all configs.""" + +import warnings +from collections import defaultdict +from enum import IntEnum +from typing import TYPE_CHECKING, Any, Callable, DefaultDict, NamedTuple, Optional, cast + +if TYPE_CHECKING: # pragma: no cover + from snowflake.snowpark.modin.config.envvars import EnvironmentVariable + + +class DeprecationDescriptor: # pragma: no cover + """ + Describe deprecated parameter. + + Parameters + ---------- + parameter : type[Parameter] + Deprecated parameter. + new_parameter : type[Parameter], optional + If there's a replacement parameter for the deprecated one, specify it here. + when_removed : str, optional + If known, the exact release when the deprecated parameter is planned to be removed. + """ + + _parameter: type["Parameter"] + _new_parameter: Optional[type["Parameter"]] + _when_removed: str + + def __init__( # noqa: FIR100 + self, + parameter: type["Parameter"], + new_parameter: Optional[type["Parameter"]] = None, + when_removed: Optional[str] = None, + ): + self._parameter = parameter + self._new_parameter = new_parameter + self._when_removed = "a future" if when_removed is None else when_removed + + def deprecation_message(self, use_envvar_names: bool = False) -> str: + """ + Generate a message to be used in a warning raised when using the deprecated parameter. + + Parameters + ---------- + use_envvar_names : bool, default: False + Whether to use environment variable names in the warning. If ``True``, both + ``self._parameter`` and ``self._new_parameter`` have to be a type of ``EnvironmentVariable``. + + Returns + ------- + str + """ + name = ( + cast("EnvironmentVariable", self._parameter).varname + if use_envvar_names + else self._parameter.__name__ + ) + msg = f"'{name}' is deprecated and will be removed in {self._when_removed} version." + if self._new_parameter is not None: + new_name = ( + cast("EnvironmentVariable", self._new_parameter).varname + if use_envvar_names + else self._new_parameter.__name__ + ) + msg += f" Use '{new_name}' instead." + return msg + + +class TypeDescriptor(NamedTuple): # pragma: no cover + """ + Class for config data manipulating of exact type. + + Parameters + ---------- + decode : callable + Callable to decode config value from the raw data. + normalize : callable + Callable to bring different config value variations to + the single form. + verify : callable + Callable to check that config value satisfies given config + type requirements. + help : str + Class description string. + """ + + decode: Callable[[str], object] + normalize: Callable[[object], object] + verify: Callable[[object], bool] + help: str + + +class ExactStr(str): # pragma: no cover + """Class to be used in type params where no transformations are needed.""" + + +_TYPE_PARAMS = { + str: TypeDescriptor( + decode=lambda value: value.strip().title(), + normalize=lambda value: str(value).strip().title(), + verify=lambda value: True, + help="a case-insensitive string", + ), + ExactStr: TypeDescriptor( + decode=lambda value: value, + normalize=lambda value: value, + verify=lambda value: True, + help="a string", + ), + bool: TypeDescriptor( + decode=lambda value: value.strip().lower() in {"true", "yes", "1"}, + normalize=bool, + verify=lambda value: isinstance(value, bool) + or ( + isinstance(value, str) + and value.strip().lower() in {"true", "yes", "1", "false", "no", "0"} + ), + help="a boolean flag (any of 'true', 'yes' or '1' in case insensitive manner is considered positive)", + ), + int: TypeDescriptor( + decode=lambda value: int(value.strip()), + normalize=int, # type: ignore + verify=lambda value: isinstance(value, int) + or (isinstance(value, str) and value.strip().isdigit()), + help="an integer value", + ), + dict: TypeDescriptor( + decode=lambda value: { + key: int(val) if val.isdigit() else val + for key_value in value.split(",") + for key, val in [[v.strip() for v in key_value.split("=", maxsplit=1)]] + }, + normalize=lambda value: ( + value + if isinstance(value, dict) + else { + key: int(val) if val.isdigit() else val + for key_value in str(value).split(",") + for key, val in [[v.strip() for v in key_value.split("=", maxsplit=1)]] + } + ), + verify=lambda value: isinstance(value, dict) + or ( + isinstance(value, str) + and all( + key_value.find("=") not in (-1, len(key_value) - 1) + for key_value in value.split(",") + ) + ), + help="a sequence of KEY=VALUE values separated by comma (Example: 'KEY1=VALUE1,KEY2=VALUE2,KEY3=VALUE3')", + ), +} + +# special marker to distinguish unset value from None value +# as someone may want to use None as a real value for a parameter +_UNSET = object() + + +class ValueSource(IntEnum): # noqa: PR01 # pragma: no cover + """Class that describes the method of getting the value for a parameter.""" + + # got from default, i.e. neither user nor configuration source had the value + DEFAULT = 0 + # set by user + SET_BY_USER = 1 + # got from parameter configuration source, like environment variable + GOT_FROM_CFG_SOURCE = 2 + + +class Parameter: # pragma: no cover + """ + Base class describing interface for configuration entities. + + Attributes + ---------- + choices : Optional[Sequence[str]] + Array with possible options of ``Parameter`` values. + type : str + String that denotes ``Parameter`` type. + default : Optional[Any] + ``Parameter`` default value. + is_abstract : bool, default: True + Whether or not ``Parameter`` is abstract. + _value_source : Optional[ValueSource] + Source of the ``Parameter`` value, should be set by + ``ValueSource``. + _deprecation_descriptor : Optional[DeprecationDescriptor] + Indicate whether this parameter is deprecated. + """ + + choices: Optional[tuple[str, ...]] = None + type = str + default: Optional[Any] = None + is_abstract = True + _value_source: Optional[ValueSource] = None + _value: Any = _UNSET + _subs: list = [] + _once: DefaultDict[Any, list] = defaultdict(list) + _deprecation_descriptor: Optional[DeprecationDescriptor] = None + + @classmethod + def _get_raw_from_config(cls) -> str: + """ + Read the value from config storage. + + Returns + ------- + str + Config raw value. + + Raises + ------ + KeyError + If value is absent. + + Notes + ----- + Config storage can be config file or environment variable or whatever. + Method should be implemented in the child class. + """ + raise NotImplementedError() + + @classmethod + def get_help(cls) -> str: + """ + Generate user-presentable help for the option. + + Returns + ------- + str + + Notes + ----- + Method should be implemented in the child class. + """ + raise NotImplementedError() + + def __init_subclass__(cls, type: Any, abstract: bool = False, **kw: dict): + """ + Initialize subclass. + + Parameters + ---------- + type : Any + Type of the config. + abstract : bool, default: False + Whether config is abstract. + **kw : dict + Optional arguments for config initialization. + """ + assert type in _TYPE_PARAMS, f"Unsupported variable type: {type}" + cls.type = type + cls.is_abstract = abstract + cls._value = _UNSET + cls._subs = [] + cls._once = defaultdict(list) + super().__init_subclass__(**kw) + + @classmethod + def subscribe(cls, callback: Callable) -> None: + """ + Add `callback` to the `_subs` list and then execute it. + + Parameters + ---------- + callback : callable + Callable to execute. + """ + cls._subs.append(callback) + callback(cls) + + @classmethod + def _get_default(cls) -> Any: + """ + Get default value of the config. + + Returns + ------- + Any + """ + return cls.default + + @classmethod + def get_value_source(cls) -> ValueSource: + """ + Get value source of the config. + + Returns + ------- + ValueSource + """ + if cls._value_source is None: + # dummy call to .get() to initialize the value + cls.get() + assert ( + cls._value_source is not None + ), "_value_source must be initialized by now in get()" + return cls._value_source + + @classmethod + def get(cls) -> Any: + """ + Get config value. + + Returns + ------- + Any + Decoded and verified config value. + """ + if cls._deprecation_descriptor is not None: + warnings.warn( # noqa: B028 + cls._deprecation_descriptor.deprecation_message(), FutureWarning + ) + if cls._value is _UNSET: + # get the value from env + try: + raw = cls._get_raw_from_config() + except KeyError: + cls._value = cls._get_default() + cls._value_source = ValueSource.DEFAULT + else: + if not _TYPE_PARAMS[cls.type].verify(raw): + raise ValueError(f"Unsupported raw value: {raw}") + cls._value = _TYPE_PARAMS[cls.type].decode(raw) + cls._value_source = ValueSource.GOT_FROM_CFG_SOURCE + return cls._value + + @classmethod + def put(cls, value: Any) -> None: + """ + Set config value. + + Parameters + ---------- + value : Any + Config value to set. + """ + if cls._deprecation_descriptor is not None: + warnings.warn( # noqa: B028 + cls._deprecation_descriptor.deprecation_message(), FutureWarning + ) + cls._check_callbacks(cls._put_nocallback(value)) + cls._value_source = ValueSource.SET_BY_USER + + @classmethod + def once(cls, onvalue: Any, callback: Callable) -> None: + """ + Execute `callback` if config value matches `onvalue` value. + + Otherwise accumulate callbacks associated with the given `onvalue` + in the `_once` container. + + Parameters + ---------- + onvalue : Any + Config value to set. + callback : callable + Callable that should be executed if config value matches `onvalue`. + """ + onvalue = _TYPE_PARAMS[cls.type].normalize(onvalue) + if onvalue == cls.get(): + callback(cls) + else: + cls._once[onvalue].append(callback) + + @classmethod + def _put_nocallback(cls, value: Any) -> Any: + """ + Set config value without executing callbacks. + + Parameters + ---------- + value : Any + Config value to set. + + Returns + ------- + Any + Replaced (old) config value. + """ + if not _TYPE_PARAMS[cls.type].verify(value): + raise ValueError(f"Unsupported value: {value}") + value = _TYPE_PARAMS[cls.type].normalize(value) + oldvalue, cls._value = cls.get(), value + return oldvalue + + @classmethod + def _check_callbacks(cls, oldvalue: Any) -> None: + """ + Execute all needed callbacks if config value was changed. + + Parameters + ---------- + oldvalue : Any + Previous (old) config value. + """ + if oldvalue == cls.get(): + return + for callback in cls._subs: + callback(cls) + for callback in cls._once.pop(cls.get(), ()): + callback(cls) + + @classmethod + def add_option(cls, choice: Any) -> Any: + """ + Add a new choice for the parameter. + + Parameters + ---------- + choice : Any + New choice to add to the available choices. + + Returns + ------- + Any + Added choice normalized according to the parameter type. + """ + if cls.choices is not None: + if not _TYPE_PARAMS[cls.type].verify(choice): + raise ValueError(f"Unsupported choice value: {choice}") + choice = _TYPE_PARAMS[cls.type].normalize(choice) + if choice not in cls.choices: + cls.choices += (choice,) + return choice + raise TypeError("Cannot add a choice to a parameter where choices is None") + + +__all__ = ["Parameter"] diff --git a/src/snowflake/snowpark/modin/conftest.py b/src/snowflake/snowpark/modin/conftest.py new file mode 100644 index 00000000000..6ddcd4e8348 --- /dev/null +++ b/src/snowflake/snowpark/modin/conftest.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +import modin.pandas as pd # pragma: no cover +import numpy as np # pragma: no cover +import pytest # pragma: no cover + +import snowflake.snowpark.modin.plugin # pragma: no cover # noqa: F401 + + +@pytest.fixture(autouse=True, scope="module") # pragma: no cover +def add_doctest_imports(doctest_namespace) -> None: # pragma: no cover + """ + Make `np` and `pd` names available for doctests. + """ + doctest_namespace["np"] = np # pragma: no cover + doctest_namespace["pd"] = pd # pragma: no cover diff --git a/src/snowflake/snowpark/modin/core/__init__.py b/src/snowflake/snowpark/modin/core/__init__.py new file mode 100644 index 00000000000..0fbef920926 --- /dev/null +++ b/src/snowflake/snowpark/modin/core/__init__.py @@ -0,0 +1,3 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# diff --git a/src/snowflake/snowpark/modin/core/dataframe/__init__.py b/src/snowflake/snowpark/modin/core/dataframe/__init__.py new file mode 100644 index 00000000000..0fbef920926 --- /dev/null +++ b/src/snowflake/snowpark/modin/core/dataframe/__init__.py @@ -0,0 +1,3 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# diff --git a/src/snowflake/snowpark/modin/core/dataframe/algebra/__init__.py b/src/snowflake/snowpark/modin/core/dataframe/algebra/__init__.py new file mode 100644 index 00000000000..0fbef920926 --- /dev/null +++ b/src/snowflake/snowpark/modin/core/dataframe/algebra/__init__.py @@ -0,0 +1,3 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# diff --git a/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/__init__.py b/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/__init__.py new file mode 100644 index 00000000000..90a7794a316 --- /dev/null +++ b/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/__init__.py @@ -0,0 +1,66 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +"""Module default2pandas provides templates for a query compiler default-to-pandas methods.""" + +from snowflake.snowpark.modin.core.dataframe.algebra.default2pandas.binary import ( + BinaryDefault, +) +from snowflake.snowpark.modin.core.dataframe.algebra.default2pandas.cat import ( + CatDefault, +) +from snowflake.snowpark.modin.core.dataframe.algebra.default2pandas.dataframe import ( + DataFrameDefault, +) +from snowflake.snowpark.modin.core.dataframe.algebra.default2pandas.datetime import ( + DateTimeDefault, +) +from snowflake.snowpark.modin.core.dataframe.algebra.default2pandas.default import ( + DefaultMethod, +) +from snowflake.snowpark.modin.core.dataframe.algebra.default2pandas.groupby import ( + GroupByDefault, +) +from snowflake.snowpark.modin.core.dataframe.algebra.default2pandas.resample import ( + ResampleDefault, +) +from snowflake.snowpark.modin.core.dataframe.algebra.default2pandas.rolling import ( + RollingDefault, +) +from snowflake.snowpark.modin.core.dataframe.algebra.default2pandas.series import ( + SeriesDefault, +) +from snowflake.snowpark.modin.core.dataframe.algebra.default2pandas.str import ( + StrDefault, +) + +__all__ = [ + "DataFrameDefault", + "DateTimeDefault", + "SeriesDefault", + "StrDefault", + "BinaryDefault", + "ResampleDefault", + "RollingDefault", + "DefaultMethod", + "CatDefault", + "GroupByDefault", +] diff --git a/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/binary.py b/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/binary.py new file mode 100644 index 00000000000..4052c113b25 --- /dev/null +++ b/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/binary.py @@ -0,0 +1,78 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +"""Module houses default binary functions builder class.""" +from typing import Any, Callable, Union + +import pandas +from pandas._typing import AnyArrayLike, Scalar + +from snowflake.snowpark.modin.core.dataframe.algebra.default2pandas.default import ( + DefaultMethod, +) + + +class BinaryDefault(DefaultMethod): + """Build default-to-pandas methods which executes binary functions.""" + + @classmethod + def build_default_to_pandas(cls, fn: Callable, fn_name: str) -> Callable: + """ + Build function that do fallback to pandas for passed binary `fn`. + + Parameters + ---------- + fn : callable + Binary function to apply to the casted to pandas frame and other operand. + fn_name : str + Function name which will be shown in default-to-pandas warning message. + + Returns + ------- + callable + Function that takes query compiler, does fallback to pandas and applies binary `fn` + to the casted to pandas frame. + """ + + def bin_ops_wrapper( + df: pandas.DataFrame, + other: Union[pandas.DataFrame, pandas.Series, Scalar, AnyArrayLike], + *args: Any, + **kwargs: Any + ) -> pandas.DataFrame: + """Apply specified binary function to the passed operands.""" + squeeze_other = kwargs.pop("broadcast", False) or kwargs.pop( + "squeeze_other", False + ) + squeeze_self = kwargs.pop("squeeze_self", False) + + if squeeze_other: + other = other.squeeze(axis=1) + + if squeeze_self: + df = df.squeeze(axis=1) + + result = fn(df, other, *args, **kwargs) + if not isinstance(result, pandas.DataFrame): # pragma: no cover + result = pandas.DataFrame(result) + return result + + return super().build_default_to_pandas(bin_ops_wrapper, fn_name) diff --git a/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/cat.py b/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/cat.py new file mode 100644 index 00000000000..6a239a7a55c --- /dev/null +++ b/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/cat.py @@ -0,0 +1,48 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +"""Module houses default applied-on-category functions builder class.""" +import pandas + +from snowflake.snowpark.modin.core.dataframe.algebra.default2pandas.series import ( + SeriesDefault, +) + + +class CatDefault(SeriesDefault): + """Builder for default-to-pandas methods which is executed under category accessor.""" + + @classmethod + def frame_wrapper( + cls, df: pandas.DataFrame + ) -> pandas.core.arrays.categorical.CategoricalAccessor: + """ + Get category accessor of the passed frame. + + Parameters + ---------- + df : pandas.DataFrame + + Returns + ------- + pandas.core.arrays.categorical.CategoricalAccessor + """ + return df.squeeze(axis=1).cat diff --git a/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/dataframe.py b/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/dataframe.py new file mode 100644 index 00000000000..f22f80f130b --- /dev/null +++ b/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/dataframe.py @@ -0,0 +1,36 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +"""Module houses default DataFrame functions builder class.""" + +import pandas + +from snowflake.snowpark.modin.core.dataframe.algebra.default2pandas.default import ( + DefaultMethod, +) + +# from modin.utils import _inherit_docstrings +from snowflake.snowpark.modin.utils import _inherit_docstrings + + +@_inherit_docstrings(DefaultMethod) +class DataFrameDefault(DefaultMethod): + DEFAULT_OBJECT_TYPE = pandas.DataFrame diff --git a/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/datetime.py b/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/datetime.py new file mode 100644 index 00000000000..d7aefcd165e --- /dev/null +++ b/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/datetime.py @@ -0,0 +1,48 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +"""Module houses default applied-on-datetime functions builder class.""" +import pandas + +from snowflake.snowpark.modin.core.dataframe.algebra.default2pandas.series import ( + SeriesDefault, +) + + +class DateTimeDefault(SeriesDefault): + """Builder for default-to-pandas methods which is executed under datetime accessor.""" + + @classmethod + def frame_wrapper( + cls, df: pandas.DataFrame + ) -> pandas.core.indexes.accessors.DatetimeProperties: + """ + Get datetime accessor of the passed frame. + + Parameters + ---------- + df : pandas.DataFrame + + Returns + ------- + pandas.core.indexes.accessors.DatetimeProperties + """ + return df.squeeze(axis=1).dt diff --git a/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/default.py b/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/default.py new file mode 100644 index 00000000000..b6d20f08a6c --- /dev/null +++ b/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/default.py @@ -0,0 +1,279 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +"""Module houses default functions builder class.""" + +from typing import Any, Callable, Optional, Union + +import pandas +from pandas.core.dtypes.common import is_list_like + +from snowflake.snowpark.modin.utils import ( + MODIN_UNNAMED_SERIES_LABEL, + is_property, + try_cast_to_pandas, +) + + +class ObjTypeDeterminer: + """ + Class that routes work to the frame. + + Provides an instance which forwards all of the `__getattribute__` calls + to an object under which `key` function is applied. + """ + + def __getattr__(self, key: str) -> Callable: + """ + Build function that executes `key` function over passed frame. + + Parameters + ---------- + key : str + + Returns + ------- + callable + Function that takes DataFrame and executes `key` function on it. + """ + + def func(df: object, *args: Any, **kwargs: Any) -> Any: # pragma: no cover + """Access specified attribute of the passed object and call it if it's callable.""" + prop = getattr(df, key) + if callable(prop): + return prop(*args, **kwargs) + else: + return prop + + return func + + +class DefaultMethod: + """ + Builder for default-to-pandas methods. + + Attributes + ---------- + OBJECT_TYPE : str + Object type name that will be shown in default-to-pandas warning message. + DEFAULT_OBJECT_TYPE : object + Default place to search for a function. + """ + + OBJECT_TYPE = "DataFrame" + DEFAULT_OBJECT_TYPE = ObjTypeDeterminer + + # This function is pulled from the Operator class in modin/core/dataframe/algebra/operator.py + def __init__(self) -> None: + raise ValueError( # pragma: no cover + "Please use {}.register instead of the constructor".format( + type(self).__name__ + ) + ) + + @classmethod + def get_func_name_for_registered_method( + cls, fn: Union[Callable, property, str] + ) -> str: + """ + Function that takes in a Callable or a property and returns its name + """ + + if is_property(fn): + # when a property method without a name, fn_name will be something like + # "", here we use fget to get the name of the property. Note that this + # method is still not perfect because we cannot get the class name of the property, e.g., we can only get + # "hour" from series.dt.hour + fn_name = f"" # type: ignore[union-attr] + else: + fn_name = getattr(fn, "__name__", str(fn)) + + return fn_name + + @classmethod + def register( + cls, + func: Union[Callable, property, str], + obj_type: Optional[object] = None, + inplace: Optional[bool] = None, + fn_name: Optional[str] = None, + ) -> Callable: + """ + Build function that do fallback to default pandas implementation for passed `func`. + + Parameters + ---------- + func : callable or str, + Function to apply to the casted to pandas frame or its property accesed + by ``cls.frame_wrapper``. + obj_type : object, optional + If `func` is a string with a function name then `obj_type` provides an + object to search function in. + inplace : bool, optional + If True return an object to which `func` was applied, otherwise return + the result of `func`. + fn_name : str, optional + Function name which will be shown in default-to-pandas warning message. + If not specified, name will be deducted from `func`. + + Returns + ------- + callable + Function that takes query compiler, does fallback to pandas and applies `func` + to the casted to pandas frame or its property accessed by ``cls.frame_wrapper``. + """ + + if isinstance(func, str): + if obj_type is None: + obj_type = cls.DEFAULT_OBJECT_TYPE + fn = getattr(obj_type, func) + else: + fn = func + + if fn_name is None: + fn_name = cls.get_func_name_for_registered_method(func) + + if type(fn) == property: + fn = cls.build_property_wrapper(fn) + + def applyier(df: pandas.DataFrame, *args: Any, **kwargs: Any) -> Callable: + """ + Apply target function to the casted to pandas frame. + + This function is directly applied to the casted to pandas frame, executes target + function under it and processes result so it is possible to create a valid + query compiler from it. + """ + args = try_cast_to_pandas(args) # pragma: no cover + kwargs = try_cast_to_pandas(kwargs) # pragma: no cover + + # pandas default implementation doesn't know how to handle `dtypes` keyword argument + kwargs.pop("dtypes", None) + df = cls.frame_wrapper(df) + result = fn(df, *args, **kwargs) + + if not isinstance( + result, pandas.Series + ) and not isinstance( # pragma: no cover + result, pandas.DataFrame + ): + # When applying a DatetimeProperties or TimedeltaProperties function, + # if we don't specify the dtype for the DataFrame, the frame might + # get the wrong dtype, e.g. for to_pydatetime in + # https://github.com/modin-project/modin/issues/4436 + astype_kwargs = {} + dtype = getattr(result, "dtype", None) + if dtype and isinstance( + df, + ( + pandas.core.indexes.accessors.DatetimeProperties, + pandas.core.indexes.accessors.TimedeltaProperties, + ), + ): + astype_kwargs["dtype"] = dtype + result = ( + pandas.DataFrame(result, **astype_kwargs) + if is_list_like(result) + else pandas.DataFrame([result], **astype_kwargs) + ) + if isinstance(result, pandas.Series): + if result.name is None: + result.name = MODIN_UNNAMED_SERIES_LABEL + result = result.to_frame() + + inplace_method = kwargs.get("inplace", False) + if inplace is not None: + inplace_method = inplace + return result if not inplace_method else df + + return cls.build_default_to_pandas(applyier, fn_name) # type: ignore[arg-type] + + @classmethod + def build_property_wrapper(cls, prop: property) -> Callable: + """ + Build function that accesses specified property of the frame. + + Parameters + ---------- + prop : str + Property name to access. + + Returns + ------- + callable + Function that takes DataFrame and returns its value of `prop` property. + """ + + def property_wrapper(df: Any) -> Any: + """Get specified property of the passed object.""" + return prop.fget(df) # type: ignore[misc] # pragma: no cover + + return property_wrapper + + @classmethod + def build_default_to_pandas(cls, fn: Callable, fn_name: str) -> Callable: + """ + Build function that do fallback to pandas for passed `fn`. + + Parameters + ---------- + fn : callable + Function to apply to the defaulted frame. + fn_name : str + Function name which will be shown in default-to-pandas warning message. + + Returns + ------- + callable + Method that does fallback to pandas and applies `fn` to the pandas frame. + """ + fn.__name__ = f"" + + def wrapper( # type: ignore[no-untyped-def] + self, *args: Any, **kwargs: Any + ) -> Callable: + """Do fallback to pandas for the specified function.""" + return self.default_to_pandas(fn, *args, **kwargs) + + return wrapper + + @classmethod + def frame_wrapper(cls, df: pandas.DataFrame) -> pandas.DataFrame: + """ + Extract frame property to apply function on. + + This method is executed under casted to pandas frame right before applying + a function passed to `register`, which gives an ability to transform frame somehow + or access its properties, by overriding this method in a child class. + + Parameters + ---------- + df : pandas.DataFrame + + Returns + ------- + pandas.DataFrame + + Notes + ----- + Being a base implementation, this particular method does nothing with passed frame. + """ + return df diff --git a/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/groupby.py b/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/groupby.py new file mode 100644 index 00000000000..ad5d96dc6a0 --- /dev/null +++ b/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/groupby.py @@ -0,0 +1,728 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +"""Module houses default GroupBy functions builder class.""" +from typing import Any, Callable, Optional, Union + +import pandas +from pandas.core.dtypes.common import is_list_like + +# Defines a set of string names of functions that are executed in a transform-way in groupby +from pandas.core.groupby.base import transformation_kernels + +from snowflake.snowpark.modin.core.dataframe.algebra.default2pandas.default import ( + DefaultMethod, +) +from snowflake.snowpark.modin.utils import ( + MODIN_UNNAMED_SERIES_LABEL, + hashable, + is_property, +) + + +# FIXME: there is no sence of keeping `GroupBy` and `GroupByDefault` logic in a different +# classes. They should be combined. +class GroupBy: + """Builder for GroupBy aggregation functions.""" + + agg_aliases = [ + "agg", + "dict_agg", + pandas.core.groupby.DataFrameGroupBy.agg, + pandas.core.groupby.DataFrameGroupBy.aggregate, + ] + + @staticmethod + def is_transformation_kernel(agg_func: Any) -> bool: + """ + Check whether a passed aggregation function is a transformation. + + Transformation means that the result of the function will be broadcasted + to the frame's original shape. + + Parameters + ---------- + agg_func : Any + + Returns + ------- + bool + """ + return ( + hashable(agg_func) and agg_func in transformation_kernels + ) # pragma: no cover + + @classmethod + def _call_groupby( + cls, df: Union[pandas.DataFrame, pandas.Series], *args: Any, **kwargs: Any + ) -> Union[pandas.core.groupby.DataFrameGroupBy, pandas.core.groupby.SeriesGroupBy]: + """Call .groupby() on passed `df`.""" + return df.groupby(*args, **kwargs) # pragma: no cover + + @classmethod + def validate_by(cls, by: Any) -> Any: + """ + Build valid `by` parameter for `pandas.DataFrame.groupby`. + + Cast all DataFrames in `by` parameter to Series or list of Series in case + of multi-column frame. + + Parameters + ---------- + by : DateFrame, Series, index label or list of such + Object which indicates groups for GroupBy. + + Returns + ------- + Series, index label or list of such + By parameter with all DataFrames casted to Series. + """ + + def try_cast_series(df: Any) -> Any: # pragma: no cover + """Cast one-column frame to Series.""" + if isinstance(df, pandas.DataFrame): + df = df.squeeze(axis=1) + if not isinstance(df, pandas.Series): + return df + if df.name == MODIN_UNNAMED_SERIES_LABEL: + df.name = None + return df + + if isinstance(by, pandas.DataFrame): + by = [try_cast_series(column) for _, column in by.items()] + elif isinstance(by, pandas.Series): + by = [try_cast_series(by)] + elif isinstance(by, list): + by = [try_cast_series(o) for o in by] + return by + + @classmethod + def inplace_applyier_builder( + cls, key: Callable, func: Optional[Union[Callable, str]] = None + ) -> Callable: + """ + Bind actual aggregation function to the GroupBy aggregation method. + + Parameters + ---------- + key : callable + Function that takes GroupBy object and evaluates passed aggregation function. + func : callable or str, optional + Function that takes DataFrame and aggregate its data. Will be applied + to each group at the grouped frame. + + Returns + ------- + callable, + Function that executes aggregation under GroupBy object. + """ + inplace_args = [] if func is None else [func] + + def inplace_applyier( # pragma: no cover + grp: Union[ + pandas.core.groupby.DataFrameGroupBy, pandas.core.groupby.SeriesGroupBy + ], + *func_args: Any, + **func_kwargs: Any, + ) -> Callable: + return key(grp, *inplace_args, *func_args, **func_kwargs) # type: ignore[operator] # pragma: no cover + + return inplace_applyier + + @classmethod + def get_func(cls, key: Callable, **kwargs: Any) -> Callable: + """ + Extract aggregation function from groupby arguments. + + Parameters + ---------- + key : callable or str + Default aggregation function. If aggregation function is not specified + via groupby arguments, then `key` function is used. + **kwargs : dict + GroupBy arguments that may contain aggregation function. + + Returns + ------- + callable + Aggregation function. + + Notes + ----- + There are two ways of how groupby aggregation can be invoked: + 1. Explicitly with query compiler method: `qc.groupby_sum()`. + 2. By passing aggregation function as an argument: `qc.groupby_agg("sum")`. + Both are going to produce the same result, however in the first case actual aggregation + function can be extracted from the method name, while for the second only from the method arguments. + """ + if "agg_func" in kwargs: + return cls.inplace_applyier_builder(key, kwargs["agg_func"]) + elif "func_dict" in kwargs: + return cls.inplace_applyier_builder(key, kwargs["func_dict"]) + else: + return cls.inplace_applyier_builder(key) + + @classmethod + def build_aggregate_method(cls, key: Callable) -> Callable: + """ + Build function for `QueryCompiler.groupby_agg` that can be executed as default-to-pandas. + + Parameters + ---------- + key : callable or str + Default aggregation function. If aggregation function is not specified + via groupby arguments, then `key` function is used. + + Returns + ------- + callable + Function that executes groupby aggregation. + """ + + def fn( + df: Union[pandas.DataFrame, pandas.Series], + by: Any, + axis: int, + groupby_kwargs: dict[str, Any], + agg_args: Any, + agg_kwargs: dict[str, Any], + **kwargs: Any, + ) -> Any: + """Group DataFrame and apply aggregation function to each group.""" + by = cls.validate_by(by) + + grp = cls._call_groupby( + df, by, axis=axis, **groupby_kwargs + ) # pragma: no cover + agg_func = cls.get_func(key, **kwargs) + result = agg_func(grp, *agg_args, **agg_kwargs) + + return result + + return fn + + @classmethod + def build_groupby_reduce_method(cls, agg_func: Any) -> Callable: + """ + Build function for `QueryCompiler.groupby_*` that can be executed as default-to-pandas. + + Parameters + ---------- + agg_func : callable or str + Default aggregation function. If aggregation function is not specified + via groupby arguments, then `agg_func` function is used. + + Returns + ------- + callable + Function that executes groupby aggregation. + """ + + def fn( + df: Union[pandas.DataFrame, pandas.Series], + by: Any, + axis: int, + groupby_kwargs: dict[str, Any], + agg_args: Any, + agg_kwargs: dict[str, Any], + drop: bool = False, + **kwargs: Any, + ) -> Any: + """Group DataFrame and apply aggregation function to each group.""" + if not isinstance(by, (pandas.Series, pandas.DataFrame)): + by = cls.validate_by(by) + grp = cls._call_groupby( + df, by, axis=axis, **groupby_kwargs + ) # pragma: no cover + grp_agg_func = cls.get_func(agg_func, **kwargs) + return grp_agg_func( + grp, + *agg_args, + **agg_kwargs, + ) + + if isinstance(by, pandas.DataFrame): + by = by.squeeze(axis=1) + if ( + drop + and isinstance(by, pandas.Series) + and by.name in df + and df[by.name].equals(by) + ): + by = [by.name] + if isinstance(by, pandas.DataFrame): + df = pandas.concat([df] + [by[[o for o in by if o not in df]]], axis=1) + by = list(by.columns) + + groupby_kwargs = groupby_kwargs.copy() + as_index = groupby_kwargs.pop("as_index", True) + groupby_kwargs["as_index"] = True + + grp = cls._call_groupby( + df, by, axis=axis, **groupby_kwargs + ) # pragma: no cover + func = cls.get_func(agg_func, **kwargs) + result = func(grp, *agg_args, **agg_kwargs) + method = kwargs.get("method") + + if isinstance(result, pandas.Series): + result = result.to_frame( # pragma: no cover + MODIN_UNNAMED_SERIES_LABEL if result.name is None else result.name + ) + + if not as_index: + if isinstance(by, pandas.Series): + # 1. If `drop` is True then 'by' Series represents a column from the + # source frame and so the 'by' is internal. + # 2. If method is 'size' then any 'by' is considered to be internal. + # This is a hacky legacy from the ``groupby_size`` implementation: + # https://github.com/modin-project/modin/issues/3739 + internal_by = (by.name,) if drop or method == "size" else tuple() + else: + internal_by = by + + cls.handle_as_index_for_dataframe( + result, + internal_by, + by_cols_dtypes=( + df.index.dtypes.values + if isinstance(df.index, pandas.MultiIndex) + else (df.index.dtype,) + ), + by_length=len(by), + drop=drop, + method=method, + inplace=True, + ) + + if result.index.name == MODIN_UNNAMED_SERIES_LABEL: + result.index.name = None + + return result + + return fn + + @classmethod + def is_aggregate(cls, key: Union[Callable, str, property]) -> bool: + """Check whether `key` is an alias for pandas.GroupBy.aggregation method.""" + return key in cls.agg_aliases + + @classmethod + def build_property_method(cls, property: property) -> Callable: + """ + Build function for `SnowflakeQueryCompiler.` that can be executed as default-to-pandas + + Parameters + ---------- + property: property + property of groupby object. + + Returns + ------- + callable + Function that executes groupby aggregation and returns property. + """ + + def fn( + df: Union[pandas.DataFrame, pandas.Series], + by: Any, + axis: int, + groupby_kwargs: dict[str, Any], + ) -> Any: + """Group DataFrame and apply aggregation function to each group.""" + by = cls.validate_by(by) # pragma: no cover + + grp = cls._call_groupby( + df, by, axis=axis, **groupby_kwargs + ) # pragma: no cover + + return property.fget(grp) # type: ignore[misc] + + return fn + + @classmethod + def build_groupby(cls, func: Union[Callable, property]) -> Callable: + """ + Build function that groups DataFrame and applies aggregation function to the every group. + + Parameters + ---------- + func : callable or str or property + Default aggregation function. If aggregation function is not specified + via groupby arguments, then `func` function is used. + + Returns + ------- + callable + Function that takes pandas DataFrame and does GroupBy aggregation. + """ + + if is_property(func): + return cls.build_property_method(func) # type: ignore[arg-type] + if cls.is_aggregate(func): + return cls.build_aggregate_method(func) # type: ignore[arg-type] + return cls.build_groupby_reduce_method( + func + ) # pragma: no cover # type: ignore[arg-type] + + @classmethod + def handle_as_index_for_dataframe( + cls, + result: pandas.DataFrame, + internal_by_cols: Any, + by_cols_dtypes: Optional[Any] = None, + by_length: Optional[int] = None, + selection: Optional[Any] = None, + partition_idx: int = 0, + drop: bool = True, + method: Optional[str] = None, + inplace: bool = False, + ) -> pandas.DataFrame: + """ + Handle `as_index=False` parameter for the passed GroupBy aggregation result. + + Parameters + ---------- + result : DataFrame + Frame containing GroupBy aggregation result computed with `as_index=True` + parameter (group names are located at the frame's index). + internal_by_cols : list-like + Internal 'by' columns. + by_cols_dtypes : list-like, optional + Data types of the internal 'by' columns. Required to do special casing + in case of categorical 'by'. If not specified, assume that there is no + categorical data in 'by'. + by_length : int, optional + Amount of keys to group on (including frame columns and external objects like list, Series, etc.) + If not specified, consider `by_length` to be equal ``len(internal_by_cols)``. + selection : label or list of labels, optional + Set of columns that were explicitly selected for aggregation (for example + via dict-aggregation). If not specified assuming that aggregation was + applied to all of the available columns. + partition_idx : int, default: 0 + Positional index of the current partition. + drop : bool, default: True + Indicates whether or not any of the `by` data came from the same frame. + method : str, optional + Name of the groupby function. This is a hint to be able to do special casing. + Note: this parameter is a legacy from the ``groupby_size`` implementation, + it's a hacky one and probably will be removed in the future: https://github.com/modin-project/modin/issues/3739. + inplace : bool, default: False + Modify the DataFrame in place (do not create a new object). + + Returns + ------- + DataFrame + GroupBy aggregation result with the considered `as_index=False` parameter. + """ + if not inplace: + result = result.copy() + + ( + reset_index, + drop, + lvls_to_drop, + cols_to_drop, + ) = cls.handle_as_index( # pragma: no cover + result_cols=result.columns, + result_index_names=result.index.names, + internal_by_cols=internal_by_cols, + by_cols_dtypes=by_cols_dtypes, + by_length=by_length, + selection=selection, + partition_idx=partition_idx, + drop=drop, + method=method, + ) + + if len(lvls_to_drop) > 0: + result.index = result.index.droplevel(lvls_to_drop) + if len(cols_to_drop) > 0: + result.drop(columns=cols_to_drop, inplace=True) + if reset_index: + result.reset_index(drop=drop, inplace=True) + return result + + @staticmethod + def handle_as_index( + result_cols: pandas.Index, + result_index_names: Any, + internal_by_cols: Any, + by_cols_dtypes: Optional[Any] = None, + by_length: Optional[int] = None, + selection: Optional[Any] = None, + partition_idx: int = 0, + drop: bool = True, + method: Optional[str] = None, + ) -> tuple[bool, bool, list[int], Any]: + """ + Compute hints to process ``as_index=False`` parameter for the GroupBy result. + + This function resolves naming conflicts of the index levels to insert and the column labels + for the GroupBy result. The logic of this function assumes that the initial GroupBy result + was computed as ``as_index=True``. + + Parameters + ---------- + result_cols : pandas.Index + Columns of the GroupBy result. + result_index_names : list-like + Index names of the GroupBy result. + internal_by_cols : list-like + Internal 'by' columns. + by_cols_dtypes : list-like, optional + Data types of the internal 'by' columns. Required to do special casing + in case of categorical 'by'. If not specified, assume that there is no + categorical data in 'by'. + by_length : int, optional + Amount of keys to group on (including frame columns and external objects like list, Series, etc.) + If not specified, consider `by_length` to be equal ``len(internal_by_cols)``. + selection : label or list of labels, optional + Set of columns that were explicitly selected for aggregation (for example + via dict-aggregation). If not specified assuming that aggregation was + applied to all of the available columns. + partition_idx : int, default: 0 + Positional index of the current partition. + drop : bool, default: True + Indicates whether or not any of the `by` data came from the same frame. + method : str, optional + Name of the groupby function. This is a hint to be able to do special casing. + Note: this parameter is a legacy from the ``groupby_size`` implementation, + it's a hacky one and probably will be removed in the future: https://github.com/modin-project/modin/issues/3739. + + Returns + ------- + reset_index : bool + Indicates whether to reset index to the default one (0, 1, 2 ... n) at this partition. + drop_index : bool + If `reset_index` is True, indicates whether to drop all index levels (True) or insert them into the + resulting columns (False). + lvls_to_drop : list of ints + Contains numeric indices of the levels of the result index to drop as intersected. + cols_to_drop : list of labels + Contains labels of the columns to drop from the result as intersected. + + Examples + -------- + >>> groupby_result = compute_groupby_without_processing_as_index_parameter() + >>> if not as_index: + >>> reset_index, drop, lvls_to_drop, cols_to_drop = handle_as_index(**extract_required_params(groupby_result)) + >>> if len(lvls_to_drop) > 0: + >>> groupby_result.index = groupby_result.index.droplevel(lvls_to_drop) + >>> if len(cols_to_drop) > 0: + >>> groupby_result = groupby_result.drop(columns=cols_to_drop) + >>> if reset_index: + >>> groupby_result_with_processed_as_index_parameter = groupby_result.reset_index(drop=drop) + >>> else: + >>> groupby_result_with_processed_as_index_parameter = groupby_result + """ + if by_length is None: + by_length = len(internal_by_cols) + + reset_index = method != "transform" and ( + by_length > 0 or selection is not None + ) # pragma: no cover + + # If the method is "size" then the result contains only one unique named column + # and we don't have to worry about any naming conflicts, so inserting all of + # the "by" into the result (just a fast-path) + if method == "size": + return reset_index, False, [], [] + + # pandas logic of resolving naming conflicts is the following: + # 1. If any categorical is in 'by' and 'by' is multi-column, then the categorical + # index is prioritized: drop intersected columns and insert all of the 'by' index + # levels to the frame as columns. + # 2. Otherwise, aggregation result is prioritized: drop intersected index levels and + # insert the filtered ones to the frame as columns. + if by_cols_dtypes is not None: + keep_index_levels = ( + by_length > 1 + and selection is None + and any(isinstance(x, pandas.CategoricalDtype) for x in by_cols_dtypes) + ) + else: + keep_index_levels = False + + # 1. We insert 'by'-columns to the result at the beginning of the frame and so only to the + # first partition, if partition_idx != 0 we just drop the index. If there are no columns + # that are required to drop (keep_index_levels is True) then we can exit here. + # 2. We don't insert 'by'-columns to the result if 'by'-data came from a different + # frame (drop is False), there's only one exception for this rule: if the `method` is "size", + # so if (drop is False) and method is not "size" we just drop the index and so can exit here. + if (not keep_index_levels and partition_idx != 0) or ( + not drop and method != "size" + ): + return reset_index, True, [], [] + + if not isinstance(internal_by_cols, pandas.Index): + if not is_list_like(internal_by_cols): + internal_by_cols = [internal_by_cols] + internal_by_cols = pandas.Index(internal_by_cols) + + internal_by_cols = ( + internal_by_cols[ + ~internal_by_cols.str.startswith(MODIN_UNNAMED_SERIES_LABEL, na=False) + ] + if hasattr(internal_by_cols, "str") + else internal_by_cols + ) + + if selection is not None and not isinstance(selection, pandas.Index): + selection = pandas.Index(selection) + + lvls_to_drop: list[int] = [] # pragma: no cover + cols_to_drop: Any = [] # pragma: no cover + + if not keep_index_levels: + # We want to insert only these internal-by-cols that are not presented + # in the result in order to not create naming conflicts + if selection is None: + cols_to_insert = frozenset(internal_by_cols) - frozenset(result_cols) + else: + cols_to_insert = frozenset( + # We have to use explicit 'not in' check and not just difference + # of sets because of specific '__contains__' operator in case of + # scalar 'col' and MultiIndex 'selection'. + col + for col in internal_by_cols + if col not in selection + ) + else: + cols_to_insert = internal_by_cols + # We want to drop such internal-by-cols that are presented + # in the result in order to not create naming conflicts + cols_to_drop = frozenset(internal_by_cols) & frozenset(result_cols) + + if partition_idx == 0: + lvls_to_drop = [ + i + for i, name in enumerate(result_index_names) + if name not in cols_to_insert + ] + else: + lvls_to_drop = result_index_names + + drop = False + if len(lvls_to_drop) == len(result_index_names): + drop = True + lvls_to_drop = [] + + return reset_index, drop, lvls_to_drop, cols_to_drop + + +class SeriesGroupBy(GroupBy): + """Builder for GroupBy aggregation functions for Series.""" + + @classmethod + def _call_groupby(cls, df: pandas.DataFrame, *args: Any, **kwargs: Any) -> Callable: + """Call .groupby() on passed `df` squeezed to Series.""" + # We can end up here by two means - either by "true" call + # like Series().groupby() or by df.groupby()[item]. + + if len(df.columns) == 1: # pragma: no cover + # Series().groupby() case + return df.squeeze(axis=1).groupby(*args, **kwargs) # pragma: no cover + # In second case surrounding logic will supplement grouping columns, + # so we need to drop them after grouping is over; our originally + # selected column is always the first, so use it + return df.groupby(*args, **kwargs)[df.columns[0]] # pragma: no cover + + +class GroupByDefault(DefaultMethod): + """Builder for default-to-pandas GroupBy aggregation functions.""" + + _groupby_cls = GroupBy + + OBJECT_TYPE = "GroupBy" + + @classmethod + def register(cls, func: Callable, **kwargs: Any) -> Callable: + """ + Build default-to-pandas GroupBy aggregation function. + + Parameters + ---------- + func : callable or str + Default aggregation function. If aggregation function is not specified + via groupby arguments, then `func` function is used. + **kwargs : kwargs + Additional arguments that will be passed to function builder. + + Returns + ------- + callable + Functiom that takes query compiler and defaults to pandas to do GroupBy + aggregation. + """ + return super().register( + cls._groupby_cls.build_groupby(func), + fn_name=cls.get_func_name_for_registered_method(func), + **kwargs, + ) + + # This specifies a `pandas.DataFrameGroupBy` method to pass the `agg_func` to, + # it's based on `how` to apply it. Going by pandas documentation: + # 1. `.aggregate(func)` applies func row/column wise. + # 2. `.apply(func)` applies func to a DataFrames, holding a whole group (group-wise). + # 3. `.transform(func)` is the same as `.apply()` but also broadcast the `func` + # result to the group's original shape. + # 4. 'direct' mode means that the passed `func` has to be applied directly + # to the `pandas.DataFrameGroupBy` object. + _aggregation_methods_dict = { + "axis_wise": pandas.core.groupby.DataFrameGroupBy.aggregate, + "group_wise": pandas.core.groupby.DataFrameGroupBy.apply, + "transform": pandas.core.groupby.DataFrameGroupBy.transform, + "direct": lambda grp, func, *args, **kwargs: func(grp, *args, **kwargs), + } + + @classmethod + def get_aggregation_method(cls, how: str) -> Callable: + """ + Return `pandas.DataFrameGroupBy` method that implements the passed `how` UDF applying strategy. + + Parameters + ---------- + how : {"axis_wise", "group_wise", "transform"} + `how` parameter of the ``BaseQueryCompiler.groupby_agg``. + + Returns + ------- + callable(pandas.DataFrameGroupBy, callable, *args, **kwargs) -> [pandas.DataFrame | pandas.Series] + + Notes + ----- + Visit ``BaseQueryCompiler.groupby_agg`` doc-string for more information about `how` parameter. + """ + return cls._aggregation_methods_dict[how] # pragma: no cover + + +class SeriesGroupByDefault(GroupByDefault): + """Builder for default-to-pandas GroupBy aggregation functions for Series.""" + + _groupby_cls = SeriesGroupBy + + _aggregation_methods_dict = { + "axis_wise": pandas.core.groupby.SeriesGroupBy.aggregate, + "group_wise": pandas.core.groupby.SeriesGroupBy.apply, + "transform": pandas.core.groupby.SeriesGroupBy.transform, + "direct": lambda grp, func, *args, **kwargs: func(grp, *args, **kwargs), + } diff --git a/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/resample.py b/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/resample.py new file mode 100644 index 00000000000..559c1a9cb3b --- /dev/null +++ b/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/resample.py @@ -0,0 +1,106 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +"""Module houses default Resamle functions builder class.""" +from typing import Any, Callable, Union + +import pandas + +from snowflake.snowpark.modin.core.dataframe.algebra.default2pandas.default import ( + DefaultMethod, +) + + +# FIXME: there is no sence of keeping `Resampler` and `ResampleDefault` logic in a different +# classes. They should be combined. +class Resampler: + """Builder class for resampled aggregation functions.""" + + @classmethod + def build_resample(cls, func: Union[Callable, property], squeeze_self: bool) -> Any: + """ + Build function that resamples time-series data and does aggregation. + + Parameters + ---------- + func : callable + Aggregation function to execute under resampled frame. + squeeze_self : bool + Whether or not to squeeze frame before resampling. + + Returns + ------- + callable + Function that takes pandas DataFrame and applies aggregation + to resampled time-series data. + """ + + def fn( # pragma: no cover + df: pandas.DataFrame, + resample_kwargs: dict[str, Any], + *args: Any, + **kwargs: Any + ) -> Any: + """Resample time-series data of the passed frame and apply specified aggregation.""" + if squeeze_self: + df = df.squeeze(axis=1) + resampler = df.resample(**resample_kwargs) + + if type(func) == property: + return func.fget(resampler) # type: ignore[misc] # pragma: no cover + + return func(resampler, *args, **kwargs) # type: ignore[operator] # pragma: no cover + + return fn + + +class ResampleDefault(DefaultMethod): + """Builder for default-to-pandas resampled aggregation functions.""" + + OBJECT_TYPE = "Resampler" + + @classmethod + def register( + cls, func: Callable, squeeze_self: bool = False, **kwargs: Any + ) -> Callable: + """ + Build function that do fallback to pandas and aggregate resampled data. + + Parameters + ---------- + func : callable + Aggregation function to execute under resampled frame. + squeeze_self : bool, default: False + Whether or not to squeeze frame before resampling. + **kwargs : kwargs + Additional arguments that will be passed to function builder. + + Returns + ------- + callable + Function that takes query compiler and does fallback to pandas to resample + time-series data and apply aggregation on it. + """ + return super().register( + Resampler.build_resample(func, squeeze_self), + fn_name=func.__name__, + **kwargs + ) diff --git a/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/rolling.py b/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/rolling.py new file mode 100644 index 00000000000..be662920c80 --- /dev/null +++ b/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/rolling.py @@ -0,0 +1,159 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +"""Module houses default Rolling functions builder class.""" + +from typing import Any, Callable, Union + +import pandas + +from snowflake.snowpark.modin.core.dataframe.algebra.default2pandas.default import ( + DefaultMethod, +) + + +class RollingDefault(DefaultMethod): + """Builder for default-to-pandas aggregation on a rolling window functions.""" + + OBJECT_TYPE = "Rolling" + + @classmethod + def _build_rolling(cls, func: Union[Callable, property]) -> Callable: + """ + Build function that creates a rolling window and executes `func` on it. + + Parameters + ---------- + func : callable + Function to execute on a rolling window. + + Returns + ------- + callable + Function that takes pandas DataFrame and applies `func` on a rolling window. + """ + + def fn( # pragma: no cover + df: pandas.DataFrame, + rolling_kwargs: dict[str, Any], + *args: Any, + **kwargs: Any + ) -> Any: + """Create rolling window for the passed frame and execute specified `func` on it.""" + roller = df.rolling(**rolling_kwargs) # pragma: no cover + + if type(func) == property: # pragma: no cover + return func.fget(roller) # type: ignore[misc] # pragma: no cover + + return func(roller, *args, **kwargs) # type: ignore[operator] # pragma: no cover + + return fn # pragma: no cover + + @classmethod + def register(cls, func: Callable, **kwargs: Any) -> Callable: + """ + Build function that do fallback to pandas to apply `func` on a rolling window. + + Parameters + ---------- + func : callable + Function to execute on a rolling window. + **kwargs : kwargs + Additional arguments that will be passed to function builder. + + Returns + ------- + callable + Function that takes query compiler and defaults to pandas to apply aggregation + `func` on a rolling window. + """ + return super().register( # pragma: no cover + cls._build_rolling(func), fn_name=func.__name__, **kwargs + ) + + +class ExpandingDefault(DefaultMethod): + """Builder for default-to-pandas aggregation on an expanding window functions.""" + + OBJECT_TYPE = "Expanding" + + @classmethod + def _build_expanding( + cls, func: Union[Callable, property], squeeze_self: bool + ) -> Callable: + """ + Build function that creates an expanding window and executes `func` on it. + + Parameters + ---------- + func : callable + Function to execute on a expanding window. + squeeze_self : bool + Whether or not to squeeze frame before executing the window function. + + Returns + ------- + callable + Function that takes pandas DataFrame and applies `func` on a expanding window. + """ + + def fn( # pragma: no cover + df: pandas.DataFrame, rolling_args: Any, *args: Any, **kwargs: Any + ) -> Any: + """Create rolling window for the passed frame and execute specified `func` on it.""" + if squeeze_self: # pragma: no cover + df = df.squeeze(axis=1) # pragma: no cover + roller = df.expanding(*rolling_args) # pragma: no cover + + if type(func) == property: # pragma: no cover + return func.fget(roller) # type: ignore[misc] # pragma: no cover + + return func(roller, *args, **kwargs) # type: ignore[operator] # pragma: no cover + + return fn # pragma: no cover + + @classmethod + def register( + cls, func: Callable, squeeze_self: bool = False, **kwargs: Any + ) -> Callable: + """ + Build function that do fallback to pandas to apply `func` on a expanding window. + + Parameters + ---------- + func : callable + Function to execute on an expanding window. + squeeze_self : bool, default: False + Whether or not to squeeze frame before executing the window function. + **kwargs : kwargs + Additional arguments that will be passed to function builder. + + Returns + ------- + callable + Function that takes query compiler and defaults to pandas to apply aggregation + `func` on an expanding window. + """ + return super().register( # pragma: no cover + cls._build_expanding(func, squeeze_self=squeeze_self), + fn_name=func.__name__, + **kwargs + ) diff --git a/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/series.py b/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/series.py new file mode 100644 index 00000000000..455360cef59 --- /dev/null +++ b/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/series.py @@ -0,0 +1,49 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +"""Module houses default Series functions builder class.""" +import pandas + +from snowflake.snowpark.modin.core.dataframe.algebra.default2pandas.default import ( + DefaultMethod, +) + + +class SeriesDefault(DefaultMethod): + """Builder for default-to-pandas methods which is executed under Series.""" + + OBJECT_TYPE = "Series" + + @classmethod + def frame_wrapper(cls, df: pandas.DataFrame) -> pandas.Series: + """ + Squeeze passed DataFrame to be able to process Series-specific functions on it. + + Parameters + ---------- + df : pandas.DataFrame + One-column DataFrame to squeeze. + + Returns + ------- + pandas.Series + """ + return df.squeeze(axis=1) diff --git a/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/str.py b/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/str.py new file mode 100644 index 00000000000..4de39a3c084 --- /dev/null +++ b/src/snowflake/snowpark/modin/core/dataframe/algebra/default2pandas/str.py @@ -0,0 +1,48 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +"""Module houses default applied-on-str functions builder class.""" +import pandas + +from snowflake.snowpark.modin.core.dataframe.algebra.default2pandas.series import ( + SeriesDefault, +) + + +class StrDefault(SeriesDefault): + """Builder for default-to-pandas methods which is executed under `str` accessor.""" + + @classmethod + def frame_wrapper( + cls, df: pandas.DataFrame + ) -> pandas.core.strings.accessor.StringMethods: + """ + Get `str` accessor of the passed frame. + + Parameters + ---------- + df : pandas.DataFrame + + Returns + ------- + pandas.core.strings.accessor.StringMethods + """ + return df.squeeze(axis=1).str diff --git a/src/snowflake/snowpark/modin/core/execution/__init__.py b/src/snowflake/snowpark/modin/core/execution/__init__.py new file mode 100644 index 00000000000..0fbef920926 --- /dev/null +++ b/src/snowflake/snowpark/modin/core/execution/__init__.py @@ -0,0 +1,3 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# diff --git a/src/snowflake/snowpark/modin/core/execution/dispatching/__init__.py b/src/snowflake/snowpark/modin/core/execution/dispatching/__init__.py new file mode 100644 index 00000000000..e71962dddec --- /dev/null +++ b/src/snowflake/snowpark/modin/core/execution/dispatching/__init__.py @@ -0,0 +1,22 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +"""Modin's functionality related to dispatching to specific execution.""" diff --git a/src/snowflake/snowpark/modin/core/execution/dispatching/factories/__init__.py b/src/snowflake/snowpark/modin/core/execution/dispatching/factories/__init__.py new file mode 100644 index 00000000000..b1326d67f33 --- /dev/null +++ b/src/snowflake/snowpark/modin/core/execution/dispatching/factories/__init__.py @@ -0,0 +1,26 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +"""Factories responsible for dispatching to specific execution.""" + +from snowflake.snowpark.modin.core.execution.dispatching.factories import ( # noqa: F401 + factories, +) diff --git a/src/snowflake/snowpark/modin/core/execution/dispatching/factories/baseio.py b/src/snowflake/snowpark/modin/core/execution/dispatching/factories/baseio.py new file mode 100644 index 00000000000..6825cf30c23 --- /dev/null +++ b/src/snowflake/snowpark/modin/core/execution/dispatching/factories/baseio.py @@ -0,0 +1,680 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# based on https://raw.githubusercontent.com/modin-project/modin/master/modin/core/io/io.py +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +""" +Module houses `BaseIO` class. + +`BaseIO` is base class for IO classes, that stores IO functions. +""" + +from collections import OrderedDict +from typing import Any + +import pandas +from pandas.util._decorators import doc + +from snowflake.snowpark.modin.plugin.compiler import BaseQueryCompiler +from snowflake.snowpark.modin.utils import _inherit_docstrings + +_doc_default_io_method = """ +{summary} using pandas. +For parameters description please refer to pandas API. + +Returns +------- +{returns} +""" + +_doc_returns_qc = """BaseQueryCompiler + QueryCompiler with read data.""" + +_doc_returns_qc_or_parser = """BaseQueryCompiler or TextParser + QueryCompiler or TextParser with read data.""" + + +class BaseIO: + """Class for basic utils and default implementation of IO functions.""" + + query_compiler_cls: BaseQueryCompiler = None + frame_cls = None + + @classmethod + def from_non_pandas(cls, *args, **kwargs): + """ + Create a Modin `query_compiler` from a non-pandas `object`. + + Parameters + ---------- + *args : iterable + Positional arguments to be passed into `func`. + **kwargs : dict + Keyword arguments to be passed into `func`. + """ + return None + + @classmethod + def from_pandas(cls, df): + """ + Create a Modin `query_compiler` from a `pandas.DataFrame`. + + Parameters + ---------- + df : pandas.DataFrame + The pandas DataFrame to convert from. + + Returns + ------- + BaseQueryCompiler + QueryCompiler containing data from the `pandas.DataFrame`. + """ + return cls.query_compiler_cls.from_pandas(df, cls.frame_cls) + + @classmethod + def from_arrow(cls, at): + """ + Create a Modin `query_compiler` from a `pyarrow.Table`. + + Parameters + ---------- + at : Arrow Table + The Arrow Table to convert from. + + Returns + ------- + BaseQueryCompiler + QueryCompiler containing data from the Arrow Table. + """ + return cls.query_compiler_cls.from_arrow(at, cls.frame_cls) + + @classmethod + def from_dataframe(cls, df): + """ + Create a Modin QueryCompiler from a DataFrame supporting the DataFrame exchange protocol `__dataframe__()`. + + Parameters + ---------- + df : DataFrame + The DataFrame object supporting the DataFrame exchange protocol. + + Returns + ------- + BaseQueryCompiler + QueryCompiler containing data from the DataFrame. + """ + return cls.query_compiler_cls.from_dataframe(df, cls.frame_cls) + + @classmethod + @_inherit_docstrings(pandas.read_parquet, apilink="pandas.read_parquet") + @doc( + _doc_default_io_method, + summary="Load a parquet object from the file path, returning a query compiler", + returns=_doc_returns_qc, + ) + def read_parquet(cls, **kwargs): # noqa: PR01 + return cls.from_pandas( + pandas.read_parquet( + **kwargs, + ) + ) + + @classmethod + @_inherit_docstrings(pandas.read_csv, apilink="pandas.read_csv") + @doc( + _doc_default_io_method, + summary="Read a comma-separated values (CSV) file into query compiler", + returns=_doc_returns_qc_or_parser, + ) + def read_csv( + cls, + filepath_or_buffer, + **kwargs, + ): # noqa: PR01 + pd_obj = pandas.read_csv(filepath_or_buffer, **kwargs) + if isinstance(pd_obj, pandas.DataFrame): + return cls.from_pandas(pd_obj) + if isinstance(pd_obj, pandas.io.parsers.TextFileReader): + # Overwriting the read method should return a Modin DataFrame for calls + # to __next__ and get_chunk + pd_read = pd_obj.read + pd_obj.read = lambda *args, **kw: cls.from_pandas(pd_read(*args, **kw)) + return pd_obj + + @classmethod + @_inherit_docstrings(pandas.read_json, apilink="pandas.read_json") + @doc( + _doc_default_io_method, + summary="Convert a JSON string to query compiler", + returns=_doc_returns_qc, + ) + def read_json( + cls, + **kwargs, + ): # noqa: PR01 + return cls.from_pandas(pandas.read_json(**kwargs)) + + @classmethod + @_inherit_docstrings(pandas.read_gbq, apilink="pandas.read_gbq") + @doc( + _doc_default_io_method, + summary="Load data from Google BigQuery into query compiler", + returns=_doc_returns_qc, + ) + def read_gbq( + cls, + query: str, + project_id=None, + index_col=None, + col_order=None, + reauth=False, + auth_local_webserver=False, + dialect=None, + location=None, + configuration=None, + credentials=None, + use_bqstorage_api=None, + private_key=None, + verbose=None, + progress_bar_type=None, + max_results=None, + ): # noqa: PR01 + return cls.from_pandas( + pandas.read_gbq( + query, + project_id=project_id, + index_col=index_col, + col_order=col_order, + reauth=reauth, + auth_local_webserver=auth_local_webserver, + dialect=dialect, + location=location, + configuration=configuration, + credentials=credentials, + use_bqstorage_api=use_bqstorage_api, + progress_bar_type=progress_bar_type, + max_results=max_results, + ) + ) + + @classmethod + @_inherit_docstrings(pandas.read_html, apilink="pandas.read_html") + @doc( + _doc_default_io_method, + summary="Read HTML tables into query compiler", + returns=_doc_returns_qc, + ) + def read_html( + cls, + io, + match=".+", + flavor=None, + header=None, + index_col=None, + skiprows=None, + attrs=None, + parse_dates=False, + thousands=",", + encoding=None, + decimal=".", + converters=None, + na_values=None, + keep_default_na=True, + displayed_only=True, + **kwargs, + ): # noqa: PR01 + return cls.from_pandas( + pandas.read_html( + io=io, + match=match, + flavor=flavor, + header=header, + index_col=index_col, + skiprows=skiprows, + attrs=attrs, + parse_dates=parse_dates, + thousands=thousands, + encoding=encoding, + decimal=decimal, + converters=converters, + na_values=na_values, + keep_default_na=keep_default_na, + displayed_only=displayed_only, + **kwargs, + )[0] + ) + + @classmethod + @_inherit_docstrings(pandas.read_clipboard, apilink="pandas.read_clipboard") + @doc( + _doc_default_io_method, + summary="Read text from clipboard into query compiler", + returns=_doc_returns_qc, + ) + def read_clipboard(cls, sep=r"\s+", **kwargs): # pragma: no cover # noqa: PR01 + return cls.from_pandas(pandas.read_clipboard(sep=sep, **kwargs)) + + @classmethod + @_inherit_docstrings(pandas.read_excel, apilink="pandas.read_excel") + @doc( + _doc_default_io_method, + summary="Read an Excel file into query compiler", + returns="""BaseQueryCompiler or dict/OrderedDict : + QueryCompiler or OrderedDict/dict with read data.""", + ) + def read_excel( + cls, + io, + sheet_name=0, + header=0, + names=None, + index_col=None, + usecols=None, + squeeze=False, + dtype=None, + engine=None, + converters=None, + true_values=None, + false_values=None, + skiprows=None, + nrows=None, + na_values=None, + keep_default_na=True, + verbose=False, + parse_dates=False, + date_parser=None, + thousands=None, + comment=None, + skip_footer=0, + skipfooter=0, + convert_float=True, + mangle_dupe_cols=True, + na_filter=True, + **kwds, + ): # noqa: PR01 + if skip_footer != 0: + skipfooter = skip_footer + intermediate = pandas.read_excel( + io, + sheet_name=sheet_name, + header=header, + names=names, + index_col=index_col, + usecols=usecols, + squeeze=squeeze, + dtype=dtype, + engine=engine, + converters=converters, + true_values=true_values, + false_values=false_values, + skiprows=skiprows, + nrows=nrows, + na_values=na_values, + keep_default_na=keep_default_na, + verbose=verbose, + parse_dates=parse_dates, + date_parser=date_parser, + thousands=thousands, + comment=comment, + skipfooter=skipfooter, + convert_float=convert_float, + mangle_dupe_cols=mangle_dupe_cols, + na_filter=na_filter, + **kwds, + ) + if isinstance(intermediate, (OrderedDict, dict)): + parsed = type(intermediate)() + for key in intermediate.keys(): + parsed[key] = cls.from_pandas(intermediate.get(key)) + return parsed + else: + return cls.from_pandas(intermediate) + + @classmethod + @_inherit_docstrings(pandas.read_hdf, apilink="pandas.read_hdf") + @doc( + _doc_default_io_method, + summary="Read data from hdf store into query compiler", + returns=_doc_returns_qc, + ) + def read_hdf( + cls, + path_or_buf, + key=None, + mode: str = "r", + errors: str = "strict", + where=None, + start=None, + stop=None, + columns=None, + iterator=False, + chunksize=None, + **kwargs, + ): # noqa: PR01 + from modin.pandas.io import HDFStore + + modin_store = isinstance(path_or_buf, HDFStore) + if modin_store: + path_or_buf._return_modin_dataframe = False + df = pandas.read_hdf( + path_or_buf, + key=key, + mode=mode, + columns=columns, + errors=errors, + where=where, + start=start, + stop=stop, + iterator=iterator, + chunksize=chunksize, + **kwargs, + ) + if modin_store: + path_or_buf._return_modin_dataframe = True + + return cls.from_pandas(df) + + @classmethod + @_inherit_docstrings(pandas.read_feather, apilink="pandas.read_feather") + @doc( + _doc_default_io_method, + summary="Load a feather-format object from the file path into query compiler", + returns=_doc_returns_qc, + ) + def read_feather( + cls, + path, + **kwargs, + ): # noqa: PR01 + return cls.from_pandas( + pandas.read_feather( + path, + **kwargs, + ) + ) + + @classmethod + @_inherit_docstrings(pandas.read_stata, apilink="pandas.read_stata") + @doc( + _doc_default_io_method, + summary="Read Stata file into query compiler", + returns=_doc_returns_qc, + ) + def read_stata( + cls, + filepath_or_buffer, + **kwargs, + ): # noqa: PR01 + return cls.from_pandas(pandas.read_stata(filepath_or_buffer, **kwargs)) + + @classmethod + @_inherit_docstrings(pandas.read_sas, apilink="pandas.read_sas") + @doc( + _doc_default_io_method, + summary="Read SAS files stored as either XPORT or SAS7BDAT format files\ninto query compiler", + returns=_doc_returns_qc, + ) + def read_sas( + cls, + filepath_or_buffer, + format=None, + index=None, + encoding=None, + chunksize=None, + iterator=False, + **kwargs, + ): # pragma: no cover # noqa: PR01 + return cls.from_pandas( + pandas.read_sas( + filepath_or_buffer, + format=format, + index=index, + encoding=encoding, + chunksize=chunksize, + iterator=iterator, + **kwargs, + ) + ) + + @classmethod + @_inherit_docstrings(pandas.read_pickle, apilink="pandas.read_pickle") + @doc( + _doc_default_io_method, + summary="Load pickled pandas object (or any object) from file into query compiler", + returns=_doc_returns_qc, + ) + def read_pickle( + cls, + filepath_or_buffer, + **kwargs, + ): # noqa: PR01 + + return cls.from_pandas( + pandas.read_pickle( + filepath_or_buffer, + **kwargs, + ) + ) + + @classmethod + @_inherit_docstrings(pandas.read_sql, apilink="pandas.read_sql") + @doc( + _doc_default_io_method, + summary="Read SQL query or database table into query compiler", + returns=_doc_returns_qc, + ) + def read_sql( + cls, + sql, + con, + index_col=None, + coerce_float=True, + params=None, + parse_dates=None, + columns=None, + chunksize=None, + ): # noqa: PR01 + # if isinstance(con, ModinDatabaseConnection): + # con = con.get_connection() + return cls.from_pandas( + pandas.read_sql( + sql, + con, + index_col=index_col, + coerce_float=coerce_float, + params=params, + parse_dates=parse_dates, + columns=columns, + chunksize=chunksize, + ) + ) + + @classmethod + @_inherit_docstrings(pandas.read_fwf, apilink="pandas.read_fwf") + @doc( + _doc_default_io_method, + summary="Read a table of fixed-width formatted lines into query compiler", + returns=_doc_returns_qc_or_parser, + ) + def read_fwf( + cls, filepath_or_buffer, colspecs="infer", widths=None, infer_nrows=100, **kwds + ): # noqa: PR01 + pd_obj = pandas.read_fwf( + filepath_or_buffer, + colspecs=colspecs, + widths=widths, + infer_nrows=infer_nrows, + **kwds, + ) + if isinstance(pd_obj, pandas.DataFrame): + return cls.from_pandas(pd_obj) + if isinstance(pd_obj, pandas.io.parsers.TextFileReader): + # Overwriting the read method should return a Modin DataFrame for calls + # to __next__ and get_chunk + pd_read = pd_obj.read + pd_obj.read = lambda *args, **kwargs: cls.from_pandas( + pd_read(*args, **kwargs) + ) + return pd_obj + + @classmethod + @_inherit_docstrings(pandas.read_sql_table, apilink="pandas.read_sql_table") + @doc( + _doc_default_io_method, + summary="Read SQL database table into query compiler", + returns=_doc_returns_qc, + ) + def read_sql_table( + cls, + table_name, + con, + schema=None, + index_col=None, + coerce_float=True, + parse_dates=None, + columns=None, + chunksize=None, + ): # noqa: PR01 + return cls.from_pandas( + pandas.read_sql_table( + table_name, + con, + schema=schema, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates, + columns=columns, + chunksize=chunksize, + ) + ) + + @classmethod + @_inherit_docstrings(pandas.read_sql_query, apilink="pandas.read_sql_query") + @doc( + _doc_default_io_method, + summary="Read SQL query into query compiler", + returns=_doc_returns_qc, + ) + def read_sql_query( + cls, + sql, + con, + **kwargs, + ): # noqa: PR01 + return cls.from_pandas( + pandas.read_sql_query( + sql, + con, + **kwargs, + ) + ) + + @classmethod + @_inherit_docstrings(pandas.read_spss, apilink="pandas.read_spss") + @doc( + _doc_default_io_method, + summary="Load an SPSS file from the file path, returning a query compiler", + returns=_doc_returns_qc, + ) + def read_spss(cls, path, usecols, convert_categoricals): # noqa: PR01 + return cls.from_pandas(pandas.read_spss(path, usecols, convert_categoricals)) + + @classmethod + @_inherit_docstrings(pandas.DataFrame.to_sql, apilink="pandas.DataFrame.to_sql") + def to_sql( + cls, + qc, + name, + con, + schema=None, + if_exists="fail", + index=True, + index_label=None, + chunksize=None, + dtype=None, + method=None, + ): # noqa: PR01 + """ + Write records stored in a DataFrame to a SQL database using pandas. + + For parameters description please refer to pandas API. + """ + df = qc.to_pandas() + df.to_sql( + name=name, + con=con, + schema=schema, + if_exists=if_exists, + index=index, + index_label=index_label, + chunksize=chunksize, + dtype=dtype, + method=method, + ) + + @classmethod + @_inherit_docstrings( + pandas.DataFrame.to_pickle, apilink="pandas.DataFrame.to_pickle" + ) + def to_pickle( + cls, + obj: Any, + filepath_or_buffer, + **kwargs, + ): # noqa: PR01, D200 + """ + Pickle (serialize) object to file. + """ + if isinstance(obj, BaseQueryCompiler): + obj = obj.to_pandas() + + return pandas.to_pickle( + obj, + filepath_or_buffer=filepath_or_buffer, + **kwargs, + ) + + @classmethod + @_inherit_docstrings(pandas.DataFrame.to_csv, apilink="pandas.DataFrame.to_csv") + def to_csv(cls, obj, **kwargs): # noqa: PR01 + """ + Write object to a comma-separated values (CSV) file using pandas. + + For parameters description please refer to pandas API. + """ + if isinstance(obj, BaseQueryCompiler): + obj = obj.to_pandas() + + return obj.to_csv(**kwargs) + + @classmethod + @_inherit_docstrings( + pandas.DataFrame.to_parquet, apilink="pandas.DataFrame.to_parquet" + ) + def to_parquet(cls, obj, **kwargs): # noqa: PR01 + """ + Write object to the binary parquet format using pandas. + + For parameters description please refer to pandas API. + """ + if isinstance(obj, BaseQueryCompiler): + obj = obj.to_pandas() + + return obj.to_parquet(**kwargs) diff --git a/src/snowflake/snowpark/modin/core/execution/dispatching/factories/dispatcher.py b/src/snowflake/snowpark/modin/core/execution/dispatching/factories/dispatcher.py new file mode 100644 index 00000000000..d8180c94dae --- /dev/null +++ b/src/snowflake/snowpark/modin/core/execution/dispatching/factories/dispatcher.py @@ -0,0 +1,287 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +""" +Contain IO dispatcher class. + +Dispatcher routes the work to execution-specific functions. +""" + +from snowflake.snowpark.modin.core.execution.dispatching.factories import factories +from snowflake.snowpark.modin.core.execution.dispatching.factories.factories import ( + PandasOnSnowflakeFactory, +) +from snowflake.snowpark.modin.utils import _inherit_docstrings + + +class FactoryNotFoundError(AttributeError): + """ + ``FactoryNotFound`` exception class. + + Raise when no matching factory could be found. + """ + + pass + + +class StubIoEngine: + """ + IO-Engine that does nothing more than raise NotImplementedError when any method is called. + + Parameters + ---------- + factory_name : str + Factory name, which will be reflected in error messages. + + Notes + ----- + Used for testing purposes. + """ + + def __init__(self, factory_name="") -> None: + self.factory_name = factory_name or "Unknown" + + def __getattr__(self, name): + """ + Return a function that raises `NotImplementedError` for the `name` method. + + Parameters + ---------- + name : str + Method name to indicate in `NotImplementedError`. + + Returns + ------- + callable + """ + + def stub(*args, **kw): + raise NotImplementedError( + f"Method {self.factory_name}.{name} is not implemented" + ) + + return stub + + +class StubFactory(factories.BaseFactory): + """ + Factory that does nothing more than raise NotImplementedError when any method is called. + + Notes + ----- + Used for testing purposes. + """ + + io_cls = StubIoEngine() + + @classmethod + def set_failing_name(cls, factory_name): + """ + Fill in `.io_cls` class attribute with ``StubIoEngine`` engine. + + Parameters + ---------- + factory_name : str + Name to pass to the ``StubIoEngine`` constructor. + """ + cls.io_cls = StubIoEngine(factory_name) + return cls + + +class FactoryDispatcher: + """ + Class that routes IO-work to the factories. + + This class is responsible for keeping selected factory up-to-date and dispatching + calls of IO-functions to its actual execution-specific implementations. + """ + + __factory: factories.BaseFactory = None + + @classmethod + def get_factory(cls) -> factories.BaseFactory: + """Get current factory.""" + if cls.__factory is None: + # set default factory to be PandasOnPythonFactory, TODO: replace with Snowflake factory + # lazy initialize + cls.__factory = PandasOnSnowflakeFactory() + cls.__factory.prepare() + return cls.__factory + + @classmethod + @_inherit_docstrings(factories.BaseFactory._from_pandas) + def from_pandas(cls, df): + return cls.get_factory()._from_pandas(df) + + @classmethod + @_inherit_docstrings(factories.BaseFactory._from_arrow) + def from_arrow(cls, at): + return cls.get_factory()._from_arrow(at) + + @classmethod + @_inherit_docstrings(factories.BaseFactory._from_non_pandas) + def from_non_pandas(cls, *args, **kwargs): + return cls.get_factory()._from_non_pandas(*args, **kwargs) + + @classmethod + @_inherit_docstrings(factories.BaseFactory._from_dataframe) + def from_dataframe(cls, *args, **kwargs): + return cls.get_factory()._from_dataframe(*args, **kwargs) + + @classmethod + @_inherit_docstrings(factories.BaseFactory._read_parquet) + def read_parquet(cls, **kwargs): + return cls.get_factory()._read_parquet(**kwargs) + + @classmethod + @_inherit_docstrings(factories.BaseFactory._read_csv) + def read_csv(cls, **kwargs): + return cls.get_factory()._read_csv(**kwargs) + + # @classmethod + # #@_inherit_docstrings(factories.ExperimentalPandasOnRayFactory._read_csv_glob) + # def read_csv_glob(cls, **kwargs): + # return cls.get_factory()._read_csv_glob(**kwargs) + # + # @classmethod + # @_inherit_docstrings( + # factories.ExperimentalPandasOnRayFactory._read_pickle_distributed + # ) + # def read_pickle_distributed(cls, **kwargs): + # return cls.get_factory()._read_pickle_distributed(**kwargs) + + @classmethod + @_inherit_docstrings(factories.BaseFactory._read_json) + def read_json(cls, **kwargs): + return cls.get_factory()._read_json(**kwargs) + + @classmethod + @_inherit_docstrings(factories.BaseFactory._read_gbq) + def read_gbq(cls, **kwargs): + return cls.get_factory()._read_gbq(**kwargs) + + @classmethod + @_inherit_docstrings(factories.BaseFactory._read_html) + def read_html(cls, **kwargs): + return cls.get_factory()._read_html(**kwargs) + + @classmethod + @_inherit_docstrings(factories.BaseFactory._read_clipboard) + def read_clipboard(cls, **kwargs): + return cls.get_factory()._read_clipboard(**kwargs) + + @classmethod + @_inherit_docstrings(factories.BaseFactory._read_excel) + def read_excel(cls, **kwargs): + return cls.get_factory()._read_excel(**kwargs) + + @classmethod + @_inherit_docstrings(factories.BaseFactory._read_hdf) + def read_hdf(cls, **kwargs): + return cls.get_factory()._read_hdf(**kwargs) + + @classmethod + @_inherit_docstrings(factories.BaseFactory._read_feather) + def read_feather(cls, **kwargs): + return cls.get_factory()._read_feather(**kwargs) + + @classmethod + @_inherit_docstrings(factories.BaseFactory._read_stata) + def read_stata(cls, **kwargs): + return cls.get_factory()._read_stata(**kwargs) + + @classmethod + @_inherit_docstrings(factories.BaseFactory._read_sas) + def read_sas(cls, **kwargs): # pragma: no cover + return cls.get_factory()._read_sas(**kwargs) + + @classmethod + @_inherit_docstrings(factories.BaseFactory._read_pickle) + def read_pickle(cls, **kwargs): + return cls.get_factory()._read_pickle(**kwargs) + + @classmethod + @_inherit_docstrings(factories.BaseFactory._read_sql) + def read_sql(cls, **kwargs): + return cls.get_factory()._read_sql(**kwargs) + + @classmethod + @_inherit_docstrings(factories.BaseFactory._read_fwf) + def read_fwf(cls, **kwargs): + return cls.get_factory()._read_fwf(**kwargs) + + @classmethod + @_inherit_docstrings(factories.BaseFactory._read_sql_table) + def read_sql_table(cls, **kwargs): + return cls.get_factory()._read_sql_table(**kwargs) + + @classmethod + @_inherit_docstrings(factories.BaseFactory._read_sql_query) + def read_sql_query(cls, **kwargs): + return cls.get_factory()._read_sql_query(**kwargs) + + @classmethod + @_inherit_docstrings(factories.BaseFactory._read_spss) + def read_spss(cls, **kwargs): + return cls.get_factory()._read_spss(**kwargs) + + @classmethod + @_inherit_docstrings(factories.BaseFactory._to_sql) + def to_sql(cls, *args, **kwargs): + return cls.get_factory()._to_sql(*args, **kwargs) + + @classmethod + @_inherit_docstrings(factories.BaseFactory._to_pickle) + def to_pickle(cls, *args, **kwargs): + return cls.get_factory()._to_pickle(*args, **kwargs) + + # @classmethod + # @_inherit_docstrings( + # factories.ExperimentalPandasOnRayFactory._to_pickle_distributed + # ) + # def to_pickle_distributed(cls, *args, **kwargs): + # return cls.get_factory()._to_pickle_distributed(*args, **kwargs) + + # @classmethod + # @_inherit_docstrings(factories.ExperimentalPandasOnRayFactory._read_custom_text) + # def read_custom_text(cls, **kwargs): + # return cls.get_factory()._read_custom_text(**kwargs) + + @classmethod + @_inherit_docstrings(factories.BaseFactory._to_csv) + def to_csv(cls, *args, **kwargs): + return cls.get_factory()._to_csv(*args, **kwargs) + + @classmethod + @_inherit_docstrings(factories.BaseFactory._to_parquet) + def to_parquet(cls, *args, **kwargs): + return cls.get_factory()._to_parquet(*args, **kwargs) + + # Snowflake added methods + @classmethod + @_inherit_docstrings(factories.BaseFactory._read_snowflake) + def read_snowflake(cls, *args, **kwargs): + return cls.get_factory()._read_snowflake(*args, **kwargs) + + @classmethod + @_inherit_docstrings(factories.BaseFactory._to_snowflake) + def to_snowflake(cls, *args, **kwargs): + return cls.get_factory()._to_snowflake(*args, **kwargs) diff --git a/src/snowflake/snowpark/modin/core/execution/dispatching/factories/factories.py b/src/snowflake/snowpark/modin/core/execution/dispatching/factories/factories.py new file mode 100644 index 00000000000..72c2847c9ff --- /dev/null +++ b/src/snowflake/snowpark/modin/core/execution/dispatching/factories/factories.py @@ -0,0 +1,466 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +""" +Module contains Factories for all of the supported Modin executions. + +Factory is a bridge between calls of IO function from high-level API and its +actual implementation in the execution, bound to that factory. Each execution is represented +with a Factory class. +""" + +import re +import typing + +import pandas +from pandas.util._decorators import doc + +from snowflake.snowpark.modin.core.execution.dispatching.factories.baseio import BaseIO +from snowflake.snowpark.modin.plugin.io.snow_io import PandasOnSnowflakeIO + +_doc_abstract_factory_class = """ +Abstract {role} factory which allows to override the IO module easily. + +This class is responsible for dispatching calls of IO-functions to its +actual execution-specific implementations. + +Attributes +---------- +io_cls : BaseIO + IO module class of the underlying execution. The place to dispatch calls to. +""" + +_doc_factory_class = """ +Factory of {execution_name} execution. + +This class is responsible for dispatching calls of IO-functions to its +actual execution-specific implementations. + +Attributes +---------- +io_cls : {execution_name}IO + IO module class of the underlying execution. The place to dispatch calls to. +""" + +_doc_factory_prepare_method = """ +Initialize Factory. + +Fills in `.io_cls` class attribute with {io_module_name} lazily. +""" + +_doc_io_method_raw_template = """ +Build query compiler from {source}. + +Parameters +---------- +{params} + +Returns +------- +QueryCompiler + Query compiler of the selected storage format. +""" + +_doc_io_method_template = ( + _doc_io_method_raw_template + + """ +See Also +-------- +modin.pandas.{method} +""" +) + +_doc_io_method_all_params = """*args : args + Arguments to pass to the QueryCompiler builder method. +**kwargs : kwargs + Arguments to pass to the QueryCompiler builder method.""" + +_doc_io_method_kwargs_params = """**kwargs : kwargs + Arguments to pass to the QueryCompiler builder method.""" + + +types_dictionary = {"pandas": {"category": pandas.CategoricalDtype}} + + +class FactoryInfo(typing.NamedTuple): + """ + Structure that stores information about factory. + + Parameters + ---------- + engine : str + Name of underlying execution engine. + partition : str + Name of the partition format. + experimental : bool + Whether underlying engine is experimental-only. + """ + + engine: str + partition: str + experimental: bool + + +class NotRealFactory(Exception): + """ + ``NotRealFactory`` exception class. + + Raise when no matching factory could be found. + """ + + pass + + +@doc(_doc_abstract_factory_class, role="") +class BaseFactory: + io_cls: type[BaseIO] = None # The module where the I/O functionality exists. + + @classmethod + def get_info(cls) -> FactoryInfo: + """ + Get information about current factory. + + Notes + ----- + It parses factory name, so it must be conformant with how ``FactoryDispatcher`` + class constructs factory names. + """ + try: + experimental, partition, engine = re.match( + r"^(Experimental)?(.*)On(.*)Factory$", cls.__name__ + ).groups() + except AttributeError: + raise NotRealFactory() + return FactoryInfo( + engine=engine, partition=partition, experimental=bool(experimental) + ) + + @classmethod + @doc( + _doc_factory_prepare_method, + io_module_name="an underlying execution's IO-module", + ) + def prepare(cls): + raise NotImplementedError("Subclasses of BaseFactory must implement prepare") + + @classmethod + @doc( + _doc_io_method_template, + source="pandas DataFrame", + params="df : pandas.DataFrame", + method="utils.from_pandas", + ) + def _from_pandas(cls, df): + return cls.io_cls.from_pandas(df) + + @classmethod + @doc( + _doc_io_method_template, + source="Arrow Table", + params="at : pyarrow.Table", + method="utils.from_arrow", + ) + def _from_arrow(cls, at): + return cls.io_cls.from_arrow(at) + + @classmethod + @doc( + _doc_io_method_template, + source="a non-pandas object (dict, list, np.array etc...)", + params=_doc_io_method_all_params, + method="utils.from_non_pandas", + ) + def _from_non_pandas(cls, *args, **kwargs): + return cls.io_cls.from_non_pandas(*args, **kwargs) + + @classmethod + @doc( + _doc_io_method_template, + source="a DataFrame object supporting exchange protocol `__dataframe__()`", + params=_doc_io_method_all_params, + method="utils.from_dataframe", + ) + def _from_dataframe(cls, *args, **kwargs): + return cls.io_cls.from_dataframe(*args, **kwargs) + + @classmethod + @doc( + _doc_io_method_template, + source="a Parquet file", + params=_doc_io_method_kwargs_params, + method="read_parquet", + ) + def _read_parquet(cls, **kwargs): + return cls.io_cls.read_parquet(**kwargs) + + @classmethod + @doc( + _doc_io_method_template, + source="a CSV file", + params=_doc_io_method_kwargs_params, + method="read_csv", + ) + def _read_csv(cls, **kwargs): + return cls.io_cls.read_csv(**kwargs) + + @classmethod + @doc( + _doc_io_method_template, + source="a JSON file", + params=_doc_io_method_kwargs_params, + method="read_json", + ) + def _read_json(cls, **kwargs): + return cls.io_cls.read_json(**kwargs) + + @classmethod + @doc( + _doc_io_method_template, + source="a Google BigQuery", + params=_doc_io_method_kwargs_params, + method="read_gbq", + ) + def _read_gbq(cls, **kwargs): + return cls.io_cls.read_gbq(**kwargs) + + @classmethod + @doc( + _doc_io_method_template, + source="an HTML document", + params=_doc_io_method_kwargs_params, + method="read_html", + ) + def _read_html(cls, **kwargs): + return cls.io_cls.read_html(**kwargs) + + @classmethod + @doc( + _doc_io_method_template, + source="clipboard", + params=_doc_io_method_kwargs_params, + method="read_clipboard", + ) + def _read_clipboard(cls, **kwargs): # pragma: no cover + return cls.io_cls.read_clipboard(**kwargs) + + @classmethod + @doc( + _doc_io_method_template, + source="an Excel file", + params=_doc_io_method_kwargs_params, + method="read_excel", + ) + def _read_excel(cls, **kwargs): + return cls.io_cls.read_excel(**kwargs) + + @classmethod + @doc( + _doc_io_method_template, + source="an HDFStore", + params=_doc_io_method_kwargs_params, + method="read_hdf", + ) + def _read_hdf(cls, **kwargs): + return cls.io_cls.read_hdf(**kwargs) + + @classmethod + @doc( + _doc_io_method_template, + source="a feather-format object", + params=_doc_io_method_kwargs_params, + method="read_feather", + ) + def _read_feather(cls, **kwargs): + return cls.io_cls.read_feather(**kwargs) + + @classmethod + @doc( + _doc_io_method_template, + source="a Stata file", + params=_doc_io_method_kwargs_params, + method="read_stata", + ) + def _read_stata(cls, **kwargs): + return cls.io_cls.read_stata(**kwargs) + + @classmethod + @doc( + _doc_io_method_template, + source="a SAS file", + params=_doc_io_method_kwargs_params, + method="read_sas", + ) + def _read_sas(cls, **kwargs): # pragma: no cover + return cls.io_cls.read_sas(**kwargs) + + @classmethod + @doc( + _doc_io_method_template, + source="a pickled Modin or pandas DataFrame", + params=_doc_io_method_kwargs_params, + method="read_pickle", + ) + def _read_pickle(cls, **kwargs): + return cls.io_cls.read_pickle(**kwargs) + + @classmethod + @doc( + _doc_io_method_template, + source="a SQL query or database table", + params=_doc_io_method_kwargs_params, + method="read_sql", + ) + def _read_sql(cls, **kwargs): + return cls.io_cls.read_sql(**kwargs) + + @classmethod + @doc( + _doc_io_method_template, + source="a table of fixed-width formatted lines", + params=_doc_io_method_kwargs_params, + method="read_fwf", + ) + def _read_fwf(cls, **kwargs): + return cls.io_cls.read_fwf(**kwargs) + + @classmethod + @doc( + _doc_io_method_template, + source="a SQL database table", + params=_doc_io_method_kwargs_params, + method="read_sql_table", + ) + def _read_sql_table(cls, **kwargs): + return cls.io_cls.read_sql_table(**kwargs) + + @classmethod + @doc( + _doc_io_method_template, + source="a SQL query", + params=_doc_io_method_kwargs_params, + method="read_sql_query", + ) + def _read_sql_query(cls, **kwargs): + return cls.io_cls.read_sql_query(**kwargs) + + @classmethod + @doc( + _doc_io_method_template, + source="an SPSS file", + params=_doc_io_method_kwargs_params, + method="read_spss", + ) + def _read_spss(cls, **kwargs): + return cls.io_cls.read_spss(**kwargs) + + @classmethod + def _to_sql(cls, *args, **kwargs): + """ + Write query compiler content to a SQL database. + + Parameters + ---------- + *args : args + Arguments to the writer method. + **kwargs : kwargs + Arguments to the writer method. + """ + return cls.io_cls.to_sql(*args, **kwargs) + + @classmethod + def _to_pickle(cls, *args, **kwargs): + """ + Pickle query compiler object. + + Parameters + ---------- + *args : args + Arguments to the writer method. + **kwargs : kwargs + Arguments to the writer method. + """ + return cls.io_cls.to_pickle(*args, **kwargs) + + @classmethod + def _to_csv(cls, *args, **kwargs): + """ + Write query compiler content to a CSV file. + + Parameters + ---------- + *args : args + Arguments to pass to the writer method. + **kwargs : kwargs + Arguments to pass to the writer method. + """ + return cls.io_cls.to_csv(*args, **kwargs) + + @classmethod + def _to_parquet(cls, *args, **kwargs): + """ + Write query compiler content to a parquet file. + + Parameters + ---------- + *args : args + Arguments to pass to the writer method. + **kwargs : kwargs + Arguments to pass to the writer method. + """ + return cls.io_cls.to_parquet(*args, **kwargs) + + # following are snowflake specific functions, could also basically subclass BaseFactory into SnowflakeFactory an + # add the methods there. + @classmethod + @doc( + _doc_io_method_template, + source="read from Snowflake table", + params=_doc_io_method_kwargs_params, + method="read_snowflake", + ) + def _read_snowflake(cls, *args, **kwargs): + return cls.io_cls.read_snowflake(*args, **kwargs) + + @classmethod + @doc( + _doc_io_method_template, + source="save to Snowflake table", + params=_doc_io_method_kwargs_params, + method="to_snowflake", + ) + def _to_snowflake(cls, *args, **kwargs): + return cls.io_cls.to_snowflake(*args, **kwargs) + + @classmethod + def _to_local(cls, *args, **kwargs): + return cls.io_cls.to_local(*args, **kwargs) + + @classmethod + def _to_remote(cls, *args, **kwargs): + return cls.io_cls.to_remote(*args, **kwargs) + + +@doc(_doc_factory_class, backend_name="PandasOnSnowflake", execution_name="Snowflake") +class PandasOnSnowflakeFactory(BaseFactory): + @classmethod + @doc(_doc_factory_prepare_method, io_module_name="``PandasOnSnowflakeIO``") + def prepare(cls): + cls.io_cls = PandasOnSnowflakeIO diff --git a/src/snowflake/snowpark/modin/core/execution/dispatching/factories/pandasframe.py b/src/snowflake/snowpark/modin/core/execution/dispatching/factories/pandasframe.py new file mode 100644 index 00000000000..d3fce3a8fb2 --- /dev/null +++ b/src/snowflake/snowpark/modin/core/execution/dispatching/factories/pandasframe.py @@ -0,0 +1,3491 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +""" +Module contains class PandasDataframe. + +PandasDataframe is a parent abstract class for any dataframe class +for pandas storage format. +""" +import datetime +from collections import OrderedDict +from collections.abc import Hashable +from enum import Enum +from typing import Callable, Optional, Union + +import numpy as np +import pandas +from pandas._libs.lib import no_default +from pandas._typing import npt +from pandas.api.types import is_object_dtype +from pandas.core.dtypes.cast import find_common_type +from pandas.core.dtypes.common import is_list_like, is_numeric_dtype +from pandas.core.indexes.api import Index, RangeIndex, ensure_index + +from snowflake.snowpark.modin.pandas.indexing import is_range_like +from snowflake.snowpark.modin.pandas.utils import ( + check_both_not_none, + is_full_grab_slice, +) +from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage +from snowflake.snowpark.modin.utils import ( + MODIN_UNNAMED_SERIES_LABEL, # pragma: no cover +) + + +# from https://github.com/modin-project/modin/blob/master/modin/core/dataframe/base/dataframe/utils.py +class Axis(Enum): # noqa: PR01 + """ + An enum that represents the `axis` argument provided to the algebra operators. + The enum has 3 values - ROW_WISE to represent the row axis, COL_WISE to represent the + column axis, and CELL_WISE to represent no axis. ROW_WISE operations iterate over the rows + COL_WISE operations over the columns, and CELL_WISE operations over any of the partitioning + schemes that are supported in Modin (row-wise, column-wise, or block-wise). + """ + + ROW_WISE = 0 + COL_WISE = 1 + CELL_WISE = None + + +class JoinType(Enum): # noqa: PR01 + """ + An enum that represents the `join_type` argument provided to the algebra operators. + The enum has 4 values - INNER to represent inner joins, LEFT to represent left joins, RIGHT to + represent right joins, and OUTER to represent outer joins. + """ + + INNER = "inner" + LEFT = "left" + RIGHT = "right" + OUTER = "outer" + + +# from modin.core.storage_formats.pandas.query_compiler import PandasQueryCompiler +# from modin.core.storage_formats.pandas.utils import get_length_list +# from modin.error_message import ErrorMessage +# from modin.core.storage_formats.pandas.parsers import ( +# find_common_type_cat as find_common_type, +# ) +# from modin.core.dataframe.base.dataframe.dataframe import ModinDataframe +# from modin.core.dataframe.base.dataframe.utils import ( +# Axis, +# JoinType, +# ) +# from modin.core.dataframe.pandas.dataframe.utils import build_sort_functions +# +# if TYPE_CHECKING: +# from modin.core.dataframe.base.interchange.dataframe_protocol.dataframe import ( +# ProtocolDataframe, +# ) +# from pandas._typing import npt +# +# from modin.pandas.indexing import is_range_like +# from modin.pandas.utils import is_full_grab_slice, check_both_not_none +# from modin.logging import ClassLogger +# from modin.utils import MODIN_UNNAMED_SERIES_LABEL + + +def lazy_metadata_decorator(apply_axis=None, axis_arg=-1, transpose=False): + """ + Lazily propagate metadata for the ``PandasDataframe``. + + This decorator first adds the minimum required reindexing operations + to each partition's queue of functions to be lazily applied for + each PandasDataframe in the arguments by applying the function + run_f_on_minimally_updated_metadata. The decorator also sets the + flags for deferred metadata synchronization on the function result + if necessary. + + Parameters + ---------- + apply_axis : str, default: None + The axes on which to apply the reindexing operations to the `self._partitions` lazily. + Case None: No lazy metadata propagation. + Case "both": Add reindexing operations on both axes to partition queue. + Case "opposite": Add reindexing operations complementary to given axis. + Case "rows": Add reindexing operations on row axis to partition queue. + axis_arg : int, default: -1 + The index or column axis. + transpose : bool, default: False + Boolean for if a transpose operation is being used. + + Returns + ------- + Wrapped Function. + """ + + def decorator(f): + from functools import wraps + + @wraps(f) + def run_f_on_minimally_updated_metadata(self, *args, **kwargs): + for obj in ( + [self] + + [o for o in args if isinstance(o, PandasDataframe)] + + [v for v in kwargs.values() if isinstance(v, PandasDataframe)] + + [ + d + for o in args + if isinstance(o, list) + for d in o + if isinstance(d, PandasDataframe) + ] + + [ + d + for _, o in kwargs.items() + if isinstance(o, list) + for d in o + if isinstance(d, PandasDataframe) + ] + ): + if apply_axis == "both": + if obj._deferred_index and obj._deferred_column: + obj._propagate_index_objs(axis=None) + elif obj._deferred_index: + obj._propagate_index_objs(axis=0) + elif obj._deferred_column: + obj._propagate_index_objs(axis=1) + elif apply_axis == "opposite": + if "axis" not in kwargs: + axis = args[axis_arg] + else: + axis = kwargs["axis"] + if axis == 0 and obj._deferred_column: + obj._propagate_index_objs(axis=1) + elif axis == 1 and obj._deferred_index: + obj._propagate_index_objs(axis=0) + elif apply_axis == "rows": + obj._propagate_index_objs(axis=0) + result = f(self, *args, **kwargs) + if apply_axis is None and not transpose: + result._deferred_index = self._deferred_index + result._deferred_column = self._deferred_column + elif apply_axis is None and transpose: + result._deferred_index = self._deferred_column + result._deferred_column = self._deferred_index + elif apply_axis == "opposite": + if axis == 0: + result._deferred_index = self._deferred_index + else: + result._deferred_column = self._deferred_column + elif apply_axis == "rows": + result._deferred_column = self._deferred_column + return result + + return run_f_on_minimally_updated_metadata + + return decorator + + +class PandasDataframe: + """ + An abstract class that represents the parent class for any pandas storage format dataframe class. + + This class provides interfaces to run operations on dataframe partitions. + + Parameters + ---------- + partitions : np.ndarray + A 2D NumPy array of partitions. + index : sequence, optional + The index for the dataframe. Converted to a ``pandas.Index``. + Is computed from partitions on demand if not specified. + columns : sequence, optional + The columns object for the dataframe. Converted to a ``pandas.Index``. + Is computed from partitions on demand if not specified. + row_lengths : list, optional + The length of each partition in the rows. The "height" of + each of the block partitions. Is computed if not provided. + column_widths : list, optional + The width of each partition in the columns. The "width" of + each of the block partitions. Is computed if not provided. + dtypes : pandas.Series, optional + The data types for the dataframe columns. + """ + + _partition_mgr_cls = None + _query_compiler_cls = None # PandasQueryCompiler + # These properties flag whether or not we are deferring the metadata synchronization + _deferred_index = False + _deferred_column = False + + @pandas.util.cache_readonly + def __constructor__(self): + """ + Create a new instance of this object. + + Returns + ------- + PandasDataframe + """ + return type(self) + + def __init__( + self, + partitions, + index=None, + columns=None, + row_lengths=None, + column_widths=None, + dtypes=None, + ) -> None: + self._partitions = partitions + self._index_cache = ensure_index(index) if index is not None else None + self._columns_cache = ensure_index(columns) if columns is not None else None + self._row_lengths_cache = row_lengths + self._column_widths_cache = column_widths + self._dtypes = dtypes + + self._validate_axes_lengths() + self._filter_empties(compute_metadata=False) + + def _validate_axes_lengths(self): + """Validate that labels are split correctly if split is known.""" + pass + # if self._row_lengths_cache is not None and len(self.index) > 0: + # # An empty frame can have 0 rows but a nonempty index. If the frame + # # does have rows, the number of rows must equal the size of the + # # index. + # num_rows = sum(self._row_lengths_cache) + # if num_rows > 0: + # ErrorMessage.catch_bugs_and_request_email( + # num_rows != len(self._index_cache), + # f"Row lengths: {num_rows} != {len(self._index_cache)}", + # ) + # ErrorMessage.catch_bugs_and_request_email( + # any(val < 0 for val in self._row_lengths_cache), + # f"Row lengths cannot be negative: {self._row_lengths_cache}", + # ) + # if self._column_widths_cache is not None and len(self.columns) > 0: + # # An empty frame can have 0 column but a nonempty column index. If + # # the frame does have columns, the number of columns must equal the + # # size of the columns. + # num_columns = sum(self._column_widths_cache) + # if num_columns > 0: + # ErrorMessage.catch_bugs_and_request_email( + # num_columns != len(self._columns_cache), + # f"Column widths: {num_columns} != {len(self._columns_cache)}", + # ) + # ErrorMessage.catch_bugs_and_request_email( + # any(val < 0 for val in self._column_widths_cache), + # f"Column widths cannot be negative: {self._column_widths_cache}", + # ) + + @property + def row_lengths(self): + """ + Compute the row partitions lengths if they are not cached. + + Returns + ------- + list + A list of row partitions lengths. + """ + if self._row_lengths_cache is None: + if len(self._partitions.T) > 0: + row_parts = self._partitions.T[0] + self._row_lengths_cache = [part.length() for part in row_parts] + else: + self._row_lengths_cache = [] + return self._row_lengths_cache + + @property + def column_widths(self): + """ + Compute the column partitions widths if they are not cached. + + Returns + ------- + list + A list of column partitions widths. + """ + if self._column_widths_cache is None: + if len(self._partitions) > 0: + col_parts = self._partitions[0] + self._column_widths_cache = [part.width() for part in col_parts] + else: + self._column_widths_cache = [] + return self._column_widths_cache + + @property + def _axes_lengths(self): + """ + Get a pair of row partitions lengths and column partitions widths. + + Returns + ------- + list + The pair of row partitions lengths and column partitions widths. + """ + return [self.row_lengths, self.column_widths] + + @property + def dtypes(self): + """ + Compute the data types if they are not cached. + + Returns + ------- + pandas.Series + A pandas Series containing the data types for this dataframe. + """ + if self._dtypes is None: + self._dtypes = self._compute_dtypes() + return self._dtypes + + def _compute_dtypes(self): + """ + Compute the data types via TreeReduce pattern. + + Returns + ------- + pandas.Series + A pandas Series containing the data types for this dataframe. + """ + + def dtype_builder(df): + return df.apply(lambda col: find_common_type(col.values), axis=0) + + # For now we will use a pandas Series for the dtypes. + if len(self.columns) > 0: + dtypes = ( + self.tree_reduce(0, lambda df: df.dtypes, dtype_builder) + .to_pandas() + .iloc[0] + ) + else: + dtypes = pandas.Series([]) + # reset name to None because we use MODIN_UNNAMED_SERIES_LABEL internally + dtypes.name = None + return dtypes + + _index_cache = None + _columns_cache = None + + def _validate_set_axis(self, new_labels, old_labels): + """ + Validate the possibility of replacement of old labels with the new labels. + + Parameters + ---------- + new_labels : list-like + The labels to replace with. + old_labels : list-like + The labels to replace. + + Returns + ------- + list-like + The validated labels. + """ + new_labels = ensure_index(new_labels) + old_len = len(old_labels) + new_len = len(new_labels) + if old_len != new_len: + raise ValueError( + f"Length mismatch: Expected axis has {old_len} elements, " + + f"new values have {new_len} elements" + ) + return new_labels + + def _get_index(self): + """ + Get the index from the cache object. + + Returns + ------- + pandas.Index + An index object containing the row labels. + """ + if self._index_cache is None: + self._index_cache, row_lengths = self._compute_axis_labels_and_lengths(0) + if self._row_lengths_cache is None: + self._row_lengths_cache = row_lengths + return self._index_cache + + def _get_columns(self): + """ + Get the columns from the cache object. + + Returns + ------- + pandas.Index + An index object containing the column labels. + """ + if self._columns_cache is None: + self._columns_cache, column_widths = self._compute_axis_labels_and_lengths( + 1 + ) + if self._column_widths_cache is None: + self._column_widths_cache = column_widths + return self._columns_cache + + def _set_index(self, new_index): + """ + Replace the current row labels with new labels. + + Parameters + ---------- + new_index : list-like + The new row labels. + """ + if self._index_cache is None: + self._index_cache = ensure_index(new_index) + else: + new_index = self._validate_set_axis(new_index, self._index_cache) + self._index_cache = new_index + self.synchronize_labels(axis=0) + + def _set_columns(self, new_columns): + """ + Replace the current column labels with new labels. + + Parameters + ---------- + new_columns : list-like + The new column labels. + """ + if self._columns_cache is None: + self._columns_cache = ensure_index(new_columns) + else: + new_columns = self._validate_set_axis(new_columns, self._columns_cache) + self._columns_cache = new_columns + if self._dtypes is not None: + self._dtypes.index = new_columns + self.synchronize_labels(axis=1) + + columns = property(_get_columns, _set_columns) + index = property(_get_index, _set_index) + + @property + def axes(self): + """ + Get index and columns that can be accessed with an `axis` integer. + + Returns + ------- + list + List with two values: index and columns. + """ + return [self.index, self.columns] + + def _compute_axis_labels_and_lengths(self, axis: int, partitions=None): + """ + Compute the labels for specific `axis`. + + Parameters + ---------- + axis : int + Axis to compute labels along. + partitions : np.ndarray, optional + A 2D NumPy array of partitions from which labels will be grabbed. + If not specified, partitions will be taken from `self._partitions`. + + Returns + ------- + pandas.Index + Labels for the specified `axis`. + List of int + Size of partitions alongside specified `axis`. + """ + if partitions is None: + partitions = self._partitions + new_index, internal_idx = self._partition_mgr_cls.get_indices(axis, partitions) + return new_index, list(map(len, internal_idx)) + + def _filter_empties(self, compute_metadata=True): + """ + Remove empty partitions from `self._partitions` to avoid triggering excess computation. + + Parameters + ---------- + compute_metadata : bool, default: True + Trigger the computations for partition sizes and labels if they're not done already. + """ + if not compute_metadata and ( + self._index_cache is None + or self._columns_cache is None + or self._row_lengths_cache is None + or self._column_widths_cache is None + ): + # do not trigger the computations + return + + if len(self.axes[0]) == 0 or len(self.axes[1]) == 0: + # This is the case for an empty frame. We don't want to completely remove + # all metadata and partitions so for the moment, we won't prune if the frame + # is empty. + # TODO: Handle empty dataframes better + return + self._partitions = np.array( + [ + [ + self._partitions[i][j] + for j in range(len(self._partitions[i])) + if j < len(self.column_widths) and self.column_widths[j] != 0 + ] + for i in range(len(self._partitions)) + if i < len(self.row_lengths) and self.row_lengths[i] != 0 + ] + ) + self._column_widths_cache = [w for w in self.column_widths if w != 0] + self._row_lengths_cache = [r for r in self.row_lengths if r != 0] + + def synchronize_labels(self, axis=None): + """ + Set the deferred axes variables for the ``PandasDataframe``. + + Parameters + ---------- + axis : int, default: None + The deferred axis. + 0 for the index, 1 for the columns. + """ + if axis is None: + self._deferred_index = True + self._deferred_column = True + elif axis == 0: + self._deferred_index = True + else: + self._deferred_column = True + + def _propagate_index_objs(self, axis=None): + """ + Synchronize labels by applying the index object for specific `axis` to the `self._partitions` lazily. + + Adds `set_axis` function to call-queue of each partition from `self._partitions` + to apply new axis. + + Parameters + ---------- + axis : int, default: None + The axis to apply to. If it's None applies to both axes. + """ + self._filter_empties() + if axis is None or axis == 0: + cum_row_lengths = np.cumsum([0] + self.row_lengths) + if axis is None or axis == 1: + cum_col_widths = np.cumsum([0] + self.column_widths) + + if axis is None: + + def apply_idx_objs(df, idx, cols): + return df.set_axis(idx, axis="index").set_axis(cols, axis="columns") + + self._partitions = np.array( + [ + [ + self._partitions[i][j].add_to_apply_calls( + apply_idx_objs, + idx=self.index[ + slice(cum_row_lengths[i], cum_row_lengths[i + 1]) + ], + cols=self.columns[ + slice(cum_col_widths[j], cum_col_widths[j + 1]) + ], + length=self.row_lengths[i], + width=self.column_widths[j], + ) + for j in range(len(self._partitions[i])) + ] + for i in range(len(self._partitions)) + ] + ) + self._deferred_index = False + self._deferred_column = False + elif axis == 0: + + def apply_idx_objs(df, idx): + return df.set_axis(idx, axis="index") + + self._partitions = np.array( + [ + [ + self._partitions[i][j].add_to_apply_calls( + apply_idx_objs, + idx=self.index[ + slice(cum_row_lengths[i], cum_row_lengths[i + 1]) + ], + length=self.row_lengths[i], + width=self.column_widths[j], + ) + for j in range(len(self._partitions[i])) + ] + for i in range(len(self._partitions)) + ] + ) + self._deferred_index = False + elif axis == 1: + + def apply_idx_objs(df, cols): + return df.set_axis(cols, axis="columns") + + self._partitions = np.array( + [ + [ + self._partitions[i][j].add_to_apply_calls( + apply_idx_objs, + cols=self.columns[ + slice(cum_col_widths[j], cum_col_widths[j + 1]) + ], + length=self.row_lengths[i], + width=self.column_widths[j], + ) + for j in range(len(self._partitions[i])) + ] + for i in range(len(self._partitions)) + ] + ) + self._deferred_column = False + else: + ErrorMessage.internal_error(axis is not None and axis not in [0, 1]) + + @lazy_metadata_decorator(apply_axis=None) + def take_2d_labels_or_positional( + self, + row_labels: Optional[list[Hashable]] = None, + row_positions: Optional[list[int]] = None, + col_labels: Optional[list[Hashable]] = None, + col_positions: Optional[list[int]] = None, + ) -> "PandasDataframe": + """ + Lazily select columns or rows from given indices. + + Parameters + ---------- + row_labels : list of hashable, optional + The row labels to extract. + row_positions : list-like of ints, optional + The row positions to extract. + col_labels : list of hashable, optional + The column labels to extract. + col_positions : list-like of ints, optional + The column positions to extract. + + Returns + ------- + PandasDataframe + A new PandasDataframe from the mask provided. + + Notes + ----- + If both `row_labels` and `row_positions` are provided, a ValueError is raised. + The same rule applies for `col_labels` and `col_positions`. + """ + if check_both_not_none(row_labels, row_positions): + raise ValueError( + "Both row_labels and row_positions were provided - " + + "please provide only one of row_labels and row_positions." + ) + if check_both_not_none(col_labels, col_positions): + raise ValueError( + "Both col_labels and col_positions were provided - " + + "please provide only one of col_labels and col_positions." + ) + + if row_labels is not None: + # Get numpy array of positions of values from `row_labels` + if isinstance(self.index, pandas.MultiIndex): + row_positions = np.zeros(len(row_labels), dtype="int64") + # we can't use .get_locs(row_labels) because the function + # requires a different format for row_labels + for idx, label in enumerate(row_labels): + if isinstance(label, str): + label = [label] + # get_loc can return slice that _take_2d_positional can't handle + row_positions[idx] = self.index.get_locs(label)[0] + else: + row_positions = self.index.get_indexer_for(row_labels) + + if col_labels is not None: + # Get numpy array of positions of values from `col_labels` + if isinstance(self.columns, pandas.MultiIndex): + col_positions = np.zeros(len(col_labels), dtype="int64") + # we can't use .get_locs(col_labels) because the function + # requires a different format for row_labels + for idx, label in enumerate(col_labels): + if isinstance(label, str): + label = [label] + # get_loc can return slice that _take_2d_positional can't handle + col_positions[idx] = self.columns.get_locs(label)[0] + else: + col_positions = self.columns.get_indexer_for(col_labels) + + return self._take_2d_positional(row_positions, col_positions) + + def _get_sorted_positions(self, positions): + """ + Sort positions if necessary. + + Parameters + ---------- + positions : Sequence[int] + + Returns + ------- + Sequence[int] + """ + # Helper for take_2d_positional + if is_range_like(positions) and positions.step > 0: + sorted_positions = positions + else: + sorted_positions = np.sort(positions) + return sorted_positions + + def _get_new_lengths(self, partitions_dict, *, axis: int) -> list[int]: + """ + Find lengths of new partitions. + + Parameters + ---------- + partitions_dict : dict + axis : int + + Returns + ------- + list[int] + """ + # Helper for take_2d_positional + if axis == 0: + axis_lengths = self.row_lengths + else: + axis_lengths = self.column_widths + + new_lengths = [ + len( + # Row lengths for slice are calculated as the length of the slice + # on the partition. Often this will be the same length as the current + # length, but sometimes it is different, thus the extra calculation. + range(*part_indexer.indices(axis_lengths[part_idx])) + if isinstance(part_indexer, slice) + else part_indexer + ) + for part_idx, part_indexer in partitions_dict.items() + ] + return new_lengths + + def _get_new_index_obj( + self, positions, sorted_positions, axis: int + ) -> "tuple[pandas.Index, slice | npt.NDArray[np.intp]]": + """ + Find the new Index object for take_2d_positional result. + + Parameters + ---------- + positions : Sequence[int] + sorted_positions : Sequence[int] + axis : int + + Returns + ------- + pandas.Index + slice or Sequence[int] + """ + # Helper for take_2d_positional + # Use the slice to calculate the new columns + if axis == 0: + idx = self.index + else: + idx = self.columns + + # TODO: Support fast processing of negative-step ranges + if is_range_like(positions) and positions.step > 0: + # pandas Index is more likely to preserve its metadata if the indexer + # is slice + monotonic_idx = slice(positions.start, positions.stop, positions.step) + else: + monotonic_idx = np.asarray(sorted_positions, dtype=np.intp) + + new_idx = idx[monotonic_idx] + return new_idx, monotonic_idx + + def _take_2d_positional( + self, + row_positions: Optional[list[int]] = None, + col_positions: Optional[list[int]] = None, + ) -> "PandasDataframe": + """ + Lazily select columns or rows from given indices. + + Parameters + ---------- + row_positions : list-like of ints, optional + The row positions to extract. + col_positions : list-like of ints, optional + The column positions to extract. + + Returns + ------- + PandasDataframe + A new PandasDataframe from the mask provided. + """ + indexers = [] + for axis, indexer in enumerate((row_positions, col_positions)): + if is_range_like(indexer): + if indexer.step == 1 and len(indexer) == len(self.axes[axis]): + # By this function semantics, `None` indexer is a full-axis access + indexer = None + elif indexer is not None and not isinstance(indexer, pandas.RangeIndex): + # Pure python's range is not fully compatible with a list of ints, + # converting it to ``pandas.RangeIndex``` that is compatible. + indexer = pandas.RangeIndex( + indexer.start, indexer.stop, indexer.step + ) + else: + ErrorMessage.internal_error( + failure_condition=not (indexer is None or is_list_like(indexer)), + extra_log="Mask takes only list-like numeric indexers, " + + f"received: {type(indexer)}", + ) + indexers.append(indexer) + row_positions, col_positions = indexers + + if col_positions is None and row_positions is None: + return self.copy() + + sorted_row_positions = sorted_col_positions = None + + if row_positions is not None: + sorted_row_positions = self._get_sorted_positions(row_positions) + # Get dict of row_parts as {row_index: row_internal_indices} + row_partitions_dict = self._get_dict_of_block_index( + 0, sorted_row_positions, are_indices_sorted=True + ) + new_row_lengths = self._get_new_lengths(row_partitions_dict, axis=0) + new_index, _ = self._get_new_index_obj( + row_positions, sorted_row_positions, axis=0 + ) + else: + row_partitions_dict = {i: slice(None) for i in range(len(self._partitions))} + new_row_lengths = self._row_lengths_cache + new_index = self._index_cache + + if col_positions is not None: + sorted_col_positions = self._get_sorted_positions(col_positions) + # Get dict of col_parts as {col_index: col_internal_indices} + col_partitions_dict = self._get_dict_of_block_index( + 1, sorted_col_positions, are_indices_sorted=True + ) + new_col_widths = self._get_new_lengths(col_partitions_dict, axis=1) + new_columns, monotonic_col_idx = self._get_new_index_obj( + col_positions, sorted_col_positions, axis=1 + ) + + ErrorMessage.internal_error( + failure_condition=sum(new_col_widths) != len(new_columns), + extra_log=f"{sum(new_col_widths)} != {len(new_columns)}.\n" + + f"{col_positions}\n{self.column_widths}\n{col_partitions_dict}", + ) + + if self._dtypes is not None: + new_dtypes = self.dtypes.iloc[monotonic_col_idx] + else: + new_dtypes = None + else: + col_partitions_dict = { + i: slice(None) for i in range(len(self._partitions.T)) + } + new_col_widths = self._column_widths_cache + new_columns = self._columns_cache + new_dtypes = self._dtypes + + new_partitions = np.array( + [ + [ + self._partitions[row_idx][col_idx].mask( + row_internal_indices, col_internal_indices + ) + for col_idx, col_internal_indices in col_partitions_dict.items() + ] + for row_idx, row_internal_indices in row_partitions_dict.items() + ] + ) + intermediate = self.__constructor__( + new_partitions, + new_index, + new_columns, + new_row_lengths, + new_col_widths, + new_dtypes, + ) + + return self._maybe_reorder_labels( + intermediate, + row_positions, + sorted_row_positions, + col_positions, + sorted_col_positions, + ) + + def _maybe_reorder_labels( + self, + intermediate: "PandasDataframe", + row_positions, + sorted_row_positions, + col_positions, + sorted_col_positions, + ) -> "PandasDataframe": + """ + Call re-order labels on take_2d_labels_or_positional result if necessary. + + Parameters + ---------- + intermediate : PandasDataFrame + row_positions : list-like of ints, optional + The row positions to extract. + sorted_row_positions : list-like of ints, optional + Sorted version of row_positions. + col_positions : list-like of ints, optional + The column positions to extract. + sorted_col_positions : list-like of ints, optional + Sorted version of col_positions. + + Returns + ------- + PandasDataframe + """ + # Check if monotonically increasing, return if it is. Fast track code path for + # common case to keep it fast. + if ( + row_positions is None + # Fast range processing of non-positive-step ranges is not yet supported + or (is_range_like(row_positions) and row_positions.step > 0) + or len(row_positions) == 1 + or np.all(row_positions[1:] >= row_positions[:-1]) + ) and ( + col_positions is None + # Fast range processing of non-positive-step ranges is not yet supported + or (is_range_like(col_positions) and col_positions.step > 0) + or len(col_positions) == 1 + or np.all(col_positions[1:] >= col_positions[:-1]) + ): + return intermediate + + # The new labels are often smaller than the old labels, so we can't reuse the + # original order values because those were mapped to the original data. We have + # to reorder here based on the expected order from within the data. + # We create a dictionary mapping the position of the numeric index with respect + # to all others, then recreate that order by mapping the new order values from + # the old. This information is sent to `_reorder_labels`. + if row_positions is not None: + row_order_mapping = dict( + zip(sorted_row_positions, range(len(row_positions))) + ) + new_row_order = [row_order_mapping[idx] for idx in row_positions] + else: + new_row_order = None + if col_positions is not None: + col_order_mapping = dict( + zip(sorted_col_positions, range(len(col_positions))) + ) + new_col_order = [col_order_mapping[idx] for idx in col_positions] + else: + new_col_order = None + return intermediate._reorder_labels( + row_positions=new_row_order, col_positions=new_col_order + ) + + @lazy_metadata_decorator(apply_axis="rows") + def from_labels(self) -> "PandasDataframe": + """ + Convert the row labels to a column of data, inserted at the first position. + + Gives result by similar way as `pandas.DataFrame.reset_index`. Each level + of `self.index` will be added as separate column of data. + + Returns + ------- + PandasDataframe + A PandasDataframe with new columns from index labels. + """ + new_row_labels = pandas.RangeIndex(len(self.index)) + if self.index.nlevels > 1: + level_names = [ + self.index.names[i] if self.index.names[i] is not None else f"level_{i}" + for i in range(self.index.nlevels) + ] + else: + level_names = [ + self.index.names[0] + if self.index.names[0] is not None + else "index" + if "index" not in self.columns + else f"level_{0}" + ] + new_dtypes = None + if self._dtypes is not None: + names = tuple(level_names) if len(level_names) > 1 else level_names[0] + new_dtypes = self.index.to_frame(name=names).dtypes + new_dtypes = pandas.concat([new_dtypes, self._dtypes]) + + # We will also use the `new_column_names` in the calculation of the internal metadata, so this is a + # lightweight way of ensuring the metadata matches. + if self.columns.nlevels > 1: + # Column labels are different for multilevel index. + new_column_names = pandas.MultiIndex.from_tuples( + # Set level names on the 1st columns level and fill up empty level names with empty string. + # Expand tuples in level names. This is how reset_index works when col_level col_fill are not specified. + [ + tuple( + list(level) + [""] * (self.columns.nlevels - len(level)) + if isinstance(level, tuple) + else [level] + [""] * (self.columns.nlevels - 1) + ) + for level in level_names + ], + names=self.columns.names, + ) + else: + new_column_names = pandas.Index(level_names, tupleize_cols=False) + new_columns = new_column_names.append(self.columns) + + def from_labels_executor(df, **kwargs): + # Setting the names here ensures that external and internal metadata always match. + df.index.names = new_column_names + + # Handling of a case when columns have the same name as one of index levels names. + # In this case `df.reset_index` provides errors related to columns duplication. + # This case is possible because columns metadata updating is deferred. To workaround + # `df.reset_index` error we allow columns duplication in "if" branch via `concat`. + if any(name_level in df.columns for name_level in df.index.names): + columns_to_add = df.index.to_frame() + columns_to_add.reset_index(drop=True, inplace=True) + df = df.reset_index(drop=True) + result = pandas.concat([columns_to_add, df], axis=1, copy=False) + else: + result = df.reset_index() + # Put the index back to the original due to GH#4394 + result.index = df.index + return result + + new_parts = self._partition_mgr_cls.apply_func_to_select_indices( + 0, + self._partitions, + from_labels_executor, + [0], + keep_remaining=True, + ) + new_column_widths = [ + self.index.nlevels + self.column_widths[0] + ] + self.column_widths[1:] + result = self.__constructor__( + new_parts, + new_row_labels, + new_columns, + row_lengths=self._row_lengths_cache, + column_widths=new_column_widths, + dtypes=new_dtypes, + ) + # Set flag for propagating deferred row labels across dataframe partitions + result.synchronize_labels(axis=0) + return result + + def to_labels(self, column_list: list[Hashable]) -> "PandasDataframe": + """ + Move one or more columns into the row labels. Previous labels are dropped. + + Parameters + ---------- + column_list : list of hashable + The list of column names to place as the new row labels. + + Returns + ------- + PandasDataframe + A new PandasDataframe that has the updated labels. + """ + extracted_columns = self.take_2d_labels_or_positional( + col_labels=column_list + ).to_pandas() + + if len(column_list) == 1: + new_labels = pandas.Index( + extracted_columns.squeeze(axis=1), name=column_list[0] + ) + else: + new_labels = pandas.MultiIndex.from_frame( + extracted_columns, names=column_list + ) + result = self.take_2d_labels_or_positional( + col_labels=[i for i in self.columns if i not in extracted_columns.columns] + ) + result.index = new_labels + return result + + @lazy_metadata_decorator(apply_axis="both") + def _reorder_labels(self, row_positions=None, col_positions=None): + """ + Reorder the column and or rows in this DataFrame. + + Parameters + ---------- + row_positions : list of int, optional + The ordered list of new row orders such that each position within the list + indicates the new position. + col_positions : list of int, optional + The ordered list of new column orders such that each position within the + list indicates the new position. + + Returns + ------- + PandasDataframe + A new PandasDataframe with reordered columns and/or rows. + """ + new_dtypes = self._dtypes + if row_positions is not None: + ordered_rows = self._partition_mgr_cls.map_axis_partitions( + 0, self._partitions, lambda df: df.iloc[row_positions] + ) + row_idx = self.index[row_positions] + + if self._partitions.shape[0] != ordered_rows.shape[0] or len( + row_idx + ) != len(self.index): + # The frame was re-partitioned along the 0 axis during reordering using + # the "standard" partitioning. Knowing the standard partitioning scheme + # we are able to compute new row lengths. + new_lengths = get_length_list( # noqa: F821 + axis_len=len(row_idx), num_splits=ordered_rows.shape[0] + ) + else: + # If the frame's partitioning was preserved then + # we can use previous row lengths cache + new_lengths = self._row_lengths_cache + else: + ordered_rows = self._partitions + row_idx = self.index + new_lengths = self._row_lengths_cache + if col_positions is not None: + ordered_cols = self._partition_mgr_cls.map_axis_partitions( + 1, ordered_rows, lambda df: df.iloc[:, col_positions] + ) + col_idx = self.columns[col_positions] + if new_dtypes is not None: + new_dtypes = self._dtypes.iloc[col_positions] + + if self._partitions.shape[1] != ordered_cols.shape[1] or len( + col_idx + ) != len(self.columns): + # The frame was re-partitioned along the 1 axis during reordering using + # the "standard" partitioning. Knowing the standard partitioning scheme + # we are able to compute new column widths. + new_widths = get_length_list( # noqa: F821 + axis_len=len(col_idx), num_splits=ordered_cols.shape[1] + ) + else: + # If the frame's partitioning was preserved then + # we can use previous column widths cache + new_widths = self._column_widths_cache + else: + ordered_cols = ordered_rows + col_idx = self.columns + new_widths = self._column_widths_cache + return self.__constructor__( + ordered_cols, row_idx, col_idx, new_lengths, new_widths, new_dtypes + ) + + @lazy_metadata_decorator(apply_axis=None) + def copy(self): + """ + Copy this object. + + Returns + ------- + PandasDataframe + A copied version of this object. + """ + return self.__constructor__( + self._partitions, + self._index_cache.copy() if self._index_cache is not None else None, + self._columns_cache.copy() if self._columns_cache is not None else None, + self._row_lengths_cache, + self._column_widths_cache, + self._dtypes.copy() if self._dtypes is not None else None, + ) + + @lazy_metadata_decorator(apply_axis="both") + def astype(self, col_dtypes): + """ + Convert the columns dtypes to given dtypes. + + Parameters + ---------- + col_dtypes : dictionary of {col: dtype,...} + Where col is the column name and dtype is a NumPy dtype. + + Returns + ------- + BaseDataFrame + Dataframe with updated dtypes. + """ + columns = col_dtypes.keys() + # Create Series for the updated dtypes + new_dtypes = self.dtypes.copy() + for _, column in enumerate(columns): + dtype = col_dtypes[column] + if ( + not isinstance(dtype, type(self.dtypes[column])) + or dtype != self.dtypes[column] + ): + # Update the new dtype series to the proper pandas dtype + try: + new_dtype = np.dtype(dtype) + except TypeError: + new_dtype = dtype + + if dtype != np.int32 and new_dtype == np.int32: + new_dtypes[column] = np.dtype("int64") + elif dtype != np.float32 and new_dtype == np.float32: + new_dtypes[column] = np.dtype("float64") + # We cannot infer without computing the dtype if + elif isinstance(new_dtype, str) and new_dtype == "category": + new_dtypes = None + break + else: + new_dtypes[column] = new_dtype + + def astype_builder(df): + """Compute new partition frame with dtypes updated.""" + return df.astype({k: v for k, v in col_dtypes.items() if k in df}) + + new_frame = self._partition_mgr_cls.map_partitions( + self._partitions, astype_builder + ) + return self.__constructor__( + new_frame, + self._index_cache, + self._columns_cache, + self._row_lengths_cache, + self._column_widths_cache, + new_dtypes, + ) + + # Metadata modification methods + def add_prefix(self, prefix, axis): + """ + Add a prefix to the current row or column labels. + + Parameters + ---------- + prefix : str + The prefix to add. + axis : int + The axis to update. + + Returns + ------- + PandasDataframe + A new dataframe with the updated labels. + """ + + def new_labels_mapper(x, prefix=str(prefix)): # noqa B008 + return prefix + str(x) + + if axis == 0: + return self.rename(new_row_labels=new_labels_mapper) + return self.rename(new_col_labels=new_labels_mapper) + + def add_suffix(self, suffix, axis): + """ + Add a suffix to the current row or column labels. + + Parameters + ---------- + suffix : str + The suffix to add. + axis : int + The axis to update. + + Returns + ------- + PandasDataframe + A new dataframe with the updated labels. + """ + + def new_labels_mapper(x, suffix=str(suffix)): # noqa B008 + return str(x) + suffix + + if axis == 0: + return self.rename(new_row_labels=new_labels_mapper) + return self.rename(new_col_labels=new_labels_mapper) + + # END Metadata modification methods + + def numeric_columns(self, include_bool=True): + """ + Return the names of numeric columns in the frame. + + Parameters + ---------- + include_bool : bool, default: True + Whether to consider boolean columns as numeric. + + Returns + ------- + list + List of column names. + """ + columns = [] + for col, dtype in zip(self.columns, self.dtypes): + if is_numeric_dtype(dtype) and ( + include_bool or (not include_bool and dtype != np.bool_) + ): + columns.append(col) + return columns + + def _get_dict_of_block_index(self, axis, indices, are_indices_sorted=False): + """ + Convert indices to an ordered dict mapping partition (or block) index to internal indices in said partition. + + Parameters + ---------- + axis : {0, 1} + The axis along which to get the indices (0 - rows, 1 - columns). + indices : list of int, slice + A list of global indices to convert. + are_indices_sorted : bool, default: False + Flag indicating whether the `indices` sequence is sorted by ascending or not. + Note: the internal algorithm requires for the `indices` to be sorted, this + flag is used for optimization in order to not sort already sorted data. + Be careful when passing ``True`` for this flag, if the data appears to be unsorted + with the flag set to ``True`` this would lead to undefined behavior. + + Returns + ------- + OrderedDict + A mapping from partition index to list of internal indices which correspond to `indices` in each + partition. + """ + # TODO: Support handling of slices with specified 'step'. For now, converting them into a range + if isinstance(indices, slice) and ( + indices.step is not None and indices.step != 1 + ): + indices = range(*indices.indices(len(self.axes[axis]))) + # Fasttrack slices + if isinstance(indices, slice) or (is_range_like(indices) and indices.step == 1): + # Converting range-like indexer to slice + indices = slice(indices.start, indices.stop, indices.step) + if is_full_grab_slice(indices, sequence_len=len(self.axes[axis])): + return OrderedDict( + zip( + range(self._partitions.shape[axis]), + [slice(None)] * self._partitions.shape[axis], + ) + ) + # Empty selection case + if indices.start == indices.stop and indices.start is not None: + return OrderedDict() + if indices.start is None or indices.start == 0: + last_part, last_idx = list( + self._get_dict_of_block_index(axis, [indices.stop]).items() + )[0] + dict_of_slices = OrderedDict( + zip(range(last_part), [slice(None)] * last_part) + ) + dict_of_slices.update({last_part: slice(last_idx[0])}) + return dict_of_slices + elif indices.stop is None or indices.stop >= len(self.axes[axis]): + first_part, first_idx = list( + self._get_dict_of_block_index(axis, [indices.start]).items() + )[0] + dict_of_slices = OrderedDict({first_part: slice(first_idx[0], None)}) + num_partitions = np.size(self._partitions, axis=axis) + part_list = range(first_part + 1, num_partitions) + dict_of_slices.update( + OrderedDict(zip(part_list, [slice(None)] * len(part_list))) + ) + return dict_of_slices + else: + first_part, first_idx = list( + self._get_dict_of_block_index(axis, [indices.start]).items() + )[0] + last_part, last_idx = list( + self._get_dict_of_block_index(axis, [indices.stop]).items() + )[0] + if first_part == last_part: + return OrderedDict({first_part: slice(first_idx[0], last_idx[0])}) + else: + if last_part - first_part == 1: + return OrderedDict( + # FIXME: this dictionary creation feels wrong - it might not maintain the order + { + first_part: slice(first_idx[0], None), + last_part: slice(None, last_idx[0]), + } + ) + else: + dict_of_slices = OrderedDict( + {first_part: slice(first_idx[0], None)} + ) + part_list = range(first_part + 1, last_part) + dict_of_slices.update( + OrderedDict(zip(part_list, [slice(None)] * len(part_list))) + ) + dict_of_slices.update({last_part: slice(None, last_idx[0])}) + return dict_of_slices + if isinstance(indices, list): + # Converting python list to numpy for faster processing + indices = np.array(indices, dtype=np.int64) + # Fasttrack empty numpy array + if isinstance(indices, np.ndarray) and indices.size == 0: + # This will help preserve metadata stored in empty dataframes (indexes and dtypes) + # Otherwise, we will get an empty `new_partitions` array, from which it will + # no longer be possible to obtain metadata + return OrderedDict([(0, np.array([], dtype=np.int64))]) + negative_mask = np.less(indices, 0) + has_negative = np.any(negative_mask) + if has_negative: + # We're going to modify 'indices' inplace in a numpy way, so doing a copy/converting indices to numpy. + indices = ( + indices.copy() + if isinstance(indices, np.ndarray) + else np.array(indices, dtype=np.int64) + ) + indices[negative_mask] = indices[negative_mask] % len(self.axes[axis]) + # If the `indices` array was modified because of the negative indices conversion + # then the original order was broken and so we have to sort anyway: + if has_negative or not are_indices_sorted: + indices = np.sort(indices) + if axis == 0: + bins = np.array(self.row_lengths) + else: + bins = np.array(self.column_widths) + # INT_MAX to make sure we don't try to compute on partitions that don't exist. + cumulative = np.append(bins[:-1].cumsum(), np.iinfo(bins.dtype).max) + + def internal(block_idx: int, global_index): + """Transform global index to internal one for given block (identified by its index).""" + return ( + global_index + if not block_idx + else np.subtract( + global_index, cumulative[min(block_idx, len(cumulative) - 1) - 1] + ) + ) + + partition_ids = np.digitize(indices, cumulative) + count_for_each_partition = np.array( + [(partition_ids == i).sum() for i in range(len(cumulative))] + ).cumsum() + # Compute the internal indices and pair those with the partition index. + # If the first partition has any values we need to return, compute those + # first to make the list comprehension easier. Otherwise, just append the + # rest of the values to an empty list. + if count_for_each_partition[0] > 0: + first_partition_indices = [ + (0, internal(0, indices[slice(count_for_each_partition[0])])) + ] + else: + first_partition_indices = [] + partition_ids_with_indices = first_partition_indices + [ + ( + i, + internal( + i, + indices[ + slice( + count_for_each_partition[i - 1], + count_for_each_partition[i], + ) + ], + ), + ) + for i in range(1, len(count_for_each_partition)) + if count_for_each_partition[i] > count_for_each_partition[i - 1] + ] + return OrderedDict(partition_ids_with_indices) + + @staticmethod + def _join_index_objects(axis, indexes, how, sort): + """ + Join the pair of index objects (columns or rows) by a given strategy. + + Unlike Index.join() in pandas, if `axis` is 1, `sort` is False, + and `how` is "outer", the result will _not_ be sorted. + + Parameters + ---------- + axis : {0, 1} + The axis index object to join (0 - rows, 1 - columns). + indexes : list(Index) + The indexes to join on. + how : {'left', 'right', 'inner', 'outer', None} + The type of join to join to make. If `None` then joined index + considered to be the first index in the `indexes` list. + sort : boolean + Whether or not to sort the joined index. + + Returns + ------- + (Index, func) + Joined index with make_reindexer func. + """ + assert isinstance(indexes, list) + + # define helper functions + def merge(left_index, right_index): + """Combine a pair of indices depending on `axis`, `how` and `sort` from outside.""" + if axis == 1 and how == "outer" and not sort: + return left_index.union(right_index, sort=False) + else: + return left_index.join(right_index, how=how, sort=sort) + + # define condition for joining indexes + all_indices_equal = all(indexes[0].equals(index) for index in indexes[1:]) + do_join_index = how is not None and not all_indices_equal + + # define condition for joining indexes with getting indexers + need_indexers = ( + axis == 0 + and not all_indices_equal + and any(not index.is_unique for index in indexes) + ) + indexers = None + + # perform joining indexes + if do_join_index: + if len(indexes) == 2 and need_indexers: + # in case of count of indexes > 2 we should perform joining all indexes + # after that get indexers + # in the fast path we can obtain joined_index and indexers in one call + indexers = [None, None] + joined_index, indexers[0], indexers[1] = indexes[0].join( + indexes[1], how=how, sort=sort, return_indexers=True + ) + else: + joined_index = indexes[0] + # TODO: revisit for performance + for index in indexes[1:]: + joined_index = merge(joined_index, index) + else: + joined_index = indexes[0].copy() + + if need_indexers and indexers is None: + indexers = [index.get_indexer_for(joined_index) for index in indexes] + + def make_reindexer(do_reindex: bool, frame_idx: int): + """Create callback that reindexes the dataframe using newly computed index.""" + # the order of the frames must match the order of the indexes + if not do_reindex: + return lambda df: df + + if need_indexers: + assert indexers is not None + + return lambda df: df._reindex_with_indexers( + {0: [joined_index, indexers[frame_idx]]}, + copy=True, + allow_dups=True, + ) + return lambda df: df.reindex(joined_index, axis=axis) + + return joined_index, make_reindexer + + # Internal methods + # These methods are for building the correct answer in a modular way. + # Please be careful when changing these! + + def _build_treereduce_func(self, axis, func): + """ + Properly formats a TreeReduce result so that the partitioning is correct. + + Parameters + ---------- + axis : int + The axis along which to apply the function. + func : callable + The function to apply. + + Returns + ------- + callable + A function to be shipped to the partitions to be executed. + + Notes + ----- + This should be used for any TreeReduce style operation that results in a + reduced data dimensionality (dataframe -> series). + """ + + def _tree_reduce_func(df, *args, **kwargs): + """Tree-reducer function itself executing `func`, presenting the resulting pandas.Series as pandas.DataFrame.""" + series_result = func(df, *args, **kwargs) + if axis == 0 and isinstance(series_result, pandas.Series): + # In the case of axis=0, we need to keep the shape of the data + # consistent with what we have done. In the case of a reduce, the + # data for axis=0 should be a single value for each column. By + # transposing the data after we convert to a DataFrame, we ensure that + # the columns of the result line up with the columns from the data. + # axis=1 does not have this requirement because the index already will + # line up with the index of the data based on how pandas creates a + # DataFrame from a Series. + result = pandas.DataFrame(series_result).T + result.index = [MODIN_UNNAMED_SERIES_LABEL] + else: + result = pandas.DataFrame(series_result) + if isinstance(series_result, pandas.Series): + result.columns = [MODIN_UNNAMED_SERIES_LABEL] + return result + + return _tree_reduce_func + + def _compute_tree_reduce_metadata(self, axis, new_parts): + """ + Compute the metadata for the result of reduce function. + + Parameters + ---------- + axis : int + The axis on which reduce function was applied. + new_parts : NumPy 2D array + Partitions with the result of applied function. + + Returns + ------- + PandasDataframe + Modin series (1xN frame) containing the reduced data. + """ + new_axes, new_axes_lengths = [0, 0], [0, 0] + + new_axes[axis] = [MODIN_UNNAMED_SERIES_LABEL] + new_axes[axis ^ 1] = self.axes[axis ^ 1] + + new_axes_lengths[axis] = [1] + new_axes_lengths[axis ^ 1] = self._axes_lengths[axis ^ 1] + + new_dtypes = None + result = self.__constructor__( + new_parts, + *new_axes, + *new_axes_lengths, + new_dtypes, + ) + return result + + @lazy_metadata_decorator(apply_axis="both") + def reduce( + self, + axis: Union[int, Axis], + function: Callable, + dtypes: Optional[str] = None, + ) -> "PandasDataframe": + """ + Perform a user-defined aggregation on the specified axis, where the axis reduces down to a singleton. Requires knowledge of the full axis for the reduction. + + Parameters + ---------- + axis : int or modin.core.dataframe.base.utils.Axis + The axis to perform the reduce over. + function : callable(row|col) -> single value + The reduce function to apply to each column. + dtypes : str, optional + The data types for the result. This is an optimization + because there are functions that always result in a particular data + type, and this allows us to avoid (re)computing it. + + Returns + ------- + PandasDataframe + Modin series (1xN frame) containing the reduced data. + + Notes + ----- + The user-defined function must reduce to a single value. + """ + axis = Axis(axis) + function = self._build_treereduce_func(axis.value, function) + new_parts = self._partition_mgr_cls.map_axis_partitions( + axis.value, self._partitions, function + ) + return self._compute_tree_reduce_metadata(axis.value, new_parts) + + @lazy_metadata_decorator(apply_axis="opposite", axis_arg=0) + def tree_reduce( + self, + axis: Union[int, Axis], + map_func: Callable, + reduce_func: Optional[Callable] = None, + dtypes: Optional[str] = None, + ) -> "PandasDataframe": + """ + Apply function that will reduce the data to a pandas Series. + + Parameters + ---------- + axis : int or modin.core.dataframe.base.utils.Axis + The axis to perform the tree reduce over. + map_func : callable(row|col) -> row|col + Callable function to map the dataframe. + reduce_func : callable(row|col) -> single value, optional + Callable function to reduce the dataframe. + If none, then apply map_func twice. + dtypes : str, optional + The data types for the result. This is an optimization + because there are functions that always result in a particular data + type, and this allows us to avoid (re)computing it. + + Returns + ------- + PandasDataframe + A new dataframe. + """ + axis = Axis(axis) + map_func = self._build_treereduce_func(axis.value, map_func) + if reduce_func is None: + reduce_func = map_func + else: + reduce_func = self._build_treereduce_func(axis.value, reduce_func) + + map_parts = self._partition_mgr_cls.map_partitions(self._partitions, map_func) + reduce_parts = self._partition_mgr_cls.map_axis_partitions( + axis.value, map_parts, reduce_func + ) + return self._compute_tree_reduce_metadata(axis.value, reduce_parts) + + @lazy_metadata_decorator(apply_axis=None) + def map(self, func: Callable, dtypes: Optional[str] = None) -> "PandasDataframe": + """ + Perform a function that maps across the entire dataset. + + Parameters + ---------- + func : callable(row|col|cell) -> row|col|cell + The function to apply. + dtypes : dtypes of the result, optional + The data types for the result. This is an optimization + because there are functions that always result in a particular data + type, and this allows us to avoid (re)computing it. + + Returns + ------- + PandasDataframe + A new dataframe. + """ + new_partitions = self._partition_mgr_cls.map_partitions(self._partitions, func) + if dtypes == "copy": + dtypes = self._dtypes + elif dtypes is not None: + dtypes = pandas.Series( + [np.dtype(dtypes)] * len(self.columns), index=self.columns + ) + return self.__constructor__( + new_partitions, + self._index_cache, + self._columns_cache, + self._row_lengths_cache, + self._column_widths_cache, + dtypes=dtypes, + ) + + def window( + self, + axis: Union[int, Axis], + reduce_fn: Callable, + window_size: int, + result_schema: Optional[dict[Hashable, type]] = None, + ) -> "PandasDataframe": + """ + Apply a sliding window operator that acts as a GROUPBY on each window, and reduces down to a single row (column) per window. + + Parameters + ---------- + axis : int or modin.core.dataframe.base.utils.Axis + The axis to slide over. + reduce_fn : callable(rowgroup|colgroup) -> row|col + The reduce function to apply over the data. + window_size : int + The number of row/columns to pass to the function. + (The size of the sliding window). + result_schema : dict, optional + Mapping from column labels to data types that represents the types of the output dataframe. + + Returns + ------- + PandasDataframe + A new PandasDataframe with the reduce function applied over windows of the specified + axis. + + Notes + ----- + The user-defined reduce function must reduce each window’s column + (row if axis=1) down to a single value. + """ + pass + + @lazy_metadata_decorator(apply_axis="both") + def fold(self, axis, func): + """ + Perform a function across an entire axis. + + Parameters + ---------- + axis : int + The axis to apply over. + func : callable + The function to apply. + + Returns + ------- + PandasDataframe + A new dataframe. + + Notes + ----- + The data shape is not changed (length and width of the table). + """ + new_partitions = self._partition_mgr_cls.map_axis_partitions( + axis, self._partitions, func, keep_partitioning=True + ) + return self.__constructor__( + new_partitions, + self._index_cache, + self._columns_cache, + self._row_lengths_cache, + self._column_widths_cache, + ) + + def infer_objects(self) -> "PandasDataframe": + """ + Attempt to infer better dtypes for object columns. + + Attempts soft conversion of object-dtyped columns, leaving non-object and unconvertible + columns unchanged. The inference rules are the same as during normal Series/DataFrame + construction. + + Returns + ------- + PandasDataframe + A new PandasDataframe with the inferred schema. + """ + obj_cols = [ + col for col, dtype in enumerate(self.dtypes) if is_object_dtype(dtype) + ] + return self.infer_types(obj_cols) + + def infer_types(self, col_labels: list[str]) -> "PandasDataframe": + """ + Determine the compatible type shared by all values in the specified columns, and coerce them to that type. + + Parameters + ---------- + col_labels : list + List of column labels to infer and induce types over. + + Returns + ------- + PandasDataframe + A new PandasDataframe with the inferred schema. + """ + # Compute dtypes on the specified columns, and then set those dtypes on a new frame + new_cols = self.take_2d_labels_or_positional(col_labels=col_labels) + new_cols_dtypes = new_cols.tree_reduce(0, pandas.DataFrame.infer_objects).dtypes + new_dtypes = self.dtypes.copy() + new_dtypes[col_labels] = new_cols_dtypes + return self.__constructor__( + self._partitions, + self._index_cache, + self._columns_cache, + self._row_lengths_cache, + self._column_widths_cache, + new_dtypes, + ) + + def join( + self, + axis: Union[int, Axis], + condition: Callable, + other, # other: ModinDataframe, + join_type: Union[str, JoinType], + ) -> "PandasDataframe": + """ + Join this dataframe with the other. + + Parameters + ---------- + axis : int or modin.core.dataframe.base.utils.Axis + The axis to perform the join on. + condition : callable + Function that determines which rows should be joined. The condition can be a + simple equality, e.g. "left.col1 == right.col1" or can be arbitrarily complex. + other : ModinDataframe + The other data to join with, i.e. the right dataframe. + join_type : string {"inner", "left", "right", "outer"} or modin.core.dataframe.base.utils.JoinType + The type of join to perform. + + Returns + ------- + PandasDataframe + A new PandasDataframe that is the result of applying the specified join over the two + dataframes. + + Notes + ----- + During the join, this dataframe is considered the left, while the other is + treated as the right. + + Only inner joins, left outer, right outer, and full outer joins are currently supported. + Support for other join types (e.g. natural join) may be implemented in the future. + """ + pass + + def rename( + self, + new_row_labels: Optional[Union[dict[Hashable, Hashable], Callable]] = None, + new_col_labels: Optional[Union[dict[Hashable, Hashable], Callable]] = None, + level: Optional[Union[int, list[int]]] = None, + ) -> "PandasDataframe": + """ + Replace the row and column labels with the specified new labels. + + Parameters + ---------- + new_row_labels : dictionary or callable, optional + Mapping or callable that relates old row labels to new labels. + new_col_labels : dictionary or callable, optional + Mapping or callable that relates old col labels to new labels. + level : int, optional + Level whose row labels to replace. + + Returns + ------- + PandasDataframe + A new PandasDataframe with the new row and column labels. + + Notes + ----- + If level is not specified, the default behavior is to replace row labels in all levels. + """ + new_index = self.index.copy() + + def make_label_swapper(label_dict): + if isinstance(label_dict, dict): + return lambda label: label_dict.get(label, label) + return label_dict + + def swap_labels_levels(index_tuple): + if isinstance(new_row_labels, dict): + return tuple(new_row_labels.get(label, label) for label in index_tuple) + return tuple(new_row_labels(label) for label in index_tuple) + + if new_row_labels: + swap_row_labels = make_label_swapper(new_row_labels) + if isinstance(self.index, pandas.MultiIndex): + if level is not None: + new_index.set_levels( + new_index.levels[level].map(swap_row_labels), level + ) + else: + new_index = new_index.map(swap_labels_levels) + else: + new_index = new_index.map(swap_row_labels) + new_cols = self.columns.copy() + if new_col_labels: + new_cols = new_cols.map(make_label_swapper(new_col_labels)) + + def map_fn(df): + return df.rename(index=new_row_labels, columns=new_col_labels, level=level) + + new_parts = self._partition_mgr_cls.map_partitions(self._partitions, map_fn) + new_dtypes = None if self._dtypes is None else self._dtypes.set_axis(new_cols) + return self.__constructor__( + new_parts, + new_index, + new_cols, + self._row_lengths_cache, + self._column_widths_cache, + new_dtypes, + ) + + @lazy_metadata_decorator(apply_axis="both") + def sort_by( + self, + axis: Union[int, Axis], + columns: Union[str, list[str]], + ascending: bool = True, + **kwargs, + ) -> "PandasDataframe": + """ + Logically reorder rows (columns if axis=1) lexicographically by the data in a column or set of columns. + + Parameters + ---------- + axis : int or modin.core.dataframe.base.utils.Axis + The axis to perform the sort over. + columns : string or list + Column label(s) to use to determine lexicographical ordering. + ascending : boolean, default: True + Whether to sort in ascending or descending order. + **kwargs : dict + Keyword arguments to pass when sorting partitions. + + Returns + ------- + PandasDataframe + A new PandasDataframe sorted into lexicographical order by the specified column(s). + """ + if not isinstance(columns, list): + columns = [columns] + # When we do a sort on the result of Series.value_counts, we don't rename the index until + # after everything is done, which causes an error when sorting the partitions, since the + # index and the column share the same name, when in actuality, the index's name should be + # None. This fixes the indexes name beforehand in that case, so that the sort works. + + def sort_function(df): + index_renaming = None + if any(name in df.columns for name in df.index.names): + index_renaming = df.index.names + df.index = df.index.set_names([None] * len(df.index.names)) + df = df.sort_values(by=columns, ascending=ascending, **kwargs) + if index_renaming is not None: + df.index = df.index.set_names(index_renaming) + return df + + axis = Axis(axis) + if axis != Axis.ROW_WISE: + raise NotImplementedError( + f"Algebra sort only implemented row-wise. {axis.name} sort not implemented yet!" + ) + + # If this df is empty, we don't want to try and shuffle or sort. + if len(self.axes[0]) == 0 or len(self.axes[1]) == 0: + return self.copy() + # If this df only has one row partition, we don't want to do a shuffle and sort - we can + # just do a full-axis sort. + if len(self._partitions) == 1: + return self.apply_full_axis( + 1, + sort_function, + ) + if self.dtypes[columns[0]] == object: + # This means we are not sorting numbers, so we need our quantiles to not try + # arithmetic on the values. + method = "inverted_cdf" + else: + method = "linear" + + shuffling_functions = build_sort_functions( # noqa F821 + self, + columns[0], + method, + ascending[0] if is_list_like(ascending) else ascending, + **kwargs, + ) + major_col_partition_index = self.columns.get_loc(columns[0]) + cols_seen = 0 + index = -1 + for i, length in enumerate(self.column_widths): + cols_seen += length + if major_col_partition_index < cols_seen: + index = i + break + new_partitions = self._partition_mgr_cls.shuffle_partitions( + self._partitions, + index, + shuffling_functions, + sort_function, + ) + new_axes = self.axes + new_lengths = [None, None] + if kwargs.get("ignore_index", False): + new_axes[axis.value] = RangeIndex(len(new_axes[axis.value])) + else: + ( + new_axes[axis.value], + new_lengths[axis.value], + ) = self._compute_axis_labels_and_lengths(axis.value, new_partitions) + + new_axes[axis.value] = new_axes[axis.value].set_names( + self.axes[axis.value].names + ) + # We perform the final steps of the sort on full axis partitions, so we know that the + # length of each partition is the full length of the dataframe. + new_lengths[axis.value ^ 1] = [len(self.columns)] + # Since the strategy to pick our pivots involves random sampling + # we could end up picking poor pivots, leading to skew in our partitions. + # We should add a fix to check if there is skew in the partitions and rebalance + # them if necessary. Calling `rebalance_partitions` won't do this, since it only + # resolves the case where there isn't the right amount of partitions - not where + # there is skew across the lengths of partitions. + new_modin_frame = self.__constructor__( + new_partitions, *new_axes, *new_lengths, self.dtypes + ) + if kwargs.get("ignore_index", False): + new_modin_frame._propagate_index_objs(axis=0) + return new_modin_frame + + @lazy_metadata_decorator(apply_axis="both") + def filter(self, axis: Union[Axis, int], condition: Callable) -> "PandasDataframe": + """ + Filter data based on the function provided along an entire axis. + + Parameters + ---------- + axis : int or modin.core.dataframe.base.utils.Axis + The axis to filter over. + condition : callable(row|col) -> bool + The function to use for the filter. This function should filter the + data itself. + + Returns + ------- + PandasDataframe + A new filtered dataframe. + """ + axis = Axis(axis) + assert axis in ( + Axis.ROW_WISE, + Axis.COL_WISE, + ), "Axis argument to filter operator must be 0 (rows) or 1 (columns)" + + new_partitions = self._partition_mgr_cls.map_axis_partitions( + axis.value, self._partitions, condition, keep_partitioning=True + ) + + new_axes, new_lengths = [0, 0], [0, 0] + + new_axes[axis.value] = ( + self._index_cache if axis.value == 0 else self._columns_cache + ) + new_lengths[axis.value] = ( + self._row_lengths_cache if axis.value == 0 else self._column_widths_cache + ) + new_axes[axis.value ^ 1], new_lengths[axis.value ^ 1] = None, None + + return self.__constructor__( + new_partitions, + *new_axes, + *new_lengths, + self._dtypes if axis == Axis.COL_WISE else None, + ) + + def filter_by_types(self, types: list[Hashable]) -> "PandasDataframe": + """ + Allow the user to specify a type or set of types by which to filter the columns. + + Parameters + ---------- + types : list + The types to filter columns by. + + Returns + ------- + PandasDataframe + A new PandasDataframe from the filter provided. + """ + return self.take_2d_labels_or_positional( + col_positions=[i for i, dtype in enumerate(self.dtypes) if dtype in types] + ) + + @lazy_metadata_decorator(apply_axis="both") + def explode(self, axis: Union[int, Axis], func: Callable) -> "PandasDataframe": + """ + Explode list-like entries along an entire axis. + + Parameters + ---------- + axis : int or modin.core.dataframe.base.utils.Axis + The axis specifying how to explode. If axis=1, explode according + to columns. + func : callable + The function to use to explode a single element. + + Returns + ------- + PandasFrame + A new filtered dataframe. + """ + axis = Axis(axis) + partitions = self._partition_mgr_cls.map_axis_partitions( + axis.value, self._partitions, func, keep_partitioning=True + ) + if axis == Axis.COL_WISE: + new_index, row_lengths = self._compute_axis_labels_and_lengths( + 0, partitions + ) + new_columns, column_widths = self.columns, self._column_widths_cache + else: + new_index, row_lengths = self.index, self._row_lengths_cache + new_columns, column_widths = self._compute_axis_labels_and_lengths( + 1, partitions + ) + return self.__constructor__( + partitions, new_index, new_columns, row_lengths, column_widths + ) + + @lazy_metadata_decorator(apply_axis="both") + def apply_full_axis( + self, + axis, + func, + new_index=None, + new_columns=None, + dtypes=None, + keep_partitioning=True, + sync_labels=True, + pass_axis_lengths_to_partitions=False, + ): + """ + Perform a function across an entire axis. + + Parameters + ---------- + axis : {0, 1} + The axis to apply over (0 - rows, 1 - columns). + func : callable + The function to apply. + new_index : list-like, optional + The index of the result. We may know this in advance, + and if not provided it must be computed. + new_columns : list-like, optional + The columns of the result. We may know this in + advance, and if not provided it must be computed. + dtypes : list-like, optional + The data types of the result. This is an optimization + because there are functions that always result in a particular data + type, and allows us to avoid (re)computing it. + keep_partitioning : boolean, default: True + The flag to keep partition boundaries for Modin Frame. + Setting it to True disables shuffling data from one partition to another. + sync_labels : boolean, default: True + Synchronize external indexes (`new_index`, `new_columns`) with internal indexes. + This could be used when you're certain that the indices in partitions are equal to + the provided hints in order to save time on syncing them. + pass_axis_lengths_to_partitions : bool, default: False + Whether pass partition lengths along `axis ^ 1` to the kernel `func`. + Note that `func` must be able to obtain `df, *axis_lengths`. + + Returns + ------- + PandasDataframe + A new dataframe. + + Notes + ----- + The data shape may change as a result of the function. + """ + return self.broadcast_apply_full_axis( + axis=axis, + func=func, + new_index=new_index, + new_columns=new_columns, + dtypes=dtypes, + other=None, + keep_partitioning=keep_partitioning, + sync_labels=sync_labels, + pass_axis_lengths_to_partitions=pass_axis_lengths_to_partitions, + ) + + @lazy_metadata_decorator(apply_axis="both") + def apply_full_axis_select_indices( + self, + axis, + func, + apply_indices=None, + numeric_indices=None, + new_index=None, + new_columns=None, + keep_remaining=False, + ): + """ + Apply a function across an entire axis for a subset of the data. + + Parameters + ---------- + axis : int + The axis to apply over. + func : callable + The function to apply. + apply_indices : list-like, default: None + The labels to apply over. + numeric_indices : list-like, default: None + The indices to apply over. + new_index : list-like, optional + The index of the result. We may know this in advance, + and if not provided it must be computed. + new_columns : list-like, optional + The columns of the result. We may know this in + advance, and if not provided it must be computed. + keep_remaining : boolean, default: False + Whether or not to drop the data that is not computed over. + + Returns + ------- + PandasDataframe + A new dataframe. + """ + assert apply_indices is not None or numeric_indices is not None + # Convert indices to numeric indices + old_index = self.index if axis else self.columns + if apply_indices is not None: + numeric_indices = old_index.get_indexer_for(apply_indices) + # Get the indices for the axis being applied to (it is the opposite of axis + # being applied over) + dict_indices = self._get_dict_of_block_index(axis ^ 1, numeric_indices) + new_partitions = ( + self._partition_mgr_cls.apply_func_to_select_indices_along_full_axis( + axis, + self._partitions, + func, + dict_indices, + keep_remaining=keep_remaining, + ) + ) + # TODO Infer columns and index from `keep_remaining` and `apply_indices` + if new_index is None: + new_index = self.index if axis == 1 else None + if new_columns is None: + new_columns = self.columns if axis == 0 else None + return self.__constructor__(new_partitions, new_index, new_columns, None, None) + + @lazy_metadata_decorator(apply_axis="both") + def apply_select_indices( + self, + axis, + func, + apply_indices=None, + row_labels=None, + col_labels=None, + new_index=None, + new_columns=None, + keep_remaining=False, + item_to_distribute=no_default, + ): + """ + Apply a function for a subset of the data. + + Parameters + ---------- + axis : {0, 1} + The axis to apply over. + func : callable + The function to apply. + apply_indices : list-like, default: None + The labels to apply over. Must be given if axis is provided. + row_labels : list-like, default: None + The row labels to apply over. Must be provided with + `col_labels` to apply over both axes. + col_labels : list-like, default: None + The column labels to apply over. Must be provided + with `row_labels` to apply over both axes. + new_index : list-like, optional + The index of the result. We may know this in advance, + and if not provided it must be computed. + new_columns : list-like, optional + The columns of the result. We may know this in + advance, and if not provided it must be computed. + keep_remaining : boolean, default: False + Whether or not to drop the data that is not computed over. + item_to_distribute : np.ndarray or scalar, default: no_default + The item to split up so it can be applied over both axes. + + Returns + ------- + PandasDataframe + A new dataframe. + """ + # TODO Infer columns and index from `keep_remaining` and `apply_indices` + if new_index is None: + new_index = self.index if axis == 1 else None + if new_columns is None: + new_columns = self.columns if axis == 0 else None + if axis is not None: + assert apply_indices is not None + # Convert indices to numeric indices + old_index = self.index if axis else self.columns + numeric_indices = old_index.get_indexer_for(apply_indices) + # Get indices being applied to (opposite of indices being applied over) + dict_indices = self._get_dict_of_block_index(axis ^ 1, numeric_indices) + new_partitions = self._partition_mgr_cls.apply_func_to_select_indices( + axis, + self._partitions, + func, + dict_indices, + keep_remaining=keep_remaining, + ) + # Length objects for new object creation. This is shorter than if..else + # This object determines the lengths and widths based on the given + # parameters and builds a dictionary used in the constructor below. 0 gives + # the row lengths and 1 gives the column widths. Since the dimension of + # `axis` given may have changed, we currently just recompute it. + # TODO Determine lengths from current lengths if `keep_remaining=False` + lengths_objs = { + axis: [len(apply_indices)] + if not keep_remaining + else [self.row_lengths, self.column_widths][axis], + axis ^ 1: [self.row_lengths, self.column_widths][axis ^ 1], + } + return self.__constructor__( + new_partitions, new_index, new_columns, lengths_objs[0], lengths_objs[1] + ) + else: + # We are applying over both axes here, so make sure we have all the right + # variables set. + assert row_labels is not None and col_labels is not None + assert keep_remaining + assert item_to_distribute is not no_default + row_partitions_list = self._get_dict_of_block_index(0, row_labels).items() + col_partitions_list = self._get_dict_of_block_index(1, col_labels).items() + new_partitions = self._partition_mgr_cls.apply_func_to_indices_both_axis( + self._partitions, + func, + row_partitions_list, + col_partitions_list, + item_to_distribute, + # Passing caches instead of values in order to not trigger shapes recomputation + # if they are not used inside this function. + self._row_lengths_cache, + self._column_widths_cache, + ) + return self.__constructor__( + new_partitions, + new_index, + new_columns, + self._row_lengths_cache, + self._column_widths_cache, + ) + + @lazy_metadata_decorator(apply_axis="both") + def broadcast_apply( + self, axis, func, other, join_type="left", labels="keep", dtypes=None + ): + """ + Broadcast axis partitions of `other` to partitions of `self` and apply a function. + + Parameters + ---------- + axis : {0, 1} + Axis to broadcast over. + func : callable + Function to apply. + other : PandasDataframe + Modin DataFrame to broadcast. + join_type : str, default: "left" + Type of join to apply. + labels : {"keep", "replace", "drop"}, default: "keep" + Whether keep labels from `self` Modin DataFrame, replace them with labels + from joined DataFrame or drop altogether to make them be computed lazily later. + dtypes : "copy" or None, default: None + Whether keep old dtypes or infer new dtypes from data. + + Returns + ------- + PandasDataframe + New Modin DataFrame. + """ + # Only sort the indices if they do not match + ( + left_parts, + right_parts, + joined_index, + partition_sizes_along_axis, + ) = self._copartition( + axis, other, join_type, sort=not self.axes[axis].equals(other.axes[axis]) + ) + # unwrap list returned by `copartition`. + right_parts = right_parts[0] + new_frame = self._partition_mgr_cls.broadcast_apply( + axis, func, left_parts, right_parts + ) + if dtypes == "copy": + dtypes = self._dtypes + + def _pick_axis(get_axis, sizes_cache): + if labels == "keep": + return get_axis(), sizes_cache + if labels == "replace": + return joined_index, partition_sizes_along_axis + assert labels == "drop", f"Unexpected `labels`: {labels}" + return None, None + + if axis == 0: + # Pass shape caches instead of values in order to not trigger shape computation. + new_index, new_row_lengths = _pick_axis( + self._get_index, self._row_lengths_cache + ) + new_columns, new_column_widths = self.columns, self._column_widths_cache + else: + new_index, new_row_lengths = self.index, self._row_lengths_cache + new_columns, new_column_widths = _pick_axis( + self._get_columns, self._column_widths_cache + ) + + return self.__constructor__( + new_frame, + new_index, + new_columns, + new_row_lengths, + new_column_widths, + dtypes=dtypes, + ) + + def _prepare_frame_to_broadcast(self, axis, indices, broadcast_all): + """ + Compute the indices to broadcast `self` considering `indices`. + + Parameters + ---------- + axis : {0, 1} + Axis to broadcast along. + indices : dict + Dict of indices and internal indices of partitions where `self` must + be broadcasted. + broadcast_all : bool + Whether broadcast the whole axis of `self` frame or just a subset of it. + + Returns + ------- + dict + Dictionary with indices of partitions to broadcast. + + Notes + ----- + New dictionary of indices of `self` partitions represents that + you want to broadcast `self` at specified another partition named `other`. For example, + Dictionary {key: {key1: [0, 1], key2: [5]}} means, that in `other`[key] you want to + broadcast [self[key1], self[key2]] partitions and internal indices for `self` must be [[0, 1], [5]] + """ + if broadcast_all: + sizes = self.row_lengths if axis else self.column_widths + return {key: dict(enumerate(sizes)) for key in indices.keys()} + passed_len = 0 + result_dict = {} + for part_num, internal in indices.items(): + result_dict[part_num] = self._get_dict_of_block_index( + axis ^ 1, np.arange(passed_len, passed_len + len(internal)) + ) + passed_len += len(internal) + return result_dict + + def __make_init_labels_args(self, partitions, index, columns) -> dict: + kw = {} + kw["index"], kw["row_lengths"] = ( + self._compute_axis_labels_and_lengths(0, partitions) + if index is None + else (index, None) + ) + kw["columns"], kw["column_widths"] = ( + self._compute_axis_labels_and_lengths(1, partitions) + if columns is None + else (columns, None) + ) + return kw + + @lazy_metadata_decorator(apply_axis="both") + def broadcast_apply_select_indices( + self, + axis, + func, + other, + apply_indices=None, + numeric_indices=None, + keep_remaining=False, + broadcast_all=True, + new_index=None, + new_columns=None, + ): + """ + Apply a function to select indices at specified axis and broadcast partitions of `other` Modin DataFrame. + + Parameters + ---------- + axis : {0, 1} + Axis to apply function along. + func : callable + Function to apply. + other : PandasDataframe + Partitions of which should be broadcasted. + apply_indices : list, default: None + List of labels to apply (if `numeric_indices` are not specified). + numeric_indices : list, default: None + Numeric indices to apply (if `apply_indices` are not specified). + keep_remaining : bool, default: False + Whether drop the data that is not computed over or not. + broadcast_all : bool, default: True + Whether broadcast the whole axis of right frame to every + partition or just a subset of it. + new_index : pandas.Index, optional + Index of the result. We may know this in advance, + and if not provided it must be computed. + new_columns : pandas.Index, optional + Columns of the result. We may know this in advance, + and if not provided it must be computed. + + Returns + ------- + PandasDataframe + New Modin DataFrame. + """ + assert ( + apply_indices is not None or numeric_indices is not None + ), "Indices to apply must be specified!" + + if other is None: + if apply_indices is None: + apply_indices = self.axes[axis][numeric_indices] + return self.apply_select_indices( + axis=axis, + func=func, + apply_indices=apply_indices, + keep_remaining=keep_remaining, + new_index=new_index, + new_columns=new_columns, + ) + + if numeric_indices is None: + old_index = self.index if axis else self.columns + numeric_indices = old_index.get_indexer_for(apply_indices) + + dict_indices = self._get_dict_of_block_index(axis ^ 1, numeric_indices) + broadcasted_dict = other._prepare_frame_to_broadcast( + axis, dict_indices, broadcast_all=broadcast_all + ) + new_partitions = self._partition_mgr_cls.broadcast_apply_select_indices( + axis, + func, + self._partitions, + other._partitions, + dict_indices, + broadcasted_dict, + keep_remaining, + ) + + kw = self.__make_init_labels_args(new_partitions, new_index, new_columns) + return self.__constructor__(new_partitions, **kw) + + @lazy_metadata_decorator(apply_axis="both") + def broadcast_apply_full_axis( + self, + axis, + func, + other, + new_index=None, + new_columns=None, + apply_indices=None, + enumerate_partitions=False, + dtypes=None, + keep_partitioning=True, + sync_labels=True, + pass_axis_lengths_to_partitions=False, + ): + """ + Broadcast partitions of `other` Modin DataFrame and apply a function along full axis. + + Parameters + ---------- + axis : {0, 1} + Axis to apply over (0 - rows, 1 - columns). + func : callable + Function to apply. + other : PandasDataframe or list + Modin DataFrame(s) to broadcast. + new_index : list-like, optional + Index of the result. We may know this in advance, + and if not provided it must be computed. + new_columns : list-like, optional + Columns of the result. We may know this in + advance, and if not provided it must be computed. + apply_indices : list-like, default: None + Indices of `axis ^ 1` to apply function over. + enumerate_partitions : bool, default: False + Whether pass partition index into applied `func` or not. + Note that `func` must be able to obtain `partition_idx` kwarg. + dtypes : list-like, default: None + Data types of the result. This is an optimization + because there are functions that always result in a particular data + type, and allows us to avoid (re)computing it. + keep_partitioning : boolean, default: True + The flag to keep partition boundaries for Modin Frame. + Setting it to True disables shuffling data from one partition to another. + sync_labels : boolean, default: True + Synchronize external indexes (`new_index`, `new_columns`) with internal indexes. + This could be used when you're certain that the indices in partitions are equal to + the provided hints in order to save time on syncing them. + pass_axis_lengths_to_partitions : bool, default: False + Whether pass partition lengths along `axis ^ 1` to the kernel `func`. + Note that `func` must be able to obtain `df, *axis_lengths`. + + Returns + ------- + PandasDataframe + New Modin DataFrame. + """ + if other is not None: + if not isinstance(other, list): + other = [other] + other = [o._partitions for o in other] if len(other) else None + + if apply_indices is not None: + numeric_indices = self.axes[axis ^ 1].get_indexer_for(apply_indices) + apply_indices = self._get_dict_of_block_index( + axis ^ 1, numeric_indices + ).keys() + + apply_func_args = None + if pass_axis_lengths_to_partitions: + if axis == 0: + apply_func_args = ( + self._column_widths_cache + if self._column_widths_cache is not None + else [part.width(materialize=False) for part in self._partitions[0]] + ) + else: + apply_func_args = ( + self._row_lengths_cache + if self._row_lengths_cache is not None + else [ + part.length(materialize=False) for part in self._partitions.T[0] + ] + ) + + new_partitions = self._partition_mgr_cls.broadcast_axis_partitions( + axis=axis, + left=self._partitions, + right=other, + apply_func=self._build_treereduce_func(axis, func), + apply_indices=apply_indices, + enumerate_partitions=enumerate_partitions, + keep_partitioning=keep_partitioning, + apply_func_args=apply_func_args, + ) + kw = {"row_lengths": None, "column_widths": None} + if dtypes == "copy": + kw["dtypes"] = self._dtypes + elif dtypes is not None: + if new_columns is None: + ( + new_columns, + kw["column_widths"], + ) = self._compute_axis_labels_and_lengths(1, new_partitions) + kw["dtypes"] = pandas.Series( + [np.dtype(dtypes)] * len(new_columns), index=new_columns + ) + + if not keep_partitioning: + if kw["row_lengths"] is None and new_index is not None: + if axis == 0: + kw["row_lengths"] = get_length_list( # noqa F821 + axis_len=len(new_index), num_splits=new_partitions.shape[0] + ) + elif ( + axis == 1 + and self._row_lengths_cache is not None + and len(new_index) == sum(self._row_lengths_cache) + ): + kw["row_lengths"] = self._row_lengths_cache + if kw["column_widths"] is None and new_columns is not None: + if axis == 1: + kw["column_widths"] = get_length_list( # noqa F821 + axis_len=len(new_columns), + num_splits=new_partitions.shape[1], + ) + elif ( + axis == 0 + and self._column_widths_cache is not None + and len(new_columns) == sum(self._column_widths_cache) + ): + kw["column_widths"] = self._column_widths_cache + result = self.__constructor__( + new_partitions, index=new_index, columns=new_columns, **kw + ) + if sync_labels and new_index is not None: + result.synchronize_labels(axis=0) + if sync_labels and new_columns is not None: + result.synchronize_labels(axis=1) + return result + + def _copartition(self, axis, other, how, sort, force_repartition=False): + """ + Copartition two Modin DataFrames. + + Perform aligning of partitions, index and partition blocks. + + Parameters + ---------- + axis : {0, 1} + Axis to copartition along (0 - rows, 1 - columns). + other : PandasDataframe + Other Modin DataFrame(s) to copartition against. + how : str + How to manage joining the index object ("left", "right", etc.). + sort : bool + Whether sort the joined index or not. + force_repartition : bool, default: False + Whether force the repartitioning or not. By default, + this method will skip repartitioning if it is possible. This is because + reindexing is extremely inefficient. Because this method is used to + `join` or `append`, it is vital that the internal indices match. + + Returns + ------- + tuple + Tuple containing: + 1) 2-d NumPy array of aligned left partitions + 2) list of 2-d NumPy arrays of aligned right partitions + 3) joined index along ``axis`` + 4) List with sizes of partitions along axis that partitioning + was done on. This list will be empty if and only if all + the frames are empty. + """ + if isinstance(other, type(self)): + other = [other] + + self_index = self.axes[axis] + others_index = [o.axes[axis] for o in other] + joined_index, make_reindexer = self._join_index_objects( + axis, [self_index] + others_index, how, sort + ) + + frames = [self] + other + non_empty_frames_idx = [ + i for i, o in enumerate(frames) if o._partitions.size != 0 + ] + + # If all frames are empty + if len(non_empty_frames_idx) == 0: + return ( + self._partitions, + [o._partitions for o in other], + joined_index, + # There are no partition sizes because the resulting dataframe + # has no partitions. + [], + ) + + base_frame_idx = non_empty_frames_idx[0] + other_frames = frames[base_frame_idx + 1 :] + + # Picking first non-empty frame + base_frame = frames[non_empty_frames_idx[0]] + base_index = base_frame.axes[axis] + + # define conditions for reindexing and repartitioning `self` frame + do_reindex_base = not base_index.equals(joined_index) + do_repartition_base = force_repartition or do_reindex_base + + # Perform repartitioning and reindexing for `base_frame` if needed. + # Also define length of base and frames. We will need to know the + # lengths for alignment. + if do_repartition_base: + reindexed_base = base_frame._partition_mgr_cls.map_axis_partitions( + axis, + base_frame._partitions, + make_reindexer(do_reindex_base, base_frame_idx), + ) + if axis: + base_lengths = [obj.width() for obj in reindexed_base[0]] + else: + base_lengths = [obj.length() for obj in reindexed_base.T[0]] + else: + reindexed_base = base_frame._partitions + base_lengths = base_frame.column_widths if axis else base_frame.row_lengths + + others_lengths = [o._axes_lengths[axis] for o in other_frames] + + # define conditions for reindexing and repartitioning `other` frames + do_reindex_others = [ + not o.axes[axis].equals(joined_index) for o in other_frames + ] + + do_repartition_others = [None] * len(other_frames) + for i in range(len(other_frames)): + do_repartition_others[i] = ( + force_repartition + or do_reindex_others[i] + or others_lengths[i] != base_lengths + ) + + # perform repartitioning and reindexing for `other_frames` if needed + reindexed_other_list = [None] * len(other_frames) + for i in range(len(other_frames)): + if do_repartition_others[i]: + # indices of others frame start from `base_frame_idx` + 1 + reindexed_other_list[i] = other_frames[ + i + ]._partition_mgr_cls.map_axis_partitions( + axis, + other_frames[i]._partitions, + make_reindexer(do_repartition_others[i], base_frame_idx + 1 + i), + lengths=base_lengths, + ) + else: + reindexed_other_list[i] = other_frames[i]._partitions + reindexed_frames = ( + [frames[i]._partitions for i in range(base_frame_idx)] + + [reindexed_base] + + reindexed_other_list + ) + return (reindexed_frames[0], reindexed_frames[1:], joined_index, base_lengths) + + @lazy_metadata_decorator(apply_axis="both") + def n_ary_op( + self, + op, + right_frames: list, + join_type="outer", + copartition_along_columns=True, + dtypes=None, + ): + """ + Perform an n-opary operation by joining with other Modin DataFrame(s). + + Parameters + ---------- + op : callable + Function to apply after the join. + right_frames : list of PandasDataframe + Modin DataFrames to join with. + join_type : str, default: "outer" + Type of join to apply. + copartition_along_columns : bool, default: True + Whether to perform copartitioning along columns or not. + For some ops this isn't needed (e.g., `fillna`). + dtypes : series, default: None + Dtypes of the resultant dataframe, this argument will be + received if the resultant dtypes of n-opary operation is precomputed. + + Returns + ------- + PandasDataframe + New Modin DataFrame. + """ + left_parts, list_of_right_parts, joined_index, row_lengths = self._copartition( + 0, right_frames, join_type, sort=True + ) + if copartition_along_columns: + new_left_frame = self.__constructor__( + left_parts, joined_index, self.columns, row_lengths, self.column_widths + ) + new_right_frames = [ + self.__constructor__( + right_parts, + joined_index, + right_frame.columns, + row_lengths, + right_frame.column_widths, + ) + for right_parts, right_frame in zip(list_of_right_parts, right_frames) + ] + + ( + left_parts, + list_of_right_parts, + joined_columns, + column_widths, + ) = new_left_frame._copartition( + 1, + new_right_frames, + join_type, + sort=True, + ) + else: + joined_columns = self._columns_cache + column_widths = self._column_widths_cache + + new_frame = ( + np.array([]) + if len(left_parts) == 0 + or any(len(right_parts) == 0 for right_parts in list_of_right_parts) + else self._partition_mgr_cls.n_ary_operation( + left_parts, op, list_of_right_parts + ) + ) + + return self.__constructor__( + new_frame, + joined_index, + joined_columns, + row_lengths, + column_widths, + dtypes, + ) + + @lazy_metadata_decorator(apply_axis="both") + def concat( + self, + axis: Union[int, Axis], + others: Union["PandasDataframe", list["PandasDataframe"]], + how, + sort, + ) -> "PandasDataframe": + """ + Concatenate `self` with one or more other Modin DataFrames. + + Parameters + ---------- + axis : int or modin.core.dataframe.base.utils.Axis + Axis to concatenate over. + others : list + List of Modin DataFrames to concatenate with. + how : str + Type of join to use for the axis. + sort : bool + Whether sort the result or not. + + Returns + ------- + PandasDataframe + New Modin DataFrame. + """ + axis = Axis(axis) + new_widths = None + new_lengths = None + + def _compute_new_widths(): + widths = None + if self._column_widths_cache is not None and all( + o._column_widths_cache is not None for o in others + ): + widths = self._column_widths_cache + [ + width for o in others for width in o._column_widths_cache + ] + return widths + + # Fast path for equivalent columns and partitioning + if ( + axis == Axis.ROW_WISE + and all(o.columns.equals(self.columns) for o in others) + and all(o.column_widths == self.column_widths for o in others) + ): + joined_index = self.columns + left_parts = self._partitions + right_parts = [o._partitions for o in others] + new_widths = self._column_widths_cache + elif ( + axis == Axis.COL_WISE + and all(o.index.equals(self.index) for o in others) + and all(o.row_lengths == self.row_lengths for o in others) + ): + joined_index = self.index + left_parts = self._partitions + right_parts = [o._partitions for o in others] + new_lengths = self._row_lengths_cache + # we can only do this for COL_WISE because `concat` might rebalance partitions for ROW_WISE + new_widths = _compute_new_widths() + else: + ( + left_parts, + right_parts, + joined_index, + partition_sizes_along_axis, + ) = self._copartition( + axis.value ^ 1, others, how, sort, force_repartition=False + ) + if axis == Axis.COL_WISE: + new_lengths = partition_sizes_along_axis + new_widths = _compute_new_widths() + else: + new_widths = partition_sizes_along_axis + new_partitions, new_lengths2 = self._partition_mgr_cls.concat( + axis.value, left_parts, right_parts + ) + if new_lengths is None: + new_lengths = new_lengths2 + new_dtypes = None + if axis == Axis.ROW_WISE: + new_index = self.index.append([other.index for other in others]) + new_columns = joined_index + all_dtypes = [frame._dtypes for frame in [self] + others] + if all(dtypes is not None for dtypes in all_dtypes): + new_dtypes = pandas.concat(all_dtypes, axis=1) + # 'nan' value will be placed in a row if a column doesn't exist in all frames; + # this value is np.float64 type so we need an explicit conversion + new_dtypes.fillna(np.dtype("float64"), inplace=True) + new_dtypes = new_dtypes.apply( + lambda row: find_common_type(row.values), axis=1 + ) + # If we have already cached the length of each row in at least one + # of the row's partitions, we can build new_lengths for the new + # frame. Typically, if we know the length for any partition in a + # row, we know the length for the first partition in the row. So + # just check the lengths of the first column of partitions. + if not new_lengths: + new_lengths = [] + if new_partitions.size > 0: + for part in new_partitions.T[0]: + if part._length_cache is not None: + new_lengths.append(part.length()) + else: + new_lengths = None + break + else: + new_columns = self.columns.append([other.columns for other in others]) + new_index = joined_index + if self._dtypes is not None and all(o._dtypes is not None for o in others): + new_dtypes = pandas.concat([self.dtypes] + [o.dtypes for o in others]) + # If we have already cached the width of each column in at least one + # of the column's partitions, we can build new_widths for the new + # frame. Typically, if we know the width for any partition in a + # column, we know the width for the first partition in the column. + # So just check the widths of the first row of partitions. + if not new_widths: + new_widths = [] + if new_partitions.size > 0: + for part in new_partitions[0]: + if part._width_cache is not None: + new_widths.append(part.width()) + else: + new_widths = None + break + return self.__constructor__( + new_partitions, new_index, new_columns, new_lengths, new_widths, new_dtypes + ) + + def groupby( + self, + axis: Union[int, Axis], + by: Union[str, list[str]], + operator: Callable, + result_schema: Optional[dict[Hashable, type]] = None, + ) -> "PandasDataframe": + """ + Generate groups based on values in the input column(s) and perform the specified operation on each. + + Parameters + ---------- + axis : int or modin.core.dataframe.base.utils.Axis + The axis to apply the grouping over. + by : string or list of strings + One or more column labels to use for grouping. + operator : callable + The operation to carry out on each of the groups. The operator is another + algebraic operator with its own user-defined function parameter, depending + on the output desired by the user. + result_schema : dict, optional + Mapping from column labels to data types that represents the types of the output dataframe. + + Returns + ------- + PandasDataframe + A new PandasDataframe containing the groupings specified, with the operator + applied to each group. + + Notes + ----- + No communication between groups is allowed in this algebra implementation. + + The number of rows (columns if axis=1) returned by the user-defined function + passed to the groupby may be at most the number of rows in the group, and + may be as small as a single row. + + Unlike the pandas API, an intermediate “GROUP BY” object is not present in this + algebra implementation. + """ + pass + + @lazy_metadata_decorator(apply_axis="opposite", axis_arg=0) + def groupby_reduce( + self, + axis, + by, + map_func, + reduce_func, + new_index=None, + new_columns=None, + apply_indices=None, + ): + """ + Groupby another Modin DataFrame dataframe and aggregate the result. + + Parameters + ---------- + axis : {0, 1} + Axis to groupby and aggregate over. + by : PandasDataframe or None + A Modin DataFrame to group by. + map_func : callable + Map component of the aggregation. + reduce_func : callable + Reduce component of the aggregation. + new_index : pandas.Index, optional + Index of the result. We may know this in advance, + and if not provided it must be computed. + new_columns : pandas.Index, optional + Columns of the result. We may know this in advance, + and if not provided it must be computed. + apply_indices : list-like, default: None + Indices of `axis ^ 1` to apply groupby over. + + Returns + ------- + PandasDataframe + New Modin DataFrame. + """ + by_parts = by if by is None else by._partitions + if by is None: + self._propagate_index_objs(axis=0) + + if apply_indices is not None: + numeric_indices = self.axes[axis ^ 1].get_indexer_for(apply_indices) + apply_indices = list( + self._get_dict_of_block_index(axis ^ 1, numeric_indices).keys() + ) + + new_partitions = self._partition_mgr_cls.groupby_reduce( + axis, self._partitions, by_parts, map_func, reduce_func, apply_indices + ) + kw = self.__make_init_labels_args(new_partitions, new_index, new_columns) + return self.__constructor__(new_partitions, **kw) + + @classmethod + def from_pandas(cls, df): + """ + Create a Modin DataFrame from a pandas DataFrame. + + Parameters + ---------- + df : pandas.DataFrame + A pandas DataFrame. + + Returns + ------- + PandasDataframe + New Modin DataFrame. + """ + new_index = df.index + new_columns = df.columns + new_dtypes = df.dtypes + new_frame, new_lengths, new_widths = cls._partition_mgr_cls.from_pandas( + df, True + ) + return cls( + new_frame, + new_index, + new_columns, + new_lengths, + new_widths, + dtypes=new_dtypes, + ) + + @classmethod + def from_arrow(cls, at): + """ + Create a Modin DataFrame from an Arrow Table. + + Parameters + ---------- + at : pyarrow.table + Arrow Table. + + Returns + ------- + PandasDataframe + New Modin DataFrame. + """ + new_frame, new_lengths, new_widths = cls._partition_mgr_cls.from_arrow( + at, return_dims=True + ) + new_columns = Index.__new__(Index, data=at.column_names, dtype="O") + new_index = Index.__new__(RangeIndex, data=range(at.num_rows)) + new_dtypes = pandas.Series( + [cls._arrow_type_to_dtype(col.type) for col in at.columns], + index=at.column_names, + ) + return cls( + partitions=new_frame, + index=new_index, + columns=new_columns, + row_lengths=new_lengths, + column_widths=new_widths, + dtypes=new_dtypes, + ) + + @classmethod + def _arrow_type_to_dtype(cls, arrow_type): + """ + Convert an arrow data type to a pandas data type. + + Parameters + ---------- + arrow_type : arrow dtype + Arrow data type to be converted to a pandas data type. + + Returns + ------- + object + Any dtype compatible with pandas. + """ + import pyarrow + + try: + res = arrow_type.to_pandas_dtype() + # Conversion to pandas is not implemented for some arrow types, + # perform manual conversion for them: + except NotImplementedError: + if pyarrow.types.is_time(arrow_type): + res = np.dtype(datetime.time) + else: + raise + + if not isinstance(res, (np.dtype, str)): + return np.dtype(res) + return res + + @lazy_metadata_decorator(apply_axis="both") + def to_pandas(self): + """ + Convert this Modin DataFrame to a pandas DataFrame. + + Returns + ------- + pandas.DataFrame + """ + df = self._partition_mgr_cls.to_pandas(self._partitions) + if df.empty: + df = pandas.DataFrame(columns=self.columns, index=self.index) + else: + for axis, external_index in enumerate( + [self._index_cache, self._columns_cache] + ): + # no need to check external and internal axes since in that case + # external axes will be computed from internal partitions + if external_index is not None: + ErrorMessage.internal_error( + not df.axes[axis].equals(external_index), + f"Internal and external indices on axis {axis} do not match.", + ) + # have to do this in order to assign some potentially missing metadata, + # the ones that were set to the external index but were never propagated + # into the internal ones + df = df.set_axis(axis=axis, labels=external_index, copy=False) + + return df + + def to_numpy(self, **kwargs): + """ + Convert this Modin DataFrame to a NumPy array. + + Parameters + ---------- + **kwargs : dict + Additional keyword arguments to be passed in `to_numpy`. + + Returns + ------- + np.ndarray + """ + return self._partition_mgr_cls.to_numpy(self._partitions, **kwargs) + + @lazy_metadata_decorator(apply_axis=None, transpose=True) + def transpose(self): + """ + Transpose the index and columns of this Modin DataFrame. + + Reflect this Modin DataFrame over its main diagonal + by writing rows as columns and vice-versa. + + Returns + ------- + PandasDataframe + New Modin DataFrame. + """ + new_partitions = self._partition_mgr_cls.lazy_map_partitions( + self._partitions, lambda df: df.T + ).T + if self._dtypes is not None: + new_dtypes = pandas.Series( + np.full(len(self.index), find_common_type(self.dtypes.values)), + index=self.index, + ) + else: + new_dtypes = None + return self.__constructor__( + new_partitions, + self._columns_cache, + self._index_cache, + self._column_widths_cache, + self._row_lengths_cache, + dtypes=new_dtypes, + ) + + def finalize(self): + """ + Perform all deferred calls on partitions. + + This makes `self` Modin Dataframe independent of a history of queries + that were used to build it. + """ + self._partition_mgr_cls.finalize(self._partitions) + + def __dataframe__(self, nan_as_null: bool = False, allow_copy: bool = True): + """ + Get a Modin DataFrame that implements the dataframe exchange protocol. + + See more about the protocol in https://data-apis.org/dataframe-protocol/latest/index.html. + + Parameters + ---------- + nan_as_null : bool, default: False + A keyword intended for the consumer to tell the producer + to overwrite null values in the data with ``NaN`` (or ``NaT``). + This currently has no effect; once support for nullable extension + dtypes is added, this value should be propagated to columns. + allow_copy : bool, default: True + A keyword that defines whether or not the library is allowed + to make a copy of the data. For example, copying data would be necessary + if a library supports strided buffers, given that this protocol + specifies contiguous buffers. Currently, if the flag is set to ``False`` + and a copy is needed, a ``RuntimeError`` will be raised. + + Returns + ------- + ProtocolDataframe + A dataframe object following the dataframe protocol specification. + """ + from modin.core.dataframe.pandas.interchange.dataframe_protocol.dataframe import ( + PandasProtocolDataframe, + ) + + return PandasProtocolDataframe( + self, nan_as_null=nan_as_null, allow_copy=allow_copy + ) + + @classmethod + def from_dataframe(cls, df: "ProtocolDataframe") -> "PandasDataframe": # noqa F821 + """ + Convert a DataFrame implementing the dataframe exchange protocol to a Core Modin Dataframe. + + See more about the protocol in https://data-apis.org/dataframe-protocol/latest/index.html. + + Parameters + ---------- + df : ProtocolDataframe + The DataFrame object supporting the dataframe exchange protocol. + + Returns + ------- + PandasDataframe + A new Core Modin Dataframe object. + """ + if type(df) == cls: + return df + + if not hasattr(df, "__dataframe__"): + raise ValueError( + "`df` does not support DataFrame exchange protocol, i.e. `__dataframe__` method" + ) + + from modin.core.dataframe.pandas.interchange.dataframe_protocol.from_dataframe import ( + from_dataframe_to_pandas, + ) + + pandas_df = from_dataframe_to_pandas(df) + return cls.from_pandas(pandas_df) diff --git a/src/snowflake/snowpark/modin/pandas/__init__.py b/src/snowflake/snowpark/modin/pandas/__init__.py new file mode 100644 index 00000000000..d2650e84640 --- /dev/null +++ b/src/snowflake/snowpark/modin/pandas/__init__.py @@ -0,0 +1,348 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +import sys +import warnings +from typing import Any + +import pandas + +__pandas_version__ = "2.2.1" + + +if sys.version_info.major == 3 and sys.version_info.minor == 8: + raise RuntimeError( + "Snowpark pandas does not support Python 3.8. Please update to Python 3.9 or later, and" + + f" update your pandas version to {__pandas_version__}." + ) # pragma: no cover + +if pandas.__version__ != __pandas_version__: + raise RuntimeError( + f"The pandas version installed ({pandas.__version__}) does not match the supported pandas version in" + + f" Snowpark pandas ({__pandas_version__}). Please update with `pip install pandas=={__pandas_version__}`." + ) # pragma: no cover + +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + from pandas import describe_option # noqa: F401 + from pandas import get_option # noqa: F401 + from pandas import option_context # noqa: F401 + from pandas import reset_option # noqa: F401 + from pandas import ( # noqa: F401 + NA, + ArrowDtype, + BooleanDtype, + Categorical, + CategoricalDtype, + CategoricalIndex, + DateOffset, + DatetimeIndex, + DatetimeTZDtype, + ExcelWriter, + Flags, + Float32Dtype, + Float64Dtype, + Grouper, + Index, + IndexSlice, + Int8Dtype, + Int16Dtype, + Int32Dtype, + Int64Dtype, + Interval, + IntervalDtype, + IntervalIndex, + MultiIndex, + NamedAgg, + NaT, + Period, + PeriodDtype, + PeriodIndex, + RangeIndex, + SparseDtype, + StringDtype, + Timedelta, + TimedeltaIndex, + Timestamp, + UInt8Dtype, + UInt16Dtype, + UInt32Dtype, + UInt64Dtype, + api, + array, + bdate_range, + eval, + factorize, + from_dummies, + infer_freq, + interval_range, + offsets, + options, + period_range, + set_eng_float_format, + set_option, + test, + timedelta_range, + ) + +# TODO: SNOW-851745 make sure add all Snowpark pandas API general functions +from snowflake.snowpark.modin.pandas.dataframe import DataFrame +from snowflake.snowpark.modin.pandas.general import ( + concat, + crosstab, + cut, + date_range, + get_dummies, + isna, + isnull, + lreshape, + melt, + merge, + merge_asof, + merge_ordered, + notna, + notnull, + pivot, + pivot_table, + qcut, + to_datetime, + to_numeric, + to_timedelta, + unique, + value_counts, + wide_to_long, +) +from snowflake.snowpark.modin.pandas.io import ( + ExcelFile, + HDFStore, + json_normalize, + read_clipboard, + read_csv, + read_excel, + read_feather, + read_fwf, + read_gbq, + read_hdf, + read_html, + read_json, + read_orc, + read_parquet, + read_pickle, + read_sas, + read_spss, + read_sql, + read_sql_query, + read_sql_table, + read_stata, + read_table, + read_xml, + to_pickle, +) +from snowflake.snowpark.modin.pandas.plotting import Plotting as plotting +from snowflake.snowpark.modin.pandas.series import Series +from snowflake.snowpark.modin.plugin._internal.session import SnowpandasSessionHolder + +# The extensions assigned to this module +_PD_EXTENSIONS_: dict = {} + +# base needs to be re-exported in order to properly override docstrings for BasePandasDataset +# moving this import higher prevents sphinx from building documentation (??) +from snowflake.snowpark.modin.pandas import base # isort: skip # noqa: E402,F401 + +import snowflake.snowpark.modin.plugin.extensions.pd_extensions as pd_extensions # isort: skip # noqa: E402,F401 +import snowflake.snowpark.modin.plugin.extensions.pd_overrides # isort: skip # noqa: E402,F401 +import snowflake.snowpark.modin.plugin.extensions.dataframe_extensions # isort: skip # noqa: E402,F401 +import snowflake.snowpark.modin.plugin.extensions.dataframe_overrides # isort: skip # noqa: E402,F401 +import snowflake.snowpark.modin.plugin.extensions.series_extensions # isort: skip # noqa: E402,F401 +import snowflake.snowpark.modin.plugin.extensions.series_overrides # isort: skip # noqa: E402,F401 + + +def __getattr__(name: str) -> Any: + """ + Overrides getattr on the module to enable extensions. + Parameters + ---------- + name : str + The name of the attribute being retrieved. + Returns + ------- + Attribute + Returns the extension attribute, if it exists, otherwise returns the attribute + imported in this file. + """ + try: + return _PD_EXTENSIONS_.get(name, globals()[name]) + except KeyError: + raise AttributeError( + f"module 'snowflake.snowpark.modin.pandas' has no attribute '{name}'" + ) + + +__all__ = [ # noqa: F405 + "DataFrame", + "Series", + "read_csv", + "read_parquet", + "read_json", + "read_html", + "read_clipboard", + "read_excel", + "read_hdf", + "read_feather", + "read_stata", + "read_sas", + "read_pickle", + "read_sql", + "read_gbq", + "read_table", + "read_spss", + "read_orc", + "json_normalize", + "concat", + "eval", + "cut", + "factorize", + "test", + "qcut", + "to_datetime", + "get_dummies", + "isna", + "isnull", + "merge", + "pivot_table", + "date_range", + "Index", + "MultiIndex", + "Series", + "bdate_range", + "period_range", + "DatetimeIndex", + "to_timedelta", + "set_eng_float_format", + "options", + "set_option", + "CategoricalIndex", + "Timedelta", + "Timestamp", + "NaT", + "PeriodIndex", + "Categorical", + "__version__", + "melt", + "crosstab", + "plotting", + "Interval", + "UInt8Dtype", + "UInt16Dtype", + "UInt32Dtype", + "UInt64Dtype", + "SparseDtype", + "Int8Dtype", + "Int16Dtype", + "Int32Dtype", + "Int64Dtype", + "CategoricalDtype", + "DatetimeTZDtype", + "IntervalDtype", + "PeriodDtype", + "BooleanDtype", + "StringDtype", + "NA", + "RangeIndex", + "TimedeltaIndex", + "IntervalIndex", + "IndexSlice", + "Grouper", + "array", + "Period", + "show_versions", + "DateOffset", + "timedelta_range", + "infer_freq", + "interval_range", + "ExcelWriter", + "read_fwf", + "read_sql_table", + "read_sql_query", + "ExcelFile", + "to_pickle", + "HDFStore", + "lreshape", + "wide_to_long", + "merge_asof", + "merge_ordered", + "notnull", + "notna", + "pivot", + "to_numeric", + "unique", + "value_counts", + "NamedAgg", + "api", + "read_xml", + "ArrowDtype", + "Flags", + "Float32Dtype", + "Float64Dtype", + "from_dummies", +] + +del pandas + +# Make SnowpandasSessionHolder this module's and modin.pandas's __class__ so that we can make +# "session" a lazy property of the modules. +# This implementation follows Python's suggestion here: +# https://docs.python.org/3.12/reference/datamodel.html#customizing-module-attribute-access +sys.modules[__name__].__class__ = SnowpandasSessionHolder +# When docs are generated, modin.pandas is not imported, so do not perform this overwrite +if "modin.pandas" in sys.modules: + sys.modules["modin.pandas"].__class__ = SnowpandasSessionHolder + +_SKIP_TOP_LEVEL_ATTRS = [ + # __version__ and show_versions are exported by __all__, but not currently defined in Snowpark pandas. + "__version__", + "show_versions", + # SNOW-1316523: Snowpark pandas should re-export the native pandas.api submodule, but doing so + # would override register_pd_accessor and similar methods defined in our own modin.pandas.extensions + # module. + "api", +] + +# Manually re-export the members of the pd_extensions namespace, which are not declared in __all__. +_EXTENSION_ATTRS = ["read_snowflake", "to_snowflake", "to_snowpark", "to_pandas"] +# We also need to re-export native_pd.offsets, since modin.pandas doesn't re-export it. +# snowflake.snowpark.pandas.base also needs to be re-exported to make docstring overrides for BasePandasDataset work. +_ADDITIONAL_ATTRS = ["offsets", "base"] + +# This code should eventually be moved into the `snowflake.snowpark.modin.plugin` module instead. +# Currently trying to do so would result in incorrect results because `snowflake.snowpark.modin.pandas` +# import submodules of `snowflake.snowpark.modin.plugin`, so we would encounter errors due to +# partially initialized modules. +import modin.pandas.api.extensions as _ext # noqa: E402 + +# This loop overrides all methods in the `modin.pandas` namespace so users can obtain Snowpark pandas objects from it. +for name in __all__ + _ADDITIONAL_ATTRS: + if name not in _SKIP_TOP_LEVEL_ATTRS: + # instead of using this as a decorator, we can call the function directly + _ext.register_pd_accessor(name)(__getattr__(name)) + +for name in _EXTENSION_ATTRS: + _ext.register_pd_accessor(name)(getattr(pd_extensions, name)) diff --git a/src/snowflake/snowpark/modin/pandas/accessor.py b/src/snowflake/snowpark/modin/pandas/accessor.py new file mode 100644 index 00000000000..ea9946e4209 --- /dev/null +++ b/src/snowflake/snowpark/modin/pandas/accessor.py @@ -0,0 +1,201 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +""" +Implement various accessor classes for DataFrame and Series API. + +SparseFrameAccessor implements API of pandas.DataFrame.sparse accessor. + +SparseAccessor implements API of pandas.Series.sparse accessor. + +CachedAccessor implements API of pandas.core.accessor.CachedAccessor +""" + +import pandas +from pandas.core.dtypes.dtypes import SparseDtype + +import snowflake.snowpark.modin.pandas as pd +from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage +from snowflake.snowpark.modin.utils import _inherit_docstrings + + +class BaseSparseAccessor: + """ + Base class for various sparse DataFrame accessor classes. + + Parameters + ---------- + data : DataFrame or Series + Object to operate on. + """ + + _validation_msg = "Can only use the '.sparse' accessor with Sparse data." + + def __init__(self, data=None) -> None: + self._parent = data + self._validate(data) + + @classmethod + def _validate(cls, data): + """ + Verify that `data` dtypes are compatible with `pandas.core.arrays.sparse.dtype.SparseDtype`. + + Parameters + ---------- + data : DataFrame + Object to check. + + Raises + ------ + NotImplementedError + Function is implemented in child classes. + """ + ErrorMessage.not_implemented("Implemented by subclasses") # pragma: no cover + + def _default_to_pandas(self, op, *args, **kwargs): + """ + Convert dataset to pandas type and call a pandas sparse.`op` on it. + + Parameters + ---------- + op : str + Name of pandas function. + *args : list + Additional positional arguments to be passed in `op`. + **kwargs : dict + Additional keywords arguments to be passed in `op`. + + Returns + ------- + object + Result of operation. + """ + return self._parent._default_to_pandas( + lambda parent: op(parent.sparse, *args, **kwargs) + ) + + +# Snowpark pandas does not support sparse accessors - remove docstrings to prevent doctests from running +# @_inherit_docstrings(pandas.core.arrays.sparse.accessor.SparseFrameAccessor) +class SparseFrameAccessor(BaseSparseAccessor): + @classmethod + def _validate(cls, data): + """ + Verify that `data` dtypes are compatible with `pandas.core.arrays.sparse.dtype.SparseDtype`. + + Parameters + ---------- + data : DataFrame + Object to check. + + Raises + ------ + AttributeError + If check fails. + """ + dtypes = data.dtypes + if not all(isinstance(t, SparseDtype) for t in dtypes): + raise AttributeError(cls._validation_msg) + + @property + def density(self): + return self._parent._default_to_pandas(pandas.DataFrame.sparse).density + + @classmethod + def from_spmatrix(cls, data, index=None, columns=None): + return pd.DataFrame( + pandas.DataFrame.sparse.from_spmatrix(data, index=index, columns=columns) + ) + + def to_dense(self): + return self._default_to_pandas(pandas.DataFrame.sparse.to_dense) + + def to_coo(self): + return self._default_to_pandas(pandas.DataFrame.sparse.to_coo) + + +# Snowpark pandas does not support sparse accessors - remove docstrings to prevent doctests from running +# @_inherit_docstrings(pandas.core.arrays.sparse.accessor.SparseAccessor) +class SparseAccessor(BaseSparseAccessor): + @classmethod + def _validate(cls, data): + """ + Verify that `data` dtype is compatible with `pandas.core.arrays.sparse.dtype.SparseDtype`. + + Parameters + ---------- + data : Series + Object to check. + + Raises + ------ + AttributeError + If check fails. + """ + if not isinstance(data.dtype, SparseDtype): + raise AttributeError(cls._validation_msg) + + @property + def density(self): + return self._parent._default_to_pandas(pandas.Series.sparse).density + + @property + def fill_value(self): + return self._parent._default_to_pandas(pandas.Series.sparse).fill_value + + @property + def npoints(self): + return self._parent._default_to_pandas(pandas.Series.sparse).npoints + + @property + def sp_values(self): + return self._parent._default_to_pandas(pandas.Series.sparse).sp_values + + @classmethod + def from_coo(cls, A, dense_index=False): + return cls._default_to_pandas( + pandas.Series.sparse.from_coo, A, dense_index=dense_index + ) + + def to_coo(self, row_levels=(0,), column_levels=(1,), sort_labels=False): + return self._default_to_pandas( + pandas.Series.sparse.to_coo, + row_levels=row_levels, + column_levels=column_levels, + sort_labels=sort_labels, + ) + + def to_dense(self): + return self._default_to_pandas(pandas.Series.sparse.to_dense) + + +@_inherit_docstrings(pandas.core.accessor.CachedAccessor) +class CachedAccessor: + def __init__(self, name: str, accessor) -> None: + self._name = name + self._accessor = accessor + + def __get__(self, obj, cls): + if obj is None: + return self._accessor + accessor_obj = self._accessor(obj) + object.__setattr__(obj, self._name, accessor_obj) + return accessor_obj diff --git a/src/snowflake/snowpark/modin/pandas/api/__init__.py b/src/snowflake/snowpark/modin/pandas/api/__init__.py new file mode 100644 index 00000000000..76858c5d9b1 --- /dev/null +++ b/src/snowflake/snowpark/modin/pandas/api/__init__.py @@ -0,0 +1,24 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +from snowflake.snowpark.modin.pandas.api import extensions + +__all__ = ["extensions"] diff --git a/src/snowflake/snowpark/modin/pandas/api/extensions/__init__.py b/src/snowflake/snowpark/modin/pandas/api/extensions/__init__.py new file mode 100644 index 00000000000..d91de294dae --- /dev/null +++ b/src/snowflake/snowpark/modin/pandas/api/extensions/__init__.py @@ -0,0 +1,32 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +from .extensions import ( + register_dataframe_accessor, + register_pd_accessor, + register_series_accessor, +) + +__all__ = [ + "register_dataframe_accessor", + "register_series_accessor", + "register_pd_accessor", +] diff --git a/src/snowflake/snowpark/modin/pandas/api/extensions/extensions.py b/src/snowflake/snowpark/modin/pandas/api/extensions/extensions.py new file mode 100644 index 00000000000..0520e6a44f1 --- /dev/null +++ b/src/snowflake/snowpark/modin/pandas/api/extensions/extensions.py @@ -0,0 +1,186 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +""" +File containing decorators that allow registering extension APIs on Modin's API layer objects. +In Modin OSS, this file is placed under modin/pandas/api/extensions/extensions.py. However, +since our repository does not use Modin as an external dependency, following this path construction +would cause circular imports. + +plugin/{pd,dataframe,series}_extensions.py must import this module (modin/pandas/extensions.py) +in order to use the decorators defined here. + +Note that telemetry and other decorators must always be placed below the extension decorator: +the extension decorator calls setattr on the relevant object/module, and therefore must be called +last for other decorators to be applied. Furthermore, all DataFrame/Series methods declared in +extensions must have a telemetry decorator, unlike those defined directly on the class, which +have telemetry automatically added by the TelemetryMeta metaclass. +""" + +from types import ModuleType +from typing import TYPE_CHECKING, Any, Union + +if TYPE_CHECKING: + import snowflake.snowpark.modin.pandas as pd + + +def _set_attribute_on_obj( + name: str, + extensions_dict: dict, + obj: Union["pd.DataFrame", "pd.Series", ModuleType], +): + """ + Create a new or override existing attribute on obj. + + Parameters + ---------- + name : str + The name of the attribute to assign to `obj`. + extensions_dict : dict + The dictionary mapping extension name to `new_attr` (assigned below). + obj : DataFrame, Series, or modin.pandas + The object we are assigning the new attribute to. + + Returns + ------- + decorator + Returns the decorator function. + """ + + def decorator(new_attr: Any): + """ + The decorator for a function or class to be assigned to name + + Parameters + ---------- + new_attr : Any + The new attribute to assign to name. + + Returns + ------- + new_attr + Unmodified new_attr is return from the decorator. + """ + extensions_dict[name] = new_attr + setattr(obj, name, new_attr) + return new_attr + + return decorator + + +def register_dataframe_accessor(name: str): + """ + Registers a dataframe attribute with the name provided. + This is a decorator that assigns a new attribute to DataFrame. It can be used + with the following syntax: + ``` + @register_dataframe_accessor("new_method") + def my_new_dataframe_method(*args, **kwargs): + # logic goes here + return + ``` + The new attribute can then be accessed with the name provided: + ``` + df.new_method(*my_args, **my_kwargs) + ``` + Parameters + ---------- + name : str + The name of the attribute to assign to DataFrame. + Returns + ------- + decorator + Returns the decorator function. + """ + import snowflake.snowpark.modin.pandas as pd + + return _set_attribute_on_obj( + name, + pd.dataframe._DATAFRAME_EXTENSIONS_, + pd.dataframe.DataFrame, + ) + + +def register_series_accessor(name: str): + """ + Registers a series attribute with the name provided. + This is a decorator that assigns a new attribute to Series. It can be used + with the following syntax: + ``` + @register_series_accessor("new_method") + def my_new_series_method(*args, **kwargs): + # logic goes here + return + ``` + The new attribute can then be accessed with the name provided: + ``` + s.new_method(*my_args, **my_kwargs) + ``` + Parameters + ---------- + name : str + The name of the attribute to assign to Series. + Returns + ------- + decorator + Returns the decorator function. + """ + import snowflake.snowpark.modin.pandas as pd + + return _set_attribute_on_obj(name, pd.series._SERIES_EXTENSIONS_, pd.series.Series) + + +def register_pd_accessor(name: str): + """ + Registers a pd namespace attribute with the name provided. + + This is a decorator that assigns a new attribute to modin.pandas. It can be used + with the following syntax: + + ``` + @register_pd_accessor("new_function") + def my_new_pd_function(*args, **kwargs): + # logic goes here + return + ``` + + The new attribute can then be accessed with the name provided: + + ``` + import modin.pandas as pd + + pd.new_method(*my_args, **my_kwargs) + ``` + + + Parameters + ---------- + name : str + The name of the attribute to assign to modin.pandas. + + Returns + ------- + decorator + Returns the decorator function. + """ + import snowflake.snowpark.modin.pandas as pd + + return _set_attribute_on_obj(name, pd._PD_EXTENSIONS_, pd) diff --git a/src/snowflake/snowpark/modin/pandas/base.py b/src/snowflake/snowpark/modin/pandas/base.py new file mode 100644 index 00000000000..c06ebeae1b7 --- /dev/null +++ b/src/snowflake/snowpark/modin/pandas/base.py @@ -0,0 +1,4184 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +"""Implement DataFrame/Series public API as pandas does.""" +from __future__ import annotations + +import pickle as pkl +import re +import warnings +from collections.abc import Hashable, Mapping, Sequence +from typing import Any, Callable, Literal, get_args + +import numpy as np +import numpy.typing as npt +import pandas +import pandas.core.generic +import pandas.core.resample +import pandas.core.window.rolling +from pandas._libs import lib +from pandas._libs.lib import NoDefault, is_bool, no_default +from pandas._typing import ( + AggFuncType, + AnyArrayLike, + Axes, + Axis, + CompressionOptions, + DtypeBackend, + FillnaOptions, + IgnoreRaise, + IndexKeyFunc, + IndexLabel, + Level, + NaPosition, + RandomState, + Scalar, + StorageOptions, + TimedeltaConvertibleTypes, + TimestampConvertibleTypes, +) +from pandas.compat import numpy as numpy_compat +from pandas.core.common import apply_if_callable, count_not_none, pipe +from pandas.core.dtypes.common import ( + is_dict_like, + is_dtype_equal, + is_list_like, + is_numeric_dtype, + is_object_dtype, + pandas_dtype, +) +from pandas.core.dtypes.inference import is_integer +from pandas.core.indexes.api import ensure_index +from pandas.util._validators import ( + validate_ascending, + validate_bool_kwarg, + validate_percentile, +) + +from snowflake.snowpark.modin import pandas as pd +from snowflake.snowpark.modin.pandas.utils import ( + _doc_binary_op, + get_as_shape_compatible_dataframe_or_series, + is_scalar, + raise_if_native_pandas_objects, + validate_and_try_convert_agg_func_arg_func_to_str, +) +from snowflake.snowpark.modin.plugin._internal.telemetry import TelemetryMeta +from snowflake.snowpark.modin.plugin._typing import ListLike +from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage +from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage +from snowflake.snowpark.modin.utils import ( + _inherit_docstrings, + try_cast_to_pandas, + validate_int_kwarg, +) + +# Similar to pandas, sentinel value to use as kwarg in place of None when None has +# special meaning and needs to be distinguished from a user explicitly passing None. +sentinel = object() + +# Do not look up certain attributes in columns or index, as they're used for some +# special purposes, like serving remote context +_ATTRS_NO_LOOKUP = { + "____id_pack__", + "__name__", + "_cache", + "_ipython_canary_method_should_not_exist_", + "_ipython_display_", + "_repr_html_", + "_repr_javascript_", + "_repr_jpeg_", + "_repr_json_", + "_repr_latex_", + "_repr_markdown_", + "_repr_mimebundle_", + "_repr_pdf_", + "_repr_png_", + "_repr_svg_", + "__array_struct__", + "__array_interface__", + "_typ", +} + +_DEFAULT_BEHAVIOUR = { + "__init__", + "__class__", + "_get_index", + "_set_index", + "_pandas_class", + "_get_axis_number", + "empty", + "index", + "columns", + "name", + "dtypes", + "dtype", + "groupby", + "_get_name", + "_set_name", + "_default_to_pandas", + "_query_compiler", + "_to_pandas", + "_repartition", + "_build_repr_df", + "_reduce_dimension", + "__repr__", + "__len__", + "__constructor__", + "_create_or_update_from_compiler", + "_update_inplace", + # for persistance support; + # see DataFrame methods docstrings for more + "_inflate_light", + "_inflate_full", + "__reduce__", + "__reduce_ex__", + "_init", +} | _ATTRS_NO_LOOKUP + + +_doc_binary_op_kwargs = {"returns": "BasePandasDataset", "left": "BasePandasDataset"} + + +@_inherit_docstrings( + pandas.DataFrame, + apilink=["pandas.DataFrame", "pandas.Series"], + excluded=[ + pandas.DataFrame.between_time, + pandas.Series.between_time, + pandas.DataFrame.flags, + pandas.Series.flags, + pandas.DataFrame.kurt, + pandas.Series.kurt, + pandas.DataFrame.kurtosis, + pandas.Series.kurtosis, + pandas.DataFrame.rank, + pandas.Series.rank, + pandas.DataFrame.to_csv, + pandas.Series.to_csv, + pandas.DataFrame.sum, + ], +) +class BasePandasDataset(metaclass=TelemetryMeta): + """ + Implement most of the common code that exists in DataFrame/Series. + + Since both objects share the same underlying representation, and the algorithms + are the same, we use this object to define the general behavior of those objects + and then use those objects to define the output type. + + TelemetryMeta is a metaclass that automatically add telemetry decorators to classes/instance methods. + See TelemetryMeta for details. Note: Its subclasses will inherit this metaclass. + """ + + # pandas class that we pretend to be; usually it has the same name as our class + # but lives in "pandas" namespace. + _pandas_class = pandas.core.generic.NDFrame + + @pandas.util.cache_readonly + def _is_dataframe(self) -> bool: + """ + Tell whether this is a dataframe. + + Ideally, other methods of BasePandasDataset shouldn't care whether this + is a dataframe or a series, but sometimes we need to know. This method + is better than hasattr(self, "columns"), which for series will call + self.__getattr__("columns"), which requires materializing the index. + + Returns + ------- + bool : Whether this is a dataframe. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return issubclass(self._pandas_class, pandas.DataFrame) + + def _add_sibling(self, sibling): + """ + Add a DataFrame or Series object to the list of siblings. + + Siblings are objects that share the same query compiler. This function is called + when a shallow copy is made. + + Parameters + ---------- + sibling : BasePandasDataset + Dataset to add to siblings list. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + sibling._siblings = self._siblings + [self] + self._siblings += [sibling] + for sib in self._siblings: + sib._siblings += [sibling] + + def _update_inplace(self, new_query_compiler): + """ + Update the current DataFrame inplace. + + Parameters + ---------- + new_query_compiler : query_compiler + The new QueryCompiler to use to manage the data. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + old_query_compiler = self._query_compiler + self._query_compiler = new_query_compiler + for sib in self._siblings: + sib._query_compiler = new_query_compiler + old_query_compiler.free() + + def _validate_other( + self, + other, + axis, + dtype_check=False, + compare_index=False, + ): + """ + Help to check validity of other in inter-df operations. + + Parameters + ---------- + other : modin.pandas.BasePandasDataset + Another dataset to validate against `self`. + axis : {None, 0, 1} + Specifies axis along which to do validation. When `1` or `None` + is specified, validation is done along `index`, if `0` is specified + validation is done along `columns` of `other` frame. + dtype_check : bool, default: False + Validates that both frames have compatible dtypes. + compare_index : bool, default: False + Compare Index if True. + + Returns + ------- + modin.pandas.BasePandasDataset + Other frame if it is determined to be valid. + + Raises + ------ + ValueError + If `other` is `Series` and its length is different from + length of `self` `axis`. + TypeError + If any validation checks fail. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + if isinstance(other, BasePandasDataset): + return other._query_compiler + if not is_list_like(other): + # We skip dtype checking if the other is a scalar. Note that pandas + # is_scalar can be misleading as it is False for almost all objects, + # even when those objects should be treated as scalars. See e.g. + # https://github.com/modin-project/modin/issues/5236. Therefore, we + # detect scalars by checking that `other` is neither a list-like nor + # another BasePandasDataset. + return other + axis = self._get_axis_number(axis) if axis is not None else 1 + result = other + if axis == 0: + if len(other) != len(self._query_compiler.index): + raise ValueError( + f"Unable to coerce to Series, length must be {len(self._query_compiler.index)}: " + + f"given {len(other)}" + ) + else: + if len(other) != len(self._query_compiler.columns): + raise ValueError( + f"Unable to coerce to Series, length must be {len(self._query_compiler.columns)}: " + + f"given {len(other)}" + ) + if hasattr(other, "dtype"): + other_dtypes = [other.dtype] * len(other) + elif is_dict_like(other): + other_dtypes = [ + type(other[label]) + for label in self._query_compiler.get_axis(axis) + # The binary operation is applied for intersection of axis labels + # and dictionary keys. So filtering out extra keys. + if label in other + ] + else: + other_dtypes = [type(x) for x in other] + if compare_index: + if not self.index.equals(other.index): + raise TypeError("Cannot perform operation with non-equal index") + # Do dtype checking. + if dtype_check: + self_dtypes = self._get_dtypes() + if is_dict_like(other): + # The binary operation is applied for the intersection of axis labels + # and dictionary keys. So filtering `self_dtypes` to match the `other` + # dictionary. + self_dtypes = [ + dtype + for label, dtype in zip( + self._query_compiler.get_axis(axis), self._get_dtypes() + ) + if label in other + ] + + # TODO(https://github.com/modin-project/modin/issues/5239): + # this spuriously rejects other that is a list including some + # custom type that can be added to self's elements. + if not all( + (is_numeric_dtype(self_dtype) and is_numeric_dtype(other_dtype)) + or (is_object_dtype(self_dtype) and is_object_dtype(other_dtype)) + # Check if dtype is timedelta ("m") or datetime ("M") + or ( + lib.is_np_dtype(self_dtype, "mM") + and lib.is_np_dtype(other_dtype, "mM") + ) + or is_dtype_equal(self_dtype, other_dtype) + for self_dtype, other_dtype in zip(self_dtypes, other_dtypes) + ): + raise TypeError("Cannot do operation with improper dtypes") + return result + + def _validate_function(self, func, on_invalid=None): + """ + Check the validity of the function which is intended to be applied to the frame. + + Parameters + ---------- + func : object + on_invalid : callable(str, cls), optional + Function to call in case invalid `func` is met, `on_invalid` takes an error + message and an exception type as arguments. If not specified raise an + appropriate exception. + **Note:** This parameter is a hack to concord with pandas error types. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + + def error_raiser(msg, exception=Exception): + raise exception(msg) + + if on_invalid is None: + on_invalid = error_raiser + + if isinstance(func, dict): + [self._validate_function(fn, on_invalid) for fn in func.values()] + return + # We also could validate this, but it may be quite expensive for lazy-frames + # if not all(idx in self.axes[axis] for idx in func.keys()): + # error_raiser("Invalid dict keys", KeyError) + + if not is_list_like(func): + func = [func] + + for fn in func: + if isinstance(fn, str): + if not (hasattr(self, fn) or hasattr(np, fn)): + on_invalid( + f"{fn} is not valid function for {type(self)} object.", + AttributeError, + ) + elif not callable(fn): + on_invalid( + f"One of the passed functions has an invalid type: {type(fn)}: {fn}, " + + "only callable or string is acceptable.", + TypeError, + ) + + def _binary_op( + self, + op: str, + other: BasePandasDataset, + axis: Axis, + level: Level | None = None, + fill_value: float | None = None, + **kwargs: Any, + ): + """ + Do binary operation between two datasets. + + Parameters + ---------- + op : str + Name of binary operation. + other : modin.pandas.BasePandasDataset + Second operand of binary operation. + axis: Whether to compare by the index (0 or ‘index’) or columns. (1 or ‘columns’). + level: Broadcast across a level, matching Index values on the passed MultiIndex level. + fill_value: Fill existing missing (NaN) values, and any new element needed for + successful DataFrame alignment, with this value before computation. + If data in both corresponding DataFrame locations is missing the result will be missing. + only arithmetic binary operation has this parameter (e.g., add() has, but eq() doesn't have). + + kwargs can contain the following parameters passed in at the frontend: + func: Only used for `combine` method. Function that takes two series as inputs and + return a Series or a scalar. Used to merge the two dataframes column by columns. + + Returns + ------- + modin.pandas.BasePandasDataset + Result of binary operation. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + raise_if_native_pandas_objects(other) + axis = self._get_axis_number(axis) + squeeze_self = isinstance(self, pd.Series) + + # pandas itself will ignore the axis argument when using Series.. + # Per default, it is set to axis=0. However, for the case of a Series interacting with + # a DataFrame the behavior is axis=1. Manually check here for this case and adjust the axis. + if isinstance(self, pd.Series) and isinstance(other, pd.DataFrame): + axis = 1 + + new_query_compiler = self._query_compiler.binary_op( + op=op, + other=other, + axis=axis, + level=level, + fill_value=fill_value, + squeeze_self=squeeze_self, + **kwargs, + ) + return self._create_or_update_from_compiler(new_query_compiler) + + def _default_to_pandas(self, op, *args, **kwargs): + """ + Convert dataset to pandas type and call a pandas function on it. + + Parameters + ---------- + op : str + Name of pandas function. + *args : list + Additional positional arguments to be passed to `op`. + **kwargs : dict + Additional keywords arguments to be passed to `op`. + + Returns + ------- + object + Result of operation. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + args = try_cast_to_pandas(args) + kwargs = try_cast_to_pandas(kwargs) + pandas_obj = self._to_pandas() + if callable(op): + result = op(pandas_obj, *args, **kwargs) + elif isinstance(op, str): + # The inner `getattr` is ensuring that we are treating this object (whether + # it is a DataFrame, Series, etc.) as a pandas object. The outer `getattr` + # will get the operation (`op`) from the pandas version of the class and run + # it on the object after we have converted it to pandas. + attr = getattr(self._pandas_class, op) + if isinstance(attr, property): + result = getattr(pandas_obj, op) + else: + result = attr(pandas_obj, *args, **kwargs) + else: + ErrorMessage.internal_error( + failure_condition=True, + extra_log=f"{op} is an unsupported operation", + ) + # SparseDataFrames cannot be serialized by arrow and cause problems for Modin. + # For now we will use pandas. + if isinstance(result, type(self)) and not isinstance( + result, (pandas.SparseDataFrame, pandas.SparseSeries) + ): + return self._create_or_update_from_compiler( + result, inplace=kwargs.get("inplace", False) + ) + elif isinstance(result, pandas.DataFrame): + from snowflake.snowpark.modin.pandas import DataFrame + + return DataFrame(result) + elif isinstance(result, pandas.Series): + from snowflake.snowpark.modin.pandas import Series + + return Series(result) + # inplace + elif result is None: + return self._create_or_update_from_compiler( + getattr(pd, type(pandas_obj).__name__)(pandas_obj)._query_compiler, + inplace=True, + ) + else: + try: + if ( + isinstance(result, (list, tuple)) + and len(result) == 2 + and isinstance(result[0], pandas.DataFrame) + ): + # Some operations split the DataFrame into two (e.g. align). We need to wrap + # both of the returned results + if isinstance(result[1], pandas.DataFrame): + second = self.__constructor__(result[1]) + else: + second = result[1] + return self.__constructor__(result[0]), second + else: + return result + except TypeError: + return result + + @classmethod + def _get_axis_number(cls, axis): + """ + Convert axis name or number to axis index. + + Parameters + ---------- + axis : int, str or pandas._libs.lib.NoDefault + Axis name ('index' or 'columns') or number to be converted to axis index. + + Returns + ------- + int + 0 or 1 - axis index in the array of axes stored in the dataframe. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + if axis is no_default: + axis = None + + return cls._pandas_class._get_axis_number(axis) if axis is not None else 0 + + @pandas.util.cache_readonly + def __constructor__(self): + """ + Construct DataFrame or Series object depending on self type. + + Returns + ------- + modin.pandas.BasePandasDataset + Constructed object. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return type(self) + + def abs(self): # noqa: RT01, D200 + """ + Return a `BasePandasDataset` with absolute numeric value of each element. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self.__constructor__(query_compiler=self._query_compiler.unary_op("abs")) + + def _to_series_list(self, index: pd.Index) -> list[pd.Series]: + """ + Convert index to a list of series + Args: + index: can be single or multi index + + Returns: + the list of series + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + if isinstance(index, pd.MultiIndex): + return [ + pd.Series(index.get_level_values(level)) + for level in range(index.nlevels) + ] + elif isinstance(index, pd.Index): + return [pd.Series(index)] + + def _set_index(self, new_index: Axes) -> None: + """ + Set the index for this DataFrame. + + Parameters + ---------- + new_index : pandas.Index + The new index to set this. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + self._update_inplace( + new_query_compiler=self._query_compiler.set_index( + [ + s._query_compiler + for s in self._to_series_list(ensure_index(new_index)) + ] + ) + ) + + def set_axis( + self, + labels: IndexLabel, + *, + axis: Axis = 0, + copy: bool | NoDefault = no_default, + ): + """ + Assign desired index to given axis. + """ + # Behavior based on copy: + # ----------------------------------- + # - In native pandas, copy determines whether to create a copy of the data (not DataFrame). + # - We cannot emulate the native pandas' copy behavior in Snowpark since a copy of only data + # cannot be created -- you can only copy the whole object (DataFrame/Series). + # + # Snowpark behavior: + # ------------------ + # - copy is kept for compatibility with native pandas but is ignored. The user is warned that copy is unused. + # Warn user that copy does not do anything. + if copy is not no_default: + WarningMessage.single_warning( + message=f"{type(self).__name__}.set_axis 'copy' keyword is unused and is ignored." + ) + if labels is None: + raise TypeError("None is not a valid value for the parameter 'labels'.") + + # Determine whether to update self or a copy and perform update. + obj = self.copy() + setattr(obj, axis, labels) + return obj + + def _get_index(self): + """ + Get the index for this DataFrame. + + Returns + ------- + pandas.Index + The union of all indexes across the partitions. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._query_compiler.index + + index = property(_get_index, _set_index) + + def add( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Return addition of `BasePandasDataset` and `other`, element-wise (binary operator `add`). + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._binary_op( + "add", other, axis=axis, level=level, fill_value=fill_value + ) + + def aggregate( + self, func: AggFuncType = None, axis: Axis | None = 0, *args: Any, **kwargs: Any + ): + """ + Aggregate using one or more operations over the specified axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + from snowflake.snowpark.modin.pandas import Series + + origin_axis = axis + axis = self._get_axis_number(axis) + + if axis == 1 and isinstance(self, Series): + raise ValueError(f"No axis named {origin_axis} for object type Series") + + if len(self._query_compiler.columns) == 0: + # native pandas raise error with message "no result", here we raise a more readable error. + raise ValueError("No column to aggregate on.") + + func = validate_and_try_convert_agg_func_arg_func_to_str( + agg_func=func, + obj=self, + allow_duplication=False, + axis=axis, + ) + + # This is to stay consistent with pandas result format, when the func is single + # aggregation function in format of callable or str, reduce the result dimension to + # convert dataframe to series, or convert series to scalar. + need_reduce_dimension = ( + (callable(func) or isinstance(func, str)) + # A Series should be returned when a single scalar string/function aggregation function, or a + # dict of scalar string/functions is specified. In all other cases (including if the function + # is a 1-element list), the result is a DataFrame. + # + # The examples below have axis=1, but the same logic is applied for axis=0. + # >>> df = pd.DataFrame({"a": [0, 1], "b": [2, 3]}) + # + # single aggregation: return Series + # >>> df.agg("max", axis=1) + # 0 2 + # 1 3 + # dtype: int64 + # + # list of aggregations: return DF + # >>> df.agg(["max"], axis=1) + # max + # 0 2 + # 1 3 + # + # dict where all aggregations are strings: return Series + # >>> df.agg({1: "max", 0: "min"}, axis=1) + # 1 3 + # 0 0 + # dtype: int64 + # + # dict where one element is a list: return DF + # >>> df.agg({1: "max", 0: ["min"]}, axis=1) + # max min + # 1 3.0 NaN + # 0 NaN 0.0 + or ( + is_dict_like(func) + and all(not is_list_like(value) for value in func.values()) + ) + ) + + # If func is a dict, pandas will not respect kwargs for each aggregation function, and + # we should drop them before passing the to the query compiler. + # + # >>> native_pd.DataFrame({"a": [0, 1], "b": [np.nan, 0]}).agg("max", skipna=False, axis=1) + # 0 NaN + # 1 1.0 + # dtype: float64 + # >>> native_pd.DataFrame({"a": [0, 1], "b": [np.nan, 0]}).agg(["max"], skipna=False, axis=1) + # max + # 0 0.0 + # 1 1.0 + # >>> pd.DataFrame([[np.nan], [0]]).aggregate("count", skipna=True, axis=0) + # 0 1 + # dtype: int8 + # >>> pd.DataFrame([[np.nan], [0]]).count(skipna=True, axis=0) + # TypeError: got an unexpected keyword argument 'skipna' + if is_dict_like(func): + kwargs.clear() + + result = self.__constructor__( + query_compiler=self._query_compiler.agg( + func=func, + axis=axis, + args=args, + kwargs=kwargs, + ) + ) + + if need_reduce_dimension: + if self._is_dataframe: + result = Series(query_compiler=result._query_compiler) + + if isinstance(result, Series): + # When func is just "quantile" with a scalar q, result has quantile value as name + q = kwargs.get("q", 0.5) + if func == "quantile" and is_scalar(q): + result.name = q + else: + result.name = None + + # handle case for single scalar (same as result._reduce_dimension()) + if isinstance(self, Series): + return result.to_pandas().squeeze() + + return result + + agg = aggregate + + def _string_function(self, func, *args, **kwargs): + """ + Execute a function identified by its string name. + + Parameters + ---------- + func : str + Function name to call on `self`. + *args : list + Positional arguments to pass to func. + **kwargs : dict + Keyword arguments to pass to func. + + Returns + ------- + object + Function result. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + assert isinstance(func, str) + f = getattr(self, func, None) + if f is not None: + if callable(f): + return f(*args, **kwargs) + assert len(args) == 0 + assert len([kwarg for kwarg in kwargs if kwarg != "axis"]) == 0 + return f + f = getattr(np, func, None) + if f is not None: + return self._default_to_pandas("agg", func, *args, **kwargs) + raise ValueError(f"{func} is an unknown string function") + + def _get_dtypes(self): + """ + Get dtypes as list. + + Returns + ------- + list + Either a one-element list that contains `dtype` if object denotes a Series + or a list that contains `dtypes` if object denotes a DataFrame. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + if hasattr(self, "dtype"): + return [self.dtype] + else: + return list(self.dtypes) + + def align( + self, + other, + join="outer", + axis=None, + level=None, + copy=None, + fill_value=None, + method=lib.no_default, + limit=lib.no_default, + fill_axis=lib.no_default, + broadcast_axis=lib.no_default, + ): # noqa: PR01, RT01, D200 + """ + Align two objects on their axes with the specified join method. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + return self._default_to_pandas( + "align", + other, + join=join, + axis=axis, + level=level, + copy=copy, + fill_value=fill_value, + method=method, + limit=limit, + fill_axis=fill_axis, + broadcast_axis=broadcast_axis, + ) + + def all(self, axis=0, bool_only=None, skipna=True, **kwargs): + """ + Return whether all elements are True, potentially over an axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + validate_bool_kwarg(skipna, "skipna", none_allowed=False) + if axis is not None: + axis = self._get_axis_number(axis) + if bool_only and axis == 0: + if hasattr(self, "dtype"): + ErrorMessage.not_implemented( + "{}.{} does not implement numeric_only.".format( + type(self).__name__, "all" + ) + ) # pragma: no cover + data_for_compute = self[self.columns[self.dtypes == np.bool_]] + return data_for_compute.all( + axis=axis, bool_only=False, skipna=skipna, **kwargs + ) + return self._reduce_dimension( + self._query_compiler.all( + axis=axis, bool_only=bool_only, skipna=skipna, **kwargs + ) + ) + else: + if bool_only: + raise ValueError(f"Axis must be 0 or 1 (got {axis})") + # Reduce to a scalar if axis is None. + result = self._reduce_dimension( + # FIXME: Judging by pandas docs `**kwargs` serves only compatibility + # purpose and does not affect the result, we shouldn't pass them to the query compiler. + self._query_compiler.all( + axis=0, + bool_only=bool_only, + skipna=skipna, + **kwargs, + ) + ) + if isinstance(result, BasePandasDataset): + return result.all( + axis=axis, bool_only=bool_only, skipna=skipna, **kwargs + ) + return result + + def any(self, axis=0, bool_only=None, skipna=True, **kwargs): + """ + Return whether any element is True, potentially over an axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + validate_bool_kwarg(skipna, "skipna", none_allowed=False) + if axis is not None: + axis = self._get_axis_number(axis) + if bool_only and axis == 0: + if hasattr(self, "dtype"): + ErrorMessage.not_implemented( + "{}.{} does not implement numeric_only.".format( + type(self).__name__, "all" + ) + ) # pragma: no cover + data_for_compute = self[self.columns[self.dtypes == np.bool_]] + return data_for_compute.any( + axis=axis, bool_only=False, skipna=skipna, **kwargs + ) + return self._reduce_dimension( + self._query_compiler.any( + axis=axis, bool_only=bool_only, skipna=skipna, **kwargs + ) + ) + else: + if bool_only: + raise ValueError(f"Axis must be 0 or 1 (got {axis})") + # Reduce to a scalar if axis is None. + result = self._reduce_dimension( + self._query_compiler.any( + axis=0, + bool_only=bool_only, + skipna=skipna, + **kwargs, + ) + ) + if isinstance(result, BasePandasDataset): + return result.any( + axis=axis, bool_only=bool_only, skipna=skipna, **kwargs + ) + return result + + def apply( + self, + func, + axis, + broadcast, + raw, + reduce, + result_type, + convert_dtype, + args, + **kwds, + ): # noqa: PR01, RT01, D200 + """ + Apply a function along an axis of the `BasePandasDataset`. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + + def error_raiser(msg, exception): + """Convert passed exception to the same type as pandas do and raise it.""" + # HACK: to concord with pandas error types by replacing all of the + # TypeErrors to the AssertionErrors + exception = exception if exception is not TypeError else AssertionError + raise exception(msg) + + self._validate_function(func, on_invalid=error_raiser) + axis = self._get_axis_number(axis) + # TODO SNOW-864025: Support str in series.apply and df.apply + if isinstance(func, str): + # if axis != 1 function can be bounded to the Series, which doesn't + # support axis parameter + if axis == 1: + kwds["axis"] = axis + result = self._string_function(func, *args, **kwds) + if isinstance(result, BasePandasDataset): + return result._query_compiler + return result + # TODO SNOW-856682: Support dict in series.apply and df.apply + elif isinstance(func, dict): + if len(self.columns) != len(set(self.columns)): + WarningMessage.mismatch_with_pandas( + operation="apply", + message="Duplicate column names not supported with apply().", + ) # pragma: no cover + query_compiler = self._query_compiler.apply( + func, + axis, + args=args, + raw=raw, + result_type=result_type, + **kwds, + ) + return query_compiler + + def asfreq( + self, freq, method=None, how=None, normalize=False, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Convert time series to specified frequency. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + return self._default_to_pandas( + "asfreq", + freq, + method=method, + how=how, + normalize=normalize, + fill_value=fill_value, + ) + + def asof(self, where, subset=None): # noqa: PR01, RT01, D200 + """ + Return the last row(s) without any NaNs before `where`. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + scalar = not is_list_like(where) + if isinstance(where, pandas.Index): + # Prevent accidental mutation of original: + where = where.copy() + else: + if scalar: + where = [where] + where = pandas.Index(where) + + if subset is None: + data = self + else: + # Only relevant for DataFrames: + data = self[subset] + no_na_index = data.dropna().index + new_index = pandas.Index([no_na_index.asof(i) for i in where]) + result = self.reindex(new_index) + result.index = where + + if scalar: + # Need to return a Series: + result = result.squeeze() + return result + + def astype( + self, + dtype: str | type | pd.Series | dict[str, type], + copy: bool = True, + errors: Literal["raise", "ignore"] = "raise", + ) -> pd.DataFrame | pd.Series: + """ + Cast a Modin object to a specified dtype `dtype`. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + # dtype can be a series, a dict, or a scalar. If it's series or scalar, + # convert it to a dict before passing it to the query compiler. + raise_if_native_pandas_objects(dtype) + from snowflake.snowpark.modin.pandas import Series + + if isinstance(dtype, Series): + dtype = dtype.to_pandas() + if not dtype.index.is_unique: + raise ValueError( + "The new Series of types must have a unique index, i.e. " + + "it must be one-to-one mapping from column names to " + + " their new dtypes." + ) + dtype = dtype.to_dict() + # If we got a series or dict originally, dtype is a dict now. Its keys + # must be column names. + if isinstance(dtype, dict): + # Avoid materializing columns. The query compiler will handle errors where + # dtype dict includes keys that are not in columns. + col_dtypes = dtype + for col_name in col_dtypes: + if col_name not in self._query_compiler.columns: + raise KeyError( + "Only a column name can be used for the key in a dtype mappings argument. " + f"'{col_name}' not found in columns." + ) + else: + # Assume that the dtype is a scalar. + col_dtypes = {column: dtype for column in self._query_compiler.columns} + + # ensure values are pandas dtypes + col_dtypes = {k: pandas_dtype(v) for k, v in col_dtypes.items()} + new_query_compiler = self._query_compiler.astype(col_dtypes, errors=errors) + return self._create_or_update_from_compiler(new_query_compiler, not copy) + + @property + def at(self, axis=None): # noqa: PR01, RT01, D200 + """ + Get a single value for a row/column label pair. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + from .indexing import _LocIndexer + + return _LocIndexer(self) + + def at_time(self, time, asof=False, axis=None): # noqa: PR01, RT01, D200 + """ + Select values at particular time of day (e.g., 9:30AM). + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + axis = self._get_axis_number(axis) + idx = self.index if axis == 0 else self.columns + indexer = pandas.Series(index=idx).at_time(time, asof=asof).index + return self.loc[indexer] if axis == 0 else self.loc[:, indexer] + + @_inherit_docstrings( + pandas.DataFrame.between_time, apilink="pandas.DataFrame.between_time" + ) + def between_time( + self: BasePandasDataset, + start_time, + end_time, + inclusive: str | None = None, + axis=None, + ): # noqa: PR01, RT01, D200 + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + axis = self._get_axis_number(axis) + idx = self.index if axis == 0 else self.columns + indexer = ( + pandas.Series(index=idx) + .between_time( + start_time, + end_time, + inclusive=inclusive, + ) + .index + ) + return self.loc[indexer] if axis == 0 else self.loc[:, indexer] + + def bfill( + self, axis=None, inplace=False, limit=None, downcast=None + ): # noqa: PR01, RT01, D200 + """ + Synonym for `DataFrame.fillna` with ``method='bfill'``. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + return self.fillna( + method="bfill", axis=axis, limit=limit, downcast=downcast, inplace=inplace + ) + + backfill = bfill + + def bool(self): # noqa: RT01, D200 + """ + Return the bool of a single element `BasePandasDataset`. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + shape = self.shape + if shape != (1,) and shape != (1, 1): + raise ValueError( + """The PandasObject does not have exactly + 1 element. Return the bool of a single + element PandasObject. The truth value is + ambiguous. Use a.empty, a.item(), a.any() + or a.all().""" + ) + else: + return self._to_pandas().bool() + + def clip( + self, lower=None, upper=None, axis=None, inplace=False, *args, **kwargs + ): # noqa: PR01, RT01, D200 + """ + Trim values at input threshold(s). + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + # validate inputs + ErrorMessage.not_implemented() + if axis is not None: + axis = self._get_axis_number(axis) + self._validate_dtypes(numeric_only=True) + inplace = validate_bool_kwarg(inplace, "inplace") + axis = numpy_compat.function.validate_clip_with_axis(axis, args, kwargs) + # any np.nan bounds are treated as None + if lower is not None and np.any(np.isnan(lower)): + lower = None + if upper is not None and np.any(np.isnan(upper)): + upper = None + if is_list_like(lower) or is_list_like(upper): + if axis is None: + raise ValueError("Must specify axis = 0 or 1") + lower = self._validate_other(lower, axis) + upper = self._validate_other(upper, axis) + # FIXME: Judging by pandas docs `*args` and `**kwargs` serves only compatibility + # purpose and does not affect the result, we shouldn't pass them to the query compiler. + new_query_compiler = self._query_compiler.clip( + lower=lower, upper=upper, axis=axis, inplace=inplace, *args, **kwargs + ) + return self._create_or_update_from_compiler(new_query_compiler, inplace) + + def combine(self, other, func, fill_value=None, **kwargs): # noqa: PR01, RT01, D200 + """ + Perform combination of `BasePandasDataset`-s according to `func`. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + return self._binary_op( + "combine", other, axis=0, func=func, fill_value=fill_value, **kwargs + ) + + def combine_first(self, other): # noqa: PR01, RT01, D200 + """ + Update null elements with value in the same location in `other`. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + return self._binary_op("combine_first", other, axis=0) + + def copy(self, deep: bool = True): + """ + Make a copy of the object's metadata. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + if deep: + return self.__constructor__(query_compiler=self._query_compiler.copy()) + new_obj = self.__constructor__(query_compiler=self._query_compiler) + self._add_sibling(new_obj) + return new_obj + + def count( + self, + axis: Axis | None = 0, + numeric_only: bool = False, + ): + """ + Count non-NA cells for `BasePandasDataset`. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._agg_helper( + func="count", + axis=axis, + numeric_only=numeric_only, + ) + + def cummax(self, axis=None, skipna=True, *args, **kwargs): # noqa: PR01, RT01, D200 + """ + Return cumulative maximum over a `BasePandasDataset` axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + axis = self._get_axis_number(axis) + if axis == 1: + self._validate_dtypes(numeric_only=True) + return self.__constructor__( + # FIXME: Judging by pandas docs `*args` and `**kwargs` serves only compatibility + # purpose and does not affect the result, we shouldn't pass them to the query compiler. + query_compiler=self._query_compiler.cummax( + fold_axis=axis, axis=axis, skipna=skipna, **kwargs + ) + ) + + def cummin(self, axis=None, skipna=True, *args, **kwargs): # noqa: PR01, RT01, D200 + """ + Return cumulative minimum over a `BasePandasDataset` axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + axis = self._get_axis_number(axis) + if axis == 1: + self._validate_dtypes(numeric_only=True) + return self.__constructor__( + # FIXME: Judging by pandas docs `*args` and `**kwargs` serves only compatibility + # purpose and does not affect the result, we shouldn't pass them to the query compiler. + query_compiler=self._query_compiler.cummin( + fold_axis=axis, axis=axis, skipna=skipna, **kwargs + ) + ) + + def cumprod( + self, axis=None, skipna=True, *args, **kwargs + ): # noqa: PR01, RT01, D200 + """ + Return cumulative product over a `BasePandasDataset` axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + axis = self._get_axis_number(axis) + self._validate_dtypes(numeric_only=True) + return self.__constructor__( + # FIXME: Judging by pandas docs `**kwargs` serves only compatibility + # purpose and does not affect the result, we shouldn't pass them to the query compiler. + query_compiler=self._query_compiler.cumprod( + fold_axis=axis, axis=axis, skipna=skipna, **kwargs + ) + ) + + def cumsum(self, axis=None, skipna=True, *args, **kwargs): # noqa: PR01, RT01, D200 + """ + Return cumulative sum over a `BasePandasDataset` axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + axis = self._get_axis_number(axis) + self._validate_dtypes(numeric_only=True) + return self.__constructor__( + # FIXME: Judging by pandas docs `*args` and `**kwargs` serves only compatibility + # purpose and does not affect the result, we shouldn't pass them to the query compiler. + query_compiler=self._query_compiler.cumsum( + fold_axis=axis, axis=axis, skipna=skipna, **kwargs + ) + ) + + def describe( + self, + percentiles: ListLike | None = None, + include: ListLike | Literal["all"] | None = None, + exclude: ListLike | None = None, + ) -> BasePandasDataset: + """ + Generate descriptive statistics. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + # Upstream modin uses pandas.core.methods.describe._refine_percentiles for this, + # which is not available in pandas 1.5.X + if percentiles is not None: + # explicit conversion of `percentiles` to list + percentiles = list(percentiles) + + # get them all to be in [0, 1] + validate_percentile(percentiles) + + # median should always be included + if 0.5 not in percentiles: + percentiles.append(0.5) + percentiles = np.asarray(percentiles) + else: + percentiles = np.array([0.25, 0.5, 0.75]) + + data = self + if self._is_dataframe: + # Upstream modin lacks this check because it defaults to pandas for describing empty dataframes + if len(self.columns) == 0: + raise ValueError("Cannot describe a DataFrame without columns") + + # include/exclude are ignored for Series + if (include is None) and (exclude is None): + # when some numerics are found, keep only numerics + default_include: list[npt.DTypeLike] = [np.number] + default_include.append("datetime") + data = self.select_dtypes(include=default_include) + if len(data.columns) == 0: + data = self + elif include == "all": + if exclude is not None: + raise ValueError("exclude must be None when include is 'all'") + data = self + else: + data = self.select_dtypes( + include=include, + exclude=exclude, + ) + # Upstream modin uses data.empty, but that incurs an extra row count query + if self._is_dataframe and len(data.columns) == 0: + # Match pandas error from concatenating empty list of series descriptions. + raise ValueError("No objects to concatenate") + + return self.__constructor__( + query_compiler=data._query_compiler.describe(percentiles=percentiles) + ) + + def diff(self, periods: int = 1, axis: Axis = 0): + """ + First discrete difference of element. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + # We must only accept integer (or float values that are whole numbers) + # for periods. + int_periods = validate_int_kwarg(periods, "periods", float_allowed=True) + axis = self._get_axis_number(axis) + return self.__constructor__( + query_compiler=self._query_compiler.diff(axis=axis, periods=int_periods) + ) + + def drop( + self, + labels: IndexLabel = None, + axis: Axis = 0, + index: IndexLabel = None, + columns: IndexLabel = None, + level: Level = None, + inplace: bool = False, + errors: IgnoreRaise = "raise", + ) -> BasePandasDataset | None: + """ + Drop specified labels from `BasePandasDataset`. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + inplace = validate_bool_kwarg(inplace, "inplace") + if labels is not None: + if index is not None or columns is not None: + raise ValueError("Cannot specify both 'labels' and 'index'/'columns'") + axes = {self._get_axis_number(axis): labels} + elif index is not None or columns is not None: + axes = {0: index, 1: columns} + else: + raise ValueError( + "Need to specify at least one of 'labels', 'index' or 'columns'" + ) + + for axis, labels in axes.items(): + if labels is not None: + if level is not None and not self._query_compiler.has_multiindex( + axis=axis + ): + # Same error as native pandas. + raise AssertionError("axis must be a MultiIndex") + # According to pandas documentation, a tuple will be used as a single + # label and not treated as a list-like. + if not is_list_like(labels) or isinstance(labels, tuple): + axes[axis] = [labels] + + new_query_compiler = self._query_compiler.drop( + index=axes.get(0), columns=axes.get(1), level=level, errors=errors + ) + return self._create_or_update_from_compiler(new_query_compiler, inplace) + + def _dropna( + self, + axis: Axis = 0, + how: str | NoDefault = no_default, + thresh: int | NoDefault = no_default, + subset: IndexLabel = None, + inplace: bool = False, + ): + inplace = validate_bool_kwarg(inplace, "inplace") + + if is_list_like(axis): + raise TypeError("supplying multiple axes to axis is no longer supported.") + + axis = self._get_axis_number(axis) + + if (how is not no_default) and (thresh is not no_default): + raise TypeError( + "You cannot set both the how and thresh arguments at the same time." + ) + + if how is no_default: + how = "any" + if how not in ["any", "all"]: + raise ValueError("invalid how option: %s" % how) + if subset is not None: + if axis == 1: + indices = self.index.get_indexer_for(subset) + check = indices == -1 + if check.any(): + raise KeyError(list(np.compress(check, subset))) + else: + indices = self.columns.get_indexer_for(subset) + check = indices == -1 + if check.any(): + raise KeyError(list(np.compress(check, subset))) + + new_query_compiler = self._query_compiler.dropna( + axis=axis, + how=how, + thresh=thresh, + subset=subset, + ) + return self._create_or_update_from_compiler(new_query_compiler, inplace) + + def droplevel(self, level, axis=0): # noqa: PR01, RT01, D200 + """ + Return `BasePandasDataset` with requested index / column level(s) removed. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + axis = self._get_axis_number(axis) + new_axis = self.axes[axis].droplevel(level) + result = self.copy() + if axis == 0: + result.index = new_axis + else: + result.columns = new_axis + return result + + def drop_duplicates( + self, keep="first", inplace=False, **kwargs + ): # noqa: PR01, RT01, D200 + """ + Return `BasePandasDataset` with duplicate rows removed. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + inplace = validate_bool_kwarg(inplace, "inplace") + ignore_index = kwargs.get("ignore_index", False) + subset = kwargs.get("subset", None) + if subset is not None: + if is_list_like(subset): + if not isinstance(subset, list): + subset = list(subset) + else: + subset = [subset] + df = self[subset] + else: + df = self + duplicated = df.duplicated(keep=keep) + result = self[~duplicated] + if ignore_index: + result.index = pandas.RangeIndex(stop=len(result)) + if inplace: + self._update_inplace(result._query_compiler) + else: + return result + + def mask( + self, + cond: BasePandasDataset | Callable | AnyArrayLike, + other: BasePandasDataset | Callable | Scalar | None = np.nan, + inplace: bool = False, + axis: Axis | None = None, + level: Level | None = None, + ): + """ + Replace values where the condition is True. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + # TODO: https://snowflakecomputing.atlassian.net/browse/SNOW-985670 + # will move pre-processing to QC layer. + inplace = validate_bool_kwarg(inplace, "inplace") + if cond is None: + raise ValueError("Array conditional must be same shape as self") + + cond = apply_if_callable(cond, self) + + if isinstance(cond, Callable): + raise NotImplementedError("Do not support callable for 'cond' parameter.") + + from snowflake.snowpark.modin.pandas import Series + + if isinstance(cond, Series): + cond._query_compiler._shape_hint = "column" + if isinstance(self, Series): + self._query_compiler._shape_hint = "column" + if isinstance(other, Series): + other._query_compiler._shape_hint = "column" + + if not isinstance(cond, BasePandasDataset): + cond = get_as_shape_compatible_dataframe_or_series(cond, self) + cond._query_compiler._shape_hint = "array" + + if other is not None: + other = apply_if_callable(other, self) + + if isinstance(other, np.ndarray): + other = get_as_shape_compatible_dataframe_or_series( + other, + self, + shape_mismatch_message="other must be the same shape as self when an ndarray", + ) + other._query_compiler._shape_hint = "array" + + if isinstance(other, BasePandasDataset): + other = other._query_compiler + + query_compiler = self._query_compiler.mask( + cond._query_compiler, + other, + axis, + level, + ) + + return self._create_or_update_from_compiler(query_compiler, inplace) + + def where( + self, + cond: BasePandasDataset | Callable | AnyArrayLike, + other: BasePandasDataset | Callable | Scalar | None = np.nan, + inplace: bool = False, + axis: Axis | None = None, + level: Level | None = None, + ): + """ + Replace values where the condition is False. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + # TODO: SNOW-985670: Refactor `where` and `mask` + # will move pre-processing to QC layer. + inplace = validate_bool_kwarg(inplace, "inplace") + if cond is None: + raise ValueError("Array conditional must be same shape as self") + + cond = apply_if_callable(cond, self) + + if isinstance(cond, Callable): + raise NotImplementedError("Do not support callable for 'cond' parameter.") + + from snowflake.snowpark.modin.pandas import Series + + if isinstance(cond, Series): + cond._query_compiler._shape_hint = "column" + if isinstance(self, Series): + self._query_compiler._shape_hint = "column" + if isinstance(other, Series): + other._query_compiler._shape_hint = "column" + + if not isinstance(cond, BasePandasDataset): + cond = get_as_shape_compatible_dataframe_or_series(cond, self) + cond._query_compiler._shape_hint = "array" + + if other is not None: + other = apply_if_callable(other, self) + + if isinstance(other, np.ndarray): + other = get_as_shape_compatible_dataframe_or_series( + other, + self, + shape_mismatch_message="other must be the same shape as self when an ndarray", + ) + other._query_compiler._shape_hint = "array" + + if isinstance(other, BasePandasDataset): + other = other._query_compiler + + query_compiler = self._query_compiler.where( + cond._query_compiler, + other, + axis, + level, + ) + + return self._create_or_update_from_compiler(query_compiler, inplace) + + def eq(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 + """ + Get equality of `BasePandasDataset` and `other`, element-wise (binary operator `eq`). + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._binary_op("eq", other, axis=axis, level=level, dtypes=np.bool_) + + def explode(self, column, ignore_index: bool = False): # noqa: PR01, RT01, D200 + """ + Transform each element of a list-like to a row. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + exploded = self.__constructor__( + query_compiler=self._query_compiler.explode(column) + ) + if ignore_index: + exploded = exploded.reset_index(drop=True) + return exploded + + def ewm( + self, + com: float | None = None, + span: float | None = None, + halflife: float | TimedeltaConvertibleTypes | None = None, + alpha: float | None = None, + min_periods: int | None = 0, + adjust: bool = True, + ignore_na: bool = False, + axis: Axis = 0, + times: str | np.ndarray | BasePandasDataset | None = None, + method: str = "single", + ) -> pandas.core.window.ewm.ExponentialMovingWindow: # noqa: PR01, RT01, D200 + """ + Provide exponentially weighted (EW) calculations. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + return self._default_to_pandas( + "ewm", + com=com, + span=span, + halflife=halflife, + alpha=alpha, + min_periods=min_periods, + adjust=adjust, + ignore_na=ignore_na, + axis=axis, + times=times, + method=method, + ) + + def expanding( + self, min_periods=1, axis=0, method="single" + ): # noqa: PR01, RT01, D200 + """ + Provide expanding window calculations. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + return self._default_to_pandas( + "expanding", + min_periods=min_periods, + axis=axis, + method=method, + ) + + def ffill( + self, + axis: Axis | None = None, + inplace: bool = False, + limit: int | None = None, + downcast: dict | None = None, + ): + """ + Synonym for `DataFrame.fillna` with ``method='ffill'``. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self.fillna( + method="ffill", axis=axis, limit=limit, downcast=downcast, inplace=inplace + ) + + pad = ffill + + def fillna( + self, + self_is_series, + value: Hashable | Mapping | pd.Series | pd.DataFrame = None, + method: FillnaOptions | None = None, + axis: Axis | None = None, + inplace: bool = False, + limit: int | None = None, + downcast: dict | None = None, + ): + """ + Fill NA/NaN values using the specified method. + + Parameters + ---------- + self_is_series : bool + If True then self contains a Series object, if False then self contains + a DataFrame object. + value : scalar, dict, Series, or DataFrame, default: None + Value to use to fill holes (e.g. 0), alternately a + dict/Series/DataFrame of values specifying which value to use for + each index (for a Series) or column (for a DataFrame). Values not + in the dict/Series/DataFrame will not be filled. This value cannot + be a list. + method : {'backfill', 'bfill', 'pad', 'ffill', None}, default: None + Method to use for filling holes in reindexed Series + pad / ffill: propagate last valid observation forward to next valid + backfill / bfill: use next valid observation to fill gap. + axis : {None, 0, 1}, default: None + Axis along which to fill missing values. + inplace : bool, default: False + If True, fill in-place. Note: this will modify any + other views on this object (e.g., a no-copy slice for a column in a + DataFrame). + limit : int, default: None + If method is specified, this is the maximum number of consecutive + NaN values to forward/backward fill. In other words, if there is + a gap with more than this number of consecutive NaNs, it will only + be partially filled. If method is not specified, this is the + maximum number of entries along the entire axis where NaNs will be + filled. Must be greater than 0 if not None. + downcast : dict, default: None + A dict of item->dtype of what to downcast if possible, + or the string 'infer' which will try to downcast to an appropriate + equal type (e.g. float64 to int64 if possible). + + Returns + ------- + Series, DataFrame or None + Object with missing values filled or None if ``inplace=True``. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + raise_if_native_pandas_objects(value) + inplace = validate_bool_kwarg(inplace, "inplace") + axis = self._get_axis_number(axis) + if isinstance(value, (list, tuple)): + raise TypeError( + '"value" parameter must be a scalar or dict, but ' + + f'you passed a "{type(value).__name__}"' + ) + if value is None and method is None: + # same as pandas + raise ValueError("Must specify a fill 'value' or 'method'.") + if value is not None and method is not None: + raise ValueError("Cannot specify both 'value' and 'method'.") + if method is not None and method not in ["backfill", "bfill", "pad", "ffill"]: + expecting = "pad (ffill) or backfill (bfill)" + msg = "Invalid fill method. Expecting {expecting}. Got {method}".format( + expecting=expecting, method=method + ) + raise ValueError(msg) + if limit is not None: + if not isinstance(limit, int): + raise ValueError("Limit must be an integer") + elif limit <= 0: + raise ValueError("Limit must be greater than 0") + + new_query_compiler = self._query_compiler.fillna( + self_is_series=self_is_series, + value=value, + method=method, + axis=axis, + limit=limit, + downcast=downcast, + ) + return self._create_or_update_from_compiler(new_query_compiler, inplace) + + def filter( + self, items=None, like=None, regex=None, axis=None + ): # noqa: PR01, RT01, D200 + """ + Subset the `BasePandasDataset` rows or columns according to the specified index labels. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + nkw = count_not_none(items, like, regex) + if nkw > 1: + raise TypeError( + "Keyword arguments `items`, `like`, or `regex` are mutually exclusive" + ) + if nkw == 0: + raise TypeError("Must pass either `items`, `like`, or `regex`") + if axis is None: + axis = "columns" # This is the default info axis for dataframes + + axis = self._get_axis_number(axis) + labels = self.columns if axis else self.index + + if items is not None: + bool_arr = labels.isin(items) + elif like is not None: + + def f(x): + return like in str(x) + + bool_arr = labels.map(f).tolist() + else: + + def f(x): + return matcher.search(str(x)) is not None + + matcher = re.compile(regex) + bool_arr = labels.map(f).tolist() + if not axis: + return self[bool_arr] + return self[self.columns[bool_arr]] + + def first(self, offset): # noqa: PR01, RT01, D200 + """ + Select initial periods of time series data based on a date offset. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self.loc[pandas.Series(index=self.index).first(offset).index] + + def first_valid_index(self) -> Scalar | tuple[Scalar]: + """ + Return index for first non-NA value or None, if no non-NA value is found. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._query_compiler.first_valid_index() + + def floordiv( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Get integer division of `BasePandasDataset` and `other`, element-wise (binary operator `floordiv`). + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._binary_op( + "floordiv", other, axis=axis, level=level, fill_value=fill_value + ) + + def ge(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 + """ + Get greater than or equal comparison of `BasePandasDataset` and `other`, element-wise (binary operator `ge`). + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._binary_op("ge", other, axis=axis, level=level, dtypes=np.bool_) + + def get(self, key, default=None): # noqa: PR01, RT01, D200 + """ + Get item from object for given key. + """ + try: + return self.__getitem__(key) + except (KeyError, ValueError, IndexError): + return default + + def gt(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 + """ + Get greater than comparison of `BasePandasDataset` and `other`, element-wise (binary operator `gt`). + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._binary_op("gt", other, axis=axis, level=level, dtypes=np.bool_) + + def head(self, n: int = 5): + """ + Return the first `n` rows. + """ + return self.iloc[:n] + + @property + def iat(self, axis=None): # noqa: PR01, RT01, D200 + """ + Get a single value for a row/column pair by integer position. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + from .indexing import _iLocIndexer + + return _iLocIndexer(self) + + def idxmax(self, axis=0, skipna=True, numeric_only=False): # noqa: PR01, RT01, D200 + """ + Return index of first occurrence of maximum over requested axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + dtypes = self._get_dtypes() + if ( + axis == 1 + and not numeric_only + and any(not is_numeric_dtype(d) for d in dtypes) + and len(set(dtypes)) > 1 + ): + # For numeric_only=False, if we have any non-numeric dtype, e.g. + # a string type, we need every other column to be of the same type. + # We can't compare two objects of different non-numeric types, e.g. + # a string and a timestamp. + # If we have only numeric data, we can compare columns even if they + # different types, e.g. we can compare an int column to a float + # column. + raise TypeError("'>' not supported for these dtypes") + axis = self._get_axis_number(axis) + return self._reduce_dimension( + self._query_compiler.idxmax( + axis=axis, skipna=skipna, numeric_only=numeric_only + ) + ) + + def idxmin(self, axis=0, skipna=True, numeric_only=False): # noqa: PR01, RT01, D200 + """ + Return index of first occurrence of minimum over requested axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + dtypes = self._get_dtypes() + if ( + axis == 1 + and not numeric_only + and any(not is_numeric_dtype(d) for d in dtypes) + and len(set(dtypes)) > 1 + ): + # For numeric_only=False, if we have any non-numeric dtype, e.g. + # a string type, we need every other column to be of the same type. + # We can't compare two objects of different non-numeric types, e.g. + # a string and a timestamp. + # If we have only numeric data, we can compare columns even if they + # different types, e.g. we can compare an int column to a float + # column. + raise TypeError("'<' not supported for these dtypes") + axis = self._get_axis_number(axis) + return self._reduce_dimension( + self._query_compiler.idxmin( + axis=axis, skipna=skipna, numeric_only=numeric_only + ) + ) + + def infer_objects( + self, copy: bool | None = None + ) -> BasePandasDataset: # pragma: no cover # noqa: RT01, D200 + """ + Attempt to infer better dtypes for object columns. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + # This method is currently overriden in dataframe_overrides.py and series_overrides.py + # and raises NotImplementedError + new_query_compiler = self._query_compiler.infer_objects() + return self._create_or_update_from_compiler( + new_query_compiler, inplace=False if copy is None else not copy + ) + + def convert_dtypes( + self, + infer_objects: bool = True, + convert_string: bool = True, + convert_integer: bool = True, + convert_boolean: bool = True, + convert_floating: bool = True, + dtype_backend: DtypeBackend = "numpy_nullable", + ): # noqa: PR01, RT01, D200 + """ + Convert columns to best possible dtypes using dtypes supporting ``pd.NA``. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self.__constructor__( + query_compiler=self._query_compiler.convert_dtypes( + infer_objects=infer_objects, + convert_string=convert_string, + convert_integer=convert_integer, + convert_boolean=convert_boolean, + convert_floating=convert_floating, + dtype_backend=dtype_backend, + ) + ) + + def isin( + self, values: BasePandasDataset | ListLike | dict[Hashable, ListLike] + ) -> BasePandasDataset: # noqa: PR01, RT01, D200 + """ + Whether elements in `BasePandasDataset` are contained in `values`. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + + # Pass as query compiler if values is BasePandasDataset. + if isinstance(values, BasePandasDataset): + values = values._query_compiler + + # Convert non-dict values to List if values is neither List[Any] nor np.ndarray. SnowflakeQueryCompiler + # expects for the non-lazy case, where values is not a BasePandasDataset, the data to be materialized + # as list or numpy array. Because numpy may perform implicit type conversions, use here list to be more general. + elif not isinstance(values, dict) and ( + not isinstance(values, list) or not isinstance(values, np.ndarray) + ): + values = list(values) + + return self.__constructor__( + query_compiler=self._query_compiler.isin(values=values) + ) + + def isna(self): # noqa: RT01, D200 + """ + Detect missing values. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self.__constructor__(query_compiler=self._query_compiler.isna()) + + isnull = isna + + @property + def iloc(self): + """ + Purely integer-location based indexing for selection by position. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + # TODO: SNOW-930028 enable all skipped doctests + from .indexing import _iLocIndexer + + return _iLocIndexer(self) + + def kurt(self, axis=no_default, skipna=True, numeric_only=False, **kwargs): + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + validate_bool_kwarg(skipna, "skipna", none_allowed=False) + axis = self._get_axis_number(axis) + if numeric_only is not None and not numeric_only: + self._validate_dtypes(numeric_only=True) + + data = ( + self._get_numeric_data(axis) + if numeric_only is None or numeric_only + else self + ) + + return self._reduce_dimension( + data._query_compiler.kurt( + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + ) + + kurtosis = kurt + + def last(self, offset): # noqa: PR01, RT01, D200 + """ + Select final periods of time series data based on a date offset. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self.loc[pandas.Series(index=self.index).last(offset).index] + + def last_valid_index(self) -> Scalar | tuple[Scalar]: + """ + Return index for last non-NA value or None, if no non-NA value is found. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._query_compiler.last_valid_index() + + def le(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 + """ + Get less than or equal comparison of `BasePandasDataset` and `other`, element-wise (binary operator `le`). + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._binary_op("le", other, axis=axis, level=level, dtypes=np.bool_) + + def lt(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 + """ + Get less than comparison of `BasePandasDataset` and `other`, element-wise (binary operator `lt`). + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._binary_op("lt", other, axis=axis, level=level, dtypes=np.bool_) + + @property + def loc(self): + """ + Get a group of rows and columns by label(s) or a boolean array. + """ + # TODO: SNOW-935444 fix doctest where index key has name + # TODO: SNOW-933782 fix multiindex transpose bug, e.g., Name: (cobra, mark ii) => Name: ('cobra', 'mark ii') + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + from .indexing import _LocIndexer + + return _LocIndexer(self) + + def _agg_helper( + self, + func: str, + skipna: bool = True, + axis: int | None | NoDefault = no_default, + numeric_only: bool = False, + **kwargs: Any, + ): + if not self._is_dataframe and numeric_only and not is_numeric_dtype(self.dtype): + # Series aggregations on non-numeric data do not support numeric_only: + # https://github.com/pandas-dev/pandas/blob/cece8c6579854f6b39b143e22c11cac56502c4fd/pandas/core/series.py#L6358 + raise TypeError( + f"Series.{func} does not allow numeric_only=True with non-numeric dtypes." + ) + axis = self._get_axis_number(axis) + numeric_only = validate_bool_kwarg( + numeric_only, "numeric_only", none_allowed=True + ) + skipna = validate_bool_kwarg(skipna, "skipna", none_allowed=False) + agg_kwargs: dict[str, Any] = { + "numeric_only": numeric_only, + "skipna": skipna, + } + agg_kwargs.update(kwargs) + return self.aggregate(func=func, axis=axis, **agg_kwargs) + + def max( + self, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + **kwargs: Any, + ): + """ + Return the maximum of the values over the requested axis. + """ + return self._agg_helper( + func="max", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + def _stat_operation( + self, + op_name: str, + axis: int | str, + skipna: bool, + numeric_only: bool = False, + **kwargs, + ): + """ + Do common statistic reduce operations under frame. + + Parameters + ---------- + op_name : str + Name of method to apply. + axis : int or str + Axis to apply method on. + skipna : bool + Exclude NA/null values when computing the result. + numeric_only : bool + Include only float, int, boolean columns. + **kwargs : dict + Additional keyword arguments to pass to `op_name`. + + Returns + ------- + scalar or Series + `scalar` - self is Series + `Series` - self is DataFrame + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + axis = self._get_axis_number(axis) + validate_bool_kwarg(skipna, "skipna", none_allowed=False) + if not numeric_only: + self._validate_dtypes(numeric_only=True) + + data = self._get_numeric_data(axis) if numeric_only else self + result_qc = getattr(data._query_compiler, op_name)( + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + result_qc = self._reduce_dimension(result_qc) + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + # This pattern is seen throughout this file so we should try to correct it + # when we have a more general way of resetting the name to None + from snowflake.snowpark.modin.pandas import Series + + if isinstance(result_qc, Series): + result_qc.name = None + return result_qc + + def memory_usage(self, index=True, deep=False): # noqa: PR01, RT01, D200 + """ + Return the memory usage of the `BasePandasDataset`. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._reduce_dimension( + self._query_compiler.memory_usage(index=index, deep=deep) + ) + + def min( + self, + axis: Axis | None | NoDefault = no_default, + skipna: bool = True, + numeric_only: bool = False, + **kwargs, + ): + """ + Return the minimum of the values over the requested axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._agg_helper( + func="min", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + def mod( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Get modulo of `BasePandasDataset` and `other`, element-wise (binary operator `mod`). + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._binary_op( + "mod", other, axis=axis, level=level, fill_value=fill_value + ) + + def mode(self, axis=0, numeric_only=False, dropna=True): # noqa: PR01, RT01, D200 + """ + Get the mode(s) of each element along the selected axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + axis = self._get_axis_number(axis) + return self.__constructor__( + query_compiler=self._query_compiler.mode( + axis=axis, numeric_only=numeric_only, dropna=dropna + ) + ) + + def mul( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Get multiplication of `BasePandasDataset` and `other`, element-wise (binary operator `mul`). + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._binary_op( + "mul", other, axis=axis, level=level, fill_value=fill_value + ) + + multiply = mul + + def ne(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 + """ + Get Not equal comparison of `BasePandasDataset` and `other`, element-wise (binary operator `ne`). + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._binary_op("ne", other, axis=axis, level=level, dtypes=np.bool_) + + def notna(self): # noqa: RT01, D200 + """ + Detect existing (non-missing) values. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self.__constructor__(query_compiler=self._query_compiler.notna()) + + notnull = notna + + def nunique(self, axis=0, dropna=True): # noqa: PR01, RT01, D200 + """ + Return number of unique elements in the `BasePandasDataset`. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + from snowflake.snowpark.modin.pandas import Series + + axis = self._get_axis_number(axis) + result = self._reduce_dimension( + self._query_compiler.nunique(axis=axis, dropna=dropna) + ) + if isinstance(result, Series): + result.name = None + return result + + def pct_change( + self, periods=1, fill_method="pad", limit=None, freq=None, **kwargs + ): # noqa: PR01, RT01, D200 + """ + Percentage change between the current and a prior element. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + return self._default_to_pandas( + "pct_change", + periods=periods, + fill_method=fill_method, + limit=limit, + freq=freq, + **kwargs, + ) + + def pipe(self, func, *args, **kwargs): # noqa: PR01, RT01, D200 + """ + Apply chainable functions that expect `BasePandasDataset`. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + return pipe(self, func, *args, **kwargs) + + def pop(self, item): # noqa: PR01, RT01, D200 + """ + Return item and drop from frame. Raise KeyError if not found. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + result = self[item] + del self[item] + return result + + def pow( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Get exponential power of `BasePandasDataset` and `other`, element-wise (binary operator `pow`). + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._binary_op( + "pow", other, axis=axis, level=level, fill_value=fill_value + ) + + def quantile( + self, + q: Scalar | ListLike = 0.5, + axis: Axis = 0, + numeric_only: bool = False, + interpolation: Literal[ + "linear", "lower", "higher", "midpoint", "nearest" + ] = "linear", + method: Literal["single", "table"] = "single", + ) -> float | BasePandasDataset: + """ + Return values at the given quantile over requested axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + axis = self._get_axis_number(axis) + + # TODO + # - SNOW-1008361: support axis=1 + # - SNOW-1008367: support when q is Snowpandas DF/Series (need to require QC interface to accept QC q values) + # - SNOW-1003587: support datetime/timedelta columns + + if ( + axis == 1 + or interpolation not in ["linear", "nearest"] + or method != "single" + ): + ErrorMessage.not_implemented( + f"quantile function with parameters axis={axis}, interpolation={interpolation}, method={method} not supported" + ) + + if not numeric_only: + # If not numeric_only and columns, then check all columns are either + # numeric, timestamp, or timedelta + # Check if dtype is numeric, timedelta ("m"), or datetime ("M") + if not axis and not all( + is_numeric_dtype(t) or lib.is_np_dtype(t, "mM") + for t in self._get_dtypes() + ): + raise TypeError("can't multiply sequence by non-int of type 'float'") + # If over rows, then make sure that all dtypes are equal for not + # numeric_only + elif axis: + for i in range(1, len(self._get_dtypes())): + pre_dtype = self._get_dtypes()[i - 1] + curr_dtype = self._get_dtypes()[i] + if not is_dtype_equal(pre_dtype, curr_dtype): + raise TypeError( + "Cannot compare type '{}' with type '{}'".format( + pre_dtype, curr_dtype + ) + ) + else: + # Normally pandas returns this near the end of the quantile, but we + # can't afford the overhead of running the entire operation before + # we error. + if not any(is_numeric_dtype(t) for t in self._get_dtypes()): + raise ValueError("need at least one array to concatenate") + + # check that all qs are between 0 and 1 + validate_percentile(q) + axis = self._get_axis_number(axis) + query_compiler = self._query_compiler.quantiles_along_axis0( + q=q if is_list_like(q) else [q], + numeric_only=numeric_only, + interpolation=interpolation, + method=method, + ) + if is_list_like(q): + return self.__constructor__(query_compiler=query_compiler) + else: + # result is either a scalar or Series + result = self._reduce_dimension(query_compiler.transpose_single_row()) + if isinstance(result, BasePandasDataset): + result.name = q + return result + + @_inherit_docstrings(pandas.DataFrame.rank, apilink="pandas.DataFrame.rank") + def rank( + self, + axis=0, + method: str = "average", + numeric_only: bool = False, + na_option: str = "keep", + ascending: bool = True, + pct: bool = False, + ): + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + axis = self._get_axis_number(axis) + return self.__constructor__( + query_compiler=self._query_compiler.rank( + axis=axis, + method=method, + numeric_only=numeric_only, + na_option=na_option, + ascending=ascending, + pct=pct, + ) + ) + + def _copy_index_metadata(self, source, destination): # noqa: PR01, RT01, D200 + """ + Copy Index metadata from `source` to `destination` inplace. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + if hasattr(source, "name") and hasattr(destination, "name"): + destination.name = source.name + if hasattr(source, "names") and hasattr(destination, "names"): + destination.names = source.names + return destination + + def _ensure_index(self, index_like, axis=0): # noqa: PR01, RT01, D200 + """ + Ensure that we have an index from some index-like object. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + if ( + self._query_compiler.has_multiindex(axis=axis) + and not isinstance(index_like, pandas.Index) + and is_list_like(index_like) + and len(index_like) > 0 + and isinstance(index_like[0], tuple) + ): + try: + return pandas.MultiIndex.from_tuples(index_like) + except TypeError: + # not all tuples + pass + return ensure_index(index_like) + + def reindex( + self, + index=None, + columns=None, + copy=True, + **kwargs, + ): # noqa: PR01, RT01, D200 + """ + Conform `BasePandasDataset` to new index with optional filling logic. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() # pragma: no cover + + new_query_compiler = None + if index is not None: + if not isinstance(index, pandas.Index) or not index.equals(self.index): + new_query_compiler = self._query_compiler.reindex( + axis=0, labels=index, **kwargs + ) + if new_query_compiler is None: + new_query_compiler = self._query_compiler + final_query_compiler = None + if columns is not None: + if not isinstance(index, pandas.Index) or not columns.equals(self.columns): + final_query_compiler = new_query_compiler.reindex( + axis=1, labels=columns, **kwargs + ) + if final_query_compiler is None: + final_query_compiler = new_query_compiler + return self._create_or_update_from_compiler( + final_query_compiler, inplace=False if copy is None else not copy + ) + + def reindex_like( + self, other, method=None, copy=True, limit=None, tolerance=None + ): # noqa: PR01, RT01, D200 + """ + Return an object with matching indices as `other` object. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + return self._default_to_pandas( + "reindex_like", + other, + method=method, + copy=copy, + limit=limit, + tolerance=tolerance, + ) + + def rename_axis( + self, + mapper=lib.no_default, + *, + index=lib.no_default, + columns=lib.no_default, + axis=0, + copy=None, + inplace=False, + ): # noqa: PR01, RT01, D200 + """ + Set the name of the axis for the index or columns. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + axes = {"index": index, "columns": columns} + + if copy is None: + copy = True + + if axis is not None: + axis = self._get_axis_number(axis) + else: + axis = 0 + + inplace = validate_bool_kwarg(inplace, "inplace") + + if mapper is not lib.no_default and mapper is not None: + # Use v0.23 behavior if a scalar or list + non_mapper = is_scalar(mapper) or ( + is_list_like(mapper) and not is_dict_like(mapper) + ) + if non_mapper: + return self._set_axis_name(mapper, axis=axis, inplace=inplace) + else: + raise ValueError("Use `.rename` to alter labels with a mapper.") + else: + # Use new behavior. Means that index and/or columns is specified + result = self if inplace else self.copy(deep=copy) + + for axis in range(self.ndim): + v = axes.get(pandas.DataFrame._get_axis_name(axis)) + if v is lib.no_default: + continue + non_mapper = is_scalar(v) or (is_list_like(v) and not is_dict_like(v)) + if non_mapper: + newnames = v + else: + + def _get_rename_function(mapper): + if isinstance(mapper, (dict, BasePandasDataset)): + + def f(x): + if x in mapper: + return mapper[x] + else: + return x + + else: + f = mapper + + return f + + f = _get_rename_function(v) + curnames = self.index.names if axis == 0 else self.columns.names + newnames = [f(name) for name in curnames] + result._set_axis_name(newnames, axis=axis, inplace=True) + if not inplace: + return result + + def reorder_levels(self, order, axis=0): # noqa: PR01, RT01, D200 + """ + Rearrange index levels using input order. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + axis = self._get_axis_number(axis) + new_labels = self.axes[axis].reorder_levels(order) + return self.set_axis(new_labels, axis=axis) + + def resample( + self, + rule, + axis: Axis = lib.no_default, + closed: str | None = None, + label: str | None = None, + convention: str = "start", + kind: str | None = None, + on: Level = None, + level: Level = None, + origin: str | TimestampConvertibleTypes = "start_day", + offset: TimedeltaConvertibleTypes | None = None, + group_keys=no_default, + ): # noqa: PR01, RT01, D200 + """ + Resample time-series data. + """ + from .resample import Resampler + + if axis is not lib.no_default: # pragma: no cover + axis = self._get_axis_number(axis) + if axis == 1: + warnings.warn( + "DataFrame.resample with axis=1 is deprecated. Do " + + "`frame.T.resample(...)` without axis instead.", + FutureWarning, + stacklevel=1, + ) + else: + warnings.warn( + f"The 'axis' keyword in {type(self).__name__}.resample is " + + "deprecated and will be removed in a future version.", + FutureWarning, + stacklevel=1, + ) + else: + axis = 0 + + return Resampler( + dataframe=self, + rule=rule, + axis=axis, + closed=closed, + label=label, + convention=convention, + kind=kind, + on=on, + level=level, + origin=origin, + offset=offset, + group_keys=group_keys, + ) + + def reset_index( + self, + level: IndexLabel = None, + drop: bool = False, + inplace: bool = False, + col_level: Hashable = 0, + col_fill: Hashable = "", + allow_duplicates=no_default, + names: Hashable | Sequence[Hashable] = None, + ): + """ + Reset the index, or a level of it. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + inplace = validate_bool_kwarg(inplace, "inplace") + if allow_duplicates is no_default: + allow_duplicates = False + new_query_compiler = self._query_compiler.reset_index( + drop=drop, + level=level, + col_level=col_level, + col_fill=col_fill, + allow_duplicates=allow_duplicates, + names=names, + ) + return self._create_or_update_from_compiler(new_query_compiler, inplace) + + def radd( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Return addition of `BasePandasDataset` and `other`, element-wise (binary operator `radd`). + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._binary_op( + "radd", other, axis=axis, level=level, fill_value=fill_value + ) + + def rfloordiv( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Get integer division of `BasePandasDataset` and `other`, element-wise (binary operator `rfloordiv`). + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._binary_op( + "rfloordiv", other, axis=axis, level=level, fill_value=fill_value + ) + + def rmod( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Get modulo of `BasePandasDataset` and `other`, element-wise (binary operator `rmod`). + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._binary_op( + "rmod", other, axis=axis, level=level, fill_value=fill_value + ) + + def rmul( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Get Multiplication of dataframe and other, element-wise (binary operator `rmul`). + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._binary_op( + "rmul", other, axis=axis, level=level, fill_value=fill_value + ) + + def rolling( + self, + window, + min_periods: int | None = None, + center: bool = False, + win_type: str | None = None, + on: str | None = None, + axis: Axis = lib.no_default, + closed: str | None = None, + step: int | None = None, + method: str = "single", + ): # noqa: PR01, RT01, D200 + """ + Provide rolling window calculations. + """ + if axis is not lib.no_default: + axis = self._get_axis_number(axis) + name = "rolling" + if axis == 1: + warnings.warn( + f"Support for axis=1 in {type(self).__name__}.{name} is " + + "deprecated and will be removed in a future version. " + + f"Use obj.T.{name}(...) instead", + FutureWarning, + stacklevel=1, + ) + else: # pragma: no cover + warnings.warn( + f"The 'axis' keyword in {type(self).__name__}.{name} is " + + "deprecated and will be removed in a future version. " + + "Call the method without the axis keyword instead.", + FutureWarning, + stacklevel=1, + ) + else: + axis = 0 + + if win_type is not None: + from .window import Window + + return Window( + self, + window=window, + min_periods=min_periods, + center=center, + win_type=win_type, + on=on, + axis=axis, + closed=closed, + step=step, + method=method, + ) + from .window import Rolling + + return Rolling( + self, + window=window, + min_periods=min_periods, + center=center, + win_type=win_type, + on=on, + axis=axis, + closed=closed, + step=step, + method=method, + ) + + def round(self, decimals=0, *args, **kwargs): # noqa: PR01, RT01, D200 + """ + Round a `BasePandasDataset` to a variable number of decimal places. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + # FIXME: Judging by pandas docs `*args` and `**kwargs` serves only compatibility + # purpose and does not affect the result, we shouldn't pass them to the query compiler. + return self.__constructor__( + query_compiler=self._query_compiler.round(decimals=decimals, **kwargs) + ) + + def rpow( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Get exponential power of `BasePandasDataset` and `other`, element-wise (binary operator `rpow`). + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._binary_op( + "rpow", other, axis=axis, level=level, fill_value=fill_value + ) + + def rsub( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Get subtraction of `BasePandasDataset` and `other`, element-wise (binary operator `rsub`). + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._binary_op( + "rsub", other, axis=axis, level=level, fill_value=fill_value + ) + + def rtruediv( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Get floating division of `BasePandasDataset` and `other`, element-wise (binary operator `rtruediv`). + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._binary_op( + "rtruediv", other, axis=axis, level=level, fill_value=fill_value + ) + + rdiv = rtruediv + + def sample( + self, + n: int | None = None, + frac: float | None = None, + replace: bool = False, + weights: str | np.ndarray | None = None, + random_state: RandomState | None = None, + axis: Axis | None = None, + ignore_index: bool = False, + ): + """ + Return a random sample of items from an axis of object. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + if self._get_axis_number(axis): + if weights is not None and isinstance(weights, str): + raise ValueError( + "Strings can only be passed to weights when sampling from rows on a DataFrame" + ) + else: + if n is None and frac is None: + n = 1 + elif n is not None and frac is not None: + raise ValueError("Please enter a value for `frac` OR `n`, not both") + else: + if n is not None: + if n < 0: + raise ValueError( + "A negative number of rows requested. Please provide `n` >= 0." + ) + if n % 1 != 0: + raise ValueError("Only integers accepted as `n` values") + else: + if frac < 0: + raise ValueError( + "A negative number of rows requested. Please provide `frac` >= 0." + ) + + query_compiler = self._query_compiler.sample( + n, frac, replace, weights, random_state, axis, ignore_index + ) + return self.__constructor__(query_compiler=query_compiler) + + def sem( + self, + axis: Axis | None = None, + skipna: bool = True, + ddof: int = 1, + numeric_only=False, + **kwargs, + ): # noqa: PR01, RT01, D200 + """ + Return unbiased standard error of the mean over requested axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + return self._stat_operation( + "sem", axis, skipna, numeric_only, ddof=ddof, **kwargs + ) + + def mean( + self, + axis: Axis | None | NoDefault = no_default, + skipna: bool = True, + numeric_only: bool = False, + **kwargs: Any, + ): + """ + Return the mean of the values over the requested axis. + """ + return self._agg_helper( + func="mean", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + def median( + self, + axis: Axis | None | NoDefault = no_default, + skipna: bool = True, + numeric_only: bool = False, + **kwargs: Any, + ): + """ + Return the mean of the values over the requested axis. + """ + return self._agg_helper( + func="median", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + def set_flags( + self, *, copy: bool = False, allows_duplicate_labels: bool | None = None + ): # noqa: PR01, RT01, D200 + """ + Return a new `BasePandasDataset` with updated flags. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + return self._default_to_pandas( + pandas.DataFrame.set_flags, + copy=copy, + allows_duplicate_labels=allows_duplicate_labels, + ) + + @property + def flags(self): + return self._default_to_pandas(lambda df: df.flags) + + def shift( + self, + periods: int = 1, + freq=None, + axis: Axis = 0, + fill_value: Hashable = no_default, + ) -> BasePandasDataset: + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + if periods == 0 and freq is None: + # Check obvious case first, freq manipulates the index even for periods == 0 so check for it in addition. + return self.copy() + + # pandas compatible ValueError for freq='infer' + # TODO: Test as part of SNOW-1023324. + if freq == "infer": # pragma: no cover + if not hasattr(self, "freq") and not hasattr( # pragma: no cover + self, "inferred_freq" # pragma: no cover + ): # pragma: no cover + raise ValueError() # pragma: no cover + + axis = self._get_axis_number(axis) + + if fill_value == no_default: + fill_value = None + + new_query_compiler = self._query_compiler.shift(periods, freq, axis, fill_value) + return self._create_or_update_from_compiler(new_query_compiler, False) + + def skew( + self, + axis: Axis | None | NoDefault = no_default, + skipna: bool = True, + numeric_only=True, + **kwargs, + ): # noqa: PR01, RT01, D200 + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + """ + Return unbiased skew over requested axis. + """ + return self._stat_operation("skew", axis, skipna, numeric_only, **kwargs) + + def sort_index( + self, + axis=0, + level=None, + ascending=True, + inplace=False, + kind="quicksort", + na_position="last", + sort_remaining=True, + ignore_index: bool = False, + key: IndexKeyFunc | None = None, + ): # noqa: PR01, RT01, D200 + """ + Sort object by labels (along an axis). + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + # pandas throws this exception. See pandas issue #39434 + if ascending is None: + raise ValueError( + "the `axis` parameter is not supported in the pandas implementation of argsort()" + ) + axis = self._get_axis_number(axis) + inplace = validate_bool_kwarg(inplace, "inplace") + new_query_compiler = self._query_compiler.sort_index( + axis=axis, + level=level, + ascending=ascending, + inplace=inplace, + kind=kind, + na_position=na_position, + sort_remaining=sort_remaining, + ignore_index=ignore_index, + key=key, + ) + return self._create_or_update_from_compiler(new_query_compiler, inplace) + + def sort_values( + self, + by, + axis=0, + ascending=True, + inplace: bool = False, + kind="quicksort", + na_position="last", + ignore_index: bool = False, + key: IndexKeyFunc | None = None, + ): # noqa: PR01, RT01, D200 + """ + Sort by the values along either axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + axis = self._get_axis_number(axis) + inplace = validate_bool_kwarg(inplace, "inplace") + ascending = validate_ascending(ascending) + if axis == 0: + # If any column is None raise KeyError (same a native pandas). + if by is None or (isinstance(by, list) and None in by): + # Same error message as native pandas. + raise KeyError(None) + if not isinstance(by, list): + by = [by] + + # Convert 'ascending' to sequence if needed. + if not isinstance(ascending, Sequence): + ascending = [ascending] * len(by) + if len(by) != len(ascending): + # Same error message as native pandas. + raise ValueError( + f"Length of ascending ({len(ascending)})" + f" != length of by ({len(by)})" + ) + + columns = self._query_compiler.columns.values.tolist() + index_names = self._query_compiler.get_index_names() + for by_col in by: + col_count = columns.count(by_col) + index_count = index_names.count(by_col) + if col_count == 0 and index_count == 0: + # Same error message as native pandas. + raise KeyError(by_col) + if col_count and index_count: + # Same error message as native pandas. + raise ValueError( + f"'{by_col}' is both an index level and a column label, which is ambiguous." + ) + if col_count > 1: + # Same error message as native pandas. + raise ValueError(f"The column label '{by_col}' is not unique.") + + if na_position not in get_args(NaPosition): + # Same error message as native pandas for invalid 'na_position' value. + raise ValueError(f"invalid na_position: {na_position}") + result = self._query_compiler.sort_rows_by_column_values( + by, + ascending=ascending, + kind=kind, + na_position=na_position, + ignore_index=ignore_index, + key=key, + ) + else: + result = self._query_compiler.sort_columns_by_row_values( + by, + ascending=ascending, + kind=kind, + na_position=na_position, + ignore_index=ignore_index, + key=key, + ) + return self._create_or_update_from_compiler(result, inplace) + + def std( + self, + axis: Axis | None = None, + skipna: bool = True, + ddof: int = 1, + numeric_only: bool = False, + **kwargs, + ): + """ + Return sample standard deviation over requested axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + kwargs.update({"ddof": ddof}) + return self._agg_helper( + func="std", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + def sum( + self, + axis: Axis | None = None, + skipna: bool = True, + numeric_only: bool = False, + min_count: int = 0, + **kwargs: Any, + ): + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + min_count = validate_int_kwarg(min_count, "min_count") + kwargs.update({"min_count": min_count}) + return self._agg_helper( + func="sum", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + def sub( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Get subtraction of `BasePandasDataset` and `other`, element-wise (binary operator `sub`). + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._binary_op( + "sub", other, axis=axis, level=level, fill_value=fill_value + ) + + subtract = sub + + def swapaxes(self, axis1, axis2, copy=True): # noqa: PR01, RT01, D200 + """ + Interchange axes and swap values axes appropriately. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + axis1 = self._get_axis_number(axis1) + axis2 = self._get_axis_number(axis2) + if axis1 != axis2: + return self.transpose() + if copy: + return self.copy() + return self + + def swaplevel(self, i=-2, j=-1, axis=0): # noqa: PR01, RT01, D200 + """ + Swap levels `i` and `j` in a `MultiIndex`. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + axis = self._get_axis_number(axis) + idx = self.index if axis == 0 else self.columns + return self.set_axis(idx.swaplevel(i, j), axis=axis) + + def tail(self, n: int = 5): + if n == 0: + return self.iloc[0:0] + return self.iloc[-n:] + + def take( + self, + indices: list | AnyArrayLike | slice, + axis: Axis = 0, + **kwargs, + ): + """ + Return the elements in the given *positional* indices along an axis. + """ + axis = self._get_axis_number(axis) + slice_obj = indices if axis == 0 else (slice(None), indices) + return self.iloc[slice_obj] + + def to_clipboard( + self, excel=True, sep=None, **kwargs + ): # pragma: no cover # noqa: PR01, RT01, D200 + """ + Copy object to the system clipboard. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + return self._default_to_pandas("to_clipboard", excel=excel, sep=sep, **kwargs) + + def to_csv( + self, + path_or_buf=None, + sep=",", + na_rep="", + float_format=None, + columns=None, + header=True, + index=True, + index_label=None, + mode="w", + encoding=None, + compression="infer", + quoting=None, + quotechar='"', + lineterminator=None, + chunksize=None, + date_format=None, + doublequote=True, + escapechar=None, + decimal=".", + errors: str = "strict", + storage_options: StorageOptions = None, + ): # pragma: no cover + ErrorMessage.not_implemented() + from snowflake.snowpark.modin.pandas.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + return FactoryDispatcher.to_csv( + self._query_compiler, + path_or_buf=path_or_buf, + sep=sep, + na_rep=na_rep, + float_format=float_format, + columns=columns, + header=header, + index=index, + index_label=index_label, + mode=mode, + encoding=encoding, + compression=compression, + quoting=quoting, + quotechar=quotechar, + lineterminator=lineterminator, + chunksize=chunksize, + date_format=date_format, + doublequote=doublequote, + escapechar=escapechar, + decimal=decimal, + errors=errors, + storage_options=storage_options, + ) + + def to_excel( + self, + excel_writer, + sheet_name="Sheet1", + na_rep="", + float_format=None, + columns=None, + header=True, + index=True, + index_label=None, + startrow=0, + startcol=0, + engine=None, + merge_cells=True, + encoding=no_default, + inf_rep="inf", + verbose=no_default, + freeze_panes=None, + storage_options: StorageOptions = None, + ): # pragma: no cover # noqa: PR01, RT01, D200 + """ + Write object to an Excel sheet. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + return self._default_to_pandas( + "to_excel", + excel_writer, + sheet_name=sheet_name, + na_rep=na_rep, + float_format=float_format, + columns=columns, + header=header, + index=index, + index_label=index_label, + startrow=startrow, + startcol=startcol, + engine=engine, + merge_cells=merge_cells, + inf_rep=inf_rep, + freeze_panes=freeze_panes, + storage_options=storage_options, + ) + + def to_hdf( + self, path_or_buf, key, format="table", **kwargs + ): # pragma: no cover # noqa: PR01, RT01, D200 + """ + Write the contained data to an HDF5 file using HDFStore. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + return self._default_to_pandas( + "to_hdf", path_or_buf, key, format=format, **kwargs + ) + + def to_json( + self, + path_or_buf=None, + orient=None, + date_format=None, + double_precision=10, + force_ascii=True, + date_unit="ms", + default_handler=None, + lines=False, + compression="infer", + index=True, + indent=None, + storage_options: StorageOptions = None, + ): # pragma: no cover # noqa: PR01, RT01, D200 + """ + Convert the object to a JSON string. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + return self._default_to_pandas( + "to_json", + path_or_buf, + orient=orient, + date_format=date_format, + double_precision=double_precision, + force_ascii=force_ascii, + date_unit=date_unit, + default_handler=default_handler, + lines=lines, + compression=compression, + index=index, + indent=indent, + storage_options=storage_options, + ) + + def to_latex( + self, + buf=None, + columns=None, + col_space=None, + header=True, + index=True, + na_rep="NaN", + formatters=None, + float_format=None, + sparsify=None, + index_names=True, + bold_rows=False, + column_format=None, + longtable=None, + escape=None, + encoding=None, + decimal=".", + multicolumn=None, + multicolumn_format=None, + multirow=None, + caption=None, + label=None, + position=None, + ): # pragma: no cover # noqa: PR01, RT01, D200 + """ + Render object to a LaTeX tabular, longtable, or nested table. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + return self._default_to_pandas( + "to_latex", + buf=buf, + columns=columns, + col_space=col_space, + header=header, + index=index, + na_rep=na_rep, + formatters=formatters, + float_format=float_format, + sparsify=sparsify, + index_names=index_names, + bold_rows=bold_rows, + column_format=column_format, + longtable=longtable, + escape=escape, + encoding=encoding, + decimal=decimal, + multicolumn=multicolumn, + multicolumn_format=multicolumn_format, + multirow=multirow, + caption=caption, + label=label, + position=position, + ) + + def to_markdown( + self, + buf=None, + mode: str = "wt", + index: bool = True, + storage_options: StorageOptions = None, + **kwargs, + ): # noqa: PR01, RT01, D200 + """ + Print `BasePandasDataset` in Markdown-friendly format. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + return self._default_to_pandas( + "to_markdown", + buf=buf, + mode=mode, + index=index, + storage_options=storage_options, + **kwargs, + ) + + def to_pickle( + self, + path, + compression: CompressionOptions = "infer", + protocol: int = pkl.HIGHEST_PROTOCOL, + storage_options: StorageOptions = None, + ): # pragma: no cover # noqa: PR01, D200 + """ + Pickle (serialize) object to file. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + from snowflake.snowpark.modin.pandas import to_pickle + + to_pickle( + self, + path, + compression=compression, + protocol=protocol, + storage_options=storage_options, + ) + + def to_numpy( + self, + dtype: npt.DTypeLike | None = None, + copy: bool = False, + na_value: object = no_default, + **kwargs: Any, + ) -> np.ndarray: + """ + Convert the `BasePandasDataset` to a NumPy array or a Modin wrapper for NumPy array. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + if copy: + WarningMessage.ignored_argument( + operation="to_numpy", + argument="copy", + message="copy is ignored in Snowflake backend", + ) + return self._query_compiler.to_numpy( + dtype=dtype, + na_value=na_value, + **kwargs, + ) + + # TODO(williamma12): When this gets implemented, have the series one call this. + def to_period( + self, freq=None, axis=0, copy=True + ): # pragma: no cover # noqa: PR01, RT01, D200 + """ + Convert `BasePandasDataset` from DatetimeIndex to PeriodIndex. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._default_to_pandas("to_period", freq=freq, axis=axis, copy=copy) + + def to_string( + self, + buf=None, + columns=None, + col_space=None, + header=True, + index=True, + na_rep="NaN", + formatters=None, + float_format=None, + sparsify=None, + index_names=True, + justify=None, + max_rows=None, + min_rows=None, + max_cols=None, + show_dimensions=False, + decimal=".", + line_width=None, + max_colwidth=None, + encoding=None, + ): # noqa: PR01, RT01, D200 + """ + Render a `BasePandasDataset` to a console-friendly tabular output. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + return self._default_to_pandas( + "to_string", + buf=buf, + columns=columns, + col_space=col_space, + header=header, + index=index, + na_rep=na_rep, + formatters=formatters, + float_format=float_format, + sparsify=sparsify, + index_names=index_names, + justify=justify, + max_rows=max_rows, + max_cols=max_cols, + show_dimensions=show_dimensions, + decimal=decimal, + line_width=line_width, + max_colwidth=max_colwidth, + encoding=encoding, + ) + + def to_sql( + self, + name, + con, + schema=None, + if_exists="fail", + index=True, + index_label=None, + chunksize=None, + dtype=None, + method=None, + ): # noqa: PR01, D200 + """ + Write records stored in a `BasePandasDataset` to a SQL database. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + new_query_compiler = self._query_compiler + # writing the index to the database by inserting it to the DF + if index: + if not index_label: + index_label = "index" + new_query_compiler = new_query_compiler.insert(0, index_label, self.index) + # so pandas._to_sql will not write the index to the database as well + index = False + + from modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + FactoryDispatcher.to_sql( + new_query_compiler, + name=name, + con=con, + schema=schema, + if_exists=if_exists, + index=index, + index_label=index_label, + chunksize=chunksize, + dtype=dtype, + method=method, + ) + + # TODO(williamma12): When this gets implemented, have the series one call this. + def to_timestamp( + self, freq=None, how="start", axis=0, copy=True + ): # noqa: PR01, RT01, D200 + """ + Cast to DatetimeIndex of timestamps, at *beginning* of period. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() # pragma: no cover + + return self._default_to_pandas( + "to_timestamp", freq=freq, how=how, axis=axis, copy=copy + ) + + def to_xarray(self): # noqa: PR01, RT01, D200 + """ + Return an xarray object from the `BasePandasDataset`. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + return self._default_to_pandas("to_xarray") + + def truediv( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Get floating division of `BasePandasDataset` and `other`, element-wise (binary operator `truediv`). + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._binary_op( + "truediv", other, axis=axis, level=level, fill_value=fill_value + ) + + div = divide = truediv + + def truncate( + self, before=None, after=None, axis=None, copy=True + ): # noqa: PR01, RT01, D200 + """ + Truncate a `BasePandasDataset` before and after some index value. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + axis = self._get_axis_number(axis) + if ( + not self.axes[axis].is_monotonic_increasing + and not self.axes[axis].is_monotonic_decreasing + ): + raise ValueError("truncate requires a sorted index") + s = slice(*self.axes[axis].slice_locs(before, after)) + slice_obj = s if axis == 0 else (slice(None), s) + return self.iloc[slice_obj] + + def transform(self, func, axis=0, *args, **kwargs): # noqa: PR01, RT01, D200 + """ + Call ``func`` on self producing a `BasePandasDataset` with the same axis shape as self. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + kwargs["is_transform"] = True + self._validate_function(func) + try: + result = self.agg(func, axis=axis, *args, **kwargs) + except TypeError: + raise + except Exception as err: + raise ValueError("Transform function failed") from err + try: + assert len(result) == len(self) + except Exception: + raise ValueError("transforms cannot produce aggregated results") + return result + + def tz_convert(self, tz, axis=0, level=None, copy=True): # noqa: PR01, RT01, D200 + """ + Convert tz-aware axis to target time zone. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + axis = self._get_axis_number(axis) + if level is not None: + new_labels = ( + pandas.Series(index=self.axes[axis]).tz_convert(tz, level=level).index + ) + else: + new_labels = self.axes[axis].tz_convert(tz) + obj = self.copy() if copy else self + return obj.set_axis(new_labels, axis, copy=copy) + + def tz_localize( + self, tz, axis=0, level=None, copy=True, ambiguous="raise", nonexistent="raise" + ): # noqa: PR01, RT01, D200 + """ + Localize tz-naive index of a `BasePandasDataset` to target time zone. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() + axis = self._get_axis_number(axis) + new_labels = ( + pandas.Series(index=self.axes[axis]) + .tz_localize( + tz, + axis=axis, + level=level, + copy=False, + ambiguous=ambiguous, + nonexistent=nonexistent, + ) + .index + ) + return self.set_axis(new_labels, axis, copy=copy) + + def var( + self, + axis: Axis | None = None, + skipna: bool = True, + ddof: int = 1, + numeric_only: bool = False, + **kwargs: Any, + ): + """ + Return unbiased variance over requested axis. + """ + kwargs.update({"ddof": ddof}) + return self._agg_helper( + func="var", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + def __abs__(self): + """ + Return a `BasePandasDataset` with absolute numeric value of each element. + + Returns + ------- + BasePandasDataset + Object containing the absolute value of each element. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self.abs() + + @_doc_binary_op( + operation="union", bin_op="and", right="other", **_doc_binary_op_kwargs + ) + def __and__(self, other): + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._binary_op("__and__", other, axis=0) + + @_doc_binary_op( + operation="union", bin_op="rand", right="other", **_doc_binary_op_kwargs + ) + def __rand__(self, other): + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._binary_op("__rand__", other, axis=0) + + def __array__(self, dtype=None): + """ + Return the values as a NumPy array. + + Parameters + ---------- + dtype : str or np.dtype, optional + The dtype of returned array. + + Returns + ------- + arr : np.ndarray + NumPy representation of Modin object. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + arr = self.to_numpy(dtype) + return arr + + def __array_wrap__(self, result, context=None): + """ + Get called after a ufunc and other functions. + + Parameters + ---------- + result : np.ndarray + The result of the ufunc or other function called on the NumPy array + returned by __array__. + context : tuple of (func, tuple, int), optional + This parameter is returned by ufuncs as a 3-element tuple: (name of the + ufunc, arguments of the ufunc, domain of the ufunc), but is not set by + other NumPy functions. + + Returns + ------- + BasePandasDataset + Wrapped Modin object. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() # pragma: no cover + + # TODO: This is very inefficient. __array__ and as_matrix have been + # changed to call the more efficient to_numpy, but this has been left + # unchanged since we are not sure of its purpose. + return self._default_to_pandas("__array_wrap__", result, context=context) + + def __copy__(self, deep=True): + """ + Return the copy of the `BasePandasDataset`. + + Parameters + ---------- + deep : bool, default: True + Whether the copy should be deep or not. + + Returns + ------- + BasePandasDataset + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self.copy(deep=deep) + + def __deepcopy__(self, memo=None): + """ + Return the deep copy of the `BasePandasDataset`. + + Parameters + ---------- + memo : Any, optional + Deprecated parameter. + + Returns + ------- + BasePandasDataset + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self.copy(deep=True) + + @_doc_binary_op( + operation="equality comparison", + bin_op="eq", + right="other", + **_doc_binary_op_kwargs, + ) + def __eq__(self, other): + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self.eq(other) + + def __finalize__(self, other, method=None, **kwargs): + """ + Propagate metadata from `other` to `self`. + + Parameters + ---------- + other : BasePandasDataset + The object from which to get the attributes that we are going + to propagate. + method : str, optional + A passed method name providing context on where `__finalize__` + was called. + **kwargs : dict + Additional keywords arguments to be passed to `__finalize__`. + + Returns + ------- + BasePandasDataset + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() # pragma: no cover + + return self._default_to_pandas("__finalize__", other, method=method, **kwargs) + + @_doc_binary_op( + operation="greater than or equal comparison", + bin_op="ge", + right="right", + **_doc_binary_op_kwargs, + ) + def __ge__(self, right): + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self.ge(right) + + def __getitem__(self, key): + """ + Retrieve dataset according to `key`. + + Parameters + ---------- + key : callable, scalar, slice, str or tuple + The global row index to retrieve data from. + + Returns + ------- + BasePandasDataset + Located dataset. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + key = apply_if_callable(key, self) + # If a slice is passed in, use .iloc[key]. + if isinstance(key, slice): + if (is_integer(key.start) or key.start is None) and ( + is_integer(key.stop) or key.stop is None + ): + return self.iloc[key] + else: + return self.loc[key] + + # If the object calling getitem is a Series, only use .loc[key] to filter index. + if isinstance(self, pd.Series): + return self.loc[key] + + # If the object is a boolean list-like object, use .loc[key] to filter index. + # The if statement is structured this way to avoid calling dtype and reduce query count. + if isinstance(key, pd.Series): + if key.dtype == bool: + return self.loc[key] + elif is_list_like(key): + if hasattr(key, "dtype"): + if key.dtype == bool: + return self.loc[key] + if (all(is_bool(k) for k in key)) and len(key) > 0: + return self.loc[key] + + # In all other cases, use .loc[:, key] to filter columns. + return self.loc[:, key] + + __hash__ = None + + @_doc_binary_op( + operation="greater than comparison", + bin_op="gt", + right="right", + **_doc_binary_op_kwargs, + ) + def __gt__(self, right): + return self.gt(right) + + def __invert__(self): + """ + Apply bitwise inverse to each element of the `BasePandasDataset`. + + Returns + ------- + BasePandasDataset + New BasePandasDataset containing bitwise inverse to each value. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self.__constructor__(query_compiler=self._query_compiler.invert()) + + @_doc_binary_op( + operation="less than or equal comparison", + bin_op="le", + right="right", + **_doc_binary_op_kwargs, + ) + def __le__(self, right): + return self.le(right) + + def __len__(self) -> int: + """ + Return length of info axis. + + Returns + ------- + int + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._query_compiler.get_axis_len(axis=0) + + @_doc_binary_op( + operation="less than comparison", + bin_op="lt", + right="right", + **_doc_binary_op_kwargs, + ) + def __lt__(self, right): + return self.lt(right) + + def __matmul__(self, other): + """ + Compute the matrix multiplication between the `BasePandasDataset` and `other`. + + Parameters + ---------- + other : BasePandasDataset or array-like + The other object to compute the matrix product with. + + Returns + ------- + BasePandasDataset, np.ndarray or scalar + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self.dot(other) + + @_doc_binary_op( + operation="not equal comparison", + bin_op="ne", + right="other", + **_doc_binary_op_kwargs, + ) + def __ne__(self, other): + return self.ne(other) + + def __neg__(self): + """ + Change the sign for every value of self. + + Returns + ------- + BasePandasDataset + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self.__constructor__( + query_compiler=self._query_compiler.unary_op("__neg__") + ) + + def __nonzero__(self): + """ + Evaluate `BasePandasDataset` as boolean object. + + Raises + ------ + ValueError + Always since truth value for self is ambiguous. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + raise ValueError( + f"The truth value of a {self.__class__.__name__} is ambiguous. " + + "Use a.empty, a.bool(), a.item(), a.any() or a.all()." + ) + + __bool__ = __nonzero__ + + @_doc_binary_op( + operation="disjunction", + bin_op="or", + right="other", + **_doc_binary_op_kwargs, + ) + def __or__(self, other): + return self._binary_op("__or__", other, axis=0) + + @_doc_binary_op( + operation="disjunction", + bin_op="ror", + right="other", + **_doc_binary_op_kwargs, + ) + def __ror__(self, other): + return self._binary_op("__ror__", other, axis=0) + + def __sizeof__(self): + """ + Generate the total memory usage for an `BasePandasDataset`. + + Returns + ------- + int + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + ErrorMessage.not_implemented() # pragma: no cover + + return self._default_to_pandas("__sizeof__") + + def __str__(self): # pragma: no cover + """ + Return str(self). + + Returns + ------- + str + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return repr(self) + + @_doc_binary_op( + operation="exclusive disjunction", + bin_op="xor", + right="other", + **_doc_binary_op_kwargs, + ) + def __xor__(self, other): + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._binary_op("__xor__", other, axis=0) + + @_doc_binary_op( + operation="exclusive disjunction", + bin_op="rxor", + right="other", + **_doc_binary_op_kwargs, + ) + def __rxor__(self, other): + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._binary_op("__rxor__", other, axis=0) + + @property + def size(self) -> int: + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return np.prod(self.shape) # type: ignore[return-value] + + @property + def values(self) -> np.ndarray: + """ + Return a NumPy representation of the `BasePandasDataset`. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self.to_numpy() + + def _repartition(self, axis: int | None = None): + """ + Repartitioning Modin objects to get ideal partitions inside. + + Allows to improve performance where the query compiler can't improve + yet by doing implicit repartitioning. + + Parameters + ---------- + axis : {0, 1, None}, optional + The axis along which the repartitioning occurs. + `None` is used for repartitioning along both axes. + + Returns + ------- + DataFrame or Series + The repartitioned dataframe or series, depending on the original type. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + allowed_axis_values = (0, 1, None) + if axis not in allowed_axis_values: + raise ValueError( + f"Passed `axis` parameter: {axis}, but should be one of {allowed_axis_values}" + ) + return self.__constructor__( + query_compiler=self._query_compiler.repartition(axis=axis) + ) + + def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs): + """ + Apply the `ufunc` to the `BasePandasDataset`. + + Parameters + ---------- + ufunc : np.ufunc + The NumPy ufunc to apply. + method : str + The method to apply. + *inputs : tuple + The inputs to the ufunc. + **kwargs : dict + Additional keyword arguments. + + Returns + ------- + BasePandasDataset + The result of the ufunc applied to the `BasePandasDataset`. + """ + # Use pandas version of ufunc if it exists + if method != "__call__": + # Return sentinel value NotImplemented + return NotImplemented + from snowflake.snowpark.modin.plugin.utils.numpy_to_pandas import ( + numpy_to_pandas_universal_func_map, + ) + + if ufunc.__name__ in numpy_to_pandas_universal_func_map: + ufunc = numpy_to_pandas_universal_func_map[ufunc.__name__] + return ufunc(self, inputs[1:], kwargs) + # return the sentinel NotImplemented if we do not support this function + return NotImplemented + + def __array_function__( + self, func: callable, types: tuple, args: tuple, kwargs: dict + ): + """ + Apply the `func` to the `BasePandasDataset`. + + Parameters + ---------- + func : np.func + The NumPy func to apply. + types : tuple + The types of the args. + args : tuple + The args to the func. + kwargs : dict + Additional keyword arguments. + + Returns + ------- + BasePandasDataset + The result of the ufunc applied to the `BasePandasDataset`. + """ + from snowflake.snowpark.modin.plugin.utils.numpy_to_pandas import ( + numpy_to_pandas_func_map, + ) + + if func.__name__ in numpy_to_pandas_func_map: + return numpy_to_pandas_func_map[func.__name__](*args, **kwargs) + else: + # per NEP18 we raise NotImplementedError so that numpy can intercept + return NotImplemented # pragma: no cover diff --git a/src/snowflake/snowpark/modin/pandas/dataframe.py b/src/snowflake/snowpark/modin/pandas/dataframe.py new file mode 100644 index 00000000000..a84748d7e8a --- /dev/null +++ b/src/snowflake/snowpark/modin/pandas/dataframe.py @@ -0,0 +1,3445 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +"""Module houses ``DataFrame`` class, that is distributed version of ``pandas.DataFrame``.""" + +from __future__ import annotations + +import collections +import datetime +import functools +import itertools +import re +import sys +import warnings +from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence +from logging import getLogger +from typing import IO, Any, Callable, Literal + +import numpy as np +import pandas +from pandas._libs.lib import NoDefault, no_default +from pandas._typing import ( + AggFuncType, + AnyArrayLike, + Axes, + Axis, + CompressionOptions, + FilePath, + FillnaOptions, + IgnoreRaise, + IndexLabel, + Level, + PythonFuncType, + Renamer, + Scalar, + StorageOptions, + Suffixes, + WriteBuffer, +) +from pandas.core.common import apply_if_callable, is_bool_indexer +from pandas.core.dtypes.common import ( + infer_dtype_from_object, + is_bool_dtype, + is_dict_like, + is_list_like, + is_numeric_dtype, +) +from pandas.core.dtypes.inference import is_hashable, is_integer +from pandas.core.indexes.frozen import FrozenList +from pandas.io.formats.printing import pprint_thing +from pandas.util._validators import validate_bool_kwarg + +from snowflake.snowpark.modin import pandas as pd +from snowflake.snowpark.modin.pandas.accessor import CachedAccessor, SparseFrameAccessor +from snowflake.snowpark.modin.pandas.base import _ATTRS_NO_LOOKUP, BasePandasDataset +from snowflake.snowpark.modin.pandas.groupby import ( + DataFrameGroupBy, + validate_groupby_args, +) + +# from . import _update_engine +from snowflake.snowpark.modin.pandas.iterator import PartitionIterator +from snowflake.snowpark.modin.pandas.series import Series +from snowflake.snowpark.modin.pandas.snow_partition_iterator import ( + SnowparkPandasRowPartitionIterator, +) +from snowflake.snowpark.modin.pandas.utils import ( + create_empty_native_pandas_frame, + from_non_pandas, + from_pandas, + is_scalar, + raise_if_native_pandas_objects, + replace_external_data_keys_with_empty_pandas_series, + replace_external_data_keys_with_query_compiler, +) +from snowflake.snowpark.modin.plugin._internal.aggregation_utils import ( + is_snowflake_agg_func, +) +from snowflake.snowpark.modin.plugin._internal.utils import is_repr_truncated +from snowflake.snowpark.modin.plugin._typing import DropKeep, ListLike +from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage +from snowflake.snowpark.modin.plugin.utils.warning_message import ( + SET_DATAFRAME_ATTRIBUTE_WARNING, + WarningMessage, +) +from snowflake.snowpark.modin.utils import _inherit_docstrings, hashable, to_pandas +from snowflake.snowpark.udf import UserDefinedFunction + +logger = getLogger(__name__) + +DF_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE = ( + "Currently do not support Series or list-like keys with range-like values" +) + +DF_SETITEM_SLICE_AS_SCALAR_VALUE = ( + "Currently do not support assigning a slice value as if it's a scalar value" +) + +DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE = ( + "{} will result eager evaluation and potential data pulling, which is inefficient. For efficient Snowpark " + "pandas usage, consider rewriting the code with an operator (such as DataFrame.apply or DataFrame.applymap) which " + "can work on the entire DataFrame in one shot." +) + +# Dictionary of extensions assigned to this class +_DATAFRAME_EXTENSIONS_ = {} + + +@_inherit_docstrings( + pandas.DataFrame, + excluded=[ + pandas.DataFrame.flags, + pandas.DataFrame.cov, + pandas.DataFrame.merge, + pandas.DataFrame.reindex, + pandas.DataFrame.to_parquet, + pandas.DataFrame.fillna, + ], + apilink="pandas.DataFrame", +) +class DataFrame(BasePandasDataset): + _pandas_class = pandas.DataFrame + + def __init__( + self, + data=None, + index=None, + columns=None, + dtype=None, + copy=None, + query_compiler=None, + ) -> None: + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + # Siblings are other dataframes that share the same query compiler. We + # use this list to update inplace when there is a shallow copy. + self._siblings = [] + + # Engine.subscribe(_update_engine) + if isinstance(data, (DataFrame, Series)): + self._query_compiler = data._query_compiler.copy() + if index is not None and any(i not in data.index for i in index): + ErrorMessage.not_implemented( + "Passing non-existant columns or index values to constructor not" + + " yet implemented." + ) # pragma: no cover + if isinstance(data, Series): + # We set the column name if it is not in the provided Series + if data.name is None: + self.columns = [0] if columns is None else columns + # If the columns provided are not in the named Series, pandas clears + # the DataFrame and sets columns to the columns provided. + elif columns is not None and data.name not in columns: + self._query_compiler = from_pandas( + self.__constructor__(columns=columns) + )._query_compiler + if index is not None: + self._query_compiler = data.loc[index]._query_compiler + elif columns is None and index is None: + data._add_sibling(self) + else: + if columns is not None and any(i not in data.columns for i in columns): + ErrorMessage.not_implemented( + "Passing non-existant columns or index values to constructor not" + + " yet implemented." + ) # pragma: no cover + if index is None: + index = slice(None) + if columns is None: + columns = slice(None) + self._query_compiler = data.loc[index, columns]._query_compiler + + # Check type of data and use appropriate constructor + elif query_compiler is None: + distributed_frame = from_non_pandas(data, index, columns, dtype) + if distributed_frame is not None: + self._query_compiler = distributed_frame._query_compiler + return + + if isinstance(data, pandas.Index): + pass + elif is_list_like(data) and not is_dict_like(data): + old_dtype = getattr(data, "dtype", None) + values = [ + obj._to_pandas() if isinstance(obj, Series) else obj for obj in data + ] + if isinstance(data, np.ndarray): + data = np.array(values, dtype=old_dtype) + else: + try: + data = type(data)(values, dtype=old_dtype) + except TypeError: + data = values + elif is_dict_like(data) and not isinstance( + data, (pandas.Series, Series, pandas.DataFrame, DataFrame) + ): + if columns is not None: + data = {key: value for key, value in data.items() if key in columns} + + if len(data) and all(isinstance(v, Series) for v in data.values()): + from .general import concat + + new_qc = concat( + data.values(), axis=1, keys=data.keys() + )._query_compiler + + if dtype is not None: + new_qc = new_qc.astype({col: dtype for col in new_qc.columns}) + if index is not None: + new_qc = new_qc.reindex(axis=0, labels=index) + if columns is not None: + new_qc = new_qc.reindex(axis=1, labels=columns) + + self._query_compiler = new_qc + return + + data = { + k: v._to_pandas() if isinstance(v, Series) else v + for k, v in data.items() + } + pandas_df = pandas.DataFrame( + data=data, index=index, columns=columns, dtype=dtype, copy=copy + ) + self._query_compiler = from_pandas(pandas_df)._query_compiler + else: + self._query_compiler = query_compiler + + def __repr__(self): + """ + Return a string representation for a particular ``DataFrame``. + + Returns + ------- + str + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + num_rows = pandas.get_option("display.max_rows") or 10 + # see _repr_html_ for comment, allow here also all column behavior + num_cols = pandas.get_option("display.max_columns") + + ( + row_count, + col_count, + repr_df, + ) = self._query_compiler.build_repr_df(num_rows, num_cols, "x") + result = repr(repr_df) + + # if truncated, add shape information + if is_repr_truncated(row_count, col_count, num_rows, num_cols): + # The split here is so that we don't repr pandas row lengths. + return result.rsplit("\n\n", 1)[0] + "\n\n[{} rows x {} columns]".format( + row_count, col_count + ) + else: + return result + + def _repr_html_(self): # pragma: no cover + """ + Return a html representation for a particular ``DataFrame``. + + Returns + ------- + str + + Notes + ----- + Supports pandas `display.max_rows` and `display.max_columns` options. + """ + num_rows = pandas.get_option("display.max_rows") or 60 + # Modin uses here 20 as default, but this does not coincide well with pandas option. Therefore allow + # here value=0 which means display all columns. + num_cols = pandas.get_option("display.max_columns") + + ( + row_count, + col_count, + repr_df, + ) = self._query_compiler.build_repr_df(num_rows, num_cols) + result = repr_df._repr_html_() + + if is_repr_truncated(row_count, col_count, num_rows, num_cols): + # We split so that we insert our correct dataframe dimensions. + return ( + result.split("

")[0] + + f"

{row_count} rows × {col_count} columns

\n" + ) + else: + return result + + def _get_columns(self) -> pandas.Index: + """ + Get the columns for this Snowpark pandas ``DataFrame``. + + Returns + ------- + pandas.Index + The all columns. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._query_compiler.columns + + def _set_columns(self, new_columns: Axes) -> None: + """ + Set the columns for this Snowpark pandas ``DataFrame``. + + Parameters + ---------- + new_columns : + The new columns to set. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + self._update_inplace( + new_query_compiler=self._query_compiler.set_columns(new_columns) + ) + + columns = property(_get_columns, _set_columns) + + @property + def ndim(self) -> int: + return 2 + + def drop_duplicates( + self, subset=None, keep="first", inplace=False, ignore_index=False + ): # noqa: PR01, RT01, D200 + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + """ + Return ``DataFrame`` with duplicate rows removed. + """ + return super().drop_duplicates( + subset=subset, keep=keep, inplace=inplace, ignore_index=ignore_index + ) + + def dropna( + self, + *, + axis: Axis = 0, + how: str | NoDefault = no_default, + thresh: int | NoDefault = no_default, + subset: IndexLabel = None, + inplace: bool = False, + ): # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return super()._dropna( + axis=axis, how=how, thresh=thresh, subset=subset, inplace=inplace + ) + + @property + def dtypes(self): # noqa: RT01, D200 + """ + Return the dtypes in the ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._query_compiler.dtypes + + def duplicated( + self, subset: Hashable | Sequence[Hashable] = None, keep: DropKeep = "first" + ): + """ + Return boolean ``Series`` denoting duplicate rows. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + df = self[subset] if subset is not None else self + new_qc = df._query_compiler.duplicated(keep=keep) + duplicates = self._reduce_dimension(new_qc) + # remove Series name which was assigned automatically by .apply in QC + # this is pandas behavior, i.e., if duplicated result is a series, no name is returned + duplicates.name = None + return duplicates + + @property + def empty(self) -> bool: + """ + Indicate whether ``DataFrame`` is empty. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return len(self.columns) == 0 or len(self) == 0 + + @property + def axes(self): + """ + Return a list representing the axes of the ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return [self.index, self.columns] + + @property + def shape(self) -> tuple[int, int]: + """ + Return a tuple representing the dimensionality of the ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return len(self), len(self.columns) + + def add_prefix(self, prefix): + """ + Prefix labels with string `prefix`. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + # pandas converts non-string prefix values into str and adds it to the column labels. + return self.__constructor__( + query_compiler=self._query_compiler.add_substring( + str(prefix), substring_type="prefix", axis=1 + ) + ) + + def add_suffix(self, suffix): + """ + Suffix labels with string `suffix`. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + # pandas converts non-string suffix values into str and appends it to the column labels. + return self.__constructor__( + query_compiler=self._query_compiler.add_substring( + str(suffix), substring_type="suffix", axis=1 + ) + ) + + def applymap(self, func: PythonFuncType, na_action: str | None = None, **kwargs): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if not callable(func): + raise TypeError(f"{func} is not callable") + return self.__constructor__( + query_compiler=self._query_compiler.applymap( + func, na_action=na_action, **kwargs + ) + ) + + def aggregate( + self, func: AggFuncType = None, axis: Axis = 0, *args: Any, **kwargs: Any + ): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return super().aggregate(func, axis, *args, **kwargs) + + agg = aggregate + + def apply( + self, + func: AggFuncType | UserDefinedFunction, + axis: Axis = 0, + raw: bool = False, + result_type: Literal["expand", "reduce", "broadcast"] | None = None, + args=(), + **kwargs, + ): + """ + Apply a function along an axis of the ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + axis = self._get_axis_number(axis) + query_compiler = self._query_compiler.apply( + func, + axis, + raw=raw, + result_type=result_type, + args=args, + **kwargs, + ) + if not isinstance(query_compiler, type(self._query_compiler)): + # A scalar was returned + return query_compiler + + # If True, it is an unamed series. + # Theoretically, if df.apply returns a Series, it will only be an unnamed series + # because the function is supposed to be series -> scalar. + if query_compiler._modin_frame.is_unnamed_series(): + return Series(query_compiler=query_compiler) + else: + return self.__constructor__(query_compiler=query_compiler) + + def groupby( + self, + by=None, + axis: Axis | NoDefault = no_default, + level: IndexLabel | None = None, + as_index: bool = True, + sort: bool = True, + group_keys: bool = True, + observed: bool | NoDefault = no_default, + dropna: bool = True, + ): + """ + Group ``DataFrame`` using a mapper or by a ``Series`` of columns. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if axis is not no_default: + axis = self._get_axis_number(axis) + if axis == 1: + warnings.warn( + "DataFrame.groupby with axis=1 is deprecated. Do " + + "`frame.T.groupby(...)` without axis instead.", + FutureWarning, + stacklevel=1, + ) + else: + warnings.warn( + "The 'axis' keyword in DataFrame.groupby is deprecated and " + + "will be removed in a future version.", + FutureWarning, + stacklevel=1, + ) + else: + axis = 0 + + validate_groupby_args(by, level, observed) + + axis = self._get_axis_number(axis) + + if axis != 0 and as_index is False: + raise ValueError("as_index=False only valid for axis=0") + + idx_name = None + + if ( + not isinstance(by, Series) + and is_list_like(by) + and len(by) == 1 + # if by is a list-like of (None,), we have to keep it as a list because + # None may be referencing a column or index level whose label is + # `None`, and by=None wold mean that there is no `by` param. + and by[0] is not None + ): + by = by[0] + + if hashable(by) and ( + not callable(by) and not isinstance(by, (pandas.Grouper, FrozenList)) + ): + idx_name = by + elif isinstance(by, Series): + idx_name = by.name + if by._parent is self: + # if the SnowSeries comes from the current dataframe, + # convert it to labels directly for easy processing + by = by.name + elif is_list_like(by): + if axis == 0 and all( + ( + (hashable(o) and (o in self)) + or isinstance(o, Series) + or (is_list_like(o) and len(o) == len(self.shape[axis])) + ) + for o in by + ): + # plit 'by's into those that belongs to the self (internal_by) + # and those that doesn't (external_by). For SnowSeries that belongs + # to current DataFrame, we convert it to labels for easy process. + internal_by, external_by = [], [] + + for current_by in by: + if hashable(current_by): + internal_by.append(current_by) + elif isinstance(current_by, Series): + if current_by._parent is self: + internal_by.append(current_by.name) + else: + external_by.append(current_by) # pragma: no cover + else: + external_by.append(current_by) + + by = internal_by + external_by + + return DataFrameGroupBy( + self, + by, + axis, + level, + as_index, + sort, + group_keys, + idx_name, + observed=observed, + dropna=dropna, + ) + + def keys(self): # noqa: RT01, D200 + """ + Get columns of the ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self.columns + + def transform( + self, func: PythonFuncType, axis: Axis = 0, *args: Any, **kwargs: Any + ) -> DataFrame: # noqa: PR01, RT01, D200 + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if is_list_like(func) or is_dict_like(func): + ErrorMessage.not_implemented( + "dict and list parameters are not supported for transform" + ) + # throw the same error as pandas for cases where the function type is + # invalid. + if not isinstance(func, str) and not callable(func): + raise TypeError(f"{type(func)} object is not callable") + + # if the function is an aggregation function, we'll produce + # some bogus results while pandas will throw the error the + # code below is throwing. So we do the same. + if is_snowflake_agg_func(func): + raise ValueError("Function did not transform") + + return self.apply(func, axis, False, args=args, **kwargs) + + def transpose(self, copy=False, *args): # noqa: PR01, RT01, D200 + """ + Transpose index and columns. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if copy: + WarningMessage.ignored_argument( + operation="transpose", + argument="copy", + message="Transpose ignore copy argument in Snowpark pandas API", + ) + + if args: + WarningMessage.ignored_argument( + operation="transpose", + argument="args", + message="Transpose ignores args in Snowpark pandas API", + ) + + return self.__constructor__(query_compiler=self._query_compiler.transpose()) + + T = property(transpose) + + def add( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Get addition of ``DataFrame`` and `other`, element-wise (binary operator `add`). + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op( + "add", + other, + axis=axis, + level=level, + fill_value=fill_value, + ) + + def assign(self, **kwargs): # noqa: PR01, RT01, D200 + """ + Assign new columns to a ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() # pragma: no cover + + df = self.copy() + for k, v in kwargs.items(): + if callable(v): + df[k] = v(df) + else: + df[k] = v + return df + + def boxplot( + self, + column=None, + by=None, + ax=None, + fontsize=None, + rot=0, + grid=True, + figsize=None, + layout=None, + return_type=None, + backend=None, + **kwargs, + ): # noqa: PR01, RT01, D200 + """ + Make a box plot from ``DataFrame`` columns. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + return to_pandas(self).boxplot( + column=column, + by=by, + ax=ax, + fontsize=fontsize, + rot=rot, + grid=grid, + figsize=figsize, + layout=layout, + return_type=return_type, + backend=backend, + **kwargs, + ) + + def combine( + self, other, func, fill_value=None, overwrite=True + ): # noqa: PR01, RT01, D200 + """ + Perform column-wise combine with another ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + return super().combine(other, func, fill_value=fill_value, overwrite=overwrite) + + def compare( + self, + other, + align_axis=1, + keep_shape: bool = False, + keep_equal: bool = False, + result_names=("self", "other"), + ) -> DataFrame: # noqa: PR01, RT01, D200 + """ + Compare to another ``DataFrame`` and show the differences. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + if not isinstance(other, DataFrame): + raise TypeError(f"Cannot compare DataFrame to {type(other)}") + other = self._validate_other(other, 0, compare_index=True) + return self.__constructor__( + query_compiler=self._query_compiler.compare( + other, + align_axis=align_axis, + keep_shape=keep_shape, + keep_equal=keep_equal, + result_names=result_names, + ) + ) + + def corr( + self, method="pearson", min_periods=1, numeric_only=False + ): # noqa: PR01, RT01, D200 + """ + Compute pairwise correlation of columns, excluding NA/null values. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + if not numeric_only: + return self._default_to_pandas( + pandas.DataFrame.corr, + method=method, + min_periods=min_periods, + numeric_only=numeric_only, + ) + return self.__constructor__( + query_compiler=self._query_compiler.corr( + method=method, + min_periods=min_periods, + ) + ) + + def corrwith( + self, other, axis=0, drop=False, method="pearson", numeric_only=False + ): # noqa: PR01, RT01, D200 + """ + Compute pairwise correlation. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + if isinstance(other, DataFrame): + other = other._query_compiler.to_pandas() + return self._default_to_pandas( + pandas.DataFrame.corrwith, + other, + axis=axis, + drop=drop, + method=method, + numeric_only=numeric_only, + ) + + def cov( + self, + min_periods: int | None = None, + ddof: int | None = 1, + numeric_only: bool = False, + ): + """ + Compute pairwise covariance of columns, excluding NA/null values. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + return self.__constructor__( + query_compiler=self._query_compiler.cov( + min_periods=min_periods, + ddof=ddof, + numeric_only=numeric_only, + ) + ) + + def dot(self, other): # noqa: PR01, RT01, D200 + """ + Compute the matrix multiplication between the ``DataFrame`` and `other`. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() # pragma: no cover + + if isinstance(other, BasePandasDataset): + common = self.columns.union(other.index) + if len(common) > len(self.columns) or len(common) > len( + other + ): # pragma: no cover + raise ValueError("Matrices are not aligned") + + if isinstance(other, DataFrame): + return self.__constructor__( + query_compiler=self._query_compiler.dot( + other.reindex(index=common), squeeze_self=False + ) + ) + else: + return self._reduce_dimension( + query_compiler=self._query_compiler.dot( + other.reindex(index=common), squeeze_self=False + ) + ) + + other = np.asarray(other) + if self.shape[1] != other.shape[0]: + raise ValueError( + f"Dot product shape mismatch, {self.shape} vs {other.shape}" + ) + + if len(other.shape) > 1: + return self.__constructor__( + query_compiler=self._query_compiler.dot(other, squeeze_self=False) + ) + + return self._reduce_dimension( + query_compiler=self._query_compiler.dot(other, squeeze_self=False) + ) + + def eq(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 + """ + Perform equality comparison of ``DataFrame`` and `other` (binary operator `eq`). + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op("eq", other, axis=axis, level=level) + + def equals(self, other): # noqa: PR01, RT01, D200 + """ + Test whether two objects contain the same elements. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() # pragma: no cover + + if isinstance(other, pandas.DataFrame): + # Copy into a Modin DataFrame to simplify logic below + other = self.__constructor__(other) + return ( + self.index.equals(other.index) + and self.columns.equals(other.columns) + and self.eq(other).all().all() + ) + + def _update_var_dicts_in_kwargs(self, expr, kwargs): + """ + Copy variables with "@" prefix in `local_dict` and `global_dict` keys of kwargs. + + Parameters + ---------- + expr : str + The expression string to search variables with "@" prefix. + kwargs : dict + See the documentation for eval() for complete details on the keyword arguments accepted by query(). + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if "@" not in expr: + return + frame = sys._getframe() + try: + f_locals = frame.f_back.f_back.f_back.f_back.f_locals + f_globals = frame.f_back.f_back.f_back.f_back.f_globals + finally: + del frame + local_names = set(re.findall(r"@([\w]+)", expr)) + local_dict = {} + global_dict = {} + + for name in local_names: + for dct_out, dct_in in ((local_dict, f_locals), (global_dict, f_globals)): + try: + dct_out[name] = dct_in[name] + except KeyError: + pass + + if local_dict: + local_dict.update(kwargs.get("local_dict") or {}) + kwargs["local_dict"] = local_dict + if global_dict: + global_dict.update(kwargs.get("global_dict") or {}) + kwargs["global_dict"] = global_dict + + def eval(self, expr, inplace=False, **kwargs): # noqa: PR01, RT01, D200 + """ + Evaluate a string describing operations on ``DataFrame`` columns. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + self._validate_eval_query(expr, **kwargs) + inplace = validate_bool_kwarg(inplace, "inplace") + self._update_var_dicts_in_kwargs(expr, kwargs) + new_query_compiler = self._query_compiler.eval(expr, **kwargs) + return_type = type( + pandas.DataFrame(columns=self.columns) + .astype(self.dtypes) + .eval(expr, **kwargs) + ).__name__ + if return_type == type(self).__name__: + return self._create_or_update_from_compiler(new_query_compiler, inplace) + else: + if inplace: + raise ValueError("Cannot operate inplace if there is no assignment") + return getattr(sys.modules[self.__module__], return_type)( + query_compiler=new_query_compiler + ) + + def fillna( + self, + value: Hashable | Mapping | Series | DataFrame = None, + *, + method: FillnaOptions | None = None, + axis: Axis | None = None, + inplace: bool = False, + limit: int | None = None, + downcast: dict | None = None, + ) -> DataFrame | None: + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return super().fillna( + self_is_series=False, + value=value, + method=method, + axis=axis, + inplace=inplace, + limit=limit, + downcast=downcast, + ) + + def floordiv( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Get integer division of ``DataFrame`` and `other`, element-wise (binary operator `floordiv`). + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op( + "floordiv", + other, + axis=axis, + level=level, + fill_value=fill_value, + ) + + @classmethod + def from_dict( + cls, data, orient="columns", dtype=None, columns=None + ): # pragma: no cover # noqa: PR01, RT01, D200 + """ + Construct ``DataFrame`` from dict of array-like or dicts. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + return from_pandas( + pandas.DataFrame.from_dict( + data, orient=orient, dtype=dtype, columns=columns + ) + ) + + @classmethod + def from_records( + cls, + data, + index=None, + exclude=None, + columns=None, + coerce_float=False, + nrows=None, + ): # pragma: no cover # noqa: PR01, RT01, D200 + """ + Convert structured or record ndarray to ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + return from_pandas( + pandas.DataFrame.from_records( + data, + index=index, + exclude=exclude, + columns=columns, + coerce_float=coerce_float, + nrows=nrows, + ) + ) + + def ge(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 + """ + Get greater than or equal comparison of ``DataFrame`` and `other`, element-wise (binary operator `ge`). + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op("ge", other, axis=axis, level=level) + + def gt(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 + """ + Get greater than comparison of ``DataFrame`` and `other`, element-wise (binary operator `ge`). + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op("gt", other, axis=axis, level=level) + + def hist( + self, + column=None, + by=None, + grid=True, + xlabelsize=None, + xrot=None, + ylabelsize=None, + yrot=None, + ax=None, + sharex=False, + sharey=False, + figsize=None, + layout=None, + bins=10, + **kwds, + ): # pragma: no cover # noqa: PR01, RT01, D200 + """ + Make a histogram of the ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + return self._default_to_pandas( + pandas.DataFrame.hist, + column=column, + by=by, + grid=grid, + xlabelsize=xlabelsize, + xrot=xrot, + ylabelsize=ylabelsize, + yrot=yrot, + ax=ax, + sharex=sharex, + sharey=sharey, + figsize=figsize, + layout=layout, + bins=bins, + **kwds, + ) + + def info( + self, + verbose: bool | None = None, + buf: IO[str] | None = None, + max_cols: int | None = None, + memory_usage: bool | str | None = None, + show_counts: bool | None = None, + null_counts: bool | None = None, + ): # noqa: PR01, D200 + """ + Print a concise summary of the ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + def put_str(src, output_len=None, spaces=2): + src = str(src) + return src.ljust(output_len if output_len else len(src)) + " " * spaces + + def format_size(num): + for x in ["bytes", "KB", "MB", "GB", "TB"]: + if num < 1024.0: + return f"{num:3.1f} {x}" + num /= 1024.0 + return f"{num:3.1f} PB" + + output = [] + + type_line = str(type(self)) + index_line = "SnowflakeIndex" + columns = self.columns + columns_len = len(columns) + dtypes = self.dtypes + dtypes_line = f"dtypes: {', '.join(['{}({})'.format(dtype, count) for dtype, count in dtypes.value_counts().items()])}" + + if max_cols is None: + max_cols = 100 + + exceeds_info_cols = columns_len > max_cols + + if buf is None: + buf = sys.stdout + + if null_counts is None: + null_counts = not exceeds_info_cols + + if verbose is None: + verbose = not exceeds_info_cols + + if null_counts and verbose: + # We're gonna take items from `non_null_count` in a loop, which + # works kinda slow with `Modin.Series`, that's why we call `_to_pandas()` here + # that will be faster. + non_null_count = self.count()._to_pandas() + + if memory_usage is None: + memory_usage = True + + def get_header(spaces=2): + output = [] + head_label = " # " + column_label = "Column" + null_label = "Non-Null Count" + dtype_label = "Dtype" + non_null_label = " non-null" + delimiter = "-" + + lengths = {} + lengths["head"] = max(len(head_label), len(pprint_thing(len(columns)))) + lengths["column"] = max( + len(column_label), max(len(pprint_thing(col)) for col in columns) + ) + lengths["dtype"] = len(dtype_label) + dtype_spaces = ( + max(lengths["dtype"], max(len(pprint_thing(dtype)) for dtype in dtypes)) + - lengths["dtype"] + ) + + header = put_str(head_label, lengths["head"]) + put_str( + column_label, lengths["column"] + ) + if null_counts: + lengths["null"] = max( + len(null_label), + max(len(pprint_thing(x)) for x in non_null_count) + + len(non_null_label), + ) + header += put_str(null_label, lengths["null"]) + header += put_str(dtype_label, lengths["dtype"], spaces=dtype_spaces) + + output.append(header) + + delimiters = put_str(delimiter * lengths["head"]) + put_str( + delimiter * lengths["column"] + ) + if null_counts: + delimiters += put_str(delimiter * lengths["null"]) + delimiters += put_str(delimiter * lengths["dtype"], spaces=dtype_spaces) + output.append(delimiters) + + return output, lengths + + output.extend([type_line, index_line]) + + def verbose_repr(output): + columns_line = f"Data columns (total {len(columns)} columns):" + header, lengths = get_header() + output.extend([columns_line, *header]) + for i, col in enumerate(columns): + i, col_s, dtype = map(pprint_thing, [i, col, dtypes[col]]) + + to_append = put_str(f" {i}", lengths["head"]) + put_str( + col_s, lengths["column"] + ) + if null_counts: + non_null = pprint_thing(non_null_count[col]) + to_append += put_str(f"{non_null} non-null", lengths["null"]) + to_append += put_str(dtype, lengths["dtype"], spaces=0) + output.append(to_append) + + def non_verbose_repr(output): + output.append(columns._summary(name="Columns")) + + if verbose: + verbose_repr(output) + else: + non_verbose_repr(output) + + output.append(dtypes_line) + + if memory_usage: + deep = memory_usage == "deep" + mem_usage_bytes = self.memory_usage(index=True, deep=deep).sum() + mem_line = f"memory usage: {format_size(mem_usage_bytes)}" + + output.append(mem_line) + + output.append("") + buf.write("\n".join(output)) + + def insert( + self, + loc: int, + column: Hashable, + value: Scalar | AnyArrayLike, + allow_duplicates: bool | NoDefault = no_default, + ) -> None: + """ + Insert column into ``DataFrame`` at specified location. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + raise_if_native_pandas_objects(value) + if allow_duplicates is no_default: + allow_duplicates = False + if not allow_duplicates and column in self.columns: + raise ValueError(f"cannot insert {column}, already exists") + + if not isinstance(loc, int): + raise TypeError("loc must be int") + + # If columns labels are multilevel, we implement following behavior (this is + # name native pandas): + # Case 1: if 'column' is tuple it's length must be same as number of levels + # otherwise raise error. + # Case 2: if 'column' is not a tuple, create a tuple out of it by filling in + # empty strings to match the length of column levels in self frame. + if self.columns.nlevels > 1: + if isinstance(column, tuple) and len(column) != self.columns.nlevels: + # same error as native pandas. + raise ValueError("Item must have length equal to number of levels.") + if not isinstance(column, tuple): + # Fill empty strings to match length of levels + suffix = [""] * (self.columns.nlevels - 1) + column = tuple([column] + suffix) + + # Dictionary keys are treated as index column and this should be joined with + # index of target dataframe. This behavior is similar to 'value' being DataFrame + # or Series, so we simply create Series from dict data here. + if isinstance(value, dict): + value = Series(value, name=column) + + if isinstance(value, DataFrame) or ( + isinstance(value, np.ndarray) and len(value.shape) > 1 + ): + # Supported numpy array shapes are + # 1. (N, ) -> Ex. [1, 2, 3] + # 2. (N, 1) -> Ex> [[1], [2], [3]] + if value.shape[1] != 1: + if isinstance(value, DataFrame): + # Error message updated in pandas 2.1, needs to be upstreamed to OSS modin + raise ValueError( + f"Expected a one-dimensional object, got a {type(value).__name__} with {value.shape[1]} columns instead." + ) + else: + raise ValueError( + f"Expected a 1D array, got an array with shape {value.shape}" + ) + # Change numpy array shape from (N, 1) to (N, ) + if isinstance(value, np.ndarray): + value = value.squeeze(axis=1) + + if ( + is_list_like(value) + and not isinstance(value, (Series, DataFrame)) + and len(value) != self.shape[0] + and not 0 == self.shape[0] # dataframe holds no rows + ): + raise ValueError( + "Length of values ({}) does not match length of index ({})".format( + len(value), len(self) + ) + ) + if not -len(self.columns) <= loc <= len(self.columns): + raise IndexError( + f"index {loc} is out of bounds for axis 0 with size {len(self.columns)}" + ) + elif loc < 0: + raise ValueError("unbounded slice") + + join_on_index = False + if isinstance(value, (Series, DataFrame)): + value = value._query_compiler + join_on_index = True + elif is_list_like(value): + value = Series(value, name=column)._query_compiler + + new_query_compiler = self._query_compiler.insert( + loc, column, value, join_on_index + ) + # In pandas, 'insert' operation is always inplace. + self._update_inplace(new_query_compiler=new_query_compiler) + + def interpolate( + self, + method="linear", + axis=0, + limit=None, + inplace=False, + limit_direction: str | None = None, + limit_area=None, + downcast=None, + **kwargs, + ): # noqa: PR01, RT01, D200 + """ + Fill NaN values using an interpolation method. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + return self._default_to_pandas( + pandas.DataFrame.interpolate, + method=method, + axis=axis, + limit=limit, + inplace=inplace, + limit_direction=limit_direction, + limit_area=limit_area, + downcast=downcast, + **kwargs, + ) + + def iterrows(self) -> Iterator[tuple[Hashable, Series]]: + """ + Iterate over ``DataFrame`` rows as (index, ``Series``) pairs. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + def iterrow_builder(s): + """Return tuple of the given `s` parameter name and the parameter themselves.""" + return s.name, s + + # Raise warning message since iterrows is very inefficient. + WarningMessage.single_warning( + DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE.format("DataFrame.iterrows") + ) + + partition_iterator = SnowparkPandasRowPartitionIterator(self, iterrow_builder) + yield from partition_iterator + + def items(self): # noqa: D200 + """ + Iterate over (column name, ``Series``) pairs. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + + def items_builder(s): + """Return tuple of the given `s` parameter name and the parameter themselves.""" + return s.name, s + + partition_iterator = PartitionIterator(self, 1, items_builder) + yield from partition_iterator + + def iteritems(self): # noqa: RT01, D200 + """ + Iterate over (column name, ``Series``) pairs. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + return self.items() + + def itertuples( + self, index: bool = True, name: str | None = "Pandas" + ) -> Iterable[tuple[Any, ...]]: + """ + Iterate over ``DataFrame`` rows as ``namedtuple``-s. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + + def itertuples_builder(s): + """Return the next namedtuple.""" + # s is the Series of values in the current row. + fields = [] # column names + data = [] # values under each column + + if index: + data.append(s.name) + fields.append("Index") + + # Fill column names and values. + fields.extend(list(self.columns)) + data.extend(s) + + if name is not None: + # Creating the namedtuple. + itertuple = collections.namedtuple(name, fields, rename=True) + return itertuple._make(data) + + # When the name is None, return a regular tuple. + return tuple(data) + + # Raise warning message since itertuples is very inefficient. + WarningMessage.single_warning( + DF_ITERROWS_ITERTUPLES_WARNING_MESSAGE.format("DataFrame.itertuples") + ) + return SnowparkPandasRowPartitionIterator(self, itertuples_builder, True) + + def join( + self, + other: DataFrame | Series | Iterable[DataFrame | Series], + on: IndexLabel | None = None, + how: str = "left", + lsuffix: str = "", + rsuffix: str = "", + sort: bool = False, + validate: str | None = None, + ) -> DataFrame: + """ + Join columns of another ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + for o in other if isinstance(other, list) else [other]: + raise_if_native_pandas_objects(o) + + # Similar to native pandas we implement 'join' using 'pd.merge' method. + # Following code is copied from native pandas (with few changes explained below) + # https://github.com/pandas-dev/pandas/blob/v1.5.3/pandas/core/frame.py#L10002 + if isinstance(other, Series): + # Same error as native pandas. + if other.name is None: + raise ValueError("Other Series must have a name") + other = DataFrame(other) + elif is_list_like(other): + if any([isinstance(o, Series) and o.name is None for o in other]): + raise ValueError("Other Series must have a name") + + if isinstance(other, DataFrame): + if how == "cross": + return pd.merge( + self, + other, + how=how, + on=on, + suffixes=(lsuffix, rsuffix), + sort=sort, + validate=validate, + ) + return pd.merge( + self, + other, + left_on=on, + how=how, + left_index=on is None, + right_index=True, + suffixes=(lsuffix, rsuffix), + sort=sort, + validate=validate, + ) + else: # List of DataFrame/Series + # Same error as native pandas. + if on is not None: + raise ValueError( + "Joining multiple DataFrames only supported for joining on index" + ) + + # Same error as native pandas. + if rsuffix or lsuffix: + raise ValueError( + "Suffixes not supported when joining multiple DataFrames" + ) + + # NOTE: These are not the differences between Snowpark pandas API and pandas behavior + # these are differences between native pandas join behavior when join + # frames have unique index or not. + + # In native pandas logic to join multiple DataFrames/Series is data + # dependent. Under the hood it will either use 'concat' or 'merge' API + # Case 1. If all objects being joined have unique index use 'concat' (axis=1) + # Case 2. Otherwise use 'merge' API by looping through objects left to right. + # https://github.com/pandas-dev/pandas/blob/v1.5.3/pandas/core/frame.py#L10046 + + # Even though concat (axis=1) and merge are very similar APIs they have + # some differences which leads to inconsistent behavior in native pandas. + # 1. Treatment of un-named Series + # Case #1: Un-named series is allowed in concat API. Objects are joined + # successfully by assigning a number as columns name (see 'concat' API + # documentation for details on treatment of un-named series). + # Case #2: It raises 'ValueError: Other Series must have a name' + + # 2. how='right' + # Case #1: 'concat' API doesn't support right join. It raises + # 'ValueError: Only can inner (intersect) or outer (union) join the other axis' + # Case #2: Merges successfully. + + # 3. Joining frames with duplicate labels but no conflict with other frames + # Example: self = DataFrame(... columns=["A", "B"]) + # other = [DataFrame(... columns=["C", "C"])] + # Case #1: 'ValueError: Indexes have overlapping values' + # Case #2: Merged successfully. + + # In addition to this, native pandas implementation also leads to another + # type of inconsistency where left.join(other, ...) and + # left.join([other], ...) might behave differently for cases mentioned + # above. + # Example: + # import pandas as pd + # df = pd.DataFrame({"a": [4, 5]}) + # other = pd.Series([1, 2]) + # df.join([other]) # this is successful + # df.join(other) # this raises 'ValueError: Other Series must have a name' + + # In Snowpark pandas API, we provide consistent behavior by always using 'merge' API + # to join multiple DataFrame/Series. So always follow the behavior + # documented as Case #2 above. + + joined = self + for frame in other: + if isinstance(frame, DataFrame): + overlapping_cols = set(joined.columns).intersection( + set(frame.columns) + ) + if len(overlapping_cols) > 0: + # Native pandas raises: 'Indexes have overlapping values' + # We differ slightly from native pandas message to make it more + # useful to users. + raise ValueError( + f"Join dataframes have overlapping column labels: {overlapping_cols}" + ) + joined = pd.merge( + joined, + frame, + how=how, + left_index=True, + right_index=True, + validate=validate, + sort=sort, + suffixes=(None, None), + ) + return joined + + def isna(self): + return super().isna() + + def isnull(self): + return super().isnull() + + def isetitem(self, loc, value): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + return self._default_to_pandas( + pandas.DataFrame.isetitem, + loc=loc, + value=value, + ) + + def le(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 + """ + Get less than or equal comparison of ``DataFrame`` and `other`, element-wise (binary operator `le`). + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op("le", other, axis=axis, level=level) + + def lt(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 + """ + Get less than comparison of ``DataFrame`` and `other`, element-wise (binary operator `le`). + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op("lt", other, axis=axis, level=level) + + def melt( + self, + id_vars=None, + value_vars=None, + var_name=None, + value_name="value", + col_level=None, + ignore_index=True, + ): # noqa: PR01, RT01, D200 + """ + Unpivot a ``DataFrame`` from wide to long format, optionally leaving identifiers set. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if id_vars is None: + id_vars = [] + if not is_list_like(id_vars): + id_vars = [id_vars] + if value_vars is None: + # Behavior of Index.difference changed in 2.2.x + # https://github.com/pandas-dev/pandas/pull/55113 + # This change needs upstream to Modin: + # https://github.com/modin-project/modin/issues/7206 + value_vars = self.columns.drop(id_vars) + if var_name is None: + columns_name = self._query_compiler.get_index_name(axis=1) + var_name = columns_name if columns_name is not None else "variable" + return self.__constructor__( + query_compiler=self._query_compiler.melt( + id_vars=id_vars, + value_vars=value_vars, + var_name=var_name, + value_name=value_name, + col_level=col_level, + ignore_index=ignore_index, + ) + ) + + def memory_usage(self, index=True, deep=False): # noqa: PR01, RT01, D200 + """ + Return the memory usage of each column in bytes. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() # pragma: no cover + + if index: + result = self._reduce_dimension( + self._query_compiler.memory_usage(index=False, deep=deep) + ) + index_value = self.index.memory_usage(deep=deep) + return pd.concat( + [Series(index_value, index=["Index"]), result] + ) # pragma: no cover + return super().memory_usage(index=index, deep=deep) + + def merge( + self, + right: DataFrame | Series, + how: str = "inner", + on: IndexLabel | None = None, + left_on: Hashable + | AnyArrayLike + | Sequence[Hashable | AnyArrayLike] + | None = None, + right_on: Hashable + | AnyArrayLike + | Sequence[Hashable | AnyArrayLike] + | None = None, + left_index: bool = False, + right_index: bool = False, + sort: bool = False, + suffixes: Suffixes = ("_x", "_y"), + copy: bool = True, + indicator: bool = False, + validate: str | None = None, + ) -> DataFrame: + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + # Raise error if native pandas objects are passed. + raise_if_native_pandas_objects(right) + + if isinstance(right, Series) and right.name is None: + raise ValueError("Cannot merge a Series without a name") + if not isinstance(right, (Series, DataFrame)): + raise TypeError( + f"Can only merge Series or DataFrame objects, a {type(right)} was passed" + ) + + if isinstance(right, Series): + right_column_nlevels = ( + len(right.name) if isinstance(right.name, tuple) else 1 + ) + else: + right_column_nlevels = right.columns.nlevels + if self.columns.nlevels != right_column_nlevels: + # This is deprecated in native pandas. We raise explicit error for this. + raise ValueError( + "Can not merge objects with different column levels." + + f" ({self.columns.nlevels} levels on the left," + + f" {right_column_nlevels} on the right)" + ) + + # Merge empty native pandas dataframes for error checking. Otherwise, it will + # require a lot of logic to be written. This takes care of raising errors for + # following scenarios: + # 1. Only 'left_index' is set to True. + # 2. Only 'right_index is set to True. + # 3. Only 'left_on' is provided. + # 4. Only 'right_on' is provided. + # 5. 'on' and 'left_on' both are provided + # 6. 'on' and 'right_on' both are provided + # 7. 'on' and 'left_index' both are provided + # 8. 'on' and 'right_index' both are provided + # 9. 'left_on' and 'left_index' both are provided + # 10. 'right_on' and 'right_index' both are provided + # 11. Length mismatch between 'left_on' and 'right_on' + # 12. 'left_index' is not a bool + # 13. 'right_index' is not a bool + # 14. 'on' is not None and how='cross' + # 15. 'left_on' is not None and how='cross' + # 16. 'right_on' is not None and how='cross' + # 17. 'left_index' is True and how='cross' + # 18. 'right_index' is True and how='cross' + # 19. Unknown label in 'on', 'left_on' or 'right_on' + # 20. Provided 'suffixes' is not sufficient to resolve conflicts. + # 21. Merging on column with duplicate labels. + # 22. 'how' not in {'left', 'right', 'inner', 'outer', 'cross'} + # 23. conflict with existing labels for array-like join key + # 24. 'indicator' argument is not bool or str + # 25. indicator column label conflicts with existing data labels + create_empty_native_pandas_frame(self).merge( + create_empty_native_pandas_frame(right), + on=on, + how=how, + left_on=replace_external_data_keys_with_empty_pandas_series(left_on), + right_on=replace_external_data_keys_with_empty_pandas_series(right_on), + left_index=left_index, + right_index=right_index, + suffixes=suffixes, + indicator=indicator, + ) + + return self.__constructor__( + query_compiler=self._query_compiler.merge( + right._query_compiler, + how=how, + on=on, + left_on=replace_external_data_keys_with_query_compiler(self, left_on), + right_on=replace_external_data_keys_with_query_compiler( + right, right_on + ), + left_index=left_index, + right_index=right_index, + sort=sort, + suffixes=suffixes, + copy=copy, + indicator=indicator, + validate=validate, + ) + ) + + def mod( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Get modulo of ``DataFrame`` and `other`, element-wise (binary operator `mod`). + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op( + "mod", + other, + axis=axis, + level=level, + fill_value=fill_value, + ) + + def mul( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Get multiplication of ``DataFrame`` and `other`, element-wise (binary operator `mul`). + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op( + "mul", + other, + axis=axis, + level=level, + fill_value=fill_value, + ) + + multiply = mul + + def rmul( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Get multiplication of ``DataFrame`` and `other`, element-wise (binary operator `mul`). + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op( + "rmul", + other, + axis=axis, + level=level, + fill_value=fill_value, + ) + + def ne(self, other, axis="columns", level=None): # noqa: PR01, RT01, D200 + """ + Get not equal comparison of ``DataFrame`` and `other`, element-wise (binary operator `ne`). + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op("ne", other, axis=axis, level=level) + + def nlargest(self, n, columns, keep="first"): # noqa: PR01, RT01, D200 + """ + Return the first `n` rows ordered by `columns` in descending order. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + return self.__constructor__( + query_compiler=self._query_compiler.nlargest(n, columns, keep) + ) + + def nsmallest(self, n, columns, keep="first"): # noqa: PR01, RT01, D200 + """ + Return the first `n` rows ordered by `columns` in ascending order. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + return self.__constructor__( + query_compiler=self._query_compiler.nsmallest( + n=n, columns=columns, keep=keep + ) + ) + + def unstack(self, level=-1, fill_value=None): # noqa: PR01, RT01, D200 + """ + Pivot a level of the (necessarily hierarchical) index labels. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + if not isinstance(self.index, pandas.MultiIndex) or ( + isinstance(self.index, pandas.MultiIndex) + and is_list_like(level) + and len(level) == self.index.nlevels + ): + return self._reduce_dimension( + query_compiler=self._query_compiler.unstack(level, fill_value) + ) + else: + return self.__constructor__( + query_compiler=self._query_compiler.unstack(level, fill_value) + ) + + def pivot(self, index=None, columns=None, values=None): # noqa: PR01, RT01, D200 + """ + Return reshaped ``DataFrame`` organized by given index / column values. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + return self.__constructor__( + query_compiler=self._query_compiler.pivot( + index=index, columns=columns, values=values + ) + ) + + def pivot_table( + self, + values=None, + index=None, + columns=None, + aggfunc="mean", + fill_value=None, + margins=False, + dropna=True, + margins_name="All", + observed=False, + sort=True, + ): + """ + Create a spreadsheet-style pivot table as a ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + result = self.__constructor__( + query_compiler=self._query_compiler.pivot_table( + index=index, + values=values, + columns=columns, + aggfunc=aggfunc, + fill_value=fill_value, + margins=margins, + dropna=dropna, + margins_name=margins_name, + observed=observed, + sort=sort, + ) + ) + return result + + @property + def plot( + self, + x=None, + y=None, + kind="line", + ax=None, + subplots=False, + sharex=None, + sharey=False, + layout=None, + figsize=None, + use_index=True, + title=None, + grid=None, + legend=True, + style=None, + logx=False, + logy=False, + loglog=False, + xticks=None, + yticks=None, + xlim=None, + ylim=None, + rot=None, + fontsize=None, + colormap=None, + table=False, + yerr=None, + xerr=None, + secondary_y=False, + sort_columns=False, + **kwargs, + ): # noqa: PR01, RT01, D200 + """ + Make plots of ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + return self._to_pandas().plot + + def pow( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Get exponential power of ``DataFrame`` and `other`, element-wise (binary operator `pow`). + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op( + "pow", + other, + axis=axis, + level=level, + fill_value=fill_value, + ) + + def prod( + self, + axis=None, + skipna=True, + numeric_only=False, + min_count=0, + **kwargs, + ): # noqa: PR01, RT01, D200 + """ + Return the product of the values over the requested axis. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + validate_bool_kwarg(skipna, "skipna", none_allowed=False) + axis = self._get_axis_number(axis) + axis_to_apply = self.columns if axis else self.index + if ( + skipna is not False + and numeric_only is None + and min_count > len(axis_to_apply) + ): + new_index = self.columns if not axis else self.index + return Series( + [np.nan] * len(new_index), index=new_index, dtype=np.dtype("object") + ) + + data = self._validate_dtypes_sum_prod_mean(axis, numeric_only, ignore_axis=True) + if min_count > 1: + return data._reduce_dimension( + data._query_compiler.prod_min_count( + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + min_count=min_count, + **kwargs, + ) + ) + return data._reduce_dimension( + data._query_compiler.prod( + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + min_count=min_count, + **kwargs, + ) + ) + + product = prod + + def quantile( + self, + q: Scalar | ListLike = 0.5, + axis: Axis = 0, + numeric_only: bool = False, + interpolation: Literal[ + "linear", "lower", "higher", "midpoint", "nearest" + ] = "linear", + method: Literal["single", "table"] = "single", + ): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return super().quantile( + q=q, + axis=axis, + numeric_only=numeric_only, + interpolation=interpolation, + method=method, + ) + + def query(self, expr, inplace=False, **kwargs): # noqa: PR01, RT01, D200 + """ + Query the columns of a ``DataFrame`` with a boolean expression. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + self._update_var_dicts_in_kwargs(expr, kwargs) + self._validate_eval_query(expr, **kwargs) + inplace = validate_bool_kwarg(inplace, "inplace") + new_query_compiler = self._query_compiler.query(expr, **kwargs) + return self._create_or_update_from_compiler(new_query_compiler, inplace) + + def rename( + self, + mapper: Renamer | None = None, + *, + index: Renamer | None = None, + columns: Renamer | None = None, + axis: Axis | None = None, + copy: bool | None = None, + inplace: bool = False, + level: Level | None = None, + errors: IgnoreRaise = "ignore", + ) -> DataFrame | None: + """ + Alter axes labels. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + inplace = validate_bool_kwarg(inplace, "inplace") + if mapper is None and index is None and columns is None: + raise TypeError("must pass an index to rename") + + if index is not None or columns is not None: + if axis is not None: + raise TypeError( + "Cannot specify both 'axis' and any of 'index' or 'columns'" + ) + elif mapper is not None: + raise TypeError( + "Cannot specify both 'mapper' and any of 'index' or 'columns'" + ) + else: + # use the mapper argument + if axis and self._get_axis_number(axis) == 1: + columns = mapper + else: + index = mapper + + if copy is not None: + WarningMessage.ignored_argument( + operation="dataframe.rename", + argument="copy", + message="copy parameter has been ignored with Snowflake execution engine", + ) + + if isinstance(index, dict): + index = Series(index) + + new_qc = self._query_compiler.rename( + index_renamer=index, columns_renamer=columns, level=level, errors=errors + ) + return self._create_or_update_from_compiler( + new_query_compiler=new_qc, inplace=inplace + ) + + def reindex( + self, + labels=None, + index=None, + columns=None, + axis=None, + method=None, + copy=None, + level=None, + fill_value=np.nan, + limit=None, + tolerance=None, + ): # noqa: PR01, RT01, D200 + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() # pragma: no cover + + axis = self._get_axis_number(axis) + if axis == 0 and labels is not None: + index = labels + elif labels is not None: + columns = labels + return super().reindex( + index=index, + columns=columns, + method=method, + copy=copy, + level=level, + fill_value=fill_value, + limit=limit, + tolerance=tolerance, + ) + + def replace( + self, + to_replace=None, + value=no_default, + inplace: bool = False, + limit=None, + regex: bool = False, + method: str | NoDefault = no_default, + ): + """ + Replace values given in `to_replace` with `value`. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + inplace = validate_bool_kwarg(inplace, "inplace") + new_query_compiler = self._query_compiler.replace( + to_replace=to_replace, + value=value, + limit=limit, + regex=regex, + method=method, + ) + return self._create_or_update_from_compiler(new_query_compiler, inplace) + + def rfloordiv( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Get integer division of ``DataFrame`` and `other`, element-wise (binary operator `rfloordiv`). + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op( + "rfloordiv", + other, + axis=axis, + level=level, + fill_value=fill_value, + ) + + def radd( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Get addition of ``DataFrame`` and `other`, element-wise (binary operator `radd`). + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op( + "radd", + other, + axis=axis, + level=level, + fill_value=fill_value, + ) + + def rmod( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Get modulo of ``DataFrame`` and `other`, element-wise (binary operator `rmod`). + """ + return self._binary_op( + "rmod", + other, + axis=axis, + level=level, + fill_value=fill_value, + ) + + def round(self, decimals=0, *args, **kwargs): # noqa: PR01, RT01, D200 + return super().round(decimals, args=args, **kwargs) + + def rpow( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Get exponential power of ``DataFrame`` and `other`, element-wise (binary operator `rpow`). + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op( + "rpow", + other, + axis=axis, + level=level, + fill_value=fill_value, + ) + + def rsub( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Get subtraction of ``DataFrame`` and `other`, element-wise (binary operator `rsub`). + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op( + "rsub", + other, + axis=axis, + level=level, + fill_value=fill_value, + ) + + def rtruediv( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Get floating division of ``DataFrame`` and `other`, element-wise (binary operator `rtruediv`). + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op( + "rtruediv", + other, + axis=axis, + level=level, + fill_value=fill_value, + ) + + rdiv = rtruediv + + def select_dtypes( + self, + include: ListLike | str | type | None = None, + exclude: ListLike | str | type | None = None, + ) -> DataFrame: + """ + Return a subset of the ``DataFrame``'s columns based on the column dtypes. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + # This line defers argument validation to pandas, which will raise errors on our behalf in cases + # like if `include` and `exclude` are None, the same type is specified in both lists, or a string + # dtype (as opposed to object) is specified. + pandas.DataFrame().select_dtypes(include, exclude) + + if include and not is_list_like(include): + include = [include] + elif include is None: + include = [] + if exclude and not is_list_like(exclude): + exclude = [exclude] + elif exclude is None: + exclude = [] + + sel = tuple(map(set, (include, exclude))) + + # The width of the np.int_/float_ alias differs between Windows and other platforms, so + # we need to include a workaround. + # https://github.com/numpy/numpy/issues/9464 + # https://github.com/pandas-dev/pandas/blob/f538741432edf55c6b9fb5d0d496d2dd1d7c2457/pandas/core/frame.py#L5036 + def check_sized_number_infer_dtypes(dtype): + if (isinstance(dtype, str) and dtype == "int") or (dtype is int): + return [np.int32, np.int64] + elif dtype == "float" or dtype is float: + return [np.float64, np.float32] + else: + return [infer_dtype_from_object(dtype)] + + include, exclude = map( + lambda x: set( + itertools.chain.from_iterable(map(check_sized_number_infer_dtypes, x)) + ), + sel, + ) + # We need to index on column position rather than label in case of duplicates + include_these = pandas.Series(not bool(include), index=range(len(self.columns))) + exclude_these = pandas.Series(not bool(exclude), index=range(len(self.columns))) + + def is_dtype_instance_mapper(dtype): + return functools.partial(issubclass, dtype.type) + + for i, dtype in enumerate(self.dtypes): + if include: + include_these[i] = any(map(is_dtype_instance_mapper(dtype), include)) + if exclude: + exclude_these[i] = not any( + map(is_dtype_instance_mapper(dtype), exclude) + ) + + dtype_indexer = include_these & exclude_these + indicate = [i for i, should_keep in dtype_indexer.items() if should_keep] + # We need to use iloc instead of drop in case of duplicate column names + return self.iloc[:, indicate] + + def shift( + self, + periods: int = 1, + freq=None, + axis: Axis = 0, + fill_value: Hashable = no_default, + ) -> DataFrame: + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return super().shift(periods, freq, axis, fill_value) + + def set_index( + self, + keys: IndexLabel + | list[IndexLabel | pd.Index | pd.Series | list | np.ndarray | Iterable], + drop: bool = True, + append: bool = False, + inplace: bool = False, + verify_integrity: bool = False, + ) -> None | DataFrame: + """ + Set the ``DataFrame`` index using existing columns. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + inplace = validate_bool_kwarg(inplace, "inplace") + if not isinstance(keys, list): + keys = [keys] + + # make sure key is either hashable, index, or series + label_or_series = [] + + missing = [] + columns = self.columns.tolist() + for key in keys: + raise_if_native_pandas_objects(key) + if isinstance(key, pd.Series): + label_or_series.append(key._query_compiler) + elif isinstance(key, (np.ndarray, list, Iterator)): + label_or_series.append(pd.Series(key)._query_compiler) + elif isinstance(key, pd.Index): + label_or_series += [ + s._query_compiler for s in self._to_series_list(key) + ] + else: + if not is_hashable(key): + raise TypeError( + f'The parameter "keys" may be a column key, one-dimensional array, or a list ' + f"containing only valid column keys and one-dimensional arrays. Received column " + f"of type {type(key)}" + ) + label_or_series.append(key) + found = key in columns + if columns.count(key) > 1: + raise ValueError(f"The column label '{key}' is not unique") + elif not found: + missing.append(key) + + if missing: + raise KeyError(f"None of {missing} are in the columns") + + new_query_compiler = self._query_compiler.set_index( + label_or_series, drop=drop, append=append + ) + + # TODO: SNOW-782633 improve this code once duplicate is supported + # this needs to pull all index which is inefficient + if verify_integrity and not new_query_compiler.index.is_unique: + duplicates = new_query_compiler.index[ + new_query_compiler.index.duplicated() + ].unique() + raise ValueError(f"Index has duplicate keys: {duplicates}") + + return self._create_or_update_from_compiler(new_query_compiler, inplace=inplace) + + sparse = CachedAccessor("sparse", SparseFrameAccessor) + + def squeeze(self, axis: Axis | None = None): + """ + Squeeze 1 dimensional axis objects into scalars. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + axis = self._get_axis_number(axis) if axis is not None else None + len_columns = self._query_compiler.get_axis_len(1) + if axis == 1 and len_columns == 1: + return Series(query_compiler=self._query_compiler) + # get_axis_len(0) results in a sql query to count number of rows in current + # dataframe. We should only compute len_index if axis is 0 or None. + len_index = len(self) + if axis is None and (len_columns == 1 or len_index == 1): + return Series(query_compiler=self._query_compiler).squeeze() + if axis == 0 and len_index == 1: + return Series(query_compiler=self.T._query_compiler) + else: + return self.copy() + + def stack(self, level=-1, dropna=True): # noqa: PR01, RT01, D200 + """ + Stack the prescribed level(s) from columns to index. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + if not isinstance(self.columns, pandas.MultiIndex) or ( + isinstance(self.columns, pandas.MultiIndex) + and is_list_like(level) + and len(level) == self.columns.nlevels + ): + return self._reduce_dimension( + query_compiler=self._query_compiler.stack(level, dropna) + ) + else: + return self.__constructor__( + query_compiler=self._query_compiler.stack(level, dropna) + ) + + def sub( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Get subtraction of ``DataFrame`` and `other`, element-wise (binary operator `sub`). + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op( + "sub", + other, + axis=axis, + level=level, + fill_value=fill_value, + ) + + subtract = sub + + def to_feather(self, path, **kwargs): # pragma: no cover # noqa: PR01, RT01, D200 + """ + Write a ``DataFrame`` to the binary Feather format. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + return self._default_to_pandas(pandas.DataFrame.to_feather, path, **kwargs) + + def to_gbq( + self, + destination_table, + project_id=None, + chunksize=None, + reauth=False, + if_exists="fail", + auth_local_webserver=True, + table_schema=None, + location=None, + progress_bar=True, + credentials=None, + ): # pragma: no cover # noqa: PR01, RT01, D200 + """ + Write a ``DataFrame`` to a Google BigQuery table. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functionsf + ErrorMessage.not_implemented() + return self._default_to_pandas( + pandas.DataFrame.to_gbq, + destination_table, + project_id=project_id, + chunksize=chunksize, + reauth=reauth, + if_exists=if_exists, + auth_local_webserver=auth_local_webserver, + table_schema=table_schema, + location=location, + progress_bar=progress_bar, + credentials=credentials, + ) + + def to_orc(self, path=None, *, engine="pyarrow", index=None, engine_kwargs=None): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + return self._default_to_pandas( + pandas.DataFrame.to_orc, + path=path, + engine=engine, + index=index, + engine_kwargs=engine_kwargs, + ) + + def to_html( + self, + buf=None, + columns=None, + col_space=None, + header=True, + index=True, + na_rep="NaN", + formatters=None, + float_format=None, + sparsify=None, + index_names=True, + justify=None, + max_rows=None, + max_cols=None, + show_dimensions=False, + decimal=".", + bold_rows=True, + classes=None, + escape=True, + notebook=False, + border=None, + table_id=None, + render_links=False, + encoding=None, + ): # noqa: PR01, RT01, D200 + """ + Render a ``DataFrame`` as an HTML table. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + return self._default_to_pandas( + pandas.DataFrame.to_html, + buf=buf, + columns=columns, + col_space=col_space, + header=header, + index=index, + na_rep=na_rep, + formatters=formatters, + float_format=float_format, + sparsify=sparsify, + index_names=index_names, + justify=justify, + max_rows=max_rows, + max_cols=max_cols, + show_dimensions=show_dimensions, + decimal=decimal, + bold_rows=bold_rows, + classes=classes, + escape=escape, + notebook=notebook, + border=border, + table_id=table_id, + render_links=render_links, + encoding=None, + ) + + def to_parquet( + self, + path=None, + engine="auto", + compression="snappy", + index=None, + partition_cols=None, + storage_options: StorageOptions = None, + **kwargs, + ): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + from snowflake.snowpark.modin.pandas.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + return FactoryDispatcher.to_parquet( + self._query_compiler, + path=path, + engine=engine, + compression=compression, + index=index, + partition_cols=partition_cols, + storage_options=storage_options, + **kwargs, + ) + + def to_period( + self, freq=None, axis=0, copy=True + ): # pragma: no cover # noqa: PR01, RT01, D200 + """ + Convert ``DataFrame`` from ``DatetimeIndex`` to ``PeriodIndex``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + return super().to_period(freq=freq, axis=axis, copy=copy) + + def to_records( + self, index=True, column_dtypes=None, index_dtypes=None + ): # noqa: PR01, RT01, D200 + """ + Convert ``DataFrame`` to a NumPy record array. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + return self._default_to_pandas( + pandas.DataFrame.to_records, + index=index, + column_dtypes=column_dtypes, + index_dtypes=index_dtypes, + ) + + def to_stata( + self, + path: FilePath | WriteBuffer[bytes], + convert_dates: dict[Hashable, str] | None = None, + write_index: bool = True, + byteorder: str | None = None, + time_stamp: datetime.datetime | None = None, + data_label: str | None = None, + variable_labels: dict[Hashable, str] | None = None, + version: int | None = 114, + convert_strl: Sequence[Hashable] | None = None, + compression: CompressionOptions = "infer", + storage_options: StorageOptions = None, + *, + value_labels: dict[Hashable, dict[float | int, str]] | None = None, + ): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + return self._default_to_pandas( + pandas.DataFrame.to_stata, + path, + convert_dates=convert_dates, + write_index=write_index, + byteorder=byteorder, + time_stamp=time_stamp, + data_label=data_label, + variable_labels=variable_labels, + version=version, + convert_strl=convert_strl, + compression=compression, + storage_options=storage_options, + value_labels=value_labels, + ) + + def to_xml( + self, + path_or_buffer=None, + index=True, + root_name="data", + row_name="row", + na_rep=None, + attr_cols=None, + elem_cols=None, + namespaces=None, + prefix=None, + encoding="utf-8", + xml_declaration=True, + pretty_print=True, + parser="lxml", + stylesheet=None, + compression="infer", + storage_options=None, + ): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + return self.__constructor__( + query_compiler=self._query_compiler.default_to_pandas( + pandas.DataFrame.to_xml, + path_or_buffer=path_or_buffer, + index=index, + root_name=root_name, + row_name=row_name, + na_rep=na_rep, + attr_cols=attr_cols, + elem_cols=elem_cols, + namespaces=namespaces, + prefix=prefix, + encoding=encoding, + xml_declaration=xml_declaration, + pretty_print=pretty_print, + parser=parser, + stylesheet=stylesheet, + compression=compression, + storage_options=storage_options, + ) + ) + + def to_dict( + self, + orient: Literal[ + "dict", "list", "series", "split", "tight", "records", "index" + ] = "dict", + into: type[dict] = dict, + ) -> dict | list[dict]: + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._to_pandas().to_dict(orient=orient, into=into) + + def to_timestamp( + self, freq=None, how="start", axis=0, copy=True + ): # noqa: PR01, RT01, D200 + """ + Cast to DatetimeIndex of timestamps, at *beginning* of period. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + return super().to_timestamp(freq=freq, how=how, axis=axis, copy=copy) + + def truediv( + self, other, axis="columns", level=None, fill_value=None + ): # noqa: PR01, RT01, D200 + """ + Get floating division of ``DataFrame`` and `other`, element-wise (binary operator `truediv`). + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op( + "truediv", + other, + axis=axis, + level=level, + fill_value=fill_value, + ) + + div = divide = truediv + + def update( + self, other, join="left", overwrite=True, filter_func=None, errors="ignore" + ): # noqa: PR01, RT01, D200 + """ + Modify in place using non-NA values from another ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if not isinstance(other, DataFrame): + other = self.__constructor__(other) + query_compiler = self._query_compiler.df_update( + other._query_compiler, + join=join, + overwrite=overwrite, + filter_func=filter_func, + errors=errors, + ) + self._update_inplace(new_query_compiler=query_compiler) + + def diff( + self, + periods: int = 1, + axis: Axis = 0, + ): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return super().diff( + periods=periods, + axis=axis, + ) + + def drop( + self, + labels: IndexLabel = None, + axis: Axis = 0, + index: IndexLabel = None, + columns: IndexLabel = None, + level: Level = None, + inplace: bool = False, + errors: IgnoreRaise = "raise", + ): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return super().drop( + labels=labels, + axis=axis, + index=index, + columns=columns, + level=level, + inplace=inplace, + errors=errors, + ) + + def value_counts( + self, + subset: Sequence[Hashable] | None = None, + normalize: bool = False, + sort: bool = True, + ascending: bool = False, + dropna: bool = True, + ): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return Series( + query_compiler=self._query_compiler.value_counts( + subset=subset, + normalize=normalize, + sort=sort, + ascending=ascending, + dropna=dropna, + ), + name="proportion" if normalize else "count", + ) + + def mask( + self, + cond: DataFrame | Series | Callable | AnyArrayLike, + other: DataFrame | Series | Callable | Scalar | None = np.nan, + *, + inplace: bool = False, + axis: Axis | None = None, + level: Level | None = None, + ): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if isinstance(other, Series) and axis is None: + raise ValueError( + "df.mask requires an axis parameter (0 or 1) when given a Series" + ) + + return super().mask( + cond, + other=other, + inplace=inplace, + axis=axis, + level=level, + ) + + def where( + self, + cond: DataFrame | Series | Callable | AnyArrayLike, + other: DataFrame | Series | Callable | Scalar | None = np.nan, + *, + inplace: bool = False, + axis: Axis | None = None, + level: Level | None = None, + ): + """ + Replace values where the condition is False. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if isinstance(other, Series) and axis is None: + raise ValueError( + "df.where requires an axis parameter (0 or 1) when given a Series" + ) + + return super().where( + cond, + other=other, + inplace=inplace, + axis=axis, + level=level, + ) + + def xs(self, key, axis=0, level=None, drop_level=True): # noqa: PR01, RT01, D200 + """ + Return cross-section from the ``DataFrame``. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + return self._default_to_pandas( + pandas.DataFrame.xs, key, axis=axis, level=level, drop_level=drop_level + ) + + def set_axis( + self, + labels: IndexLabel, + *, + axis: Axis = 0, + copy: bool | NoDefault = no_default, # ignored + ): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if not is_scalar(axis): + raise TypeError(f"{type(axis).__name__} is not a valid type for axis.") + return super().set_axis( + labels=labels, + # 'columns', 'rows, 'index, 0, and 1 are the only valid axis values for df. + axis=pandas.DataFrame._get_axis_name(axis), + copy=copy, + ) + + def __getattr__(self, key): + """ + Return item identified by `key`. + + Parameters + ---------- + key : hashable + Key to get. + + Returns + ------- + Any + + Notes + ----- + First try to use `__getattribute__` method. If it fails + try to get `key` from ``DataFrame`` fields. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + try: + return object.__getattribute__(self, key) + except AttributeError as err: + if key not in _ATTRS_NO_LOOKUP and key in self.columns: + return self[key] + raise err + + def __setattr__(self, key, value): + """ + Set attribute `value` identified by `key`. + + Parameters + ---------- + key : hashable + Key to set. + value : Any + Value to set. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + # While we let users assign to a column labeled "x" with "df.x" , there + # are some attributes that we should assume are NOT column names and + # therefore should follow the default Python object assignment + # behavior. These are: + # - anything in self.__dict__. This includes any attributes that the + # user has added to the dataframe with, e.g., `df.c = 3`, and + # any attribute that Modin has added to the frame, e.g. + # `_query_compiler` and `_siblings` + # - `_query_compiler`, which Modin initializes before it appears in + # __dict__ + # - `_siblings`, which Modin initializes before it appears in __dict__ + # - `_cache`, which pandas.cache_readonly uses to cache properties + # before it appears in __dict__. + if key in ("_query_compiler", "_siblings", "_cache") or key in self.__dict__: + pass + elif key in self and key not in dir(self): + self.__setitem__(key, value) + # Note: return immediately so we don't keep this `key` as dataframe state. + # `__getattr__` will return the columns not present in `dir(self)`, so we do not need + # to manually track this state in the `dir`. + return + elif is_list_like(value) and key not in ["index", "columns"]: + WarningMessage.single_warning( + SET_DATAFRAME_ATTRIBUTE_WARNING + ) # pragma: no cover + object.__setattr__(self, key, value) + + def __setitem__(self, key: Any, value: Any): + """ + Set attribute `value` identified by `key`. + + Args: + key: Key to set + value: Value to set + + Note: + In the case where value is any list like or array, pandas checks the array length against the number of rows + of the input dataframe. If there is a mismatch, a ValueError is raised. Snowpark pandas indexing won't throw + a ValueError because knowing the length of the current dataframe can trigger eager evaluations; instead if + the array is longer than the number of rows we ignore the additional values. If the array is shorter, we use + enlargement filling with the last value in the array. + + Returns: + None + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + key = apply_if_callable(key, self) + if isinstance(key, DataFrame) or ( + isinstance(key, np.ndarray) and len(key.shape) == 2 + ): + # This case uses mask's codepath to perform the set, but + # we need to duplicate the code here since we are passing + # an additional kwarg `cond_fillna_with_true` to the QC here. + # We need this additional kwarg, since if df.shape + # and key.shape do not align (i.e. df has more rows), + # mask's codepath would mask the additional rows in df + # while for setitem, we need to keep the original values. + if not isinstance(key, DataFrame): + if key.dtype != bool: + raise TypeError( + "Must pass DataFrame or 2-d ndarray with boolean values only" + ) + key = DataFrame(key) + key._query_compiler._shape_hint = "array" + + if value is not None: + value = apply_if_callable(value, self) + + if isinstance(value, np.ndarray): + value = DataFrame(value) + value._query_compiler._shape_hint = "array" + elif isinstance(value, pd.Series): + # pandas raises the `mask` ValueError here: Must specify axis = 0 or 1. We raise this + # error instead, since it is more descriptive. + raise ValueError( + "setitem with a 2D key does not support Series values." + ) + + if isinstance(value, BasePandasDataset): + value = value._query_compiler + + query_compiler = self._query_compiler.mask( + cond=key._query_compiler, + other=value, + axis=None, + level=None, + cond_fillna_with_true=True, + ) + + return self._create_or_update_from_compiler(query_compiler, inplace=True) + + # Error Checking: + if (isinstance(key, pd.Series) or is_list_like(key)) and ( + isinstance(value, range) + ): + raise NotImplementedError(DF_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE) + elif isinstance(value, slice): + # Here, the whole slice is assigned as a scalar variable, i.e., a spot at an index gets a slice value. + raise NotImplementedError(DF_SETITEM_SLICE_AS_SCALAR_VALUE) + + # Note: when key is a boolean indexer or slice the key is a row key; otherwise, the key is always a column + # key. + index, columns = slice(None), key + index_is_bool_indexer = False + if isinstance(key, slice): + if is_integer(key.start) and is_integer(key.stop): + # when slice are integer slice, e.g., df[1:2] = val, the behavior is the same as + # df.iloc[1:2, :] = val + self.iloc[key] = value + return + index, columns = key, slice(None) + elif isinstance(key, pd.Series): + if is_bool_dtype(key.dtype): + index, columns = key, slice(None) + index_is_bool_indexer = True + elif is_bool_indexer(key): + index, columns = pd.Series(key), slice(None) + index_is_bool_indexer = True + + # The reason we do not call loc directly is that setitem has different behavior compared to loc in this case + # we have to explicitly set matching_item_columns_by_label to False for setitem. + index = index._query_compiler if isinstance(index, BasePandasDataset) else index + columns = ( + columns._query_compiler + if isinstance(columns, BasePandasDataset) + else columns + ) + from .indexing import is_2d_array + + matching_item_rows_by_label = not is_2d_array(value) + if is_2d_array(value): + value = DataFrame(value) + item = value._query_compiler if isinstance(value, BasePandasDataset) else value + new_qc = self._query_compiler.set_2d_labels( + index, + columns, + item, + # setitem always matches item by position + matching_item_columns_by_label=False, + matching_item_rows_by_label=matching_item_rows_by_label, + index_is_bool_indexer=index_is_bool_indexer, + # setitem always deduplicates columns. E.g., if df has two columns "A" and "B", after calling + # df[["A","A"]] = item, df still only has two columns "A" and "B", and "A"'s values are set by the + # second "A" column from value; instead, if we call df.loc[:, ["A", "A"]] = item, then df will have + # three columns "A", "A", "B". Similarly, if we call df[["X","X"]] = item, df will have three columns + # "A", "B", "X", while if we call df.loc[:, ["X", "X"]] = item, then df will have four columns "A", "B", + # "X", "X". + deduplicate_columns=True, + ) + return self._update_inplace(new_query_compiler=new_qc) + + def abs(self): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return super().abs() + + def __and__(self, other): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op("__and__", other, axis=1) + + def __rand__(self, other): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op("__rand__", other, axis=1) + + def __or__(self, other): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op("__or__", other, axis=1) + + def __ror__(self, other): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._binary_op("__ror__", other, axis=1) + + def __neg__(self): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return super().__neg__() + + def __iter__(self): + """ + Iterate over info axis. + + Returns + ------- + iterable + Iterator of the columns names. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return iter(self.columns) + + def __contains__(self, key): + """ + Check if `key` in the ``DataFrame.columns``. + + Parameters + ---------- + key : hashable + Key to check the presence in the columns. + + Returns + ------- + bool + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self.columns.__contains__(key) + + def __round__(self, decimals=0): + """ + Round each value in a ``DataFrame`` to the given number of decimals. + + Parameters + ---------- + decimals : int, default: 0 + Number of decimal places to round to. + + Returns + ------- + DataFrame + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return super().round(decimals) + + def __delitem__(self, key): + """ + Delete item identified by `key` label. + + Parameters + ---------- + key : hashable + Key to delete. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() # pragma: no cover + + if key not in self: + raise KeyError(key) + self._update_inplace(new_query_compiler=self._query_compiler.delitem(key)) + + __add__ = add + __iadd__ = add # pragma: no cover + __radd__ = radd + __mul__ = mul + __imul__ = mul # pragma: no cover + __rmul__ = rmul + __pow__ = pow + __ipow__ = pow # pragma: no cover + __rpow__ = rpow + __sub__ = sub + __isub__ = sub # pragma: no cover + __rsub__ = rsub + __floordiv__ = floordiv + __ifloordiv__ = floordiv # pragma: no cover + __rfloordiv__ = rfloordiv + __truediv__ = truediv + __itruediv__ = truediv # pragma: no cover + __rtruediv__ = rtruediv + __mod__ = mod + __imod__ = mod # pragma: no cover + __rmod__ = rmod + __rdiv__ = rdiv + + def __dataframe__(self, nan_as_null: bool = False, allow_copy: bool = True): + """ + Get a Modin DataFrame that implements the dataframe exchange protocol. + + See more about the protocol in https://data-apis.org/dataframe-protocol/latest/index.html. + + Parameters + ---------- + nan_as_null : bool, default: False + A keyword intended for the consumer to tell the producer + to overwrite null values in the data with ``NaN`` (or ``NaT``). + This currently has no effect; once support for nullable extension + dtypes is added, this value should be propagated to columns. + allow_copy : bool, default: True + A keyword that defines whether or not the library is allowed + to make a copy of the data. For example, copying data would be necessary + if a library supports strided buffers, given that this protocol + specifies contiguous buffers. Currently, if the flag is set to ``False`` + and a copy is needed, a ``RuntimeError`` will be raised. + + Returns + ------- + ProtocolDataframe + A dataframe object following the dataframe protocol specification. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() # pragma: no cover + + return self._query_compiler.to_dataframe( + nan_as_null=nan_as_null, allow_copy=allow_copy + ) + + @property + def attrs(self): # noqa: RT01, D200 + """ + Return dictionary of global attributes of this dataset. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() # pragma: no cover + + def attrs(df): + return df.attrs + + return self._default_to_pandas(attrs) + + @property + def style(self): # noqa: RT01, D200 + """ + Return a Styler object. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() + + def style(df): + """Define __name__ attr because properties do not have it.""" + return df.style + + return self._default_to_pandas(style) + + def isin( + self, values: ListLike | Series | DataFrame | dict[Hashable, ListLike] + ) -> DataFrame: + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if isinstance(values, dict): + return super().isin(values) + elif isinstance(values, Series): + # Note: pandas performs explicit is_unique check here, deactivated for performance reasons. + # if not values.index.is_unique: + # raise ValueError("cannot compute isin with a duplicate axis.") + return self.__constructor__( + query_compiler=self._query_compiler.isin(values._query_compiler) + ) + elif isinstance(values, DataFrame): + # Note: pandas performs explicit is_unique check here, deactivated for performance reasons. + # if not (values.columns.is_unique and values.index.is_unique): + # raise ValueError("cannot compute isin with a duplicate axis.") + return self.__constructor__( + query_compiler=self._query_compiler.isin(values._query_compiler) + ) + else: + if not is_list_like(values): + # throw pandas compatible error + raise TypeError( + "only list-like or dict-like objects are allowed " + f"to be passed to {self.__class__.__name__}.isin(), " + f"you passed a '{type(values).__name__}'" + ) + return super().isin(values) + + def _create_or_update_from_compiler(self, new_query_compiler, inplace=False): + """ + Return or update a ``DataFrame`` with given `new_query_compiler`. + + Parameters + ---------- + new_query_compiler : PandasQueryCompiler + QueryCompiler to use to manage the data. + inplace : bool, default: False + Whether or not to perform update or creation inplace. + + Returns + ------- + DataFrame or None + None if update was done, ``DataFrame`` otherwise. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + assert ( + isinstance(new_query_compiler, type(self._query_compiler)) + or type(new_query_compiler) in self._query_compiler.__class__.__bases__ + ), f"Invalid Query Compiler object: {type(new_query_compiler)}" + if not inplace: + return self.__constructor__(query_compiler=new_query_compiler) + else: + self._update_inplace(new_query_compiler=new_query_compiler) + + def _get_numeric_data(self, axis: int): + """ + Grab only numeric data from ``DataFrame``. + + Parameters + ---------- + axis : {0, 1} + Axis to inspect on having numeric types only. + + Returns + ------- + DataFrame + ``DataFrame`` with numeric data. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + # pandas ignores `numeric_only` if `axis` is 1, but we do have to drop + # non-numeric columns if `axis` is 0. + if axis != 0: + return self + return self.drop( + columns=[ + i for i in self.dtypes.index if not is_numeric_dtype(self.dtypes[i]) + ] + ) + + def _validate_dtypes(self, numeric_only=False): + """ + Check that all the dtypes are the same. + + Parameters + ---------- + numeric_only : bool, default: False + Whether or not to allow only numeric data. + If True and non-numeric data is found, exception + will be raised. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + dtype = self.dtypes[0] + for t in self.dtypes: + if numeric_only and not is_numeric_dtype(t): + raise TypeError(f"{t} is not a numeric data type") + elif not numeric_only and t != dtype: + raise TypeError(f"Cannot compare type '{t}' with type '{dtype}'") + + def _validate_dtypes_sum_prod_mean(self, axis, numeric_only, ignore_axis=False): + """ + Validate data dtype for `sum`, `prod` and `mean` methods. + + Parameters + ---------- + axis : {0, 1} + Axis to validate over. + numeric_only : bool + Whether or not to allow only numeric data. + If True and non-numeric data is found, exception + will be raised. + ignore_axis : bool, default: False + Whether or not to ignore `axis` parameter. + + Returns + ------- + DataFrame + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + # We cannot add datetime types, so if we are summing a column with + # dtype datetime64 and cannot ignore non-numeric types, we must throw a + # TypeError. + if ( + not axis + and numeric_only is False + and any(dtype == np.dtype("datetime64[ns]") for dtype in self.dtypes) + ): + raise TypeError("Cannot add Timestamp Types") + + # If our DataFrame has both numeric and non-numeric dtypes then + # operations between these types do not make sense and we must raise a + # TypeError. The exception to this rule is when there are datetime and + # timedelta objects, in which case we proceed with the comparison + # without ignoring any non-numeric types. We must check explicitly if + # numeric_only is False because if it is None, it will default to True + # if the operation fails with mixed dtypes. + if ( + (axis or ignore_axis) + and numeric_only is False + and np.unique([is_numeric_dtype(dtype) for dtype in self.dtypes]).size == 2 + ): + # check if there are columns with dtypes datetime or timedelta + if all( + dtype != np.dtype("datetime64[ns]") + and dtype != np.dtype("timedelta64[ns]") + for dtype in self.dtypes + ): + raise TypeError("Cannot operate on Numeric and Non-Numeric Types") + + return self._get_numeric_data(axis) if numeric_only else self + + def _to_pandas( + self, + *, + statement_params: dict[str, str] | None = None, + **kwargs: Any, + ) -> pandas.DataFrame: + """ + Convert Snowpark pandas DataFrame to pandas DataFrame + + Args: + statement_params: Dictionary of statement level parameters to be set while executing this action. + + Returns: + pandas DataFrame + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._query_compiler.to_pandas( + statement_params=statement_params, **kwargs + ) + + def _validate_eval_query(self, expr, **kwargs): + """ + Validate the arguments of ``eval`` and ``query`` functions. + + Parameters + ---------- + expr : str + The expression to evaluate. This string cannot contain any + Python statements, only Python expressions. + **kwargs : dict + Optional arguments of ``eval`` and ``query`` functions. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + if isinstance(expr, str) and expr == "": + raise ValueError("expr cannot be an empty string") + + if isinstance(expr, str) and "not" in expr: + if "parser" in kwargs and kwargs["parser"] == "python": + ErrorMessage.not_implemented() # pragma: no cover + + def _reduce_dimension(self, query_compiler): + """ + Reduce the dimension of data from the `query_compiler`. + + Parameters + ---------- + query_compiler : BaseQueryCompiler + Query compiler to retrieve the data. + + Returns + ------- + Series + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return Series(query_compiler=query_compiler) + + def _set_axis_name(self, name, axis=0, inplace=False): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + axis = self._get_axis_number(axis) + renamed = self if inplace else self.copy() + if axis == 0: + renamed.index = renamed.index.set_names(name) + else: + renamed.columns = renamed.columns.set_names(name) + if not inplace: + return renamed + + def _to_datetime(self, **kwargs): + """ + Convert `self` to datetime. + + Parameters + ---------- + **kwargs : dict + Optional arguments to use during query compiler's + `to_datetime` invocation. + + Returns + ------- + Series of datetime64 dtype + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return self._reduce_dimension( + query_compiler=self._query_compiler.dataframe_to_datetime(**kwargs) + ) + + # Persistance support methods - BEGIN + @classmethod + def _inflate_light(cls, query_compiler): + """ + Re-creates the object from previously-serialized lightweight representation. + + The method is used for faster but not disk-storable persistence. + + Parameters + ---------- + query_compiler : BaseQueryCompiler + Query compiler to use for object re-creation. + + Returns + ------- + DataFrame + New ``DataFrame`` based on the `query_compiler`. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return cls(query_compiler=query_compiler) + + @classmethod + def _inflate_full(cls, pandas_df): + """ + Re-creates the object from previously-serialized disk-storable representation. + + Parameters + ---------- + pandas_df : pandas.DataFrame + Data to use for object re-creation. + + Returns + ------- + DataFrame + New ``DataFrame`` based on the `pandas_df`. + """ + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + return cls(data=from_pandas(pandas_df)) + + def __reduce__(self): + # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions + ErrorMessage.not_implemented() # pragma: no cover + + self._query_compiler.finalize() + # if PersistentPickle.get(): + # return self._inflate_full, (self._to_pandas(),) + return self._inflate_light, (self._query_compiler,) + + # Persistance support methods - END diff --git a/src/snowflake/snowpark/modin/pandas/general.py b/src/snowflake/snowpark/modin/pandas/general.py new file mode 100644 index 00000000000..d0ac54caeea --- /dev/null +++ b/src/snowflake/snowpark/modin/pandas/general.py @@ -0,0 +1,2387 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +"""Implement pandas general API.""" +from __future__ import annotations + +from collections.abc import Hashable, Iterable, Mapping, Sequence +from datetime import date, datetime, tzinfo +from logging import getLogger +from typing import TYPE_CHECKING, Any, Literal, Union + +import numpy as np +import pandas +import pandas.core.common as common +from pandas import IntervalIndex, NaT, Timestamp +from pandas._libs import NaTType, lib +from pandas._libs.tslibs import to_offset +from pandas._typing import ( + AnyArrayLike, + Axis, + DateTimeErrorChoices, + IndexLabel, + IntervalClosedType, + Scalar, + Suffixes, +) +from pandas.core.arrays import datetimelike +from pandas.core.arrays.datetimes import ( + _infer_tz_from_endpoints, + _maybe_normalize_endpoints, +) +from pandas.core.dtypes.common import is_list_like +from pandas.core.dtypes.inference import is_array_like +from pandas.core.tools.datetimes import ( + ArrayConvertible, + DatetimeScalar, + DatetimeScalarOrArrayConvertible, + DictConvertible, +) +from pandas.util._validators import validate_inclusive + +# add this line to make doctests runnable +from snowflake.snowpark.modin import pandas as pd # noqa: F401 +from snowflake.snowpark.modin.pandas.base import BasePandasDataset +from snowflake.snowpark.modin.pandas.dataframe import DataFrame +from snowflake.snowpark.modin.pandas.series import Series +from snowflake.snowpark.modin.pandas.utils import ( + is_scalar, + raise_if_native_pandas_objects, +) +from snowflake.snowpark.modin.plugin._internal.telemetry import ( + snowpark_pandas_telemetry_standalone_function_decorator, +) +from snowflake.snowpark.modin.plugin._internal.timestamp_utils import ( + VALID_TO_DATETIME_UNIT, +) +from snowflake.snowpark.modin.plugin._typing import ListLike, ListLikeOfFloats +from snowflake.snowpark.modin.plugin.compiler import BaseQueryCompiler +from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import ( + SnowflakeQueryCompiler, +) +from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage +from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage +from snowflake.snowpark.modin.utils import _inherit_docstrings, to_pandas + +if TYPE_CHECKING: + # To prevent cross-reference warnings when building documentation and prevent erroneously + # linking to `snowflake.snowpark.DataFrame`, we need to explicitly + # qualify return types in this file with `snowflake.snowpark.modin.pandas.DataFrame`. + # SNOW-1233342: investigate how to fix these links without using absolute paths + import snowflake # pragma: no cover + +_logger = getLogger(__name__) + +VALID_DATE_TYPE = Union[ + np.integer, float, str, date, datetime, np.datetime64, pd.Timestamp +] + + +@snowpark_pandas_telemetry_standalone_function_decorator +@_inherit_docstrings(pandas.isna, apilink="pandas.isna") +def isna(obj): # noqa: PR01, RT01, D200 + """ + Detect missing values for an array-like object. + """ + # TODO: SNOW-1063345: Modin upgrade - modin.pandas functions in general.py + if isinstance(obj, BasePandasDataset): + return obj.isna() + else: + return pandas.isna(obj) + + +isnull = isna + + +@snowpark_pandas_telemetry_standalone_function_decorator +@_inherit_docstrings(pandas.notna, apilink="pandas.notna") +def notna(obj): # noqa: PR01, RT01, D200 + """ + Detect non-missing values for an array-like object. + """ + # TODO: SNOW-1063345: Modin upgrade - modin.pandas functions in general.py + if isinstance(obj, BasePandasDataset): + return obj.notna() + else: + return pandas.notna(obj) + + +notnull = notna + + +@snowpark_pandas_telemetry_standalone_function_decorator +def merge( + left: snowflake.snowpark.modin.pandas.DataFrame | Series, + right: snowflake.snowpark.modin.pandas.DataFrame | Series, + how: str | None = "inner", + on: IndexLabel | None = None, + left_on: None + | (Hashable | AnyArrayLike | Sequence[Hashable | AnyArrayLike]) = None, + right_on: None + | (Hashable | AnyArrayLike | Sequence[Hashable | AnyArrayLike]) = None, + left_index: bool | None = False, + right_index: bool | None = False, + sort: bool | None = False, + suffixes: Suffixes | None = ("_x", "_y"), + copy: bool | None = True, + indicator: bool | str | None = False, + validate: str | None = None, +): + """ + Merge DataFrame or named Series objects with a database-style join. + + A named Series object is treated as a DataFrame with a single named column. + + The join is done on columns or indexes. If joining columns on + columns, the DataFrame indexes *will be ignored*. Otherwise if joining indexes + on indexes or indexes on a column or columns, the index will be passed on. + When performing a cross merge, no column specifications to merge on are + allowed. + + .. warning:: + + If both key columns contain rows where the key is a null value, those + rows will be matched against each other. This is different from usual SQL + join behaviour and can lead to unexpected results. + + Parameters + ---------- + left : :class:`~snowflake.snowpark.modin.pandas.DataFrame` or named Series + right : :class:`~snowflake.snowpark.modin.pandas.DataFrame` or named Series + Object to merge with. + how : {'left', 'right', 'outer', 'inner', 'cross'}, default 'inner' + Type of merge to be performed. + + * left: use only keys from left frame, similar to a SQL left outer join; + preserve key order. + * right: use only keys from right frame, similar to a SQL right outer join; + preserve key order. + * outer: use union of keys from both frames, similar to a SQL full outer + join; sort keys lexicographically. + * inner: use intersection of keys from both frames, similar to a SQL inner + join; preserve the order of the left keys. + * cross: creates the cartesian product from both frames, preserves the order + of the left keys. + + on : label or list + Column or index level names to join on. These must be found in both + DataFrames. If `on` is None and not merging on indexes then this defaults + to the intersection of the columns in both DataFrames. + left_on : label or list, or array-like + Column or index level names to join on in the left DataFrame. Can also + be an array or list of arrays of the length of the left DataFrame. + These arrays are treated as if they are columns. + right_on : label or list, or array-like + Column or index level names to join on in the right DataFrame. Can also + be an array or list of arrays of the length of the right DataFrame. + These arrays are treated as if they are columns. + left_index : bool, default False + Use the index from the left DataFrame as the join key(s). If it is a + MultiIndex, the number of keys in the other DataFrame (either the index + or a number of columns) must match the number of levels. + right_index : bool, default False + Use the index from the right DataFrame as the join key. Same caveats as + left_index. + sort : bool, default False + Sort the join keys lexicographically in the result DataFrame. If False, + the order of the join keys depends on the join type (how keyword). + suffixes : list-like, default is ("_x", "_y") + A length-2 sequence where each element is optionally a string + indicating the suffix to add to overlapping column names in + `left` and `right` respectively. Pass a value of `None` instead + of a string to indicate that the column name from `left` or + `right` should be left as-is, with no suffix. At least one of the + values must not be None. + copy : bool, default True + This argument is ignored in Snowpark pandas API. + indicator : bool or str, default False + If True, adds a column to the output DataFrame called "_merge" with + information on the source of each row. The column can be given a different + name by providing a string argument. The column will have a Categorical + type with the value of "left_only" for observations whose merge key only + appears in the left DataFrame, "right_only" for observations + whose merge key only appears in the right DataFrame, and "both" + if the observation's merge key is found in both DataFrames. + + validate : str, optional + This is not yet supported. + + Returns + ------- + :class:`~snowflake.snowpark.modin.pandas.DataFrame` + A DataFrame of the two merged objects. + + See Also + -------- + merge_ordered : Merge with optional filling/interpolation. + merge_asof : Merge on nearest keys. + DataFrame.join : Similar method using indices. + + Examples + -------- + >>> df1 = pd.DataFrame({'lkey': ['foo', 'bar', 'baz', 'foo'], + ... 'value': [1, 2, 3, 5]}) + >>> df2 = pd.DataFrame({'rkey': ['foo', 'bar', 'baz', 'foo'], + ... 'value': [5, 6, 7, 8]}) + >>> df1 + lkey value + 0 foo 1 + 1 bar 2 + 2 baz 3 + 3 foo 5 + >>> df2 + rkey value + 0 foo 5 + 1 bar 6 + 2 baz 7 + 3 foo 8 + + Merge df1 and df2 on the lkey and rkey columns. The value columns have + the default suffixes, _x and _y, appended. + + >>> df1.merge(df2, left_on='lkey', right_on='rkey') + lkey value_x rkey value_y + 0 foo 1 foo 5 + 1 foo 1 foo 8 + 2 bar 2 bar 6 + 3 baz 3 baz 7 + 4 foo 5 foo 5 + 5 foo 5 foo 8 + + Merge DataFrames df1 and df2 with specified left and right suffixes + appended to any overlapping columns. + + >>> df1.merge(df2, left_on='lkey', right_on='rkey', + ... suffixes=('_left', '_right')) + lkey value_left rkey value_right + 0 foo 1 foo 5 + 1 foo 1 foo 8 + 2 bar 2 bar 6 + 3 baz 3 baz 7 + 4 foo 5 foo 5 + 5 foo 5 foo 8 + + + >>> df1 = pd.DataFrame({'a': ['foo', 'bar'], 'b': [1, 2]}) + >>> df2 = pd.DataFrame({'a': ['foo', 'baz'], 'c': [3, 4]}) + >>> df1 + a b + 0 foo 1 + 1 bar 2 + >>> df2 + a c + 0 foo 3 + 1 baz 4 + + >>> df1.merge(df2, how='inner', on='a') + a b c + 0 foo 1 3 + + >>> df1.merge(df2, how='left', on='a') + a b c + 0 foo 1 3.0 + 1 bar 2 NaN + + >>> df1 = pd.DataFrame({'left': ['foo', 'bar']}) + >>> df2 = pd.DataFrame({'right': [7, 8]}) + >>> df1 + left + 0 foo + 1 bar + >>> df2 + right + 0 7 + 1 8 + + >>> df1.merge(df2, how='cross') + left right + 0 foo 7 + 1 foo 8 + 2 bar 7 + 3 bar 8 + """ + # TODO: SNOW-1063345: Modin upgrade - modin.pandas functions in general.py + # Raise error if 'left' or 'right' is native pandas object. + raise_if_native_pandas_objects(left) + raise_if_native_pandas_objects(right) + + if isinstance(left, Series): + if left.name is None: + raise ValueError("Cannot merge a Series without a name") + else: + left = left.to_frame() + + if not isinstance(left, DataFrame): + raise TypeError( + f"Can only merge Series or DataFrame objects, a {type(left)} was passed" + ) + + return left.merge( + right, + how=how, + on=on, + left_on=left_on, + right_on=right_on, + left_index=left_index, + right_index=right_index, + sort=sort, + suffixes=suffixes, + copy=copy, + indicator=indicator, + validate=validate, + ) + + +@snowpark_pandas_telemetry_standalone_function_decorator +@_inherit_docstrings(pandas.merge_ordered, apilink="pandas.merge_ordered") +def merge_ordered( + left, + right, + on=None, + left_on=None, + right_on=None, + left_by=None, + right_by=None, + fill_method=None, + suffixes=("_x", "_y"), + how: str = "outer", +) -> DataFrame: # noqa: PR01, RT01, D200 + """ + Perform a merge for ordered data with optional filling/interpolation. + """ + # TODO: SNOW-1063345: Modin upgrade - modin.pandas functions in general.py + ErrorMessage.not_implemented() + + if not isinstance(left, DataFrame): + raise ValueError(f"can not merge DataFrame with instance of type {type(right)}") + if isinstance(right, DataFrame): + right = to_pandas(right) + return DataFrame( + pandas.merge_ordered( + to_pandas(left), + right, + on=on, + left_on=left_on, + right_on=right_on, + left_by=left_by, + right_by=right_by, + fill_method=fill_method, + suffixes=suffixes, + how=how, + ) + ) + + +@snowpark_pandas_telemetry_standalone_function_decorator +@_inherit_docstrings(pandas.merge_asof, apilink="pandas.merge_asof") +def merge_asof( + left, + right, + on=None, + left_on=None, + right_on=None, + left_index: bool = False, + right_index: bool = False, + by=None, + left_by=None, + right_by=None, + suffixes=("_x", "_y"), + tolerance=None, + allow_exact_matches: bool = True, + direction: str = "backward", +) -> DataFrame: # noqa: PR01, RT01, D200 + """ + Perform a merge by key distance. + """ + # TODO: SNOW-1063345: Modin upgrade - modin.pandas functions in general.py + ErrorMessage.not_implemented() + + if not isinstance(left, DataFrame): + raise ValueError(f"can not merge DataFrame with instance of type {type(right)}") + + # As of pandas 1.2 these should raise an error; before that it did + # something likely random: + if ( + (on and (left_index or right_index)) + or (left_on and left_index) + or (right_on and right_index) + ): + raise ValueError("Can't combine left/right_index with left/right_on or on.") + + if on is not None: + if left_on is not None or right_on is not None: + raise ValueError("If 'on' is set, 'left_on' and 'right_on' can't be set.") + left_on = on + right_on = on + + if by is not None: + if left_by is not None or right_by is not None: + raise ValueError("Can't have both 'by' and 'left_by' or 'right_by'") + left_by = right_by = by + + if left_on is None and not left_index: + raise ValueError("Must pass on, left_on, or left_index=True") + + if right_on is None and not right_index: + raise ValueError("Must pass on, right_on, or right_index=True") + + return DataFrame( + query_compiler=left._query_compiler.merge_asof( + right._query_compiler, + left_on, + right_on, + left_index, + right_index, + left_by, + right_by, + suffixes, + tolerance, + allow_exact_matches, + direction, + ) + ) + + +@snowpark_pandas_telemetry_standalone_function_decorator +def pivot_table( + data, + values=None, + index=None, + columns=None, + aggfunc="mean", + fill_value=None, + margins=False, + dropna=True, + margins_name="All", + observed=False, + sort=True, +): + """ + Create a spreadsheet-style pivot table as a ``DataFrame``. + + The levels in the pivot table will be stored in MultiIndex objects + (hierarchical indexes) on the index and columns of the result DataFrame. + + Parameters + ---------- + values : list-like or scalar, optional + Column or columns to aggregate. + index : column, Grouper, array, or list of the previous + Keys to group by on the pivot table index. If a list is passed, + it can contain any of the other types (except list). If an array is + passed, it must be the same length as the data and will be used in + the same manner as column values. + columns : column, Grouper, array, or list of the previous + Keys to group by on the pivot table column. If a list is passed, + it can contain any of the other types (except list). If an array is + passed, it must be the same length as the data and will be used in + the same manner as column values. + aggfunc : function, list of functions, dict in string, default "mean". + If a list of functions is passed, the resulting pivot table will have + hierarchical columns whose top level are the function names + (inferred from the function objects themselves). + If a dict is passed, the key is column to aggregate and the value is + function or list of functions. If ``margin=True``, aggfunc will be + used to calculate the partial aggregates. + fill_value : scalar, default None + Value to replace missing values with (in the resulting pivot table, + after aggregation). + margins : bool, default False + If ``margins=True``, special ``All`` columns and rows + will be added with partial group aggregates across the categories + on the rows and columns. + dropna : bool, default True + Do not include columns whose entries are all NaN. If True, + rows with a NaN value in any column will be omitted before + computing margins. + margins_name : str, default 'All' + Name of the row / column that will contain the totals + when margins is True. + observed : bool, default False + This only applies if any of the groupers are Categoricals. + Categoricals are not yet implemented in Snowpark pandas. + If True: only show observed values for categorical groupers. + If False: show all values for categorical groupers. + + sort : bool, default True + Specifies if the result should be sorted. + + Returns + ------- + Snowpark pandas :class:`~snowflake.snowpark.modin.pandas.DataFrame` + An Excel style pivot table. + + Notes + ----- + Raise NotImplementedError if + + * margins, observed, or sort is given; + * or index, columns, or values is not str; + * or DataFrame contains MultiIndex; + * or any argfunc is not "count", "mean", "min", "max", or "sum" + + See Also + -------- + DataFrame.pivot : Pivot without aggregation that can handle + non-numeric data. + DataFrame.melt: Unpivot a DataFrame from wide to long format, + optionally leaving identifiers set. + wide_to_long : Wide panel to long format. Less flexible but more + user-friendly than melt. + + Examples + -------- + >>> df = pd.DataFrame({"A": ["foo", "foo", "foo", "foo", "foo", + ... "bar", "bar", "bar", "bar"], + ... "B": ["one", "one", "one", "two", "two", + ... "one", "one", "two", "two"], + ... "C": ["small", "large", "large", "small", + ... "small", "large", "small", "small", + ... "large"], + ... "D": [1, 2, 2, 3, 3, 4, 5, 6, 7], + ... "E": [2, 4, 5, 5, 6, 6, 8, 9, 9]}) + >>> df + A B C D E + 0 foo one small 1 2 + 1 foo one large 2 4 + 2 foo one large 2 5 + 3 foo two small 3 5 + 4 foo two small 3 6 + 5 bar one large 4 6 + 6 bar one small 5 8 + 7 bar two small 6 9 + 8 bar two large 7 9 + + This first example aggregates values by taking the sum. + + >>> table = pd.pivot_table(df, values='D', index=['A', 'B'], + ... columns=['C'], aggfunc="sum") + >>> table # doctest: +NORMALIZE_WHITESPACE + C large small + A B + bar one 4.0 5 + two 7.0 6 + foo one 4.0 1 + two NaN 6 + + We can also fill missing values using the `fill_value` parameter. + + >>> table = pd.pivot_table(df, values='D', index=['A', 'B'], + ... columns=['C'], aggfunc="sum", fill_value=0) + >>> table # doctest: +NORMALIZE_WHITESPACE + C large small + A B + bar one 4.0 5 + two 7.0 6 + foo one 4.0 1 + two NaN 6 + + >>> table = pd.pivot_table(df, values=['D', 'E'], index=['A', 'C'], + ... aggfunc={'D': "mean", 'E': "mean"}) + >>> table # doctest: +NORMALIZE_WHITESPACE + D E + D E + A C + bar large 5.500000 7.500000 + small 5.500000 8.500000 + foo large 2.000000 4.500000 + small 2.333333 4.333333 + + >>> table = pd.pivot_table(df, values=['D', 'E'], index=['A', 'C'], + ... aggfunc={'D': "mean", + ... 'E': ["min", "max", "mean"]}) + >>> table # doctest: +NORMALIZE_WHITESPACE + D E + mean max mean min + D E E E + A C + bar large 5.500000 9 7.500000 6 + small 5.500000 9 8.500000 8 + foo large 2.000000 5 4.500000 4 + small 2.333333 6 4.333333 2 + """ + # TODO: SNOW-1063345: Modin upgrade - modin.pandas functions in general.py + if not isinstance(data, DataFrame): + raise ValueError( + f"can not create pivot table with instance of type {type(data)}" + ) + + return data.pivot_table( + values=values, + index=index, + columns=columns, + aggfunc=aggfunc, + fill_value=fill_value, + margins=margins, + dropna=dropna, + margins_name=margins_name, + sort=sort, + ) + + +@snowpark_pandas_telemetry_standalone_function_decorator +@_inherit_docstrings(pandas.pivot, apilink="pandas.pivot") +def pivot(data, index=None, columns=None, values=None): # noqa: PR01, RT01, D200 + """ + Return reshaped DataFrame organized by given index / column values. + """ + # TODO: SNOW-1063345: Modin upgrade - modin.pandas functions in general.py + ErrorMessage.not_implemented() + + if not isinstance(data, DataFrame): + raise ValueError(f"can not pivot with instance of type {type(data)}") + return data.pivot(index=index, columns=columns, values=values) + + +@snowpark_pandas_telemetry_standalone_function_decorator +def to_numeric( + arg: Scalar | Series | ArrayConvertible, + errors: Literal["ignore", "raise", "coerce"] = "raise", + downcast: Literal["integer", "signed", "unsigned", "float"] | None = None, +) -> Series | Scalar | None: + # TODO: SNOW-1336091: Snowpark pandas cannot run in sprocs until modin 0.28.1 is available in conda + """ + Convert argument to a numeric type. + + If the input arg type is already a numeric type, the return dtype + will be the original type; otherwise, the return dtype is float. + + Parameters + ---------- + arg : scalar, list, tuple, 1-d array, or Series + Argument to be converted. + errors : {'ignore', 'raise', 'coerce'}, default 'raise' + - If 'raise', then invalid parsing will raise an exception. + - If 'coerce', then invalid parsing will be set as NaN. + - If 'ignore', then invalid parsing will return the input. + downcast : str, default None + downcast is ignored in Snowflake backend. + + Returns + ------- + ret + Numeric if parsing succeeded. + Return type depends on input. Series if `arg` is not scalar. + + See Also + -------- + DataFrame.astype : Cast argument to a specified dtype. + to_datetime : Convert argument to datetime. + to_timedelta : Convert argument to timedelta. + numpy.ndarray.astype : Cast a numpy array to a specified type. + DataFrame.convert_dtypes : Convert dtypes. + + Examples + -------- + Take separate series and convert to numeric, coercing when told to + + >>> s = pd.Series(['1.0', '2', -3]) + >>> pd.to_numeric(s) + 0 1.0 + 1 2.0 + 2 -3.0 + dtype: float64 + + Note: to_numeric always converts non-numeric values to floats + >>> s = pd.Series(['1', '2', '-3']) + >>> pd.to_numeric(s) + 0 1.0 + 1 2.0 + 2 -3.0 + dtype: float64 + >>> pd.to_numeric(s, downcast='float') # downcast is ignored + 0 1.0 + 1 2.0 + 2 -3.0 + dtype: float64 + >>> pd.to_numeric(s, downcast='signed') # downcast is ignored + 0 1.0 + 1 2.0 + 2 -3.0 + dtype: float64 + >>> s = pd.Series(['apple', '1.0', '2', -3]) + >>> pd.to_numeric(s, errors='ignore') # doctest: +SKIP + 0 apple + 1 1.0 + 2 2 + 3 -3 + dtype: object + >>> pd.to_numeric(s, errors='coerce') + 0 NaN + 1 1.0 + 2 2.0 + 3 -3.0 + dtype: float64 + """ + # TODO: SNOW-1063345: Modin upgrade - modin.pandas functions in general.py + raise_if_native_pandas_objects(arg) + if errors not in ("ignore", "raise", "coerce"): + raise ValueError("invalid error value specified") + if downcast is not None: + WarningMessage.ignored_argument( + operation="to_numeric", + argument="downcast", + message="downcast is ignored in Snowflake backend", + ) + # convert arg to series + arg_is_scalar = is_scalar(arg) + + if ( + not arg_is_scalar + and not isinstance(arg, (list, tuple, Series)) + and not (is_array_like(arg) and arg.ndim == 1) + ): + raise TypeError("arg must be a list, tuple, 1-d array, or Series") + + if arg_is_scalar: + arg = Series([arg]) + + if not isinstance(arg, Series): + name = None + # keep index name + if isinstance(arg, pandas.Index): + name = arg.name + arg = Series(arg, name=name) + + ret = arg._to_numeric(errors=errors) + if arg_is_scalar: + # use squeeze to turn the series result into a scalar + ret = ret.squeeze() + return ret + + +@snowpark_pandas_telemetry_standalone_function_decorator +def unique(values) -> np.ndarray: + """ + Return unique values based on a hash table. Unique values are + returned in the order of appearance. This does NOT sort. + + Parameters + ---------- + values : ndarray (1-d), list, bytearray, tuple, Series, Index, list-like + Non-hashable objects like set, dict, and user defined classes are + invalid input. + Values to perform computation. + + Returns + ------- + ndarray + The unique values returned as a NumPy array. See Notes. + + See Also + -------- + Series.unique() + + Notes + ----- + Returns the unique values as a NumPy array. This includes + + * Datetime with Timezone + * IntegerNA + + See Examples section. + + Examples + -------- + >>> pd.unique([2, 1, 3, 3]) + array([2, 1, 3]) + + >>> pd.unique([pd.Timestamp('2016-01-01', tz='US/Eastern') + ... for _ in range(3)]) + array([Timestamp('2015-12-31 21:00:00-0800', tz='America/Los_Angeles')], + dtype=object) + + >>> pd.unique([("a", "b"), ("b", "a"), ("a", "c"), ("b", "a")]) + array([list(['a', 'b']), list(['b', 'a']), list(['a', 'c'])], dtype=object) + + >>> pd.unique([None, np.nan, 2]) + array([nan, 2.]) + """ + # TODO: SNOW-1063345: Modin upgrade - modin.pandas functions in general.py + if is_list_like(values) and not isinstance(values, dict): + return Series(values).unique() + else: + raise TypeError("Only list-like objects can be used with unique()") + + +# Adding docstring since pandas docs don't have web section for this function. +@snowpark_pandas_telemetry_standalone_function_decorator +def value_counts( + values, sort=True, ascending=False, normalize=False, bins=None, dropna=True +): + """ + Compute a histogram of the counts of non-null values. + + Parameters + ---------- + values : ndarray (1-d) + Values to perform computation. + sort : bool, default: True + Sort by values. + ascending : bool, default: False + Sort in ascending order. + normalize : bool, default: False + If True then compute a relative histogram. + bins : integer, optional + Rather than count values, group them into half-open bins, + convenience for pd.cut, only works with numeric data. + dropna : bool, default: True + Don't include counts of NaN. + + Returns + ------- + Series + """ + # TODO: SNOW-1063345: Modin upgrade - modin.pandas functions in general.py + ErrorMessage.not_implemented() + + return Series(values).value_counts( + sort=sort, + ascending=ascending, + normalize=normalize, + bins=bins, + dropna=dropna, + ) + + +@snowpark_pandas_telemetry_standalone_function_decorator +def concat( + objs: ( + Iterable[snowflake.snowpark.modin.pandas.DataFrame | Series] + | Mapping[Hashable, snowflake.snowpark.modin.pandas.DataFrame | Series] + ), + axis: Axis = 0, + join: str = "outer", + ignore_index: bool = False, + keys: Sequence[Hashable] = None, + levels: list[Sequence[Hashable]] = None, + names: list[Hashable] = None, + verify_integrity: bool = False, + sort: bool = False, + copy: bool = True, +) -> snowflake.snowpark.modin.pandas.DataFrame | Series: + """ + Concatenate pandas objects along a particular axis. + + Allows optional set logic along the other axes. + + Can also add a layer of hierarchical indexing on the concatenation axis, + which may be useful if the labels are the same (or overlapping) on + the passed axis number. + + Parameters + ---------- + objs : a sequence or mapping of Series or DataFrame objects + If a mapping is passed, the sorted keys will be used as the `keys` + argument, unless it is passed, in which case the values will be + selected (see below). Any None objects will be dropped silently unless + they are all None in which case a ValueError will be raised. + axis : {0/'index', 1/'columns'}, default 0 + The axis to concatenate along. + join : {'inner', 'outer'}, default 'outer' + How to handle indexes on other axis (or axes). + ignore_index : bool, default False + If True, do not use the index values along the concatenation axis. The + resulting axis will be labeled 0, ..., n - 1. This is useful if you are + concatenating objects where the concatenation axis does not have + meaningful indexing information. Note the index values on the other + axes are still respected in the join. + keys : sequence, default None + If multiple levels passed, should contain tuples. Construct + hierarchical index using the passed keys as the outermost level. + levels : list of sequences, default None + Specific levels (unique values) to use for constructing a + MultiIndex. Otherwise they will be inferred from the keys. + Snowpark pandas does not support 'levels' argument. + names : list, default None + Names for the levels in the resulting hierarchical index. + verify_integrity : bool, default False + Check whether the new concatenated axis contains duplicates. + Snowpark pandas does not support distributed computation of concat when + 'verify_integrity' is True. + sort : bool, default False + Sort non-concatenation axis if it is not already aligned. + copy : bool, default True + If False, do not copy data unnecessarily. + This argument is ignored in Snowpark pandas. + + Returns + ------- + object, type of objs + When concatenating all Snowpark pandas :class:`~snowflake.snowpark.modin.pandas.Series` along the index (axis=0), + a Snowpark pandas :class:`~snowflake.snowpark.modin.pandas.Series` is returned. When ``objs`` contains at least + one Snowpark pandas :class:`~snowflake.snowpark.modin.pandas.DataFrame`, + a Snowpark pandas :class:`~snowflake.snowpark.modin.pandas.DataFrame` is returned. When concatenating along + the columns (axis=1), a Snowpark pandas :class:`~snowflake.snowpark.modin.pandas.DataFrame` is returned. + + See Also + -------- + DataFrame.join : Join DataFrames using indexes. + DataFrame.merge : Merge DataFrames by indexes or columns. + + Notes + ----- + The keys, levels, and names arguments are all optional. + + It is not recommended to build DataFrames by adding single rows in a + for loop. Build a list of rows and make a DataFrame in a single concat. + + Examples + -------- + Combine two ``Series``. + + >>> s1 = pd.Series(['a', 'b']) + >>> s2 = pd.Series(['c', 'd']) + >>> pd.concat([s1, s2]) + 0 a + 1 b + 0 c + 1 d + dtype: object + + Clear the existing index and reset it in the result + by setting the ``ignore_index`` option to ``True``. + + >>> pd.concat([s1, s2], ignore_index=True) + 0 a + 1 b + 2 c + 3 d + dtype: object + + Add a hierarchical index at the outermost level of + the data with the ``keys`` option. + + >>> pd.concat([s1, s2], keys=['s1', 's2']) + s1 0 a + 1 b + s2 0 c + 1 d + dtype: object + + Label the index keys you create with the ``names`` option. + + >>> pd.concat([s1, s2], keys=['s1', 's2'], + ... names=['Series name', 'Row ID']) + Series name Row ID + s1 0 a + 1 b + s2 0 c + 1 d + dtype: object + + Combine two ``DataFrame`` objects with identical columns. + + >>> df1 = pd.DataFrame([['a', 1], ['b', 2]], + ... columns=['letter', 'number']) + >>> df1 + letter number + 0 a 1 + 1 b 2 + >>> df2 = pd.DataFrame([['c', 3], ['d', 4]], + ... columns=['letter', 'number']) + >>> df2 + letter number + 0 c 3 + 1 d 4 + >>> pd.concat([df1, df2]) + letter number + 0 a 1 + 1 b 2 + 0 c 3 + 1 d 4 + + Combine ``DataFrame`` objects with overlapping columns + and return everything. Columns outside the intersection will + be filled with ``NaN`` values. + + >>> df3 = pd.DataFrame([['c', 3, 'cat'], ['d', 4, 'dog']], + ... columns=['letter', 'number', 'animal']) + >>> df3 + letter number animal + 0 c 3 cat + 1 d 4 dog + >>> pd.concat([df1, df3], sort=False) + letter number animal + 0 a 1 None + 1 b 2 None + 0 c 3 cat + 1 d 4 dog + + Combine ``DataFrame`` objects with overlapping columns + and return only those that are shared by passing ``inner`` to + the ``join`` keyword argument. + + >>> pd.concat([df1, df3], join="inner") + letter number + 0 a 1 + 1 b 2 + 0 c 3 + 1 d 4 + + Combine ``DataFrame`` objects horizontally along the x axis by + passing in ``axis=1``. + + >>> df4 = pd.DataFrame([['bird', 'polly'], ['monkey', 'george']], + ... columns=['animal', 'name']) + >>> pd.concat([df1, df4], axis=1) + letter number animal name + 0 a 1 bird polly + 1 b 2 monkey george + + Combining series horizontally creates a DataFrame. Missing names are replaced with + numeric values. + + >>> pd.concat([s1, s2], axis=1) + 0 1 + 0 a c + 1 b d + + When combining objects horizoantally ``ignore_index=True`` will clear the existing + column names and reset it in the result. + + >>> pd.concat([df1, df4], axis=1, ignore_index=True) + 0 1 2 3 + 0 a 1 bird polly + 1 b 2 monkey george + + When combining objects horizontally, add a hierarchical column index at the + outermost level of the column labels with the ``keys`` option. + + >>> pd.concat([df1, df4], axis=1, keys=['x', 'y']) # doctest: +NORMALIZE_WHITESPACE + x y + letter number animal name + 0 a 1 bird polly + 1 b 2 monkey george + + Concatenatiing series horizontally with ``keys``. + + >>> pd.concat([s1, s2], axis=1, keys=['x', 'y']) + x y + 0 1 + 0 a c + 1 b d + + When combining objects horizontally, ``how='inner'`` to keep only overalpping + index values. + + >>> df5 = pd.DataFrame([['a', 1], ['b', 2]], + ... columns=['letter', 'number'], + ... index=[1, 2]) + >>> df5 + letter number + 1 a 1 + 2 b 2 + >>> pd.concat([df1, df5], axis=1, join='inner') + letter number letter number + 1 b 2 a 1 + + Prevent the result from including duplicate index values with the + ``verify_integrity`` option. + + >>> df5 = pd.DataFrame([1], index=['a']) + >>> df5 + 0 + a 1 + >>> df6 = pd.DataFrame([2], index=['a']) + >>> df6 + 0 + a 2 + >>> pd.concat([df5, df6], verify_integrity=True) + Traceback (most recent call last): + ... + ValueError: Indexes have overlapping values: Index(['a'], dtype='object') + + Append a single row to the end of a ``DataFrame`` object. + + >>> df7 = pd.DataFrame({'a': 1, 'b': 2}, index=[0]) + >>> df7 + a b + 0 1 2 + >>> new_row = pd.DataFrame({'a': 3, 'b': 4}, index=[0]) + >>> new_row + a b + 0 3 4 + >>> pd.concat([df7, new_row], ignore_index=True) + a b + 0 1 2 + 1 3 4 + """ + # TODO: SNOW-1063345: Modin upgrade - modin.pandas functions in general.py + # Raise error if native pandas objects are passed. + raise_if_native_pandas_objects(objs) + + # In native pandas 'concat' API is expected to work with all types of iterables like + # tuples, list, generators, custom iterators, deque etc. + # Few exceptions are 'DataFrame', 'Series', 'str', these are also technically + # iterables, but they are not iterables of pandas objects. + # Note other iterables can also have non pandas objects as element in them, but it's + # not possible to know that in advance without iterating over all objects, so we + # also individual element later. + + # Raise error if 'objs' is not an iterable or an iterable of non-pandas objects. + if not isinstance(objs, Iterable) or isinstance( + objs, (pd.DataFrame, pd.Series, str) + ): + # Same error as native pandas. + raise TypeError( + "first argument must be an iterable of pandas " + f'objects, you passed an object of type "{type(objs).__name__}"' + ) + + if isinstance(objs, dict): + if keys is None: + keys = list(objs.keys()) + # if 'keys' is not none, filter out additional objects from mapping. + objs = [objs[k] for k in keys] + else: + # Native pandas also supports generators as input, that can only be iterated + # only once so first create a list from 'objs'. + objs = list(objs) + + for obj in objs: + # Raise error if native pandas objects are passed. + raise_if_native_pandas_objects(obj) + + if join not in ("inner", "outer"): + # Same error as native pandas. + raise ValueError( + "Only can inner (intersect) or outer (union) join the other axis" + ) + + axis = pandas.DataFrame()._get_axis_number(axis) + + if len(objs) == 0: + # Same error as native pandas. + raise ValueError("No objects to concatenate") + + # Filter out None objects + if keys is None: + objs = [o for o in objs if o is not None] + else: + tuples = [(k, v) for k, v in zip(keys, objs) if v is not None] + # convert list of tuples to tuples of list. + keys, objs = list(map(list, zip(*tuples))) if tuples else ([], []) + + if len(objs) == 0: + # Same error as native pandas. + raise ValueError("All objects passed were None") + + for obj in objs: + # Same error as native pandas. + if not isinstance(obj, (Series, DataFrame)): + raise TypeError( + f"cannot concatenate object of type '{type(obj)}'; " + "only Series and DataFrame objs are valid" + ) + + # Assign names to unnamed series + series_name = 0 + for i, obj in enumerate(objs): + if isinstance(obj, pd.Series) and obj.name is None: + objs[i] = obj.rename(series_name) + series_name = series_name + 1 + + # Check if all objects are of Series types. + all_series = all([isinstance(obj, pd.Series) for obj in objs]) + # When concatenating Series objects on axis 0, pandas tries to preserve name from + # input if all have same name otherwise set it to None. + if all_series and axis == 0: + unique_names = {obj.name for obj in objs} + name = objs[0].name if len(unique_names) == 1 else None + objs = [obj.rename(name) for obj in objs] + + if not copy: + WarningMessage.ignored_argument( + operation="concat", + argument="copy", + message="copy parameter has been ignored with Snowflake execution engine", + ) + + result = objs[0]._query_compiler.concat( + axis, + [o._query_compiler for o in objs[1:]], + join=join, + ignore_index=ignore_index, + keys=keys, + levels=levels, + names=names, + verify_integrity=verify_integrity, + sort=sort, + ) + # If all objects are series and concat axis=0, return Series else return DataFrame. + if all_series and axis == 0: + return Series(query_compiler=result) + return DataFrame(query_compiler=result) + + +@snowpark_pandas_telemetry_standalone_function_decorator +def to_datetime( + arg: DatetimeScalarOrArrayConvertible + | DictConvertible + | snowflake.snowpark.modin.pandas.DataFrame + | Series, + errors: DateTimeErrorChoices = "raise", + dayfirst: bool = False, + yearfirst: bool = False, + utc: bool = False, + format: str | None = None, + exact: bool | lib.NoDefault = lib.no_default, + unit: str | None = None, + infer_datetime_format: lib.NoDefault | bool = lib.no_default, + origin: Any = "unix", + cache: bool = True, +) -> Series | DatetimeScalar | NaTType | None: + # TODO: SNOW-1336091: Snowpark pandas cannot run in sprocs until modin 0.28.1 is available in conda + """ + Convert argument to datetime. + + This function converts a scalar, array-like, :class:`~snowflake.snowpark.modin.pandas.Series` or + :class:`~snowflake.snowpark.modin.pandas.DataFrame`/dict-like to a pandas datetime object. + + Parameters + ---------- + arg : int, float, str, datetime, list, tuple, 1-d array, Series, :class:`~snowflake.snowpark.modin.pandas.DataFrame`/dict-like + The object to convert to a datetime. If a :class:`~snowflake.snowpark.modin.pandas.DataFrame` is provided, the + method expects minimally the following columns: :const:`"year"`, + :const:`"month"`, :const:`"day"`. + errors : {'ignore', 'raise', 'coerce'}, default 'raise' + - If :const:`'raise'`, then invalid parsing will raise an exception. + - If :const:`'coerce'`, then invalid parsing will be set as :const:`NaT`. + - If :const:`'ignore'`, then invalid parsing will return the input. + dayfirst : bool, default False + Specify a date parse order if `arg` is str or is list-like. + If :const:`True`, parses dates with the day first, e.g. :const:`"10/11/12"` + is parsed as :const:`2012-11-10`. + + .. warning:: + + ``dayfirst=True`` is not strict, but will prefer to parse + with day first. If a delimited date string cannot be parsed in + accordance with the given `dayfirst` option, e.g. + ``to_datetime(['31-12-2021'])``, then a warning will be shown. + + yearfirst : bool, default False + Specify a date parse order if `arg` is str or is list-like. + + - If :const:`True` parses dates with the year first, e.g. + :const:`"10/11/12"` is parsed as :const:`2010-11-12`. + - If both `dayfirst` and `yearfirst` are :const:`True`, `yearfirst` is + preceded (same as :mod:`dateutil`). + + .. warning:: + + ``yearfirst=True`` is not strict, but will prefer to parse + with year first. + + utc : bool, default None + Control timezone-related parsing, localization and conversion. + + - If :const:`True`, the function *always* returns a timezone-aware + UTC-localized :class:`Timestamp`, :class:`~snowflake.snowpark.modin.pandas.Series` or + :class:`DatetimeIndex`. To do this, timezone-naive inputs are + *localized* as UTC, while timezone-aware inputs are *converted* to UTC. + + - If :const:`False` (default), inputs will not be coerced to UTC. + Timezone-naive inputs will remain naive, while timezone-aware ones + will keep their time offsets. Limitations exist for mixed + offsets (typically, daylight savings), see :ref:`Examples + ` section for details. + + See also: pandas general documentation about `timezone conversion and + localization + `_. + + format : str, default None + The strftime to parse time, e.g. :const:`"%d/%m/%Y"`. Note that + :const:`"%f"` will parse all the way up to nanoseconds. See + `strftime documentation + `_ for more information on choices. + exact : bool, default True + Control how `format` is used: + + - If :const:`True`, require an exact `format` match. + - If :const:`False`, allow the `format` to match anywhere in the target + string. + + unit : str, default 'ns' + The unit of the arg (D,s,ms,us,ns) denote the unit, which is an + integer or float number. This will be based off the origin. + Example, with ``unit='ms'`` and ``origin='unix'``, this would calculate + the number of milliseconds to the unix epoch start. + infer_datetime_format : bool, default False + If :const:`True` and no `format` is given, attempt to infer the format + of the datetime strings based on the first non-NaN element, + and if it can be inferred, switch to a faster method of parsing them. + In some cases this can increase the parsing speed by ~5-10x. + origin : scalar, default 'unix' + Define the reference date. The numeric values would be parsed as number + of units (defined by `unit`) since this reference date. + + - If :const:`'unix'` (or POSIX) time; origin is set to 1970-01-01. + - If :const:`'julian'`, unit must be :const:`'D'`, and origin is set to + beginning of Julian Calendar. Julian day number :const:`0` is assigned + to the day starting at noon on January 1, 4713 BC. + - If Timestamp convertible, origin is set to Timestamp identified by + origin. + cache : bool, default True + cache parameter is ignored with Snowflake backend, i.e., no caching will be + applied + + Returns + ------- + datetime + If parsing succeeded. + Return type depends on input (types in parenthesis correspond to + fallback in case of unsuccessful timezone or out-of-range timestamp + parsing): + + - scalar: :class:`Timestamp` (or :class:`datetime.datetime`) + - array-like: :class:`~snowflake.snowpark.modin.pandas.Series` with :class:`datetime64` dtype containing + :class:`datetime.datetime` (or + :class: :class:`~snowflake.snowpark.modin.pandas.Series` of :class:`object` dtype containing + :class:`datetime.datetime`) + - Series: :class:`~snowflake.snowpark.modin.pandas.Series` of :class:`datetime64` dtype (or + :class: :class:`~snowflake.snowpark.modin.pandas.Series` of :class:`object` dtype containing + :class:`datetime.datetime`) + - DataFrame: :class:`~snowflake.snowpark.modin.pandas.Series` of :class:`datetime64` dtype (or + :class:`~snowflake.snowpark.modin.pandas.Series` of :class:`object` dtype containing + :class:`datetime.datetime`) + + Raises + ------ + ParserError + When parsing a date from string fails. + ValueError + When another datetime conversion error happens. For example when one + of 'year', 'month', day' columns is missing in a :class:`~snowflake.snowpark.modin.pandas.DataFrame`, or + when a Timezone-aware :class:`datetime.datetime` is found in an array-like + of mixed time offsets, and ``utc=False``. + + See Also + -------- + DataFrame.astype : Cast argument to a specified dtype. + to_timedelta : Convert argument to timedelta. + convert_dtypes : Convert dtypes. + + Notes + ----- + + Many input types are supported, and lead to different output types: + + - **scalars** can be int, float, str, datetime object (from stdlib :mod:`datetime` + module or :mod:`numpy`). They are converted to :class:`Timestamp` when + possible, otherwise they are converted to :class:`datetime.datetime`. + None/NaN/null scalars are converted to :const:`NaT`. + + - **array-like** can contain int, float, str, datetime objects. They are + converted to :class:`DatetimeIndex` when possible, otherwise they are + converted to :class:`Index` with :class:`object` dtype, containing + :class:`datetime.datetime`. None/NaN/null entries are converted to + :const:`NaT` in both cases. + + - **Series** are converted to :class:`~snowflake.snowpark.modin.pandas.Series` with :class:`datetime64` + dtype when possible, otherwise they are converted to :class:`~snowflake.snowpark.modin.pandas.Series` with + :class:`object` dtype, containing :class:`datetime.datetime`. None/NaN/null + entries are converted to :const:`NaT` in both cases. + + - **DataFrame/dict-like** are converted to :class:`~snowflake.snowpark.modin.pandas.Series` with + :class:`datetime64` dtype. For each row a datetime is created from assembling + the various dataframe columns. Column keys can be common abbreviations + like [‘year’, ‘month’, ‘day’, ‘minute’, ‘second’, ‘ms’, ‘us’, ‘ns’]) or + plurals of the same. + + The following causes are responsible for :class:`datetime.datetime` objects + being returned (possibly inside an :class:`Index` or a :class:`~snowflake.snowpark.modin.pandas.Series` with + :class:`object` dtype) instead of a proper pandas designated type + (:class:`Timestamp` or :class:`~snowflake.snowpark.modin.pandas.Series` with :class:`datetime64` dtype): + + - when any input element is before :const:`Timestamp.min` or after + :const:`Timestamp.max`, see `timestamp limitations + `_. + + - when ``utc=False`` (default) and the input is an array-like or + :class:`~snowflake.snowpark.modin.pandas.Series` containing mixed naive/aware datetime, or aware with mixed + time offsets. Note that this happens in the (quite frequent) situation when + the timezone has a daylight savings policy. In that case you may wish to + use ``utc=True``. + + Examples + -------- + + **Handling various input formats** + + Assembling a datetime from multiple columns of a :class:`~snowflake.snowpark.modin.pandas.DataFrame`. The keys + can be common abbreviations like ['year', 'month', 'day', 'minute', 'second', + 'ms', 'us', 'ns']) or plurals of the same + + >>> df = pd.DataFrame({'year': [2015, 2016], + ... 'month': [2, 3], + ... 'day': [4, 5]}) + >>> pd.to_datetime(df) + 0 2015-02-04 + 1 2016-03-05 + dtype: datetime64[ns] + + Passing ``infer_datetime_format=True`` can often-times speedup a parsing + if it's not an ISO8601 format exactly, but in a regular format. + + >>> s = pd.Series(['3/11/2000', '3/12/2000', '3/13/2000'] * 1000) + >>> s.head() + 0 3/11/2000 + 1 3/12/2000 + 2 3/13/2000 + 3 3/11/2000 + 4 3/12/2000 + dtype: object + + Using a unix epoch time + + >>> pd.to_datetime(1490195805, unit='s') + Timestamp('2017-03-22 15:16:45') + >>> pd.to_datetime(1490195805433502912, unit='ns') + Timestamp('2017-03-22 15:16:45.433502912') + + .. warning:: For float arg, precision rounding might happen. To prevent + unexpected behavior use a fixed-width exact type. + + Using a non-unix epoch origin + + >>> pd.to_datetime([1, 2, 3], unit='D', + ... origin=pd.Timestamp('1960-01-01')) + 0 1960-01-02 + 1 1960-01-03 + 2 1960-01-04 + dtype: datetime64[ns] + + + **Non-convertible date/times** + + If a date does not meet the `timestamp limitations + `_, passing ``errors='ignore'`` + will return the original input instead of raising any exception. + + Passing ``errors='coerce'`` will force an out-of-bounds date to :const:`NaT`, + in addition to forcing non-dates (or non-parseable dates) to :const:`NaT`. + + >>> pd.to_datetime(['13000101', 'abc'], format='%Y%m%d', errors='ignore') # doctest: +SKIP + 0 13000101 + 1 abc + dtype: object + + >>> pd.to_datetime(['13000101', 'abc'], format='%Y%m%d', errors='coerce') + 0 NaT + 1 NaT + dtype: datetime64[ns] + + + .. _to_datetime_tz_examples: + + **Timezones and time offsets** + + The default behaviour (``utc=False``) is as follows: + + - Timezone-naive inputs are converted to timezone-naive :class:`~snowflake.snowpark.modin.pandas.Series`: + + >>> pd.to_datetime(['2018-10-26 12:00', '2018-10-26 13:00:15']) + 0 2018-10-26 12:00:00 + 1 2018-10-26 13:00:15 + dtype: datetime64[ns] + + - Timezone-aware inputs *with constant time offset* are still converted to + timezone-naive :class:`~snowflake.snowpark.modin.pandas.Series` by default. + + >>> pd.to_datetime(['2018-10-26 12:00:00 -0500', '2018-10-26 13:00:00 -0500']) + 0 2018-10-26 12:00:00 + 1 2018-10-26 13:00:00 + dtype: datetime64[ns] + + - Use right format to convert to timezone-aware type (Note that when call Snowpark + pandas API to_pandas() the timezone-aware output will always be converted to session timezone): + + >>> pd.to_datetime(['2018-10-26 12:00:00 -0500', '2018-10-26 13:00:00 -0500'], format="%Y-%m-%d %H:%M:%S %z") + 0 2018-10-26 10:00:00-07:00 + 1 2018-10-26 11:00:00-07:00 + dtype: datetime64[ns, America/Los_Angeles] + + - Timezone-aware inputs *with mixed time offsets* (for example + issued from a timezone with daylight savings, such as Europe/Paris): + + >>> pd.to_datetime(['2020-10-25 02:00:00 +0200', '2020-10-25 04:00:00 +0100']) + 0 2020-10-25 02:00:00 + 1 2020-10-25 04:00:00 + dtype: datetime64[ns] + + >>> pd.to_datetime(['2020-10-25 02:00:00 +0200', '2020-10-25 04:00:00 +0100'], format="%Y-%m-%d %H:%M:%S %z") + 0 2020-10-24 17:00:00-07:00 + 1 2020-10-24 20:00:00-07:00 + dtype: datetime64[ns, America/Los_Angeles] + + Setting ``utc=True`` makes sure always convert to timezone-aware outputs: + + - Timezone-naive inputs are *localized* based on the session timezone + + >>> pd.to_datetime(['2018-10-26 12:00', '2018-10-26 13:00'], utc=True) + 0 2018-10-26 12:00:00-07:00 + 1 2018-10-26 13:00:00-07:00 + dtype: datetime64[ns, America/Los_Angeles] + + - Timezone-aware inputs are *converted* to session timezone + + >>> pd.to_datetime(['2018-10-26 12:00:00 -0530', '2018-10-26 12:00:00 -0500'], + ... utc=True) + 0 2018-10-26 10:30:00-07:00 + 1 2018-10-26 10:00:00-07:00 + dtype: datetime64[ns, America/Los_Angeles] + """ + # TODO: SNOW-1063345: Modin upgrade - modin.pandas functions in general.py + raise_if_native_pandas_objects(arg) + + if arg is None: + return None # same as pandas + if unit and unit not in VALID_TO_DATETIME_UNIT: + raise ValueError(f"Unrecognized unit {unit}") + + if not cache: + WarningMessage.ignored_argument( + operation="to_datetime", + argument="cache", + message="cache parameter is ignored with Snowflake backend, i.e., no caching will be applied", + ) + arg_is_scalar = is_scalar(arg) + # handle empty array, list, dict + if not arg_is_scalar and not isinstance(arg, (DataFrame, Series)) and len(arg) == 0: + return arg if isinstance(arg, Series) else Series(arg) # always return a Series + if not isinstance(arg, (DataFrame, Series)): + # turn dictionary like arg into DataFrame and list like or scalar to Series + if isinstance(arg, dict): + arg = DataFrame(arg) # pragma: no cover + else: + name = None + # keep index name + if isinstance(arg, pandas.Index): + name = arg.name + arg = Series(arg) + arg.name = name + + series = arg._to_datetime( + errors=errors, + dayfirst=dayfirst, + yearfirst=yearfirst, + utc=utc, + format=format, + exact=exact, + unit=unit, + infer_datetime_format=infer_datetime_format, + origin=origin, + ) + if arg_is_scalar: + # Calling squeeze directly on Snowpark pandas Series makes an unnecessary + # count sql call. To avoid that we convert Snowpark pandas Series to Native + # pandas seris first. + return series.to_pandas().squeeze() + return series + + +@snowpark_pandas_telemetry_standalone_function_decorator +def get_dummies( + data, + prefix=None, + prefix_sep="_", + dummy_na=False, + columns=None, + sparse=False, + drop_first=False, + dtype=None, +): # noqa: PR01, RT01, D200 + """ + Convert categorical variable into dummy/indicator variables. + + Parameters + ---------- + data : array-like, Series, or :class:`~snowflake.snowpark.modin.pandas.DataFrame` + Data of which to get dummy indicators. + prefix : str, list of str, or dict of str, default None + String to append DataFrame column names. + Pass a list with length equal to the number of columns + when calling get_dummies on a DataFrame. Alternatively, `prefix` + can be a dictionary mapping column names to prefixes. + Only str, list of str and None is supported for this parameter. + prefix_sep : str, default '_' + If appending prefix, separator/delimiter to use. + dummy_na : bool, default False + Add a column to indicate NaNs, if False NaNs are ignored. Only the + value False is supported for this parameter. + columns : list-like, default None + Column names in the DataFrame to be encoded. + If `columns` is None then all the columns with + `string` dtype will be converted. + sparse : bool, default False + Whether the dummy-encoded columns should be backed by + a :class:`SparseArray` (True) or a regular NumPy array (False). + This parameter is ignored. + drop_first : bool, default False + Whether to get k-1 dummies out of k categorical levels by removing the + first level. Only the value False is supported for this parameter. + dtype : dtype, default np.uint8 + Data type for new columns. Only the value None is supported for this parameter. + + Returns + ------- + :class:`~snowflake.snowpark.modin.pandas.DataFrame` + Dummy-coded data. + + Examples + -------- + >>> s = pd.Series(list('abca')) + + >>> pd.get_dummies(s) + a b c + 0 1 0 0 + 1 0 1 0 + 2 0 0 1 + 3 1 0 0 + + >>> df = pd.DataFrame({'A': ['a', 'b', 'a'], 'B': ['b', 'a', 'c'], + ... 'C': [1, 2, 3]}) + + >>> pd.get_dummies(df, prefix=['col1', 'col2']) + C col1_a col1_b col2_a col2_b col2_c + 0 1 1 0 0 1 0 + 1 2 0 1 1 0 0 + 2 3 1 0 0 0 1 + + >>> pd.get_dummies(pd.Series(list('abcaa'))) + a b c + 0 1 0 0 + 1 0 1 0 + 2 0 0 1 + 3 1 0 0 + 4 1 0 0 + """ + # TODO: SNOW-1063345: Modin upgrade - modin.pandas functions in general.py + new_qc = data._query_compiler.get_dummies( + columns=columns, + prefix=prefix, + prefix_sep=prefix_sep, + dummy_na=dummy_na, + drop_first=drop_first, + dtype=dtype, + is_series=not data._is_dataframe, + ) + return DataFrame(query_compiler=new_qc) + + +@snowpark_pandas_telemetry_standalone_function_decorator +def melt( + frame, + id_vars=None, + value_vars=None, + var_name=None, + value_name="value", + col_level=None, + ignore_index: bool = True, +): # noqa: PR01, RT01, D200 + """ + Unpivot a DataFrame from wide to long format, optionally leaving identifiers set. + + Parameters + ---------- + id_vars : list of identifiers to retain in the result + value_vars : list of columns to unpivot on + defaults to all columns, excluding the id_vars columns + var_name : variable name, defaults to "variable" + value_name : value name, defaults to "value" + col_level : int, not implemented + ignore_index : bool, not implemented + + Returns + ------- + :class:`~snowflake.snowpark.modin.pandas.DataFrame` + unpivoted on the value columns + + Examples + -------- + >>> df = pd.DataFrame({'A': {0: 'a', 1: 'b', 2: 'c'}, + ... 'B': {0: 1, 1: 3, 2: 5}, + ... 'C': {0: 2, 1: 4, 2: 6}}) + + >>> pd.melt(df) + variable value + 0 A a + 1 A b + 2 A c + 3 B 1 + 4 B 3 + 5 B 5 + 6 C 2 + 7 C 4 + 8 C 6 + + >>> df = pd.DataFrame({'A': {0: 'a', 1: 'b', 2: 'c'}, + ... 'B': {0: 1, 1: 3, 2: 5}, + ... 'C': {0: 2, 1: 4, 2: 6}}) + >>> pd.melt(df, id_vars=['A'], value_vars=['B'], var_name='myVarname', value_name='myValname') + A myVarname myValname + 0 a B 1 + 1 b B 3 + 2 c B 5 + """ + # TODO: SNOW-1063345: Modin upgrade - modin.pandas functions in general.py + return frame.melt( + id_vars=id_vars, + value_vars=value_vars, + var_name=var_name, + value_name=value_name, + col_level=col_level, + ignore_index=ignore_index, + ) + + +@snowpark_pandas_telemetry_standalone_function_decorator +@_inherit_docstrings(pandas.crosstab, apilink="pandas.crosstab") +def crosstab( + index, + columns, + values=None, + rownames=None, + colnames=None, + aggfunc=None, + margins=False, + margins_name: str = "All", + dropna: bool = True, + normalize=False, +) -> DataFrame: # noqa: PR01, RT01, D200 + """ + Compute a simple cross tabulation of two (or more) factors. + """ + # TODO: SNOW-1063345: Modin upgrade - modin.pandas functions in general.py + ErrorMessage.not_implemented() + + pandas_crosstab = pandas.crosstab( + index, + columns, + values, + rownames, + colnames, + aggfunc, + margins, + margins_name, + dropna, + normalize, + ) + return DataFrame(pandas_crosstab) + + +# Adding docstring since pandas docs don't have web section for this function. +@snowpark_pandas_telemetry_standalone_function_decorator +def lreshape(data: DataFrame, groups, dropna=True, label=None): + """ + Reshape wide-format data to long. Generalized inverse of ``DataFrame.pivot``. + + Accepts a dictionary, `groups`, in which each key is a new column name + and each value is a list of old column names that will be "melted" under + the new column name as part of the reshape. + + Parameters + ---------- + data : DataFrame + The wide-format DataFrame. + groups : dict + Dictionary in the form: `{new_name : list_of_columns}`. + dropna : bool, default: True + Whether include columns whose entries are all NaN or not. + label : optional + Deprecated parameter. + + Returns + ------- + DataFrame + Reshaped DataFrame. + """ + # TODO: SNOW-1063345: Modin upgrade - modin.pandas functions in general.py + ErrorMessage.not_implemented() + + if not isinstance(data, DataFrame): + raise ValueError(f"can not lreshape with instance of type {type(data)}") + return DataFrame( + pandas.lreshape(to_pandas(data), groups, dropna=dropna, label=label) + ) + + +@_inherit_docstrings(pandas.wide_to_long, apilink="pandas.wide_to_long") +@snowpark_pandas_telemetry_standalone_function_decorator +def wide_to_long( + df: DataFrame, stubnames, i, j, sep: str = "", suffix: str = r"\d+" +) -> DataFrame: # noqa: PR01, RT01, D200 + """ + Unpivot a DataFrame from wide to long format. + """ + # TODO: SNOW-1063345: Modin upgrade - modin.pandas functions in general.py + ErrorMessage.not_implemented() + + if not isinstance(df, DataFrame): + raise ValueError(f"can not wide_to_long with instance of type {type(df)}") + # ErrorMessage.default_to_pandas("`wide_to_long`") + return DataFrame( + pandas.wide_to_long(to_pandas(df), stubnames, i, j, sep=sep, suffix=suffix) + ) + + +def _determine_name(objs: Iterable[BaseQueryCompiler], axis: int | str): + """ + Determine names of index after concatenation along passed axis. + + Parameters + ---------- + objs : iterable of QueryCompilers + Objects to concatenate. + axis : int or str + The axis to concatenate along. + + Returns + ------- + list with single element + Computed index name, `None` if it could not be determined. + """ + # TODO: SNOW-1063345: Modin upgrade - modin.pandas functions in general.py + axis = pandas.DataFrame()._get_axis_number(axis) + + def get_names(obj): + return obj.columns.names if axis else obj.index.names + + names = np.array([get_names(obj) for obj in objs]) + + # saving old name, only if index names of all objs are the same + if np.all(names == names[0]): + # we must do this check to avoid this calls `list(str_like_name)` + return list(names[0]) if is_list_like(names[0]) else [names[0]] + else: + return None + + +@_inherit_docstrings(pandas.to_datetime, apilink="pandas.to_timedelta") +@snowpark_pandas_telemetry_standalone_function_decorator +def to_timedelta(arg, unit=None, errors="raise"): # noqa: PR01, RT01, D200 + """ + Convert argument to timedelta. + + Accepts str, timedelta, list-like or Series for arg parameter. + Returns a Series if and only if arg is provided as a Series. + """ + # TODO: SNOW-1063345: Modin upgrade - modin.pandas functions in general.py + ErrorMessage.not_implemented() + + if isinstance(arg, Series): + query_compiler = arg._query_compiler.to_timedelta(unit=unit, errors=errors) + return Series(query_compiler=query_compiler) + return pandas.to_timedelta(arg, unit=unit, errors=errors) + + +@snowpark_pandas_telemetry_standalone_function_decorator +def date_range( + start: VALID_DATE_TYPE | None = None, + end: VALID_DATE_TYPE | None = None, + periods: int | None = None, + freq: str | pd.DateOffset | None = None, + tz: str | tzinfo | None = None, + normalize: bool = False, + name: Hashable | None = None, + inclusive: IntervalClosedType = "both", + **kwargs, +) -> Series: + """ + Return a fixed frequency series. + + Returns the range of equally spaced time points (where the difference between any + two adjacent points is specified by the given frequency) such that they all + satisfy `start <[=] x <[=] end`, where the first one and the last one are, resp., + the first and last time points in that range that fall on the boundary of ``freq`` + (if given as a frequency string) or that are valid for ``freq`` (if given as a + :class:`pandas.tseries.offsets.DateOffset`). (If exactly one of ``start``, + ``end``, or ``freq`` is *not* specified, this missing parameter can be computed + given ``periods``, the number of timesteps in the range. See the note below.) + + Parameters + ---------- + start : str or datetime-like, optional + Left bound for generating dates. + end : str or datetime-like, optional + Right bound for generating dates. + periods : int, optional + Number of periods to generate. + freq : str or DateOffset, default 'D' + Frequency strings can have multiples, e.g. '5H'. + tz : str or tzinfo, optional + Time zone name for returning localized DatetimeIndex, for example + 'Asia/Hong_Kong'. By default, the resulting DatetimeIndex is + timezone-naive. + normalize : bool, default False + Normalize start/end dates to midnight before generating date range. + name : str, default None + Name of the resulting DatetimeIndex. + inclusive : {"both", "neither", "left", "right"}, default "both" + Include boundaries; Whether to set each bound as closed or open. + + .. versionadded:: 1.4.0 + **kwargs + For compatibility. Has no effect on the result. + + Returns + ------- + rng : DatetimeIndex + + See Also + -------- + DatetimeIndex : An immutable container for datetimes. + timedelta_range : Return a fixed frequency TimedeltaIndex. + period_range : Return a fixed frequency PeriodIndex. + interval_range : Return a fixed frequency IntervalIndex. + + Notes + ----- + ``tz`` is not supported. + + Of the four parameters ``start``, ``end``, ``periods``, and ``freq``, + exactly three must be specified. If ``freq`` is omitted, the resulting + ``DatetimeIndex`` will have ``periods`` linearly spaced elements between + ``start`` and ``end`` (closed on both sides). + + To learn more about the frequency strings, please see `this link + `__. + + Also, custom or business frequencies are not implemented in Snowpark pandas, e.g., "B", "C", "SMS", "BMS", "CBMS", + "BQS", "BYS", "bh", "cbh". + + Examples + -------- + **Specifying the values** + + The next four examples generate the same `DatetimeIndex`, but vary + the combination of `start`, `end` and `periods`. + + Specify `start` and `end`, with the default daily frequency. + + >>> pd.date_range(start='1/1/2018', end='1/08/2018') + 0 2018-01-01 + 1 2018-01-02 + 2 2018-01-03 + 3 2018-01-04 + 4 2018-01-05 + 5 2018-01-06 + 6 2018-01-07 + 7 2018-01-08 + dtype: datetime64[ns] + + Specify `start` and `periods`, the number of periods (days). + + >>> pd.date_range(start='1/1/2018', periods=8) + 0 2018-01-01 + 1 2018-01-02 + 2 2018-01-03 + 3 2018-01-04 + 4 2018-01-05 + 5 2018-01-06 + 6 2018-01-07 + 7 2018-01-08 + dtype: datetime64[ns] + + Specify `end` and `periods`, the number of periods (days). + + >>> pd.date_range(end='1/1/2018', periods=8) + 0 2017-12-25 + 1 2017-12-26 + 2 2017-12-27 + 3 2017-12-28 + 4 2017-12-29 + 5 2017-12-30 + 6 2017-12-31 + 7 2018-01-01 + dtype: datetime64[ns] + + Specify `start`, `end`, and `periods`; the frequency is generated + automatically (linearly spaced). + + >>> pd.date_range(start='2018-04-24', end='2018-04-27', periods=3) + 0 2018-04-24 00:00:00 + 1 2018-04-25 12:00:00 + 2 2018-04-27 00:00:00 + dtype: datetime64[ns] + + **Other Parameters** + + Changed the `freq` (frequency) to ``'ME'`` (month end frequency). + + >>> pd.date_range(start='1/1/2018', periods=5, freq='ME') + 0 2018-01-31 + 1 2018-02-28 + 2 2018-03-31 + 3 2018-04-30 + 4 2018-05-31 + dtype: datetime64[ns] + + Multiples are allowed + + >>> pd.date_range(start='1/1/2018', periods=5, freq='3ME') + 0 2018-01-31 + 1 2018-04-30 + 2 2018-07-31 + 3 2018-10-31 + 4 2019-01-31 + dtype: datetime64[ns] + + `freq` can also be specified as an Offset object. + + >>> pd.date_range(start='1/1/2018', periods=5, freq=pd.offsets.MonthEnd(3)) + 0 2018-01-31 + 1 2018-04-30 + 2 2018-07-31 + 3 2018-10-31 + 4 2019-01-31 + dtype: datetime64[ns] + + Specify `tz` to set the timezone. + + >>> pd.date_range(start='1/1/2018', periods=5, tz='Asia/Tokyo') # doctest: +SKIP + DatetimeIndex(['2018-01-01 00:00:00+09:00', '2018-01-02 00:00:00+09:00', + '2018-01-03 00:00:00+09:00', '2018-01-04 00:00:00+09:00', + '2018-01-05 00:00:00+09:00'], + dtype='datetime64[ns, Asia/Tokyo]', freq='D') + + `inclusive` controls whether to include `start` and `end` that are on the + boundary. The default, "both", includes boundary points on either end. + + >>> pd.date_range(start='2017-01-01', end='2017-01-04', inclusive="both") + 0 2017-01-01 + 1 2017-01-02 + 2 2017-01-03 + 3 2017-01-04 + dtype: datetime64[ns] + + Use ``inclusive='left'`` to exclude `end` if it falls on the boundary. + + >>> pd.date_range(start='2017-01-01', end='2017-01-04', inclusive='left') + 0 2017-01-01 + 1 2017-01-02 + 2 2017-01-03 + dtype: datetime64[ns] + + Use ``inclusive='right'`` to exclude `start` if it falls on the boundary, and + similarly ``inclusive='neither'`` will exclude both `start` and `end`. + + >>> pd.date_range(start='2017-01-01', end='2017-01-04', inclusive='right') + 0 2017-01-02 + 1 2017-01-03 + 2 2017-01-04 + dtype: datetime64[ns] + """ + # TODO: SNOW-1063345: Modin upgrade - modin.pandas functions in general.py + + if freq is None and common.any_none(periods, start, end): + freq = "D" + + if common.count_not_none(start, end, periods, freq) != 3: + raise ValueError( + "Of the four parameters: start, end, periods, and freq, exactly three must be specified" + ) + + # Validation code is mostly copied from pandas code DatetimeArray._generate_range and it will cast it to an integer + periods = datetimelike.validate_periods(periods) + + # Return DateOffset object from string or datetime.timedelta object + freq = to_offset(freq) + + if freq is None and periods < 0: + raise ValueError("Number of samples, %s, must be non-negative." % periods) + + if start is not None: + start = Timestamp(start) + + if end is not None: + end = Timestamp(end) + + if start is NaT or end is NaT: + raise ValueError("Neither `start` nor `end` can be NaT") + + # Check that the `inclusive` argument is among {"both", "neither", "left", "right"} + left_inclusive, right_inclusive = validate_inclusive(inclusive) + + # If normalize is needed, set start and end time to midnight + start, end = _maybe_normalize_endpoints(start, end, normalize) + + # If a timezone is not explicitly given via `tz`, see if one can be inferred from the `start` and `end` endpoints. + # If more than one of these inputs provides a timezone, require that they all agree. + tz = _infer_tz_from_endpoints(start, end, tz) + + qc = SnowflakeQueryCompiler.from_date_range( + start=start, + end=end, + periods=periods, + freq=freq, + tz=tz, + left_inclusive=left_inclusive, + right_inclusive=right_inclusive, + ) + s = Series(query_compiler=qc) + s.name = name + return s + + +@_inherit_docstrings(pandas.qcut, apilink="pandas.qcut") +@snowpark_pandas_telemetry_standalone_function_decorator +def qcut( + x: np.ndarray | Series, + q: int | ListLikeOfFloats, + labels: ListLike | bool | None = None, + retbins: bool = False, + precision: int = 3, + duplicates: Literal["raise"] | Literal["drop"] = "raise", +): # noqa: PR01, RT01, D200 + """ + Quantile-based discretization function. Inherits docstrings from Pandas. + retbins=True is not supported in Snowpark pandas. + + labels=False will run binning computation in Snowflake, whereas if labels is an array + the data will be fetched to the client and the binning run client-side, as Snowpark pandas API does + not yet support pd.Categorical in its ORM mapper. + """ + + kwargs = { + "labels": labels, + "retbins": retbins, + "precision": precision, + "duplicates": duplicates, + } + + # For numpy or list, call to native pandas. + if not isinstance(x, Series): + return pandas.qcut(x, q, **kwargs) + + # Check that labels is supported as in pandas. + if not (labels is None or labels is False or is_list_like(labels)): + raise ValueError( + "Bin labels must either be False, None or passed in as a list-like argument" + ) + + # Carry out check that for the list-like case quantiles are (monotonically) increasing, + # if not the case throw pandas compatible error. + if not isinstance(q, int) and np.all(np.diff(q) < 0): + # Note: Pandas 2.x changed the error message here, using Pandas 2.x behavior here. + raise ValueError("left side of interval must be <= right side") + + # remove duplicates (input like [0.5, 0.5] is ok) + q = sorted(list(set(q))) + + if labels is not False: + # Labels require categorical, not yet supported. Use native pandas conversion here to compute result. + return pandas.qcut(x.to_pandas(), q, **kwargs) + + ans = x._qcut(q, retbins, duplicates) + + # Within Snowpark Pandas, we avoid issuing a count query. However, for qcut if q !=1 and x is a Series/list-like containing + # a single element, an error will be produced ValueError: Bin edges must be unique: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]). + # You can drop duplicate edges by setting the 'duplicates' kwarg. + # With q qcut being an API that requires conversion, we can mimick this behavior here. + ret = ans.to_pandas().to_numpy() + + if len(ret) == 1 and isinstance(q, int) and q != 1: + if duplicates == "raise": + # Produce raising error. + raise ValueError( + f"Bin edges must be unique: {repr(np.array([0.] * q))}.\nYou can drop duplicate edges by setting the 'duplicates' kwarg." + ) + else: + # The result will be always NaN because no unique bin could be found. + return np.array([np.nan]) + + return ret + + +@snowpark_pandas_telemetry_standalone_function_decorator +def cut( + x: AnyArrayLike, + bins: int | Sequence[Scalar] | IntervalIndex, + right: bool = True, + labels=None, + retbins: bool = False, + precision: int = 3, + include_lowest: bool = False, + duplicates: str = "raise", + ordered: bool = True, +): + """ + Bin values into discrete intervals. + + Use `cut` when you need to segment and sort data values into bins. This + function is also useful for going from a continuous variable to a + categorical variable. For example, `cut` could convert ages to groups of + age ranges. Supports binning into an equal number of bins, or a + pre-specified array of bins. + + Parameters + ---------- + x : array-like + The input array to be binned. Must be 1-dimensional. + bins : int, sequence of scalars + The criteria to bin by. + + * int : Defines the number of equal-width bins in the range of `x`. The + range of `x` is extended by .1% on each side to include the minimum + and maximum values of `x`. + * sequence of scalars : Defines the bin edges allowing for non-uniform + width. No extension of the range of `x` is done. + + right : bool, default True + Indicates whether `bins` includes the rightmost edge or not. If + ``right == True`` (the default), then the `bins` ``[1, 2, 3, 4]`` + indicate (1,2], (2,3], (3,4]. This argument is ignored when + `bins` is an IntervalIndex. + labels : array or False, default None + Specifies the labels for the returned bins. Must be the same length as + the resulting bins. If False, returns only integer indicators of the + bins. This affects the type of the output container (see below). + This argument is ignored when `bins` is an IntervalIndex. If True, + raises an error. When `ordered=False`, labels must be provided. + + Snowpark pandas API does not support labels=None. + Labels must be of a Snowpark pandas API supported dtype. + + retbins : bool, default False + Snowpark pandas API does not support this parameter yet. + precision : int, default 3 + The precision at which to store and display the bins labels. + include_lowest : bool, default False + Whether the first interval should be left-inclusive or not. + duplicates : {default 'raise', 'drop'}, optional + If bin edges are not unique, raise ValueError or drop non-uniques. + ordered : bool, default True + Whether the labels are ordered or not. Applies to returned types + Categorical and Series (with Categorical dtype). If True, + the resulting categorical will be ordered. If False, the resulting + categorical will be unordered (labels must be provided). + + Returns + ------- + out : Categorical, Series, or ndarray + An array-like object representing the respective bin for each value + of `x`. The type depends on the value of `labels`. + + * None (default) : returns a Series for Series `x` or a + Categorical for all other inputs. The values stored within + are Interval dtype. + + * sequence of scalars : returns a Series for Series `x` or a + Categorical for all other inputs. The values stored within + are whatever the type in the sequence is. + + * False : returns an ndarray of integers. + + bins : numpy.ndarray + The computed or specified bins. Only returned when `retbins=True`. + For scalar or sequence `bins`, this is an ndarray with the computed + bins. If set `duplicates=drop`, `bins` will drop non-unique bin. + + Notes + ----- + Any NA values will be NA in the result. Out of bounds values will be NA in + the resulting Series or Categorical object. + + Snowpark pandas API does not natively support Categorical and categorical types. When calling `cut` with a + Snowpark pandas Series and using `labels=False`, a Snowpark pandas Series object is returned. However, + for `labels != False` an error is raised. + + Examples + -------- + Discretize into three equal-sized bins. + + >>> pd.cut(np.array([1, 7, 5, 4, 6, 3]), 3, labels=False) + ... # doctest: +ELLIPSIS + array([0, 2, 1, 1, 2, 0]) + + ``labels=False`` implies you just want the bins back. + + >>> pd.cut([0, 1, 1, 2], bins=4, labels=False) + array([0, 1, 1, 3]) + + Passing a Series as an input returns a Series with labels=False: + + >>> s = pd.Series(np.array([2, 4, 6, 8, 10]), + ... index=['a', 'b', 'c', 'd', 'e']) + >>> pd.cut(s, 3, labels=False) + ... # doctest: +ELLIPSIS + a 0 + b 0 + c 1 + d 2 + e 2 + dtype: int64 + """ + + if retbins is True: + ErrorMessage.not_implemented("retbins not supported.") + + # Execute other supported objects via native pandas. + if not isinstance(x, Series): + return pandas.cut( + x, + bins, + right=right, + labels=labels, + retbins=retbins, + precision=precision, + include_lowest=include_lowest, + duplicates=duplicates, + ordered=ordered, + ) + + # Produce pandas-compatible error if ordered=False and labels are not specified. + # No error is raised when labels are not desired (labels=False). + if ordered is False and labels is None: + raise ValueError("'labels' must be provided if 'ordered = False'") + + bins, qc = x._query_compiler.cut( + bins, + right=right, + labels=labels, + precision=precision, + include_lowest=include_lowest, + duplicates=duplicates, + ) + + # Depending on setting, reconstruct bins and convert qc to the correct result. + if labels is False: + return pd.Series(query_compiler=qc) + else: + # Raise NotImplemented Error as categorical is not supported. + ErrorMessage.not_implemented("categorical not supported in Snowpark pandas API") + + # Following code would produce correct result, uncomment once categorical is supported. + # Convert to pandas categorical and return as Series. + # Note: In the future, once we support CategoricalType we could keep this lazily around. For now, + # match what pandas does here. In the future, change pandas -> pd and everything should work out-of-the box. + # arr = qc.to_numpy().ravel() + # return pandas.Series( + # pandas.Categorical(values=arr, categories=labels, ordered=ordered) + # ) diff --git a/src/snowflake/snowpark/modin/pandas/groupby.py b/src/snowflake/snowpark/modin/pandas/groupby.py new file mode 100644 index 00000000000..8e44354ac4a --- /dev/null +++ b/src/snowflake/snowpark/modin/pandas/groupby.py @@ -0,0 +1,1209 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +"""Implement GroupBy public API as pandas does.""" + +from collections.abc import Hashable +from typing import Any, Callable, Literal, Optional, Union + +import numpy as np # noqa: F401 +import numpy.typing as npt +import pandas +import pandas.core.groupby +from pandas._libs.lib import NoDefault, no_default +from pandas._typing import AggFuncType, Axis, IndexLabel +from pandas.core.dtypes.common import is_dict_like, is_list_like, is_numeric_dtype +from pandas.errors import SpecificationError +from pandas.io.formats.printing import PrettyDict +from pandas.util._validators import validate_bool_kwarg + +# following import are used in doctest +from snowflake.snowpark.modin import pandas as pd # noqa: F401 + +# Snowpark pandas API version +from snowflake.snowpark.modin.pandas.series import Series +from snowflake.snowpark.modin.pandas.utils import ( + raise_if_native_pandas_objects, + validate_and_try_convert_agg_func_arg_func_to_str, +) +from snowflake.snowpark.modin.plugin._internal.apply_utils import ( + create_groupby_transform_func, +) +from snowflake.snowpark.modin.plugin._internal.telemetry import TelemetryMeta +from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import ( + SnowflakeQueryCompiler, +) +from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage +from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage +from snowflake.snowpark.modin.utils import ( + MODIN_UNNAMED_SERIES_LABEL, + _inherit_docstrings, + doc_replace_dataframe_with_link, + hashable, + validate_int_kwarg, +) + +_DEFAULT_BEHAVIOUR = { + "__class__", + "__getitem__", + "__init__", + "__iter__", + "_as_index", + "_axis", + "_by", + "_check_index_name", + "_columns", + "_df", + "_groups_cache", + "_idx_name", + "_index", + "_indices_cache", + "_internal_by", + "_internal_by_cache", + "_iter", + "_kwargs", + "_level", + "_pandas_class", + "_query_compiler", + "_sort", + "_wrap_aggregation", +} + + +@_inherit_docstrings( + pandas.core.groupby.DataFrameGroupBy, modify_doc=doc_replace_dataframe_with_link +) +class DataFrameGroupBy(metaclass=TelemetryMeta): + _pandas_class = pandas.core.groupby.DataFrameGroupBy + + def __init__( + self, + df, + by, + axis, + level, + as_index, + sort, + group_keys, + idx_name, + **kwargs, + ) -> None: + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + self._axis = axis + self._idx_name = idx_name + self._df = df + self._df._query_compiler.validate_groupby(by, axis, level) + self._query_compiler = self._df._query_compiler + self._columns = self._query_compiler.columns + self._by = by + self._level = level + self._kwargs = { + "level": level, + "sort": sort, + "as_index": as_index, + "group_keys": group_keys, + } + self._kwargs.update(kwargs) + + def __getattr__(self, key): + """ + Alter regular attribute access, looks up the name in the columns. + + Parameters + ---------- + key : str + Attribute name. + + Returns + ------- + The value of the attribute. + """ + try: + return object.__getattribute__(self, key) + except AttributeError as err: + if key in self._columns: + return self.__getitem__(key) + raise err + + @property + def ngroups(self): + return self._query_compiler.groupby_ngroups( + by=self._by, + axis=self._axis, + groupby_kwargs=self._kwargs, + ) + + def skew(self, *args, **kwargs): + ErrorMessage.method_not_implemented_error(name="skew", class_="GroupBy") + + def ffill(self, limit=None): + ErrorMessage.method_not_implemented_error(name="ffill", class_="GroupBy") + + def sem(self, ddof=1): + ErrorMessage.method_not_implemented_error(name="sem", class_="GroupBy") + + def value_counts( + self, + subset=None, + normalize: bool = False, + sort: bool = True, + ascending: bool = False, + dropna: bool = True, + ): + ErrorMessage.method_not_implemented_error(name="value_counts", class_="GroupBy") + + def mean( + self, + numeric_only: bool = False, + engine: Optional[Literal["cython", "numba"]] = None, + engine_kwargs: Optional[dict[str, bool]] = None, + ): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + WarningMessage.warning_if_engine_args_is_set( + "groupby_mean", engine, engine_kwargs + ) + return self._wrap_aggregation( + qc_method=type(self._query_compiler).groupby_agg, + numeric_only=numeric_only, + agg_func="mean", + agg_kwargs=dict(numeric_only=numeric_only), + ) + + def any(self, skipna=True): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="any", class_="GroupBy") + + @property + def plot(self): # pragma: no cover + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="plot", class_="GroupBy") + + def ohlc(self): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="ohlc", class_="GroupBy") + + def __bytes__(self): + """ + Convert DataFrameGroupBy object into a python2-style byte string. + + Returns + ------- + bytearray + Byte array representation of `self`. + + Notes + ----- + Deprecated and removed in pandas and will be likely removed in Modin. + """ + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="__bytes__", class_="GroupBy") + + _groups_cache = no_default + + # TODO: since python 3.9: + # @cached_property + @property + def groups(self) -> PrettyDict[Hashable, pd.Index]: + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + return self._query_compiler.groupby_groups( + self._by, + self._axis, + groupby_kwargs={ + # groupby.groups always treat as_index as True. this seems to be + # intended behavior: https://github.com/pandas-dev/pandas/issues/56965 + k: True if k == "as_index" else v + for k, v in self._kwargs.items() + }, + ) + + def min( + self, + numeric_only: bool = False, + min_count: int = -1, + engine: Optional[Literal["cython", "numba"]] = None, + engine_kwargs: Optional[dict[str, bool]] = None, + ): + WarningMessage.warning_if_engine_args_is_set( + "groupby_min", engine, engine_kwargs + ) + validate_int_kwarg(min_count, "min_count", float_allowed=False) + return self._wrap_aggregation( + qc_method=type(self._query_compiler).groupby_agg, + numeric_only=numeric_only, + agg_func="min", + agg_kwargs=dict(min_count=min_count, numeric_only=numeric_only), + ) + + def idxmax( + self, axis: Axis = no_default, skipna: bool = True, numeric_only: bool = False + ): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + axis_number = self._df._get_axis_number(axis) + if axis_number == 1: + # Performing idxmax is deprecated and will be removed in a future pandas version. + raise NotImplementedError( + "DataFrameGroupBy.idxmax with axis=1 is deprecated and will be removed in a " + "future version. Operate on the un-grouped DataFrame instead." + ) + else: + # When performing idxmax/idxmin on axis=0, it can be done column-wise. + result = self._wrap_aggregation( + qc_method=type(self._query_compiler).groupby_agg, + numeric_only=numeric_only, + how="axis_wise", + agg_func="idxmax", + # axis is also specified here since the axis used with idxmax/idxmin is different from + # the groupby axis. + agg_kwargs=dict(skipna=skipna, axis=0), + ) + return result + + def idxmin( + self, axis: Axis = no_default, skipna: bool = True, numeric_only: bool = False + ) -> Series: + axis_number = self._df._get_axis_number(axis) + if axis_number == 1: + # Performing idxmin is deprecated and will be removed in a future pandas version. + raise NotImplementedError( + "DataFrameGroupBy.idxmin with axis=1 is deprecated and will be removed in a " + "future version. Operate on the un-grouped DataFrame instead." + ) + else: + # When performing idxmax/idxmin on axis=0, it can be done column-wise. + result = self._wrap_aggregation( + qc_method=type(self._query_compiler).groupby_agg, + numeric_only=numeric_only, + how="axis_wise", + agg_func="idxmin", + # axis is also specified here since the axis used with idxmax/idxmin is different from + # the groupby axis. + agg_kwargs=dict(skipna=skipna, axis=0), + ) + return result + + @property + def ndim(self): + """ + Return 2. + + Returns + ------- + int + Returns 2. + + Notes + ----- + Deprecated and removed in pandas and will be likely removed in Modin. + """ + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + return 2 # ndim is always 2 for DataFrames + + def shift( + self, periods: int = 1, freq: int = None, axis: Axis = 0, fill_value: Any = None + ): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + if not isinstance(periods, int): + raise TypeError( + f"Periods must be integer, but {periods} is {type(periods)}." + ) + qc = self._query_compiler.groupby_shift( + self._by, self._axis, self._level, periods, freq, fill_value, self.ndim == 1 + ) + return ( + pd.Series(query_compiler=qc) + if self.ndim == 1 + else pd.DataFrame(query_compiler=qc) + ) + + def nth(self, n, dropna=None): + ErrorMessage.method_not_implemented_error(name="nth", class_="GroupBy") + + def cumsum(self, axis: Axis = 0, *args, **kwargs): + qc = self._query_compiler.groupby_cumsum(self._by, self._axis, self._kwargs) + return ( + pd.Series(query_compiler=qc) + if self.ndim == 1 + else pd.DataFrame(query_compiler=qc) + ) + + @property + def indices(self) -> dict[Hashable, npt.NDArray[np.intp]]: + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + return self._query_compiler.groupby_indices( + self._by, + self._axis, + groupby_kwargs={ + # groupby.indices always treat as_index as True. this seems to be + # intended behavior: https://github.com/pandas-dev/pandas/issues/56965 + k: True if k == "as_index" else v + for k, v in self._kwargs.items() + }, + ) + + @_inherit_docstrings(pandas.core.groupby.DataFrameGroupBy.pct_change) + def pct_change(self, *args, **kwargs): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="pct_change", class_="GroupBy") + + def filter(self, func, dropna=True, *args, **kwargs): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="filter", class_="GroupBy") + + def cummax(self, axis: Axis = 0, numeric_only: bool = False, *args, **kwargs): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + qc = self._query_compiler.groupby_cummax( + self._by, self._axis, numeric_only, self._kwargs + ) + return ( + pd.Series(query_compiler=qc) + if self.ndim == 1 + else pd.DataFrame(query_compiler=qc) + ) + + def apply(self, func, *args, **kwargs): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + # TODO: SNOW-1244717: Explore whether window function are performant and can be used + # whenever `func` is an aggregation function. + if not callable(func): + raise NotImplementedError("No support for non-callable `func`") + dataframe_result = pd.DataFrame( + query_compiler=self._query_compiler.groupby_apply( + self._by, + agg_func=func, + axis=self._axis, + groupby_kwargs=self._kwargs, + agg_args=args, + agg_kwargs=kwargs, + ) + ) + if dataframe_result.columns.equals(pandas.Index([MODIN_UNNAMED_SERIES_LABEL])): + return dataframe_result.squeeze(axis=1) + return dataframe_result + + @property + def dtypes(self): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="dtypes", class_="GroupBy") + + def first(self, **kwargs): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="first", class_="GroupBy") + + _internal_by_cache = no_default + + # TODO: since python 3.9: + # @cached_property + @property + def _internal_by(self): + """ + Get only those components of 'by' that are column labels of the source frame. + + Returns + ------- + tuple of labels + """ + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + if self._internal_by_cache is not no_default: + return self._internal_by_cache + + by_list = self._by if is_list_like(self._by) else [self._by] + + internal_by = tuple( + by for by in by_list if hashable(by) and by in self._columns + ) + + self._internal_by_cache = internal_by + return internal_by + + def __getitem__(self, key): + """ + Implement indexing operation on a DataFrameGroupBy object. + + Parameters + ---------- + key : list or str + Names of columns to use as subset of original object. + + Returns + ------- + DataFrameGroupBy or SeriesGroupBy + Result of indexing operation. + + Raises + ------ + NotImplementedError + Column lookups on GroupBy with arbitrary Series in by is not yet supported. + """ + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + if self._axis == 1: + raise ValueError("Cannot subset columns when using axis=1") + + # These parameters are common for building the resulted Series or DataFrame groupby object + kwargs = { + **self._kwargs.copy(), + "by": self._by, + "axis": self._axis, + "idx_name": self._idx_name, + } + # The rules of type deduction for the resulted object is the following: + # 1. If `key` is a list-like or `as_index is False`, then the resulted object is a DataFrameGroupBy + # 2. Otherwise, the resulted object is SeriesGroupBy + # 3. Result type does not depend on the `by` origin + # Examples: + # - drop: any, as_index: any, __getitem__(key: list_like) -> DataFrameGroupBy + # - drop: any, as_index: False, __getitem__(key: any) -> DataFrameGroupBy + # - drop: any, as_index: True, __getitem__(key: label) -> SeriesGroupBy + if is_list_like(key): + make_dataframe = True + else: + if self._as_index: + make_dataframe = False + else: + make_dataframe = True + key = [key] + + column_index = self._df.columns + # validate that all keys are labels belong to the data column of the df + for label in key: + if not (label in column_index): + raise KeyError(f"Columns not found: '{label}'") + + # internal_by records all label in by that belongs to the data columns + internal_by = frozenset(self._internal_by) + if len(internal_by.intersection(key)) != 0: + message = ( + "Data column selection with overlap of 'by' columns is not yet supported, " + "please duplicate the overlapped by columns and rename it to a different name" + ) + ErrorMessage.not_implemented(message=message) + + # select the union of the internal bys and select keys. Here we find all integer + # positions for all the selected columns, and then call iloc to select all columns. + # This is because loc currently doesn't support select with multiindex, once iloc and + # dataframe getitem is supported, this can be replaced with df[list(internal_by) + list(key)] + # TODO (SNOW-896342): update self._df.iloc[:, ilocs_list] to use df[list(internal_by) + list(key)] + # once dataframe getitem is supported. + _, by_ilocs = column_index._get_indexer_strict(list(internal_by), "columns") + _, key_ilocs = column_index._get_indexer_strict(list(key), "columns") + ilocs_list = list(by_ilocs) + list(key_ilocs) + + if len(key_ilocs) > 1: + make_dataframe = True + + if make_dataframe: + return DataFrameGroupBy( + self._df.iloc[:, ilocs_list], + **kwargs, + ) + else: + return SeriesGroupBy( + self._df.iloc[:, ilocs_list], + **kwargs, + ) + + def cummin(self, axis: Axis = 0, numeric_only: bool = False, *args, **kwargs): + qc = self._query_compiler.groupby_cummin( + self._by, self._axis, numeric_only, self._kwargs + ) + return ( + pd.Series(query_compiler=qc) + if self.ndim == 1 + else pd.DataFrame(query_compiler=qc) + ) + + def bfill(self, limit=None): + ErrorMessage.method_not_implemented_error(name="bfill", class_="GroupBy") + + def prod(self, numeric_only=False, min_count=0): + ErrorMessage.method_not_implemented_error(name="prod", class_="GroupBy") + + def std( + self, + ddof: int = 1, + engine: Optional[Literal["cython", "numba"]] = None, + engine_kwargs: Optional[dict[str, bool]] = None, + numeric_only: bool = False, + ): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + WarningMessage.warning_if_engine_args_is_set( + "groupby_std", engine, engine_kwargs + ) + return self._wrap_aggregation( + qc_method=SnowflakeQueryCompiler.groupby_agg, + numeric_only=numeric_only, + agg_func="std", + agg_kwargs=dict(ddof=ddof, numeric_only=numeric_only), + ) + + def aggregate( + self, + func: Optional[AggFuncType] = None, + *args: Any, + engine: Optional[Literal["cython", "numba"]] = None, + engine_kwargs: Optional[dict[str, bool]] = None, + **kwargs: Any, + ): + WarningMessage.warning_if_engine_args_is_set( + "groupby_aggregate", engine, engine_kwargs + ) + if self._axis != 0 and (is_dict_like(func) or is_list_like(func)): + # This is the same as pandas for func that is a list or dict + ErrorMessage.not_implemented( + "axis other than 0 is not supported" + ) # pragma: no cover + + func = validate_and_try_convert_agg_func_arg_func_to_str( + agg_func=func, obj=self, allow_duplication=True, axis=self._axis + ) + + if isinstance(func, str): + # Using "getattr" here masks possible AttributeError which we throw + # in __getattr__, so we should call __getattr__ directly instead. + agg_func = self.__getattr__(func) + if callable(agg_func): + return agg_func(*args, **kwargs) + + # when the aggregation function passed in is list like always return a Dataframe regardless + # it is SeriesGroupBy or DataFrameGroupBy + is_result_dataframe = (self.ndim == 2) or is_list_like(func) + result = self._wrap_aggregation( + qc_method=type(self._query_compiler).groupby_agg, + numeric_only=False, + agg_func=func, + agg_args=args, + agg_kwargs=kwargs, + how="axis_wise", + is_result_dataframe=is_result_dataframe, + ) + return result + + agg = aggregate + + def last(self, **kwargs): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="last", class_="GroupBy") + + def rank( + self, + method: str = "average", + ascending: bool = True, + na_option: str = "keep", + pct: bool = False, + *args, + **kwargs, + ): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + query_compiler = self._query_compiler.groupby_rank( + by=self._by, + axis=self._axis, + method=method, + na_option=na_option, + ascending=ascending, + pct=pct, + groupby_kwargs=self._kwargs, + agg_args=args, + agg_kwargs=kwargs, + ) + if self.ndim == 1: + result = pd.Series(query_compiler=query_compiler) + else: + result = pd.DataFrame(query_compiler=query_compiler) + return result + + @property + def corrwith(self): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="corrwith", class_="GroupBy") + + def max( + self, + numeric_only: bool = False, + min_count: int = -1, + engine: Optional[Literal["cython", "numba"]] = None, + engine_kwargs: Optional[dict[str, bool]] = None, + ): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + WarningMessage.warning_if_engine_args_is_set( + "groupby_max", engine, engine_kwargs + ) + validate_int_kwarg(min_count, "min_count", float_allowed=False) + return self._wrap_aggregation( + qc_method=type(self._query_compiler).groupby_agg, + numeric_only=numeric_only, + agg_func="max", + agg_kwargs=dict(min_count=min_count, numeric_only=numeric_only), + ) + + def var( + self, + ddof: int = 1, + engine: Optional[Literal["cython", "numba"]] = None, + engine_kwargs: Optional[dict[str, bool]] = None, + numeric_only: bool = False, + ): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + WarningMessage.warning_if_engine_args_is_set( + "groupby_var", engine, engine_kwargs + ) + + return self._wrap_aggregation( + qc_method=SnowflakeQueryCompiler.groupby_agg, + numeric_only=numeric_only, + agg_func="var", + agg_kwargs=dict(ddof=ddof, numeric_only=numeric_only), + ) + + def get_group(self, name, obj=None): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="get_group", class_="GroupBy") + + def __len__(self): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="__len__", class_="GroupBy") + + def all(self, skipna=True): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="all", class_="GroupBy") + + def size(self): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="size", class_="GroupBy") + + def sum( + self, + numeric_only: bool = False, + min_count: int = 0, + engine: Optional[Literal["cython", "numba"]] = None, + engine_kwargs: Optional[dict[str, bool]] = None, + ): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + WarningMessage.warning_if_engine_args_is_set( + "groupby_sum", engine, engine_kwargs + ) + validate_int_kwarg(min_count, "min_count", float_allowed=False) + return self._wrap_aggregation( + qc_method=type(self._query_compiler).groupby_agg, + numeric_only=numeric_only, + agg_func="sum", + agg_kwargs=dict(min_count=min_count, numeric_only=numeric_only), + ) + + def describe(self, **kwargs): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="describe", class_="GroupBy") + + def boxplot( + self, + grouped, + subplots=True, + column=None, + fontsize=None, + rot=0, + grid=True, + ax=None, + figsize=None, + layout=None, + **kwargs, + ): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="boxplot", class_="GroupBy") + + def ngroup(self, ascending=True): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="ngroup", class_="GroupBy") + + def nunique(self, dropna=True): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + return self._wrap_aggregation( + qc_method=type(self._query_compiler).groupby_nunique, + agg_func="nunique", + agg_kwargs=dict(dropna=dropna), + ) + + def resample(self, rule, *args, **kwargs): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="resample", class_="GroupBy") + + def sample(self, n=None, frac=None, replace=False, weights=None, random_state=None): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="sample", class_="GroupBy") + + def median(self, numeric_only: bool = False): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + return self._wrap_aggregation( + qc_method=type(self._query_compiler).groupby_agg, + numeric_only=numeric_only, + agg_func="median", + agg_kwargs=dict(numeric_only=numeric_only), + ) + + def head(self, n=5): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + # Ensure that n is an integer value. + if not isinstance(n, int): + raise TypeError("n must be an integer value.") + + # Only the groupby parameter "dropna" affects the output of head. None of the other groupby + # parameters: as_index, sort, and group_keys, affect head. + # Values needed for the helper functions. + agg_kwargs = { + "n": n, + "level": self._level, + "dropna": self._kwargs.get("dropna", True), + } + + result = self._wrap_aggregation( + qc_method=type(self._query_compiler).groupby_agg, + agg_func="head", + agg_kwargs=agg_kwargs, + ) + return pd.DataFrame(result) + + def cumprod(self, axis=0, *args, **kwargs): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="cumprod", class_="GroupBy") + + def __iter__(self): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + return self._iter.__iter__() + + def cov(self): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="cov", class_="GroupBy") + + def transform( + self, + func: Union[str, Callable], + *args: Any, + engine: Optional[Literal["cython", "numba"]] = None, + engine_kwargs: Optional[dict[str, bool]] = None, + **kwargs: Any, + ) -> "pd.DataFrame": + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + WarningMessage.warning_if_engine_args_is_set( + "groupby_transform", engine, engine_kwargs + ) + + # The resultant DataFrame from `transform` has an index that always matches the original + # DataFrame's index. + # Create a new groupby object so that we can tune parameters to ensure that `apply` + # returns a DataFrame with the required index (same as original DataFrame). + # + # Unlike `transform`, groupby parameters affect the result of `apply`. + # - `group_keys` controls whether the grouped column(s) are included in the index. + # `group_keys` needs to be False to ensure that the resultant DataFrame has the + # original DataFrame's index. + # + # - `dropna` controls whether the NA values should be included as a group/be present + # in the group keys. `transform` always includes the NA values, therefore `dropna` + # needs to be False to ensure that all NA values are included. + # + # - `sort` controls whether the group keys are sorted. + # + # - `as_index` controls whether the groupby object has group labels as the index. + by = self._by + level = self._level + groupby_obj = self._df.groupby( + by=by, # either by or levels can be specified at a time + level=level, + as_index=self._as_index, + group_keys=False, + dropna=False, + sort=self._sort, + ) + + # Apply the transform function to each group. + res = groupby_obj.apply( + create_groupby_transform_func(func, by, level, *args, **kwargs) + ) + + dropna = self._kwargs.get("dropna", True) + if dropna is True: + # - To avoid dropping any NA values, `dropna` is set to False in both the groupby + # object created above and the groupby object created in `create_groupby_transform_func`. + # + # - If dropna is set to True in the groupby object, the output from this code (so far) + # and the expected native pandas result differs. + # + # - In the Snowpark pandas code, all rows grouped under NA keys calculate the result with + # the given `func`, thus resulting in non-NA values. + # + # - In the native pandas version, all rows grouped under NA keys take up + # "NaN" values in all columns. + # + # Therefore, we need to convert the rows grouped under NA keys to have NaN values in + # all columns. + na_col_data = self._df[by].isna() + condition = ( + na_col_data.any(axis=1) + if isinstance(na_col_data, pd.DataFrame) + else na_col_data + ) + res.loc[condition, :] = np.nan + + return res + + def corr(self, **kwargs): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="corr", class_="GroupBy") + + def fillna(self, *args, **kwargs): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="fillna", class_="GroupBy") + + def count(self): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + result = self._wrap_aggregation( + qc_method=type(self._query_compiler).groupby_agg, + numeric_only=False, + agg_func="count", + ) + return result + + def pipe(self, func, *args, **kwargs): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="pipe", class_="GroupBy") + + def cumcount(self, ascending: bool = True): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + query_compiler = self._query_compiler.groupby_cumcount( + self._by, self._axis, self._kwargs, ascending + ) + return pd.Series(query_compiler=query_compiler) + + def tail(self, n=5): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + # Ensure that n is an integer value. + if not isinstance(n, int): + raise TypeError("n must be an integer value.") + + # Only the groupby parameter "dropna" affects the output of tail. None of the other groupby + # parameters: as_index, sort, and group_keys, affect tail. + # Values needed for the helper functions. + agg_kwargs = { + "n": n, + "level": self._level, + "dropna": self._kwargs.get("dropna", True), + } + + result = self._wrap_aggregation( + qc_method=type(self._query_compiler).groupby_agg, + agg_func="tail", + agg_kwargs=agg_kwargs, + ) + return pd.DataFrame(result) + + # expanding and rolling are unique cases and need to likely be handled + # separately. They do not appear to be commonly used. + def expanding(self, *args, **kwargs): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="expanding", class_="GroupBy") + + def rolling(self, *args, **kwargs): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="rolling", class_="GroupBy") + + def hist(self): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="hist", class_="GroupBy") + + def quantile(self, q=0.5, interpolation="linear"): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + return self._wrap_aggregation( + type(self._query_compiler).groupby_agg, + numeric_only=False, + agg_func="quantile", + agg_kwargs=dict(q=q, interpolation=interpolation), + ) + + def diff(self): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="diff", class_="GroupBy") + + def take(self, *args, **kwargs): + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="take", class_="GroupBy") + + @property + def _index(self): + """ + Get index value. + + Returns + ------- + pandas.Index + Index value. + """ + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + return self._query_compiler.index + + @property + def _sort(self): + """ + Get sort parameter value. + + Returns + ------- + bool + Value of sort parameter used to create DataFrameGroupBy object. + """ + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + return self._kwargs.get("sort") + + @property + def _as_index(self): + """ + Get as_index parameter value. + + Returns + ------- + bool + Value of as_index parameter used to create DataFrameGroupBy object. + """ + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + return self._kwargs.get("as_index") + + @property + def _iter(self): + """ + Construct a tuple of (group_id, DataFrame) tuples to allow iteration over groups. + + Returns + ------- + generator + Generator expression of GroupBy object broken down into tuples for iteration. + """ + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + ErrorMessage.method_not_implemented_error(name="_iter", class_="GroupBy") + + def _wrap_aggregation( + self, + qc_method: Callable, + numeric_only: bool = False, + agg_args: list[Any] = None, + agg_kwargs: dict[str, Any] = None, + is_result_dataframe: Optional[bool] = None, + **kwargs: Any, + ): + """ + Perform common metadata transformations and apply groupby functions. + + Parameters + ---------- + qc_method : callable + The query compiler method to call. + numeric_only : bool, default: False + Specifies whether to aggregate non numeric columns: + - True: include only numeric columns (including categories that holds a numeric dtype) + - False: include all columns + agg_args : list-like, optional + Positional arguments to pass to the aggregation function. + agg_kwargs : dict-like, optional + Keyword arguments to pass to the aggregation function. + is_result_dataframe: bool optional + whether the result of aggregation is a dataframe or series. If None, is_result_dataframe will be + False for SeriesGroupBy, and True for DataFrameGroupBy. + **kwargs : dict + Keyword arguments to pass to the specified query compiler's method. + + Returns + ------- + DataFrame or Series + Returns the same type as `self._df`. + """ + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + numeric_only = validate_bool_kwarg( + numeric_only, "numeric_only", none_allowed=True + ) + + agg_args = tuple() if agg_args is None else agg_args + agg_kwargs = dict() if agg_kwargs is None else agg_kwargs + + is_series_groupby = self.ndim == 1 + if is_series_groupby: + # when ndim is 1, it is SeriesGroupBy. SeriesGroupBy does not implement numeric_only + # parameter even if it accepts the parameter, and the aggregation is handled the + # same as numeric_only is False. + if numeric_only and not is_numeric_dtype(self._query_compiler.dtypes[0]): + # pandas throws an NotImplementedError when the numeric_only is True, but the + # series dtype is not numeric + ErrorMessage.not_implemented( + "SeriesGroupBy does not implement numeric_only" + ) + numeric_only = False + + if is_result_dataframe is None: + is_result_dataframe = not is_series_groupby + result_type = pd.DataFrame if is_result_dataframe else pd.Series + result = result_type( + query_compiler=qc_method( + self._query_compiler, + by=self._by, + axis=self._axis, + groupby_kwargs=self._kwargs, + agg_args=agg_args, + agg_kwargs=agg_kwargs, + numeric_only=numeric_only, + is_series_groupby=is_series_groupby, + **kwargs, + ) + ) + return result + + def _check_index_name(self, result): + """ + Check the result of groupby aggregation on the need of resetting index name. + + Parameters + ---------- + result : DataFrame + Group by aggregation result. + + Returns + ------- + DataFrame + """ + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions + if self._by is not None: + # pandas does not name the index for this case + result._query_compiler.set_index_name(None) + return result + + +@_inherit_docstrings( + pandas.core.groupby.SeriesGroupBy, modify_doc=doc_replace_dataframe_with_link +) +class SeriesGroupBy(DataFrameGroupBy): + _pandas_class = pandas.core.groupby.SeriesGroupBy + + @property + def ndim(self): + """ + Return 1. + + Returns + ------- + int + Returns 1. + + Notes + ----- + Deprecated and removed in pandas and will be likely removed in Modin. + """ + # TODO: SNOW-1063350: Modin upgrade - modin.pandas.groupby.SeriesGroupBy functions + return 1 # ndim is always 1 for Series + + @property + def _iter(self): + """ + Construct a tuple of (group_id, Series) tuples to allow iteration over groups. + + Returns + ------- + generator + Generator expression of GroupBy object broken down into tuples for iteration. + """ + # TODO: SNOW-1063350: Modin upgrade - modin.pandas.groupby.SeriesGroupBy functions + ErrorMessage.method_not_implemented_error(name="_iter", class_="GroupBy") + + @property + def is_monotonic_decreasing(self): + # TODO: SNOW-1063350: Modin upgrade - modin.pandas.groupby.SeriesGroupBy functions + ErrorMessage.method_not_implemented_error( + name="is_monotonic_decreasing", class_="GroupBy" + ) + + @property + def is_monotonic_increasing(self): + # TODO: SNOW-1063350: Modin upgrade - modin.pandas.groupby.SeriesGroupBy functions + ErrorMessage.method_not_implemented_error( + name="is_monotonic_increasing", class_="GroupBy" + ) + + def aggregate( + self, + func: Optional[AggFuncType] = None, + *args: Any, + engine: Optional[Literal["cython", "numba"]] = None, + engine_kwargs: Optional[dict[str, bool]] = None, + **kwargs: Any, + ): + # TODO: SNOW-1063350: Modin upgrade - modin.pandas.groupby.SeriesGroupBy functions + if is_dict_like(func): + raise SpecificationError( + "Value for func argument in dict format is not allowed for SeriesGroupBy." + ) + + return super().aggregate( + func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs + ) + + agg = aggregate + + def nlargest(self, n=5, keep="first"): + # TODO: SNOW-1063350: Modin upgrade - modin.pandas.groupby.SeriesGroupBy functions + ErrorMessage.method_not_implemented_error(name="nlargest", class_="GroupBy") + + def nsmallest(self, n=5, keep="first"): + # TODO: SNOW-1063350: Modin upgrade - modin.pandas.groupby.SeriesGroupBy functions + ErrorMessage.method_not_implemented_error(name="nsmallest", class_="GroupBy") + + def unique(self): + # TODO: SNOW-1063350: Modin upgrade - modin.pandas.groupby.SeriesGroupBy functions + ErrorMessage.method_not_implemented_error(name="unique", class_="GroupBy") + + def apply(self, func, *args, **kwargs): + """Not implemented yet""" + # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.SeriesGroupBy functions + ErrorMessage.not_implemented("apply is not implemented for SeriesGroupBy") + + +def validate_groupby_args( + by: Any, + level: Optional[IndexLabel], + observed: Union[bool, NoDefault], +) -> None: + """ + Common validation and checks for the groupby arguments that are used by both SeriesGroupBy + and DataFrameGroupBy. + + Raises: + TypeError if native pandas series is used as by item, or if both level and by are None + Warns: + If observed is True, this parameter is ignored because CategoryDType is not supported with Snowpark pandas API + """ + # TODO: SNOW-1063350: Modin upgrade - modin.pandas.groupby.SeriesGroupBy functions + # check if pandas.Series is used as by item, no native pandas series or dataframe + # object is allowed. + raise_if_native_pandas_objects(by) + if not isinstance(by, Series) and is_list_like(by): + for o in by: + raise_if_native_pandas_objects(o) + + if level is None and by is None: + raise TypeError("You have to supply one of 'by' and 'level'") + + if observed is not no_default and observed: + WarningMessage.ignored_argument( + operation="groupby", + argument="observed", + message="CategoricalDType is not yet supported with Snowpark pandas API, the observed parameter is ignored.", + ) diff --git a/src/snowflake/snowpark/modin/pandas/indexing.py b/src/snowflake/snowpark/modin/pandas/indexing.py new file mode 100644 index 00000000000..e9ffc5e471d --- /dev/null +++ b/src/snowflake/snowpark/modin/pandas/indexing.py @@ -0,0 +1,1337 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +# noqa: MD02 +""" +Details about how Indexing Helper Class works. + +_LocationIndexerBase provide methods framework for __getitem__ + and __setitem__ that work with Modin DataFrame's internal index. Base + class's __{get,set}item__ takes in partitions & idx_in_partition data + and perform lookup/item write. + +_LocIndexer and _iLocIndexer is responsible for indexer specific logic and + lookup computation. Loc will take care of enlarge DataFrame. Both indexer + will take care of translating pandas' lookup to Modin DataFrame's internal + lookup. + +An illustration is available at +https://github.com/ray-project/ray/pull/1955#issuecomment-386781826 +""" + +import itertools +import numbers +from typing import Any, Callable, Optional, Union + +import numpy as np +import pandas +from pandas._libs.tslibs import Resolution, parsing +from pandas._typing import AnyArrayLike, Scalar +from pandas.api.types import is_bool, is_list_like +from pandas.core.dtypes.common import ( + is_bool_dtype, + is_datetime64_any_dtype, + is_integer, + is_integer_dtype, + is_numeric_dtype, + pandas_dtype, +) +from pandas.core.indexing import IndexingError + +import snowflake.snowpark.modin.pandas as pd +import snowflake.snowpark.modin.pandas.utils as frontend_utils +from snowflake.snowpark.modin.pandas.base import BasePandasDataset +from snowflake.snowpark.modin.pandas.dataframe import DataFrame +from snowflake.snowpark.modin.pandas.series import ( + SERIES_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE_ERROR_MESSAGE, + SERIES_SETITEM_SLICE_AS_SCALAR_VALUE_ERROR_MESSAGE, + Series, +) +from snowflake.snowpark.modin.pandas.utils import is_scalar +from snowflake.snowpark.modin.plugin._internal.indexing_utils import ( + MULTIPLE_ELLIPSIS_INDEXING_ERROR_MESSAGE, + TOO_MANY_INDEXERS_INDEXING_ERROR_MESSAGE, +) +from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import ( + SnowflakeQueryCompiler, +) +from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage + +INDEXING_KEY_TYPE = Union[Scalar, list, slice, Callable, tuple, AnyArrayLike] +INDEXING_ITEM_TYPE = Union[Scalar, AnyArrayLike, pd.Series, pd.DataFrame] +INDEXING_LOCATOR_TYPE = Union[Scalar, list, slice, tuple, pd.Series] + +ILOC_SET_INDICES_MUST_BE_INTEGER_OR_BOOL_ERROR_MESSAGE = ( + "arrays used as indices must be of integer (or boolean) type" +) +ILOC_GET_REQUIRES_NUMERIC_INDEXERS_ERROR_MESSAGE = ( + ".iloc requires numeric indexers, got {}" +) +LOC_SET_INCOMPATIBLE_INDEXER_WITH_DF_ERROR_MESSAGE = ( + "Incompatible indexer with DataFrame" +) +LOC_SET_INCOMPATIBLE_INDEXER_WITH_SERIES_ERROR_MESSAGE = ( + "Incompatible indexer with Series" +) +LOC_SET_INCOMPATIBLE_INDEXER_WITH_SCALAR_ERROR_MESSAGE = ( + "Scalar indexer incompatible with {} item" +) +SET_CELL_WITH_LIST_LIKE_VALUE_ERROR_MESSAGE = ( + "Currently do not support setting cell with list-like values" +) + + +ILOC_GET_DATAFRAME_INDEXER_NOT_ALLOWED_ERROR_MESSAGE = ( + "DataFrame indexer is not allowed for .iloc\nConsider using" + " .loc for automatic alignment." +) + + +def is_boolean_array(x: Any) -> bool: + """ + Check that argument is an array of bool. + + Parameters + ---------- + x : object + Object to check. + + Returns + ------- + bool + True if argument is an array of bool, False otherwise. + """ + + # special case empty list is not regarded as boolean array; + # because of later Numpy versions (for Python 3.9+), can't + # compare directly to [], but need workaround to detect list properly + if isinstance(x, list) and 0 == len(x): + return False + + if isinstance(x, (np.ndarray, Series, pandas.Series, pandas.Index)): + # check dtype, if != object, no need to perform element-wise check + if pandas_dtype(x.dtype) != pandas_dtype("object"): + return is_bool_dtype(x.dtype) + elif isinstance(x, (DataFrame, pandas.DataFrame)): + return all(map(is_bool_dtype, x.dtypes)) + return is_list_like(x) and all(map(is_bool, x)) + + +def is_2d_array(x: Any) -> bool: + """ + Check that argument is a 2D array. + + Parameters + ---------- + x : object + Object to check. + + Returns + ------- + bool + True if argument is a 2D array, False otherwise. + """ + return isinstance(x, (list, np.ndarray)) and len(x) > 0 and is_list_like(x[0]) + + +def is_range_like(obj: Any) -> bool: + """ + Check if the object is range-like. + + Objects that are considered range-like have information about the range (start and + stop positions, and step) and also have to be iterable. Examples of range-like + objects are: Python range, pandas.RangeIndex. + + Parameters + ---------- + obj : object + + Returns + ------- + bool + """ + if not isinstance(obj, (DataFrame, Series)): + return ( + hasattr(obj, "__iter__") + and hasattr(obj, "start") + and hasattr(obj, "stop") + and hasattr(obj, "step") + ) + else: + # This would potentially have to change once RangeIndex is supported + return False + + +def boolean_mask_to_numeric(indexer: Any) -> np.ndarray: + """ + Convert boolean mask to numeric indices. + + Parameters + ---------- + indexer : list-like of booleans + + Returns + ------- + np.ndarray of ints + Numerical positions of ``True`` elements in the passed `indexer`. + """ + if isinstance(indexer, (np.ndarray, Series, pandas.Series)): + return np.where(indexer)[0] + else: + # It's faster to build the resulting numpy array from the reduced amount of data via + # `compress` iterator than convert non-numpy-like `indexer` to numpy and apply `np.where`. + return np.fromiter( + # `itertools.compress` masks `data` with the `selectors` mask, + # works about ~10% faster than a pure list comprehension + itertools.compress(data=range(len(indexer)), selectors=indexer), + dtype=np.int64, + ) + + +def check_dict_or_set_indexers(key: Any) -> None: + """ + Check if the indexer is or contains a dict or set, which is no longer allowed since pandas 2.0. + Our error messages and types are the same as pandas 2.0. + + Raises + ---------- + TypeError: + If key is set or dict type or a tuple with any set or dict type item. + """ + if ( + isinstance(key, set) + or isinstance(key, tuple) + and any(isinstance(x, set) for x in key) + ): + raise TypeError( + "Passing a set as an indexer is not supported. Use a list instead." + ) + + if ( + isinstance(key, dict) + or isinstance(key, tuple) + and any(isinstance(x, dict) for x in key) + ): + raise TypeError( + "Passing a dict as an indexer is not supported. Use a list instead." + ) + + +def validate_positional_slice(slice_key: Any) -> None: + """ + Validate slice start, stop, and step are int typed. + + Parameters + ---------- + slice_key : slice or is_range_like + + Raises + ---------- + TypeError: + If the start, stop, or step of slice_key is not None and is not integer. + """ + for key in [slice_key.start, slice_key.stop, slice_key.step]: + if key is not None and not is_integer(key): + raise TypeError( + f"cannot do positional indexing with these indexers [{key}] of type {type(key).__name__}" + ) + + +class _LocationIndexerBase: + """ + Base class for location indexer like loc and iloc. + + Parameters + ---------- + modin_df : modin.pandas.DataFrame + DataFrame to operate on. + """ + + api_name = "undefined" + + def __init__(self, modin_df: BasePandasDataset) -> None: + # TODO: SNOW-1063351: Modin upgrade - modin.pandas.indexing._LocationIndexerBase + self.df = modin_df + self.qc = modin_df._query_compiler + + def _validate_key_length_with_ellipsis_stripping(self, key: tuple) -> tuple: + """ + Validate tuple type key's length and strip leading ellipsis. + + If tuple length is no greater than ndim of DataFrame df: return key + Else: + If the first entry is ellipsis, strip leading ellipsis and call this function + on the remaining tuple again. + Else raise IndexingError. + + e.g. (..., 2 , 3) is reduced to (2 , 3); (..., 3) is reduced to (3,) + """ + if len(key) > self.df.ndim: + if key[0] is Ellipsis: + # e.g. Series.iloc[..., 3] reduces to just Series.iloc[3] + key = key[1:] + if Ellipsis in key: + raise IndexingError(MULTIPLE_ELLIPSIS_INDEXING_ERROR_MESSAGE) + return self._validate_key_length_with_ellipsis_stripping(key) + raise IndexingError(TOO_MANY_INDEXERS_INDEXING_ERROR_MESSAGE) + return key + + def __getitem__(self, key: INDEXING_KEY_TYPE) -> None: # pragma: no cover + """ + Retrieve dataset according to `key`. + + Parameters + ---------- + key : callable, scalar, or tuple + The global row index to retrieve data from. + + Returns + ------- + modin.pandas.DataFrame or modin.pandas.Series + Located dataset. + + See Also + -------- + pandas.DataFrame.loc + """ + # TODO: SNOW-1063351: Modin upgrade - modin.pandas.indexing._LocationIndexerBase + ErrorMessage.not_implemented("Implemented by subclasses") + + def __setitem__( + self, key: INDEXING_KEY_TYPE, item: INDEXING_ITEM_TYPE + ) -> None: # pragma: no cover + """ + Assign `item` value to dataset located by `key`. + + Parameters + ---------- + key : callable or tuple + The global row numbers to assign data to. + item : modin.pandas.DataFrame, modin.pandas.Series or scalar + Value that should be assigned to located dataset. + + See Also + -------- + pandas.DataFrame.iloc + """ + # TODO: SNOW-1063351: Modin upgrade - modin.pandas.indexing._LocationIndexerBase + ErrorMessage.not_implemented("Implemented by subclasses") + + def _should_squeeze( + self, + locator: Union[Scalar, list, slice, tuple, pd.Series], + axis: int, + ) -> Optional[bool]: + """ + The method helps to make the decision whether squeeze is needed to get the final pandas object. Specifically, + squeeze is needed: + - if self is series and axis = 1 + - if the locator are not scalar and tuple + Otherwise, the decision is not sure (return None) + + Args: + locator: locator on the axis + axis: the axis to check + + Returns: + A tuple of boolean values to indicate whether to squeeze on the two axis. + """ + # TODO: SNOW-1063351: Modin upgrade - modin.pandas.indexing._LocationIndexerBase + if axis == 1 and isinstance(self.df, Series): + # squeeze col is always False for Series + return False + + not_dataset = not isinstance(locator, BasePandasDataset) + is_scalar_loc = not_dataset and is_scalar(locator) + is_tuple_loc = not_dataset and isinstance(locator, tuple) + + if not is_scalar_loc and not is_tuple_loc: + # no need to squeeze if any axis key are not scalar or tuple + return False + + # otherwise, not sure + return None + + def _get_pandas_object_from_qc_view( + self, + qc_view: SnowflakeQueryCompiler, + *, + squeeze_row: bool, + squeeze_col: bool, + ) -> Union[Scalar, pd.Series, pd.DataFrame]: + """ + Convert the query compiler view to the appropriate pandas object. The method helps to call squeeze to get the + final pandas object. + Args: + qc_view: SnowflakeQueryCompiler + Query compiler to convert. + squeeze_row: bool + Whether to squeeze row + squeeze_col: bool + Whether to squeeze column + + Returns: DataFrame, Series or Scalar + The pandas object with the data from the query compiler view. + """ + # TODO: SNOW-1063351: Modin upgrade - modin.pandas.indexing._LocationIndexerBase + res_df = self.df.__constructor__(query_compiler=qc_view) + + if not squeeze_row and not squeeze_col: + return res_df + + if squeeze_row and squeeze_col: + axis = None + elif squeeze_row: + axis = 0 + else: + axis = 1 + return res_df.squeeze(axis=axis) + + def _parse_row_and_column_locators( + self, key: INDEXING_KEY_TYPE + ) -> tuple[INDEXING_LOCATOR_TYPE, INDEXING_LOCATOR_TYPE]: + """ + Unpack the user input. This shared parsing helper method is used by both iloc and loc's getitem and setitem. + + Examples: + loc[:] -> (slice(None), slice(None)) + loc[a] -> (a, slice(None)) + loc[,b] -> (slice(None), b) + loc[a,:] -> (a, slice(None)) + loc[:,b] -> (slice(None), b) + loc[a,...] -> (a, slice(None)) + loc[...,b] -> (slice(None), b) + loc[[a,b]] -> ([a,b], slice(None)), + loc[a,b] -> ([a], [b]) + loc[...,a,b] -> ([a], [b]) + loc[lambda df: df.col > 0,b] -> (df.col > 0, [b]) + (same for iloc too) + + Args: + key: User input to unpack. + + Returns: + row_loc : scalar or list + Row locator(s) as a scalar or list. + col_loc : scalar or list + Column locator(s) as a scalar or list. + + Raises: + index error if key is tuple(...,...) + """ + # TODO: SNOW-1063351: Modin upgrade - modin.pandas.indexing._LocationIndexerBase + row_loc: INDEXING_LOCATOR_TYPE = slice(None) + col_loc: INDEXING_LOCATOR_TYPE = slice(None) + if isinstance(key, tuple): + key = self._validate_key_length_with_ellipsis_stripping(key) + if len(key) > 2: + raise IndexingError(TOO_MANY_INDEXERS_INDEXING_ERROR_MESSAGE) + if len(key) > 0: + row_loc = key[0] + if len(key) == 2: + if key[0] is Ellipsis and key[1] is Ellipsis: + raise IndexingError(MULTIPLE_ELLIPSIS_INDEXING_ERROR_MESSAGE) + col_loc = key[1] + else: + row_loc = key + + def _parse_locator(_key: INDEXING_LOCATOR_TYPE) -> INDEXING_LOCATOR_TYPE: + # Ellipsis to slice(None) + if _key is Ellipsis: + return slice(None) + # callable will be evaluated to use the result as locator + if callable(_key): + _key = _key(self.df) + return _key + + return _parse_locator(row_loc), _parse_locator(col_loc) + + def _parse_get_row_and_column_locators( + self, key: INDEXING_KEY_TYPE + ) -> tuple[ + Union[Scalar, list, slice, tuple, pd.Series], + Union[Scalar, list, slice, tuple, pd.Series], + ]: + """Used by loc and iloc. See _LocationIndexerBase._parse_row_and_column_locators""" + # TODO: SNOW-1063351: Modin upgrade - modin.pandas.indexing._LocationIndexerBase + row_key, col_key = self._parse_row_and_column_locators(key) + self._validate_get_locator_key(row_key) + self._validate_get_locator_key(col_key) + + return row_key, col_key + + def _parse_set_row_and_column_locators( + self, key: INDEXING_KEY_TYPE + ) -> tuple[ + Union[Scalar, list, slice, tuple, pd.Series], + Union[Scalar, list, slice, tuple, pd.Series], + ]: + """Used by loc and iloc. See _LocationIndexerBase._parse_row_and_column_locators""" + # TODO: SNOW-1063351: Modin upgrade - modin.pandas.indexing._LocationIndexerBase + row_key, col_key = self._parse_row_and_column_locators(key) + self._validate_set_locator_key(row_key) + self._validate_set_locator_key(col_key) + + return row_key, col_key + + def _is_multiindex_full_lookup( + self, axis: int, key: Union[Scalar, list, slice, tuple, pd.Series] + ) -> bool: + """ + Determine if the key will perform a full lookup for MultiIndex. "Multiindex full lookup" is True only when the + axis is MultiIndex and the key is a tuple and the number of levels matches up with the length of the tuple key. + When it is True, pandas will drop all levels from the multiindex axis and call squeeze on the axis. + + Examples: + if self has a three level multiindex ["l0","l1","l2], then key has to be a tuple with length equals to 3 to + perform a multiindex full lookup. + + Args: + axis: {0, 1} + 0 for row, 1 for column. + key: Scalar, tuple, or other list like + Lookup key for MultiIndex row/column. + + Returns: bool + True if the key will perform a full lookup for the MultiIndex. + + """ + # TODO: SNOW-1063351: Modin upgrade - modin.pandas.indexing._LocationIndexerBase + if not self.qc.has_multiindex(axis=axis): + return False + + if not isinstance(key, tuple): + return False + + if any(isinstance(key_level, slice) for key_level in key): + # do not squeeze if any level of the key is a slice + return False + + return len(key) == self.qc.nlevels(axis) + + def _validate_locator_key(self, key: INDEXING_KEY_TYPE) -> None: + """Validate indexing key type. + + Parameters + ---------- + key: indexing key + + Raises + ------ + TypeError: + native pandas object. + set or dict. + all other types out of scalar, list like, slice, series, or, index. + For iloc, raise if scalar is not integer + IndexingError: + tuple. + ValueError: + SnowDataFrame. + """ + # TODO: SNOW-1063351: Modin upgrade - modin.pandas.indexing._LocationIndexerBase + frontend_utils.raise_if_native_pandas_objects(key) + check_dict_or_set_indexers(key) + + if not ( + is_scalar(key) + or isinstance(key, (pd.Series, slice)) + or is_list_like(key) + or is_range_like(key) + ): + raise TypeError( + f".{self.api_name} requires scalars, list-like indexers, slices, or ranges. Got {key}" + ) + + def _validate_get_locator_key(self, key: INDEXING_KEY_TYPE) -> None: + """ + Helper function to validate the locator key for get is valid. + + Parameter: + ---------- + key: get locator key + + """ + # TODO: SNOW-1063351: Modin upgrade - modin.pandas.indexing._LocationIndexerBase + self._validate_locator_key(key) + + def _validate_set_locator_key(self, key: INDEXING_KEY_TYPE) -> None: + """ + Helper function to validate the locator key for set is valid. + + Parameter: + ---------- + key: set locator key + + """ + # TODO: SNOW-1063351: Modin upgrade - modin.pandas.indexing._LocationIndexerBase + self._validate_locator_key(key) + + +class _LocIndexer(_LocationIndexerBase): + """ + An indexer for modin_df.loc[] functionality. + + Parameters + ---------- + modin_df : modin.pandas.DataFrame + DataFrame to operate on. + """ + + api_name = "loc" + + def _should_squeeze( + self, + locator: Union[Scalar, list, slice, tuple, pd.Series], + axis: int, + ) -> bool: + """ + The method helps to make the decision whether squeeze is needed to get the final pandas object. Specifically, + squeeze is needed: + - if self is series and axis = 1 + - if the locator are not scalar and tuple + - if the locator is scalar but on a multiindex + - if it is a multiindex full lookup, i.e., an exact match on the multiindex + + Args: + locator: locator on the axis + axis: the axis to check + + Returns: + A tuple of boolean values to indicate whether to squeeze on the two axis. + """ + # TODO: SNOW-1063352: Modin upgrade - modin.pandas.indexing._LocIndexer + do_squeeze = super()._should_squeeze(locator, axis) + if do_squeeze is not None: + return do_squeeze + + not_dataset = not isinstance(locator, BasePandasDataset) + is_scalar_loc = not_dataset and is_scalar(locator) + is_tuple_loc = not_dataset and isinstance(locator, tuple) + + if (is_scalar_loc or is_tuple_loc) and not self.qc.is_multiindex(axis=axis): + # for single index, if the locator is scalar or tuple, then squeeze is needed + return True + + if self._is_multiindex_full_lookup(axis=axis, key=locator): + # for multiindex, squeeze is needed only when full lookup happens, i.e., exact match on all levels. + return True + + # otherwise, no squeeze is needed + return False + + def _parse_row_and_column_locators( + self, key: INDEXING_KEY_TYPE + ) -> tuple[ + Union[Scalar, list, slice, tuple, pd.Series], + Union[Scalar, list, slice, tuple, pd.Series], + ]: + """ + Unpack the user input. This shared parsing helper method is used by both iloc and loc's getitem and setitem. + + Examples: + loc[:] -> (slice(None), slice(None)) + loc[a] -> (a, slice(None)) + loc[,b] -> (slice(None), b) + loc[a,:] -> (a, slice(None)) + loc[:,b] -> (slice(None), b) + loc[a,...] -> (a, slice(None)) + loc[...,b] -> (slice(None), b) + loc[[a,b]] -> ([a,b], slice(None)), + loc[a,b] -> ([a], [b]) + loc[...,a,b] -> ([a], [b]) + loc[lambda df: df.col > 0,b] -> (df.col > 0, [b]) + Also, for multiindex cases used by loc: + loc[("level0", "level1")] -> (("level0", "level1"), slice(None)) + + Args: + key: User input to unpack. + + Returns: + row_loc : scalar or list + Row locator(s) as a scalar or list. + col_loc : scalar or list + Column locator(s) as a scalar or list. + + Raises: + index error if key is tuple(...,...) + """ + # TODO: SNOW-1063352: Modin upgrade - modin.pandas.indexing._LocIndexer + if isinstance(key, tuple): + is_nested_tuple = any([not is_scalar(k) for k in key]) + if ( + self.qc.is_multiindex(axis=0) + and not is_nested_tuple + and not (self.df.ndim == 2 and self.qc.is_multiindex(axis=1)) + ): + # always treat tuple loc key as row_loc when the key is not nested tuple and the frame is a Series or + # the frame's column is not multiindex + # e.g., df.loc['cobra', 'mark i'], key = ('cobra', 'mark i') should be treated as row_loc if the row is + # multiindex or the frame is a Series + row_loc = key + if len(row_loc) > self.qc.nlevels(axis=0): + raise IndexingError(TOO_MANY_INDEXERS_INDEXING_ERROR_MESSAGE) + return row_loc, slice(None) + + return super()._parse_row_and_column_locators(key) + + def _locator_type_convert( + self, locator: INDEXING_LOCATOR_TYPE + ) -> Union[INDEXING_LOCATOR_TYPE, "SnowflakeQueryCompiler"]: + """ + A helper function to convert locator type before passing to the backend + Args: + locator: row or column locator + + Returns: + Processed locator + """ + # TODO: SNOW-1063352: Modin upgrade - modin.pandas.indexing._LocIndexer + if isinstance(locator, pd.Series): + locator = locator._query_compiler + elif not isinstance(locator, slice) and is_range_like(locator): + locator = slice(locator.start, locator.stop, locator.step) # type: ignore[union-attr] + return locator + + def _try_partial_string_indexing( + self, row_loc: Union[Scalar, list, slice, tuple, pd.Series] + ) -> Union[Scalar, list, slice, tuple, pd.Series]: + """ + Try to convert row locator to slice if it matches partial string indexing criteria: + 1. `row_loc` needs to be a valid datetime string + 2. the index is datetime type + + Args: + row_loc: the original row locator + + Returns: + the new row locator for partial string indexing; otherwise, the original row locator + """ + # TODO: SNOW-1063352: Modin upgrade - modin.pandas.indexing._LocIndexer + + def _try_partial_string_indexing_for_string( + row_loc: str, + ) -> Union[Scalar, list, slice, tuple, pd.Series]: + """ + Convert string `row_loc` into slice if it matches the partial string indexing criteria. Otherwise, return + the original `row_loc`. + + Args: + row_loc: input + + Returns: + slice or the original `row_loc` + """ + # TODO: SNOW-1063352: Modin upgrade - modin.pandas.indexing._LocIndexer + try: + parsed, reso_str = parsing.parse_datetime_string_with_reso(row_loc) + except ValueError: + return row_loc + + # extract tzinfo first since Period will drop tzinfo later; then the tzinfo will be added back when + # assembling the final slice + tzinfo = parsed.tzinfo + reso = Resolution.from_attrname(reso_str) + period = pd.Period(parsed, freq=reso.attr_abbrev) + + # partial string indexing only works for DatetimeIndex + if is_datetime64_any_dtype(self.df._query_compiler.index_dtypes[0]): + return slice( + pd.Timestamp(period.start_time, tzinfo=tzinfo), + pd.Timestamp(period.end_time, tzinfo=tzinfo), + ) + + return row_loc + + if isinstance(row_loc, str): + return _try_partial_string_indexing_for_string(row_loc) + + if isinstance(row_loc, slice): + start, stop = row_loc.start, row_loc.stop + if isinstance(row_loc.start, str): + start = _try_partial_string_indexing_for_string(row_loc.start) + if isinstance(start, slice): + start = start.start + if isinstance(row_loc.stop, str): + stop = _try_partial_string_indexing_for_string(row_loc.stop) + if isinstance(stop, slice): + stop = stop.stop + # partial string indexing only updates start and stop, and should keep using the original step. + row_loc = slice(start, stop, row_loc.step) + + return row_loc + + def __getitem__( + self, key: INDEXING_KEY_TYPE + ) -> Union[Scalar, pd.Series, pd.DataFrame]: + """ + Retrieve dataset according to `key`. + + Parameters + ---------- + key : callable, scalar, list-like, boolean mask, Snowpark pandas Series, slice, or size-two tuple of these + The 2D locator. + + Returns + ------- + modin.pandas.DataFrame or modin.pandas.Series + Located dataset. + + See Also + -------- + pandas.DataFrame.loc + """ + # TODO: SNOW-1063352: Modin upgrade - modin.pandas.indexing._LocIndexer + row_loc, col_loc = self._parse_get_row_and_column_locators(key) + row_loc = self._try_partial_string_indexing(row_loc) + squeeze_row, squeeze_col = self._should_squeeze( + locator=row_loc, axis=0 + ), self._should_squeeze(locator=col_loc, axis=1) + + qc_view = self.qc.take_2d_labels( + self._locator_type_convert(row_loc), self._locator_type_convert(col_loc) + ) + + result = self._get_pandas_object_from_qc_view( + qc_view, squeeze_row=squeeze_row, squeeze_col=squeeze_col + ) + if isinstance(result, Series): + result._parent = self.df + result._parent_axis = 0 + + return result + + def _loc_set_matching_item_columns_by_label( + self, key: INDEXING_KEY_TYPE, item: INDEXING_ITEM_TYPE + ) -> bool: + """ + Decide whether loc set behavior is to match item columns by label or by position. + Note: loc set's behavior is different when key is a tuple of row and col keys vs. key is a row key only. When + key is tuple (e.g., df.loc[row_key, col_key] = item), only ``item``'s column labels that match with col_key are + used to set df values; otherwise, (e.g., df.loc[row_key_only] = item), loc set columns based on ``item``'s + column positions not labels. E.g., df has columns ["A", "B", "C"] and item has columns ["C", "B", "A"], + df.loc[:] = item will update df's columns "A", "B", "C" using item column "C", "B", "A" respectively. + TODO: SNOW-972417 pandas has some complicated logic to use dtypes from both self df and item to decide whether + the loc set behavior for df.loc[row_key, col_key] = item) is matching by label or not. Further effort is needed + to decide what the right behavior for Snowpark pandas. + + Args: + key: loc key + item: the RHS in loc set + + Returns: + True if matching item by label + """ + # TODO: SNOW-1063352: Modin upgrade - modin.pandas.indexing._LocIndexer + if is_2d_array(item): + return False + return ( + isinstance(self.df, pd.DataFrame) + and isinstance(key, tuple) + and not is_scalar( + key[1] + ) # e.g., df.loc[:, 'A'] = item is matching item by position + and isinstance(item, pd.DataFrame) + ) + + def __setitem__( + self, + key: INDEXING_KEY_TYPE, + item: INDEXING_ITEM_TYPE, + ) -> None: + """ + Assign `item` value to dataset located by label `key`. + + Args: + key: indexing key type + item: indexing item type + + See Also: + DataFrame.loc + """ + # TODO: SNOW-1063352: Modin upgrade - modin.pandas.indexing._LocIndexer + row_loc, col_loc = self._parse_row_and_column_locators(key) + + # TODO SNOW-962260 support multiindex + if self.qc.is_multiindex(axis=0) or self.qc.is_multiindex(axis=1): + ErrorMessage.not_implemented( + "loc set for multiindex is not yet implemented" + ) + + self._validate_item_type(item, row_loc) + + # If the row key is list-like (Index, list, np.ndarray, etc.), convert it to Series. + if not isinstance(row_loc, pd.Series) and is_list_like(row_loc): + row_loc = pd.Series(row_loc) + + matching_item_columns_by_label = self._loc_set_matching_item_columns_by_label( + key, item + ) + item_is_2d_array = is_2d_array(item) + matching_item_rows_by_label = not item_is_2d_array + + index_is_bool_indexer = isinstance( + row_loc, BasePandasDataset + ) and is_bool_dtype(row_loc.dtypes) + + index = ( + row_loc._query_compiler + if isinstance(row_loc, BasePandasDataset) + else row_loc + ) + columns = ( + col_loc._query_compiler + if isinstance(col_loc, BasePandasDataset) + else col_loc + ) + if item_is_2d_array: + item = pd.DataFrame(item) + item = item._query_compiler if isinstance(item, BasePandasDataset) else item + new_qc = self.qc.set_2d_labels( + index, + columns, + item, + matching_item_columns_by_label=matching_item_columns_by_label, + matching_item_rows_by_label=matching_item_rows_by_label, + index_is_bool_indexer=index_is_bool_indexer, + ) + + self.df._update_inplace(new_query_compiler=new_qc) + + def _validate_locator_key(self, key: INDEXING_KEY_TYPE) -> None: + """Used by loc. See LocationIndexerBase._validate_locator_key""" + # TODO: SNOW-1063352: Modin upgrade - modin.pandas.indexing._LocIndexer + super()._validate_locator_key(key) + if isinstance(key, pd.DataFrame): + raise ValueError("Cannot index with multidimensional key") + + def _validate_item_type( + self, + item: INDEXING_ITEM_TYPE, + row_loc: Union[Scalar, list, slice, tuple, AnyArrayLike], + ) -> None: + """ + Validate item data type for loc set. Raise error if the type is invalid. + Args: + item: the item to set + row_loc: row locator + + Returns: + None + """ + # TODO: SNOW-1063352: Modin upgrade - modin.pandas.indexing._LocIndexer + frontend_utils.raise_if_native_pandas_objects(item) + + if isinstance(self.df, pd.Series): + if isinstance(item, pd.DataFrame): + raise ValueError(LOC_SET_INCOMPATIBLE_INDEXER_WITH_DF_ERROR_MESSAGE) + elif is_scalar(row_loc) and ( + isinstance(item, pd.Series) or is_list_like(item) + ): + ErrorMessage.not_implemented( + SET_CELL_WITH_LIST_LIKE_VALUE_ERROR_MESSAGE + ) + else: + if is_scalar(row_loc) and ( + isinstance(item, pd.DataFrame) or is_2d_array(item) + ): + raise ValueError( + LOC_SET_INCOMPATIBLE_INDEXER_WITH_SCALAR_ERROR_MESSAGE.format( + item.__class__.__name__ + ) + ) + + if (isinstance(row_loc, pd.Series) or is_list_like(row_loc)) and ( + isinstance(item, range) + ): + ErrorMessage.not_implemented( + SERIES_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE_ERROR_MESSAGE + ) + + if isinstance(item, slice): + # Here, the whole slice is assigned as a scalar variable, i.e., a spot at an index gets a slice value. + ErrorMessage.not_implemented( + SERIES_SETITEM_SLICE_AS_SCALAR_VALUE_ERROR_MESSAGE + ) + + +class _iLocIndexer(_LocationIndexerBase): + """ + An indexer for modin_df.iloc[] functionality. + + Parameters + ---------- + modin_df : modin.pandas.DataFrame + DataFrame to operate on. + """ + + api_name = "iloc" + + def _should_squeeze( + self, + locator: Union[Scalar, list, slice, tuple, pd.Series], + axis: int, + ) -> bool: + """ + The method helps to make the decision whether squeeze is needed to get the final pandas object. Specifically, + squeeze is needed: + - if self is series and axis = 1 + - if the locator are not scalar and tuple + - if the locator is scalar + + Args: + locator: locator on the axis + axis: the axis to check + + Returns: + A tuple of boolean values to indicate whether to squeeze on the two axis. + """ + # TODO: SNOW-1063355: Modin upgrade - modin.pandas.indexing._iLocIndexer + do_squeeze = super()._should_squeeze(locator, axis) + if do_squeeze is not None: + return do_squeeze + + not_dataset = not isinstance(locator, BasePandasDataset) + is_scalar_loc = not_dataset and is_scalar(locator) + if is_scalar_loc: + return True + + # otherwise, no squeeze is needed + return False + + @staticmethod + def _convert_range_to_valid_slice(range_key: Any) -> slice: + # TODO: SNOW-1063355: Modin upgrade - modin.pandas.indexing._iLocIndexer + start, stop, step = range_key.start, range_key.stop, range_key.step + # range has different logic from slice: slice can handle cases where (start > stop and step > 0) + # and (start < stop and step < 0) but range has an empty result for this. For example, slice(3, -1, 1) + # and slice(-1, 0, -1) are not empty results but range is. + if (start > stop and step > 0) or (start < stop and step < 0): + return slice(0, 0, 1) + else: + return slice(start, stop, step) + + def __getitem__( + self, + key: INDEXING_KEY_TYPE, + ) -> Union[Scalar, pd.DataFrame, pd.Series]: + """ + Retrieve dataset according to positional `key`. + + Args: + key: int, bool, list like of int or bool, slice of int, series, callable or tuple + The global row numbers to retrieve data from. + + Returns: + DataFrame, Series, or scalar. + """ + # TODO: SNOW-1063355: Modin upgrade - modin.pandas.indexing._iLocIndexer + row_loc, col_loc = self._parse_get_row_and_column_locators(key) + squeeze_row = self._should_squeeze(locator=row_loc, axis=0) + squeeze_col = self._should_squeeze(locator=col_loc, axis=1) + + original_row_loc = row_loc # keep a copy for error message + + # Convert range to slice objects. + if not isinstance(row_loc, pd.Series) and is_range_like(row_loc): + row_loc = self._convert_range_to_valid_slice(row_loc) + if not isinstance(col_loc, pd.Series) and is_range_like(col_loc): + col_loc = self._convert_range_to_valid_slice(col_loc) + + # Convert all scalar, list-like, and indexer row_loc to a Series object to get a query compiler object. + if is_scalar(row_loc): + row_loc = pd.Series([row_loc]) + elif is_list_like(row_loc): + if hasattr(row_loc, "dtype"): + dtype = row_loc.dtype + elif not row_loc: + # If the list-like object is empty, we need to explicitly specify a dtype + dtype = float + else: + dtype = None + row_loc = pd.Series(row_loc, dtype=dtype) + + # Check whether the row and column input is of numeric dtype. + self._validate_numeric_get_key_values(row_loc, original_row_loc) + self._validate_numeric_get_key_values(col_loc) + + if isinstance(row_loc, pd.Series): + # Get the corresponding query compiler object. + row_loc = row_loc._query_compiler + + if isinstance(col_loc, pd.Series): + col_loc = col_loc.to_list() + + qc_view = self.qc.take_2d_positional(row_loc, col_loc) + result = self._get_pandas_object_from_qc_view( + qc_view, + squeeze_row=squeeze_row, + squeeze_col=squeeze_col, + ) + + if isinstance(result, Series): + result._parent = self.df + result._parent_axis = 0 + return result + + def _get_pandas_object_from_qc_view( + self, + qc_view: SnowflakeQueryCompiler, + *, + squeeze_row: bool, + squeeze_col: bool, + ) -> Union[Scalar, list, pd.Series, pd.DataFrame]: + """ + Convert the query compiler view to the appropriate pandas object. + + Args: + qc_view: SnowflakeQueryCompiler + Query compiler to convert. + squeeze_row: bool + Whether to squeeze row + squeeze_col: bool + Whether to squeeze column + + Returns: DataFrame, Series or Scalar + The pandas object with the data from the query compiler view. + """ + # TODO: SNOW-1063355: Modin upgrade - modin.pandas.indexing._iLocIndexer + res_df = self.df.__constructor__(query_compiler=qc_view) + + if not squeeze_row and not squeeze_col: + return res_df + + if squeeze_row and squeeze_col: + res = res_df.to_pandas().squeeze() + # res can be an empty pandas series where the key is out of bounds, here we convert to an empty list to + # avoid return a native pandas object + if isinstance(res, pandas.Series): + res = [] + return res + + if squeeze_row: + if isinstance(res_df, pd.Series): + # call _reduce_dimension directly instead of calling series.squeeze() to avoid to call len(res_df) + res = res_df._reduce_dimension(res_df._query_compiler) + # res can be an empty pandas series where the key is out of bounds, here we convert to an empty list to + # avoid return a native pandas object + if isinstance(res, pandas.Series): + res = [] + return res + return Series(query_compiler=res_df.T._query_compiler) + + # where only squeeze_col is True + len_columns = len(res_df.columns) + if len_columns == 1: + return Series(query_compiler=res_df._query_compiler) + else: + return res_df.copy() + + def __setitem__( + self, + key: INDEXING_KEY_TYPE, + item: INDEXING_ITEM_TYPE, + ) -> None: + """ + Assign `item` value to dataset located by `key`. + + Parameters + ---------- + key : callable or tuple + The global row numbers to assign data to. + item : modin.pandas.DataFrame, modin.pandas.Series, scalar or list like of similar + Value that should be assigned to located dataset. + """ + # TODO: SNOW-1063355: Modin upgrade - modin.pandas.indexing._iLocIndexer + row_loc, col_loc = self._parse_set_row_and_column_locators(key) + + self._validate_numeric_set_key_values(row_loc) + self._validate_numeric_set_key_values(col_loc) + + is_item_series = isinstance(item, pd.Series) + + if not isinstance(item, BasePandasDataset) and is_list_like(item): + if isinstance(self.df, pd.Series) and is_scalar(row_loc): + ErrorMessage.not_implemented( + SET_CELL_WITH_LIST_LIKE_VALUE_ERROR_MESSAGE + ) + + if isinstance(item, pandas.Index): + item = np.array(item.tolist()).transpose() + else: + item = np.array(item) + + if all(sz == 1 for sz in item.shape): + # Treat as a scalar if a single value regardless of dimensions + item = item.flatten()[0] + else: + if item.ndim == 1: + item = pd.Series(item) + is_item_series = True + else: + item = pd.DataFrame(item) + + is_row_key_df = isinstance(row_loc, pd.DataFrame) + is_col_key_df = isinstance(col_loc, pd.DataFrame) + + # The semantics of iloc setitem differ if the row and col key are both + # tuples or dataframes, in particular they set as row, key location coordinates + # rather than entire rows or columns. So for example + # + # row_key=[1,2] and col_key=[3,4] would be locations (1,3), (1,4), (2,3), (2,4) + # but + # row_key=(1,2) and col_key=(3,4) would only set locations (1,3), (2, 4). + + if not is_row_key_df and not is_col_key_df: + set_as_coords = isinstance(row_loc, tuple) or isinstance(col_loc, tuple) + else: + set_as_coords = is_row_key_df and is_col_key_df + + new_qc = self.qc.set_2d_positional( + row_loc._query_compiler + if isinstance(row_loc, BasePandasDataset) + else row_loc, + col_loc._query_compiler + if isinstance(col_loc, BasePandasDataset) + else col_loc, + item._query_compiler if isinstance(item, BasePandasDataset) else item, + set_as_coords, + is_item_series, + ) + + self.df._create_or_update_from_compiler(new_qc, inplace=True) + + def _validate_locator_key(self, key: INDEXING_KEY_TYPE) -> None: + """Used by iloc. See _LocationIndexerBase._validate_locator_key""" + # TODO: SNOW-1063355: Modin upgrade - modin.pandas.indexing._iLocIndexer + super()._validate_locator_key(key) + + if isinstance(key, pd.MultiIndex): + raise TypeError("key of type MultiIndex cannot be used with iloc") + + def _validate_get_locator_key(self, key: INDEXING_KEY_TYPE) -> None: + """Used by iloc. See _LocationIndexerBase._validate_get_locator_key""" + # TODO: SNOW-1063355: Modin upgrade - modin.pandas.indexing._iLocIndexer + super()._validate_get_locator_key(key) + + if is_scalar(key) and not is_integer(key): + raise IndexError( + ILOC_GET_REQUIRES_NUMERIC_INDEXERS_ERROR_MESSAGE.format(key) + ) + + # Tuple e.g. (1, 2) + if isinstance(key, tuple): + # `key` is not allowed to be tuple since nested tuple is not allowed. + # `key` here, which is a 1d indexing key, is generated from 2d indexing key which split into two 1d indexing + # keys if is tuple type. e,g. 2d_key = ((1,2),0), then 1d key for row key=(1,2). This is not allowed. + raise IndexingError(TOO_MANY_INDEXERS_INDEXING_ERROR_MESSAGE) + + if isinstance(key, pd.DataFrame): + raise IndexError(ILOC_GET_DATAFRAME_INDEXER_NOT_ALLOWED_ERROR_MESSAGE) + + def _validate_set_locator_key(self, key: INDEXING_KEY_TYPE) -> None: + """Used by iloc. See _LocationIndexerBase._validate_set_locator_key""" + # TODO: SNOW-1063355: Modin upgrade - modin.pandas.indexing._iLocIndexer + super()._validate_set_locator_key(key) + + if is_scalar(key) and not is_integer(key): + raise IndexError( + ILOC_SET_INDICES_MUST_BE_INTEGER_OR_BOOL_ERROR_MESSAGE.format(key) + ) + + def _are_valid_numeric_key_values( + self, + key: Union[slice, int, list[int], list[bool], AnyArrayLike], + is_valid_numeric_dtype: Callable = is_numeric_dtype, + is_valid_numeric_type: Callable = lambda v: isinstance(v, numbers.Number), + ) -> bool: + """ + Validate iloc input key type after relevant type conversion. + + Args: + key: positional key or pd.Series version of positional key + is_valid_numeric_dtype: callable that checks numeric dtype + is_valid_numeric_type: callable that checks numeric type + + Returns: + bool: True if the key is valid else False for invalid key + + Notes: + Snowpark pandas implicitly allows float list like or series values to be compatible with pandas. + For row values, array-like objects, Index objects, and scalars must be converted to a Series object + before calling this method. The original key should be passed in along with the Series version for + printing the error message. + + Raises: + Series: + validate numeric type; + Scalar: + validate numeric type; + slice or range like: + validate start, stop, and step are int type. + list_like: + validate numeric type; + Other invalid types: + raise IndexingError. + """ + # TODO: SNOW-1063355: Modin upgrade - modin.pandas.indexing._iLocIndexer + if isinstance(key, pd.Series): + return is_valid_numeric_dtype(key.dtype) + elif isinstance(key, slice) or is_range_like(key): + validate_positional_slice(key) + elif hasattr(key, "dtype"): + return is_valid_numeric_dtype(key.dtype) + elif is_list_like(key): + return all( + is_scalar(x) and (is_bool(x) or is_valid_numeric_type(x)) for x in key # type: ignore[union-attr] + ) + elif is_scalar(key): + return is_valid_numeric_type(key) + + return True + + def _validate_numeric_get_key_values( + self, + key: Union[slice, int, list[int], list[bool], AnyArrayLike], + original_key: Union[slice, int, list[int], list[bool], AnyArrayLike] = None, + ) -> None: + """See _iLocIndexer._validate_numeric_key_values""" + # TODO: SNOW-1063355: Modin upgrade - modin.pandas.indexing._iLocIndexer + are_valid = self._are_valid_numeric_key_values(key) + if not are_valid: + raise IndexError( + ILOC_GET_REQUIRES_NUMERIC_INDEXERS_ERROR_MESSAGE.format( + key if original_key is None else original_key + ) + ) + + def _validate_numeric_set_key_values( + self, + key: Union[slice, int, list[int], list[bool], AnyArrayLike], + ) -> None: + """See _iLocIndexer._validate_numeric_key_values""" + # TODO: SNOW-1063355: Modin upgrade - modin.pandas.indexing._iLocIndexer + are_valid = self._are_valid_numeric_key_values( + key, + lambda k: is_integer_dtype(k) or is_bool_dtype(k), + lambda k: isinstance(k, numbers.Integral) or is_bool(k), + ) + if not are_valid: + raise IndexError(ILOC_SET_INDICES_MUST_BE_INTEGER_OR_BOOL_ERROR_MESSAGE) diff --git a/src/snowflake/snowpark/modin/pandas/io.py b/src/snowflake/snowpark/modin/pandas/io.py new file mode 100644 index 00000000000..09fa448e97e --- /dev/null +++ b/src/snowflake/snowpark/modin/pandas/io.py @@ -0,0 +1,1163 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +""" +Implement I/O public API as pandas does. + +Almost all docstrings for public and magic methods should be inherited from pandas +for better maintability. +Manually add documentation for methods which are not presented in pandas. +""" + +from __future__ import annotations + +import csv +import inspect +import pathlib +import pickle +from collections import OrderedDict +from typing import ( + IO, + TYPE_CHECKING, + Any, + AnyStr, + Callable, + Hashable, + Iterable, + Iterator, + Literal, + Pattern, + Sequence, +) + +import numpy as np +import pandas +from pandas._libs.lib import NoDefault, no_default +from pandas._typing import ( + CompressionOptions, + ConvertersArg, + CSVEngine, + DtypeArg, + DtypeBackend, + FilePath, + IndexLabel, + IntStrT, + ParseDatesArg, + ReadBuffer, + ReadCsvBuffer, + StorageOptions, + XMLParsers, +) +from pandas.io.parsers import TextFileReader +from pandas.io.parsers.readers import _c_parser_defaults + +# add this line to enable doc tests to run +from snowflake.snowpark.modin import pandas as pd # noqa: F401 +from snowflake.snowpark.modin.config import ExperimentalNumPyAPI +from snowflake.snowpark.modin.plugin._internal.telemetry import ( + snowpark_pandas_telemetry_standalone_function_decorator, +) +from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage +from snowflake.snowpark.modin.utils import ( + SupportsPrivateToNumPy, + SupportsPrivateToPandas, + SupportsPublicToNumPy, + _inherit_docstrings, + classproperty, + expanduser_path_arg, +) + +# below logic is to handle circular imports without errors +if TYPE_CHECKING: # pragma: no cover + from .dataframe import DataFrame + from .series import Series + +# TODO: SNOW-1265551: add inherit_docstrings decorators once docstring overrides are available + + +class ModinObjects: + """Lazily import Modin classes and provide an access to them.""" + + _dataframe = None + + @classproperty + def DataFrame(cls): + """Get ``modin.pandas.DataFrame`` class.""" + if cls._dataframe is None: + from .dataframe import DataFrame + + cls._dataframe = DataFrame + return cls._dataframe + + +def _read( + **kwargs, +): # pragma: no cover: our frontend currently overrides read_csv, so this is unused + """ + Read csv file from local disk. + + Parameters + ---------- + **kwargs : dict + Keyword arguments in pandas.read_csv. + + Returns + ------- + modin.pandas.DataFrame + """ + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + squeeze = kwargs.pop( + "squeeze", False + ) # pragma: no cover: this is a removed argument and should be removed upstream + pd_obj = FactoryDispatcher.read_csv(**kwargs) + # This happens when `read_csv` returns a TextFileReader object for iterating through + if isinstance(pd_obj, TextFileReader): # pragma: no cover + reader = pd_obj.read + pd_obj.read = lambda *args, **kwargs: ModinObjects.DataFrame( + query_compiler=reader(*args, **kwargs) + ) + return pd_obj + result = ModinObjects.DataFrame(query_compiler=pd_obj) + if squeeze: + return result.squeeze(axis=1) + return result + + +# TODO: SNOW-1265551: add inherit_docstrings decorators once docstring overrides are available +@expanduser_path_arg("path_or_buffer") +@snowpark_pandas_telemetry_standalone_function_decorator +def read_xml( + path_or_buffer: FilePath | ReadBuffer[bytes] | ReadBuffer[str], + *, + xpath: str = "./*", + namespaces: dict[str, str] | None = None, + elems_only: bool = False, + attrs_only: bool = False, + names: Sequence[str] | None = None, + dtype: DtypeArg | None = None, + converters: ConvertersArg | None = None, + parse_dates: ParseDatesArg | None = None, + encoding: str | None = "utf-8", + parser: XMLParsers = "lxml", + stylesheet: FilePath | ReadBuffer[bytes] | ReadBuffer[str] | None = None, + iterparse: dict[str, list[str]] | None = None, + compression: CompressionOptions = "infer", + storage_options: StorageOptions = None, + dtype_backend: DtypeBackend | NoDefault = no_default, +) -> DataFrame: + # TODO(https://github.com/modin-project/modin/issues/7104): + # modin needs to remove defaults to pandas at API layer + ErrorMessage.not_implemented() + + +@_inherit_docstrings(pandas.read_csv, apilink="pandas.read_csv") +@expanduser_path_arg("filepath_or_buffer") +@snowpark_pandas_telemetry_standalone_function_decorator +def read_csv( + filepath_or_buffer: FilePath | ReadCsvBuffer[bytes] | ReadCsvBuffer[str], + *, + sep: str | None | NoDefault = no_default, + delimiter: str | None | NoDefault = None, + # Column and Index Locations and Names + header: int | Sequence[int] | None | Literal["infer"] = "infer", + names: Sequence[Hashable] | None | NoDefault = no_default, + index_col: IndexLabel | Literal[False] | None = None, + usecols=None, + # General Parsing Configuration + dtype: DtypeArg | None = None, + engine: CSVEngine | None = None, + converters=None, + true_values=None, + false_values=None, + skipinitialspace: bool = False, + skiprows=None, + skipfooter: int = 0, + nrows: int | None = None, + # NA and Missing Data Handling + na_values=None, + keep_default_na: bool = True, + na_filter: bool = True, + verbose: bool = no_default, + skip_blank_lines: bool = True, + # Datetime Handling + parse_dates=None, + infer_datetime_format: bool = no_default, + keep_date_col: bool = no_default, + date_parser=no_default, + date_format=None, + dayfirst: bool = False, + cache_dates: bool = True, + # Iteration + iterator: bool = False, + chunksize: int | None = None, + # Quoting, Compression, and File Format + compression: CompressionOptions = "infer", + thousands: str | None = None, + decimal: str = ".", + lineterminator: str | None = None, + quotechar: str = '"', + quoting: int = csv.QUOTE_MINIMAL, + doublequote: bool = True, + escapechar: str | None = None, + comment: str | None = None, + encoding: str | None = None, + encoding_errors: str | None = "strict", + dialect: str | csv.Dialect | None = None, + # Error Handling + on_bad_lines="error", + # Internal + delim_whitespace: bool = no_default, + low_memory=_c_parser_defaults["low_memory"], + memory_map: bool = False, + float_precision: Literal["high", "legacy"] | None = None, + storage_options: StorageOptions = None, + dtype_backend: DtypeBackend | NoDefault = no_default, +) -> DataFrame | TextFileReader: # pragma: no cover: this function is overridden by plugin/pd_overrides.py + # ISSUE #2408: parse parameter shared with pandas read_csv and read_table and update with provided args + _pd_read_csv_signature = { + val.name for val in inspect.signature(pandas.read_csv).parameters.values() + } + _, _, _, f_locals = inspect.getargvalues(inspect.currentframe()) + kwargs = {k: v for k, v in f_locals.items() if k in _pd_read_csv_signature} + return _read(**kwargs) + + +@_inherit_docstrings(pandas.read_table, apilink="pandas.read_table") +@snowpark_pandas_telemetry_standalone_function_decorator +@expanduser_path_arg("filepath_or_buffer") +def read_table( + filepath_or_buffer: FilePath | ReadCsvBuffer[bytes] | ReadCsvBuffer[str], + *, + sep: str | None | NoDefault = no_default, + delimiter: str | None | NoDefault = None, + # Column and Index Locations and Names + header: int | Sequence[int] | None | Literal["infer"] = "infer", + names: Sequence[Hashable] | None | NoDefault = no_default, + index_col: IndexLabel | Literal[False] | None = None, + usecols=None, + # General Parsing Configuration + dtype: DtypeArg | None = None, + engine: CSVEngine | None = None, + converters=None, + true_values=None, + false_values=None, + skipinitialspace: bool = False, + skiprows=None, + skipfooter: int = 0, + nrows: int | None = None, + # NA and Missing Data Handling + na_values=None, + keep_default_na: bool = True, + na_filter: bool = True, + verbose: bool = False, + skip_blank_lines: bool = True, + # Datetime Handling + parse_dates=False, + infer_datetime_format: bool = no_default, + keep_date_col: bool = False, + date_parser=no_default, + date_format: str = None, + dayfirst: bool = False, + cache_dates: bool = True, + # Iteration + iterator: bool = False, + chunksize: int | None = None, + # Quoting, Compression, and File Format + compression: CompressionOptions = "infer", + thousands: str | None = None, + decimal: str = ".", + lineterminator: str | None = None, + quotechar: str = '"', + quoting: int = csv.QUOTE_MINIMAL, + doublequote: bool = True, + escapechar: str | None = None, + comment: str | None = None, + encoding: str | None = None, + encoding_errors: str | None = "strict", + dialect: str | csv.Dialect | None = None, + # Error Handling + on_bad_lines="error", + # Internal + delim_whitespace=False, + low_memory=_c_parser_defaults["low_memory"], + memory_map: bool = False, + float_precision: str | None = None, + storage_options: StorageOptions = None, + dtype_backend: DtypeBackend | NoDefault = no_default, +) -> DataFrame | TextFileReader: # pragma: no cover + # ISSUE #2408: parse parameter shared with pandas read_csv and read_table and update with provided args + _pd_read_table_signature = { + val.name for val in inspect.signature(pandas.read_table).parameters.values() + } + _, _, _, f_locals = inspect.getargvalues(inspect.currentframe()) + if f_locals.get("sep", sep) is False or f_locals.get("sep", sep) is no_default: + f_locals["sep"] = "\t" + kwargs = {k: v for k, v in f_locals.items() if k in _pd_read_table_signature} + return _read(**kwargs) + + +# TODO: SNOW-1265551: add inherit_docstrings decorators once docstring overrides are available +@snowpark_pandas_telemetry_standalone_function_decorator +@expanduser_path_arg("path") +def read_parquet( + path, + engine: str = "auto", + columns: list[str] | None = None, + storage_options: StorageOptions = None, + use_nullable_dtypes: bool = no_default, + dtype_backend=no_default, + filesystem=None, + filters=None, + **kwargs, +) -> DataFrame: # pragma: no cover: this function is overridden by plugin/pd_overrides.py + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + if engine == "fastparquet" and dtype_backend is not no_default: + raise ValueError( + "The 'dtype_backend' argument is not supported for the fastparquet engine" + ) + + return ModinObjects.DataFrame( + query_compiler=FactoryDispatcher.read_parquet( + path=path, + engine=engine, + columns=columns, + storage_options=storage_options, + use_nullable_dtypes=use_nullable_dtypes, + dtype_backend=dtype_backend, + filesystem=filesystem, + filters=filters, + **kwargs, + ) + ) + + +# TODO: SNOW-1265551: add inherit_docstrings decorators once docstring overrides are available +@snowpark_pandas_telemetry_standalone_function_decorator +@expanduser_path_arg("path_or_buf") +def read_json( + path_or_buf, + *, + orient: str | None = None, + typ: Literal["frame", "series"] = "frame", + dtype: DtypeArg | None = None, + convert_axes=None, + convert_dates: bool | list[str] = True, + keep_default_dates: bool = True, + precise_float: bool = False, + date_unit: str | None = None, + encoding: str | None = None, + encoding_errors: str | None = "strict", + lines: bool = False, + chunksize: int | None = None, + compression: CompressionOptions = "infer", + nrows: int | None = None, + storage_options: StorageOptions = None, + dtype_backend: DtypeBackend | NoDefault = no_default, + engine="ujson", +) -> DataFrame | Series | pandas.io.json._json.JsonReader: # pragma: no cover: this function is overridden by plugin/pd_overrides.py + _, _, _, kwargs = inspect.getargvalues(inspect.currentframe()) + + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + return ModinObjects.DataFrame(query_compiler=FactoryDispatcher.read_json(**kwargs)) + + +@_inherit_docstrings(pandas.read_gbq, apilink="pandas.read_gbq") +@snowpark_pandas_telemetry_standalone_function_decorator +def read_gbq( + query: str, + project_id: str | None = None, + index_col: str | None = None, + col_order: list[str] | None = None, + reauth: bool = False, + auth_local_webserver: bool = True, + dialect: str | None = None, + location: str | None = None, + configuration: dict[str, Any] | None = None, + credentials=None, + use_bqstorage_api: bool | None = None, + max_results: int | None = None, + progress_bar_type: str | None = None, +) -> DataFrame: + _, _, _, kwargs = inspect.getargvalues(inspect.currentframe()) + kwargs.update(kwargs.pop("kwargs", {})) + + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + return ModinObjects.DataFrame(query_compiler=FactoryDispatcher.read_gbq(**kwargs)) + + +@_inherit_docstrings(pandas.read_html, apilink="pandas.read_html") +@snowpark_pandas_telemetry_standalone_function_decorator +@expanduser_path_arg("io") +def read_html( + io, + *, + match: str | Pattern = ".+", + flavor: str | None = None, + header: int | Sequence[int] | None = None, + index_col: int | Sequence[int] | None = None, + skiprows: int | Sequence[int] | slice | None = None, + attrs: dict[str, str] | None = None, + parse_dates: bool = False, + thousands: str | None = ",", + encoding: str | None = None, + decimal: str = ".", + converters: dict | None = None, + na_values: Iterable[object] | None = None, + keep_default_na: bool = True, + displayed_only: bool = True, + extract_links: Literal[None, "header", "footer", "body", "all"] = None, + dtype_backend: DtypeBackend | NoDefault = no_default, + storage_options: StorageOptions = None, +) -> list[DataFrame]: # pragma: no cover # noqa: PR01, RT01, D200 + """ + Read HTML tables into a ``DataFrame`` object. + """ + _, _, _, kwargs = inspect.getargvalues(inspect.currentframe()) + + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + qcs = FactoryDispatcher.read_html(**kwargs) + return [ModinObjects.DataFrame(query_compiler=qc) for qc in qcs] + + +@_inherit_docstrings(pandas.read_clipboard, apilink="pandas.read_clipboard") +@snowpark_pandas_telemetry_standalone_function_decorator +def read_clipboard( + sep=r"\s+", + dtype_backend: DtypeBackend | NoDefault = no_default, + **kwargs, +): # pragma: no cover # noqa: PR01, RT01, D200 + """ + Read text from clipboard and pass to read_csv. + """ + _, _, _, kwargs = inspect.getargvalues(inspect.currentframe()) + kwargs.update(kwargs.pop("kwargs", {})) + + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + return ModinObjects.DataFrame( + query_compiler=FactoryDispatcher.read_clipboard(**kwargs) + ) + + +@_inherit_docstrings(pandas.read_excel, apilink="pandas.read_excel") +@snowpark_pandas_telemetry_standalone_function_decorator +@expanduser_path_arg("io") +def read_excel( + io, + sheet_name: str | int | list[IntStrT] | None = 0, + *, + header: int | Sequence[int] | None = 0, + names: list[str] | None = None, + index_col: int | Sequence[int] | None = None, + usecols: int + | str + | Sequence[int] + | Sequence[str] + | Callable[[str], bool] + | None = None, + dtype: DtypeArg | None = None, + engine: Literal[("xlrd", "openpyxl", "odf", "pyxlsb")] | None = None, + converters: dict[str, Callable] | dict[int, Callable] | None = None, + true_values: Iterable[Hashable] | None = None, + false_values: Iterable[Hashable] | None = None, + skiprows: Sequence[int] | int | Callable[[int], object] | None = None, + nrows: int | None = None, + na_values=None, + keep_default_na: bool = True, + na_filter: bool = True, + verbose: bool = False, + parse_dates: list | dict | bool = False, + date_parser: Callable | NoDefault = no_default, + date_format=None, + thousands: str | None = None, + decimal: str = ".", + comment: str | None = None, + skipfooter: int = 0, + storage_options: StorageOptions = None, + dtype_backend: DtypeBackend | NoDefault = no_default, + engine_kwargs: dict | None = None, +) -> DataFrame | dict[IntStrT, DataFrame]: # pragma: no cover + _, _, _, kwargs = inspect.getargvalues(inspect.currentframe()) + + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + intermediate = FactoryDispatcher.read_excel(**kwargs) + if isinstance(intermediate, (OrderedDict, dict)): + parsed = type(intermediate)() + for key in intermediate.keys(): + parsed[key] = ModinObjects.DataFrame(query_compiler=intermediate.get(key)) + return parsed + else: + return ModinObjects.DataFrame(query_compiler=intermediate) + + +@_inherit_docstrings(pandas.read_hdf, apilink="pandas.read_hdf") +@snowpark_pandas_telemetry_standalone_function_decorator +@expanduser_path_arg("path_or_buf") +def read_hdf( + path_or_buf, + key=None, + mode: str = "r", + errors: str = "strict", + where=None, + start: int | None = None, + stop: int | None = None, + columns=None, + iterator=False, + chunksize: int | None = None, + **kwargs, +): # noqa: PR01, RT01, D200 + """ + Read data from the store into DataFrame. + """ + _, _, _, kwargs = inspect.getargvalues(inspect.currentframe()) + kwargs.update(kwargs.pop("kwargs", {})) + + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + return ModinObjects.DataFrame(query_compiler=FactoryDispatcher.read_hdf(**kwargs)) + + +@_inherit_docstrings(pandas.read_feather, apilink="pandas.read_feather") +@snowpark_pandas_telemetry_standalone_function_decorator +@expanduser_path_arg("path") +def read_feather( + path, + columns: Sequence[Hashable] | None = None, + use_threads: bool = True, + storage_options: StorageOptions = None, + dtype_backend: DtypeBackend | NoDefault = no_default, +): + _, _, _, kwargs = inspect.getargvalues(inspect.currentframe()) + + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + return ModinObjects.DataFrame( + query_compiler=FactoryDispatcher.read_feather(**kwargs) + ) + + +@_inherit_docstrings(pandas.read_stata) +@snowpark_pandas_telemetry_standalone_function_decorator +@expanduser_path_arg("filepath_or_buffer") +def read_stata( + filepath_or_buffer, + *, + convert_dates: bool = True, + convert_categoricals: bool = True, + index_col: str | None = None, + convert_missing: bool = False, + preserve_dtypes: bool = True, + columns: Sequence[str] | None = None, + order_categoricals: bool = True, + chunksize: int | None = None, + iterator: bool = False, + compression: CompressionOptions = "infer", + storage_options: StorageOptions = None, +) -> DataFrame | pandas.io.stata.StataReader: + _, _, _, kwargs = inspect.getargvalues(inspect.currentframe()) + + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + return ModinObjects.DataFrame(query_compiler=FactoryDispatcher.read_stata(**kwargs)) + + +@_inherit_docstrings(pandas.read_sas, apilink="pandas.read_sas") +@snowpark_pandas_telemetry_standalone_function_decorator +@expanduser_path_arg("filepath_or_buffer") +def read_sas( + filepath_or_buffer, + *, + format: str | None = None, + index: Hashable | None = None, + encoding: str | None = None, + chunksize: int | None = None, + iterator: bool = False, + compression: CompressionOptions = "infer", +) -> DataFrame | pandas.io.sas.sasreader.ReaderBase: # noqa: PR01, RT01, D200 + """ + Read SAS files stored as either XPORT or SAS7BDAT format files. + """ + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + return ModinObjects.DataFrame( + query_compiler=FactoryDispatcher.read_sas( + filepath_or_buffer=filepath_or_buffer, + format=format, + index=index, + encoding=encoding, + chunksize=chunksize, + iterator=iterator, + compression=compression, + ) + ) + + +@_inherit_docstrings(pandas.read_pickle, apilink="pandas.read_pickle") +@snowpark_pandas_telemetry_standalone_function_decorator +@expanduser_path_arg("filepath_or_buffer") +def read_pickle( + filepath_or_buffer, + compression: CompressionOptions = "infer", + storage_options: StorageOptions = None, +): + _, _, _, kwargs = inspect.getargvalues(inspect.currentframe()) + + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + return ModinObjects.DataFrame( + query_compiler=FactoryDispatcher.read_pickle(**kwargs) + ) + + +@_inherit_docstrings(pandas.read_sql, apilink="pandas.read_sql") +@snowpark_pandas_telemetry_standalone_function_decorator +def read_sql( + sql, + con, + index_col=None, + coerce_float=True, + params=None, + parse_dates=None, + columns=None, + chunksize=None, + dtype_backend: DtypeBackend | NoDefault = no_default, + dtype=None, +): # pragma: no cover # noqa: PR01, RT01, D200 + """ + Read SQL query or database table into a DataFrame. + """ + _, _, _, kwargs = inspect.getargvalues(inspect.currentframe()) + + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + if kwargs.get("chunksize") is not None: + ErrorMessage.default_to_pandas("Parameters provided [chunksize]") + df_gen = pandas.read_sql(**kwargs) + return ( + ModinObjects.DataFrame(query_compiler=FactoryDispatcher.from_pandas(df)) + for df in df_gen + ) + return ModinObjects.DataFrame(query_compiler=FactoryDispatcher.read_sql(**kwargs)) + + +@_inherit_docstrings(pandas.read_fwf, apilink="pandas.read_fwf") +@snowpark_pandas_telemetry_standalone_function_decorator +@expanduser_path_arg("filepath_or_buffer") +def read_fwf( + filepath_or_buffer: str | pathlib.Path | IO[AnyStr], + *, + colspecs="infer", + widths=None, + infer_nrows=100, + dtype_backend: DtypeBackend | NoDefault = no_default, + **kwds, +): # pragma: no cover # noqa: PR01, RT01, D200 + """ + Read a table of fixed-width formatted lines into DataFrame. + """ + from pandas.io.parsers.base_parser import parser_defaults + + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + _, _, _, kwargs = inspect.getargvalues(inspect.currentframe()) + kwargs.update(kwargs.pop("kwds", {})) + target_kwargs = parser_defaults.copy() + target_kwargs.update(kwargs) + pd_obj = FactoryDispatcher.read_fwf(**target_kwargs) + # When `read_fwf` returns a TextFileReader object for iterating through + if isinstance(pd_obj, TextFileReader): + reader = pd_obj.read + pd_obj.read = lambda *args, **kwargs: ModinObjects.DataFrame( + query_compiler=reader(*args, **kwargs) + ) + return pd_obj + return ModinObjects.DataFrame(query_compiler=pd_obj) + + +@_inherit_docstrings(pandas.read_sql_table, apilink="pandas.read_sql_table") +@snowpark_pandas_telemetry_standalone_function_decorator +def read_sql_table( + table_name, + con, + schema=None, + index_col=None, + coerce_float=True, + parse_dates=None, + columns=None, + chunksize=None, + dtype_backend: DtypeBackend | NoDefault = no_default, +): # noqa: PR01, RT01, D200 + """ + Read SQL database table into a DataFrame. + """ + _, _, _, kwargs = inspect.getargvalues(inspect.currentframe()) + + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + return ModinObjects.DataFrame( + query_compiler=FactoryDispatcher.read_sql_table(**kwargs) + ) + + +@_inherit_docstrings(pandas.read_sql_query, apilink="pandas.read_sql_query") +@snowpark_pandas_telemetry_standalone_function_decorator +def read_sql_query( + sql, + con, + index_col: str | list[str] | None = None, + coerce_float: bool = True, + params: list[str] | dict[str, str] | None = None, + parse_dates: list[str] | dict[str, str] | None = None, + chunksize: int | None = None, + dtype: DtypeArg | None = None, + dtype_backend: DtypeBackend | NoDefault = no_default, +) -> DataFrame | Iterator[DataFrame]: + _, _, _, kwargs = inspect.getargvalues(inspect.currentframe()) + + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + return ModinObjects.DataFrame( + query_compiler=FactoryDispatcher.read_sql_query(**kwargs) + ) + + +@_inherit_docstrings(pandas.to_pickle) +@snowpark_pandas_telemetry_standalone_function_decorator +@expanduser_path_arg("filepath_or_buffer") +def to_pickle( + obj: Any, + filepath_or_buffer, + compression: CompressionOptions = "infer", + protocol: int = pickle.HIGHEST_PROTOCOL, + storage_options: StorageOptions = None, +) -> None: + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + if isinstance(obj, ModinObjects.DataFrame): + obj = obj._query_compiler + return FactoryDispatcher.to_pickle( + obj, + filepath_or_buffer=filepath_or_buffer, + compression=compression, + protocol=protocol, + storage_options=storage_options, + ) + + +@_inherit_docstrings(pandas.read_spss, apilink="pandas.read_spss") +@snowpark_pandas_telemetry_standalone_function_decorator +@expanduser_path_arg("path") +def read_spss( + path: str | pathlib.Path, + usecols: Sequence[str] | None = None, + convert_categoricals: bool = True, + dtype_backend: DtypeBackend | NoDefault = no_default, +): # noqa: PR01, RT01, D200 + """ + Load an SPSS file from the file path, returning a DataFrame. + """ + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + return ModinObjects.DataFrame( + query_compiler=FactoryDispatcher.read_spss( + path=path, + usecols=usecols, + convert_categoricals=convert_categoricals, + dtype_backend=dtype_backend, + ) + ) + + +@_inherit_docstrings(pandas.json_normalize, apilink="pandas.json_normalize") +@snowpark_pandas_telemetry_standalone_function_decorator +def json_normalize( + data: dict | list[dict], + record_path: str | list | None = None, + meta: str | list[str | list[str]] | None = None, + meta_prefix: str | None = None, + record_prefix: str | None = None, + errors: str | None = "raise", + sep: str = ".", + max_level: int | None = None, +) -> DataFrame: # noqa: PR01, RT01, D200 + """ + Normalize semi-structured JSON data into a flat table. + """ + # TODO(https://github.com/modin-project/modin/issues/7104): + # modin needs to remove defaults to pandas at API layer + ErrorMessage.not_implemented() + + +@_inherit_docstrings(pandas.read_orc, apilink="pandas.read_orc") +@snowpark_pandas_telemetry_standalone_function_decorator +@expanduser_path_arg("path") +def read_orc( + path, + columns: list[str] | None = None, + dtype_backend: DtypeBackend | NoDefault = no_default, + filesystem=None, + **kwargs, +) -> DataFrame: # noqa: PR01, RT01, D200 + """ + Load an ORC object from the file path, returning a DataFrame. + """ + # TODO(https://github.com/modin-project/modin/issues/7104): + # modin needs to remove defaults to pandas at API layer + ErrorMessage.not_implemented() + + +@_inherit_docstrings(pandas.HDFStore) +@snowpark_pandas_telemetry_standalone_function_decorator +class HDFStore(pandas.HDFStore): # pragma: no cover # noqa: PR01, D200 + """ + Dict-like IO interface for storing pandas objects in PyTables. + """ + + _return_modin_dataframe = True + + def __getattribute__(self, item): + default_behaviors = ["__init__", "__class__"] + method = super().__getattribute__(item) + if item not in default_behaviors: + if callable(method): + + def return_handler(*args, **kwargs): + """ + Replace the default behavior of methods with inplace kwarg. + + Returns + ------- + A Modin DataFrame in place of a pandas DataFrame, or the same + return type as pandas.HDFStore. + + Notes + ----- + This function will replace all of the arguments passed to + methods of HDFStore with the pandas equivalent. It will convert + Modin DataFrame to pandas DataFrame, etc. Currently, pytables + does not accept Modin DataFrame objects, so we must convert to + pandas. + """ + # We don't want to constantly be giving this error message for + # internal methods. + if item[0] != "_": + ErrorMessage.default_to_pandas(f"`{item}`") + args = [ + to_pandas(arg) + if isinstance(arg, ModinObjects.DataFrame) + else arg + for arg in args + ] + kwargs = { + k: to_pandas(v) if isinstance(v, ModinObjects.DataFrame) else v + for k, v in kwargs.items() + } + obj = super(HDFStore, self).__getattribute__(item)(*args, **kwargs) + if self._return_modin_dataframe and isinstance( + obj, pandas.DataFrame + ): + return ModinObjects.DataFrame(obj) + return obj + + # We replace the method with `return_handler` for inplace operations + method = return_handler + return method + + +@_inherit_docstrings(pandas.ExcelFile) +@snowpark_pandas_telemetry_standalone_function_decorator +class ExcelFile(pandas.ExcelFile): # pragma: no cover # noqa: PR01, D200 + """ + Class for parsing tabular excel sheets into DataFrame objects. + """ + + _behave_like_pandas = False + + def _set_pandas_mode(self): # noqa + # disable Modin behavior to be able to pass object to `pandas.read_excel` + # otherwise, Modin objects may be passed to the pandas context, resulting + # in undefined behavior + self._behave_like_pandas = True + + def __getattribute__(self, item): + if item in ["_set_pandas_mode", "_behave_like_pandas"]: + return object.__getattribute__(self, item) + + default_behaviors = ["__init__", "__class__"] + method = super().__getattribute__(item) + if not self._behave_like_pandas and item not in default_behaviors: + if callable(method): + + def return_handler(*args, **kwargs): + """ + Replace the default behavior of methods with inplace kwarg. + + Returns + ------- + A Modin DataFrame in place of a pandas DataFrame, or the same + return type as pandas.ExcelFile. + + Notes + ----- + This function will replace all of the arguments passed to + methods of ExcelFile with the pandas equivalent. It will convert + Modin DataFrame to pandas DataFrame, etc. + """ + # We don't want to constantly be giving this error message for + # internal methods. + if item[0] != "_": + ErrorMessage.default_to_pandas(f"`{item}`") + args = [ + to_pandas(arg) + if isinstance(arg, ModinObjects.DataFrame) + else arg + for arg in args + ] + kwargs = { + k: to_pandas(v) if isinstance(v, ModinObjects.DataFrame) else v + for k, v in kwargs.items() + } + obj = super(ExcelFile, self).__getattribute__(item)(*args, **kwargs) + if isinstance(obj, pandas.DataFrame): + return ModinObjects.DataFrame(obj) + return obj + + # We replace the method with `return_handler` for inplace operations + method = return_handler + return method + + +@snowpark_pandas_telemetry_standalone_function_decorator +def from_non_pandas(df, index, columns, dtype): # pragma: no cover + """ + Convert a non-pandas DataFrame into Modin DataFrame. + + Parameters + ---------- + df : object + Non-pandas DataFrame. + index : object + Index for non-pandas DataFrame. + columns : object + Columns for non-pandas DataFrame. + dtype : type + Data type to force. + + Returns + ------- + modin.pandas.DataFrame + Converted DataFrame. + """ + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + new_qc = FactoryDispatcher.from_non_pandas(df, index, columns, dtype) + if new_qc is not None: + return ModinObjects.DataFrame(query_compiler=new_qc) + return new_qc + + +@snowpark_pandas_telemetry_standalone_function_decorator +def from_pandas(df): # pragma: no cover + """ + Convert a pandas DataFrame to a Modin DataFrame. + + Parameters + ---------- + df : pandas.DataFrame + The pandas DataFrame to convert. + + Returns + ------- + modin.pandas.DataFrame + A new Modin DataFrame object. + """ + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + return ModinObjects.DataFrame(query_compiler=FactoryDispatcher.from_pandas(df)) + + +@snowpark_pandas_telemetry_standalone_function_decorator +def from_arrow(at): # pragma: no cover + """ + Convert an Arrow Table to a Modin DataFrame. + + Parameters + ---------- + at : Arrow Table + The Arrow Table to convert from. + + Returns + ------- + DataFrame + A new Modin DataFrame object. + """ + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + return ModinObjects.DataFrame(query_compiler=FactoryDispatcher.from_arrow(at)) + + +@snowpark_pandas_telemetry_standalone_function_decorator +def from_dataframe(df): # pragma: no cover + """ + Convert a DataFrame implementing the dataframe exchange protocol to a Modin DataFrame. + + See more about the protocol in https://data-apis.org/dataframe-protocol/latest/index.html. + + Parameters + ---------- + df : DataFrame + The DataFrame object supporting the dataframe exchange protocol. + + Returns + ------- + DataFrame + A new Modin DataFrame object. + """ + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + return ModinObjects.DataFrame(query_compiler=FactoryDispatcher.from_dataframe(df)) + + +@snowpark_pandas_telemetry_standalone_function_decorator +def to_pandas(modin_obj: SupportsPrivateToPandas) -> Any: # pragma: no cover + """ + Convert a Modin DataFrame/Series to a pandas DataFrame/Series. + + Parameters + ---------- + modin_obj : modin.DataFrame, modin.Series + The Modin DataFrame/Series to convert. + + Returns + ------- + pandas.DataFrame or pandas.Series + Converted object with type depending on input. + """ + return modin_obj._to_pandas() + + +@snowpark_pandas_telemetry_standalone_function_decorator +def to_numpy( + modin_obj: SupportsPrivateToNumPy | SupportsPublicToNumPy, +) -> np.ndarray: # pragma: no cover + """ + Convert a Modin object to a NumPy array. + + Parameters + ---------- + modin_obj : modin.DataFrame, modin."Series", modin.numpy.array + The Modin distributed object to convert. + + Returns + ------- + numpy.array + Converted object with type depending on input. + """ + if isinstance(modin_obj, SupportsPrivateToNumPy): + return modin_obj._to_numpy() + array = modin_obj.to_numpy() + if ExperimentalNumPyAPI.get(): + array = array._to_numpy() + return array + + +__all__ = [ + "ExcelFile", + "HDFStore", + "json_normalize", + "read_clipboard", + "read_csv", + "read_excel", + "read_feather", + "read_fwf", + "read_gbq", + "read_hdf", + "read_html", + "read_json", + "read_orc", + "read_parquet", + "read_pickle", + "read_sas", + "read_spss", + "read_sql", + "read_sql_query", + "read_sql_table", + "read_stata", + "read_table", + "read_xml", + "from_non_pandas", + "from_pandas", + "from_arrow", + "from_dataframe", + "to_pickle", + "to_pandas", + "to_numpy", +] diff --git a/src/snowflake/snowpark/modin/pandas/iterator.py b/src/snowflake/snowpark/modin/pandas/iterator.py new file mode 100644 index 00000000000..717a743861c --- /dev/null +++ b/src/snowflake/snowpark/modin/pandas/iterator.py @@ -0,0 +1,82 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +"""Place to define the Modin iterator.""" + +from collections.abc import Iterator +from typing import Callable + +import snowflake.snowpark.modin.pandas.dataframe as DataFrame + + +class PartitionIterator(Iterator): + """ + Iterator on partitioned data. + + Parameters + ---------- + df : DataFrame + The dataframe to iterate over. + axis : {0, 1} + Axis to iterate over. + func : callable + The function to get inner iterables from each partition. + """ + + def __init__(self, df: DataFrame, axis: int, func: Callable) -> None: + self.df = df + self.axis = axis + self.index_iter = ( + zip( + iter(slice(None) for _ in range(len(self.df.columns))), + range(len(self.df.columns)), + ) + if axis + else zip( + range(len(self.df.index)), + iter(slice(None) for _ in range(len(self.df.index))), + ) + ) + self.func = func + + def __iter__(self) -> None: + """ + Implement iterator interface. + + Returns + ------- + PartitionIterator + Iterator object. + """ + return self + + def __next__(self): + """ + Implement iterator interface. + + Returns + ------- + PartitionIterator + Incremented iterator object. + """ + key = next(self.index_iter) + df = self.df.iloc[key] + return self.func(df) diff --git a/src/snowflake/snowpark/modin/pandas/plotting.py b/src/snowflake/snowpark/modin/pandas/plotting.py new file mode 100644 index 00000000000..e7099e94011 --- /dev/null +++ b/src/snowflake/snowpark/modin/pandas/plotting.py @@ -0,0 +1,84 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +"""Implement pandas plotting API.""" + +from pandas import plotting as pdplot + +# Snowpark pandas API version +from snowflake.snowpark.modin.pandas.dataframe import DataFrame +from snowflake.snowpark.modin.utils import instancer, to_pandas + + +@instancer +class Plotting: + """Wrapper of pandas plotting module.""" + + def __dir__(self): + """ + Enable tab completion of plotting library. + + Returns + ------- + list + List of attributes in `self`. + """ + return dir(pdplot) + + def __getattribute__(self, item): + """ + Convert any Modin DataFrames in parameters to pandas so that they can be plotted normally. + + Parameters + ---------- + item : str + Attribute to look for. + + Returns + ------- + object + If attribute is found in pandas.plotting, and it is a callable, a wrapper function is + returned which converts its arguments to pandas and calls a function pandas.plotting.`item` + on these arguments. + If attribute is found in pandas.plotting but it is not a callable, returns it. + Otherwise function tries to look for an attribute in `self`. + """ + if hasattr(pdplot, item): + func = getattr(pdplot, item) + if callable(func): + + def wrap_func(*args, **kwargs): + """Convert Modin DataFrames to pandas then call the function.""" + args = tuple( + arg if not isinstance(arg, DataFrame) else to_pandas(arg) + for arg in args + ) + kwargs = { + kwd: val if not isinstance(val, DataFrame) else to_pandas(val) + for kwd, val in kwargs.items() + } + return func(*args, **kwargs) + + return wrap_func + else: + return func + else: + return object.__getattribute__(self, item) diff --git a/src/snowflake/snowpark/modin/pandas/resample.py b/src/snowflake/snowpark/modin/pandas/resample.py new file mode 100644 index 00000000000..e33d3dc565f --- /dev/null +++ b/src/snowflake/snowpark/modin/pandas/resample.py @@ -0,0 +1,510 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the pandas project, under the BSD 3-Clause License + +"""Implement Resampler public API.""" +from typing import Any, Callable, Literal, Optional, Union + +import numpy as np +import pandas +import pandas.core.resample +from pandas._libs import lib +from pandas._libs.lib import no_default +from pandas._typing import AggFuncType, AnyArrayLike, Axis, T + +from snowflake.snowpark.modin import ( # noqa: F401 # add this line to enable doc tests to run + pandas as pd, +) +from snowflake.snowpark.modin.plugin._internal.telemetry import TelemetryMeta +from snowflake.snowpark.modin.plugin._typing import InterpolateOptions +from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage +from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage +from snowflake.snowpark.modin.utils import ( + _inherit_docstrings, + doc_replace_dataframe_with_link, +) + + +@_inherit_docstrings( + pandas.core.resample.Resampler, modify_doc=doc_replace_dataframe_with_link +) +class Resampler(metaclass=TelemetryMeta): + def __init__( + self, + dataframe, + rule, + axis=0, + closed=None, + label=None, + convention="start", + kind=None, + on=None, + level=None, + origin="start_day", + offset=None, + group_keys=no_default, + ) -> None: + # TODO: SNOW-1063368: Modin upgrade - modin.pandas.resample.Resample + self._dataframe = dataframe + self._query_compiler = dataframe._query_compiler + self.axis = self._dataframe._get_axis_number(axis) + self.resample_kwargs = { + "rule": rule, + "axis": axis, + "closed": closed, + "label": label, + "convention": convention, + "kind": kind, + "on": on, + "level": level, + "origin": origin, + "offset": offset, + "group_keys": group_keys, + } + self.__groups = self._get_groups() + + def _method_not_implemented(self, method: str): # pragma: no cover + # TODO: SNOW-1063368: Modin upgrade - modin.pandas.resample.Resample + raise ErrorMessage.not_implemented( + f"Method {method} is not implemented for Resampler!" + ) + + def _validate_numeric_only_for_aggregate_methods(self, numeric_only): + """ + When the caller object is Series (ndim == 1), it is not valid to call aggregation + method with numeric_only = True. + + Raises: + NotImplementedError if the above condition is encountered. + """ + # TODO: SNOW-1063368: Modin upgrade - modin.pandas.resample.Resample + if self._dataframe.ndim == 1: + if numeric_only and numeric_only is not lib.no_default: + raise ErrorMessage.not_implemented( + "Series Resampler does not implement numeric_only." + ) + + def _get_groups(self): + """ + Compute the resampled groups. + + Returns + ------- + PandasGroupby + Groups as specified by resampling arguments. + """ + # TODO: SNOW-1063368: Modin upgrade - modin.pandas.resample.Resample + df = self._dataframe if self.axis == 0 else self._dataframe.T + groups = df.groupby( + pandas.Grouper( + key=self.resample_kwargs["on"], + freq=self.resample_kwargs["rule"], + closed=self.resample_kwargs["closed"], + label=self.resample_kwargs["label"], + convention=self.resample_kwargs["convention"], + level=self.resample_kwargs["level"], + origin=self.resample_kwargs["origin"], + offset=self.resample_kwargs["offset"], + ), + group_keys=self.resample_kwargs["group_keys"], + ) + return groups + + def __getitem__(self, key): # pragma: no cover + """ + Get ``Resampler`` based on `key` columns of original dataframe. + + Parameters + ---------- + key : str or list + String or list of selections. + + Returns + ------- + modin.pandas.BasePandasDataset + New ``Resampler`` based on `key` columns subset + of the original dataframe. + """ + + # TODO: SNOW-1063368: Modin upgrade - modin.pandas.resample.Resample + + def _get_new_resampler(key): + subset = self._dataframe[key] + resampler = type(self)(subset, **self.resample_kwargs) + return resampler + + from snowflake.snowpark.modin.pandas.series import Series + + if isinstance(key, (list, tuple, Series, pandas.Index, np.ndarray)): + if len(self._dataframe.columns.intersection(key)) != len(set(key)): + missed_keys = list(set(key).difference(self._dataframe.columns)) + raise KeyError(f"Columns not found: {str(sorted(missed_keys))[1:-1]}") + return _get_new_resampler(list(key)) + + if key not in self._dataframe: + raise KeyError(f"Column not found: {key}") + + return _get_new_resampler(key) + + @property + def groups(self): # pragma: no cover + # TODO: SNOW-1063368: Modin upgrade - modin.pandas.resample.Resample + self._method_not_implemented("groups") + # This property is currently not supported, and NotImplementedError will be + # thrown before reach here. This is kept here because property function requires + # a return value. + return self._query_compiler.default_to_pandas( + lambda df: pandas.DataFrame.resample(df, **self.resample_kwargs).groups + ) + + @property + def indices(self): # pragma: no cover + # TODO: SNOW-1063368: Modin upgrade - modin.pandas.resample.Resample + self._method_not_implemented("indices") + # Same as groups, keeps the return because indices requires return value + return self._query_compiler.default_to_pandas( + lambda df: pandas.DataFrame.resample(df, **self.resample_kwargs).indices + ) + + def get_group(self, name, obj=None): # pragma: no cover + # TODO: SNOW-1063368: Modin upgrade - modin.pandas.resample.Resample + self._method_not_implemented("get_group") + + def apply( + self, func: Optional[AggFuncType] = None, *args: Any, **kwargs: Any + ): # pragma: no cover + # TODO: SNOW-1063368: Modin upgrade - modin.pandas.resample.Resample + self._method_not_implemented("aggregate") + + def aggregate( + self, func: Optional[AggFuncType] = None, *args: Any, **kwargs: Any + ): # pragma: no cover + # TODO: SNOW-1063368: Modin upgrade - modin.pandas.resample.Resample + self._method_not_implemented("aggregate") + + agg = aggregate + + def transform( + self, + arg: Union[Callable[..., T], tuple[Callable[..., T], str]], + *args: Any, + **kwargs: Any, + ): # pragma: no cover + # TODO: SNOW-1063368: Modin upgrade - modin.pandas.resample.Resample + self._method_not_implemented("transform") + + def pipe( + self, + func: Union[Callable[..., T], tuple[Callable[..., T], str]], + *args: Any, + **kwargs: Any, + ): # pragma: no cover + # TODO: SNOW-1063368: Modin upgrade - modin.pandas.resample.Resample + self._method_not_implemented("pipe") + + def ffill(self, limit: Optional[int] = None) -> Union[pd.DataFrame, pd.Series]: + is_series = not self._dataframe._is_dataframe + + if limit is not None: + ErrorMessage.not_implemented( + "Parameter limit of resample.ffill has not been implemented." + ) + + return self._dataframe.__constructor__( + query_compiler=self._query_compiler.resample( + self.resample_kwargs, + "ffill", + (), + {}, + is_series, + ) + ) + + def backfill(self, limit: Optional[int] = None): + self._method_not_implemented("backfill") # pragma: no cover + + def bfill(self, limit: Optional[int] = None): # pragma: no cover + self._method_not_implemented("bfill") + + def pad(self, limit: Optional[int] = None): # pragma: no cover + self._method_not_implemented("pad") + + def nearest(self, limit: Optional[int] = None): # pragma: no cover + self._method_not_implemented("nearest") + + def fillna(self, method, limit: Optional[int] = None): # pragma: no cover + self._method_not_implemented("fillna") + + def asfreq(self, fill_value: Optional[Any] = None): # pragma: no cover + self._method_not_implemented("asfreq") + + def interpolate( + self, + method: InterpolateOptions = "linear", + *, + axis: Axis = 0, + limit: Optional[int] = None, + inplace: bool = False, + limit_direction: Literal["forward", "backward", "both"] = "forward", + limit_area: Optional[Literal["inside", "outside"]] = None, + downcast: Optional[Literal["infer"]] = None, + **kwargs, + ): # pragma: no cover + self._method_not_implemented("interpolate") + + def count(self) -> Union[pd.DataFrame, pd.Series]: + # TODO: SNOW-1063368: Modin upgrade - modin.pandas.resample.Resample + is_series = not self._dataframe._is_dataframe + + return self._dataframe.__constructor__( + query_compiler=self._query_compiler.resample( + self.resample_kwargs, + "count", + tuple(), + dict(), + is_series, + ) + ) + + def nunique(self, *args: Any, **kwargs: Any): # pragma: no cover + # TODO: SNOW-1063368: Modin upgrade - modin.pandas.resample.Resample + self._method_not_implemented("nunique") + + def first( + self, + numeric_only: Union[bool, lib.NoDefault] = lib.no_default, + min_count: int = 0, + *args: Any, + **kwargs: Any, + ): # pragma: no cover + # TODO: SNOW-1063368: Modin upgrade - modin.pandas.resample.Resample + self._method_not_implemented("first") + + def last( + self, + numeric_only: Union[bool, lib.NoDefault] = lib.no_default, + min_count: int = 0, + *args: Any, + **kwargs: Any, + ): # pragma: no cover + # TODO: SNOW-1063368: Modin upgrade - modin.pandas.resample.Resample + self._method_not_implemented("last") + + def max( + self, + numeric_only: bool = False, + min_count: int = 0, + *args: Any, + **kwargs: Any, + ) -> Union[pd.DataFrame, pd.Series]: + self._validate_numeric_only_for_aggregate_methods(numeric_only) + WarningMessage.warning_if_engine_args_is_set("resample_max", args, kwargs) + + agg_kwargs = dict(numeric_only=numeric_only, min_count=min_count) + is_series = not self._dataframe._is_dataframe + + return self._dataframe.__constructor__( + query_compiler=self._query_compiler.resample( + self.resample_kwargs, + "max", + tuple(), + agg_kwargs, + is_series, + ) + ) + + def mean( + self, + numeric_only: bool = False, + *args: Any, + **kwargs: Any, + ) -> Union[pd.DataFrame, pd.Series]: + # TODO: SNOW-1063368: Modin upgrade - modin.pandas.resample.Resample + self._validate_numeric_only_for_aggregate_methods(numeric_only) + WarningMessage.warning_if_engine_args_is_set("resample_mean", args, kwargs) + + agg_kwargs = dict(numeric_only=numeric_only) + is_series = not self._dataframe._is_dataframe + + return self._dataframe.__constructor__( + query_compiler=self._query_compiler.resample( + self.resample_kwargs, + "mean", + tuple(), + agg_kwargs, + is_series, + ) + ) + + def median( + self, + numeric_only: bool = False, + *args: Any, + **kwargs: Any, + ) -> Union[pd.DataFrame, pd.Series]: + # TODO: SNOW-1063368: Modin upgrade - modin.pandas.resample.Resample + self._validate_numeric_only_for_aggregate_methods(numeric_only) + WarningMessage.warning_if_engine_args_is_set("resample_median", args, kwargs) + + agg_kwargs = dict(numeric_only=numeric_only) + is_series = not self._dataframe._is_dataframe + + return self._dataframe.__constructor__( + query_compiler=self._query_compiler.resample( + self.resample_kwargs, + "median", + tuple(), + agg_kwargs, + is_series, + ) + ) + + def min( + self, + numeric_only: bool = False, + min_count: int = 0, + *args: Any, + **kwargs: Any, + ) -> Union[pd.DataFrame, pd.Series]: + # TODO: SNOW-1063368: Modin upgrade - modin.pandas.resample.Resample + self._validate_numeric_only_for_aggregate_methods(numeric_only) + WarningMessage.warning_if_engine_args_is_set("resample_min", args, kwargs) + + agg_kwargs = dict(numeric_only=numeric_only, min_count=min_count) + is_series = not self._dataframe._is_dataframe + + return self._dataframe.__constructor__( + query_compiler=self._query_compiler.resample( + self.resample_kwargs, + "min", + tuple(), + agg_kwargs, + is_series, + ) + ) + + def ohlc(self, *args: Any, **kwargs: Any): # pragma: no cover + # TODO: SNOW-1063368: Modin upgrade - modin.pandas.resample.Resample + self._method_not_implemented("ohlc") + + def prod( + self, + numeric_only: Union[bool, lib.NoDefault] = lib.no_default, + min_count: int = 0, + *args: Any, + **kwargs: Any, + ): # pragma: no cover + # TODO: SNOW-1063368: Modin upgrade - modin.pandas.resample.Resample + self._method_not_implemented("prod") + + def size(self): # pragma: no cover + # TODO: SNOW-1063368: Modin upgrade - modin.pandas.resample.Resample + self._method_not_implemented("size") + + def sem( + self, + ddof: int = 1, + numeric_only: Union[bool, lib.NoDefault] = lib.no_default, + *args: Any, + **kwargs: Any, + ): # pragma: no cover + # TODO: SNOW-1063368: Modin upgrade - modin.pandas.resample.Resample + self._method_not_implemented("sem") + + def std( + self, + ddof: int = 1, + numeric_only: bool = False, + *args: Any, + **kwargs: Any, + ) -> Union[pd.DataFrame, pd.Series]: + # TODO: SNOW-1063368: Modin upgrade - modin.pandas.resample.Resample + self._validate_numeric_only_for_aggregate_methods(numeric_only) + WarningMessage.warning_if_engine_args_is_set("resample_std", args, kwargs) + + agg_kwargs = dict(numeric_only=numeric_only, ddof=ddof) + is_series = not self._dataframe._is_dataframe + + return self._dataframe.__constructor__( + query_compiler=self._query_compiler.resample( + self.resample_kwargs, + "std", + tuple(), + agg_kwargs, + is_series, + ) + ) + + def sum( + self, + numeric_only: bool = False, + min_count: int = 0, + *args: Any, + **kwargs: Any, + ) -> Union[pd.DataFrame, pd.Series]: + # TODO: SNOW-1063368: Modin upgrade - modin.pandas.resample.Resample + self._validate_numeric_only_for_aggregate_methods(numeric_only) + WarningMessage.warning_if_engine_args_is_set("resample_sum", args, kwargs) + + agg_kwargs = dict(numeric_only=numeric_only, min_count=min_count) + is_series = not self._dataframe._is_dataframe + + return self._dataframe.__constructor__( + query_compiler=self._query_compiler.resample( + self.resample_kwargs, + "sum", + tuple(), + agg_kwargs, + is_series, + ) + ) + + def var( + self, + ddof: int = 1, + numeric_only: Union[bool, lib.NoDefault] = lib.no_default, + *args: Any, + **kwargs: Any, + ) -> Union[pd.DataFrame, pd.Series]: + # TODO: SNOW-1063368: Modin upgrade - modin.pandas.resample.Resample + self._validate_numeric_only_for_aggregate_methods(numeric_only) + WarningMessage.warning_if_engine_args_is_set("resample_var", args, kwargs) + + agg_kwargs = dict(numeric_only=numeric_only, ddof=ddof) + is_series = not self._dataframe._is_dataframe + + return self._dataframe.__constructor__( + query_compiler=self._query_compiler.resample( + self.resample_kwargs, + "var", + tuple(), + agg_kwargs, + is_series, + ) + ) + + def quantile( + self, q: Union[float, AnyArrayLike] = 0.5, **kwargs: Any + ): # pragma: no cover + # TODO: SNOW-1063368: Modin upgrade - modin.pandas.resample.Resample + self._method_not_implemented("quantile") diff --git a/src/snowflake/snowpark/modin/pandas/series.py b/src/snowflake/snowpark/modin/pandas/series.py new file mode 100644 index 00000000000..2e52624814c --- /dev/null +++ b/src/snowflake/snowpark/modin/pandas/series.py @@ -0,0 +1,2723 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +"""Module houses `Series` class, that is distributed version of `pandas.Series`.""" + +from __future__ import annotations + +from collections.abc import Hashable, Mapping, Sequence +from logging import getLogger +from typing import IO, TYPE_CHECKING, Any, Callable, Literal + +import numpy as np +import numpy.typing as npt +import pandas +from pandas._libs.lib import NoDefault, is_integer, no_default +from pandas._typing import ( + AggFuncType, + AnyArrayLike, + Axis, + FillnaOptions, + IgnoreRaise, + IndexKeyFunc, + IndexLabel, + Level, + Renamer, + Scalar, +) +from pandas.api.types import is_datetime64_any_dtype, is_string_dtype +from pandas.core.common import apply_if_callable, is_bool_indexer +from pandas.core.dtypes.common import is_bool_dtype, is_dict_like, is_list_like +from pandas.core.series import _coerce_method +from pandas.errors import SpecificationError +from pandas.util._validators import validate_bool_kwarg + +from snowflake.snowpark.modin.pandas.accessor import CachedAccessor, SparseAccessor +from snowflake.snowpark.modin.pandas.base import _ATTRS_NO_LOOKUP, BasePandasDataset +from snowflake.snowpark.modin.pandas.iterator import PartitionIterator +from snowflake.snowpark.modin.pandas.utils import from_pandas, is_scalar +from snowflake.snowpark.modin.plugin._typing import DropKeep, ListLike +from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage +from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage +from snowflake.snowpark.modin.utils import ( + MODIN_UNNAMED_SERIES_LABEL, + _inherit_docstrings, +) + +if TYPE_CHECKING: + from snowflake.snowpark.modin.pandas.dataframe import DataFrame # pragma: no cover + +# add this line to enable doc tests to run +from snowflake.snowpark.modin import pandas as pd # noqa: F401 + +logger = getLogger(__name__) + +SERIES_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE_ERROR_MESSAGE = ( + "Currently do not support Series or list-like keys with range-like values" +) + +SERIES_SETITEM_SLICE_AS_SCALAR_VALUE_ERROR_MESSAGE = ( + "Currently do not support assigning a slice value as if it's a scalar value" +) + +SERIES_SETITEM_INCOMPATIBLE_INDEXER_WITH_SERIES_ERROR_MESSAGE = ( + "Snowpark pandas DataFrame cannot be used as an indexer with Series" +) + +SERIES_SETITEM_INCOMPATIBLE_INDEXER_WITH_SCALAR_ERROR_MESSAGE = ( + "Scalar key incompatible with {} value" +) + +# Dictionary of extensions assigned to this class +_SERIES_EXTENSIONS_ = {} + + +@_inherit_docstrings( + pandas.Series, + excluded=[ + pandas.Series.flags, + pandas.Series.info, + pandas.Series.prod, + pandas.Series.product, + pandas.Series.reindex, + pandas.Series.fillna, + ], + apilink="pandas.Series", +) +class Series(BasePandasDataset): + _pandas_class = pandas.Series + __array_priority__ = pandas.Series.__array_priority__ + + def __init__( + self, + data=None, + index=None, + dtype=None, + name=None, + copy=False, + fastpath=False, + query_compiler=None, + ) -> None: + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + # Siblings are other dataframes that share the same query compiler. We + # use this list to update inplace when there is a shallow copy. + self._siblings = [] + + # modified: + # Engine.subscribe(_update_engine) + + if isinstance(data, type(self)): + query_compiler = data._query_compiler.copy() + if index is not None: + if any(i not in data.index for i in index): + ErrorMessage.not_implemented( + "Passing non-existent columns or index values to constructor " + + "not yet implemented." + ) # pragma: no cover + query_compiler = data.loc[index]._query_compiler + if query_compiler is None: + # Defaulting to pandas + if name is None: + name = MODIN_UNNAMED_SERIES_LABEL + if ( + isinstance(data, (pandas.Series, pandas.Index)) + and data.name is not None + ): + name = data.name + + query_compiler = from_pandas( + pandas.DataFrame( + pandas.Series( + data=data, + index=index, + dtype=dtype, + name=name, + copy=copy, + fastpath=fastpath, + ) + ) + )._query_compiler + self._query_compiler = query_compiler.columnarize() + if name is not None: + self.name = name + + def _get_name(self): + """ + Get the value of the `name` property. + + Returns + ------- + hashable + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + name = self._query_compiler.columns[0] + if name == MODIN_UNNAMED_SERIES_LABEL: + return None + return name + + def _set_name(self, name): + """ + Set the value of the `name` property. + + Parameters + ---------- + name : hashable + Name value to set. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + if name is None: + name = MODIN_UNNAMED_SERIES_LABEL + if isinstance(name, tuple): + columns = pd.MultiIndex.from_tuples(tuples=[name]) + else: + columns = [name] + self._update_inplace( + new_query_compiler=self._query_compiler.set_columns(columns) + ) + + name = property(_get_name, _set_name) + _parent = None + # Parent axis denotes axis that was used to select series in a parent dataframe. + # If _parent_axis == 0, then it means that index axis was used via df.loc[row] + # indexing operations and assignments should be done to rows of parent. + # If _parent_axis == 1 it means that column axis was used via df[column] and assignments + # should be done to columns of parent. + _parent_axis = 0 + + def __add__(self, right): + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self.add(right) + + def __radd__(self, left): + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self.radd(left) + + def __and__(self, other): + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().__and__(other) + + def __rand__(self, other): + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().__rand__(other) + + # add `_inherit_docstrings` decorator to force method link addition. + @_inherit_docstrings(pandas.Series.__array__, apilink="pandas.Series.__array__") + def __array__(self, dtype=None): # noqa: PR01, RT01, D200 + """ + Return the values as a NumPy array. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().__array__(dtype).flatten() + + def __contains__(self, key): + """ + Check if `key` in the `Series.index`. + + Parameters + ---------- + key : hashable + Key to check the presence in the index. + + Returns + ------- + bool + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return key in self.index + + def __copy__(self, deep=True): + """ + Return the copy of the Series. + + Parameters + ---------- + deep : bool, default: True + Whether the copy should be deep or not. + + Returns + ------- + Series + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self.copy(deep=deep) + + def __deepcopy__(self, memo=None): + """ + Return the deep copy of the Series. + + Parameters + ---------- + memo : Any, optional + Deprecated parameter. + + Returns + ------- + Series + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self.copy(deep=True) + + def __delitem__(self, key): + """ + Delete item identified by `key` label. + + Parameters + ---------- + key : hashable + Key to delete. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + if key not in self.keys(): + raise KeyError(key) + self.drop(labels=key, inplace=True) + + def __divmod__(self, right): + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self.divmod(right) + + def __rdivmod__(self, left): + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self.rdivmod(left) + + def __floordiv__(self, right): + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self.floordiv(right) + + def __rfloordiv__(self, right): + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self.rfloordiv(right) + + def __getattr__(self, key): + """ + Return item identified by `key`. + + Parameters + ---------- + key : hashable + Key to get. + + Returns + ------- + Any + + Notes + ----- + First try to use `__getattribute__` method. If it fails + try to get `key` from `Series` fields. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + try: + return object.__getattribute__(self, key) + except AttributeError as err: + if key not in _ATTRS_NO_LOOKUP: + try: + value = self[key] + if isinstance(value, Series) and value.empty: + raise err + return value + except Exception: + # We want to raise err if self[key] raises any kind of exception + raise err + raise err + + __float__ = _coerce_method(float) + __int__ = _coerce_method(int) + + def abs(self): + """ + Return a Series with absolute numeric value of each element. + + Returns + ------- + Series + + Examples + -------- + >>> ser = pd.Series([1, -2.29, 3, -4.77]) + >>> ser + 0 1.00 + 1 -2.29 + 2 3.00 + 3 -4.77 + dtype: float64 + + >>> abs(ser) + 0 1.00 + 1 2.29 + 2 3.00 + 3 4.77 + dtype: float64 + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().abs() + + def __neg__(self): + """ + Returns a Series with the sign changed for each element. + + Returns + ------- + Series + + Examples + -------- + >>> ser = pd.Series([1, -2.29, 3, -4.77]) + >>> ser + 0 1.00 + 1 -2.29 + 2 3.00 + 3 -4.77 + dtype: float64 + + >>> - ser + 0 -1.00 + 1 2.29 + 2 -3.00 + 3 4.77 + dtype: float64 + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().__neg__() + + def __iter__(self): + """ + Return an iterator of the values. + + Returns + ------- + iterable + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self._to_pandas().__iter__() + + def __mod__(self, right): + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self.mod(right) + + def __rmod__(self, left): + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self.rmod(left) + + def __mul__(self, right): + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self.mul(right) + + def __rmul__(self, left): + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self.rmul(left) + + def __or__(self, other): + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().__or__(other) + + def __ror__(self, other): + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().__ror__(other) + + def __xor__(self, other): # pragma: no cover + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().__xor__(other) + + def __rxor__(self, other): # pragma: no cover + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().__rxor__(other) + + def __pow__(self, right): + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self.pow(right) + + def __rpow__(self, left): + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self.rpow(left) + + def __repr__(self): + """ + Return a string representation for a particular Series. + + Returns + ------- + str + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + num_rows = pandas.get_option("display.max_rows") or 60 + num_cols = pandas.get_option("display.max_columns") or 20 + + ( + row_count, + col_count, + temp_df, + ) = self._query_compiler.build_repr_df(num_rows, num_cols) + if isinstance(temp_df, pandas.DataFrame) and not temp_df.empty: + temp_df = temp_df.iloc[:, 0] + temp_str = repr(temp_df) + freq_str = ( + f"Freq: {temp_df.index.freqstr}, " + if isinstance(temp_df.index, pandas.DatetimeIndex) + else "" + ) + if self.name is not None: + name_str = f"Name: {str(self.name)}, " + else: + name_str = "" + if row_count > num_rows: + len_str = f"Length: {row_count}, " + else: + len_str = "" + dtype_str = "dtype: {}".format( + str(self.dtype) + ")" + if temp_df.empty + else temp_str.rsplit("dtype: ", 1)[-1] + ) + if row_count == 0: + return f"Series([], {freq_str}{name_str}{dtype_str}" + maxsplit = 1 + if ( + isinstance(temp_df, pandas.Series) + and temp_df.name is not None + and temp_df.dtype == "category" + ): + maxsplit = 2 + return temp_str.rsplit("\n", maxsplit)[0] + "\n{}{}{}{}".format( + freq_str, name_str, len_str, dtype_str + ) + + def __round__(self, decimals=0): + """ + Round each value in a Series to the given number of decimals. + + Parameters + ---------- + decimals : int, default: 0 + Number of decimal places to round to. + + Returns + ------- + Series + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().round(decimals) + + def __setitem__(self, key, value): + """ + Set `value` identified by `key` in the Series. + + Parameters + ---------- + key : hashable + Key to set. + value : Any + Value to set. + + Examples + -------- + Using the following series to set values on. __setitem__ is an inplace operation, so copies of `series`are made + in the examples to highlight the different behaviors produced. + >>> series = pd.Series([1, "b", 3], index=["a", "b", "c"]) + + Using a scalar as the value to set a particular element. + >>> s = series.copy() + >>> s["c"] = "a" + >>> s + a 1 + b b + c a + dtype: object + + Using list-like objects as the key and value to set multiple elements. + >>> s = series.copy() + >>> s[["c", "a"]] = ["foo", "bar"] + >>> s # doctest: +SKIP + a bar + b 2 + c foo + dtype: object + + Having a duplicate label in the key. + >>> s = series.copy() + >>> s[["c", "a", "c"]] = pd.Index(["foo", "bar", "baz"]) + >>> s # doctest: +SKIP + a bar + b 2 + c baz + dtype: object + + When using a Series as the value, its index does not matter. + >>> s = series.copy() # doctest: +SKIP + >>> s[["a", "b"]] = pd.Series([9, 8], index=["foo", "bar"]) + >>> s # doctest: +SKIP + a 9 + b 8 + c 3 + dtype: int64 + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + key = apply_if_callable(key, self) + + # Error Checking: + # Currently do not support Series[scalar key] = Series item/DataFrame item since this results in a nested series + # or df. + if is_scalar(key) and isinstance(value, BasePandasDataset): + raise ValueError( + SERIES_SETITEM_INCOMPATIBLE_INDEXER_WITH_SCALAR_ERROR_MESSAGE.format( + "Snowpark pandas " + value.__class__.__name__ + if isinstance(value, BasePandasDataset) + else value.__class__.__name__ + ) + ) + if isinstance(key, pd.DataFrame): + raise ValueError( + SERIES_SETITEM_INCOMPATIBLE_INDEXER_WITH_SERIES_ERROR_MESSAGE + ) + elif (isinstance(key, pd.Series) or is_list_like(key)) and ( + isinstance(value, range) + ): + raise NotImplementedError( + SERIES_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE_ERROR_MESSAGE + ) + elif isinstance(value, slice): + # Here, the whole slice is assigned as a scalar variable, i.e., a spot at an index gets a slice value. + raise NotImplementedError( + SERIES_SETITEM_SLICE_AS_SCALAR_VALUE_ERROR_MESSAGE + ) + + if isinstance(key, (slice, range)): + if (key.start is None or is_integer(key.start)) and ( # pragma: no cover + key.stop is None or is_integer(key.stop) + ): + # integer slice behaves the same as iloc slice + self.iloc[key] = value # pragma: no cover + else: + # TODO: SNOW-976232 once the slice test is added to test_setitem, code here should be covered. + self.loc[key] = value # pragma: no cover + + elif isinstance(value, Series): + # If value is a Series, value's index doesn't matter/is ignored. However, loc setitem matches the key's + # index with value's index. To emulate this behavior, treat the Series as if it is matching by position. + # + # For example, + # With __setitem__, the index of value does not matter. + # >>> series = pd.Series([1, 2, 3], index=["a", "b", "c"]) + # >>> series[["a", "b"]] = pd.Series([9, 8]) + # a 9 + # b 8 + # c 3 + # dtype: int64 + # value = pd.Series([9, 8], index=["foo", "bar"]) also produces same result as above. + # + # However, with loc setitem, index matters. + # >>> series.loc[["a", "b"]] = pd.Series([9, 8]) + # a NaN + # b NaN + # c 3.0 + # dtype: float64 + # + # >>> series.loc[["a", "b"]] = pd.Series([9, 8], index=["a", "b"]) + # a 9 + # b 8 + # c 3 + # dtype: int64 + # Due to the behavior above, loc setitem can work with any kind of value regardless of length. + # With __setitem__, the length of the value must match length of the key. Currently, loc setitem can + # handle this with boolean keys. + + # Convert list-like keys to Series. + if not isinstance(key, pd.Series) and is_list_like(key): + key = pd.Series(key) + + index_is_bool_indexer = False + + if isinstance(key, pd.Series) and is_bool_dtype(key.dtype): + index_is_bool_indexer = True # pragma: no cover + elif is_bool_indexer(key): + index_is_bool_indexer = True # pragma: no cover + + new_qc = self._query_compiler.set_2d_labels( + key._query_compiler if isinstance(key, BasePandasDataset) else key, + slice(None), # column key is not applicable to Series objects + value._query_compiler, + matching_item_columns_by_label=False, + matching_item_rows_by_label=False, + index_is_bool_indexer=index_is_bool_indexer, + ) + self._update_inplace(new_query_compiler=new_qc) + + else: + self.loc[key] = value + + def __sub__(self, right): + return self.sub(right) + + def __rsub__(self, left): + return self.rsub(left) + + def __truediv__(self, right): + return self.truediv(right) + + def __rtruediv__(self, left): + return self.rtruediv(left) + + __iadd__ = __add__ + __imul__ = __add__ + __ipow__ = __pow__ + __isub__ = __sub__ + __itruediv__ = __truediv__ + + def add(self, other, level=None, fill_value=None, axis=0): # noqa: PR01, RT01, D200 + """ + Return Addition of series and other, element-wise (binary operator add). + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().add(other, level=level, fill_value=fill_value, axis=axis) + + def radd( + self, other, level=None, fill_value=None, axis=0 + ): # noqa: PR01, RT01, D200 + """ + Return Addition of series and other, element-wise (binary operator radd). + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().radd(other, level=level, fill_value=fill_value, axis=axis) + + def add_prefix(self, prefix): # noqa: PR01, RT01, D200 + """ + Prefix labels with string `prefix`. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + # pandas converts non-string prefix values into str and adds it to the index labels. + return self.__constructor__( + query_compiler=self._query_compiler.add_substring( + str(prefix), substring_type="prefix", axis=0 + ) + ) + + def add_suffix(self, suffix): + """ + Suffix labels with string `suffix`. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + # pandas converts non-string suffix values into str and appends it to the index labels. + return self.__constructor__( + query_compiler=self._query_compiler.add_substring( + str(suffix), substring_type="suffix", axis=0 + ) + ) + + def drop( + self, + labels: IndexLabel = None, + axis: Axis = 0, + index: IndexLabel = None, + columns: IndexLabel = None, + level: Level | None = None, + inplace: bool = False, + errors: IgnoreRaise = "raise", + ) -> Series | None: + """ + Drop specified labels from `BasePandasDataset`. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().drop( + labels=labels, + axis=axis, + index=index, + columns=columns, + level=level, + inplace=inplace, + errors=errors, + ) + + def aggregate( + self, func: AggFuncType = None, axis: Axis = 0, *args: Any, **kwargs: Any + ): + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + if is_dict_like(func): + raise SpecificationError( + "Value for func argument in dict format is not allowed for Series aggregate." + ) + + return super().aggregate(func, axis, *args, **kwargs) + + agg = aggregate + + def apply( + self, + func: AggFuncType, + convert_dtype: bool = True, + args: tuple[Any, ...] = (), + **kwargs: Any, + ): + """ + Apply a function along an axis of the `BasePandasDataset`. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + self._validate_function(func) + new_query_compiler = self._query_compiler.apply_on_series(func, args, **kwargs) + + if convert_dtype: + # TODO SNOW-810614: call convert_dtypes for consistency + WarningMessage.ignored_argument( + operation="apply", + argument="convert_dtype", + message="convert_dtype is ignored in Snowflake backend", + ) + + return self.__constructor__(query_compiler=new_query_compiler) + + def argmax(self, axis=None, skipna=True, *args, **kwargs): # noqa: PR01, RT01, D200 + """ + Return int position of the largest value in the Series. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + result = self.idxmax(axis=axis, skipna=skipna, *args, **kwargs) + if np.isnan(result) or result is pandas.NA: + result = -1 + return result + + def argmin(self, axis=None, skipna=True, *args, **kwargs): # noqa: PR01, RT01, D200 + """ + Return int position of the smallest value in the Series. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + result = self.idxmin(axis=axis, skipna=skipna, *args, **kwargs) + if np.isnan(result) or result is pandas.NA: + result = -1 + return result + + def argsort(self, axis=0, kind="quicksort", order=None): # noqa: PR01, RT01, D200 + """ + Return the integer indices that would sort the Series values. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + return self._default_to_pandas( + pandas.Series.argsort, axis=axis, kind=kind, order=order + ) + + def autocorr(self, lag=1): # noqa: PR01, RT01, D200 + """ + Compute the lag-N autocorrelation. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + return self.corr(self.shift(lag)) + + def between(self, left, right, inclusive: str = "both"): # noqa: PR01, RT01, D200 + """ + Return boolean Series equivalent to left <= series <= right. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + return self._default_to_pandas( + pandas.Series.between, left, right, inclusive=inclusive + ) + + def compare( + self, + other: Series, + align_axis: str | int = 1, + keep_shape: bool = False, + keep_equal: bool = False, + result_names: tuple = ("self", "other"), + ) -> Series: # noqa: PR01, RT01, D200 + """ + Compare to another Series and show the differences. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + if not isinstance(other, Series): + raise TypeError(f"Cannot compare Series to {type(other)}") + result = self.to_frame().compare( + other.to_frame(), + align_axis=align_axis, + keep_shape=keep_shape, + keep_equal=keep_equal, + result_names=result_names, + ) + if align_axis == "columns" or align_axis == 1: + # pandas.DataFrame.Compare returns a dataframe with a multidimensional index object as the + # columns so we have to change column object back. + result.columns = pandas.Index(["self", "other"]) + else: + result = result.squeeze().rename(None) + return result + + def corr(self, other, method="pearson", min_periods=None): # noqa: PR01, RT01, D200 + """ + Compute correlation with `other` Series, excluding missing values. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + if method == "pearson": + this, other = self.align(other, join="inner", copy=False) + this = self.__constructor__(this) + other = self.__constructor__(other) + + if len(this) == 0: + return np.nan + if len(this) != len(other): + raise ValueError("Operands must have same size") + + if min_periods is None: + min_periods = 1 + + valid = this.notna() & other.notna() + if not valid.all(): + this = this[valid] + other = other[valid] + if len(this) < min_periods: + return np.nan + + this = this.astype(dtype="float64") + other = other.astype(dtype="float64") + this -= this.mean() + other -= other.mean() + + other = other.__constructor__(query_compiler=other._query_compiler.conj()) + result = this * other / (len(this) - 1) + result = np.array([result.sum()]) + + stddev_this = ((this * this) / (len(this) - 1)).sum() + stddev_other = ((other * other) / (len(other) - 1)).sum() + + stddev_this = np.array([np.sqrt(stddev_this)]) + stddev_other = np.array([np.sqrt(stddev_other)]) + + result /= stddev_this * stddev_other + + np.clip(result.real, -1, 1, out=result.real) + if np.iscomplexobj(result): + np.clip(result.imag, -1, 1, out=result.imag) + return result[0] + + return self.__constructor__( + query_compiler=self._query_compiler.default_to_pandas( + pandas.Series.corr, + other._query_compiler, + method=method, + min_periods=min_periods, + ) + ) + + def count(self): + return super().count() + + def cov( + self, other, min_periods=None, ddof: int | None = 1 + ): # noqa: PR01, RT01, D200 + """ + Compute covariance with Series, excluding missing values. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + this, other = self.align(other, join="inner", copy=False) + this = self.__constructor__(this) + other = self.__constructor__(other) + if len(this) == 0: + return np.nan + + if len(this) != len(other): + raise ValueError("Operands must have same size") + + if min_periods is None: + min_periods = 1 + + valid = this.notna() & other.notna() + if not valid.all(): + this = this[valid] + other = other[valid] + + if len(this) < min_periods: + return np.nan + + this = this.astype(dtype="float64") + other = other.astype(dtype="float64") + + this -= this.mean() + other -= other.mean() + + other = other.__constructor__(query_compiler=other._query_compiler.conj()) + result = this * other / (len(this) - ddof) + result = result.sum() + return result + + def describe( + self, + percentiles: ListLike | None = None, + include: ListLike | Literal["all"] | None = None, + exclude: ListLike | None = None, + ) -> Series: + """ + Generate descriptive statistics. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().describe( + percentiles=percentiles, + include=None, + exclude=None, + ) + + def diff(self, periods: int = 1): + """ + First discrete difference of element. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().diff(periods=periods, axis=0) + + def divmod( + self, other, level=None, fill_value=None, axis=0 + ): # noqa: PR01, RT01, D200 + """ + Return Integer division and modulo of series and `other`, element-wise (binary operator `divmod`). + Not implemented + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + + def dot(self, other): # noqa: PR01, RT01, D200 + """ + Compute the dot product between the Series and the columns of `other`. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() # pragma: no cover + + if isinstance(other, BasePandasDataset): + common = self.index.union(other.index) + if len(common) > len(self) or len(common) > len(other): # pragma: no cover + raise ValueError("Matrices are not aligned") + + if isinstance(other, Series): + return self._reduce_dimension( + query_compiler=self._query_compiler.dot( + other.reindex(index=common), squeeze_self=True + ) + ) + else: + return self.__constructor__( + query_compiler=self._query_compiler.dot( + other.reindex(index=common), squeeze_self=True + ) + ) + + other = np.asarray(other) + if self.shape[0] != other.shape[0]: + raise ValueError( + f"Dot product shape mismatch, {self.shape} vs {other.shape}" + ) + + if len(other.shape) > 1: + return ( + self._query_compiler.dot(other, squeeze_self=True).to_numpy().squeeze() + ) + + return self._reduce_dimension( + query_compiler=self._query_compiler.dot(other, squeeze_self=True) + ) + + def drop_duplicates(self, keep="first", inplace=False): # noqa: PR01, RT01, D200 + """ + Return Series with duplicate values removed. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().drop_duplicates(keep=keep, inplace=inplace) + + def dropna( + self, + *, + axis: Axis = 0, + inplace: bool = False, + how: str | NoDefault = no_default, + ): + """ + Return a new Series with missing values removed. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super()._dropna(axis=axis, inplace=inplace, how=how) + + def duplicated(self, keep: DropKeep = "first"): + """ + Indicate duplicate Series values. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + name = self.name + series = self.to_frame().duplicated(keep=keep) + # we are using df.duplicated method for series but its result will lose the series name, so we preserve it here + series.name = name + return series + + def eq(self, other, level=None, fill_value=None, axis=0): # noqa: PR01, RT01, D200 + """ + Return Equal to of series and `other`, element-wise (binary operator `eq`). + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().eq(other, level=level, axis=axis) + + def equals(self, other): # noqa: PR01, RT01, D200 + """ + Test whether two objects contain the same elements. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() # pragma: no cover + + return ( + self.name == other.name + and self.index.equals(other.index) + and self.eq(other).all() + ) + + def explode(self, ignore_index: bool = False): # noqa: PR01, RT01, D200 + """ + Transform each element of a list-like to a row. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + + return super().explode( + MODIN_UNNAMED_SERIES_LABEL if self.name is None else self.name, + ignore_index=ignore_index, + ) + + def factorize( + self, sort=False, na_sentinel=no_default, use_na_sentinel=no_default + ): # noqa: PR01, RT01, D200 + """ + Encode the object as an enumerated type or categorical variable. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + return self._default_to_pandas( + pandas.Series.factorize, + sort=sort, + na_sentinel=na_sentinel, + use_na_sentinel=use_na_sentinel, + ) + + def fillna( + self, + value: Hashable | Mapping | Series = None, + *, + method: FillnaOptions | None = None, + axis: Axis | None = None, + inplace: bool = False, + limit: int | None = None, + downcast: dict | None = None, + ) -> Series | None: + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + if isinstance(value, BasePandasDataset) and not isinstance(value, Series): + raise TypeError( + '"value" parameter must be a scalar, dict or Series, but ' + + f'you passed a "{type(value).__name__}"' + ) + return super().fillna( + self_is_series=True, + value=value, + method=method, + axis=axis, + inplace=inplace, + limit=limit, + downcast=downcast, + ) + + def floordiv( + self, other, level=None, fill_value=None, axis=0 + ): # noqa: PR01, RT01, D200 + """ + Get Integer division of dataframe and `other`, element-wise (binary operator `floordiv`). + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().floordiv(other, level=level, fill_value=fill_value, axis=axis) + + def ge(self, other, level=None, fill_value=None, axis=0): # noqa: PR01, RT01, D200 + """ + Return greater than or equal to of series and `other`, element-wise (binary operator `ge`). + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().ge(other, level=level, axis=axis) + + def groupby( + self, + by=None, + axis: Axis = 0, + level: IndexLabel | None = None, + as_index: bool = True, + sort: bool = True, + group_keys: bool = True, + observed: bool | NoDefault = no_default, + dropna: bool = True, + ): + """ + Group Series using a mapper or by a Series of columns. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + from snowflake.snowpark.modin.pandas.groupby import ( + SeriesGroupBy, + validate_groupby_args, + ) + + validate_groupby_args(by, level, observed) + + if not as_index: + raise TypeError("as_index=False only valid with DataFrame") + + axis = self._get_axis_number(axis) + return SeriesGroupBy( + self, + by, + axis, + level, + as_index, + sort, + group_keys, + idx_name=None, + observed=observed, + dropna=dropna, + ) + + def gt(self, other, level=None, fill_value=None, axis=0): # noqa: PR01, RT01, D200 + """ + Return greater than of series and `other`, element-wise (binary operator `gt`). + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().gt(other, level=level, axis=axis) + + def hist( + self, + by=None, + ax=None, + grid=True, + xlabelsize=None, + xrot=None, + ylabelsize=None, + yrot=None, + figsize=None, + bins=10, + **kwds, + ): # noqa: PR01, RT01, D200 + """ + Draw histogram of the input series using matplotlib. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + return self._default_to_pandas( + pandas.Series.hist, + by=by, + ax=ax, + grid=grid, + xlabelsize=xlabelsize, + xrot=xrot, + ylabelsize=ylabelsize, + yrot=yrot, + figsize=figsize, + bins=bins, + **kwds, + ) + + def idxmax(self, axis=0, skipna=True, *args, **kwargs): # noqa: PR01, RT01, D200 + """ + Return the row label of the maximum value. + + Parameters + ---------- + axis : {0 or 'index'} + Unused. Parameter needed for compatibility with DataFrame. + skipna : bool, default True + Exclude NA/null values. If an entire Series is NA, the result will be NA. + *args, **kwargs + Additional arguments and keywords have no effect but might be accepted for compatibility with NumPy. + + Returns + ------- + Index, the label of the maximum value. + + Examples + -------- + >>> s = pd.Series(data=[1, None, 4, 3, 4], + ... index=['A', 'B', 'C', 'D', 'E']) + >>> s.idxmax() + 'C' + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + if skipna is None: + skipna = True + return super().idxmax(axis=axis, skipna=skipna, *args, **kwargs) + + def idxmin(self, axis=0, skipna=True, *args, **kwargs): # noqa: PR01, RT01, D200 + """ + Return the row label of the minimum value. + + Parameters + ---------- + axis : {0 or 'index'} + Unused. Parameter needed for compatibility with DataFrame. + skipna : bool, default True + Exclude NA/null values. If an entire Series is NA, the result will be NA. + *args, **kwargs + Additional arguments and keywords have no effect but might be accepted for compatibility with NumPy. + + Returns + ------- + Index, the label of the minimum value. + + Examples + -------- + >>> s = pd.Series(data=[1, None, 4, 3, 4], + ... index=['A', 'B', 'C', 'D', 'E']) + >>> s.idxmin() + 'A' + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + if skipna is None: + skipna = True + return super().idxmin(axis=axis, skipna=skipna, *args, **kwargs) + + def info( + self, + verbose: bool | None = None, + buf: IO[str] | None = None, + max_cols: int | None = None, + memory_usage: bool | str | None = None, + show_counts: bool = True, + ): + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self._default_to_pandas( + pandas.Series.info, + verbose=verbose, + buf=buf, + max_cols=max_cols, + memory_usage=memory_usage, + show_counts=show_counts, + ) + + def interpolate( + self, + method="linear", + axis=0, + limit=None, + inplace=False, + limit_direction: str | None = None, + limit_area=None, + downcast=None, + **kwargs, + ): # noqa: PR01, RT01, D200 + """ + Fill NaN values using an interpolation method. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + return self._default_to_pandas( + pandas.Series.interpolate, + method=method, + axis=axis, + limit=limit, + inplace=inplace, + limit_direction=limit_direction, + limit_area=limit_area, + downcast=downcast, + **kwargs, + ) + + def item(self): # noqa: RT01, D200 + """ + Return the first element of the underlying data as a Python scalar. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + return self[0] + + def items(self): # noqa: D200 + """ + Lazily iterate over (index, value) tuples. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + + def item_builder(s): + return s.name, s.squeeze() + + partition_iterator = PartitionIterator(self.to_frame(), 0, item_builder) + yield from partition_iterator + + def keys(self): # noqa: RT01, D200 + """ + Return alias for index. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self.index + + def kurt( + self, + axis: Axis | None | NoDefault = no_default, + skipna=True, + numeric_only=False, + **kwargs, + ): # noqa: PR01, RT01, D200 + """ + Return unbiased kurtosis over requested axis. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + axis = self._get_axis_number(axis) + return super().kurt(axis, skipna, numeric_only, **kwargs) + + kurtosis = kurt + + def le(self, other, level=None, fill_value=None, axis=0): # noqa: PR01, RT01, D200 + """ + Return less than or equal to of series and `other`, element-wise (binary operator `le`). + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().le(other, level=level, axis=axis) + + def lt(self, other, level=None, fill_value=None, axis=0): # noqa: PR01, RT01, D200 + """ + Return less than of series and `other`, element-wise (binary operator `lt`). + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().lt(other, level=level, axis=axis) + + def map( + self, + arg: Callable | Mapping | Series, + na_action: Literal["ignore"] | None = None, + ) -> Series: + """ + Map values of Series according to input correspondence. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self.__constructor__( + query_compiler=self._query_compiler.map(arg, na_action) + ) + + def mask( + self, + cond: DataFrame | Series | Callable | AnyArrayLike, + other: DataFrame | Series | Callable | Scalar | None = np.nan, + inplace: bool = False, + axis: Axis | None = None, + level: Level | None = None, + ): + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().mask( + cond, + other=other, + inplace=inplace, + axis=axis, + level=level, + ) + + def memory_usage(self, index=True, deep=False): # noqa: PR01, RT01, D200 + """ + Return the memory usage of the Series. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() # pragma: no cover + + if index: + result = self._reduce_dimension( + self._query_compiler.memory_usage(index=False, deep=deep) + ) + index_value = self.index.memory_usage(deep=deep) + return result + index_value + return super().memory_usage(index=index, deep=deep) + + def mod(self, other, level=None, fill_value=None, axis=0): # noqa: PR01, RT01, D200 + """ + Return Modulo of series and `other`, element-wise (binary operator `mod`). + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().mod(other, level=level, fill_value=fill_value, axis=axis) + + def mode(self, dropna=True): # noqa: PR01, RT01, D200 + """ + Return the mode(s) of the Series. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + return super().mode(numeric_only=False, dropna=dropna) + + def mul(self, other, level=None, fill_value=None, axis=0): # noqa: PR01, RT01, D200 + """ + Return multiplication of series and `other`, element-wise (binary operator `mul`). + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().mul(other, level=level, fill_value=fill_value, axis=axis) + + multiply = mul + + def rmul( + self, other, level=None, fill_value=None, axis=0 + ): # noqa: PR01, RT01, D200 + """ + Return multiplication of series and `other`, element-wise (binary operator `mul`). + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().rmul(other, level=level, fill_value=fill_value, axis=axis) + + def ne(self, other, level=None, fill_value=None, axis=0): # noqa: PR01, RT01, D200 + """ + Return not equal to of series and `other`, element-wise (binary operator `ne`). + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().ne(other, level=level, axis=axis) + + def nlargest(self, n=5, keep="first"): # noqa: PR01, RT01, D200 + """ + Return the largest `n` elements. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + return self._default_to_pandas(pandas.Series.nlargest, n=n, keep=keep) + + def nsmallest(self, n=5, keep="first"): # noqa: PR01, RT01, D200 + """ + Return the smallest `n` elements. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + return self.__constructor__( + query_compiler=self._query_compiler.nsmallest(n=n, keep=keep) + ) + + def set_axis( + self, + labels: IndexLabel, + *, + axis: Axis = 0, + copy: bool | NoDefault = no_default, # ignored + ): + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + if not is_scalar(axis): + raise TypeError(f"{type(axis).__name__} is not a valid type for axis.") + return super().set_axis( + labels=labels, + # 'rows', 'index, and 0 are valid axis values for Series. + # 'columns' and 1 are valid axis values only for DataFrame. + axis=pandas.Series._get_axis_name(axis), + copy=copy, + ) + + def unstack(self, level=-1, fill_value=None): # noqa: PR01, RT01, D200 + """ + Unstack, also known as pivot, Series with MultiIndex to produce DataFrame. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + from snowflake.snowpark.modin.pandas.dataframe import DataFrame + + result = DataFrame( + query_compiler=self._query_compiler.unstack(level, fill_value) + ) + + return result.droplevel(0, axis=1) if result.columns.nlevels > 1 else result + + @property + def plot( + self, + kind="line", + ax=None, + figsize=None, + use_index=True, + title=None, + grid=None, + legend=False, + style=None, + logx=False, + logy=False, + loglog=False, + xticks=None, + yticks=None, + xlim=None, + ylim=None, + rot=None, + fontsize=None, + colormap=None, + table=False, + yerr=None, + xerr=None, + label=None, + secondary_y=False, + **kwds, + ): # noqa: PR01, RT01, D200 + """ + Make plot of Series. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + return self._to_pandas().plot + + def pow(self, other, level=None, fill_value=None, axis=0): # noqa: PR01, RT01, D200 + """ + Return exponential power of series and `other`, element-wise (binary operator `pow`). + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().pow(other, level=level, fill_value=fill_value, axis=axis) + + def prod( + self, + axis=None, + skipna=True, + level=None, + numeric_only=False, + min_count=0, + **kwargs, + ): + ErrorMessage.not_implemented() + validate_bool_kwarg(skipna, "skipna", none_allowed=False) + axis = self._get_axis_number(axis) + if level is not None: + if ( + not self._query_compiler.has_multiindex(axis=axis) + and level > 0 + or level < -1 + and level != self.index.name + ): + raise ValueError("level > 0 or level < -1 only valid with MultiIndex") + return self.groupby(level=level, axis=axis, sort=False).prod( + numeric_only=numeric_only, min_count=min_count, **kwargs + ) + new_index = self.columns if axis else self.index + if min_count > len(new_index): + return np.nan + + data = self._validate_dtypes_sum_prod_mean(axis, numeric_only, ignore_axis=True) + if min_count > 1: + return data._reduce_dimension( + data._query_compiler.prod_min_count( + axis=axis, + skipna=skipna, + level=level, + numeric_only=numeric_only, + min_count=min_count, + **kwargs, + ) + ) + return data._reduce_dimension( + data._query_compiler.prod( + axis=axis, + skipna=skipna, + level=level, + numeric_only=numeric_only, + min_count=min_count, + **kwargs, + ) + ) + + product = prod + + def ravel(self, order="C"): # noqa: PR01, RT01, D200 + """ + Return the flattened underlying data as an ndarray. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + data = self._query_compiler.to_numpy().flatten(order=order) + if isinstance(self.dtype, pandas.CategoricalDtype): + data = pandas.Categorical(data, dtype=self.dtype) + + return data + + def reindex(self, *args, **kwargs): + if args: + if len(args) > 1: + raise TypeError("Only one positional argument ('index') is allowed") + if "index" in kwargs: + raise TypeError( + "'index' passed as both positional and keyword argument" + ) + kwargs.update({"index": args[0]}) + index = kwargs.pop("index", None) + method = kwargs.pop("method", None) + level = kwargs.pop("level", None) + copy = kwargs.pop("copy", True) + limit = kwargs.pop("limit", None) + tolerance = kwargs.pop("tolerance", None) + fill_value = kwargs.pop("fill_value", None) + if kwargs: + raise TypeError( + "reindex() got an unexpected keyword " + + f'argument "{list(kwargs.keys())[0]}"' + ) + return super().reindex( + index=index, + columns=None, + method=method, + level=level, + copy=copy, + limit=limit, + tolerance=tolerance, + fill_value=fill_value, + ) + + def rename( + self, + index: Renamer | Hashable | None = None, + *, + axis: Axis | None = None, + copy: bool | None = None, + inplace: bool = False, + level: Level | None = None, + errors: IgnoreRaise = "ignore", + ) -> Series | None: + """ + Alter Series index labels or name. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + if axis is not None: + # make sure we raise if an invalid 'axis' is passed. + # note: axis is unused. It's needed for compatibility with DataFrame. + self._get_axis_number(axis) + + if copy is not None: + WarningMessage.ignored_argument( + operation="series.rename", + argument="copy", + message="copy parameter has been ignored with Snowflake execution engine", + ) + + if callable(index) or is_dict_like(index): + if isinstance(index, dict): + index = Series(index) + new_qc = self._query_compiler.rename( + index_renamer=index, level=level, errors=errors + ) + new_series = self._create_or_update_from_compiler( + new_query_compiler=new_qc, inplace=inplace + ) + if not inplace and hasattr(self, "name"): + new_series.name = self.name + return new_series + else: + # just change Series.name + if inplace: + self.name = index + else: + self_cp = self.copy() + self_cp.name = index + return self_cp + + def repeat(self, repeats, axis=None): # noqa: PR01, RT01, D200 + """ + Repeat elements of a Series. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + if (isinstance(repeats, int) and repeats == 0) or ( + is_list_like(repeats) and len(repeats) == 1 and repeats[0] == 0 + ): + return self.__constructor__() + + return self.__constructor__(query_compiler=self._query_compiler.repeat(repeats)) + + def reset_index( + self, + level=None, + drop=False, + name=no_default, + inplace=False, + allow_duplicates=False, + ): + """ + Generate a new Series with the index reset. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + if name is no_default: + # For backwards compatibility, keep columns as [0] instead of + # [None] when self.name is None + name = 0 if self.name is None else self.name + + if not drop and inplace: + raise TypeError( + "Cannot reset_index inplace on a Series to create a DataFrame" + ) + else: + obj = self.copy() + obj.name = name + new_query_compiler = obj._query_compiler.reset_index( + drop=drop, + level=level, + col_level=0, + col_fill="", + allow_duplicates=allow_duplicates, + names=None, + ) + return self._create_or_update_from_compiler(new_query_compiler, inplace) + + def rdivmod( + self, other, level=None, fill_value=None, axis=0 + ): # noqa: PR01, RT01, D200 + """ + Return integer division and modulo of series and `other`, element-wise (binary operator `rdivmod`). + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + + def rfloordiv( + self, other, level=None, fill_value=None, axis=0 + ): # noqa: PR01, RT01, D200 + """ + Return integer division of series and `other`, element-wise (binary operator `rfloordiv`). + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().rfloordiv(other, level=level, fill_value=fill_value, axis=axis) + + def rmod( + self, other, level=None, fill_value=None, axis=0 + ): # noqa: PR01, RT01, D200 + """ + Return modulo of series and `other`, element-wise (binary operator `rmod`). + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().rmod(other, level=level, fill_value=fill_value, axis=axis) + + def round(self, decimals=0, *args, **kwargs): # noqa: PR01, RT01, D200 + """ + Round each value in a Series to the given number of decimals. + + Parameters + ---------- + decimals : int, default 0 + Number of decimal places to round to. If decimals is negative, it specifies the number of positions to the left of the decimal point. + *args, **kwargs + Additional arguments and keywords have no effect but might be accepted for compatibility with NumPy. + + Returns + ------- + Series + Rounded values of the Series. + + See Also + -------- + numpy.around : Round values of an np.array. + DataFrame.round : Round values of a DataFrame. + + Examples + -------- + >>> s = pd.Series([0.1, 1.3, 2.7]) + >>> s.round() + 0 0.0 + 1 1.0 + 2 3.0 + dtype: float64 + """ + return super().round(decimals, args=args, **kwargs) + + def rpow( + self, other, level=None, fill_value=None, axis=0 + ): # noqa: PR01, RT01, D200 + """ + Return exponential power of series and `other`, element-wise (binary operator `rpow`). + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().rpow(other, level=level, fill_value=fill_value, axis=axis) + + def rsub( + self, other, level=None, fill_value=None, axis=0 + ): # noqa: PR01, RT01, D200 + """ + Return subtraction of series and `other`, element-wise (binary operator `rsub`). + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().rsub(other, level=level, fill_value=fill_value, axis=axis) + + def rtruediv( + self, other, level=None, fill_value=None, axis=0 + ): # noqa: PR01, RT01, D200 + """ + Return floating division of series and `other`, element-wise (binary operator `rtruediv`). + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().rtruediv(other, level=level, fill_value=fill_value, axis=axis) + + rdiv = rtruediv + + def quantile( + self, + q: Scalar | ListLike = 0.5, + interpolation: Literal[ + "linear", "lower", "higher", "midpoint", "nearest" + ] = "linear", + ): + """ + Return value at the given quantile. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().quantile( + q=q, + axis=0, + numeric_only=False, + interpolation=interpolation, + method="single", + ) + + def reorder_levels(self, order): # noqa: PR01, RT01, D200 + """ + Rearrange index levels using input order. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + return super().reorder_levels(order) + + def replace( + self, + to_replace=None, + value=no_default, + inplace=False, + limit=None, + regex=False, + method: str | NoDefault = no_default, + ): + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + inplace = validate_bool_kwarg(inplace, "inplace") + # The following errors cannot be raised by query compiler because we don't know + # if frontend object is Series or DataFrame. + if to_replace is not None and is_dict_like(value): + raise ValueError( + "In Series.replace 'to_replace' must be None if the 'value' is dict-like" + ) + if is_dict_like(to_replace) and value != no_default: + raise ValueError( + "In Series.replace 'to_replace' cannot be dict-like if 'value' is provided" + ) + new_query_compiler = self._query_compiler.replace( + to_replace=to_replace, + value=value, + limit=limit, + regex=regex, + method=method, + ) + return self._create_or_update_from_compiler(new_query_compiler, inplace) + + def searchsorted(self, value, side="left", sorter=None): # noqa: PR01, RT01, D200 + """ + Find indices where elements should be inserted to maintain order. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + searchsorted_qc = self._query_compiler + if sorter is not None: + # `iloc` method works slowly (https://github.com/modin-project/modin/issues/1903), + # so _default_to_pandas is used for now + # searchsorted_qc = self.iloc[sorter].reset_index(drop=True)._query_compiler + # sorter = None + return self._default_to_pandas( + pandas.Series.searchsorted, value, side=side, sorter=sorter + ) + # searchsorted should return item number irrespective of Series index, so + # Series.index is always set to pandas.RangeIndex, which can be easily processed + # on the query_compiler level + if not isinstance(searchsorted_qc.index, pandas.RangeIndex): + searchsorted_qc = searchsorted_qc.reset_index(drop=True) + + result = self.__constructor__( + query_compiler=searchsorted_qc.searchsorted( + value=value, side=side, sorter=sorter + ) + ).squeeze() + + # matching pandas output + if not is_scalar(value) and not is_list_like(result): + result = np.array([result]) + elif isinstance(result, type(self)): + result = result.to_numpy() + + return result + + def sort_values( + self, + axis: Axis = 0, + ascending: bool | int | Sequence[bool] | Sequence[int] = True, + inplace: bool = False, + kind: str = "quicksort", + na_position: str = "last", + ignore_index: bool = False, + key: IndexKeyFunc | None = None, + ): + """ + Sort by the values. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + from snowflake.snowpark.modin.pandas.dataframe import DataFrame + + if is_list_like(ascending) and len(ascending) != 1: + raise ValueError( + f"Length of ascending ({len(ascending)}) must be 1 for Series" + ) + + if axis is not None: + # Validate `axis` + self._get_axis_number(axis) + + # When we convert to a DataFrame, the name is automatically converted to 0 if it + # is None, so we do this to avoid a KeyError. + by = self.name if self.name is not None else 0 + result = ( + DataFrame(self.copy()) + .sort_values( + by=by, + ascending=ascending, + inplace=False, + kind=kind, + na_position=na_position, + ignore_index=ignore_index, + key=key, + ) + .squeeze(axis=1) + ) + result.name = self.name + return self._create_or_update_from_compiler( + result._query_compiler, inplace=inplace + ) + + sparse = CachedAccessor("sparse", SparseAccessor) + + def squeeze(self, axis: Axis | None = None): + """ + Squeeze 1 dimensional axis objects into scalars. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + if axis is not None: + # Validate `axis` + pandas.Series._get_axis_number(axis) + if len(self) == 1: + return self._reduce_dimension(self._query_compiler) + else: + return self.copy() + + def sub(self, other, level=None, fill_value=None, axis=0): # noqa: PR01, RT01, D200 + """ + Return subtraction of Series and `other`, element-wise (binary operator `sub`). + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().sub(other, level=level, fill_value=fill_value, axis=axis) + + subtract = sub + + def swaplevel(self, i=-2, j=-1, copy=True): # noqa: PR01, RT01, D200 + """ + Swap levels `i` and `j` in a `MultiIndex`. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + return self._default_to_pandas("swaplevel", i=i, j=j, copy=copy) + + def take( + self, + indices: list | AnyArrayLike, + axis: Axis = 0, + **kwargs, + ): + """ + Return the elements in the given positional indices along an axis. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().take(indices, axis=0, **kwargs) + + def to_dict(self, into: type[dict] = dict) -> dict: + """ + Convert Series to {label -> value} dict or dict-like object. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self._to_pandas().to_dict(into=into) + + def to_frame( + self, name: Hashable = no_default + ) -> DataFrame: # noqa: PR01, RT01, D200 + """ + Convert Series to {label -> value} dict or dict-like object. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + from snowflake.snowpark.modin.pandas.dataframe import DataFrame + + if name is None: + name = no_default + + self_cp = self.copy() + if name is not no_default: + self_cp.name = name + + return DataFrame(self_cp) + + def to_list(self) -> list: + """ + Return a list of the values. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self.values.tolist() + + def to_numpy( + self, + dtype: npt.DTypeLike | None = None, + copy: bool = False, + na_value: object = no_default, + **kwargs: Any, + ) -> np.ndarray: + """ + Return the NumPy ndarray representing the values in this Series or Index. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return ( + super() + .to_numpy( + dtype=dtype, + copy=copy, + na_value=na_value, + **kwargs, + ) + .flatten() + ) + + tolist = to_list + + # TODO(williamma12): When we implement to_timestamp, have this call the version + # in base.py + def to_period(self, freq=None, copy=True): # noqa: PR01, RT01, D200 + """ + Cast to PeriodArray/Index at a particular frequency. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + return self._default_to_pandas("to_period", freq=freq, copy=copy) + + def to_string( + self, + buf=None, + na_rep="NaN", + float_format=None, + header=True, + index=True, + length=False, + dtype=False, + name=False, + max_rows=None, + min_rows=None, + ): # noqa: PR01, RT01, D200 + """ + Render a string representation of the Series. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + return self._default_to_pandas( + pandas.Series.to_string, + buf=buf, + na_rep=na_rep, + float_format=float_format, + header=header, + index=index, + length=length, + dtype=dtype, + name=name, + max_rows=max_rows, + ) + + # TODO(williamma12): When we implement to_timestamp, have this call the version + # in base.py + def to_timestamp(self, freq=None, how="start", copy=True): # noqa: PR01, RT01, D200 + """ + Cast to DatetimeIndex of Timestamps, at beginning of period. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + return self._default_to_pandas("to_timestamp", freq=freq, how=how, copy=copy) + + def transpose(self, *args, **kwargs): # noqa: PR01, RT01, D200 + """ + Return the transpose, which is by definition `self`. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self + + T = property(transpose) + + def truediv( + self, other, level=None, fill_value=None, axis=0 + ): # noqa: PR01, RT01, D200 + """ + Return floating division of series and `other`, element-wise (binary operator `truediv`). + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().truediv(other, level=level, fill_value=fill_value, axis=axis) + + div = divide = truediv + + def truncate( + self, before=None, after=None, axis=None, copy=True + ): # noqa: PR01, RT01, D200 + """ + Truncate a Series before and after some index value. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + return self._default_to_pandas( + pandas.Series.truncate, before=before, after=after, axis=axis, copy=copy + ) + + def unique(self): + """ + Return unique values of Series object. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self.__constructor__( + query_compiler=self._query_compiler.unique() + ).to_numpy() + + def update(self, other): # noqa: PR01, D200 + """ + Modify Series in place using values from passed Series. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + if not isinstance(other, Series): + other = self.__constructor__(other) + query_compiler = self._query_compiler.series_update(other) + self._update_inplace(new_query_compiler=query_compiler) + + def value_counts( + self, + normalize: bool = False, + sort: bool = True, + ascending: bool = False, + bins: int | None = None, + dropna: bool = True, + ): + """ + Return a Series containing counts of unique values. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self.__constructor__( + query_compiler=self._query_compiler.value_counts( + subset=None, + normalize=normalize, + sort=sort, + ascending=ascending, + bins=bins, + dropna=dropna, + ).set_index_names([self.name]), + name="proportion" if normalize else "count", + ) + + def view(self, dtype=None): # noqa: PR01, RT01, D200 + """ + Create a new view of the Series. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + return self.__constructor__( + query_compiler=self._query_compiler.series_view(dtype=dtype) + ) + + def where( + self, + cond: DataFrame | Series | Callable | AnyArrayLike, + other: DataFrame | Series | Callable | Scalar | None = np.nan, + inplace: bool = False, + axis: Axis | None = None, + level: Level | None = None, + ): + """ + Replace values where the condition is False. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().where( + cond, + other=other, + inplace=inplace, + axis=axis, + level=level, + ) + + def xs( + self, key, axis=0, level=None, drop_level=True + ): # pragma: no cover # noqa: PR01, D200 + """ + Return cross-section from the Series/DataFrame. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented("") + + @property + def attrs(self): # noqa: RT01, D200 + """ + Return dictionary of global attributes of this dataset. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + + def attrs(df): + return df.attrs + + return self._default_to_pandas(attrs) + + @property + def array(self): # noqa: RT01, D200 + """ + Return the ExtensionArray of the data backing this Series or Index. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + + def array(df): + return df.array + + return self._default_to_pandas(array) + + @property + def axes(self): # noqa: RT01, D200 + """ + Return a list of the row axis labels. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return [self.index] + + @property + def cat(self): # noqa: RT01, D200 + """ + Accessor object for categorical properties of the Series values. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + from .series_utils import CategoryMethods + + return CategoryMethods(self) + + @property + def dt(self): # noqa: RT01, D200 + """ + Accessor object for datetimelike properties of the Series values. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + current_dtype = self.dtype + if not is_datetime64_any_dtype(current_dtype): + raise AttributeError("Can only use .dt accessor with datetimelike values") + + from .series_utils import DatetimeProperties + + return DatetimeProperties(self) + + @property + def dtype(self): # noqa: RT01, D200 + """ + Return the dtype object of the underlying data. + See :func:`DataFrame.dtypes` for exact behavior. + + Examples + -------- + >>> s = pd.Series([1, 2, 3]) + >>> s.dtype + dtype('int64') + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self._query_compiler.dtypes.squeeze() + + dtypes = dtype + + @property + def empty(self) -> bool: + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return len(self) == 0 + + @property + def hasnans(self): # noqa: RT01, D200 + """ + Return True if Series has any nans. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self.isna().sum() > 0 + + def isna(self): + """ + Detect missing values. + + Returns + ------- + The result of detecting missing values. + """ + return super().isna() + + def isnull(self): + """ + Detect missing values. + + Returns + ------- + The result of detecting missing values. + """ + return super().isnull() + + @property + def is_monotonic_increasing(self): # noqa: RT01, D200 + """ + Return True if values in the Series are monotonic_increasing. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self._reduce_dimension(self._query_compiler.is_monotonic_increasing()) + + @property + def is_monotonic_decreasing(self): # noqa: RT01, D200 + """ + Return True if values in the Series are monotonic_decreasing. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self._reduce_dimension(self._query_compiler.is_monotonic_decreasing()) + + @property + def is_unique(self): # noqa: RT01, D200 + """ + Return True if values in the Series are unique. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self.nunique(dropna=False) == len(self) + + @property + def nbytes(self): # noqa: RT01, D200 + """ + Return the number of bytes in the underlying data. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() + return self.memory_usage(index=False) + + @property + def ndim(self) -> int: + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return 1 + + def nunique(self, dropna=True): # noqa: PR01, RT01, D200 + """ + Return number of unique elements in the object. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super().nunique(dropna=dropna) + + @property + def shape( + self, + ) -> tuple(int,): + return (len(self),) + + def shift( + self, + periods: int = 1, + freq=None, + axis: Axis = 0, + fill_value: Hashable = no_default, + ): + """ + Shift index by desired number of periods with an optional time `freq`. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + if axis == 1: + # pandas compatible error. + raise ValueError("No axis named 1 for object type Series") + + return super().shift(periods, freq, axis, fill_value) + + @property + def str(self): # noqa: RT01, D200 + """ + Vectorized string functions for Series and Index. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + current_dtype = self.dtype + if not is_string_dtype(current_dtype): + raise AttributeError("Can only use .str accessor with string values!") + + from .series_utils import StringMethods + + return StringMethods(self) + + def _to_pandas( + self, + *, + statement_params: dict[str, str] | None = None, + **kwargs: Any, + ): + """ + Convert Snowpark pandas Series to pandas Series + + Args: + statement_params: Dictionary of statement level parameters to be set while executing this action. + + Returns: + pandas series + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + df = self._query_compiler.to_pandas(statement_params=statement_params, **kwargs) + if len(df.columns) == 0: + return pandas.Series([]) + series = df[df.columns[0]] + # special case when series is wrapped as dataframe, but has not label. + # This is indicated with MODIN_UNNAMED_SERIES_LABEL + if self._query_compiler.columns[0] == MODIN_UNNAMED_SERIES_LABEL: + series.name = None + + return series + + def _to_datetime(self, **kwargs): + """ + Convert `self` to datetime. + + Parameters + ---------- + **kwargs : dict + Optional arguments to use during query compiler's + `to_datetime` invocation. + + Returns + ------- + datetime + Series of datetime64 dtype. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self.__constructor__( + query_compiler=self._query_compiler.series_to_datetime(**kwargs) + ) + + def _to_numeric(self, **kwargs: Any) -> Series: + """ + Convert `self` to numeric. + + Parameters + ---------- + **kwargs : dict + Optional arguments to use during query compiler's + `to_numeric` invocation. + + Returns + ------- + numeric + Series of numeric dtype. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self.__constructor__( + query_compiler=self._query_compiler.to_numeric(**kwargs) + ) + + def _qcut( + self, + q: int | ListLike, + retbins: bool = False, + duplicates: Literal["raise", "drop"] = "raise", + ) -> Series: + """ + Quantile-based discretization function. + + See SnowflakeQueryCompiler.qcut for details. + + """ + + return self.__constructor__( + query_compiler=self._query_compiler.qcut(q, retbins, duplicates) + ) + + def _reduce_dimension(self, query_compiler): + """ + Try to reduce the dimension of data from the `query_compiler`. + + Parameters + ---------- + query_compiler : BaseQueryCompiler + Query compiler to retrieve the data. + + Returns + ------- + pandas.Series or pandas.DataFrame. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return query_compiler.to_pandas().squeeze() + + def _validate_dtypes_sum_prod_mean(self, axis, numeric_only, ignore_axis=False): + """ + Validate data dtype for `sum`, `prod` and `mean` methods. + + Parameters + ---------- + axis : {0, 1} + Axis to validate over. + numeric_only : bool + Whether or not to allow only numeric data. + If True and non-numeric data is found, exception + will be raised. + ignore_axis : bool, default: False + Whether or not to ignore `axis` parameter. + + Returns + ------- + Series + + Notes + ----- + Actually returns unmodified `self` object, + added for compatibility with Modin DataFrame. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self + + def _validate_dtypes(self, numeric_only=False): + """ + Check that all the dtypes are the same. + + Parameters + ---------- + numeric_only : bool, default: False + Whether or not to allow only numeric data. + If True and non-numeric data is found, exception + will be raised. + + Notes + ----- + Actually does nothing, added for compatibility with Modin DataFrame. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + pass + + def _get_numeric_data(self, axis: int): + """ + Grab only numeric data from Series. + + Parameters + ---------- + axis : {0, 1} + Axis to inspect on having numeric types only. + + Returns + ------- + Series + + Notes + ----- + `numeric_only` parameter is not supported by Series, so this method + does not do anything. The method is added for compatibility with Modin DataFrame. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return self + + def _update_inplace(self, new_query_compiler): + """ + Update the current Series in-place using `new_query_compiler`. + + Parameters + ---------- + new_query_compiler : BaseQueryCompiler + QueryCompiler to use to manage the data. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + super()._update_inplace(new_query_compiler=new_query_compiler) + # Propagate changes back to parent so that column in dataframe had the same contents + if self._parent is not None: + if self._parent_axis == 0: + self._parent.loc[self.name] = self + else: + self._parent[self.name] = self + + def _create_or_update_from_compiler(self, new_query_compiler, inplace=False): + """ + Return or update a Series with given `new_query_compiler`. + + Parameters + ---------- + new_query_compiler : PandasQueryCompiler + QueryCompiler to use to manage the data. + inplace : bool, default: False + Whether or not to perform update or creation inplace. + + Returns + ------- + Series, DataFrame or None + None if update was done, Series or DataFrame otherwise. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + assert ( + isinstance(new_query_compiler, type(self._query_compiler)) + or type(new_query_compiler) in self._query_compiler.__class__.__bases__ + ), f"Invalid Query Compiler object: {type(new_query_compiler)}" + if not inplace and new_query_compiler.is_series_like(): + return self.__constructor__(query_compiler=new_query_compiler) + elif not inplace: + # This can happen with things like `reset_index` where we can add columns. + from snowflake.snowpark.modin.pandas.dataframe import DataFrame + + return DataFrame(query_compiler=new_query_compiler) + else: + self._update_inplace(new_query_compiler=new_query_compiler) + + def _repartition(self): + """ + Repartitioning Series to get ideal partitions inside. + + Allows to improve performance where the query compiler can't improve + yet by doing implicit repartitioning. + + Returns + ------- + Series + The repartitioned Series. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return super()._repartition(axis=0) + + # Persistance support methods - BEGIN + @classmethod + def _inflate_light(cls, query_compiler, name): + """ + Re-creates the object from previously-serialized lightweight representation. + + The method is used for faster but not disk-storable persistence. + + Parameters + ---------- + query_compiler : BaseQueryCompiler + Query compiler to use for object re-creation. + name : str + The name to give to the new object. + + Returns + ------- + Series + New Series based on the `query_compiler`. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return cls(query_compiler=query_compiler, name=name) + + @classmethod + def _inflate_full(cls, pandas_series): + """ + Re-creates the object from previously-serialized disk-storable representation. + + Parameters + ---------- + pandas_series : pandas.Series + Data to use for object re-creation. + + Returns + ------- + Series + New Series based on the `pandas_series`. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + return cls(data=pandas_series) + + def __reduce__(self): + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + ErrorMessage.not_implemented() # pragma: no cover + + self._query_compiler.finalize() + # if PersistentPickle.get(): + # return self._inflate_full, (self._to_pandas(),) + return self._inflate_light, (self._query_compiler, self.name) + + # Persistance support methods - END diff --git a/src/snowflake/snowpark/modin/pandas/series_utils.py b/src/snowflake/snowpark/modin/pandas/series_utils.py new file mode 100644 index 00000000000..0a513955ed0 --- /dev/null +++ b/src/snowflake/snowpark/modin/pandas/series_utils.py @@ -0,0 +1,1517 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +""" +Implement Series's accessors public API as pandas does. + +Accessors: `Series.cat`, `Series.str`, `Series.dt` +""" +import re +import sys +from typing import TYPE_CHECKING, Callable, Optional, Union + +import numpy as np +import pandas + +from snowflake.snowpark.modin.pandas import DataFrame, Series +from snowflake.snowpark.modin.utils import _inherit_docstrings + +if sys.version_info[0] == 3 and sys.version_info[1] >= 7: + # Python >= 3.7 + from re import Pattern as _pattern_type +else: + # Python <= 3.6 + from re import _pattern_type + +if TYPE_CHECKING: + from datetime import tzinfo + + from pandas._typing import npt + +# add this line to enable doc tests to run +from snowflake.snowpark.modin import pandas as pd # noqa: F401 +from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage + + +@_inherit_docstrings(pandas.core.arrays.categorical.CategoricalAccessor) +class CategoryMethods: + # CategoricalDType is not supported with Snowpark pandas API. Mark all methods + # to be unsupported. + category_not_supported_message = "CategoricalDType and corresponding methods is not available in Snowpark pandas API yet!" + + def __init__(self, series) -> None: + self._series = series + self._query_compiler = series._query_compiler + + @property + def categories(self): + ErrorMessage.not_implemented(self.category_not_supported_message) + + @categories.setter + def categories(self, categories): + ErrorMessage.not_implemented( + self.category_not_supported_message + ) # pragma: no cover + + @property + def ordered(self): + ErrorMessage.not_implemented(self.category_not_supported_message) + + @property + def codes(self): + ErrorMessage.not_implemented(self.category_not_supported_message) + + def rename_categories(self, new_categories, inplace=False): + ErrorMessage.not_implemented(self.category_not_supported_message) + + def reorder_categories(self, new_categories, ordered=None, inplace=False): + ErrorMessage.not_implemented(self.category_not_supported_message) + + def add_categories(self, new_categories, inplace=False): + ErrorMessage.not_implemented(self.category_not_supported_message) + + def remove_categories(self, removals, inplace=False): + ErrorMessage.not_implemented(self.category_not_supported_message) + + def remove_unused_categories(self, inplace=False): + ErrorMessage.not_implemented(self.category_not_supported_message) + + def set_categories(self, new_categories, ordered=None, rename=False, inplace=False): + ErrorMessage.not_implemented(self.category_not_supported_message) + + def as_ordered(self, inplace=False): + ErrorMessage.not_implemented(self.category_not_supported_message) + + def as_unordered(self, inplace=False): + ErrorMessage.not_implemented(self.category_not_supported_message) + + +@_inherit_docstrings(pandas.core.strings.accessor.StringMethods) +class StringMethods: + def __init__(self, series) -> None: + # Check if dtypes is objects + + self._series = series + self._query_compiler = series._query_compiler + + def casefold(self): + return Series(query_compiler=self._query_compiler.str_casefold()) + + def cat(self, others=None, sep=None, na_rep=None, join=None): + compiler_result = self._query_compiler.str_cat( + others=others, sep=sep, na_rep=na_rep, join=join + ) + # if others is None, result is a string. otherwise, it's a series. + return ( + compiler_result.to_pandas().squeeze() + if others is None + else Series(query_compiler=compiler_result) + ) + + def decode(self, encoding, errors="strict"): + return Series( + query_compiler=self._query_compiler.str_decode(encoding, errors=errors) + ) + + def split( + self, + pat: Optional[str] = None, + n: int = -1, + expand: bool = False, + regex: Optional[bool] = None, + ) -> Series: + """ + Split strings around given separator/delimiter. + + Splits the string in the Series/Index from the beginning, at the specified delimiter string. + + Parameters + ---------- + pat : str, optional + String to split on. If not specified, split on whitespace. + n : int, default -1 (all) + Limit number of splits in output. None, 0 and -1 will be interpreted as return all splits. + expand : bool, default False (Not implemented yet, should be set to False) + Expand the split strings into separate columns. + - If True, return DataFrame/MultiIndex expanding dimensionality. + - If False, return Series/Index, containing lists of strings. + regex : bool, default None (Not implemented yet, should be set to False or None) + Determines if the passed-in pattern is a regular expression: + - If True, assumes the passed-in pattern is a regular expression + - If False or None, treats the pattern as a literal string. + + Returns + ------- + Series, Index, DataFrame or MultiIndex + Type matches caller unless expand=True (see Notes). + + See also + -------- + Series.str.split + Split strings around given separator/delimiter. + Series.str.rsplit + Splits string around given separator/delimiter, starting from the right. + Series.str.join + Join lists contained as elements in the Series/Index with passed delimiter. + str.split + Standard library version for split. + str.rsplit + Standard library version for rsplit. + + Notes + ----- + The handling of the n keyword depends on the number of found splits: + + - If found splits > n, make first n splits only + - If found splits <= n, make all splits + - If for a certain row the number of found splits < n, append None for padding up to n if expand=True + - If using expand=True, Series and Index callers return DataFrame and MultiIndex objects, respectively. + + Examples + -------- + >>> s = pd.Series( + [ + "this is a regular sentence", + "https://docs.python.org/3/tutorial/index.html", + np.nan + ] + ) + s + 0 this is a regular sentence + 1 https://docs.python.org/3/tutorial/index.html + 2 NaN + dtype: object + + In the default setting, the string is split by whitespace. + + >>> s.str.split() + 0 [this, is, a, regular, sentence] + 1 [https://docs.python.org/3/tutorial/index.html] + 2 NaN + dtype: object + + The n parameter can be used to limit the number of splits on the delimiter. + + >>> s.str.split(n=2) + 0 [this, is, a regular sentence] + 1 [https://docs.python.org/3/tutorial/index.html] + 2 NaN + dtype: object + + The pat parameter can be used to split by other characters. + + >>> s.str.split(pat="/") + 0 [this is a regular sentence] + 1 [https:, , docs.python.org, 3, tutorial, index... + 2 NaN + dtype: object + """ + if not pat and pat is not None: + raise ValueError("split() requires a non-empty pattern match.") + + else: + return Series( + query_compiler=self._query_compiler.str_split( + pat=pat, n=n, expand=expand, regex=regex + ) + ) + + def rsplit(self, pat=None, n=-1, expand=False): + if not pat and pat is not None: + raise ValueError("rsplit() requires a non-empty pattern match.") + + else: + return Series( + query_compiler=self._query_compiler.str_rsplit( + pat=pat, n=n, expand=expand + ) + ) + + def get(self, i): + return Series(query_compiler=self._query_compiler.str_get(i)) + + def join(self, sep): + if sep is None: + raise AttributeError("'NoneType' object has no attribute 'join'") + return Series(query_compiler=self._query_compiler.str_join(sep)) + + def get_dummies(self, sep="|"): + return DataFrame(query_compiler=self._query_compiler.str_get_dummies(sep)) + + def contains( + self, + pat: str, + case: bool = True, + flags: int = 0, + na: object = None, + regex: bool = True, + ): + """ + Test if pattern or regex is contained within a string of a Series or Index. + + Return boolean Series or Index based on whether a given pattern or regex is contained within a string of a Series or Index. + + Parameters + ---------- + pat : str + Character sequence or regular expression. + case : bool, default True + If True, case sensitive. + flags : int, default 0 (no flags) + Flags to pass through to the re module, e.g. re.IGNORECASE. + na : scalar, optional + Fill value for missing values. The default depends on dtype of the array. For object-dtype, numpy.nan is used. For StringDtype, pandas.NA is used. + regex : bool, default True + If True, assumes the pat is a regular expression. + If False, treats the pat as a literal string. + + Returns + ------- + Series or Index of boolean values + A Series or Index of boolean values indicating whether the given pattern is contained within the string of each element of the Series or Index. + + See also + -------- + match + Analogous, but stricter, relying on re.match instead of re.search. + Series.str.startswith + Test if the start of each string element matches a pattern. + Series.str.endswith + Same as startswith, but tests the end of string. + + Examples + -------- + Returning a Series of booleans using only a literal pattern. + + >>> s1 = pd.Series(['Mouse', 'dog', 'house and parrot', '23', np.NaN]) + >>> s1.str.contains('og', regex=False) + 0 False + 1 True + 2 False + 3 False + 4 NaN + dtype: object + + Returning an Index of booleans using only a literal pattern. + + >>> ind = pd.Index(['Mouse', 'dog', 'house and parrot', '23.0', np.NaN]) + >>> ind.str.contains('23', regex=False) + Index([False, False, False, True, nan], dtype='object') + + Specifying case sensitivity using case. + + >>> s1.str.contains('oG', case=True, regex=True) + 0 False + 1 False + 2 False + 3 False + 4 NaN + dtype: object + + Specifying na to be False instead of NaN replaces NaN values with False. If Series or Index does not contain NaN values the resultant dtype will be bool, otherwise, an object dtype. + + >>> s1.str.contains('og', na=False, regex=True) + 0 False + 1 True + 2 False + 3 False + 4 False + dtype: bool + + Returning ‘house’ or ‘dog’ when either expression occurs in a string. + + >>> s1.str.contains('house|dog', regex=True) + 0 False + 1 True + 2 True + 3 False + 4 NaN + dtype: object + + Ignoring case sensitivity using flags with regex. + + >>> import re + >>> s1.str.contains('PARROT', flags=re.IGNORECASE, regex=True) + 0 False + 1 False + 2 True + 3 False + 4 NaN + dtype: object + + Returning any digit using regular expression. + + >>> s1.str.contains('\\d', regex=True) + 0 False + 1 False + 2 False + 3 True + 4 NaN + dtype: object + + Ensure pat is a not a literal pattern when regex is set to True. Note in the following example one might expect only s2[1] and s2[3] to return True. However, ‘.0’ as a regex matches any character followed by a 0. + + >>> s2 = pd.Series(['40', '40.0', '41', '41.0', '35']) + >>> s2.str.contains('.0', regex=True) + 0 True + 1 True + 2 False + 3 True + 4 False + dtype: bool + """ + return Series( + query_compiler=self._query_compiler.str_contains( + pat, case=case, flags=flags, na=na, regex=regex + ) + ) + + def replace( + self, + pat: str, + repl: Union[str, Callable], + n: int = -1, + case: Optional[bool] = None, + flags: int = 0, + regex: bool = True, + ) -> Series: + r""" + Replace each occurrence of pattern/regex in the Series/Index. + + Equivalent to str.replace() or re.sub(), depending on the regex value. + + Parameters + ---------- + pat : str + String can be a character sequence or regular expression. + repl : str or callable + Replacement string or a callable. The callable is passed the regex match object and must return a replacement string to be used. See re.sub(). + n : int, default -1 (all) + Number of replacements to make from start. + case : bool, default None + Determines if replace is case sensitive: + - If True, case sensitive (the default if pat is a string) + - Set to False for case insensitive + - Cannot be set if pat is a compiled regex. + flags : int, default 0 (no flags) + Regex module flags, e.g. re.IGNORECASE. Cannot be set if pat is a compiled regex. + regex : bool, default False + Determines if the passed-in pattern is a regular expression: + - If True, assumes the passed-in pattern is a regular expression. + - If False, treats the pattern as a literal string + - Cannot be set to False if pat is a compiled regex or repl is a callable. + + Returns + ------- + Series or Index of object + A copy of the object with all matching occurrences of pat replaced by repl. + + Raises + ------ + ValueError + - if regex is False and repl is a callable or pat is a compiled regex + - if pat is a compiled regex and case or flags is set + + Notes + ----- + When pat is a compiled regex, all flags should be included in the compiled regex. Use of case, flags, or regex=False with a compiled regex will raise an error. + + Examples + -------- + When pat is a string and regex is True, the given pat is compiled as a regex. When repl is a string, it replaces matching regex patterns as with re.sub(). NaN value(s) in the Series are left as is: + + >>> pd.Series(['foo', 'fuz', np.nan]).str.replace('f.', 'ba', regex=True) + 0 bao + 1 baz + 2 NaN + dtype: object + + When pat is a string and regex is False, every pat is replaced with repl as with str.replace(): + + >>> pd.Series(['f.o', 'fuz', np.nan]).str.replace('f.', 'ba', regex=False) + 0 bao + 1 fuz + 2 NaN + dtype: object + + When repl is a callable, it is called on every pat using re.sub(). The callable should expect one positional argument (a regex object) and return a string. + + To get the idea: + + >>> pd.Series(['foo', 'fuz', np.nan]).str.replace('f', repr, regex=True) + 0 oo + 1 uz + 2 NaN + dtype: object + + Reverse every lowercase alphabetic word: + + >>> repl = lambda m: m.group(0)[::-1] + >>> ser = pd.Series(['foo 123', 'bar baz', np.nan]) + >>> ser.str.replace(r'[a-z]+', repl, regex=True) + 0 oof 123 + 1 rab zab + 2 NaN + dtype: object + + Using regex groups (extract second group and swap case): + + >>> pat = r"(?P\w+) (?P\w+) (?P\w+)" + >>> repl = lambda m: m.group('two').swapcase() + >>> ser = pd.Series(['One Two Three', 'Foo Bar Baz']) + >>> ser.str.replace(pat, repl, regex=True) + 0 tWO + 1 bAR + dtype: object + + Using a compiled regex with flags + + >>> import re + >>> regex_pat = re.compile(r'FUZ', flags=re.IGNORECASE) + >>> pd.Series(['foo', 'fuz', np.nan]).str.replace(regex_pat, 'bar', regex=True) + 0 foo + 1 bar + 2 NaN + dtype: object + """ + if not (isinstance(repl, str) or callable(repl)): + raise TypeError("repl must be a string or callable") + return Series( + query_compiler=self._query_compiler.str_replace( + pat, repl, n=n, case=case, flags=flags, regex=regex + ) + ) + + def pad(self, width, side="left", fillchar=" "): + if len(fillchar) != 1: + raise TypeError("fillchar must be a character, not str") + return Series( + query_compiler=self._query_compiler.str_pad( + width, side=side, fillchar=fillchar + ) + ) + + def center(self, width, fillchar=" "): + if len(fillchar) != 1: + raise TypeError("fillchar must be a character, not str") + return Series( + query_compiler=self._query_compiler.str_center(width, fillchar=fillchar) + ) + + def ljust(self, width, fillchar=" "): + if len(fillchar) != 1: + raise TypeError("fillchar must be a character, not str") + return Series( + query_compiler=self._query_compiler.str_ljust(width, fillchar=fillchar) + ) + + def rjust(self, width, fillchar=" "): + if len(fillchar) != 1: + raise TypeError("fillchar must be a character, not str") + return Series( + query_compiler=self._query_compiler.str_rjust(width, fillchar=fillchar) + ) + + def zfill(self, width): + return Series(query_compiler=self._query_compiler.str_zfill(width)) + + def wrap(self, width, **kwargs): + if width <= 0: + raise ValueError(f"invalid width {width} (must be > 0)") + return Series(query_compiler=self._query_compiler.str_wrap(width, **kwargs)) + + def slice(self, start=None, stop=None, step=None): + if step == 0: + raise ValueError("slice step cannot be zero") + return Series( + query_compiler=self._query_compiler.str_slice( + start=start, stop=stop, step=step + ) + ) + + def slice_replace(self, start=None, stop=None, repl=None): + return Series( + query_compiler=self._query_compiler.str_slice_replace( + start=start, stop=stop, repl=repl + ) + ) + + def count(self, pat: str, flags: int = 0, **kwargs): + """ + Count occurrences of pattern in each string of the Series/Index. + + This function is used to count the number of times a particular regex pattern is repeated in each of the string elements of the Series. + + Parameters + ---------- + pat : str + Valid regular expression. + flags : int, default 0, meaning no flags + Flags for the re module. + **kwargs + For compatibility with other string methods. Not used. + + Returns + ------- + Series or Index + Same type as the calling object containing the integer counts. + + See also + -------- + re + Standard library module for regular expressions. + str.count + Standard library version, without regular expression support. + + Notes + ----- + Some characters need to be escaped when passing in pat. eg. '$' has a special meaning in regex and must be escaped when finding this literal character. + + Examples + -------- + >>> s = pd.Series(['A', 'B', 'Aaba', 'Baca', np.nan, 'CABA', 'cat']) + >>> s.str.count('a') + 0 0.0 + 1 0.0 + 2 2.0 + 3 2.0 + 4 NaN + 5 0.0 + 6 1.0 + dtype: float64 + + Escape '$' to find the literal dollar sign. + + >>> s = pd.Series(['$', 'B', 'Aab$', '$$ca', 'C$B$', 'cat']) + >>> s.str.count('\\$') + 0 1 + 1 0 + 2 1 + 3 2 + 4 2 + 5 0 + dtype: int64 + + This is also available on Index + + >>> pd.Index(['A', 'A', 'Aaba', 'cat']).str.count('a') + Int64Index([0, 0, 2, 1], dtype='int64') + """ + if not isinstance(pat, (str, _pattern_type)): + raise TypeError("first argument must be string or compiled pattern") + return Series( + query_compiler=self._query_compiler.str_count(pat, flags=flags, **kwargs) + ) + + def startswith(self, pat, na=np.NaN): + """ + Test if the start of each string element matches a pattern. + + Parameters + ---------- + pat : str or tuple[str, ...] + Character sequence or tuple of strings. Regular expressions are not accepted. + na : object, default NaN + Object shown if element tested is not a string. The default depends on dtype of the array. For object-dtype, numpy.nan is used. For StringDtype, pandas.NA is used. + + Returns + ------- + Series or Index of bool + A Series of booleans indicating whether the given pattern matches the start of each string element. + + See also + -------- + str.startswith + Python standard library string method. + Series.str.endswith + Same as startswith, but tests the end of string. + Series.str.contains + Tests if string element contains a pattern. + + Examples + -------- + >>> s = pd.Series(['bat', 'Bear', 'cat', np.nan]) + >>> s + 0 bat + 1 Bear + 2 cat + 3 NaN + dtype: object + + >>> s.str.startswith('b') + 0 True + 1 False + 2 False + 3 NaN + dtype: object + + >>> s.str.startswith(('b', 'B')) + 0 True + 1 True + 2 False + 3 NaN + dtype: object + + Specifying na to be False instead of NaN. + + >>> s.str.startswith('b', na=False) + 0 True + 1 False + 2 False + 3 False + dtype: bool + """ + return Series(query_compiler=self._query_compiler.str_startswith(pat, na=na)) + + def encode(self, encoding, errors="strict"): + return Series( + query_compiler=self._query_compiler.str_encode(encoding, errors=errors) + ) + + def endswith(self, pat, na=np.NaN): + """ + Test if the end of each string element matches a pattern. + + Parameters + ---------- + pat : str or tuple[str, …] + Character sequence or tuple of strings. Regular expressions are not accepted. + na : object, default NaN + Object shown if element tested is not a string. The default depends on dtype of the array. For object-dtype, numpy.nan is used. For StringDtype, pandas.NA is used. + + Returns + ------- + Series or Index of bool + A Series of booleans indicating whether the given pattern matches the end of each string element. + + See also + -------- + str.endswith + Python standard library string method. + Series.str.startswith + Same as endswith, but tests the start of string. + Series.str.contains + Tests if string element contains a pattern. + + Examples + -------- + >>> s = pd.Series(['bat', 'bear', 'caT', np.nan]) + >>> s + 0 bat + 1 bear + 2 caT + 3 NaN + dtype: object + + >>> s.str.endswith('t') + 0 True + 1 False + 2 False + 3 NaN + dtype: object + + >>> s.str.endswith(('t', 'T')) + 0 True + 1 False + 2 True + 3 NaN + dtype: object + + Specifying na to be False instead of NaN. + + >>> s.str.endswith('t', na=False) + 0 True + 1 False + 2 False + 3 False + dtype: bool + """ + return Series(query_compiler=self._query_compiler.str_endswith(pat, na=na)) + + def findall(self, pat, flags=0, **kwargs): + if not isinstance(pat, (str, _pattern_type)): + raise TypeError("first argument must be string or compiled pattern") + return Series( + query_compiler=self._query_compiler.str_findall(pat, flags=flags, **kwargs) + ) + + def match(self, pat, case=True, flags=0, na=np.NaN): + if not isinstance(pat, (str, _pattern_type)): + raise TypeError("first argument must be string or compiled pattern") + return Series( + query_compiler=self._query_compiler.str_match(pat, flags=flags, na=na) + ) + + def extract(self, pat, flags=0, expand=True): + query_compiler = self._query_compiler.str_extract( + pat, flags=flags, expand=expand + ) + return ( + DataFrame(query_compiler=query_compiler) + if expand or re.compile(pat).groups > 1 + else Series(query_compiler=query_compiler) + ) + + def extractall(self, pat, flags=0): + return Series(query_compiler=self._query_compiler.str_extractall(pat, flags)) + + def len(self): + """ + Get the length of a string. For non-string values this + returns the length of the string representation. + + Returns + ------- + Series + A Series with the length of each value + + Examples + -------- + >>> s = pd.Series(['dog', + ... '', + ... 5, + ... {'foo' : 'bar'}, + ... [2, 3, 5, 7], + ... ('one', 'two', 'three')]) + >>> s.str.len() + 0 3 + 1 0 + 2 1 + 3 13 + 4 9 + 5 21 + dtype: int64 + """ + return Series(query_compiler=self._query_compiler.str_len()) + + def strip(self, to_strip: str = None) -> Series: + """ + Remove leading and trailing characters. + + Strip whitespaces (including newlines) or a set of specified characters from each string in the Series/Index from left and right sides. Replaces any non-strings in Series with NaNs. Equivalent to str.strip(). + + Parameters + ---------- + to_strip : str or None, default None + Specifying the set of characters to be removed. All combinations of this set of characters will be stripped. If None then whitespaces are removed. + + Returns + ------- + Series or Index of object + + See also + -------- + Series.str.strip + Remove leading and trailing characters in Series/Index. + Series.str.lstrip + Remove leading characters in Series/Index. + Series.str.rstrip + Remove trailing characters in Series/Index. + + Examples + -------- + >>> s = pd.Series(['1. Ant. ', '2. Bee!\\n', '3. Cat?\\t', np.nan, 10, True]) + >>> s + 0 1. Ant. + 1 2. Bee!\\n + 2 3. Cat?\\t + 3 NaN + 4 10 + 5 True + dtype: object + + >>> s.str.strip() + 0 1. Ant. + 1 2. Bee! + 2 3. Cat? + 3 NaN + 4 NaN + 5 NaN + dtype: object + + >>> s.str.lstrip('123.') + 0 Ant. + 1 Bee!\\n + 2 Cat?\\t + 3 NaN + 4 NaN + 5 NaN + dtype: object + + >>> s.str.rstrip('.!? \\n\\t') + 0 1. Ant + 1 2. Bee + 2 3. Cat + 3 NaN + 4 NaN + 5 NaN + dtype: object + + >>> s.str.strip('123.!? \\n\\t') + 0 Ant + 1 Bee + 2 Cat + 3 NaN + 4 NaN + 5 NaN + dtype: object + """ + return Series(query_compiler=self._query_compiler.str_strip(to_strip=to_strip)) + + def rstrip(self, to_strip=None): + return Series(query_compiler=self._query_compiler.str_rstrip(to_strip=to_strip)) + + def lstrip(self, to_strip=None): + return Series(query_compiler=self._query_compiler.str_lstrip(to_strip=to_strip)) + + def partition(self, sep=" ", expand=True): + if sep is not None and len(sep) == 0: + raise ValueError("empty separator") + + return (DataFrame if expand else Series)( + query_compiler=self._query_compiler.str_partition(sep=sep, expand=expand) + ) + + def removeprefix(self, prefix): + return Series(query_compiler=self._query_compiler.str_removeprefix(prefix)) + + def removesuffix(self, suffix): + return Series(query_compiler=self._query_compiler.str_removesuffix(suffix)) + + def repeat(self, repeats): + return Series(query_compiler=self._query_compiler.str_repeat(repeats)) + + def rpartition(self, sep=" ", expand=True): + if sep is not None and len(sep) == 0: + raise ValueError("empty separator") + + else: + return Series( + query_compiler=self._query_compiler.str_rpartition( + sep=sep, expand=expand + ) + ) + + def lower(self): + return Series(query_compiler=self._query_compiler.str_lower()) + + def upper(self): + return Series(query_compiler=self._query_compiler.str_upper()) + + def title(self): + """ + Convert strings in the Series/Index to be titlecased . + + Returns + ------- + Series or Index of object + + See also + -------- + Series.str.lower + Converts all characters to lowercase. + + Series.str.upper + Converts all characters to uppercase. + + Series.str.title + Converts first character of each word to uppercase and remaining to lowercase. + + Series.str.capitalize + Converts first character to uppercase and remaining to lowercase. + + Series.str.swapcase + Converts uppercase to lowercase and lowercase to uppercase. + + Series.str.casefold + Removes all case distinctions in the string. + + Examples + -------- + >>> s = pd.Series(['lower', 'CAPITALS', 'this is a sentence', 'SwApCaSe']) + >>> s + 0 lower + 1 CAPITALS + 2 this is a sentence + 3 SwApCaSe + dtype: object + + >>> s.str.title() + 0 Lower + 1 Capitals + 2 This Is A Sentence + 3 Swapcase + dtype: object + """ + return Series(query_compiler=self._query_compiler.str_title()) + + def find(self, sub, start=0, end=None): + if not isinstance(sub, str): + raise TypeError(f"expected a string object, not {type(sub).__name__}") + return Series( + query_compiler=self._query_compiler.str_find(sub, start=start, end=end) + ) + + def rfind(self, sub, start=0, end=None): + if not isinstance(sub, str): + raise TypeError(f"expected a string object, not {type(sub).__name__}") + return Series( + query_compiler=self._query_compiler.str_rfind(sub, start=start, end=end) + ) + + def index(self, sub, start=0, end=None): + if not isinstance(sub, str): + raise TypeError(f"expected a string object, not {type(sub).__name__}") + return Series( + query_compiler=self._query_compiler.str_index(sub, start=start, end=end) + ) + + def rindex(self, sub, start=0, end=None): + if not isinstance(sub, str): + raise TypeError(f"expected a string object, not {type(sub).__name__}") + return Series( + query_compiler=self._query_compiler.str_rindex(sub, start=start, end=end) + ) + + def capitalize(self): + """ + Convert strings in the Series/Index to be capitalized. + + Returns + ------- + Series or Index of object + + See also + -------- + Series.str.lower + Converts all characters to lowercase. + + Series.str.upper + Converts all characters to uppercase. + + Series.str.title + Converts first character of each word to uppercase and remaining to lowercase. + + Series.str.capitalize + Converts first character to uppercase and remaining to lowercase. + + Series.str.swapcase + Converts uppercase to lowercase and lowercase to uppercase. + + Series.str.casefold + Removes all case distinctions in the string. + + Examples + -------- + >>> s = pd.Series(['lower', 'CAPITALS', 'this is a sentence', 'SwApCaSe']) + >>> s + 0 lower + 1 CAPITALS + 2 this is a sentence + 3 SwApCaSe + dtype: object + + >>> s.str.capitalize() + 0 Lower + 1 Capitals + 2 This is a sentence + 3 Swapcase + dtype: object + """ + return Series(query_compiler=self._query_compiler.str_capitalize()) + + def swapcase(self): + return Series(query_compiler=self._query_compiler.str_swapcase()) + + def normalize(self, form): + return Series(query_compiler=self._query_compiler.str_normalize(form)) + + def translate(self, table): + return Series(query_compiler=self._query_compiler.str_translate(table)) + + def isalnum(self): + return Series(query_compiler=self._query_compiler.str_isalnum()) + + def isalpha(self): + return Series(query_compiler=self._query_compiler.str_isalpha()) + + def isdigit(self): + """ + Check whether all characters in each string are digits. + + This is equivalent to running the Python string method str.isdigit() for each element of the Series. If a string has zero characters, False is returned for that check. + + Returns + ------- + Series of boolean values with the same length as the original Series. + + Examples + -------- + >>> s = pd.Series(['23', '³', '⅕', '']) + + The `s.str.isdigit` method checks for characters used to form numbers in base 10. + Currently, special digits like superscripted and subscripted digits in unicode are + not checked for. + >>> s.str.isdigit() + 0 True + 1 False + 2 False + 3 False + dtype: bool + """ + return Series(query_compiler=self._query_compiler.str_isdigit()) + + def isspace(self): + return Series(query_compiler=self._query_compiler.str_isspace()) + + def islower(self): + """ + Check whether all characters in each string are lowercase. + + This is equivalent to running the Python string method str.islower() for each element of the Series. If a string has zero characters, False is returned for that check. + + Returns + ------- + Series of boolean values with the same length as the original Series. + + Examples + -------- + >>> s = pd.Series(['leopard', 'Golden Eagle', 'SNAKE', '']) + >>> s.str.islower() + 0 True + 1 False + 2 False + 3 False + dtype: bool + """ + return Series(query_compiler=self._query_compiler.str_islower()) + + def isupper(self): + """ + Check whether all characters in each string are uppercase. + + This is equivalent to running the Python string method str.isupper() for each element of the Series. If a string has zero characters, False is returned for that check. + + Returns + ------- + Series of boolean values with the same length as the original Series. + + Examples + -------- + >>> s = pd.Series(['leopard', 'Golden Eagle', 'SNAKE', '']) + >>> s.str.isupper() + 0 False + 1 False + 2 True + 3 False + dtype: bool + """ + return Series(query_compiler=self._query_compiler.str_isupper()) + + def istitle(self): + """ + Check whether all characters in each string are uppercase. + + This is equivalent to running the Python string method str.isupper() for each element of the Series. If a string has zero characters, False is returned for that check. + + Returns + ------- + Series of boolean values with the same length as the original Series. + + Examples + -------- + >>> s = pd.Series(['leopard', 'Golden Eagle', 'SNAKE', '', 'Snake']) + >>> s.str.istitle() + 0 False + 1 True + 2 False + 3 False + 4 True + dtype: bool + """ + return Series(query_compiler=self._query_compiler.str_istitle()) + + def isnumeric(self): + return Series(query_compiler=self._query_compiler.str_isnumeric()) + + def isdecimal(self): + return Series(query_compiler=self._query_compiler.str_isdecimal()) + + +@_inherit_docstrings(pandas.core.indexes.accessors.CombinedDatetimelikeProperties) +class DatetimeProperties: + def __init__(self, series) -> None: + self._series = series + self._query_compiler = series._query_compiler + + @property + def date(self): + """ + Returns a series of python :class:`datetime.date` objects. + + Namely, the date part of Timestamps without time and timezone information. + + Examples + -------- + For Series: + + >>> s = pd.Series(["2020-01-01 01:23:00", "2020-02-01 12:11:05"]) + >>> s = pd.to_datetime(s) + >>> s + 0 2020-01-01 01:23:00 + 1 2020-02-01 12:11:05 + dtype: datetime64[ns] + >>> s.dt.date + 0 2020-01-01 + 1 2020-02-01 + dtype: object + """ + return Series(query_compiler=self._query_compiler.dt_property("date")) + + @property + def time(self): + return Series(query_compiler=self._query_compiler.dt_time()) + + @property + def timetz(self): + return Series(query_compiler=self._query_compiler.dt_timetz()) + + @property + def year(self): + """ + Returns a series of the years of the datetime. + + Examples + -------- + >>> datetime_series = pd.Series( + ... pd.date_range("2000-01-01", periods=3, freq="YE") + ... ) + >>> datetime_series + 0 2000-12-31 + 1 2001-12-31 + 2 2002-12-31 + dtype: datetime64[ns] + >>> datetime_series.dt.year + 0 2000 + 1 2001 + 2 2002 + dtype: int16 + """ + return Series(query_compiler=self._query_compiler.dt_property("year")) + + @property + def month(self): + """ + Returns a series of the months of the datetime. + + Examples + -------- + >>> datetime_series = pd.Series( + ... pd.date_range("2000-01-01", periods=3, freq="ME") + ... ) + >>> datetime_series + 0 2000-01-31 + 1 2000-02-29 + 2 2000-03-31 + dtype: datetime64[ns] + >>> datetime_series.dt.month + 0 1 + 1 2 + 2 3 + dtype: int8 + """ + return Series(query_compiler=self._query_compiler.dt_property("month")) + + @property + def day(self): + """ + Returns a series of the days of the datetime. + + Examples + -------- + >>> datetime_series = pd.Series( + ... pd.date_range("2000-01-01", periods=3, freq="D") + ... ) + >>> datetime_series + 0 2000-01-01 + 1 2000-01-02 + 2 2000-01-03 + dtype: datetime64[ns] + >>> datetime_series.dt.day + 0 1 + 1 2 + 2 3 + dtype: int8 + """ + return Series(query_compiler=self._query_compiler.dt_property("day")) + + @property + def hour(self): + """ + Returns a series of the hours of the datetime. + + Examples + -------- + >>> datetime_series = pd.Series( + ... pandas.date_range("2000-01-01", periods=3, freq="h") + ... ) + >>> datetime_series + 0 2000-01-01 00:00:00 + 1 2000-01-01 01:00:00 + 2 2000-01-01 02:00:00 + dtype: datetime64[ns] + >>> datetime_series.dt.hour + 0 0 + 1 1 + 2 2 + dtype: int8 + """ + return Series(query_compiler=self._query_compiler.dt_property("hour")) + + @property + def minute(self): + """ + Returns a series of the minutes of the datetime. + + Examples + -------- + >>> datetime_series = pd.Series( + ... pd.date_range("2000-01-01", periods=3, freq="min") + ... ) + >>> datetime_series + 0 2000-01-01 00:00:00 + 1 2000-01-01 00:01:00 + 2 2000-01-01 00:02:00 + dtype: datetime64[ns] + >>> datetime_series.dt.minute + 0 0 + 1 1 + 2 2 + dtype: int8 + """ + return Series(query_compiler=self._query_compiler.dt_property("minute")) + + @property + def second(self): + """ + Returns a series of the seconds of the datetime. + + Examples + -------- + >>> datetime_series = pd.Series( + ... pd.date_range("2000-01-01", periods=3, freq="s") + ... ) + >>> datetime_series + 0 2000-01-01 00:00:00 + 1 2000-01-01 00:00:01 + 2 2000-01-01 00:00:02 + dtype: datetime64[ns] + >>> datetime_series.dt.second + 0 0 + 1 1 + 2 2 + dtype: int8 + """ + return Series(query_compiler=self._query_compiler.dt_property("second")) + + @property + def microsecond(self): + return Series(query_compiler=self._query_compiler.dt_microsecond()) + + @property + def nanosecond(self): + return Series(query_compiler=self._query_compiler.dt_nanosecond()) + + @property + def week(self): + return Series(query_compiler=self._query_compiler.dt_week()) + + @property + def weekofyear(self): + return Series(query_compiler=self._query_compiler.dt_weekofyear()) + + @property + def dayofweek(self): + return Series(query_compiler=self._query_compiler.dt_dayofweek()) + + @property + def weekday(self): + return Series(query_compiler=self._query_compiler.dt_weekday()) + + @property + def dayofyear(self): + return Series(query_compiler=self._query_compiler.dt_dayofyear()) + + @property + def quarter(self): + """ + Returns a series of the quarters of the datetime. + + Examples + -------- + >>> datetime_series = pd.Series( + ... pd.date_range("2000-01-01", periods=3, freq="3ME") + ... ) + >>> datetime_series + 0 2000-01-31 + 1 2000-04-30 + 2 2000-07-31 + dtype: datetime64[ns] + >>> datetime_series.dt.quarter + 0 1 + 1 2 + 2 3 + dtype: int8 + """ + return Series(query_compiler=self._query_compiler.dt_property("quarter")) + + @property + def is_month_start(self): + return Series(query_compiler=self._query_compiler.dt_is_month_start()) + + @property + def is_month_end(self): + return Series(query_compiler=self._query_compiler.dt_is_month_end()) + + @property + def is_quarter_start(self): + return Series(query_compiler=self._query_compiler.dt_is_quarter_start()) + + @property + def is_quarter_end(self): + return Series(query_compiler=self._query_compiler.dt_is_quarter_end()) + + @property + def is_year_start(self): + return Series(query_compiler=self._query_compiler.dt_is_year_start()) + + @property + def is_year_end(self): + return Series(query_compiler=self._query_compiler.dt_is_year_end()) + + @property + def is_leap_year(self): + return Series(query_compiler=self._query_compiler.dt_is_leap_year()) + + @property + def daysinmonth(self): + return Series(query_compiler=self._query_compiler.dt_daysinmonth()) + + @property + def days_in_month(self): + return Series(query_compiler=self._query_compiler.dt_days_in_month()) + + @property + def tz(self) -> "tzinfo | None": + dtype = self._series.dtype + if isinstance(dtype, np.dtype): + return None + return dtype.tz + + @property + def freq(self): + return self._query_compiler.dt_freq().to_pandas().squeeze() + + def to_period(self, *args, **kwargs): + return Series(query_compiler=self._query_compiler.dt_to_period(*args, **kwargs)) + + def to_pydatetime(self): + return Series(query_compiler=self._query_compiler.dt_to_pydatetime()).to_numpy() + + def tz_localize(self, *args, **kwargs): + return Series( + query_compiler=self._query_compiler.dt_tz_localize(*args, **kwargs) + ) + + def tz_convert(self, *args, **kwargs): + return Series( + query_compiler=self._query_compiler.dt_tz_convert(*args, **kwargs) + ) + + def normalize(self, *args, **kwargs): + return Series(query_compiler=self._query_compiler.dt_normalize(*args, **kwargs)) + + def strftime(self, *args, **kwargs): + return Series(query_compiler=self._query_compiler.dt_strftime(*args, **kwargs)) + + def round(self, *args, **kwargs): + return Series(query_compiler=self._query_compiler.dt_round(*args, **kwargs)) + + def floor(self, *args, **kwargs): + return Series(query_compiler=self._query_compiler.dt_floor(*args, **kwargs)) + + def ceil(self, *args, **kwargs): + return Series(query_compiler=self._query_compiler.dt_ceil(*args, **kwargs)) + + def month_name(self, *args, **kwargs): + return Series( + query_compiler=self._query_compiler.dt_month_name(*args, **kwargs) + ) + + def day_name(self, *args, **kwargs): + return Series(query_compiler=self._query_compiler.dt_day_name(*args, **kwargs)) + + def total_seconds(self, *args, **kwargs): + return Series( + query_compiler=self._query_compiler.dt_total_seconds(*args, **kwargs) + ) + + def to_pytimedelta(self) -> "npt.NDArray[np.object_]": + res = self._query_compiler.dt_to_pytimedelta() + return res.to_numpy()[:, 0] + + @property + def seconds(self): + return Series(query_compiler=self._query_compiler.dt_seconds()) + + @property + def days(self): + return Series(query_compiler=self._query_compiler.dt_days()) + + @property + def microseconds(self): + return Series(query_compiler=self._query_compiler.dt_microseconds()) + + @property + def nanoseconds(self): + return Series(query_compiler=self._query_compiler.dt_nanoseconds()) + + @property + def components(self): + + return DataFrame(query_compiler=self._query_compiler.dt_components()) + + @property + def qyear(self): + return Series(query_compiler=self._query_compiler.dt_qyear()) + + @property + def start_time(self): + return Series(query_compiler=self._query_compiler.dt_start_time()) + + @property + def end_time(self): + return Series(query_compiler=self._query_compiler.dt_end_time()) + + def to_timestamp(self, *args, **kwargs): + return Series( + query_compiler=self._query_compiler.dt_to_timestamp(*args, **kwargs) + ) diff --git a/src/snowflake/snowpark/modin/pandas/shared_docs.py b/src/snowflake/snowpark/modin/pandas/shared_docs.py new file mode 100644 index 00000000000..20bdde253a9 --- /dev/null +++ b/src/snowflake/snowpark/modin/pandas/shared_docs.py @@ -0,0 +1,62 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# +# +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the pandas project, under the BSD 3-Clause License + +from __future__ import annotations + +_shared_docs: dict[str, str] = {} + +_shared_docs[ + "aggregate" +] = """ +Aggregate using one or more operations over the specified axis. + +Parameters +---------- +func : function, str, list or dict + Function to use for aggregating the data. If a function, must either + work when passed a {klass} or when passed to {klass}.apply. + + Accepted combinations are: + + - function + - string function name + - list of functions and/or function names, e.g. ``[np.sum, 'mean']`` + - dict of axis labels -> functions, function names or list of such. +{axis} +*args + Positional arguments to pass to `func`. +**kwargs + Keyword arguments to pass to `func`. + +Returns +------- +scalar, Snowpark pandas :class:`~snowflake.snowpark.modin.pandas.Series` or Snowpark pandas :class:`~snowflake.snowpark.modin.pandas.DataFrame` + + The return can be: + + * scalar : when Snowpark pandas Series.agg is called with single function + * Snowpark pandas :class:`~snowflake.snowpark.modin.pandas.Series` : when Snowpark pandas DataFrame.agg is called with a single function + * Snowpark pandas :class:`~snowflake.snowpark.modin.pandas.DataFrame` : when Snowpark pandas DataFrame.agg is called with several functions + + Return scalar, Snowpark pandas :class:`~snowflake.snowpark.modin.pandas.Series` or Snowpark pandas :class:`~snowflake.snowpark.modin.pandas.DataFrame`. + +Notes +----- +The aggregation operations are always performed over an axis, either the +index (default) or the column axis. This behavior is different from +`numpy` aggregation functions (`mean`, `median`, `prod`, `sum`, `std`, +`var`), where the default is to compute the aggregation of the flattened +array, e.g., ``numpy.mean(arr_2d)`` as opposed to +``numpy.mean(arr_2d, axis=0)``. + +`agg` is an alias for `aggregate`. Use the alias. + +Functions that mutate the passed object can produce unexpected +behavior or errors and are not supported. + +A passed user-defined-function will be passed a Series for evaluation. +{examples}""" diff --git a/src/snowflake/snowpark/modin/pandas/snow_partition_iterator.py b/src/snowflake/snowpark/modin/pandas/snow_partition_iterator.py new file mode 100644 index 00000000000..3529355b81b --- /dev/null +++ b/src/snowflake/snowpark/modin/pandas/snow_partition_iterator.py @@ -0,0 +1,111 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +from collections.abc import Iterator +from typing import Any, Callable + +import pandas + +import snowflake.snowpark.modin.pandas.dataframe as DataFrame + +PARTITION_SIZE = 4096 + + +class SnowparkPandasRowPartitionIterator(Iterator): + """ + Iterator on partitioned data used by DataFrame.iterrows and DataFrame.itertuples to iterate over axis=0 or rows. + + SnowparkPandasRowPartitionIterator pulls table data in batches (where number of rows = PARTITION_SIZE) to iterate + over rows. This is to prevent the table from being queried for every single row - the batch of rows pulled in is + converted to a native pandas DataFrame and completely iterated over before pulling in the next batch. This results + in to_pandas() query being made per batch; no joins are ever performed in this implementation. + + However, if enable_partition_with_native_pandas is set to False, it behaves just like the PartitionIterator where + an iloc call is made to the table to pull in every single row. This results in a join query run for every single + row, which is inefficient because a lot more queries are issued. This option should be used when a Snowpark pandas + DataFrame or Series is to be returned to avoid downloading and uploading the same data. + + Parameters + ---------- + df : DataFrame + The dataframe to iterate over. + axis : {0, 1} + Axis to iterate over. + func : callable + The function to get inner iterables from each partition. + enable_partition_with_native_pandas: bool, default False + When True, retrieve the table as partitions. Each partition is a pandas DataFrame which is iterated over until + exhausted, and the next partition is pulled in. + When False, iterate over the Snowpark pandas DataFrame directly row-by-row. + """ + + def __init__( + self, + df: DataFrame, + func: Callable, + enable_partition_with_native_pandas: bool = False, + ) -> None: + self.position = 0 # keep track of position in the iterator + # To avoid making a query per row to extract row data (like in DataFrame.iterrows and DataFrame.itertuples), + # a batch of rows of size PARTITION_SIZE is materialized at a time and converted to a pandas DataFrame. + # This uses fewer queries. Partitions are used instead of materializing the whole table since some tables + # are too large to be materialized in one go. PARTITION_SIZE is arbitrary and can be tuned for performance. + self.df = df + self.func = func + self.enable_partition_with_native_pandas = enable_partition_with_native_pandas + # TODO SNOW-1017263: update to_pandas() to return an iterator and use that directly here. + if self.enable_partition_with_native_pandas: + self.partition = self.get_next_partition() + self.num_rows = -1 # unused + else: + self.partition = None # unused + # The call below triggers eager evaluation for row count - it is used as a stopping condition to raise + # StopIteration for the iterator. + self.num_rows = len(self.df) + + def __iter__(self) -> "SnowparkPandasRowPartitionIterator": + """ + Implement iterator interface. + + Returns + ------- + SnowparkPandasRowPartitionIterator + Iterator object. + """ + return self + + def __next__(self) -> Any: + """ + Implement iterator interface. + + Returns + ------- + Any + Next element in the SnowparkPandasRowPartitionIterator after the callable func is applied. + """ + # self.position is used to get the integer location of rows. + if self.enable_partition_with_native_pandas: + if len(self.partition) <= self.position % PARTITION_SIZE: + raise StopIteration + ser = self.partition.iloc[self.position % PARTITION_SIZE] + self.position += 1 + if self.position and self.position % PARTITION_SIZE == 0: + # Finished iterating through the current partition, fetch the next partition. + self.partition = self.get_next_partition() + return self.func(ser) + else: + if self.position < self.num_rows: + ser = self.df.iloc[self.position] + self.position += 1 + return self.func(ser) + else: + raise StopIteration + + def get_next_partition(self) -> pandas.DataFrame: + """ + Helper method to retrieve a partition of table data of size PARTITION_SIZE number of rows. + """ + return self.df.iloc[ + slice(self.position, self.position + PARTITION_SIZE) + ].to_pandas() diff --git a/src/snowflake/snowpark/modin/pandas/utils.py b/src/snowflake/snowpark/modin/pandas/utils.py new file mode 100644 index 00000000000..f458a1d970b --- /dev/null +++ b/src/snowflake/snowpark/modin/pandas/utils.py @@ -0,0 +1,740 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +"""Implement utils for pandas component.""" + +from collections.abc import Hashable, Iterator, Sequence +from types import BuiltinFunctionType +from typing import Any, Callable, Optional, Union + +import numpy as np +import pandas +from pandas._typing import ( + AggFuncType, + AggFuncTypeBase, + AggFuncTypeDict, + AnyArrayLike, + IndexLabel, + Scalar, +) +from pandas.core.dtypes.common import is_array_like, is_dict_like, is_list_like +from pandas.errors import SpecificationError +from pandas.util._decorators import doc + +import snowflake.snowpark.modin.pandas as pd +from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, +) +from snowflake.snowpark.modin.plugin._internal.aggregation_utils import ( + get_pandas_aggr_func_name, +) +from snowflake.snowpark.modin.plugin.compiler import BaseQueryCompiler +from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage +from snowflake.snowpark.modin.utils import hashable + +_doc_binary_operation = """ +Return {operation} of {left} and `{right}` (binary operator `{bin_op}`). + +Parameters +---------- +{right} : {right_type} + The second operand to perform computation. + +Returns +------- +{returns} +""" + + +def from_non_pandas(df, index, columns, dtype): + """ + Convert a non-pandas DataFrame into Modin DataFrame. + + Parameters + ---------- + df : object + Non-pandas DataFrame. + index : object + Index for non-pandas DataFrame. + columns : object + Columns for non-pandas DataFrame. + dtype : type + Data type to force. + + Returns + ------- + modin.pandas.DataFrame + Converted DataFrame. + """ + # from modin.core.execution.dispatching.factories.dispatcher import FactoryDispatcher + + new_qc = FactoryDispatcher.from_non_pandas(df, index, columns, dtype) + if new_qc is not None: + from snowflake.snowpark.modin.pandas import DataFrame + + return DataFrame(query_compiler=new_qc) + return new_qc + + +def from_pandas(df): + """ + Convert a pandas DataFrame to a Modin DataFrame. + + Parameters + ---------- + df : pandas.DataFrame + The pandas DataFrame to convert. + + Returns + ------- + modin.pandas.DataFrame + A new Modin DataFrame object. + """ + # from modin.core.execution.dispatching.factories.dispatcher import FactoryDispatcher + from snowflake.snowpark.modin.pandas import DataFrame + + return DataFrame(query_compiler=FactoryDispatcher.from_pandas(df)) + + +def from_arrow(at): + """ + Convert an Arrow Table to a Modin DataFrame. + + Parameters + ---------- + at : Arrow Table + The Arrow Table to convert from. + + Returns + ------- + DataFrame + A new Modin DataFrame object. + """ + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + from snowflake.snowpark.modin.pandas import DataFrame + + return DataFrame(query_compiler=FactoryDispatcher.from_arrow(at)) + + +def from_dataframe(df): + """ + Convert a DataFrame implementing the dataframe exchange protocol to a Modin DataFrame. + + See more about the protocol in https://data-apis.org/dataframe-protocol/latest/index.html. + + Parameters + ---------- + df : DataFrame + The DataFrame object supporting the dataframe exchange protocol. + + Returns + ------- + DataFrame + A new Modin DataFrame object. + """ + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + from snowflake.snowpark.modin.pandas import DataFrame + + return DataFrame(query_compiler=FactoryDispatcher.from_dataframe(df)) + + +def is_scalar(obj): + """ + Return True if given object is scalar. + + This method works the same as is_scalar method from pandas but + it is optimized for Modin frames. For BasePandasDataset objects + pandas version of is_scalar tries to access missing attribute + causing index scan. This triggers execution for lazy frames and + we avoid it by handling BasePandasDataset objects separately. + + Parameters + ---------- + obj : object + Object to check. + + Returns + ------- + bool + True if given object is scalar and False otherwise. + """ + from pandas.api.types import is_scalar as pandas_is_scalar + + from .base import BasePandasDataset + + return not isinstance(obj, BasePandasDataset) and pandas_is_scalar(obj) + + +def is_full_grab_slice(slc, sequence_len=None): + """ + Check that the passed slice grabs the whole sequence. + + Parameters + ---------- + slc : slice + Slice object to check. + sequence_len : int, optional + Length of the sequence to index with the passed `slc`. + If not specified the function won't be able to check whether + ``slc.stop`` is equal or greater than the sequence length to + consider `slc` to be a full-grab, and so, only slices with + ``.stop is None`` are considered to be a full-grab. + + Returns + ------- + bool + """ + assert isinstance(slc, slice), "slice object required" + return ( + slc.start in (None, 0) + and slc.step in (None, 1) + and ( + slc.stop is None or (sequence_len is not None and slc.stop >= sequence_len) + ) + ) + + +def from_modin_frame_to_mi(df, sortorder=None, names=None): + """ + Make a pandas.MultiIndex from a DataFrame. + + Parameters + ---------- + df : DataFrame + DataFrame to be converted to pandas.MultiIndex. + sortorder : int, default: None + Level of sortedness (must be lexicographically sorted by that + level). + names : list-like, optional + If no names are provided, use the column names, or tuple of column + names if the columns is a MultiIndex. If a sequence, overwrite + names with the given sequence. + + Returns + ------- + pandas.MultiIndex + The pandas.MultiIndex representation of the given DataFrame. + """ + from snowflake.snowpark.modin.pandas import DataFrame + + if isinstance(df, DataFrame): + df = df._to_pandas() + return _original_pandas_MultiIndex_from_frame(df, sortorder, names) + + +def is_label(obj, label, axis=0): + """ + Check whether or not 'obj' contain column or index level with name 'label'. + + Parameters + ---------- + obj : modin.pandas.DataFrame, modin.pandas.Series or modin.core.storage_formats.base.BaseQueryCompiler + Object to check. + label : object + Label name to check. + axis : {0, 1}, default: 0 + Axis to search for `label` along. + + Returns + ------- + bool + True if check is successful, False otherwise. + """ + qc = getattr(obj, "_query_compiler", obj) + return hashable(label) and ( + label in qc.get_axis(axis ^ 1) or label in qc.get_index_names(axis) + ) + + +def check_both_not_none(option1, option2): + """ + Check that both `option1` and `option2` are not None. + + Parameters + ---------- + option1 : Any + First object to check if not None. + option2 : Any + Second object to check if not None. + + Returns + ------- + bool + True if both option1 and option2 are not None, False otherwise. + """ + return not (option1 is None or option2 is None) + + +def _walk_aggregation_func( + key: IndexLabel, value: AggFuncType, depth: int = 0 +) -> Iterator[tuple[IndexLabel, AggFuncTypeBase, Optional[str], bool]]: + """ + Walk over a function from a dictionary-specified aggregation. + + Note: this function is not supposed to be called directly and + is used by ``walk_aggregation_dict``. + + Parameters + ---------- + key : IndexLabel + A key in a dictionary-specified aggregation for the passed `value`. + This means an index label to apply the `value` functions against. + value : AggFuncType + An aggregation function matching the `key`. + depth : int, default: 0 + Specifies a nesting level for the `value` where ``depth=0`` is when + you call the function on a raw dictionary value. + + Yields + ------ + (col: IndexLabel, func: AggFuncTypeBase, func_name: Optional[str], col_renaming_required: bool) + Yield an aggregation function with its metadata: + - `col`: column name to apply the function. + - `func`: aggregation function to apply to the column. + - `func_name`: custom function name that was specified in the dict. + - `col_renaming_required`: whether it's required to rename the + `col` into ``(col, func_name)``. + """ + col_renaming_required = bool(depth) + + if isinstance(value, (list, tuple)): + if depth == 0: + for val in value: + yield from _walk_aggregation_func(key, val, depth + 1) + elif depth == 1: + if len(value) != 2: + raise ValueError( + f"Incorrect rename format. Renamer must consist of exactly two elements, got: {len(value)}." + ) + func_name, func = value + yield key, func, func_name, col_renaming_required + else: + # pandas doesn't support this as well + ErrorMessage.not_implemented( + "Nested renaming is not supported." + ) # pragma: no cover + else: + yield key, value, None, col_renaming_required + + +def walk_aggregation_dict( + agg_dict: AggFuncTypeDict, +) -> Iterator[tuple[IndexLabel, AggFuncTypeBase, Optional[str], bool]]: + """ + Walk over an aggregation dictionary. + + Parameters + ---------- + agg_dict : AggFuncTypeDict + + Yields + ------ + (col: IndexLabel, func: AggFuncTypeBase, func_name: Optional[str], col_renaming_required: bool) + Yield an aggregation function with its metadata: + - `col`: column name to apply the function. + - `func`: aggregation function to apply to the column. + - `func_name`: custom function name that was specified in the dict. + - `col_renaming_required`: whether it's required to rename the + `col` into ``(col, func_name)``. + """ + for key, value in agg_dict.items(): + yield from _walk_aggregation_func(key, value) + + +def raise_if_native_pandas_objects(obj: Any) -> None: + """ + Raise TypeError if provided object is of type pandas.Series or pandas.DataFrame + + Args: + obj: object to check + + Raises: + TypeError if provided ``obj`` is either native pandas DataFrame or Series + + """ + if isinstance(obj, (pandas.DataFrame, pandas.Series)): + raise TypeError( + f"{type(obj)} is not supported as 'value' argument. Please convert this to " + "Snowpark pandas objects by calling modin.pandas.Series()/DataFrame()" + ) + + +def replace_external_data_keys_with_empty_pandas_series( + keys: Optional[ + Union[Hashable, AnyArrayLike, Sequence[Union[Hashable, AnyArrayLike]]] + ] = None +) -> Optional[Union[Hashable, pandas.Series, list[Union[Hashable, pandas.Series]]]]: + """ + Replace any array-like key with empty series. + Args: + keys: join key or sequence of join keys. + + Returns: + Join key(s) by replacing array-like join key with empty series. + """ + if keys is None: + return None + if is_array_like(keys): + return create_empty_pandas_series_from_array_like(keys) + if isinstance(keys, (list, tuple)): + return [ + create_empty_pandas_series_from_array_like(key) + if is_array_like(key) + else key + for key in keys + ] + return keys + + +def create_empty_pandas_series_from_array_like(obj: AnyArrayLike) -> pandas.Series: + """ + Create empty (zero rows) native pandas series from given array-like object. + Args: + obj: array-like object + + Returns: + Native pandas series with zero rows. + + """ + assert is_array_like(obj) + # Snowpark pandas series. + if isinstance(obj, pd.Series): + return create_empty_native_pandas_frame(obj).squeeze() + # Everything else first gets converted to pandas.Series + if not isinstance(obj, pandas.Series): + obj = pandas.Series(obj) + # Create empty series by calling head with zero rows. + return obj.head(0) + + +def create_empty_native_pandas_frame( + obj: Union["pd.Series", "pd.DataFrame"] +) -> pandas.DataFrame: + """ + Create an empty native pandas DataFrame using the columns and index labels info from + the given object. Empty here implies zero rows. + + Args: + obj: Snowflake Series or DataFrame. + + Returns: + A native pandas DataFrame with 0 rows in it. + """ + qc = obj._query_compiler + index_names = qc.get_index_names() + index = ( + pandas.MultiIndex.from_tuples(tuples=[], names=index_names) + if len(index_names) > 1 + else pandas.Index(data=[], name=index_names[0]) + ) + return pandas.DataFrame(columns=qc.columns, index=index) + + +def replace_external_data_keys_with_query_compiler( + frame: "pd.DataFrame", + keys: Optional[ + Union[Hashable, AnyArrayLike, Sequence[Union[Hashable, AnyArrayLike]]] + ] = None, +) -> Optional[ + Union[Hashable, BaseQueryCompiler, list[Union[Hashable, BaseQueryCompiler]]] +]: + """ + Replace any array-like join key(s) with query compiler. + + Args: + frame: dataframe, join keys belong to. + keys: join key or sequence of join keys. + + Returns: + List of join keys by replacing array-like join keys with query compiler. + + """ + if keys is None: + return None + if not isinstance(keys, (list, tuple)): + keys = [keys] + replaced_keys = [] + for key in keys: + if is_array_like(key): + raise_if_native_pandas_objects(key) + if not isinstance(key, pd.Series): + key = pd.Series(key) + # Native pandas raises + # ValueError: The truth value of an array with more than one element is ambiguous + # Error message is not very helpful. We instead raise error with + # more helpful message. + if frame.shape[0] != key.shape[0]: + raise ValueError( + "array-like join key must be of same length as dataframe" + ) + replaced_keys.append(key._query_compiler) + else: + replaced_keys.append(key) + return replaced_keys + + +def try_convert_builtin_func_to_str( + fn: Union[AggFuncTypeBase, list[AggFuncTypeBase]], obj: object +) -> Union[AggFuncTypeBase, list[AggFuncTypeBase]]: + """ + Try to convert an aggregation function to a string or list of such if the function is a + builtin function and supported in the current object dir. + + This is mainly required by our server backend aggregation function mapping, which requires the + function to be in string format or numpy function (numpy function is handled differently because + it can potentially have different behavior as builtin function, For example: np.percentile and + percentile have different behavior). For any function that can not find a map in snowflake, it will + go through fallback, includes function that is not a numpy function and can not be converted to + string format. + + Args: + fn : callable, str, or list of above + obj : the object to search for function dir + + Returns: + str, callable or list of above + If `fn` is a callable, return its name if it's a builtin function (i.e. min, max) + and it is a method of the current object, otherwise return `fn` itself. + If `fn` is a string, return it. + If `fn` is an Iterable, return a list of try_convert_func_to_str applied to + each element of `fn`. + """ + + def _try_convert_single_builtin_func_to_str(f): + return ( + f.__name__ + if ( + callable(f) + and isinstance(f, BuiltinFunctionType) + and f.__name__ in dir(obj) + ) + else f + ) + + if is_list_like(fn): + return [_try_convert_single_builtin_func_to_str(f) for f in fn] + else: + return _try_convert_single_builtin_func_to_str(fn) + + +def validate_and_try_convert_agg_func_arg_func_to_str( + agg_func: AggFuncType, obj: object, allow_duplication: bool, axis: int +) -> AggFuncType: + """ + Perform validation on the func argument for aggregation, and try to convert builtin function in agg_func to str. + Following validation is performed: + 1) Argument agg_func can not be None. + 2) If agg_func is dict like, the values of the dict can not be dict like, and if the aggregation is across axis=0, + all keys must be a valid column of the object. When axis=1, we do not check if the labels are present in the index + to avoid the extra query needed to materialize it. + 3) If allow_duplication is False, more than one aggregation function with the same name can not be applied on the + same column. For example: [min, max, min] is not valid. This is mainly used by general aggregation. + + This function also calls try_convert_func_to_str on agg_func to convert the builtin functions used in agg_func to + str but keep the original dict like or list like format. This is mainly required by our server backend aggregation + function mapping, which requires the function to be in string format or numpy function (numpy function is handled + differently because it can potentially have different behavior as builtin function, For example: np.percentile and + percentile have different behavior). For any function that can not find a map in snowflake, it will + go through fallback, includes function that is not a numpy function and can not be converted to string format. + + Args: + agg_func: AggFuncType + The func arg passed for the aggregation + obj: object + The object to search for attributes + allow_duplication: bool + Whether allow duplicated function with the same name. Note that numpy functions has different function + name compare with the equivalent builtin function, for example, np.min and min have different + names ('amin' and 'min'). However, this behavior is changing with python 3.9, + where np.min will have the same name 'min'. + axis: int + The axis across which the aggregation is applied. + + Returns: + Processed aggregation function arg with builtin function converted to name + Raises: + SpecificationError + If nested dict configuration is used when agg_func is dict like or functions with duplicated names. + + """ + if agg_func is None: + # Snowpark pandas only support func argument at this moment. + # TODO (SNOW-902943): pandas allows usage of NamedAgg in kwargs to configure + # tuples of (columns, agg_func) with rename. For example: + # df.groupby('A').agg(b_min=pd.NamedAgg(column='B', aggfunc='min')), which applies + # min function on column 'B', and uses 'b_min' as the new column name. + # Once supported, refine the check to check both. + ErrorMessage.not_implemented( + "Must provide value for 'func' argument, func=None is currently not supported with Snowpark pandas" + ) + + if callable(agg_func): + result_agg_func = try_convert_builtin_func_to_str(agg_func, obj) + elif is_dict_like(agg_func): + # A dict like func input should in format like {'col1': max, 'col2': [min, np.max]}, where each + # entry have key as the data column label, and value as the aggregation functions to apply on + # the column. Following checks and process will be performed if the input is dict like: + # 1) Perform check for the dict entries to make sure all columns belongs to the data columns, and + # no nested dictionary is used in the configuration. + # 2) Perform a processing to the values (aggregation function) to convert the function to string + # format if possible. For example, {'col1': max, 'col2': [min, np.max]} will be processed to + # {'col1': 'max', 'col2': ['min', np.max]} + + # check if there is any value also in dictionary format, which is not allowed in pandas + if any(is_dict_like(fn) for fn in agg_func.values()): + raise SpecificationError( + "Value for func argument with nested dict format is not allowed." + ) + if any(is_list_like(fn) and len(fn) == 0 for fn in agg_func.values()): + # A label must have aggregations provided, e.g. df.agg({0: []}) is illegal + raise ValueError("No objects to concatenate") + # check that all columns in the dictionary exists in the data columns of the current dataframe + columns = obj._query_compiler.columns + if axis == 0: + # If axis == 1, we would need a query to materialize the index to check its existence + # so we defer the error checking to later. + for i in agg_func.keys(): + if i not in columns: + raise KeyError(f"Column(s) ['{i}'] do not exist") + + func_dict = { + label: try_convert_builtin_func_to_str(fn, obj) + for label, fn in agg_func.items() + } + + result_agg_func = func_dict + elif is_list_like(agg_func): + # When the input func is in list like format like [min, max, np.sum], perform a processing to the + # aggregation function to convert it to string representation if possible. + result_agg_func = try_convert_builtin_func_to_str(agg_func, obj) + else: + result_agg_func = agg_func + + if not allow_duplication: + # if allow_duplication is False, check is there duplication in the function names, which + # are used as the row label for the aggregation result in dataframe/series aggregation, and + # not allowed in pandas. + found_duplication = False + if is_dict_like(result_agg_func): + for agg_func in result_agg_func.values(): + if is_list_like(agg_func): + agg_func_names = [get_pandas_aggr_func_name(fn) for fn in agg_func] + found_duplication = len(agg_func_names) > len(set(agg_func_names)) + break + elif is_list_like(result_agg_func): + agg_func_names = [get_pandas_aggr_func_name(fn) for fn in result_agg_func] + found_duplication = len(agg_func_names) > len(set(agg_func_names)) + + if found_duplication: + raise SpecificationError("Function names must be unique!") + + return result_agg_func + + +def _doc_binary_op(operation, bin_op, left="Series", right="right", returns="Series"): + """ + Return callable documenting `Series` or `DataFrame` binary operator. + + Parameters + ---------- + operation : str + Operation name. + bin_op : str + Binary operation name. + left : str, default: 'Series' + The left object to document. + right : str, default: 'right' + The right operand name. + returns : str, default: 'Series' + Type of returns. + + Returns + ------- + callable + """ + if left == "Series": + right_type = "Series or scalar value" + elif left == "DataFrame": + right_type = "DataFrame, Series or scalar value" + elif left == "BasePandasDataset": + right_type = "BasePandasDataset or scalar value" + else: + ErrorMessage.not_implemented( + f"Only 'BasePandasDataset', `DataFrame` and 'Series' `left` are allowed, actually passed: {left}" + ) # pragma: no cover + doc_op = doc( + _doc_binary_operation, + operation=operation, + right=right, + right_type=right_type, + bin_op=bin_op, + returns=returns, + left=left, + ) + + return doc_op + + +def get_as_shape_compatible_dataframe_or_series( + other: Union["pd.DataFrame", "pd.Series", Callable, AnyArrayLike, Scalar], + reference_df: "pd.DataFrame", + shape_mismatch_message: Optional[ + str + ] = "Array conditional must be same shape as self", +) -> Union["pd.DataFrame", "pd.Series"]: + """ + Get the "other" type as a shape compatible dataframe or series using the reference_df as a reference for + compatible shape and construction. If there is no shape on the other type then wrap as a numpy array. + + Parameters + ---------- + other : Other type which could be array like + reference_df : Reference dataframe or series + + Returns + ------- + Dataframe or series that contains same values as other + """ + if not hasattr(other, "shape"): + # If an array type is provided that doesn't have a shape, then wrap it so it has a shape. + # For example, if other=[1,2,3] then np.asanyarray will wrap as a numpy array with correct shape, + # ie, np.anyarray(other).shape=(3,) in this case. + other = np.asanyarray(other) + + if len(other.shape) == 0 or other.shape != reference_df.shape: + raise ValueError(shape_mismatch_message) + + if "columns" in reference_df: + other = reference_df.__constructor__( + other, index=reference_df.index, columns=reference_df.columns + ) + else: + other = reference_df.__constructor__(other, index=reference_df.index) + + return other + + +_original_pandas_MultiIndex_from_frame = pandas.MultiIndex.from_frame +pandas.MultiIndex.from_frame = from_modin_frame_to_mi diff --git a/src/snowflake/snowpark/modin/pandas/window.py b/src/snowflake/snowpark/modin/pandas/window.py new file mode 100644 index 00000000000..bf71341a430 --- /dev/null +++ b/src/snowflake/snowpark/modin/pandas/window.py @@ -0,0 +1,463 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Licensed to Modin Development Team under one or more contributor license agreements. +# See the NOTICE file distributed with this work for additional information regarding +# copyright ownership. The Modin Development Team licenses this file to you under the +# Apache License, Version 2.0 (the "License"); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + +# Code in this file may constitute partial or total reimplementation, or modification of +# existing code originally distributed by the Modin project, under the Apache License, +# Version 2.0. + +"""Implement Window and Rolling public API.""" +from typing import Any, Literal, Optional, Union + +import numpy as np # noqa: F401 +import pandas.core.window.rolling + +from snowflake.snowpark.dataframe import DataFrame as SnowparkDataFrame + +# add these two lines to enable doc tests to run +from snowflake.snowpark.modin import pandas as pd # noqa: F401 +from snowflake.snowpark.modin.plugin._internal.telemetry import TelemetryMeta +from snowflake.snowpark.modin.utils import ( + _inherit_docstrings, + doc_replace_dataframe_with_link, +) + + +@_inherit_docstrings( + pandas.core.window.rolling.Window, modify_doc=doc_replace_dataframe_with_link +) +# TODO SNOW-1041934: Add support for more window aggregations +class Window(metaclass=TelemetryMeta): + def __init__( + self, + dataframe, + window: Any = None, + min_periods: int = None, + center: bool = False, + win_type: str = None, + on: str = None, + axis: Union[int, str] = 0, + closed: str = None, + step: int = None, + method: str = "single", + ) -> None: + # TODO: SNOW-1063357: Modin upgrade - modin.pandas.window.Window + self._dataframe = dataframe + self._query_compiler = dataframe._query_compiler + self.window_kwargs = { # pragma: no cover + "window": window, + "min_periods": min_periods, + "center": center, + "win_type": win_type, + "on": on, + "axis": axis, + "closed": closed, + "step": step, + "method": method, + } + self.axis = axis + + def mean(self, *args, **kwargs): + # TODO: SNOW-1063357: Modin upgrade - modin.pandas.window.Window + return self._dataframe.__constructor__( + query_compiler=self._query_compiler.window_mean( + self.axis, self.window_kwargs, *args, **kwargs + ) + ) + + def sum(self, *args, **kwargs): + # TODO: SNOW-1063357: Modin upgrade - modin.pandas.window.Window + return self._dataframe.__constructor__( + query_compiler=self._query_compiler.window_sum( + self.axis, self.window_kwargs, *args, **kwargs + ) + ) + + def var(self, ddof=1, *args, **kwargs): + return self._dataframe.__constructor__( + query_compiler=self._query_compiler.window_var( + self.axis, self.window_kwargs, ddof, *args, **kwargs + ) + ) + + def std(self, ddof=1, *args, **kwargs): + # TODO: SNOW-1063357: Modin upgrade - modin.pandas.window.Window + return self._dataframe.__constructor__( + query_compiler=self._query_compiler.window_std( + self.axis, self.window_kwargs, ddof, *args, **kwargs + ) + ) + + +@_inherit_docstrings( + pandas.core.window.rolling.Rolling, + excluded=[pandas.core.window.rolling.Rolling.__init__], + modify_doc=doc_replace_dataframe_with_link, +) +# TODO SNOW-1041934: Add support for more window aggregations +class Rolling(metaclass=TelemetryMeta): + def __init__( + self, + dataframe, + window: Any, + min_periods: Optional[int] = None, + center: bool = False, + win_type: Optional[str] = None, + on: Optional[str] = None, + axis: Union[int, str] = 0, + closed: Optional[str] = None, + step: Optional[int] = None, + method: str = "single", + ) -> None: + # TODO: SNOW-1063358: Modin upgrade - modin.pandas.window.Rolling + # Raise ValueError when invalid parameter values/combinations + if (isinstance(window, int) and window <= 0) or window is None: + raise ValueError("window must be an integer 0 or greater") + if not isinstance(center, bool): + raise ValueError("center must be a boolean") + if min_periods is not None and not isinstance(min_periods, int): + raise ValueError("min_periods must be an integer") + if isinstance(min_periods, int) and min_periods < 0: + raise ValueError("min_periods must be >= 0") + if ( + isinstance(min_periods, int) + and isinstance(window, int) + and min_periods > window + ): + raise ValueError(f"min_periods {min_periods} must be <= window {window}") + + self._dataframe = dataframe + self._query_compiler = dataframe._query_compiler + self.rolling_kwargs = { + "window": window, + "min_periods": min_periods, + "center": center, + "win_type": win_type, + "on": on, + "axis": axis, + "closed": closed, + "step": step, + "method": method, + } + self.axis = axis + + def _call_qc_method(self, method_name, *args, **kwargs): + """ + Call a query compiler method for the specified rolling aggregation. + + Parameters + ---------- + method_name : str + Name of the aggregation. + *args : tuple + Positional arguments to pass to the query compiler method. + **kwargs : dict + Keyword arguments to pass to the query compiler method. + + Returns + ------- + BaseQueryCompiler + QueryCompiler holding the result of the aggregation. + """ + # TODO: SNOW-1063358: Modin upgrade - modin.pandas.window.Rolling + qc_method = getattr(self._query_compiler, f"rolling_{method_name}") + return qc_method(self.axis, self.rolling_kwargs, *args, **kwargs) + + def _aggregate(self, method_name, *args, **kwargs): + """ + Run the specified rolling aggregation. + + Parameters + ---------- + method_name : str + Name of the aggregation. + *args : tuple + Positional arguments to pass to the aggregation. + **kwargs : dict + Keyword arguments to pass to the aggregation. + + Returns + ------- + DataFrame or Series + Result of the aggregation. + """ + # TODO: SNOW-1063358: Modin upgrade - modin.pandas.window.Rolling + qc_result = self._call_qc_method(method_name, *args, **kwargs) + return self._dataframe.__constructor__(query_compiler=qc_result) + + def count(self, numeric_only: bool = False): + # TODO: SNOW-1063358: Modin upgrade - modin.pandas.window.Rolling + return self._aggregate(method_name="count", numeric_only=numeric_only) + + def sem( + self, + ddof: int = 1, + numeric_only: bool = False, + *args: Any, + **kwargs: Any, + ): + # TODO: SNOW-1063358: Modin upgrade - modin.pandas.window.Rolling + return self._aggregate( + method_name="sem", ddof=ddof, numeric_only=numeric_only, *args, **kwargs + ) + + def sum( + self, + numeric_only: bool = False, + *args: Any, + engine: Optional[Literal["cython", "numba"]] = None, + engine_kwargs: Optional[dict[str, bool]] = None, + **kwargs: Any, + ): + # TODO: SNOW-1063358: Modin upgrade - modin.pandas.window.Rolling + return self._aggregate( + method_name="sum", + numeric_only=numeric_only, + engine=engine, + engine_kwargs=engine_kwargs, + *args, + **kwargs, + ) + + def mean( + self, + numeric_only: bool = False, + *args: Any, + engine: Optional[Literal["cython", "numba"]] = None, + engine_kwargs: Optional[dict[str, bool]] = None, + **kwargs: Any, + ): + # TODO: SNOW-1063358: Modin upgrade - modin.pandas.window.Rolling + return self._aggregate( + method_name="mean", + numeric_only=numeric_only, + engine=engine, + engine_kwargs=engine_kwargs, + *args, + **kwargs, + ) + + def median( + self, + numeric_only: bool = False, + engine: Optional[Literal["cython", "numba"]] = None, + engine_kwargs: Optional[dict[str, bool]] = None, + **kwargs: Any, + ): + # TODO: SNOW-1063358: Modin upgrade - modin.pandas.window.Rolling + return self._aggregate( + method_name="median", + numeric_only=numeric_only, + engine=engine, + engine_kwargs=engine_kwargs, + **kwargs, + ) + + def var( + self, + ddof: int = 1, + numeric_only: bool = False, + *args: Any, + engine: Optional[Literal["cython", "numba"]] = None, + engine_kwargs: Optional[dict[str, bool]] = None, + **kwargs: Any, + ): + # TODO: SNOW-1063358: Modin upgrade - modin.pandas.window.Rolling + return self._aggregate( + method_name="var", + ddof=ddof, + numeric_only=numeric_only, + engine=engine, + engine_kwargs=engine_kwargs, + *args, + **kwargs, + ) + + def std( + self, + ddof: int = 1, + numeric_only: bool = False, + *args: Any, + engine: Optional[Literal["cython", "numba"]] = None, + engine_kwargs: Optional[dict[str, bool]] = None, + **kwargs: Any, + ): + # TODO: SNOW-1063358: Modin upgrade - modin.pandas.window.Rolling + return self._aggregate( + method_name="std", + ddof=ddof, + numeric_only=numeric_only, + engine=engine, + engine_kwargs=engine_kwargs, + *args, + **kwargs, + ) + + def min( + self, + numeric_only: bool = False, + *args: Any, + engine: Optional[Literal["cython", "numba"]] = None, + engine_kwargs: Optional[dict[str, bool]] = None, + **kwargs: Any, + ): + # TODO: SNOW-1063358: Modin upgrade - modin.pandas.window.Rolling + return self._aggregate( + method_name="min", + numeric_only=numeric_only, + engine=engine, + engine_kwargs=engine_kwargs, + *args, + **kwargs, + ) + + def max( + self, + numeric_only: bool = False, + *args: Any, + engine: Optional[Literal["cython", "numba"]] = None, + engine_kwargs: Optional[dict[str, bool]] = None, + **kwargs: Any, + ): + # TODO: SNOW-1063358: Modin upgrade - modin.pandas.window.Rolling + return self._aggregate( + method_name="max", + numeric_only=numeric_only, + engine=engine, + engine_kwargs=engine_kwargs, + *args, + **kwargs, + ) + + def corr( + self, + other: Optional[SnowparkDataFrame] = None, + pairwise: Optional[bool] = None, + ddof: int = 1, + numeric_only: bool = False, + **kwargs: Any, + ): + # TODO: SNOW-1063358: Modin upgrade - modin.pandas.window.Rolling + return self._aggregate( + method_name="corr", + other=other, + pairwise=pairwise, + ddof=ddof, + numeric_only=numeric_only, + **kwargs, + ) + + def cov( + self, + other: Optional[SnowparkDataFrame] = None, + pairwise: Optional[bool] = None, + ddof: int = 1, + numeric_only: bool = False, + **kwargs: Any, + ): + # TODO: SNOW-1063358: Modin upgrade - modin.pandas.window.Rolling + return self._aggregate( + method_name="cov", + other=other, + pairwise=pairwise, + ddof=ddof, + numeric_only=numeric_only, + **kwargs, + ) + + def skew( + self, + numeric_only: bool = False, + **kwargs: Any, + ): + # TODO: SNOW-1063358: Modin upgrade - modin.pandas.window.Rolling + return self._aggregate(method_name="skew", numeric_only=numeric_only, **kwargs) + + def kurt( + self, + numeric_only: bool = False, + **kwargs: Any, + ): + # TODO: SNOW-1063358: Modin upgrade - modin.pandas.window.Rolling + return self._aggregate(method_name="kurt", numeric_only=numeric_only, **kwargs) + + def apply( + self, + func: Any, + raw: bool = False, + engine: Optional[Literal["cython", "numba"]] = None, + engine_kwargs: Optional[dict[str, bool]] = None, + args: Optional[tuple] = None, + kwargs: Optional[dict] = None, + ): + # TODO: SNOW-1063358: Modin upgrade - modin.pandas.window.Rolling + return self._aggregate( + method_name="apply", + func=func, + raw=raw, + engine=engine, + engine_kwargs=engine_kwargs, + args=args, + kwargs=kwargs, + ) + + def aggregate( + self, + func: Union[str, list, dict], + *args: Any, + **kwargs: Any, + ): + # TODO: SNOW-1063358: Modin upgrade - modin.pandas.window.Rolling + return self._aggregate(method_name="aggregate", func=func, *args, **kwargs) + + agg = aggregate + + def quantile( + self, + quantile: float, + interpolation: str = "linear", + numeric_only: bool = False, + **kwargs: Any, + ): + # TODO: SNOW-1063358: Modin upgrade - modin.pandas.window.Rolling + return self._aggregate( + method_name="quantile", + quantile=quantile, + interpolation=interpolation, + numeric_only=numeric_only, + **kwargs, + ) + + def rank( + self, + method: str = "average", + ascending: bool = True, + pct: bool = False, + numeric_only: bool = False, + **kwargs, + ): + # TODO: SNOW-1063358: Modin upgrade - modin.pandas.window.Rolling + return self._aggregate( + method_name="rank", + method=method, + ascending=ascending, + pct=pct, + numeric_only=numeric_only, + **kwargs, + ) + + +# TODO: SNOW-1063366: Modin upgrade - modin.pandas.window.Expanding diff --git a/src/snowflake/snowpark/modin/plugin/PANDAS_CHANGELOG.md b/src/snowflake/snowpark/modin/plugin/PANDAS_CHANGELOG.md new file mode 100644 index 00000000000..9d1e7b6b698 --- /dev/null +++ b/src/snowflake/snowpark/modin/plugin/PANDAS_CHANGELOG.md @@ -0,0 +1,312 @@ +## 1.15.0a1 (tbd) + +### Bug Fixes +- Fixed overriding of subclasses' property docstrings for modin issue https://github.com/modin-project/modin/issues/7113. + +## 1.14.0a2 (2024-04-18) + +### Behavior Changes +- The `to_pandas` operation converts all integers to int64, instead of int8, int16 etc. To get an exact type, the user needs to explicitly specify precision values for their Snowflake column. This is a general behavior change across all of Snowpark. +- The following API changes are made to align Snowpark pandas with the pandas 2.2.1 API: + - Updated DateOffset strings to pandas 2.2.1 versions. + - As part of this transition, we have a set of transitional API and test bugs: + - SNOW-1320623, SNOW-1321196 - pandas `df.loc` and `__setitem__` have buggy behavior when: + - the column key has duplicates in a specific manner (https://github.com/pandas-dev/pandas/issues/58317), or + - a new row and column are used in the row and column keys (https://github.com/pandas-dev/pandas/issues/58316). + Snowpark pandas deviates from this behavior and will maintain the same behavior as pandas from versions 2.1.x. + - SNOW-1320660 - `qcut` / `cut` with bin preparation is temporarily NotImplemented due to upstream changes. + - SNOW-1321662 - `merge` fails when join is outer and sort is False. + - SNOW-1321682 - `df.melt` w/ duplicated cols. + - SNOW-1318223 - `series.py::_flex_method` list-like other (`pd.Index`) may not be supported in pandas now. + - SNOW-1321719 - `test_bitwise_operators.py` xfails. +- Changed the dtype of the index of empty `DataFrame` and `Series` to be `int64` rather than `object` to match the behavior of pandas. +- Changed the import path of Snowpark pandas package to use Modin 0.28.1 instead. The new recommended import statement is `import modin.pandas as pd; import snowflake.snowpark.modin.plugin`. + +### New Features +- Added support for `axis` argument for `df.where` and `df.mask` when `other` is a Series. +- Added back `_repr_html_` to DataFrame class for pretty printing (partially reverts commit 576ba26586caca3fa063da1fed465c61091b6d9c). +- Added support for `DataFrameGroupBy.nunique`. + +## 1.14.0a1 (2024-04-11) + +### Behavior Changes +- The following API changes are made to align Snowpark pandas with the pandas 2.1.4 API: + - Removed `errors` and `try_cast` arguments from `DataFrame`/`Series.where` and `mask`. + - Added the `dtype_backend` argument to `DataFrame`/`Series.convert_dtypes`; this argument is ignored by Snowpark pandas and only exists for compatibility. + - Removed `is_copy` from `DataFrame`/`Series.take`. + - Removed `squeeze` argument from `DataFrame`/`Series.groupby`. Changed the default value of `group_keys` to `True`, and `observed` to `no_default`. + - Limited the length of generated labels and identifiers to 32 characters + - Removed the `squeeze`, `prefix`, `mangle_dupe_cols`, `error_bad_lines`, and `warn_bad_lines` arguments from `pd.read_csv`. These were previously unsupported by Snowpark pandas, and existed only for compatibility. + - Renamed the `skip_initial_space` argument in `pd.read_csv` to `skipinitialspace`; it remains unsupported and will raise an error if specified. + - Added the `date_format` and `dtype_backend` arguments in `pd.read_csv`. These are currently unsupported and added only for compatibility. `dtype_backend` is ignored and will raise a warning if provided, and `date_format` will raise an error. + - Added the `dtype_backend`, `filesystem`, and `filters` arguments in `pd.read_parquet`. These are currently unsupported and added only for compatibility. `dtype_backend` is ignored and will raise a warning if provided, and `filesystem` and `filters` will raise an error. + - Removed the `numpy` argument from `pd.read_json`. This was previously unsupported, and existed only for compatibility. + - Added the `dtype_backend` and `engine` arguments to `pd.read_json`. These are currently unsupported and added only for compatibility; they are ignored and will raise a warning if provided. + +- The following methods are removed: + - `DataFrame`/`Series.append` + - `Series.is_monotonic` + +### New Features +- Added support for `pd.cut` with `retbins=False` and `labels=False`. +- Added support for `Series.str.strip`. +- Added support for `Series.str.len`. +- Added support for `Series.str.capitalize`. +- Added support for `DataFrame.apply` and `Series.apply` to work with `@udf` decorated functions to allow working with package dependencies. +- Added support for `DataFrameGroupBy.transform`. +- Added support for `DataFrame.idxmax`, `DataFrame.idxmin`, `Series.idxmax`, and `Series.idxmin`. +- Added support for `Series.str.replace`. +- Added support for `Series.str.split`. +- Added support for `Series.str.title` and `Series.str.istitle`. +- Added support for `np.where`, `np.logical_*`, and `np.add` operators via `__array_ufunc__` and `__array_function__`. +- Added support for `DataFrameGroupby.head` and `DataFrameGroupBy.tail`. +- Added support for `DataFrameGroupBy.idxmax` and `DataFrameGroupBy.idxmin` for `GroupBy` `axis = 0`. +- Updated to `snowpark-python` v1.14.0. +- Updated to `pandas` 2.2.1 from 2.1.4. +- Added support for `axis` argument for `df.where` and `df.mask` when `other` is a Series. + +### Bug Fixes +- Fixed broadcast when masking a DataFrame with a Series using `df.where` or `df.mask`. +- Error out when scalar is passed for condition to DataFrame/Series `where` or `mask`. +- Fixed property docstring generation for some classes that use the telemetry metaclass. +- Fixed an issue where creating a Snowpark pandas DataFrame from a Series with a tuple `name`, such as `pd.DataFrame(pd.Series(name=("A", 1)))`, did not create `MultiIndex` columns on the resulting frame. +- Added custom docstrings inplace to avoid module reload errors. +- Added a separate docstring class for BasePandasDataset. +- Fixed docstring overrides for subclasses. + +## 1.13.0a1 (2024-03-15) +### Dependency Updates +- Upgraded `pandas` from 1.5.3 to 2.1.4. + +### Behavior Changes +- Removed support for Python 3.8, as pandas 2.1.4 no longer supports this version. Please upgrade your environment to use Python 3.9 or newer. +- The following API changes are made as a result of moving from pandas 1.5.3 to 2.1.4: + - Removed sized index types like `pd.Int64Index` and `pd.Float64Index`. Index objects are now explicitly constructed with a dtype parameter. + - Changed the default dtype of an empty Series from `float64` to `object`. + - Changed the default value of `numeric_only` to `False` for many operations. Previously, setting `numeric_only=None` would automatically drop non-numeric columns from a frame if possible; this behavior has been removed, and setting `numeric_only=None` gives the same behavior as `numeric_only=False`. + - Removed the `level` parameter from aggregation functions (`sum`, `prod`, `count`, `any`, etc.). + - Removed `Series.append`. Use `pd.concat` instead. + - Removed the `inplace` parameter for `set_axis`. + - Removed the `datetime_is_numeric` parameter for `describe`. All datetime data is now treated as numeric. + - Removed the `loffset` and `base` parameters for `resample` and `Grouper`. Use `offset` and `origin` instead. + - Added a name to the output of `value_counts`. The returned series will be named `count` when `normalize=False`, and `proportion` when `normalize=True`. +- The following errors have changed as a result of moving from pandas 1.5.3 to 2.1.4: + - When attempting to call `DataFrame.aggregate` with a dict where a label has no associated functions (e.g. `df.aggregate({0: []})`), the error message has changed from "no result" to "No objects to concatenate." + - Calling aggregation methods with `numeric_only=True` on non-numeric `Series` objects raises a `TypeError`. + - Calling `DataFrame.aggregate` or `Series.aggregate` with a list of aggregations will not raise an error when invalid keyword arguments are passed. For example, `df.agg(["count"], invalid=0)` will not raise an error in Snowpark pandas even though "invalid" is not a valid argument to the `count` aggregation. + - Calling `GroupBy.shift` with a non-integer value for the `periods` argument now always raises a `TypeError`. Previously, floating point values that happened to be integers (like `2.0` or `-2.0`) were valid. + - Stopped automatically creating a Snowpark session when there is no active Snowpark session. Instead, Snowpark pandas requires a unique active Snowpark session. + +### New Features +- Added `"quantile"` as a valid aggregation in `DataFrame.agg` and `Series.agg`. +- Added support for binary operations between `DataFrame`/`Series` and `Series`/`DataFrame` along `axis=1`. +- Added support for binary operations between a `Series` or `DataFrame` object and a list-like object for `axis=1`. +- Added support for `DataFrame.round` and `Series.round`. +- Added support for `df.melt` and `pd.melt` +- Added support for binary operations between two `DataFrame` objects. +- Added support for `DataFrame.sort_index` and `Series.sort_index` along `axis=0`. +- Added support for `DataFrame.skew` and `Series.skew` along `axis=0` +- Added support for reading `SELECT` SQL Queries into a `DataFrame` object via `pd.read_snowflake` and changed `name` argument of `pd.read_snowflake` to `name_or_query`. +- Added support for `Series.str.startswith` and `Series.str.endswith`. +- Added support for reading SQL Queries with CTEs and CTEs with anonymous stored procedures into a `DataFrame` object via `pd.read_snowflake`. +- Added support for `DataFrame.first_valid_index`, `DataFrame.last_valid_index`, `Series.first_valid_index`, and `Series.last_valid_index`. +- Added support for `DataFrame.ffill`, `DataFrame.pad`, `Series.ffill`, and, `Series.pad`. +- Added support for reading `CALL SQL` Queries into a `DataFrame` object via `pd.read_snowflake`. +- Added support for `Series.str.lower` and `Series.str.upper`. +- Added support for `Series.str.isdigit`, `Series.str.islower`, and `Series.str.isupper`. +- Added partial support for `DataFrameGroupBy.apply` on `axis=0`, for `func` returning a `DataFrame`. +- Added partial support for `DataFrameGroupBy.apply` on `axis=0`, for `func` returning an object that is neither a DataFrame nor a Series. +- Added support for `Series.groupby.cumcount`, `Series.groupby.cummax`, `Series.groupby.cummin`, and `Series.groupby.cumsum`. +- Added support for `DataFrame.groupby.cumcount`, `DataFrame.groupby.cummax`, `DataFrame.groupby.cummin`, and `DataFrame.groupby.cumsum`. +- Added support for `pd.qcut` with `retbins=False`. +- Added support for `Series.str.contains` and `Series.str.count`. +- Added partial support for `DataFrameGroupBy.apply` on `axis=0`, for `func` always returning a `Series` with the same index and same name. +- Added support for `DataFrameGroupBy.rank` and `SeriesGroupBy.rank`. + +### Bug Fixes +- Allowed getting the Snowpark pandas session before creating a Snowpark pandas Dataframe or Series. +- Fixed an issue when using `pd.read_snowflake` together with `apply(..., axis=1)` where the row position column could not be disambiguated. +- Fixed the exception that you get when accessing a missing attribute of the Snowpark pandas module. +- Using dataframe or series apply(axis=1) when there are multiple sessions no longer raises an exception. +- Added docstring and doctests to correctly reflect difference between Snowpark pandas and native pandas functionality for `get` method. + +### Improvements +- Improved performance for `DataFrame.apply` and `Series.apply` for `axis=1` for functions passed without type hints by micro-batching rows. +- Restructure Snowpark pandas documentation + +## 1.12.1a1 (2024-02-20) + +### New Features +- Added support for `DataFrame.cummin`, `DataFrame.cummax`, `DataFrame.cumsum`, `Series.cummin`, `Series.cummax`, and `Series.cumsum`. +- Added support for `groups` and `indices` properties of `groupby` object. +- Added support for `DataFrame.add_prefix`, `DataFrame.add_suffix`, `Series.add_prefix`, and `Series.add_suffix`. +- Added support for `DataFrame.rolling` and `Series.rolling` on `axis=0` with integer `window`, `min_periods>=1`, and `center` for aggregations `min`, `max`, `sum`, `mean`, `var`, and `std`. +- Added support for `DataFrame.rank` and `Series.rank` with `pct=True`. +- Added support for `pd.date_range`. +- Added support for the `fill_value` parameter in binary operations. +- Added support for `Dataframe.duplicated` and `Series.duplicated`. +- Added support for `Dataframe.drop_duplicates` and `Series.drop_duplicates`. +- Added support for binary operations between `DataFrame` and `Series` (and vice-versa). +- Added support for binary operations between a `Series` or `DataFrame` object and a list-like object for `axis=0`. + +### Behavior Changes +- Deprecated support for Python 3.8. A future release will upgrade the `pandas` version to 2.1.4, which no longer supports Python 3.8. Users should upgrade Python to 3.9 or later. + +### Improvements +- Added cleanup logic at interpreter shutdown to close all active sessions. +- Improved performance for `DataFrame.apply` for `axis=1` by relying on Snowflake vectorized UDFs instead of vectorized UDTFs together with dynamic pivot. + +### Bug Fixes +- Fixed bug for `loc` when the index is unordered and the key is a slice with reversed order. +- Fixed bug for `pd.get_dummies` when input has been sorted, or just read from Snowflake. + +## 1.12.0a1 (2024-02-02) + +### Improvements +- Enabled telemetry for several private methods, e.g., `__getitem__` and `__setitem__`. +- Removed `to_numeric` length check. +- Added parameter type validation for aggregation, includes numeric_only, skipna and min_count. +- Changed `to_pandas` to return decimal numbers as `float64` instead of `object` based on Snowpark 1.12 release. + +### Bug Fixes +- Fixed bug where `loc` get on multiindex prefix matching. +- Removed the `modin.pandas.Session` reference to the Snowpark Session class. +- Removed unnecessary coalescing of join keys for left, right and inner join/merge. + +### New Features +- Added support for `DataFrame.diff` and `Series.diff`. +- Added support for `DataFrame.groupby.shift` and `Series.groupby.shift` +- Added support for `DataFrame.quantile` and `Series.quantile` +- Added support for `min`, `max`, `count`, and `sum` aggregations with `axis=1`. +- Added support for `DataFrame.resample` and `Series.resample` for aggregations: `median`, `sum`, `std`, `var`, `count`. +- Added support for binary operations with `pd.DateOffset` where offset is treated as a timedelta. +- Added support for `DataFrame.fillna` where `value` is a dataframe or `Series.fillna` where `value` is a series or dict. +- Added support for `DataFrame.isin`. +- Added support for `pd.get_dummies` for DataFrames and Series if params `dummy_na`, `drop_first` and `dtype` take default values. +- Added support for `groupby` with `sum`, `DataFrame.sum`, and `Series.sum` for string-typed data. +- Added support for `DataFrame.select_dtypes`. +- Added support for partial string indexing for `DatetimeIndex`. +- Added support for `DataFrame.iterrows` and `DataFrame.itertuples`. +- Added support for `DataFrame.sample` and `Series.sample`. +- Added support for `DataFrame.shift` and `Series.shift` with `axis=0,1`, `fill_value` and `periods`. +- Added support for `DataFrame.rank` and `Series.rank`. +- Added support for `DataFrame.describe` and `Series.describe`. +- Added support for `DataFrame.replace` and `Series.replace`. + +### Bug Fixes +- Fixed bug when `apply` has been called multiple times. +- Fixed bug where `loc` with slice key on a single row dataframe or series. +- Fixed bug where `series.reset_index` triggers eager evaluation. + +## 1.11.1a1 (2023-12-21) + +### Improvements +- Improved performance of `transpose` by removing unnecessary count queries. +- Raised NotImplementedError where setting cell with list like values. +- Reduced the number of queries for `iloc` get with scalar row key +- Improved performance of `insert` by removing count query. +- Improved performance of displaying Dataframe/Series in notebook. As part of this improvement we also removed support for html representation for DataFrames. +- Enabled SQL simplifier. +- Started warning users about all fallbacks to pandas via stored procedures. + +### Bug Fixes +- Fixed bug when `setitem`/`loc` on empty frame returns wrong result. +- Fixed bug where telemetry message can be duplicated. + +## 1.10.0a1 (2023-12-13) + +### New Features +- Added support for setting the Snowpark session for Snowpark pandas DataFrame/Series, via `snowflake.snowpark.modin.pandas.session`. +- Added support for `ngroups` on `groupby` object. +- Added support for `Series.set_axis()` and `DataFrame.set_axis()`. +- Added support for `Series.dt.month`, `Series.dt.year`, `Series.dt.day` and `Series.dt.quarter`. +- Added support for `DataFrame.transform` with string and callable parameters. +- Added support for `DataFrame.abs`, `Series.abs`, `DataFrame.__neg__` and `Series.__neg__`. +- Added support for `df.resample` and `ser.resample`. Supported resample bins are: `T`, `S`, `H`, and `D`. Supported aggregations are: `max`, `min`, and `mean`. +- Added support for `pd.read_parquet` using Snowflake `COPY INTO` SQL command. +- Added support for `pd.read_json` using Snowflake `COPY INTO` SQL command. +- Added support for `DataFrame.value_counts` and `Series.value_counts`. +- Added support for `DataFrame.all`, `Series.all`, `Dataframe.any` and `Series.any` for integer +- Added support for `Series.mask()` and `DataFrame.mask()`. +- Added support for `ffill` on `df.resample`. +- Added support for `method` parameter of `DataFrame.fillna()` and `Series.fillna()`. + +### Improvements +- Updated with changes from snowpark-python 1.8.0 release. +- Rewrote and improved `.iloc` get using single query with lazy evaluation. +- Improved warning messages from `.to_datetime`. +- Improved `.to_datetime` to avoid unnecessary eager evaluation. +- Improved performance for fallback execution, i.e., running unsupported pandas APIs using stored procedures. +- Rewrote and improved `.loc` get using single query with lazy evaluation. +- Rewrote and improved `.loc` set using single query with lazy evaluation. +- Changed the implementation of `pd.read_csv` to use Snowflake `COPY INTO` SQL command instead of locally executing pandas `read_csv`. +- Improved performance of groupby by removing unnecessary count queries. +- Raise NotImplementedError for pivot_table when no index configured. +- memory_usage() will not return an error, but it will return '0' for all values. +- Rewrote and improved `__getitem__` using single query with lazy evaluation. +- Rewrote and improved `__setitem__` using single query with lazy evaluation. +- Improved performance of aggregate functions by reducing query count. + +### Bug Fixes +- Fixed a bug where binary operations between series with duplicate index values produces wrong result. +- Fixed a bug for `fillna` where the fill value is not supposed to be applied to index columns, and also stay consistent with Snowflake type system without explicit casting to variant. +- Fixed a bug where non-homogenous columns or indices were not converted correctly in `to_pandas`. + +### Changes +- Error out when unsupported aggregation function is used. + +## 1.7.0a4 (2023-10-10) +- Improved warning messages from `.to_datetime` + +### New Features +- Added support for `DataFrame.to_dict` and `series.to_dict`. +- Added support for `DataFrame.take` and `series.take`. +- Added support for `pd.Series.isin` + +### Improvements +- Rewrote and improved `.iloc` get with series key using single join query with lazy evaluation. +- Updated docstring for `DataFrame.sort_values` and `Series.sort_value` APIs. +- Updated docstring for `DataFrame.reset_index` and `Series.reset_index` APIs. +- Removed unnecessary client side check and fallback for aggregation. + +### Bug Fixes +- Fixed a bug where `.loc` and `.iloc` handle column indexers + +## 1.7.0a3 (2023-10-04) + +### New Features +- Added support for `Series.dt.date`, `Series.dt.hour`, `Series.dt.minute` and `Series.dt.second`. + +### Bug Fixes +- Fixed a bug where `DataFrame.dropna` used the original row position as new row positions after rows were dropped. +- Fixed a bug where `.loc` uses a string as the column key. +- Fixed a bug where `.iloc` pulls series key's index to client. +- Fixed a bug where `DataFrame.join` calls `to_pandas()` unexpected. +- Fixed a bug where some unsupported APIs didn't raise NotImplementedError. +- Fixed a bug where binary operation `pow`, `rpow`, `__and__`, `__rand__`, `__or__`, `__ror__`, `__xor__`, and `__rxor__` calls frontend `default_to_pandas`. +- Fixed a bug where creating DataFrame from shared database fails. + +## 1.7.0a2 (2023-09-20) + +### New Features +- Added support for `pd.read_csv` by reading csv files on the client then uploading data to Snowflake. +- Added support for binary arithmetic and comparison operators between series. +- Added support for `pd.unique`. + +### Improvements +- Improved performance for `head`, `tail`, `_repr_html_`, `loc`, `iloc`, `__getitem__` and `__setitem__`, `__repr__`. +- Improved API documents for Snowpark pandas IO methods. +- Improved error messages when using Snowpark pandas API with multiple Snowpark sessions. +- Improved type conversion performance (from string to datetime). + +### Bug Fixes +- Fixed a bug where an extra temp table was incorrectly created while using `pd.read_snowflake` to read a regular Snowflake table. +- Fixed a bug where `df.pivot_table` failed when the original dataframe is created from large local data. +- Fixed a bug when creating a Snowpark pandas DataFrame/Series from local numpy data that is not json-serializable. +- Fixed a bug where `df.apply`, `series.apply` and `df.applymap` incorrectly convert SQL nulls to JSON nulls in Snowflake Variant data. +- Fixed a bug where aggregation functions with `groupby` did not work on decimal columns. +- Fixed a bug where the output `_repr_html_` and `__repr__` did not match pandas behavior. + +## 1.7.0a1 (2023-09-15) + +Start of Private Preview diff --git a/src/snowflake/snowpark/modin/plugin/README.md b/src/snowflake/snowpark/modin/plugin/README.md new file mode 100644 index 00000000000..88714962314 --- /dev/null +++ b/src/snowflake/snowpark/modin/plugin/README.md @@ -0,0 +1,108 @@ +## Developer setup for SnowPandas +This guide is based on https://github.com/snowflakedb/snowpark-python/blob/main/CONTRIBUTING.md, yet modified so it's possible to develop on both Snowpandas and Snowpark-Python in parallel. +This is necessary due to version conflicts (i.e., the pandas/Arrow versions used). + +First, create new environment for SnowPandas + +```bash +conda create --name snowpandas-dev python=3.9 +``` + +Activate the environment via +```bash +conda activate snowpark-dev +``` + +Then install all dependencies via (from snowpark root!) +```bash +python -m pip install -e ".[development, modin-development]" +pip3 install psutil + +# for demo +pip install jupyter +pip install matplotlib seaborn +``` + +## Folder structure +Following tree diagram shows the high-level structure of the SnowPandas module within Snowpark. +```bash +snowflake +└── snowpark + └── modin + └── pandas + ├── frontend ← pandas API frontend layer + │ └── dispatching ← additional patching for I/O + ├── translation ← folder containing abstraction + │ frontend to DF-algebra and query + │ compiler + │ ├── _internal ← Snowflake specific internals + │ ├── default2pandas← SnowPandas fallback to default pandas implementation + │ └── compiler ← query compiler, Modin -> Snowpark + │ DF/SnowSQL compilation logic. + └── utils ← util classes from Modin, logging, … + +``` + +## Doctests for Modin +Modin uses a decorator `_inherit_docstrings` to equip functions with the original pandas' docstrings that my contain tests. By simply adding an import statement, these doctests can be run through the modin shim. +Yet, not all tests pass currently which can be either due to missing pandas functionality within Modin or formatting errors between expected output and received output. + +For this reason, within `src/conftest.py` all Modin doctests have been deactivated. To activate them, simply comment the `pytest_ignore_collect` function. Modin doctests can be run from the repo root dir via +```bash +pytest -rP src/snowflake/snowpark/modin/pandas --log-cli-level=INFO +``` + +## Configuration file for connecting to Snowflake +Snowflake Python Connector and Snowpark Python API now support creating a connection/session from a configuration file. +Snowpark pandas API also offers the convenience of implicit session creation from a configuration file. +This eliminates the need to explicitly create a Snowpark session in your code, allowing you to write your pandas code just as you would normally. +To achieve this, you'll need to create a configuration file located at `~/.snowflake/connections.toml`. +The contents of this configuration file should be as follows (following [TOML](https://toml.io/en/) file format): + +```python +default_connection_name = "default" + +[default] +account = "" +user = "" +password = "" +role="" +database = "" +schema = "" +warehouse = "" +``` + +The value of `default_connection_name` points to a configuration inside the TOML file, which will be used as the default configuration. +Note that keys of a configuration (`account`, `user`) are the same as keys of connection parameters we use in `tests/parameters.py` and values of a configuration should be double quoted. + +## Git setup +To setup a development version for Snowpark pandas API, run the following git commands: +``` +git clone git@github.com:snowflakedb/snowpandas.git +cd snowpandas +git remote add upstream git@github.com:snowflakedb/snowpark-python.git +git remote set-url --push upstream DISABLE + +# This should be the output when invoking the following command: +# origin git@github.com:snowflakedb/snowpandas.git (fetch) +# origin git@github.com:snowflakedb/snowpandas.git (push) +# upstream git@github.com:snowflakedb/snowpark-python.git (fetch) +# upstream DISABLE (push) +git remote -v +``` +### Branch +- `pandas-main` is the local main branch which will have all changes for Snowpark pandas API. + +### Incorporate changes from the `upstream` Snowpark Python repo +Assume on `pandas-main` branch +``` +git checkout -b +git fetch upstream +git merge upstream/main +git push +``` +Submit a PR to merge your branch to `pandas-main` branch. This should be done regular or there are important changes from Snowpark. + + +### Before PuPr +When releasing Snowpark pandas API, merging this branch with the main via a PR should allow for a clean history. diff --git a/src/snowflake/snowpark/modin/plugin/__init__.py b/src/snowflake/snowpark/modin/plugin/__init__.py new file mode 100644 index 00000000000..076200e9543 --- /dev/null +++ b/src/snowflake/snowpark/modin/plugin/__init__.py @@ -0,0 +1,30 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +from packaging import version + +# We need this import here to prevent circular dependency issues, since snowflake.snowpark.modin.pandas +# currently imports some internal utilities from snowflake.snowpark.modin.plugin. Test cases will +# import snowflake.snowpark.modin.plugin before snowflake.snowpark.modin.pandas, so in order to prevent +# circular dependencies from manifesting, apparently snowflake.snowpark.modin.pandas needs to +# be imported first. +from snowflake.snowpark.modin import pandas # noqa: F401 +from snowflake.snowpark.modin.config import DocModule +from snowflake.snowpark.modin.plugin import docstrings + +DocModule.put(docstrings.__name__) + +install_msg = "Run `pip install snowflake-snowpark-python[modin]` to resolve." +try: + import modin +except ModuleNotFoundError: # pragma: no cover + raise ModuleNotFoundError( + "Modin is not installed. " + install_msg + ) # pragma: no cover + +supported_modin_version = "0.28.1" +if version.parse(modin.__version__) != version.parse(supported_modin_version): + raise ImportError( + "Installed Modin version is not supported. " + install_msg + ) # pragma: no cover diff --git a/src/snowflake/snowpark/modin/plugin/_internal/__init__.py b/src/snowflake/snowpark/modin/plugin/_internal/__init__.py new file mode 100644 index 00000000000..0fbef920926 --- /dev/null +++ b/src/snowflake/snowpark/modin/plugin/_internal/__init__.py @@ -0,0 +1,3 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# diff --git a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py new file mode 100644 index 00000000000..eb54d15b794 --- /dev/null +++ b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py @@ -0,0 +1,1073 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# +# +# This file contains utils functions used by aggregation functions. +# +import functools +from collections import defaultdict +from collections.abc import Hashable, Iterable +from functools import partial +from typing import Any, Callable, Literal, NamedTuple, Optional, Union + +import numpy as np +from pandas._typing import AggFuncType, AggFuncTypeBase +from pandas.core.dtypes.common import ( + is_dict_like, + is_list_like, + is_named_tuple, + is_numeric_dtype, + is_scalar, +) + +from snowflake.snowpark._internal.type_utils import ColumnOrName +from snowflake.snowpark.column import CaseExpr, Column as SnowparkColumn +from snowflake.snowpark.functions import ( + Column, + array_agg, + array_construct, + array_construct_compact, + array_contains, + array_flatten, + array_max, + array_min, + array_position, + builtin, + cast, + coalesce, + col, + count, + count_distinct, + get, + greatest, + iff, + is_null, + least, + listagg, + lit, + max as max_, + mean, + median, + min as min_, + parse_json, + skew, + stddev, + stddev_pop, + sum as sum_, + var_pop, + variance, + when, +) +from snowflake.snowpark.modin.plugin._internal.frame import InternalFrame +from snowflake.snowpark.modin.plugin._internal.ordered_dataframe import ( + OrderedDataFrame, + OrderingColumn, +) +from snowflake.snowpark.modin.plugin._internal.utils import ( + from_pandas_label, + pandas_lit, + to_pandas_label, +) +from snowflake.snowpark.modin.plugin._typing import PandasLabelToSnowflakeIdentifierPair +from snowflake.snowpark.types import ( + BooleanType, + DataType, + DoubleType, + IntegerType, + StringType, +) + +AGG_NAME_COL_LABEL = "AGG_FUNC_NAME" + + +def array_agg_keepna( + column_to_aggregate: ColumnOrName, ordering_columns: Iterable[OrderingColumn] +) -> Column: + """ + Aggregate a column, including nulls, into an array by the given ordering columns. + """ + # array_agg drops nulls, but we can use the solution [1] to work around + # that by turning each element `v` into the array `[v]`... + # except that we can't use array_construct(NULL) and instead have to use + # parse_json(lit("null")) per [2]. + # [1] https://stackoverflow.com/a/77422662 + # [2] https://github.com/snowflakedb/snowflake-connector-python/issues/1388#issuecomment-1371091831 + return array_flatten( + array_agg( + array_construct( + iff( + is_null(column_to_aggregate), + parse_json(lit("null")), + Column(column_to_aggregate), + ) + ) + ).within_group( + [ordering_column.snowpark_column for ordering_column in ordering_columns] + ) + ) + + +def column_quantile( + column: SnowparkColumn, + interpolation: Literal["linear", "lower", "higher", "midpoint", "nearest"], + q: float, +) -> SnowparkColumn: + assert interpolation in ( + "linear", + "nearest", + ), f"unsupported interpolation method '{interpolation}'" + # PERCENTILE_CONT interpolates between the nearest values if needed, while + # PERCENTILE_DISC finds the nearest value + agg_method = "percentile_cont" if interpolation == "linear" else "percentile_disc" + # PERCENTILE_* returns DECIMAL; we cast to DOUBLE + # example sql: SELECT CAST(PERCENTILE_COUNT(0.25) WITHIN GROUP(ORDER BY a) AS DOUBLE) AS a FROM table + return builtin(agg_method)(pandas_lit(q)).within_group(column).cast(DoubleType()) + + +def _columns_coalescing_idxmax_idxmin_helper( + *cols: SnowparkColumn, + axis: Literal[0, 1], + func: Literal["idxmax", "idxmin"], + keepna: bool, + pandas_column_labels: list, + is_groupby: bool = False, +) -> SnowparkColumn: + """ + Computes the index corresponding to the func for each row if axis=1 or column if axis=0. + If all values in a row/column are NaN, then the result will be NaN. + + Parameters + ---------- + *cols: SnowparkColumn + A tuple of Snowpark Columns. + axis: {0, 1} + The axis to apply the func on. + func: {"idxmax", "idxmin"} + The function to apply. + keepna: bool + Whether to skip NaN Values. + pandas_column_labels: list + pandas index/column names. + + Returns + ------- + Callable + """ + if axis == 0: + extremum = max_(*cols) if func == "idxmax" else min_(*cols) + + # TODO SNOW-1316602: Support MultiIndex for DataFrame, Series, and DataFrameGroupBy cases. + if len(pandas_column_labels) > 1: + # The index is a MultiIndex, current logic does not support this. + raise NotImplementedError( + f"{func} is not yet supported when the index is a MultiIndex." + ) + + # TODO SNOW-1270521: max_by and min_by are not guaranteed to break tiebreaks deterministically + extremum_position = ( + get( + builtin("max_by")( + Column(pandas_column_labels[0]), + Column(*cols), + 1, + ), + 0, + ) + if func == "idxmax" + else get( + builtin("min_by")( + Column(pandas_column_labels[0]), + Column(*cols), + 1, + ), + 0, + ) + ) + + if is_groupby and keepna: + # When performing groupby, if a group has any NaN values in its column, the idxmax/idxmin of that column + # will always be NaN. Therefore, we need to check whether there are any NaN values in each group. + return iff( + builtin("count_if")(Column(*cols).is_null()) > 0, + pandas_lit(None), + extremum_position, + ) + else: + # if extremum is null, i.e. there are no columns or all columns are + # null, mark extremum_position as null, because our final expression has + # to evaluate to null. + return builtin("nvl2")(extremum, extremum_position, lit(None)) + + else: + column_array = array_construct(*cols) + # extremum is null if there are no columns or all columns are null. + # otherwise, extremum contains the extremal column, i.e. the max column for + # idxmax and the min column for idxmin. + extremum = (array_max if func == "idxmax" else array_min)(column_array) + # extremum_position is the position of the first column with a value equal + # to extremum. + extremum_position = array_position(extremum, column_array) + + if keepna: + # if any of the columns is null, mark extremum_position as null, + # because our final expression has to evaluate to null. That's how we + # "keep NA." + extremum_position = iff( + array_contains(lit(None), column_array), lit(None), extremum_position + ) + else: + # if extremum is null, i.e. there are no columns or all columns are + # null, mark extremum_position as null, because our final expression has + # to evalute to null. + extremum_position = builtin("nvl2")(extremum, extremum_position, lit(None)) + + # If extremum_position is null, return null. + return builtin("nvl2")( + extremum_position, + # otherwise, we create an array of all the column names using pandas_column_labels + # and get the element of that array that is at extremum_position. + get( + array_construct(*(lit(c) for c in pandas_column_labels)), + cast(extremum_position, "int"), + ), + lit(None), + ) + + +# Map between the pandas input aggregation function (str or numpy function) and +# the corresponding snowflake builtin aggregation function for axis=0. +SNOWFLAKE_BUILTIN_AGG_FUNC_MAP: dict[Union[str, Callable], Callable] = { + "count": count, + "mean": mean, + "min": min_, + "max": max_, + "idxmax": functools.partial( + _columns_coalescing_idxmax_idxmin_helper, func="idxmax" + ), + "idxmin": functools.partial( + _columns_coalescing_idxmax_idxmin_helper, func="idxmin" + ), + "sum": sum_, + "median": median, + "skew": skew, + "std": stddev, + "var": variance, + "booland_agg": builtin("booland_agg"), + "boolor_agg": builtin("boolor_agg"), + np.max: max_, + np.min: min_, + np.sum: sum_, + np.mean: mean, + np.median: median, + np.std: stddev, + np.var: variance, + "array_agg": array_agg, + "quantile": column_quantile, + "nunique": count_distinct, +} + + +class AggFuncInfo(NamedTuple): + """ + Information needed to distinguish between dummy and normal aggregate functions. + """ + + # The aggregate function + func: AggFuncTypeBase + + # If true, the aggregate function is applied to "NULL" rather than a column + is_dummy_agg: bool + + +def _columns_coalescing_min(*cols: SnowparkColumn) -> Callable: + """ + Computes the minimum value in each row, skipping NaN values. If all values in a row are NaN, + then the result will be NaN. + + Example SQL: + SELECT ARRAY_MIN(ARRAY_CONSTRUCT_COMPACT(a, b, c)) AS min + FROM VALUES (10, 1, NULL), (NULL, NULL, NULL) AS t (a, b, c); + + Result: + -------- + | min | + -------- + | 1 | + -------- + | NULL | + -------- + """ + return array_min(array_construct_compact(*cols)) + + +def _columns_coalescing_max(*cols: SnowparkColumn) -> Callable: + """ + Computes the maximum value in each row, skipping NaN values. If all values in a row are NaN, + then the result will be NaN. + + Example SQL: + SELECT ARRAY_MAX(ARRAY_CONSTRUCT_COMPACT(a, b, c)) AS max + FROM VALUES (10, 1, NULL), (NULL, NULL, NULL) AS t (a, b, c); + + Result: + -------- + | max | + -------- + | 10 | + -------- + | NULL | + -------- + """ + return array_max(array_construct_compact(*cols)) + + +def _columns_count(*cols: SnowparkColumn) -> Callable: + """ + Counts the number of non-NULL values in each row. + + Example SQL: + SELECT NVL2(a, 1, 0) + NVL2(b, 1, 0) + NVL2(c, 1, 0) AS count + FROM VALUES (10, 1, NULL), (NULL, NULL, NULL) AS t (a, b, c); + + Result: + --------- + | count | + --------- + | 2 | + --------- + | 0 | + --------- + """ + # IMPORTANT: count and sum use python builtin sum to invoke __add__ on each column rather than Snowpark + # sum_, since Snowpark sum_ gets the sum of all rows within a single column. + # NVL2(col, x, y) returns x if col is NULL, and y otherwise. + return sum(builtin("nvl2")(col, pandas_lit(1), pandas_lit(0)) for col in cols) + + +def _columns_coalescing_sum(*cols: SnowparkColumn) -> Callable: + """ + Sums all non-NaN elements in each row. If all elements are NaN, returns 0. + + Example SQL: + SELECT ZEROIFNULL(a) + ZEROIFNULL(b) + ZEROIFNULL(c) AS sum + FROM VALUES (10, 1, NULL), (NULL, NULL, NULL) AS t (a, b, c); + + Result: + ------- + | sum | + ------- + | 11 | + ------- + | 0 | + ------- + """ + # IMPORTANT: count and sum use python builtin sum to invoke __add__ on each column rather than Snowpark + # sum_, since Snowpark sum_ gets the sum of all rows within a single column. + return sum(builtin("zeroifnull")(col) for col in cols) + + +# Map between the pandas input aggregation function (str or numpy function) and +# the corresponding aggregation function for axis=1 when skipna=True. The returned aggregation +# function may either be a builtin aggregation function, or a function taking in *arg columns +# that then calls the appropriate builtin aggregations. +SNOWFLAKE_COLUMNS_AGG_FUNC_MAP: dict[Union[str, Callable], Callable] = { + "count": _columns_count, + "sum": _columns_coalescing_sum, + np.sum: _columns_coalescing_sum, + "min": _columns_coalescing_min, + "max": _columns_coalescing_max, + "idxmax": _columns_coalescing_idxmax_idxmin_helper, + "idxmin": _columns_coalescing_idxmax_idxmin_helper, + np.min: _columns_coalescing_min, + np.max: _columns_coalescing_max, +} + +# These functions are called instead if skipna=False +SNOWFLAKE_COLUMNS_KEEPNA_AGG_FUNC_MAP: dict[Union[str, Callable], Callable] = { + "min": least, + "max": greatest, + "idxmax": _columns_coalescing_idxmax_idxmin_helper, + "idxmin": _columns_coalescing_idxmax_idxmin_helper, + # IMPORTANT: count and sum use python builtin sum to invoke __add__ on each column rather than Snowpark + # sum_, since Snowpark sum_ gets the sum of all rows within a single column. + "sum": lambda *cols: sum(cols), + np.sum: lambda *cols: sum(cols), + np.min: least, + np.max: greatest, +} + + +class AggregateColumnOpParameters(NamedTuple): + """ + Parameters/Information needed to apply aggregation on a Snowpark column correctly. + """ + + # Snowflake quoted identifier for the column to apply aggregation on + snowflake_quoted_identifier: ColumnOrName + + # The Snowpark data type for the column to apply aggregation on + data_type: DataType + + # pandas label for the new column produced after aggregation + agg_pandas_label: Optional[Hashable] + + # Snowflake quoted identifier for the new Snowpark column produced after aggregation + agg_snowflake_quoted_identifier: str + + # the snowflake aggregation function to apply on the column + snowflake_agg_func: Callable + + # the columns specifying the order of rows in the column. This is only + # relevant for aggregations that depend on row order, e.g. summing a string + # column. + ordering_columns: Iterable[OrderingColumn] + + +def is_snowflake_agg_func(agg_func: AggFuncTypeBase) -> bool: + return agg_func in SNOWFLAKE_BUILTIN_AGG_FUNC_MAP + + +def get_snowflake_agg_func( + agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: int = 0 +) -> Optional[Callable]: + """ + Get the corresponding Snowflake/Snowpark aggregation function for the given aggregation function. + If no corresponding snowflake aggregation function can be found, return None. + """ + if axis == 0: + snowflake_agg_func = SNOWFLAKE_BUILTIN_AGG_FUNC_MAP.get(agg_func) + if snowflake_agg_func == stddev or snowflake_agg_func == variance: + # for aggregation function std and var, we only support ddof = 0 or ddof = 1. + # when ddof is 1, std is mapped to stddev, var is mapped to variance + # when ddof is 0, std is mapped to stddev_pop, var is mapped to var_pop + # TODO (SNOW-892532): support std/var for ddof that is not 0 or 1 + ddof = agg_kwargs.get("ddof", 1) + if ddof != 1 and ddof != 0: + return None + if ddof == 0: + return stddev_pop if snowflake_agg_func == stddev else var_pop + elif snowflake_agg_func == column_quantile: + interpolation = agg_kwargs.get("interpolation", "linear") + q = agg_kwargs.get("q", 0.5) + if interpolation not in ("linear", "nearest"): + return None + if not is_scalar(q): + # SNOW-1062878 Because list-like q would return multiple rows, calling quantile + # through the aggregate frontend in this manner is unsupported. + return None + return lambda col: column_quantile(col, interpolation, q) + else: + snowflake_agg_func = SNOWFLAKE_COLUMNS_AGG_FUNC_MAP.get(agg_func) + + return snowflake_agg_func + + +def generate_rowwise_aggregation_function( + agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any] +) -> Optional[Callable]: + """ + Get a callable taking *arg columns to apply for an aggregation. + + Unlike get_snowflake_agg_func, this function may return a wrapped composition of + Snowflake builtin functions depending on the values of the specified kwargs. + """ + snowflake_agg_func = SNOWFLAKE_COLUMNS_AGG_FUNC_MAP.get(agg_func) + if not agg_kwargs.get("skipna", True): + snowflake_agg_func = SNOWFLAKE_COLUMNS_KEEPNA_AGG_FUNC_MAP.get( + agg_func, snowflake_agg_func + ) + min_count = agg_kwargs.get("min_count", 0) + if min_count > 0: + # Create a case statement to check if the number of non-null values exceeds min_count + # when min_count > 0, if the number of not NULL values is < min_count, return NULL. + def agg_func_wrapper(fn: Callable) -> Callable: + return lambda *cols: when( + _columns_count(*cols) < min_count, pandas_lit(None) + ).otherwise(fn(*cols)) + + return snowflake_agg_func and agg_func_wrapper(snowflake_agg_func) + return snowflake_agg_func + + +def is_supported_snowflake_agg_func( + agg_func: AggFuncTypeBase, agg_kwargs: dict[str, Any], axis: int +) -> bool: + """ + check if the aggregation function is supported with snowflake. Current supported + aggregation functions are the functions that can be mapped to snowflake builtin function. + + Args: + agg_func: str or Callable. the aggregation function to check + agg_kwargs: keyword argument passed for the aggregation function, such as ddof, min_count etc. + The value can be different for different aggregation functions. + Returns: + is_valid: bool. Whether it is valid to implement with snowflake or not. + """ + return get_snowflake_agg_func(agg_func, agg_kwargs, axis) is not None + + +def are_all_agg_funcs_supported_by_snowflake( + agg_funcs: list[AggFuncTypeBase], agg_kwargs: dict[str, Any], axis: int +) -> bool: + """ + Check if all aggregation functions in the given list are snowflake supported + aggregation functions. + + Returns: + True if all functions in the list are snowflake supported aggregation functions, otherwise, + return False. + """ + return all( + is_supported_snowflake_agg_func(func, agg_kwargs, axis) for func in agg_funcs + ) + + +def check_is_aggregation_supported_in_snowflake( + agg_func: AggFuncType, + agg_kwargs: dict[str, Any], + axis: int, +) -> bool: + """ + check if distributed implementation with snowflake is available for the aggregation + based on the input arguments. + + Args: + agg_func: the aggregation function to apply + agg_kwargs: keyword argument passed for the aggregation function, such as ddof, min_count etc. + The value can be different for different aggregation function. + Returns: + bool + Whether the aggregation operation can be executed with snowflake sql engine. + """ + # validate agg_func, only snowflake builtin agg function or dict of snowflake builtin agg + # function can be implemented in distributed way. + if is_dict_like(agg_func): + return all( + ( + are_all_agg_funcs_supported_by_snowflake(value, agg_kwargs, axis) + if is_list_like(value) + else is_supported_snowflake_agg_func(value, agg_kwargs, axis) + ) + for value in agg_func.values() + ) + elif is_list_like(agg_func): + return are_all_agg_funcs_supported_by_snowflake(agg_func, agg_kwargs, axis) + return is_supported_snowflake_agg_func(agg_func, agg_kwargs, axis) + + +def is_snowflake_numeric_type_required(snowflake_agg_func: Callable) -> bool: + """ + Is the given snowflake aggregation function needs to be applied on the numeric column. + """ + return snowflake_agg_func in [ + mean, + median, + skew, + sum_, + stddev, + stddev_pop, + variance, + var_pop, + column_quantile, + ] + + +def drop_non_numeric_data_columns( + query_compiler: "snowflake_query_compiler.SnowflakeQueryCompiler", # type: ignore[name-defined] # noqa: F821 + pandas_labels_for_columns_to_exclude: list[Hashable], +) -> "snowflake_query_compiler.SnowflakeQueryCompiler": # type: ignore[name-defined] # noqa: F821 + """ + Drop the data columns of the internal frame that are non-numeric if numeric_only is True. + + Args: + query_compiler: The query compiler for the internal frame to process on + pandas_labels_for_columns_to_exclude: List of pandas labels to exclude from dropping even if the + corresponding column is non-numeric. + Returns: + SnowflakeQueryCompiler that contains the processed new frame with non-numeric data columns dropped + """ + from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import ( + SnowflakeQueryCompiler, + ) + + original_frame = query_compiler._modin_frame + # get all data column to retain, a data column is retained if the pandas label for the column + data_column_to_retain: list[PandasLabelToSnowflakeIdentifierPair] = [ + PandasLabelToSnowflakeIdentifierPair( + original_frame.data_column_pandas_labels[i], + original_frame.data_column_snowflake_quoted_identifiers[i], + ) + for i, data_type in enumerate(query_compiler.dtypes.values) + if is_numeric_dtype(data_type) + or ( + original_frame.data_column_pandas_labels[i] + in pandas_labels_for_columns_to_exclude + ) + ] + + # get the original pandas labels and snowflake quoted identifiers for the numeric data columns + new_data_column_pandas_labels: list[Hashable] = [ + col.pandas_label for col in data_column_to_retain + ] + new_data_column_snowflake_quoted_identifiers: list[str] = [ + col.snowflake_quoted_identifier for col in data_column_to_retain + ] + + return SnowflakeQueryCompiler( + InternalFrame.create( + ordered_dataframe=original_frame.ordered_dataframe, + data_column_pandas_labels=new_data_column_pandas_labels, + data_column_snowflake_quoted_identifiers=new_data_column_snowflake_quoted_identifiers, + data_column_pandas_index_names=original_frame.data_column_pandas_index_names, + index_column_pandas_labels=original_frame.index_column_pandas_labels, + index_column_snowflake_quoted_identifiers=original_frame.index_column_snowflake_quoted_identifiers, + ) + ) + + +def generate_aggregation_column( + agg_column_op_params: AggregateColumnOpParameters, + agg_kwargs: dict[str, Any], + is_groupby_agg: bool, + index_column_snowflake_quoted_identifier: Optional[list[str]] = None, +) -> SnowparkColumn: + """ + Generate the aggregation column for the given column and aggregation function. + + Args: + agg_column_op_params: AggregateColumnOpParameters. The aggregation parameter for a Snowpark column, contains following: + - snowflake_quoted_identifier: the snowflake quoted identifier for the column to apply aggregation on + - data_type: the Snowpark datatype for the column to apply aggregation on + - agg_snowflake_quoted_identifier: The snowflake quoted identifier used for the result column after aggregation + - snowflake_agg_func: The Snowflake aggregation function to apply on the given column + - ordering_columns: the list of snowflake quoted identifiers corresponding to the ordering columns + agg_kwargs: keyword argument passed for the aggregation function, such as ddof, min_count etc. + is_groupby_agg: is the aggregation function applied after groupby or not. + index_column_snowflake_quoted_identifier: The Snowflake quoted identifier corresponding to the index column + + Returns: + SnowparkColumn after the aggregation function. The column is also aliased back to the original name + """ + snowpark_column = agg_column_op_params.snowflake_quoted_identifier + snowflake_agg_func = agg_column_op_params.snowflake_agg_func + if is_snowflake_numeric_type_required(snowflake_agg_func) and isinstance( + agg_column_op_params.data_type, BooleanType + ): + # if the column is a boolean column and the aggregation function requires numeric values, + # we cast the boolean column to integer (True mapped to 1, and False mapped to 0). This is + # to stay consistent with pandas behavior, where boolean type in pandas is treated as numeric type. + snowpark_column = cast( + agg_column_op_params.snowflake_quoted_identifier, IntegerType() + ) + + if snowflake_agg_func == sum_: + if isinstance(agg_column_op_params.data_type, StringType): + agg_snowpark_column = listagg(snowpark_column).within_group( + [ + ordering_column.snowpark_column + for ordering_column in agg_column_op_params.ordering_columns + ] + ) + else: + # There is a slightly different behavior for sum in terms of missing value in pandas and Snowflake, + # where sum on a column with all NaN in pandas result in 0, but sum on a column with all NULL result + # in NULL. Therefore, a post process on the result to replace the NULL result with 0 using coalesce. + agg_snowpark_column = coalesce( + snowflake_agg_func(snowpark_column), pandas_lit(0) + ) + elif snowflake_agg_func in ( + SNOWFLAKE_BUILTIN_AGG_FUNC_MAP["booland_agg"], + SNOWFLAKE_BUILTIN_AGG_FUNC_MAP["boolor_agg"], + ): + # Need to wrap column name in IDENTIFIER, or else bool agg function will treat the name as a string literal + agg_snowpark_column = snowflake_agg_func(builtin("identifier")(snowpark_column)) + elif snowflake_agg_func == array_agg: + # Array aggregation requires the ordering columns, which we have to + # pass in here. + # note that we always assume keepna for array_agg. TODO(SNOW-1040398): + # make keepna treatment consistent across array_agg and other + # aggregation methods. + agg_snowpark_column = array_agg_keepna( + snowpark_column, ordering_columns=agg_column_op_params.ordering_columns + ) + elif ( + isinstance(snowflake_agg_func, partial) + and snowflake_agg_func.func == _columns_coalescing_idxmax_idxmin_helper + ): + agg_snowpark_column = _columns_coalescing_idxmax_idxmin_helper( + snowpark_column, + axis=0, + func=snowflake_agg_func.keywords["func"], + keepna=not agg_kwargs.get("skipna", True), + pandas_column_labels=index_column_snowflake_quoted_identifier, # type: ignore + is_groupby=is_groupby_agg, + ) + elif snowflake_agg_func == count_distinct: + if agg_kwargs.get("dropna", True) is False: + # count_distinct does only count distinct non-NULL values. + # Check if NULL is contained, then add +1 in this case. + if not isinstance(snowpark_column, SnowparkColumn): + snowpark_column = col(snowpark_column) + agg_snowpark_column = snowflake_agg_func(snowpark_column) + iff( + sum_(snowpark_column.is_null().cast(IntegerType())) > pandas_lit(0), + pandas_lit(1), + pandas_lit(0), + ) + else: + agg_snowpark_column = snowflake_agg_func(snowpark_column) + else: + agg_snowpark_column = snowflake_agg_func(snowpark_column) + + # Handle min_count and skipna parameters + min_count = -1 + skipna = True + is_groupby_min_max = is_groupby_agg and snowflake_agg_func in [min_, max_] + if snowflake_agg_func is sum_ or is_groupby_min_max: + # min_count parameter is only valid for groupby min/max/sum, dataframe sum and series sum + min_count = agg_kwargs.get("min_count", -1) + if not is_groupby_agg: + # skipna parameter is valid for all supported none-groupby aggregation function + skipna = agg_kwargs.get("skipna", True) + + if not skipna or min_count > 0: + case_expr: Optional[CaseExpr] = None + if not skipna: + # TODO(SNOW-1040398): Use a different aggregation function map for + # skipna=False, and set the skipna value at an earlier layer. + # when skipna is False, return NULL as far as there is NULL in the column. This is achieved by first + # converting the column to boolean with is_null, and call max on the boolean column. If NULL exists, + # the result of max will be True, otherwise, False. + # For example: [1, NULL, 2, 3] will be [False, True, False, False] with is_null, and max on the boolean + # result is True. + case_expr = when( + max_(is_null(agg_column_op_params.snowflake_quoted_identifier)), + pandas_lit(None), + ) + if min_count > 0: + # when min_count > 0, if the number of not NULL values is < min_count, return NULL. + min_count_cond = ( + count(agg_column_op_params.snowflake_quoted_identifier) < min_count + ) + case_expr = ( + case_expr.when(min_count_cond, pandas_lit(None)) + if (case_expr is not None) + else when(min_count_cond, pandas_lit(None)) + ) + + assert ( + case_expr is not None + ), f"No case expression is constructed with skipna({skipna}), min_count({min_count})" + agg_snowpark_column = case_expr.otherwise(agg_snowpark_column) + + # rename the column to agg_column_quoted_identifier + agg_snowpark_column = agg_snowpark_column.as_( + agg_column_op_params.agg_snowflake_quoted_identifier + ) + + return agg_snowpark_column + + +def aggregate_with_ordered_dataframe( + ordered_dataframe: OrderedDataFrame, + agg_col_ops: list[AggregateColumnOpParameters], + agg_kwargs: dict[str, Any], + groupby_columns: Optional[list[str]] = None, + index_column_snowflake_quoted_identifier: Optional[list[str]] = None, +) -> OrderedDataFrame: + """ + Perform aggregation on the snowpark dataframe based on the given column to aggregation function map. + + Args: + ordered_dataframe: a OrderedDataFrame to perform aggregation on + agg_col_ops: mapping between the columns to apply aggregation on and the corresponding aggregation to apply + agg_kwargs: keyword argument passed for the aggregation function, such as ddof, min_count etc. + groupby_columns: If provided, groupby the dataframe with the given columns before apply aggregate. Otherwise, + no groupby will be performed. + index_column_snowflake_quoted_identifier: The Snowflake quoted identifier corresponding to the index column + + Returns: + OrderedDataFrame with all aggregated columns. + """ + + is_groupby_agg = groupby_columns is not None + agg_list: list[SnowparkColumn] = [ + generate_aggregation_column( + agg_column_op_params=agg_col_op, + agg_kwargs=agg_kwargs, + is_groupby_agg=is_groupby_agg, + index_column_snowflake_quoted_identifier=index_column_snowflake_quoted_identifier, + ) + for agg_col_op in agg_col_ops + ] + + if is_groupby_agg: + agg_ordered_dataframe = ordered_dataframe.group_by(groupby_columns, *agg_list) + else: + agg_ordered_dataframe = ordered_dataframe.agg(*agg_list) + return agg_ordered_dataframe + + +def convert_agg_func_arg_to_col_agg_func_map( + internal_frame: InternalFrame, + agg_func: AggFuncType, + pandas_labels_for_columns_to_exclude_when_agg_on_all: list[Hashable], +) -> dict[ + PandasLabelToSnowflakeIdentifierPair, Union[AggFuncTypeBase, list[AggFuncTypeBase]] +]: + """ + Convert the agg_func arguments to column to aggregation function maps, which is a map between + the Snowpark pandas column (represented as a PandasLabelToSnowflakeIdentifierPair) to the corresponding + aggregation functions needs to be applied on this column. Following rules are applied: + 1) If agg_func is a base aggregation (str or callable) or a list of base aggregation function, then all + aggregation functions are applied on each data column of the internal frame. + 2) If agg_func is already in a dict format (column label to aggregation functions map), only the columns + occur in the dictionary key is considered for aggregation. + + Args: + internal_frame: InternalFrame. The internal frame to apply aggregation on + agg_func: AggFuncType (str or callable, or a list of str or callable, or a dict between label and str or callable or list of str or callable) + The aggregations functions to apply on the internal frame. + pandas_labels_for_columns_to_exclude_when_agg_on_all: List[Hashable] + List of pandas labels for the columns to exclude from aggregation when the aggregation needs to be applied on + all data columns, which is the case when rule 1) described above is applied. + + Returns: + Dict[PandasLabelToSnowflakeIdentifierPair, Union[AggFuncTypeBase, List[AggFuncTypeBase]]] + Map between Snowpandas column and the aggregation functions needs to be applied on the column + """ + col_agg_func_map: dict[ + PandasLabelToSnowflakeIdentifierPair, + Union[AggFuncTypeBase, list[AggFuncTypeBase]], + ] = {} + + if is_dict_like(agg_func): + for label, fn in agg_func.items(): + # for each column configured in the map, look for the corresponding columns + col_quoted_identifiers = ( + internal_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( + [label], + include_index=False, + ) + )[0] + + for quoted_identifier in col_quoted_identifiers: + col_agg_func_map[ + PandasLabelToSnowflakeIdentifierPair(label, quoted_identifier) + ] = fn + else: + # if the aggregation function is str or callable or a list of str or callable, apply the aggregations + # functions on each data column. + for label, quoted_identifier in zip( + internal_frame.data_column_pandas_labels, + internal_frame.data_column_snowflake_quoted_identifiers, + ): + if label not in pandas_labels_for_columns_to_exclude_when_agg_on_all: + col_agg_func_map[ + PandasLabelToSnowflakeIdentifierPair(label, quoted_identifier) + ] = agg_func + + return col_agg_func_map + + +def get_agg_func_to_col_map( + col_to_agg_func_map: dict[ + PandasLabelToSnowflakeIdentifierPair, + Union[AggFuncTypeBase, list[AggFuncTypeBase]], + ] +) -> dict[AggFuncTypeBase, list[PandasLabelToSnowflakeIdentifierPair]]: + """ + Convert the column to aggregation function map to aggregation function to columns map, and keeps the order of + the occurrence in the original map. + + For example: + Given col_to_agg_func_map {(col1, "col1") : ["min", "max"], (col2, "col2"): ["max", "sum"]} + The aggregation func to columns map is {"min": [(col1, "col1")], "max": [(col1, "col1"), (col2, "col2")], "sum": [(col2, "col2")]} + """ + agg_func_to_col_map: dict[ + AggFuncTypeBase, list[PandasLabelToSnowflakeIdentifierPair] + ] = defaultdict(list) + for column_identifier, agg_funcs in col_to_agg_func_map.items(): + # iterate over each aggregation function + agg_funcs_list = agg_funcs if is_list_like(agg_funcs) else [agg_funcs] + for agg_func in agg_funcs_list: + agg_func_to_col_map[agg_func].append(column_identifier) + + return agg_func_to_col_map + + +def get_pandas_aggr_func_name(aggfunc: AggFuncTypeBase) -> str: + """ + Returns the friendly name for the aggr function. For example, if it is a callable, it will return __name__ + otherwise the same string name value. + """ + return ( + getattr(aggfunc, "__name__", str(aggfunc)) + if not isinstance(aggfunc, str) + else aggfunc + ) + + +def generate_pandas_labels_for_agg_result_columns( + pandas_label: Hashable, + num_levels: int, + agg_func_list: list[AggFuncTypeBase], + include_agg_func_in_agg_label: bool, + include_pandas_label_in_agg_label: bool, +) -> list[Hashable]: + """ + Generate the pandas labels for the result columns after apply agg_func to the pandas column with given + pandas label. One aggregation column will be produced for each aggregation function in the given list. If + include_agg_func_in_agg_label is true, the aggregation function name will be appended to the original pandas + label to produce the new pandas label, otherwise the original pandas label is used. + For example: Given pandas label 'A', and agg_func [min, max] + if include_agg_func_in_agg_label is False and include_pandas_label_in_agg_label is True, the result labels will be ['A', 'A'] + if include_agg_func_in_agg_label is True and include_pandas_label_in_agg_label is True, the result labels will be [('A', 'min'), ('A', 'max')] + if include_agg_func_in_agg_label is True and include_pandas_label_in_agg_label is False, the result label will be ('min', 'max') + + Note that include_agg_func_in_agg_label and include_pandas_label_in_agg_label can not be both False. + + Args: + pandas_label: Hashable + The pandas label for the column to apply aggregation function on + num_levels: int + The number of levels for the pandas label + agg_func_list: List[AggFuncTypeBase] + List of aggregation functions to be applied on the pandas column + include_agg_func_in_agg_label: bool + Whether to include the aggregation function in the label for the aggregation result column + include_pandas_label_in_agg_label: bool, + Whether to include the original pandas label in the label for the aggregation result column + + Returns: + List[Hashable] + List of pandas labels for the result aggregation columns, the length is the same as agg_func_list. + """ + assert ( + include_pandas_label_in_agg_label or include_agg_func_in_agg_label + ), "the result aggregation label must at least contain at least the original label or the aggregation function name." + agg_func_column_labels = [] + for agg_func in agg_func_list: + label_tuple = ( + from_pandas_label(pandas_label, num_levels) + if include_pandas_label_in_agg_label + else () + ) + aggr_func_label = ( + (get_pandas_aggr_func_name(agg_func),) + if include_agg_func_in_agg_label + else () + ) + label_tuple = label_tuple + aggr_func_label + agg_func_column_labels.append(to_pandas_label(label_tuple)) + + return agg_func_column_labels + + +def generate_column_agg_info( + internal_frame: InternalFrame, + column_to_agg_func: dict[ + PandasLabelToSnowflakeIdentifierPair, + Union[AggFuncInfo, list[AggFuncInfo]], + ], + agg_kwargs: dict[str, Any], + include_agg_func_only_in_result_label: bool, +) -> tuple[list[AggregateColumnOpParameters], list[Hashable]]: + """ + Generate the ColumnAggregationInfo for the internal frame based on the column_to_agg_func map. + + Args: + internal_frame: InternalFrame + The internal frame to apply aggregation on + column_to_agg_func: Dict[PandasLabelToSnowflakeIdentifierPair, Union[AggFuncInfo, List[AggFuncInfo]]], + Map between the Snowpark pandas column needs to apply aggregation on and the aggregation functions to apply + for the column. The Snowpark pandas column is represented as a pair of the pandas label and the quoted + identifier for the columns. The aggregation function can be marked as dummy. In this case, it will be + applied to "Null" rahter than the column. + agg_kwargs: Dict[str, Any] + keyword argument passed for the aggregation function + include_agg_func_only_in_result_label: bool + should the result label only contains the aggregation function name if it is included in the result label. + + + Returns: + List[AggregateColumnOpParameters] + Each AggregateColumnOpParameters contains information of the quoted identifier for the column to apply + aggregation on, the snowflake aggregation function to apply on the column, and the quoted identifier + and pandas label to use for the result aggregation column. + List[Hashable] + The new index data column index names for the dataframe after aggregation + """ + + quoted_identifier_to_snowflake_type: dict[ + str, DataType + ] = internal_frame.quoted_identifier_to_snowflake_type() + num_levels: int = internal_frame.num_index_levels(axis=1) + # reserve all index column name and ordering column names + identifiers_to_exclude: list[str] = ( + internal_frame.index_column_snowflake_quoted_identifiers + + internal_frame.ordering_column_snowflake_quoted_identifiers + ) + column_agg_ops: list[AggregateColumnOpParameters] = [] + # if any value in the dictionary is a list, the aggregation function name is added as + # an extra level to the final pandas label, otherwise not. When any value in the dictionary is a list, + # the aggregation function name will be added as an extra level for the result label. + agg_func_level_included = any( + is_list_like(fn) and not is_named_tuple(fn) + for fn in column_to_agg_func.values() + ) + pandas_label_level_included = ( + not agg_func_level_included or not include_agg_func_only_in_result_label + ) + + for pandas_label_to_identifier, agg_func in column_to_agg_func.items(): + pandas_label, quoted_identifier = pandas_label_to_identifier + agg_func_list = ( + [agg_func] + if not is_list_like(agg_func) or is_named_tuple(agg_func) + else agg_func + ) + # generate the pandas label and quoted identifier for the result aggregation columns, one + # for each aggregation function to apply. + agg_col_labels = generate_pandas_labels_for_agg_result_columns( + pandas_label_to_identifier.pandas_label, + num_levels, + [func for (func, _) in agg_func_list], + agg_func_level_included, + pandas_label_level_included, + ) + agg_col_identifiers = ( + internal_frame.ordered_dataframe.generate_snowflake_quoted_identifiers( + pandas_labels=agg_col_labels, excluded=identifiers_to_exclude + ) + ) + identifiers_to_exclude += agg_col_identifiers + # construct the ColumnAggregationInfo for each aggregation + for func_info, label, identifier in zip( + agg_func_list, agg_col_labels, agg_col_identifiers + ): + (func, is_dummy_agg) = func_info + agg_func_col = pandas_lit(None) if is_dummy_agg else quoted_identifier + snowflake_agg_func = get_snowflake_agg_func(func, agg_kwargs, axis=0) + # once reach here, we require all func have a corresponding snowflake aggregation function. + # check_is_aggregation_supported_in_snowflake can be used to help performing the check. + assert ( + snowflake_agg_func + ), f"no snowflake aggregation function found for {func}" + column_agg_ops.append( + AggregateColumnOpParameters( + snowflake_quoted_identifier=agg_func_col, + data_type=quoted_identifier_to_snowflake_type[quoted_identifier], + agg_pandas_label=label, + agg_snowflake_quoted_identifier=identifier, + snowflake_agg_func=snowflake_agg_func, + ordering_columns=internal_frame.ordering_columns, + ) + ) + + new_data_column_index_names: list[Hashable] = [] + if pandas_label_level_included: + new_data_column_index_names += internal_frame.data_column_pandas_index_names + if agg_func_level_included: + new_data_column_index_names += [None] + + return column_agg_ops, new_data_column_index_names diff --git a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py new file mode 100644 index 00000000000..cd35deeeb11 --- /dev/null +++ b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py @@ -0,0 +1,1258 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# +import json +import sys +from collections import namedtuple +from collections.abc import Hashable +from enum import Enum, auto +from typing import Any, Callable, Literal, Optional, Union + +import cloudpickle +import numpy as np +import pandas as native_pd +from pandas._typing import AggFuncType +from pandas.api.types import is_scalar + +from snowflake.snowpark._internal.type_utils import PYTHON_TO_SNOW_TYPE_MAPPINGS +from snowflake.snowpark._internal.udf_utils import get_types_from_type_hints +from snowflake.snowpark.column import Column as SnowparkColumn +from snowflake.snowpark.functions import builtin, col, dense_rank, udf, udtf +from snowflake.snowpark.modin.plugin._internal.frame import InternalFrame +from snowflake.snowpark.modin.plugin._internal.ordered_dataframe import ( + OrderedDataFrame, + OrderingColumn, +) +from snowflake.snowpark.modin.plugin._internal.utils import ( + TempObjectType, + parse_object_construct_snowflake_quoted_identifier_and_extract_pandas_label, + parse_snowflake_object_construct_identifier_to_map, +) +from snowflake.snowpark.modin.utils import MODIN_UNNAMED_SERIES_LABEL +from snowflake.snowpark.session import Session +from snowflake.snowpark.types import ( + ArrayType, + DataType, + IntegerType, + LongType, + MapType, + PandasDataFrameType, + PandasSeriesType, + StringType, + VariantType, +) +from snowflake.snowpark.udf import UserDefinedFunction +from snowflake.snowpark.udtf import UserDefinedTableFunction +from snowflake.snowpark.window import Window + +APPLY_LABEL_COLUMN_QUOTED_IDENTIFIER = '"LABEL"' +APPLY_VALUE_COLUMN_QUOTED_IDENTIFIER = '"VALUE"' +APPLY_ORIGINAL_ROW_POSITION_COLUMN_QUOTED_IDENTIFIER = '"ORIGINAL_ROW_POSITION"' +APPLY_ROW_POSITION_WITHIN_GROUP_COLUMN_QUOTED_IDENTIFIER = '"ROW_POSITION_WITHIN_GROUP"' +APPLY_FIRST_GROUP_KEY_OCCURRENCE_POSITION_QUOTED_IDENTIFIER = ( + '"FIRST_GROUP_KEY_OCCURRENCE_POSITION"' +) + +# Default partition size to use when applying a UDTF. A higher value results in less parallelism, less contention and higher batching. +DEFAULT_UDTF_PARTITION_SIZE = 1000 + +# Use the workaround described below to use functions that are attributes of +# this module in UDFs and UDTFs. Without this workaround, we can't pickle +# those functions. +# https://github.com/cloudpipe/cloudpickle?tab=readme-ov-file#overriding-pickles-serialization-mechanism-for-importable-constructs +cloudpickle.register_pickle_by_value(sys.modules[__name__]) + + +class GroupbyApplySortMethod(Enum): + """ + A rule for sorting the rows resulting from groupby.apply. + """ + + UNSET = auto() + + # order by order of the input row that each output row originated from. + ORIGINAL_ROW_ORDER = auto() + # order by 1) comparing the group keys to each other 2) resolving + # ties by the order within the result for each group. this is like + # "sort=True" for groupby aggregations. + GROUP_KEY_COMPARISON_ORDER = auto() + # order by 1) ordering by the order in which the group keys appear + # in the original frame 2) resolving ties by the order within the + # result for each group. this is like "sort=false" for groupby + # aggregations. + GROUP_KEY_APPEARANCE_ORDER = auto() + + +def check_return_variant_and_get_return_type(func: Callable) -> tuple[bool, DataType]: + """Check whether the function returns a variant in Snowflake, and get its return type.""" + return_type, _ = get_types_from_type_hints(func, TempObjectType.FUNCTION) + if return_type is None or isinstance( + return_type, (VariantType, PandasSeriesType, PandasDataFrameType) + ): + # By default, we assume it is a series-to-series function + # However, vectorized UDF only allows returning one column + # We will convert the result series to a list, which will be + # returned as a Variant + return_variant = True + else: + return_variant = False + return return_variant, return_type + + +def create_udtf_for_apply_axis_1( + row_position_snowflake_quoted_identifier: str, + func: Union[Callable, UserDefinedFunction], + raw: bool, + result_type: Optional[Literal["expand", "reduce", "broadcast"]], + args: tuple, + column_index: native_pd.Index, + input_types: list[DataType], + session: Session, + **kwargs: Any, +) -> UserDefinedTableFunction: + """ + Creates a wrapper UDTF for `func` to produce narrow table results for row-wise `df.apply` (i.e., `axis=1`). + The UDTF produces 3 columns: row position column, label column and value column. + + The label column maintains a json string from a dict, which contains + a pandas label in the current series, and its occurrence. We need to + record the occurrence to deduplicate the duplicate labels so the later pivot + operation on the label column can create separate columns on duplicate labels. + The value column maintains the value of the result after applying `func`. + + Args: + row_position_snowflake_quoted_identifier: quoted identifier identifying the row position column passed into the UDTF. + func: The UDF to apply row-wise. + raw: pandas parameter controlling apply within the UDTF. + result_type: pandas parameter controlling apply within the UDTF. + args: pandas parameter controlling apply within the UDTF. + column_index: The columns of the callee DataFrame, i.e. df.columns as pd.Index object. + input_types: Snowpark column types of the input data columns. + **kwargs: pandas parameter controlling apply within the UDTF. + + Returns: + Snowpark vectorized UDTF producing 3 columns. + """ + + # If given as Snowpark function, extract packages. + udf_packages = [] + if isinstance(func, UserDefinedFunction): + # TODO: Cover will be achieved with SNOW-1261830. + udf_packages = func._packages # pragma: no cover + func = func.func # pragma: no cover + + class ApplyFunc: + def end_partition(self, df): # type: ignore[no-untyped-def] # pragma: no cover + # First column is row position, set as index. + df = df.set_index(df.columns[0]) + + df.columns = column_index + df = df.apply( + func, axis=1, raw=raw, result_type=result_type, args=args, **kwargs + ) + # When a dataframe is returned from `df.apply`, + # `func` is a series-to-series function, e.g., + # def func(row): + # result = row + 1 + # result.index.name = 'new_index_name' + # return result + # + # For example, the original dataframe is + # a b b + # 0 0 1 2 + # + # the result dataframe from `df.apply` is + # new_index_name a b b + # 0 1 2 3 + # After the transformation below, we will get a dataframe with two + # columns. Each row in the result represents the series result + # at a particular position. + # "LABEL" "VALUE" + # 0 {"pos": 0, "0": "a", "names": ["new_index_name"]} 1 + # 1 {"pos": 1, "0": "b", "names": ["new_index_name"]} 2 + # 2 {"pos": 2, "0": "b", "names": ["new_index_name"]} 3 + # where: + # - `pos` indicates the position within the series. + # - The integer keys like "0" map from index level to the result's + # label at that level. In this case, the result only has one + # index level. + # - `names` contains the names of the result's index levels. + # - VALUE contains the result at this position. + if isinstance(df, native_pd.DataFrame): + result = [] + for row_position_index, series in df.iterrows(): + + for i, (label, value) in enumerate(series.items()): + # If this is a tuple then we store each component with a 0-based + # lookup. For example, (a,b,c) is stored as (0:a, 1:b, 2:c). + if isinstance(label, tuple): + obj_label = {k: v for k, v in enumerate(list(label))} + else: + obj_label = {0: label} + obj_label["names"] = series.index.names + obj_label["pos"] = i + result.append( + [ + row_position_index, + json.dumps(obj_label), + value, + ] + ) + # use object type so the result is json-serializable + result = native_pd.DataFrame( + result, columns=["__row__", "label", "value"], dtype=object + ) + # When a series is returned from `df.apply`, + # `func` is a series-to-scalar function, e.g., `np.sum` + # For example, the original dataframe is + # a b + # 0 1 2 + # and the result series from `df.apply` is + # 0 3 + # dtype: int64 + # After the transformation below, we will get a dataframe with two columns: + # "LABEL" "VALUE" + # 0 {'0': MODIN_UNNAMED_SERIES_LABEL} 3 + elif isinstance(df, native_pd.Series): + result = df.to_frame(name="value") + result.insert(0, "label", json.dumps({"0": MODIN_UNNAMED_SERIES_LABEL})) + result.reset_index(names="__row__", inplace=True) + else: + raise TypeError(f"Unsupported data type {df} from df.apply") + + result["value"] = ( + result["value"].apply(handle_missing_value_in_variant).astype(object) + ) + return result + + ApplyFunc.end_partition._sf_vectorized_input = native_pd.DataFrame # type: ignore[attr-defined] + + packages = list(session.get_packages().values()) + udf_packages + func_udtf = udtf( + ApplyFunc, + output_schema=PandasDataFrameType( + [LongType(), StringType(), VariantType()], + [ + row_position_snowflake_quoted_identifier, + APPLY_LABEL_COLUMN_QUOTED_IDENTIFIER, + APPLY_VALUE_COLUMN_QUOTED_IDENTIFIER, + ], + ), + input_types=[PandasDataFrameType([LongType()] + input_types)], + # We have to use the current pandas version to ensure the behavior consistency + packages=[native_pd] + packages, + session=session, + ) + + return func_udtf + + +def convert_groupby_apply_dataframe_result_to_standard_schema( + func_input_df: native_pd.DataFrame, + func_output_df: native_pd.DataFrame, + input_row_positions: native_pd.Series, + include_index_columns: bool, +) -> native_pd.DataFrame: # pragma: no cover: this function runs inside a UDTF, so coverage tools can't detect that we are testing it. + """ + Take the result of applying the user-provided function to a dataframe, and convert it to a dataframe with known schema that we can output from a vUDTF. + + Args: + func_input_df: The input to `func`, where `func` is the Python function + that the user originally passed to apply(). + func_output_df: The output of `func`. + input_row_positions: The original row positions of the rows that + func_input_df came from. + include_index_columns: Whether to include the result's index columns in + the output. + + Returns: + A 5-column dataframe that represents the function result per the + description in create_udtf_for_groupby_apply. + + """ + result_rows = [] + result_index_names = func_output_df.index.names + is_transform = func_output_df.index.equals(func_input_df.index) + for row_number, (index_label, row) in enumerate(func_output_df.iterrows()): + output_row_number = input_row_positions.iloc[row_number] if is_transform else -1 + if include_index_columns: + if isinstance(index_label, tuple): + for k, v in enumerate(index_label): + result_rows.append( + [ + json.dumps({"index_pos": k, "name": result_index_names[k]}), + row_number, + v, + output_row_number, + ] + ) + else: + result_rows.append( + [ + json.dumps({"index_pos": 0, "name": result_index_names[0]}), + row_number, + index_label, + output_row_number, + ] + ) + for col_number, (label, value) in enumerate(row.items()): + obj_label: dict[Any, Any] = {} + if isinstance(label, tuple): + obj_label = {k: v for k, v in enumerate(list(label))} + else: + obj_label = {0: label} + obj_label["data_pos"] = col_number + obj_label["names"] = row.index.names + result_rows.append( + [ + json.dumps(obj_label), + row_number, + convert_numpy_int_result_to_int(value), + output_row_number, + ] + ) + # use object type so the result is json-serializable + result_df = native_pd.DataFrame( + result_rows, + columns=[ + "label", + "row_position_within_group", + "value", + "original_row_number", + ], + dtype=object, + ) + result_df["value"] = ( + result_df["value"].apply(handle_missing_value_in_variant).astype(object) + ) + result_df["first_position_for_group"] = input_row_positions.iloc[0] + return result_df + + +def create_groupby_transform_func( + func: Callable, by: str, level: Any, *args: Any, **kwargs: Any +) -> Callable: + """ + Helper function to create the groupby lambda required for DataFrameGroupBy.transform. + This is a workaround to prevent pickling DataFrame objects: the pickle module will + try to pickle all objects accessible to the function passed in. + + Args + ---- + func: The function to create the groupby lambda required for DataFrameGroupBy. + by: The column(s) to group by. + level: If the axis is a MultiIndex (hierarchical), group by a particular level or levels. + Do not specify both by and level. + args: Function's positional arguments. + kwargs: Function's keyword arguments. + + + Returns + ------- + A lambda function that can be used in place of func in groupby transform. + """ + # - `dropna` controls whether the NA values should be included as a group/be present + # in the group keys. Therefore, it must be False to ensure that no values are excluded. + # Setting `dropna=True` here raises the IndexError: "cannot do a non-empty take from an empty axes." + # This is because any dfs created from the NA group keys result in empty dfs to work with, + # which cannot be used with the `take` method. + # + # - `group_keys` controls whether the grouped column(s) are included in the index. + # - `sort` controls whether the group keys are sorted. + # - `as_index` controls whether the groupby object has group labels as the index. + + # The index of the result of any transform call is guaranteed to be the original + # index. Therefore, the groupby parameters group_keys, sort, and as_index do not + # affect the result of transform, and are not explicitly specified. + + return lambda df: ( + df.groupby(by=by, level=level, dropna=False).transform(func, *args, **kwargs) + ) + + +def create_udtf_for_groupby_apply( + func: Callable, + args: tuple, + kwargs: dict, + data_column_index: native_pd.Index, + index_column_names: list, + input_data_column_types: list[DataType], + input_index_column_types: list[DataType], + session: Session, +) -> UserDefinedTableFunction: + """ + Create a UDTF from the Python function for groupby.apply. + + The UDTF takes as input the following columns in the listed order: + 1. The original row position within the dataframe (not just within the group) + 2. All the index columns + 3. All the data columns + + The UDF returns as output the following columns in the listed order. There is + one row per result row and per result column. + 1. The label for the row or index level value. This is a json string of a dict + representing the label. + + For output rows representing data values, this looks like e.g. if the + data column ('a', 'int_col') is the 4th column, and the entire column + index has names ('l1', 'l2'): + {"data_pos": 4, "0": "a", "1": "int_col", "names": ["l1", "l2"]} + + Note that "names" is common across all data columns. + + For values of an index level, this looks like e.g. if the index level + 3 has name "level_3": + {"index_pos": 3, name: "level_3"} + 2. The row position of this result row within the group. + 3. The value of the index level or the data column at this row. + 4. For transforms, this gives the position of the input row that produced + this result row. We need this for transforms when group_keys=False + because we have to reindex the final result according to original row + position. If `func` is not a transform, this position is -1. + 5. The position of the first row from the input dataframe that fell into + this group. For example, if we are grouping by column "A", we divide + the input dataframe into groups where column A is equal to "a1", where + it's equal to "a2", etc. We then apply `func` to each group. If "a2" + first appears in row position 0, then all output rows resulting from the + "a2" group get a value of 0 for this column. If "a1" first appears in + row position 1, then all output rows resulting from the "a1" group get + a value of 1 for this column. e.g.: + + Input dataframe + --------------- + position A B + 0 a2 b0 + 1 a1 b1 + 2 a2 b2 + + + Input Groups + ------------ + + for group_key == a1: + + A B + a1 b1 + + for group_key == a2: + + A B + a1 b1 + + Output Groups + ------------- + + for group_key == a1: + + first_appearance_position other result columns... + 1 other result values.... + + for group_key == a2: + + first_appearance_position other result columns... + 0 other result values.... + 0 other result values.... + + Args + ---- + func: The function we need to apply to each group + args: Function's positional arguments + kwargs: Function's keyword arguments + data_column_index: Column labels for the input dataframe + index_column_names: Names of the input dataframe's index + input_data_column_types: Types of the input dataframe's data columns + input_index_column_types: Types of the input dataframe's index columns + + Returns + ------- + A UDTF that will apply the provided function to a group and return a + dataframe representing all of the data and metadata of the result. + """ + + class ApplyFunc: + def end_partition(self, df: native_pd.DataFrame): # type: ignore[no-untyped-def] # pragma: no cover: adding type hint causes an error when creating udtf. also, skip coverage for this function because coverage tools can't tell that we're executing this function because we execute it in a UDTF. + """ + Apply the user-provided function to the group represented by this partition. + + Args + ---- + df: The dataframe representing one group + + Returns + ------- + A dataframe representing the result of applying the user-provided + function to this group. + """ + # First column is row position, save it for later. + row_positions = df.iloc[:, 0] + df = df.iloc[:, 1:] + + df.set_index( + [f"ARG{i}" for i in range(2, len(index_column_names) + 2)], inplace=True + ) + df.index.names = index_column_names + df.columns = data_column_index + # Use infer_objects() because integer columns come as floats + # TODO: file snowpark bug about that. Asked about this here: + # https://github.com/snowflakedb/snowpandas/pull/823/files#r1507286892 + input_df = df.infer_objects() + func_result = func(input_df, *args, **kwargs) + if isinstance(func_result, native_pd.Series): + # If function returns series, we have to transpose the series + # and change its metadata a little bit, but after that we can + # continue largely as if the function has returned a dataframe. + # + # If the series has a 1-dimensional index, the series name + # becomes the name of the column index. For example, if + # `func` returned the series native_pd.Series([1], name='a'): + # + # 0 1 + # Name: a, dtype: int64 + # + # The result needs to use the dataframe + # pd.DataFrame([1], columns=pd.Index([0], name='a'): + # + # a 0 + # 0 1 + # + name = func_result.name + func_result.name = None + func_result_as_frame = func_result.to_frame().T + if func_result_as_frame.columns.nlevels == 1: + func_result_as_frame.columns.name = name + return convert_groupby_apply_dataframe_result_to_standard_schema( + input_df, + func_result_as_frame, + row_positions, + # We don't need to include any information + # about the index of `func_result_as_frame`. + # The series only has one index, and that + # index becomes the columns of + # `func_result_as_frame`. + include_index_columns=False, + ) + if isinstance(func_result, native_pd.DataFrame): + return convert_groupby_apply_dataframe_result_to_standard_schema( + input_df, func_result, row_positions, include_index_columns=True + ) + # At this point, we know the function result was not a DataFrame + # or Series + return native_pd.DataFrame( + { + "label": [ + json.dumps({"0": MODIN_UNNAMED_SERIES_LABEL, "data_pos": 0}) + ], + "row_position_within_group": [0], + "value": [convert_numpy_int_result_to_int(func_result)], + "original_row_number": [-1], + "first_position_for_group": [row_positions.iloc[0]], + }, + # use object dtype so result is JSON-serializable + dtype=object, + ) + + input_types = [ + # first input column is the integer row number. the row number integer + # becomes a float inside the UDTF due to SNOW-1184587 + LongType(), + # the next columns are the index columns... + *input_index_column_types, + # ...then come the data columns. + *input_data_column_types, + ] + return udtf( + ApplyFunc, + output_schema=PandasDataFrameType( + [StringType(), IntegerType(), VariantType(), IntegerType(), IntegerType()], + [ + APPLY_LABEL_COLUMN_QUOTED_IDENTIFIER, + APPLY_ROW_POSITION_WITHIN_GROUP_COLUMN_QUOTED_IDENTIFIER, + APPLY_VALUE_COLUMN_QUOTED_IDENTIFIER, + APPLY_ORIGINAL_ROW_POSITION_COLUMN_QUOTED_IDENTIFIER, + APPLY_FIRST_GROUP_KEY_OCCURRENCE_POSITION_QUOTED_IDENTIFIER, + ], + ), + input_types=[PandasDataFrameType(col_types=input_types)], + # We have to specify the local pandas package so that the UDF's pandas + # behavior is consistent with client-side pandas behavior. + packages=[native_pd] + list(session.get_packages().values()), + session=session, + ) + + +def create_udf_for_series_apply( + func: Union[Callable, UserDefinedFunction], + return_type: DataType, + input_type: DataType, + na_action: Optional[Literal["ignore"]], + session: Session, + args: tuple[Any, ...], + **kwargs: Any, +) -> UserDefinedFunction: + """ + Creates Snowpark user defined function to use like a columnar expression from given func or existing Snowpark user defined function. + + Args: + func: a Python function or Snowpark user defined function. + return_type: return type of the function as Snowpark type. + input_type: input type of the function as Snowpark type. + na_action: if "ignore", use strict mode. + session: Snowpark session, should be identical with pd.session + args: positional arguments to pass to the UDF + **kwargs: keyword arguments to pass to the UDF + + Returns: + Snowpark user defined function. + """ + + # Start with session packages. + packages = list(session.get_packages().values()) + + # Snowpark function with annotations, extract underlying func to wrap. + if isinstance(func, UserDefinedFunction): + # Ensure return_type specified is identical. + assert func._return_type == return_type + + # Append packages from function. + packages += func._packages + + # Below the function func is wrapped again, extract here the underlying Python function. + func = func.func + + if isinstance(return_type, VariantType): + + def apply_func(x): # type: ignore[no-untyped-def] # pragma: no cover + result = [] + # When the return type is Variant, the return value must be json-serializable + # Calling tolist() convert np.int*, np.bool*, etc. (which is not + # json-serializable) to python native values + for e in x.apply(func, args=args, **kwargs).tolist(): + result.append(handle_missing_value_in_variant(e)) + return result + + else: + + def apply_func(x): # type: ignore[no-untyped-def] # pragma: no cover + return x.apply(func, args=args, **kwargs) + + func_udf = udf( + apply_func, + return_type=PandasSeriesType(return_type), + input_types=[PandasSeriesType(input_type)], + strict=bool(na_action == "ignore"), + session=session, + packages=packages, + ) + return func_udf + + +def handle_missing_value_in_variant(value: Any) -> Any: + """ + Returns the correct NULL value in a variant column when a UDF is applied. + + Snowflake supports two types of NULL values, JSON NULL and SQL NULL in variant data. + In Snowflake Python UDF, a VARIANT JSON NULL is translated to Python None and A SQL NULL is + translated to a Python object, which has the `is_sql_null` attribute. + See details in + https://docs.snowflake.com/en/user-guide/semistructured-considerations#null-values + https://docs.snowflake.com/en/developer-guide/udf/python/udf-python-designing#null-values + + In Snowpark pandas apply/applymap API with a variant column, we return JSON NULL if a Python + None is returned in UDF (follow the same as Python UDF), and return SQL null for all other + pandas missing values (np.nan, pd.NA, pd.NaT). Note that pd.NA, pd.NaT are not + json-serializable, so we need to return a json-serializable value anyway (None or SqlNullWrapper()) + """ + + class SqlNullWrapper: + def __init__(self) -> None: + self.is_sql_null = True + + if is_scalar(value) and native_pd.isna(value): + if value is None: + return None + else: + return SqlNullWrapper() + else: + return value + + +def convert_numpy_int_result_to_int(value: Any) -> Any: + """ + If the result is a numpy int, convert it to a python int. + + Use this function to make UDF results JSON-serializable. numpy ints are not + JSON-serializable, but python ints are. Note that this function cannot make + all results JSON-serializable, e.g. it will not convert make + [1, np.int64(3)] or [[np.int64(3)]] serializable by converting the numpy + ints to python ints. However, it's very common for functions to return + numpy integers or dataframes or series thereof, so if we apply this function + to the result (in case the function returns an integer) or each element of + the result (in case the function returns a dataframe or series), we can + make sure that we return a JSON-serializable column to snowflake. + + Args + ---- + value: The value to fix + + Returns + ------- + int(value) if the value is a numpy int, otherwise the value. + """ + return int(value) if np.issubdtype(type(value), np.integer) else value + + +def deduce_return_type_from_function( + func: Union[AggFuncType, UserDefinedFunction] +) -> Optional[DataType]: + """ + Deduce return type if possible from a function, list, dict or type object. List will be mapped to ArrayType(), + dict to MapType(), and if a type object (e.g., str) is given a mapping will be consulted. + Args: + func: callable function, object or Snowpark UserDefinedFunction that can be passed in pandas to reference a function. + + Returns: + Snowpark Datatype or None if no return type could be deduced. + """ + + # Does function have an @udf decorator? Then return type from it directly. + if isinstance(func, UserDefinedFunction): + return func._return_type + + # get the return type of type hints + # PYTHON_TO_SNOW_TYPE_MAPPINGS contains some Python builtin functions that + # can only return the certain type (e.g., `str` will return string) + # if we can't get the type hints from the function, + # use variant as the default, which can hold any type of value + if isinstance(func, list): + return ArrayType() + elif isinstance(func, dict): + return MapType() + elif func in PYTHON_TO_SNOW_TYPE_MAPPINGS: + return PYTHON_TO_SNOW_TYPE_MAPPINGS[func]() + else: + # handle special case 'object' type, in this case use Variant Type. + # Catch potential TypeError exception here from python_type_to_snow_type. + # If it is not the object type, return None to indicate that type hint could not be extracted successfully. + try: + return get_types_from_type_hints(func, TempObjectType.FUNCTION)[0] + except TypeError as te: + if str(te) == "invalid type ": + return VariantType() + return None + + +def sort_apply_udtf_result_columns_by_pandas_positions( + positions: list[int], + pandas_labels: list[Hashable], + snowflake_quoted_identifiers: list[str], +) -> tuple[list[Hashable], list[str]]: + """ + Sort the columns resulting from a UDTF according the position they should take in the resulting pandas dataframe. + + Args + ---- + positions: Positions the columns should take in the resulting pandas dataframe. + pandas_labels: The pandas labels of the columns + snowflake_quoted_identifiers: The snowflake quoted identifiers of the columns. + + Returns: + ------- + tuple where first element has the sorted pandas labels, and second has the sorted quoted identifiers. + """ + # We group the column information together as a tuple (position, pandas + # label, snowflake identifier) to make it easier for sorting as needed. + ColumnInfo = namedtuple( + "ColumnInfo", + ["position", "pandas_label", "snowflake_quoted_identifier"], + ) + + column_info = [ + ColumnInfo(position, pandas_label, snowflake_quoted_identifier) + for position, pandas_label, snowflake_quoted_identifier in zip( + positions, + pandas_labels, + snowflake_quoted_identifiers, + ) + ] + + # Sort based on the column position information. + column_info.sort(key=lambda x: x.position) + + pandas_labels = [info.pandas_label for info in column_info] + snowflake_quoted_identifiers = [ + info.snowflake_quoted_identifier for info in column_info + ] + return pandas_labels, snowflake_quoted_identifiers + + +def get_metadata_from_groupby_apply_pivot_result_column_names( + func_result_snowflake_quoted_identifiers: list[str], +) -> tuple[list[Hashable], list[Hashable], list[str], list[Hashable], list[str]]: + """ + Extract the pandas and snowflake metadata from the column names of the pivot result for groupby.apply. + + Args: + func_result_snowflake_quoted_identifiers: + The identifiers of the columns that represent the function result. + + Returns: + A tuple containing the following, in the order below: + 1. A list containing the names of the column index for the resulting dataframe + 2. A list containing the pandas labels of the data columns in the function result + 3. A list containing the snowflake quoted identifiers of the data columns in the function result. + 4. A list containing the pandas labels of the index columns in the function result + 5. A list containing the snowflake quoted identifiers of the index columns in the function result + + Examples + -------- + # not doing a doctest because it seems to choke on some of the input characters + # due to the escaping. + + input: + + get_metadata_from_groupby_apply_pivot_result_column_names([ + # this representa a data column named ('a', 'group_key') at position 0 + '"\'{""0"": ""a"", ""1"": ""group_key"", ""data_pos"": 0, ""names"": [""c1"", ""c2""]}\'"', + # this represents a data column named ('b', 'int_col') at position 1 + '"\'{""0"": ""b"", ""1"": ""int_col"", ""data_pos"": 1, ""names"": [""c1"", ""c2""]}\'"', + # this repesents a data column named ('b', 'string_col') at position 2 + '"\'{""0"": ""b"", ""1"": ""string_col"", ""data_pos"": 2, ""names"": [""c1"", ""c2""]}\'"', + # this represents an index column for an index level named "i1" + '"\'{""index_pos"": 0, ""name"": ""i1""}\'"', + # this represents an index column for an index level named "i2" + '"\'{""index_pos"": 1, ""name"": ""i2""}\'"' + ]) + + output: + + ( + # these are the column index's names + ['c1', 'c2'], + # these are data column labels + [('a', 'group_key'), ('b', 'int_col'), ('b', 'string_col')], + # these are the snowflake quoted identifiers of the data columns + ['"\'{""0"": ""a"", ""1"": ""group_key"", ""data_pos"": 0, ""names"": [""c1"", ""c2""]}\'"', + '"\'{""0"": ""b"", ""1"": ""int_col"", ""data_pos"": 1, ""names"": [""c1"", ""c2""]}\'"', + '"\'{""0"": ""b"", ""1"": ""string_col"", ""data_pos"": 2, ""names"": [""c1"", ""c2""]}\'"' + ], + # these are the names of the index levels + ['i1', 'i2'], + # these are the snowflake quoted identifiers of the index columns + ['"\'{""index_pos"": 0, ""name"": ""i1""}\'"', '"\'{""index_pos"": 1, ""name"": ""i2""}\'"'] + ) + + """ + index_column_snowflake_quoted_identifiers = [] + data_column_snowflake_quoted_identifiers = [] + data_column_kv_maps = [] + index_column_kv_maps = [] + index_column_pandas_labels = [] + data_column_pandas_labels = [] + column_index_names = None + for identifier in func_result_snowflake_quoted_identifiers: + object_map = parse_snowflake_object_construct_identifier_to_map(identifier) + if "index_pos" in object_map: + index_column_snowflake_quoted_identifiers.append(identifier) + index_column_pandas_labels.append(object_map["name"]) + index_column_kv_maps.append(object_map) + else: + if column_index_names is None: + # if the object map has no 'names', it represents an + # aggregation, i.e. `func` returned a scalar instead of a + # dataframe or series. The result's columns always have a + # single level named `None`. + column_index_names = object_map.get("names", [None]) + ( + data_column_pandas_label, + data_column_kv_map, + ) = parse_object_construct_snowflake_quoted_identifier_and_extract_pandas_label( + identifier, num_levels=len(column_index_names) + ) + data_column_pandas_labels.append(data_column_pandas_label) + data_column_kv_maps.append(data_column_kv_map) + data_column_snowflake_quoted_identifiers.append(identifier) + assert ( + column_index_names is not None + ), "Pivot result should include at least one data column" + + data_column_positions = [kv["data_pos"] for kv in data_column_kv_maps] + index_column_positions = [kv["index_pos"] for kv in index_column_kv_maps] + + # ignore these cases because we have to merge the different column + # indices + # TODO(SNOW-1232208): Handle this case. Note that the pandas behavior for + # this case when func returns a series is contested + # https://github.com/pandas-dev/pandas/issues/54992 + if len(set(data_column_positions)) != len(data_column_positions): + # We can end up here if the column indices differ either in their names + # or in their values. For example: + # 1) one group returns a dataframe whose columns are pd.Index(['col_0'], name="group_1_columns"), + # and another group returns a dataframe whose columns are pd.Index(['col_0'], name="group_2_columns"). + # + # In this case, the snowflake labels for each result's 0th column are like + # {0: "col_0", "data_pos": 0, "names": ["group_1_columns"]}, + # {0: "col_0", "data_pos", 0, "names": ["group_2_columns"]} + # + # 2) one group returns a dataframe whose columns are pd.Index(['col_0'], name="columns"), + # and another group returns a dataframe whose columns are pd.Index(['col_1']), name="columns"). + # + # In this case, the snowflake labels for each result's 0th column are like + # {0: "col_0", "data_pos": 0, "names": ["columns"]}, + # {0: "col_1", "data_pos", 0, "names": ["columns"]} + raise NotImplementedError( + "No support for applying a function that returns two dataframes that have different labels for the column at a given position, " + + "a function that returns two dataframes that have different column index names, " + + "or a function that returns two series with different names or conflicting labels for the row at a given position." + ) + if len(set(index_column_positions)) != len(index_column_positions): + raise NotImplementedError( + "No support for applying a function that returns two dataframes that have different names for a given index level" + ) + + ( + data_column_pandas_labels, + data_column_snowflake_quoted_identifiers, + ) = sort_apply_udtf_result_columns_by_pandas_positions( + data_column_positions, + data_column_pandas_labels, + data_column_snowflake_quoted_identifiers, + ) + ( + index_column_pandas_labels, + index_column_snowflake_quoted_identifiers, + ) = sort_apply_udtf_result_columns_by_pandas_positions( + index_column_positions, + index_column_pandas_labels, + index_column_snowflake_quoted_identifiers, + ) + + return ( + column_index_names, + data_column_pandas_labels, + data_column_snowflake_quoted_identifiers, + index_column_pandas_labels, + index_column_snowflake_quoted_identifiers, + ) + + +def groupby_apply_pivot_result_to_final_ordered_dataframe( + ordered_dataframe: OrderedDataFrame, + agg_func: Callable, + by_snowflake_quoted_identifiers_list: list[str], + sort_method: GroupbyApplySortMethod, + as_index: bool, + original_row_position_snowflake_quoted_identifier: str, + group_key_appearance_order_quoted_identifier: str, + row_position_within_group_snowflake_quoted_identifier: str, + data_column_snowflake_quoted_identifiers: list[str], + index_column_snowflake_quoted_identifiers: list[str], + renamed_data_column_snowflake_quoted_identifiers: list[str], + renamed_index_column_snowflake_quoted_identifiers: list[str], + new_index_identifier: str, + func_returned_dataframe: bool, +) -> OrderedDataFrame: + """ + Convert the intermediate groupby.apply result to the final OrderedDataFrame. + + Sort in the correct order and rename index and data columns as needed. Add + an index column if as_index=False. + + Args: + ordered_dataframe: + The intermediate result. + agg_func: + The original function passed to groupby.apply + by_snowflake_quoted_identifiers_list: + identifiers for columns we're grouping by + sort_method: + How to sort the result + as_index: + If true, add group keys as levels in the index. Otherwise, generate a + new index that is equivalent to the new row positions. + original_row_position_snowflake_quoted_identifier: + The label of the original row that each result row originates from. + group_key_appearance_order_quoted_identifier: + The identifier for the column that tells the position of the row + where this group key first occurred in the input dataframe. + row_position_within_group_snowflake_quoted_identifier: + The label of the row position within each group result. + data_column_snowflake_quoted_identifiers: + The identifiers of the data columns of the function results. + index_column_snowflake_quoted_identifiers: + The identifiers of the index columns of the function results. + renamed_data_column_snowflake_quoted_identifiers: + What to rename the data columns to + renamed_index_column_snowflake_quoted_identifiers: + What to rename the index columns to + new_index_identifier: + The identifier for the new index level that we add if as_index=False. + func_returned_dataframe: + Whether `agg_func` returned a pandas DataFrame + Returns: + Ordered dataframe in correct order with all the final snowflake identifiers. + + """ + return_variant, return_type = check_return_variant_and_get_return_type(agg_func) + return ordered_dataframe.sort( + *( + OrderingColumn(x) + for x in ( + *( + by_snowflake_quoted_identifiers_list + if sort_method is GroupbyApplySortMethod.GROUP_KEY_COMPARISON_ORDER + else [ + group_key_appearance_order_quoted_identifier, + ] + if sort_method is GroupbyApplySortMethod.GROUP_KEY_APPEARANCE_ORDER + else [original_row_position_snowflake_quoted_identifier] + ), + row_position_within_group_snowflake_quoted_identifier, + ) + ) + ).select( + *( + # For `func` returning a dataframe: + # if as_index=True: + # the group keys, i.e. the by columns, become the first + # levels of the result index + # If as_index=False: + # We drop the group keys. + # Otherwise: + # We always include the group keys. + by_snowflake_quoted_identifiers_list + if (not func_returned_dataframe or as_index) + else [] + ), + *( + # Whether `func` returns a dataframe or not, when as_index=False, we + # we need to add a new index level that shows where the groups came + # from. + # if sorting by original row order: + # the original row position itself is the new index level. + # Otherwise: + # sort the groups (either in GROUP_KEY_COMPARISON_ORDER or + # in GROUP_KEY_APPEARANCE_ORDER) and assign the + # label i to all rows that came from func(group_i). + [ + original_row_position_snowflake_quoted_identifier + if sort_method is GroupbyApplySortMethod.ORIGINAL_ROW_ORDER + else ( + dense_rank().over( + Window.order_by( + *( + SnowparkColumn(col).asc_nulls_last() + for col in ( + by_snowflake_quoted_identifiers_list + if sort_method + is GroupbyApplySortMethod.GROUP_KEY_COMPARISON_ORDER + else [group_key_appearance_order_quoted_identifier] + ) + ) + ) + ) + - 1 + ).as_(new_index_identifier) + ] + if not as_index + else [] + ), + *[ + ( + col(old_quoted_identifier).as_(quoted_identifier) + if return_variant + else col(old_quoted_identifier).cast(return_type).as_(quoted_identifier) + ) + for old_quoted_identifier, quoted_identifier in zip( + data_column_snowflake_quoted_identifiers + + index_column_snowflake_quoted_identifiers, + renamed_data_column_snowflake_quoted_identifiers + + renamed_index_column_snowflake_quoted_identifiers, + ) + ], + ) + + +def groupby_apply_create_internal_frame_from_final_ordered_dataframe( + ordered_dataframe: OrderedDataFrame, + func_returned_dataframe: bool, + as_index: bool, + group_keys: bool, + by_pandas_labels: list[Hashable], + by_snowflake_quoted_identifiers: list[str], + func_result_data_column_pandas_labels: list[Hashable], + func_result_data_column_snowflake_quoted_identifiers: list[str], + func_result_index_column_pandas_labels: list[Hashable], + func_result_index_column_snowflake_quoted_identifiers: list[str], + column_index_names: list[str], + new_index_identifier: str, + original_data_column_pandas_labels: list[Hashable], +) -> InternalFrame: + """ + Create the InternalFrame for the groupby.apply result from the final OrderedDataFrame. + + Designate the appropriate snowflake columns as data columns and index + columns. + + Args: + ordered_dataframe: + The final, sorted OrderedDataFrame with the result of groupby.apply + func_returned_dataframe: + Whether the function returned a pandas DataFrame. + as_index: + Whether to include groups in the index. + group_keys: + The group_keys argument to groupby() + by_pandas_labels: + The labels of the grouping columns. + by_snowflake_quoted_identifiers: + The snowflake identifiers of the grouping columns. + func_result_data_column_pandas_labels: + The pandas labels for the columns resulting from calling func() on + each group. Note that these are assumed to be the same across groups. + func_result_data_column_snowflake_quoted_identifiers: + Snowflake identifiers for the columns resulting from calling func() + on each group. Note that these are assumed to be the same across groups. + func_result_index_column_pandas_labels: + The pandas labels for the index levels resulting from calling func() on + each group. Note that these are assumed to be the same across groups. + func_result_index_column_snowflake_quoted_identifiers: + Snowflake identifiers for the index levels resulting from calling func() + on each group. Note that these are assumed to be the same across groups. + column_index_names: + The names of the result's column index. + new_index_identifier: + If as_index=False, use this identifier for a new index level that + indicates which group each chunk of the result came from. + original_data_column_pandas_labels: + The data column pandas labels of the original dataframe. + + Returns: + An InternalFrame representing the final result. + """ + if not as_index and not func_returned_dataframe: + # If func has not returned a dataframe and as_index=False, we put some + # of the by columns in the result instead of in the index. + # note we only include columns from the original frame, and we don't + # include any index levels that we grouped by: + # https://github.com/pandas-dev/pandas/blob/654c6dd5199cb2d6d522dde4c4efa7836f971811/pandas/core/groupby/groupby.py#L1308-L1311 + data_column_pandas_labels = [] + data_column_snowflake_quoted_identifiers = [] + for label, identifier in zip(by_pandas_labels, by_snowflake_quoted_identifiers): + if label in original_data_column_pandas_labels: + data_column_pandas_labels.append(label) + data_column_snowflake_quoted_identifiers.append(identifier) + # If func returned a scalar (i.e. not a dataframe or series), we need to + # call the column with the function result None instead of + # MODIN_UNNAMED_SERIES_LABEL. + if func_result_data_column_pandas_labels == [MODIN_UNNAMED_SERIES_LABEL]: + data_column_pandas_labels.append(None) + else: + data_column_pandas_labels.extend(func_result_data_column_pandas_labels) + data_column_snowflake_quoted_identifiers.extend( + func_result_data_column_snowflake_quoted_identifiers + ) + else: + # Otherwise, the final result's data columns are exactly the columns + # that `func` returned. + data_column_pandas_labels = func_result_data_column_pandas_labels + data_column_snowflake_quoted_identifiers = ( + func_result_data_column_snowflake_quoted_identifiers + ) + + if (not func_returned_dataframe) or group_keys: + # in these cases, we have to prepend index level(s) that indicate which + # group each chunk came from. If as_index=True, these levels are the + # grouping columns themselves. Otherwise, use the new column containing + # the sequential group numbers. + if as_index: + group_pandas_labels = by_pandas_labels + group_quoted_identifiers = by_snowflake_quoted_identifiers + else: + group_pandas_labels = [None] + group_quoted_identifiers = [new_index_identifier] + else: + group_pandas_labels = [] + group_quoted_identifiers = [] + + return InternalFrame.create( + ordered_dataframe=ordered_dataframe, + data_column_pandas_labels=data_column_pandas_labels, + data_column_pandas_index_names=column_index_names, + data_column_snowflake_quoted_identifiers=data_column_snowflake_quoted_identifiers, + index_column_pandas_labels=group_pandas_labels + + func_result_index_column_pandas_labels, + index_column_snowflake_quoted_identifiers=group_quoted_identifiers + + func_result_index_column_snowflake_quoted_identifiers, + ) + + +def groupby_apply_sort_method( + sort: bool, + group_keys: bool, + original_row_position_quoted_identifier: str, + ordered_dataframe_before_sort: OrderedDataFrame, + func_returned_dataframe: bool, +) -> GroupbyApplySortMethod: + """ + Get the sort method that groupby.apply should use on the result rows. + + This function implements the following pandas logic from [1], where + "transform" [2] is a function that returns a result whose index is the + same as the index of the dataframe being grouped. + + if func did not return a dataframe, group_keys=True, or this is not a transform: + if sort: + sort in order of increasing group key values + else: + sort in order of first appearance of group key values + else: + reindex result to the original dataframe's order. + + [1] https://github.com/pandas-dev/pandas/blob/e14a9bd41d8cd8ac52c5c958b735623fe0eae064/pandas/core/groupby/groupby.py#L1196 + [2] https://pandas.pydata.org/docs/user_guide/groupby.html#transformation + + Args: + sort: + The `sort` argument to groupby() + group_keys: + The `group_keys` argument to groupby() + is_transform_quoted_identifier: + The snowflake identifier of the column in the ordered dataframe + that tells whether each row comes from a function that acted + like a transform. + ordered_dataframe_before_sort: + Ordered dataframe containing the intermediate, unsorted + groupby.apply result. + func_returned_dataframe: + Whether the user's `func` returned a dataframe. + + Returns: + enum telling how to sort. + + """ + if not func_returned_dataframe or group_keys: + return ( + GroupbyApplySortMethod.GROUP_KEY_COMPARISON_ORDER + if sort + else GroupbyApplySortMethod.GROUP_KEY_APPEARANCE_ORDER + ) + # to distinguish between transforms and non-transforms, we need to + # execute an extra query to compare the index of the result to the + # index of the original dataframe. + # https://github.com/pandas-dev/pandas/issues/57656#issuecomment-1969454704 + # Need to wrap column name in IDENTIFIER, or else bool agg function + # will treat the name as a string literal + is_transform: bool = not ordered_dataframe_before_sort.agg( + builtin("boolor_agg")( + SnowparkColumn(original_row_position_quoted_identifier) == -1 + ).as_("is_transform") + ).collect()[0][0] + return ( + GroupbyApplySortMethod.ORIGINAL_ROW_ORDER + if is_transform + else ( + GroupbyApplySortMethod.GROUP_KEY_COMPARISON_ORDER + if sort + else GroupbyApplySortMethod.GROUP_KEY_APPEARANCE_ORDER + ) + ) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py new file mode 100644 index 00000000000..23620933c96 --- /dev/null +++ b/src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py @@ -0,0 +1,540 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# +import functools +from collections.abc import Hashable +from dataclasses import dataclass + +from pandas._typing import Callable, Scalar + +from snowflake.snowpark.column import Column as SnowparkColumn +from snowflake.snowpark.functions import col, concat, floor, iff, repeat, when +from snowflake.snowpark.modin.plugin._internal.frame import InternalFrame +from snowflake.snowpark.modin.plugin._internal.join_utils import ( + JoinOrAlignInternalFrameResult, +) +from snowflake.snowpark.modin.plugin._internal.type_utils import infer_object_type +from snowflake.snowpark.modin.plugin._internal.utils import pandas_lit +from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage +from snowflake.snowpark.types import ( + DataType, + NullType, + StringType, + _FractionalType, + _IntegralType, +) + +NAN_COLUMN = pandas_lit("nan").cast("float") + +# set of supported binary operations that can be mapped to Snowflake +SUPPORTED_BINARY_OPERATIONS = { + "truediv", + "rtruediv", + "floordiv", + "rfloordiv", + "mod", + "rmod", + "pow", + "rpow", + "__or__", + "__ror__", + "__and__", + "__rand__", + "add", + "radd", + "sub", + "rsub", + "mul", + "rmul", + "eq", + "ne", + "gt", + "lt", + "ge", + "le", +} + + +def compute_modulo_between_snowpark_columns( + first_operand: SnowparkColumn, + first_datatype: DataType, + second_operand: SnowparkColumn, + second_datatype: DataType, +) -> SnowparkColumn: + """ + Compute modulo between two Snowpark columns ``first_operand`` and ``second_operand``. + Supports only numeric values for operands, raises NotImplementedError otherwise. + Module may produce results different from native pandas or Python. + """ + # 0. if f or s is NULL, return NULL (Snowflake's rule) + # 1. s == 0, return nan + # 2. if s != 0, return f % s + # + # Examples + # -------- + # >>> a = pd.Series([7, 7, -7, -7]) + # >>> b = pd.Series([5, -5, 5, -5]) + # >>> a % b + # 0 2.0 + # 1 2.0 + # 2 -2.0 + # 3 -2.0 + # dtype: float64 + + # >>> a = pd.Series([8.9, -0.22, np.nan, -1.02, 3.15, 2.0]) + # >>> b = pd.Series([-2.3, -76.34, 5.3, 5.3, 8.12]) + # >>> a % b + # 0 2.00 + # 1 -0.22 + # 2 NaN + # 3 -1.02 + # 4 3.15 + # 5 NaN + # dtype: float64 + + # Behavior differences + # -------------------- + # Python pandas 1.5 Snowflake + # 7 % 5 = 2 7 % 5 = 2 7 % 5 = 2 + # 7 % -5 = -3 7 % -5 = -3 7 % -5 = 2 + # -7 % 5 = 3 -7 % 5 = 3 -7 % 5 = -2 + # -7 % -5 = -2 -7 % -5 = -2 -7 % -5 = -2 + # + # Snowpark pandas API differs from native pandas results whenever an operand with a negative + # sign is used. + + is_first_operand_numeric_type = ( + isinstance(first_datatype, _IntegralType) + or isinstance(first_datatype, _FractionalType) + or isinstance(first_datatype, NullType) + ) + + is_second_operand_numeric_type = ( + isinstance(second_datatype, _IntegralType) + or isinstance(second_datatype, _FractionalType) + or isinstance(second_datatype, NullType) + ) + + if is_first_operand_numeric_type and is_second_operand_numeric_type: + return ( + when(first_operand.is_null() | second_operand.is_null(), None) + .when(second_operand == 0, NAN_COLUMN) + .otherwise(first_operand % second_operand) + ) + else: + ErrorMessage.not_implemented( + "Modulo does not support non-numeric types, consider using a UDF with apply instead." + ) + + +def compute_power_between_snowpark_columns( + first_operand: SnowparkColumn, + second_operand: SnowparkColumn, +) -> SnowparkColumn: + """ + Compute power between two Snowpark columns ``first_operand`` and ``second_operand``. + """ + # 0. if f == 1 or s == 0, return 1 or 1.0 based on f's type (pandas' behavior) + # 1. if f or s is NULL, return NULL (Snowflake's behavior) + # 2. if f is nan, or s is nan, or f < 0 and s can not be cast to int without loss (int(s) != s), return nan + # In Snowflake, if f < 0 and s is not an integer, an invalid floating point operation will be raised. + # E.g., pow(-7, -10.0) is valid, but pow(-7, -10.1) is invalid in snowflake. + # In pandas, pow(-7, -10.1) returns NaN. + # 3. else return f ** s + result = ( + when((first_operand == 1) | (second_operand == 0), 1) + .when(first_operand.is_null() | second_operand.is_null(), None) + .when( + (first_operand == NAN_COLUMN) + | (second_operand == NAN_COLUMN) + | ( + (first_operand < 0) + # it checks whether the value can be cast int without loss + & (second_operand.cast("int") != second_operand) + ), + NAN_COLUMN, + ) + .otherwise(first_operand**second_operand) + ) + return result + + +def is_binary_op_supported(op: str) -> bool: + """ + check whether binary operation is mappable to Snowflake + Args + op: op as string + + Returns: + True if binary operation can be mapped to Snowflake/Snowpark, else False + """ + + return op in SUPPORTED_BINARY_OPERATIONS + + +def compute_binary_op_between_snowpark_columns( + op: str, + first_operand: SnowparkColumn, + first_datatype: Callable[[], DataType], + second_operand: SnowparkColumn, + second_datatype: Callable[[], DataType], +) -> SnowparkColumn: + """ + Compute pandas binary operation for two SnowparkColumns + Args: + op: pandas operation + first_operand: SnowparkColumn for lhs + first_datatype: Callable for Snowpark Datatype for lhs, this is lazy so we can avoid pulling the value if + it is not needed. + second_operand: SnowparkColumn for rhs + second_datatype: Callable for Snowpark DateType for rhs, this is lazy so we can avoid pulling the value if + it is not needed. + + Returns: + SnowparkColumn expr for translated pandas operation + """ + + binary_op_result_column = None + + # some operators and the data types have to be handled specially to align with pandas + # However, it is difficult to fail early if the arithmetic operator is not compatible + # with the data type, so we just let the server raise exception (e.g. a string minus a string). + if op in ["truediv", "rtruediv", "floordiv", "rfloordiv"]: + # rtruediv means b/a, rfloordiv means b//a in Python + if op in ["rtruediv", "rfloordiv"]: + first_operand, second_operand = ( + second_operand, + first_operand, + ) + + binary_op_result_column = first_operand / second_operand + + if op in ["floordiv", "rfloordiv"]: + binary_op_result_column = floor(binary_op_result_column) + elif op in ["mod", "rmod"]: + if op == "rmod": + first_operand, second_operand = ( + second_operand, + first_operand, + ) + binary_op_result_column = compute_modulo_between_snowpark_columns( + first_operand, first_datatype(), second_operand, second_datatype() + ) + elif op in ["pow", "rpow"]: + if op == "rpow": + first_operand, second_operand = ( + second_operand, + first_operand, + ) + binary_op_result_column = compute_power_between_snowpark_columns( + first_operand, second_operand + ) + elif op in ["__or__", "__ror__"]: + binary_op_result_column = first_operand | second_operand + elif op in ["__and__", "__rand__"]: + binary_op_result_column = first_operand & second_operand + elif op in ["add", "radd", "mul", "rmul"]: + + # string/string case (only for add/radd) + if isinstance(second_datatype(), StringType) and isinstance( + first_datatype(), StringType + ): + if "add" == op: + binary_op_result_column = concat(first_operand, second_operand) + elif "radd" == op: + binary_op_result_column = concat(second_operand, first_operand) + + # string/integer case (only for mul/rmul) + if op in ["mul", "rmul"] and ( + ( + isinstance(second_datatype(), _IntegralType) + and isinstance(first_datatype(), StringType) + ) + or ( + isinstance(second_datatype(), StringType) + and isinstance(first_datatype(), _IntegralType) + ) + ): + # Snowflake's repeat doesn't support negative number + # but pandas will return an empty string + + # swap first_operand with second_operand because REPEAT(, ) expects to be string + if isinstance(first_datatype(), _IntegralType): + first_operand, second_operand = second_operand, first_operand + + binary_op_result_column = iff( + second_operand > pandas_lit(0), + repeat(first_operand, second_operand), + pandas_lit(""), + ) + + # If there is no special binary_op_result_column result, it means the operator and + # the data type of the column don't need special handling. Then we get the overloaded + # operator from Snowpark Column class, e.g., __add__ to perform binary operations. + if binary_op_result_column is None: + binary_op_result_column = getattr(first_operand, f"__{op}__")(second_operand) + + return binary_op_result_column + + +def compute_binary_op_between_snowpark_column_and_scalar( + op: str, + first_operand: SnowparkColumn, + datatype: Callable[[], DataType], + second_operand: Scalar, +) -> SnowparkColumn: + """ + Compute the binary operation between a Snowpark column and a scalar. + Args: + op: the name of binary operation + first_operand: The SnowparkColumn for lhs + datatype: Callable for Snowpark data type, this is lazy so we can avoid pulling the value if + it is not needed. + second_operand: Scalar value + + Returns: + The result as a Snowpark column + """ + + def second_datatype() -> DataType: + return infer_object_type(second_operand) + + second_operand = pandas_lit(second_operand) + + return compute_binary_op_between_snowpark_columns( + op, first_operand, datatype, second_operand, second_datatype + ) + + +def compute_binary_op_between_scalar_and_snowpark_column( + op: str, + first_operand: Scalar, + second_operand: SnowparkColumn, + datatype: Callable[[], DataType], +) -> SnowparkColumn: + """ + Compute the binary operation between a scalar and a Snowpark column. + Args: + op: the name of binary operation + first_operand: Scalar value + second_operand: The SnowparkColumn for rhs + datatype: Callable for Snowpark data type, this is lazy so we can avoid pulling the value if + it is not needed. + + Returns: + The result as a Snowpark column + """ + + def first_datatype() -> DataType: + return infer_object_type(first_operand) + + first_operand = pandas_lit(first_operand) + + return compute_binary_op_between_snowpark_columns( + op, first_operand, first_datatype, second_operand, datatype + ) + + +def compute_binary_op_with_fill_value( + op: str, + lhs: SnowparkColumn, + lhs_datatype: Callable[[], DataType], + rhs: SnowparkColumn, + rhs_datatype: Callable[[], DataType], + fill_value: Scalar, +) -> SnowparkColumn: + """ + Helper method for performing binary operations. + 1. Fills NaN/None values in the lhs and rhs with the given fill_value. + 2. Computes the binary operation expression for lhs rhs. + + fill_value replaces NaN/None values when only either lhs or rhs is NaN/None, not both lhs and rhs. + For instance, with fill_value = 100, + 1. Given lhs = None and rhs = 10, lhs is replaced with fill_value. + result = lhs + rhs => None + 10 => 100 (replaced) + 10 = 110 + 2. Given lhs = 3 and rhs = None, rhs is replaced with fill_value. + result = lhs + rhs => 3 + None => 3 + 100 (replaced) = 103 + 3. Given lhs = None and rhs = None, neither lhs nor rhs is replaced since they both are None. + result = lhs + rhs => None + None => None. + + Args: + op: pandas operation to perform between lhs and rhs + lhs: the lhs SnowparkColumn + lhs_datatype: Callable for Snowpark Datatype for lhs + rhs: the rhs SnowparkColumn + rhs_datatype: Callable for Snowpark Datatype for rhs + fill_value: Fill existing missing (NaN) values, and any new element needed for + successful DataFrame alignment, with this value before computation. + + Returns: + SnowparkColumn expression for translated pandas operation + """ + lhs_cond, rhs_cond = lhs, rhs + if fill_value is not None: + fill_value_lit = pandas_lit(fill_value) + lhs_cond = iff(lhs.is_null() & ~rhs.is_null(), fill_value_lit, lhs) + rhs_cond = iff(rhs.is_null() & ~lhs.is_null(), fill_value_lit, rhs) + + return compute_binary_op_between_snowpark_columns( + op, lhs_cond, lhs_datatype, rhs_cond, rhs_datatype + ) + + +def merge_label_and_identifier_pairs( + sorted_column_labels: list[str], + q_frame_sorted: list[tuple[str, str]], + q_missing_sorted: list[tuple[str, str]], +) -> list[tuple[str, str]]: + """ + Helper function to create a merged list of column label/snowflake quoted identifiers. Assume q_frame_sorted and q_missing_sorted are disjoint wrt to labels. + + Example: + Given sorted_column_labels = [1, 2, 3] + and q_frame_sorted = [(1, "A"), (3, "C")] q_missing_sorted = [(2, "B")] + this function will produce as output [(1, "A"), (2, "B"), (3, "C")]. + Each q_frame_sorted and q_missing_sorted are lists of label/identifier pairs. + I.e., [(1, "A"), (3, "C")] should be understood as 1 -> "A", 3 -> "B". + They're each assumed to be sorted with respect to their labels, and all labels must be contained within + the sorted_column_labels variable. + The result is a combined, sorted representation 1 -> "A", 2 -> "B", 3 -> "C" which resembles the merge-step + of a classical mergesort algorithm. + Args: + sorted_column_labels: The labels to merge for + q_frame_sorted: sorted list of label/identifier pairs. All labels must be contained within sorted_column_labels. + q_missing_sorted: sorted list of label/identifier pairs. All labels must be contained within sorted_column_labels. + + Returns: + List of label/identifier pairs. If the labels were projected out, they would form sorted_column_labels. + """ + if len(q_frame_sorted) > 0 and len(q_missing_sorted) > 0: + # merge labels/identifiers + + i_frame = 0 + i_missing = 0 + + pairs = [] + for label in sorted_column_labels: + # Leave merge iff either queue is exhausted. + if i_frame >= len(q_frame_sorted) or i_missing >= len(q_missing_sorted): + break + + if label == q_frame_sorted[i_frame][0]: + pairs.append(q_frame_sorted[i_frame]) + i_frame += 1 + elif label == q_missing_sorted[i_missing][0]: + pairs.append(q_missing_sorted[i_missing]) + i_missing += 1 + # else case not relevant here, because labels of q_frame_sorted and q_missing_sorted must be disjoint. + + if i_frame < len(q_frame_sorted): + pairs += q_frame_sorted[i_frame:] + elif i_missing < len(q_missing_sorted): + pairs += q_missing_sorted[i_missing:] + + return pairs + elif len(q_missing_sorted) == 0: + return q_frame_sorted + else: + return q_missing_sorted + + +@dataclass +class BinaryOperationPair: + # For detailed description of the members, cf. `prepare_binop_pairs_between_dataframe_and_dataframe`. + # This is a helper class to hold the results of this function. + identifier: str + lhs: SnowparkColumn + lhs_datatype: Callable + rhs: SnowparkColumn + rhs_datatype: Callable + + +def prepare_binop_pairs_between_dataframe_and_dataframe( + aligned_rhs_and_lhs: JoinOrAlignInternalFrameResult, + combined_data_labels: list[Hashable], + lhs_frame: InternalFrame, + rhs_frame: InternalFrame, +) -> list[BinaryOperationPair]: + """ + Returns a list of BinaryOperationPair which can be used to carry out a binary operation between two dataframes. + Each BinaryOperationPair consists of the following: + - identifier: an identifier that can be used within align_result to hold the result of a binary operation between two columns + - lhs: a SnowparkColumn expression for the left operand + - lhs_typer: a function to lazily determine the Snowpark datatype of `lhs` + - rhs: a SnowparkColumn expression for the right operand + - rhs_typer: a function to lazily determine the Snowpark datatype of `rhs` + + BinaryOperationPair will be returned in the order of `combined_data_labels` + + Args: + aligned_rhs_and_lhs: the align result between other_frame and self_frame + combined_data_labels: the combined data labels to be used for align result. + rhs_frame: a frame representing the right side. + lhs_frame: a frame representing the left side. + + Returns: + List of BinaryOperationPair. + """ + # construct list of pairs which label belongs to which quoted identifier + type_map = aligned_rhs_and_lhs.result_frame.quoted_identifier_to_snowflake_type() + left_right_pairs = [] + for label in combined_data_labels: + left_identifier, right_identifier = None, None + + try: + left_idx = lhs_frame.data_column_pandas_labels.index(label) + left_quoted_identifier = lhs_frame.data_column_snowflake_quoted_identifiers[ + left_idx + ] + left_identifier = ( + aligned_rhs_and_lhs.result_column_mapper.left_quoted_identifiers_map[ + left_quoted_identifier + ] + ) + left = col(left_identifier) + # To avoid referencing always the last right_identifier in the loop, use functools.partial + left_typer = functools.partial( + lambda identifier: type_map[identifier], left_identifier + ) # noqa: E731 + except ValueError: + # lhs label not in list. + left = pandas_lit(None) + left_typer = lambda: infer_object_type( # type: ignore[assignment] # noqa: E731 + None + ) + + try: + right_idx = rhs_frame.data_column_pandas_labels.index(label) + right_quoted_identifier = ( + rhs_frame.data_column_snowflake_quoted_identifiers[right_idx] + ) + right_identifier = ( + aligned_rhs_and_lhs.result_column_mapper.right_quoted_identifiers_map[ + right_quoted_identifier + ] + ) + right = col(right_identifier) + # To avoid referencing always the last right_identifier in the loop, use functools.partial + right_typer = functools.partial( + lambda identifier: type_map[identifier], right_identifier + ) # noqa: E731 + except ValueError: + # rhs label not in list + right = pandas_lit(None) + right_typer = lambda: infer_object_type(None) # type: ignore[assignment] # noqa: E731 + + identifier_to_replace = left_identifier or right_identifier + assert identifier_to_replace, "either identifier must be valid" + + # We return a lambda to determine the datatype of each operand lazily as this allows to defer + # invoking a DESCRIBE query as much as possible. + left_right_pairs.append( + BinaryOperationPair( + identifier=identifier_to_replace, + lhs=left, + lhs_datatype=left_typer, + rhs=right, + rhs_datatype=right_typer, + ) + ) + return left_right_pairs diff --git a/src/snowflake/snowpark/modin/plugin/_internal/concat_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/concat_utils.py new file mode 100644 index 00000000000..7da729837bf --- /dev/null +++ b/src/snowflake/snowpark/modin/plugin/_internal/concat_utils.py @@ -0,0 +1,347 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# +from collections.abc import Hashable, Sequence +from typing import Literal, Optional, Union + +import pandas as native_pd + +from snowflake.snowpark._internal.type_utils import ColumnOrName +from snowflake.snowpark.functions import array_construct +from snowflake.snowpark.modin.plugin._internal import join_utils +from snowflake.snowpark.modin.plugin._internal.frame import InternalFrame +from snowflake.snowpark.modin.plugin._internal.ordered_dataframe import OrderingColumn +from snowflake.snowpark.modin.plugin._internal.utils import ( + INDEX_LABEL, + append_columns, + pandas_lit, +) +from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage + +CONCAT_POSITION_COLUMN_LABEL = "concat_position" + + +def add_keys_as_column_levels( + columns: native_pd.Index, + frames: list[InternalFrame], + keys: Sequence[Hashable], + names: Union[list[Hashable], None], +) -> native_pd.Index: + """ + Concat all column names from given ``frames``. Also add ``keys`` as outermost + level of column labels. + Args: + columns: Column index of concatenated frame. + frames: A list of internal frames. + keys: A list of hashable to be used as keys. Length of keys must be same as + length of frames. + names: Optional names for levels in column index. + + Returns: + Concatenated column names as native pandas index. + """ + assert len(keys) == len(frames), "Length of keys must be same as length of frames" + + key_values = [] + for key, frame in zip(keys, frames): + key_values.extend([key] * len(frame.data_column_pandas_labels)) + keys_index = native_pd.Index(key_values) + # Add 'keys' as outermost level to column labels. + arrays = [keys_index.get_level_values(i) for i in range(keys_index.nlevels)] + [ + columns.get_level_values(i) for i in range(columns.nlevels) + ] + columns = native_pd.MultiIndex.from_arrays(arrays) + names = names or [] + # Fill with 'None' to match the number of levels in column index + while len(names) < columns.nlevels: + names.append(None) + return columns.set_names(names) + + +def convert_to_single_level_index(frame: InternalFrame, axis: int) -> InternalFrame: + """ + If index on given axis is a MultiIndex, convert it to single level index of tuples. + Do nothing if index on given axis has only one level. + + On axis=1, this is equivalent to following operation in pandas. + df.columns = df.columns.to_flat_index() + For example a frame if columns index + pd.MultiIndex.from_tuples([('a', 'b'), ('c', 'd')], names=['x', 'y']) + will be converted to a frame with column index + pd.Index([('a', 'b'), ('c', 'd')]) + + Similarly on axis=0 this is equivalent to following operations in pandas + df.index = df.index.to_flat_index() + + NOTE: Original level names are lost during this operation becomes None. + + Args: + frame: A InternalFrame. + axis: int: {0, 1} + + Returns: + New InternalFrame with single level index. + + """ + assert axis in (0, 1), f"Invalid axis {axis}, allowed values are 0 and 1" + # Because we break up and store a MultiIndex with several Snowpark columns, we can + # perform the single-level index conversion as a no-op. + if frame.num_index_levels(axis=axis) == 1: + return frame + if axis == 1: + return InternalFrame.create( + ordered_dataframe=frame.ordered_dataframe, + data_column_pandas_labels=frame.data_column_pandas_labels, + data_column_snowflake_quoted_identifiers=frame.data_column_snowflake_quoted_identifiers, + # Setting length of index names to 1 will convert column labels from + # multi-index to single level index. + data_column_pandas_index_names=[None], + index_column_pandas_labels=frame.index_column_pandas_labels, + index_column_snowflake_quoted_identifiers=frame.index_column_snowflake_quoted_identifiers, + ) + else: + WarningMessage.tuples_stored_as_array( + "MultiIndex values are compressed to single index of tuples.Snowflake" + " backend doesn't support tuples datatype. Tuple row labels are stored as" + "ARRAY" + ) + index_identifier = ( + frame.ordered_dataframe.generate_snowflake_quoted_identifiers( + pandas_labels=[INDEX_LABEL], + )[0] + ) + ordered_dataframe = append_columns( + frame.ordered_dataframe, + index_identifier, + array_construct(*frame.index_column_snowflake_quoted_identifiers), + ) + return InternalFrame.create( + ordered_dataframe=ordered_dataframe, + index_column_pandas_labels=[None], + index_column_snowflake_quoted_identifiers=[index_identifier], + data_column_pandas_labels=frame.data_column_pandas_labels, + data_column_snowflake_quoted_identifiers=frame.data_column_snowflake_quoted_identifiers, + data_column_pandas_index_names=frame.data_column_pandas_index_names, + ) + + +def union_all( + frame1: InternalFrame, + frame2: InternalFrame, + join: Literal["inner", "outer"], + sort: Optional[bool] = False, +) -> InternalFrame: + """ + Concatenate frames on index axis by taking using UNION operator. + Snowflake identifiers of output frame are based on snowflake identifiers from first + frame. + Args: + frame1: First frame + frame2: Second frame + join: How to handle column index + 'inner': Output frame contains only overlapping columns from both frames. + 'outer': Output frame contains union of columns from both frames. + sort: Sort column axis if True. + + Returns: + New InternalFrame after taking union of given frames. + """ + columns1 = frame1.data_columns_index + columns2 = frame2.data_columns_index + + if join == "inner": + # Preserves the order from calling index. + # For example: + # pd.Index([3, 1, 2]).intersection(pd.Index([1, 2, 3]) will result in + # pd.Index([3, 1, 2]) + data_column_labels = columns1.intersection(columns2, sort=False) + elif join == "outer": + # Preserves the order from calling index. And for labels not in calling index + # preserves the order from argument index. + # For example: + # pd.Index([3, 1, 2]).union(pd.Index([1, 4, 2, 3, 5]) will result in + # pd.Index([3, 1, 2, 4, 5]) + data_column_labels = columns1.union(columns2, sort=False) + else: + raise AssertionError( + f"Invalid join type '{join}'. Accepted values are 'inner' and 'outer'" + ) + if sort: + data_column_labels = data_column_labels.sort_values() + + frame1 = _select_columns(frame1, data_column_labels.tolist()) + frame2 = _select_columns(frame2, data_column_labels.tolist()) + + frame1, frame2 = join_utils.convert_incompatible_types_to_variant( + frame1, + frame2, + frame1.ordered_dataframe.projected_column_snowflake_quoted_identifiers, + frame2.ordered_dataframe.projected_column_snowflake_quoted_identifiers, + ) + + # select data + index + ordering columns for union all + # it is guaranteed that the ordering columns does not overlap with index and data column + # TODO SNOW-956072: remove the following code after removing convert_incompatible_types_to_variant + frame1_identifiers_for_union_all = ( + frame1.index_column_snowflake_quoted_identifiers + + frame1.data_column_snowflake_quoted_identifiers + + frame1.ordering_column_snowflake_quoted_identifiers + ) + frame2_identifiers_for_union_all = ( + frame2.index_column_snowflake_quoted_identifiers + + frame2.data_column_snowflake_quoted_identifiers + + frame2.ordering_column_snowflake_quoted_identifiers + ) + + # In Snowflake UNION ALL operator, the names of the output columns are based on the + # names of the columns of the first query. So here we copy identifiers from + # first frame. + # Reference: https://docs.snowflake.com/en/sql-reference/operators-query + ordered_dataframe = frame1.ordered_dataframe.select( + frame1_identifiers_for_union_all + ).union_all(frame2.ordered_dataframe.select(frame2_identifiers_for_union_all)) + ordered_dataframe = ordered_dataframe.sort(frame1.ordering_columns) + return InternalFrame.create( + ordered_dataframe=ordered_dataframe, + data_column_pandas_labels=frame1.data_column_pandas_labels, + data_column_snowflake_quoted_identifiers=frame1.data_column_snowflake_quoted_identifiers, + data_column_pandas_index_names=frame1.data_column_pandas_index_names, + index_column_pandas_labels=frame1.index_column_pandas_labels, + index_column_snowflake_quoted_identifiers=frame1.index_column_snowflake_quoted_identifiers, + ) + + +def add_key_as_index_columns(frame: InternalFrame, key: Hashable) -> InternalFrame: + """ + Add given 'key' as outermost index columns to given 'frame'. + If 'key' is a tuple multiple columns are added for each element in tuple. + + Args: + frame: InternalFrame + key: key to add as index column + + Returns: + A InternalFrame after adding 'key' as index columns. + """ + if not isinstance(key, tuple): + key = tuple([key]) + new_identifiers = frame.ordered_dataframe.generate_snowflake_quoted_identifiers( + pandas_labels=[INDEX_LABEL] * len(key), + ) + col_values = [pandas_lit(value) for value in key] + ordered_dataframe = append_columns( + frame.ordered_dataframe, new_identifiers, col_values + ) + + # Add key as outermost index columns. + index_column_pandas_labels = [None] * len(key) + frame.index_column_pandas_labels + index_column_snowflake_quoted_identifiers = ( + new_identifiers + frame.index_column_snowflake_quoted_identifiers + ) + + return InternalFrame.create( + ordered_dataframe=ordered_dataframe, + data_column_pandas_labels=frame.data_column_pandas_labels, + data_column_snowflake_quoted_identifiers=frame.data_column_snowflake_quoted_identifiers, + data_column_pandas_index_names=frame.data_column_pandas_index_names, + index_column_pandas_labels=index_column_pandas_labels, + index_column_snowflake_quoted_identifiers=index_column_snowflake_quoted_identifiers, + ) + + +def _select_columns( + frame: InternalFrame, data_column_labels: list[Hashable] +) -> InternalFrame: + """ + Select only the given labels from given frame. If any data column label is missing + in frame add new column with NULL values. + + Args: + frame: An InternalFrame + data_column_labels: A list of pandas labels. + + Returns: + New InternalFrame after only with given data columns. + + """ + select_list: list[ColumnOrName] = [] + + # Add index columns + select_list.extend(frame.index_column_snowflake_quoted_identifiers) + + # Add ordering columns + select_list.extend(frame.ordering_column_snowflake_quoted_identifiers) + + snowflake_ids = frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( + data_column_labels, include_index=False + ) + # Add data columns + data_column_snowflake_identifiers = [] + # A map to keep track number of times a label is already seen. + # Native pandas fails with IndexError when either frame has duplicate labels, with + # the exception when both frames have exact same lables and exact same order. + # In Snowpark pandas, we don't fail concat when duplicates lables are present but + # try to match as many columns as possible from the frames. + label_count_map: dict[Hashable, int] = {} + for label, id_tuple in zip(data_column_labels, snowflake_ids): + if len(id_tuple) <= label_count_map.get(label, 0): + # if missing add new column to frame with NULL values. + snowflake_id = ( + frame.ordered_dataframe.generate_snowflake_quoted_identifiers( + pandas_labels=[label] + )[0] + ) + select_list.append(pandas_lit(None).as_(snowflake_id)) + else: + index = label_count_map.get(label, 0) + snowflake_id = id_tuple[index] + select_list.append(snowflake_id) + label_count_map[label] = index + 1 + + data_column_snowflake_identifiers.append(snowflake_id) + return InternalFrame.create( + ordered_dataframe=frame.ordered_dataframe.select(select_list), + data_column_pandas_labels=data_column_labels, + data_column_snowflake_quoted_identifiers=data_column_snowflake_identifiers, + data_column_pandas_index_names=frame.data_column_pandas_index_names, + index_column_pandas_labels=frame.index_column_pandas_labels, + index_column_snowflake_quoted_identifiers=frame.index_column_snowflake_quoted_identifiers, + ) + + +def add_global_ordering_columns(frame: InternalFrame, position: int) -> InternalFrame: + """ + To create global ordering for concat (axis=0) operation we first ensure a + row position column for local ordering within the frame. Then add another + column to indicate position of this frame among concat frames given by 'position' + parameter. + Now these two columns can be used to determine global ordering. + Args: + frame: Internal frame. + position: position of this frame among all frames being concatenated. + + Returns: + A new frame with updated ordering columns. + + """ + frame = frame.ensure_row_position_column() + ordered_dataframe = frame.ordered_dataframe.sort( + [OrderingColumn(frame.row_position_snowflake_quoted_identifier)] + ) + identifier = ordered_dataframe.generate_snowflake_quoted_identifiers( + pandas_labels=[CONCAT_POSITION_COLUMN_LABEL], + )[0] + ordered_dataframe = append_columns( + ordered_dataframe, identifier, pandas_lit(position) + ) + ordered_dataframe = ordered_dataframe.sort( + OrderingColumn(identifier), *ordered_dataframe.ordering_columns + ) + return InternalFrame.create( + ordered_dataframe=ordered_dataframe, + data_column_pandas_labels=frame.data_column_pandas_labels, + data_column_snowflake_quoted_identifiers=frame.data_column_snowflake_quoted_identifiers, + data_column_pandas_index_names=frame.data_column_pandas_index_names, + index_column_pandas_labels=frame.index_column_pandas_labels, + index_column_snowflake_quoted_identifiers=frame.index_column_snowflake_quoted_identifiers, + ) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/cumulative_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/cumulative_utils.py new file mode 100644 index 00000000000..ce13b0082f7 --- /dev/null +++ b/src/snowflake/snowpark/modin/plugin/_internal/cumulative_utils.py @@ -0,0 +1,201 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# +# +# This file contains utils functions used by cumulative aggregation functions. +# + +import functools +from typing import Any, Callable + +from snowflake.snowpark.column import Column as SnowparkColumn +from snowflake.snowpark.functions import col, iff, sum as sum_sp +from snowflake.snowpark.modin.plugin._internal.aggregation_utils import ( + drop_non_numeric_data_columns, +) +from snowflake.snowpark.modin.plugin._internal.frame import InternalFrame +from snowflake.snowpark.modin.plugin._internal.groupby_utils import ( + check_is_groupby_supported_by_snowflake, + extract_groupby_column_pandas_labels, +) +from snowflake.snowpark.modin.plugin._internal.utils import pandas_lit +from snowflake.snowpark.modin.plugin.compiler import snowflake_query_compiler +from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage +from snowflake.snowpark.modin.utils import MODIN_UNNAMED_SERIES_LABEL +from snowflake.snowpark.window import Window + + +def get_cumagg_col_to_expr_map_axis0( + internal_frame: InternalFrame, + cumagg_func: Callable, + skipna: bool, +) -> dict[SnowparkColumn, SnowparkColumn]: + """ + Map each input column to to a corresponding expression that computes the cumulative aggregation function on that column when axis = 0. + + Args: + internal_frame: InternalFrame. + The internal frame to apply the cumulative aggregation function on. + cumagg_func: Callable + The cumulative aggregation function to apply on the internal frame. + skipna : bool + Exclude NA/null values. If an entire row/column is NA, the result will be NA. + + Returns: + Dict[SnowparkColumn, SnowparkColumn] + Map between Snowpandas column and the corresponding expression that computes the cumulative aggregation function on that column. + """ + window = Window.order_by( + internal_frame._modin_frame.row_position_snowflake_quoted_identifier + ).rows_between(Window.UNBOUNDED_PRECEDING, Window.CURRENT_ROW) + if skipna: + cumagg_col_to_expr_map = { + snowflake_quoted_id: iff( + col(snowflake_quoted_id).is_null(), + pandas_lit(None), + cumagg_func(snowflake_quoted_id).over(window), + ) + for snowflake_quoted_id in internal_frame._modin_frame.data_column_snowflake_quoted_identifiers + } + else: + # When skipna is False and the aggregated values (form prior rows) contain any nulls, then the cumulative aggregate is also null. + # For this reason, we count the number of nulls in the window and compare to zero using the two nested iff's below. + # Note that this could have also been achieved using COUNT_IF(), but as of this writing it has not been supported by Snowpark yet. + cumagg_col_to_expr_map = { + snowflake_quoted_id: iff( + sum_sp( + iff( + col(snowflake_quoted_id).is_null(), pandas_lit(1), pandas_lit(0) + ) + ).over(window) + > pandas_lit(0), + pandas_lit(None), + cumagg_func(snowflake_quoted_id).over(window), + ) + for snowflake_quoted_id in internal_frame._modin_frame.data_column_snowflake_quoted_identifiers + } + return cumagg_col_to_expr_map + + +def get_groupby_cumagg_frame_axis0( + query_compiler: "snowflake_query_compiler.SnowflakeQueryCompiler", + by: Any, + axis: int, + numeric_only: bool, + groupby_kwargs: dict[str, Any], + cumagg_func: Callable, + cumagg_func_name: str, + ascending: bool = True, +) -> InternalFrame: + """ + Return the output internal frame after applying the cumulative aggregation function on the input internal frame when axis = 0. + + Args: + by: mapping, series, callable, label, pd.Grouper, BaseQueryCompiler, list of such. + Used to determine the groups for the groupby. + axis : 0 (index), 1 (columns) + numeric_only: bool + Include only float, int, boolean columns. + groupby_kwargs: Dict[str, Any] + keyword arguments passed for the groupby. + cumagg_func: Callable + The cumulative aggregation function to apply on the internal frame. + cumagg_func_name: str + The name of the cumulative aggregation function to apply on the internal frame. + ascending : bool + If False, process the window in reverse order. Needed for cumcount. + + Returns: + InternalFrame + Output internal frame after applying the cumulative aggregation function. + """ + level = groupby_kwargs.get("level", None) + dropna = groupby_kwargs.get("dropna", True) + + if not check_is_groupby_supported_by_snowflake(by, level, axis): + ErrorMessage.not_implemented( + f"GroupBy {cumagg_func_name} with by = {by}, level = {level} and axis = {axis} is not supported yet in Snowpark pandas." + ) + + if level is not None and level != 0: + ErrorMessage.not_implemented( + f"GroupBy {cumagg_func_name} with level = {level} is not supported yet in Snowpark pandas." + ) + + by_list = extract_groupby_column_pandas_labels(query_compiler, by, level) + + qc = query_compiler + if numeric_only: + qc = drop_non_numeric_data_columns(query_compiler, by_list) + + by_snowflake_quoted_identifiers_list = [ + # Duplicate labels in by result in a ValueError. + entry[0] + for entry in qc._modin_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( + by_list + ) + ] + + window = ( + Window.partition_by(by_snowflake_quoted_identifiers_list) + .order_by( + qc._modin_frame.ordered_dataframe.ordering_column_snowflake_quoted_identifiers + ) + .rows_between( + Window.UNBOUNDED_PRECEDING if ascending else Window.CURRENT_ROW, + Window.CURRENT_ROW if ascending else Window.UNBOUNDED_FOLLOWING, + ) + ) + + dropna_cond = functools.reduce( + lambda combined_col, col: combined_col | col, + map( + lambda by_snowflake_quoted_identifier: col( + by_snowflake_quoted_identifier + ).is_null(), + by_snowflake_quoted_identifiers_list, + ), + ) + + pandas_labels = [] + new_columns = [] + if cumagg_func_name == "cumcount": + new_col = cumagg_func("*").over(window) - pandas_lit(1) + if dropna: + new_col = iff(dropna_cond, pandas_lit(None), new_col) + if qc._modin_frame.num_index_columns > 1: + pandas_labels.append( + (MODIN_UNNAMED_SERIES_LABEL,) * qc._modin_frame.num_index_columns + ) + else: + pandas_labels.append(MODIN_UNNAMED_SERIES_LABEL) + new_columns.append(new_col) + else: + for pandas_label, snowflake_quoted_identifier in zip( + qc._modin_frame.data_column_pandas_labels, + qc._modin_frame.data_column_snowflake_quoted_identifiers, + ): + if snowflake_quoted_identifier not in by_snowflake_quoted_identifiers_list: + new_col = iff( + col(snowflake_quoted_identifier).is_null(), + pandas_lit(None), + cumagg_func(snowflake_quoted_identifier).over(window), + ) + if dropna: + new_col = iff(dropna_cond, pandas_lit(None), new_col) + + pandas_labels.append(pandas_label) + new_columns.append(new_col) + + result_frame = qc._modin_frame.project_columns(pandas_labels, new_columns) + if cumagg_func_name == "cumcount": + return InternalFrame.create( + ordered_dataframe=result_frame.ordered_dataframe, + data_column_pandas_labels=[None], + data_column_snowflake_quoted_identifiers=result_frame.data_column_snowflake_quoted_identifiers, + index_column_pandas_labels=result_frame.index_column_pandas_labels, + index_column_snowflake_quoted_identifiers=result_frame.index_column_snowflake_quoted_identifiers, + data_column_pandas_index_names=[None], + ) + else: + return result_frame diff --git a/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py new file mode 100644 index 00000000000..771c63eb2ac --- /dev/null +++ b/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py @@ -0,0 +1,314 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +from typing import Sequence, Union + +import numpy as np +import pandas +from pandas import Index, IntervalIndex +from pandas._typing import Scalar +from pandas.core.dtypes.common import is_numeric_dtype +from pandas.core.dtypes.inference import is_scalar +from pandas.core.reshape.tile import _is_dt_or_td + +from snowflake.snowpark.functions import col, iff +from snowflake.snowpark.modin.plugin._internal.frame import InternalFrame +from snowflake.snowpark.modin.plugin._internal.ordered_dataframe import ( + DataFrameReference, + OrderedDataFrame, + OrderingColumn, +) +from snowflake.snowpark.modin.plugin._internal.utils import pandas_lit +from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage +from snowflake.snowpark.types import LongType + + +# This function stems from pandas 2.2.x and has been minimally modified to not require +# the full data, but instead work with min/max values solely. It replaces +# The pandas 2.1.x function from pandas.core.reshape.tile import _convert_bin_to_numeric_type. +def _nbins_to_bins(x_min: Scalar, x_max: Scalar, nbins: int, right: bool) -> Index: + """ + If a user passed an integer N for bins, convert this to a sequence of N + equal(ish)-sized bins. + """ + if is_scalar(nbins) and nbins < 1: + raise ValueError("`bins` should be a positive integer.") # pragma: no cover + + # this snippet of original pandas code is handled outside of this function + # if x_idx.size == 0: + # raise ValueError("Cannot cut empty array") + + # retrieve type of original series used in cut. To speed up processing, + # infer from aggrgates as the type won't change when computing min/max. + x_dtype = pandas.Series([x_min, x_max]).dtype + rng = (x_min, x_max) + mn, mx = rng + + if is_numeric_dtype(x_dtype) and (np.isinf(mn) or np.isinf(mx)): + # GH#24314 + raise ValueError( # pragma: no cover + "cannot specify integer `bins` when input data contains infinity" # pragma: no cover + ) # pragma: no cover + + if mn == mx: # adjust end points before binning + if _is_dt_or_td(x_dtype): # pragma: no cover + # original pandas code (commented): + # # using seconds=1 is pretty arbitrary here + # # error: Argument 1 to "dtype_to_unit" has incompatible type + # # "dtype[Any] | ExtensionDtype"; expected "DatetimeTZDtype | dtype[Any]" + # unit = dtype_to_unit(x_dtype) # type: ignore[arg-type] + # td = Timedelta(seconds=1).as_unit(unit) + # # Use DatetimeArray/TimedeltaArray method instead of linspace + # # error: Item "ExtensionArray" of "ExtensionArray | ndarray[Any, Any]" + # # has no attribute "_generate_range" + # bins = x_idx._values._generate_range( # type: ignore[union-attr] + # start=mn - td, end=mx + td, periods=nbins + 1, freq=None, unit=unit + # ) + ErrorMessage.not_implemented( + "no support for datetime types yet." + ) # pragma: no cover + else: + mn -= 0.001 * abs(mn) if mn != 0 else 0.001 # pragma: no cover + mx += 0.001 * abs(mx) if mx != 0 else 0.001 # pragma: no cover + + bins = np.linspace(mn, mx, nbins + 1, endpoint=True) # pragma: no cover + else: # adjust end points after binning + if _is_dt_or_td(x_dtype): + # original pandas code (commented): + # # Use DatetimeArray/TimedeltaArray method instead of linspace + # + # # error: Argument 1 to "dtype_to_unit" has incompatible type + # # "dtype[Any] | ExtensionDtype"; expected "DatetimeTZDtype | dtype[Any]" + # unit = dtype_to_unit(x_dtype) # type: ignore[arg-type] + # # error: Item "ExtensionArray" of "ExtensionArray | ndarray[Any, Any]" + # # has no attribute "_generate_range" + # bins = x_idx._values._generate_range( # type: ignore[union-attr] + # start=mn, end=mx, periods=nbins + 1, freq=None, unit=unit + # ) + ErrorMessage.not_implemented( + "no support for datetime types yet." + ) # pragma: no cover + else: + bins = np.linspace(mn, mx, nbins + 1, endpoint=True) + adj = (mx - mn) * 0.001 # 0.1% of the range + if right: + bins[0] -= adj + else: + bins[-1] += adj + + return Index(bins) + + +def preprocess_bins_for_cut( + x_min: Scalar, + x_max: Scalar, + bins: Union[int, Sequence[Scalar], pandas.IntervalIndex], + right: bool, + include_lowest: bool, + precision: int, +) -> Union[int, Sequence[Scalar], pandas.IntervalIndex]: + """ + Adjusts bins to be directly used with compute_bin_indices function below. bins for both qcut and cut are given either as int which will create equidistant bins, + as list of scalars (typically float), or IntervalIndex (not supported). + + Args: + x_min: minimum value of the data which will be binned + x_max: maximum value of the data which will be binned + bins: the bins according to pandas which will define the buckets + right: if True use left-open intervals (a, b], if False use right-open intervals [a, b) + include_lowest: If True and right is True, adjust the first interval by 10 ** (-precision), i.e. the first interval will be (a-10 ** (-precision), b]. This will include the minimum value in the binning process. + precision: only used together with include_lowest to adjust the first bin (cf. include_lowest) + + Returns: + adjusted bins + """ + # Code is mostly from original pandas and adjusted for Snowpark pandas API. + + if not np.iterable(bins): + # Call adjusted function from pandas 2.2.x branch + bins = _nbins_to_bins(x_min, x_max, bins, right) + + elif isinstance(bins, IntervalIndex): + if bins.is_overlapping: # pragma: no cover + raise ValueError( + "Overlapping IntervalIndex is not accepted." + ) # pragma: no cover + + else: + bins = Index(bins) + if not bins.is_monotonic_increasing: + raise ValueError("bins must increase monotonically.") + + # if include_lowest is True, then expand first bucket by 10 ** (-precision) + # I.e., for right=True, intervals will have the form (a, b]. + # If a is now contained in the values, it will fall into (a - 10**(-precision), b]. + # For right=False, this is irrelevant. The expansion only works for right=True. + if include_lowest and right: + bins = Index([bins[0] - 10 ** (-precision)] + list(bins[1:].values)) + + return bins + + +def compute_bin_indices( + values_frame: InternalFrame, + cuts_frame: InternalFrame, + n_cuts: int, + right: bool = True, +) -> InternalFrame: + """ + Given a frame of cuts, i.e. borders of bins (strictly increasing) compute for the data in values_frame the index of the bin they fall into. + E.g., cuts_frame may contain the following data + 0.0, 3.0, 7.8, 10.0 + This would form the following bins (0.0, 3.0], (3.0, 7.8], (7.8, 10.0]. + Consequently, this function will return indices in the range 0...2, e.g. for the following data + + -10.0, 0.0, 1.0, 5.6, 9.0, 10.0, 11.0 + + the following bin indices + + nan, nan, 0., 1., 2., 2., nan + + Note that NULL (nan) is returned for data which lies outside of the cuts provided. + + Args: + values_frame: an InternalFrame representing a Series, the data to be binned. + cuts_frame: an InternalFrame representing a Series with data being a strictly monotonically + increasing sequence of floating numbers forming the border of bins. + n_cuts: The length of cuts_frame. Passed in as separate parameter to avoid an additional query. + right: if True use left-open intervals (a, b], if False use right-open intervals [a, b). + Returns: + InternalFrame representing a Series with the bin indices. indices will be in the range [0, n_cuts - 1]. + """ + + # There will be 0, ..., len(cuts_frame) - 1 buckets, result will be thus in this range. + # We can find for values the cut they belong to by performing a left <= join. As this feature is not supported + # within OrderedDataFrame yet, we use the Snowpark layer directly. This should have no negative + # consequences when it comes to building lazy graphs, as both cut and qcut are materializing operations. + + cuts_frame = cuts_frame.ensure_row_position_column() + value_frame = values_frame.ensure_row_position_column() + + ( + bucket_data_identifier, + bucket_row_position_identifier, + value_data_identifier, + value_row_position_identifier, + ) = value_frame.ordered_dataframe.generate_snowflake_quoted_identifiers( + pandas_labels=["b_data", "b_row_pos", "v_data", "v_row_pos"] + ) + + value_index_identifiers = value_frame.index_column_snowflake_quoted_identifiers + + bucket_snowpark_frame = ( + cuts_frame.ordered_dataframe.to_projected_snowpark_dataframe(True, True, True) + ) + value_snowpark_frame = ( + value_frame.ordered_dataframe.to_projected_snowpark_dataframe(True, True, True) + ) + + # relabel to new identifiers to reference within range join below. + bucket_snowpark_frame = bucket_snowpark_frame.select( + col(cuts_frame.data_column_snowflake_quoted_identifiers[0]).as_( + bucket_data_identifier + ), + col(cuts_frame.row_position_snowflake_quoted_identifier).as_( + bucket_row_position_identifier + ), + ) + + value_snowpark_frame = value_snowpark_frame.select( + *tuple(value_index_identifiers), + col(value_frame.data_column_snowflake_quoted_identifiers[0]).as_( + value_data_identifier + ), + col(value_frame.row_position_snowflake_quoted_identifier).as_( + value_row_position_identifier + ), + ) + + # Perform a left join. The idea is to find all values which fall into an interval + # defined by the cuts/bins in the bucket frame. The closest can be then identified using the + # row position. An alternative to this + # was to use an ASOF join with a proper matching condition. + + if right: + ans = value_snowpark_frame.join( + bucket_snowpark_frame, + value_snowpark_frame[value_data_identifier] + <= bucket_snowpark_frame[bucket_data_identifier], + how="left", + lsuffix="_L", + rsuffix="_R", + ) + + # Result will be v_row_pos and min(b_row_pos) - 1. However, to deal with the edge cases we need to correct + # for the case when the result is in the left-most interval. + ans = ans.group_by( + value_index_identifiers + + [value_data_identifier, value_row_position_identifier] + ).min(bucket_row_position_identifier) + else: + # For right=False, perform a >= join and use max(b_row_pos) - 1. + ans = value_snowpark_frame.join( + bucket_snowpark_frame, + value_snowpark_frame[value_data_identifier] + >= bucket_snowpark_frame[bucket_data_identifier], + how="left", + lsuffix="_L", + rsuffix="_R", + ) + + # Result will be v_row_pos and max(q_row_pos) - 1. However, to deal with the edge cases we need to correct + # for the case when the result is in the left-most interval. + ans = ans.group_by( + value_index_identifiers + + [value_data_identifier, value_row_position_identifier] + ).max(bucket_row_position_identifier) + + column_names = ans.columns + bin_index_col = col(column_names[-1]) + + if right: + # An index value of 0 means the data is outside of the first bucket. Set to NULL. All others, perform -1. + # For data outside of the last bucket, the left join will automatically fill it with NULL. + correct_index_expr = iff( + bin_index_col != pandas_lit(0), + bin_index_col - pandas_lit(1), + pandas_lit(None), + ).astype(LongType()) + else: + # For right=False, correct for the bin indices exceeding the max value n_cuts - 1. If the index is larger + # than this number, then set to NULL. + correct_index_expr = iff( + bin_index_col >= pandas_lit(n_cuts - 1), pandas_lit(None), bin_index_col + ).astype(LongType()) + + ans = ans.select( + *tuple(value_index_identifiers), + col(value_row_position_identifier), + correct_index_expr, + ) + column_names = ans.columns + new_data_identifier = column_names[-1] + + # Create OrderedDataFrame and InternalFrame and QC out of this. + # Need to restore index as well which has been passed through. + new_ordered_dataframe = OrderedDataFrame( + DataFrameReference(ans), + projected_column_snowflake_quoted_identifiers=value_index_identifiers + + [new_data_identifier], + ordering_columns=[OrderingColumn(value_row_position_identifier)], + row_position_snowflake_quoted_identifier=value_row_position_identifier, + ) + + new_frame = InternalFrame.create( + ordered_dataframe=new_ordered_dataframe, + data_column_pandas_labels=value_frame.data_column_pandas_labels, + data_column_pandas_index_names=value_frame.data_column_index_names, + data_column_snowflake_quoted_identifiers=[new_data_identifier], + index_column_pandas_labels=value_frame.index_column_pandas_labels, + index_column_snowflake_quoted_identifiers=value_index_identifiers, + ) + + return new_frame diff --git a/src/snowflake/snowpark/modin/plugin/_internal/frame.py b/src/snowflake/snowpark/modin/plugin/_internal/frame.py new file mode 100644 index 00000000000..598839e01ce --- /dev/null +++ b/src/snowflake/snowpark/modin/plugin/_internal/frame.py @@ -0,0 +1,1196 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# +import functools +from collections.abc import Hashable +from dataclasses import dataclass +from logging import getLogger +from typing import Any, Callable, NamedTuple, Optional, Union + +import pandas as pd +from pandas._typing import IndexLabel +from pandas.core.dtypes.common import is_object_dtype + +from snowflake.snowpark._internal.analyzer.analyzer_utils import ( + quote_name_without_upper_casing, +) +from snowflake.snowpark.column import Column as SnowparkColumn +from snowflake.snowpark.functions import col, last_value +from snowflake.snowpark.modin.plugin._internal.ordered_dataframe import ( + OrderedDataFrame, + OrderingColumn, +) +from snowflake.snowpark.modin.plugin._internal.type_utils import TypeMapper +from snowflake.snowpark.modin.plugin._internal.utils import ( + DEFAULT_DATA_COLUMN_LABEL, + INDEX_LABEL, + ROW_POSITION_COLUMN_LABEL, + append_columns, + assert_duplicate_free, + count_rows, + extract_pandas_label_from_snowflake_quoted_identifier, + fill_missing_levels_for_pandas_label, + from_pandas_label, + get_distinct_rows, + is_valid_snowflake_quoted_identifier, + snowpark_to_pandas_helper, + to_pandas_label, +) +from snowflake.snowpark.modin.plugin._typing import ( + LabelIdentifierPair, + LabelTuple, + PandasLabelToSnowflakeIdentifierPair, +) +from snowflake.snowpark.modin.utils import MODIN_UNNAMED_SERIES_LABEL +from snowflake.snowpark.types import DataType +from snowflake.snowpark.window import Window + +logger = getLogger(__name__) + +LEFT_PREFIX = "left" +RIGHT_PREFIX = "right" + + +class UpdatedInternalFrameResult(NamedTuple): + """Contains the updated internal frame and mapping from old ids to new ids.""" + + frame: "InternalFrame" + old_id_to_new_id_mappings: dict[str, str] + + +@dataclass(frozen=True) +class InternalFrame: + """ + internal abstraction of storage format to hold all information necessary to represent + a pandas.DataFrame within Snowflake + """ + + # OrderedDataFrame representation of the state of the data hold by this internal frame + # Ordering columns and row position column are maintained by OrderedDataFrame + ordered_dataframe: OrderedDataFrame + # Map between label and snowflake quoted identifier. + # This map is maintained as an ordered list, which must be in the order of + # pandas index columns + pandas data columns. + # For MultiIndex as df.columns, the pandas label will be a tuple for each column. + # An example of MultiIndex as df.columns: + # pd.MultiIndex.from_tuples([('baz', 'A'), ('baz', 'B'), ('zoo', 'A'), ('zoo', 'B')]) + # the pandas labels of data columns will be [('baz', 'A'), ('baz', 'B'), ('zoo', 'A'), ('zoo', 'B')] + label_to_snowflake_quoted_identifier: tuple[LabelIdentifierPair, ...] + # Number of index columns for the pandas dataframe, where the first num_index_columns elements + # of pandas_label_to_snowflake_quoted_identifier is for the pandas index columns + num_index_columns: int + # Store pandas labels for columns' index name or multiindex names, e.g., the labels is used to generate + # df.columns.names + # The length of data_column_index_names equals to number of multiindex levels. + # For a 3-level MultiIndex, the value can be like ['A', 'B', 'C'] + data_column_index_names: tuple[LabelTuple, ...] + + @classmethod + def create( + cls, + *, + ordered_dataframe: OrderedDataFrame, + data_column_pandas_labels: list[Hashable], + data_column_pandas_index_names: list[Hashable], + data_column_snowflake_quoted_identifiers: list[str], + index_column_pandas_labels: list[Hashable], + index_column_snowflake_quoted_identifiers: list[str], + ) -> "InternalFrame": + """ + Args: + ordered_dataframe: underlying ordered dataframe used + data_column_pandas_labels: A list of pandas hashable labels for pandas data columns. + data_column_pandas_index_names: A list of hashable labels for pandas column index names + data_column_snowflake_quoted_identifiers: A list of snowflake quoted identifiers for pandas data columns, + represented by str. These identifiers are used to refer columns in underlying snowpark dataframe to + access data in snowflake. + index_column_pandas_labels: A list of pandas index column labels. + index_column_snowflake_quoted_identifiers: A list of snowflake quoted identifiers for pandas index columns. + """ + + assert len(data_column_snowflake_quoted_identifiers) == len( + data_column_pandas_labels + ), f"data column label identifier length mismatch, labels {data_column_pandas_labels}, identifiers {data_column_snowflake_quoted_identifiers}" + assert len(index_column_snowflake_quoted_identifiers) == len( + index_column_pandas_labels + ), f"index column label identifier length mismatch, labels {index_column_pandas_labels}, identifiers {index_column_snowflake_quoted_identifiers}" + + # List of pandas_label_to_snowflake_quoted_identifier mapping for index columns + index_columns_mapping: list[LabelIdentifierPair] = [ + LabelIdentifierPair( + # index column labels is always flat with only one level + from_pandas_label(pandas_label, num_levels=1), + snowflake_quoted_identifier, + ) + for pandas_label, snowflake_quoted_identifier in zip( + index_column_pandas_labels, + index_column_snowflake_quoted_identifiers, + ) + ] + + # List of pandas_label_to_snowflake_quoted_identifier mapping for data columns + data_columns_mapping: list[LabelIdentifierPair] = [ + LabelIdentifierPair( + from_pandas_label( + pandas_label, + num_levels=len(data_column_pandas_index_names), + ), + snowflake_quoted_identifier, + ) + for pandas_label, snowflake_quoted_identifier in zip( + data_column_pandas_labels, + data_column_snowflake_quoted_identifiers, + ) + ] + + return cls( + ordered_dataframe=ordered_dataframe, + label_to_snowflake_quoted_identifier=tuple( + index_columns_mapping + data_columns_mapping + ), + num_index_columns=len(index_column_snowflake_quoted_identifiers), + data_column_index_names=tuple( + # data_column_index_names is always flat with only one level + from_pandas_label(name, num_levels=1) + for name in data_column_pandas_index_names + ), + ) + + def __post_init__(self) -> None: + # perform checks for dataclass here + + # check there must be at least one index column associated with the dataframe + assert ( + self.num_index_columns >= 1 + ), "At least 1 index column should be presented for the dataframe" + + # the ordering_columns_tuple cannot be empty, because we guarantee the determinism + # for the data order of the dataframe, + assert len(self.ordering_columns) > 0, "ordering_columns cannot be empty" + + # validate data columns + self._validate_data_column_pandas_index_names() + + # make sure that all names required in metadata are present within snowpark_dataframe + # so that the internal frame represents a valid state. + snowflake_quoted_identifiers = ( + self.ordered_dataframe.projected_column_snowflake_quoted_identifiers + ) + + def validate_snowflake_quoted_identifier( + quoted_identifier: str, + column_category: str, + hashable_label: Hashable = None, + ) -> None: + """ + validation for the snowflake quoted identifier, which performs two checks: + 1) the identifier is quoted 2) the identifier exists in the underlying snowpark dataframe + + Returns: + None. Assertion is raised if any check fails. + """ + # generate a properly quoted escaped_name for the error message below. + escaped_name = quoted_identifier.replace("'", "\\'") + assert is_valid_snowflake_quoted_identifier( + quoted_identifier + ), f"Found not-quoted identifier for '{column_category}':'{escaped_name}'" + + assert quoted_identifier in snowflake_quoted_identifiers, ( + f"{column_category}={escaped_name} not found in snowpark dataframe " + f"schema {snowflake_quoted_identifiers}, pandas_label={hashable_label}" + ) + + # validate the snowflake quoted identifier data + index columns + for ( + label, + snowflake_quoted_identifier, + ) in self.label_to_snowflake_quoted_identifier: + validate_snowflake_quoted_identifier( + snowflake_quoted_identifier, + "dataframe column", + to_pandas_label(label), + ) + + # check that snowflake quoted identifier is duplicate free + assert_duplicate_free( + self.index_column_snowflake_quoted_identifiers + + self.data_column_snowflake_quoted_identifiers, + "dataframe columns", + ) + + def _validate_data_column_pandas_index_names(self) -> None: + # the index on column (df.columns) must have a name (can be None) + assert ( + len(self.data_column_pandas_index_names) >= 1 + ), "data_column_pandas_index_names cannot be empty" + + # validate all labels are tuples with the same length + num_levels = len(self.data_column_pandas_index_names) + for label, _ in self.label_to_snowflake_quoted_identifier[ + self.num_index_columns : + ]: + assert num_levels == len( + label + ), f"All tuples in data_column_pandas_labels must have the same length {num_levels}, but got {label}" + + @property + def index_column_snowflake_quoted_identifiers(self) -> list[str]: + """ + Get snowflake quoted identifier for all index columns + Returns: + List of snowflake quoted identifiers for index columns + """ + return [ + col.snowflake_quoted_identifier + for col in self.label_to_snowflake_quoted_identifier[ + : self.num_index_columns + ] + ] + + @property + def data_column_snowflake_quoted_identifiers(self) -> list[str]: + """ + Get snowflake quoted identifier for all data columns + Returns: + List of snowflake quoted identifiers for data columns + """ + return [ + col.snowflake_quoted_identifier + for col in self.label_to_snowflake_quoted_identifier[ + self.num_index_columns : + ] + ] + + def quoted_identifier_to_snowflake_type(self) -> dict[str, DataType]: + return { + f.column_identifier.quoted_name: f.datatype + for f in self.ordered_dataframe.schema.fields + } + + @property + def index_column_pandas_labels(self) -> list[Hashable]: + """ + Get pandas labels for all index columns + Returns: + List of pandas labels for index columns + """ + return [ + to_pandas_label(col.label) + for col in self.label_to_snowflake_quoted_identifier[ + : self.num_index_columns + ] + ] + + @property + def data_column_pandas_labels(self) -> list[Hashable]: + """ + Get pandas labels for all data columns + Returns: + List of pandas labels for data columns + """ + return [ + to_pandas_label(col.label) + for col in self.label_to_snowflake_quoted_identifier[ + self.num_index_columns : + ] + ] + + @property + def ordering_column_snowflake_quoted_identifiers(self) -> list[str]: + """ + Get snowflake quoted identifier for ordering columns + Return: + List of snowflake quoted identifier for the ordering columns + """ + + return self.ordered_dataframe.ordering_column_snowflake_quoted_identifiers + + @property + def ordering_columns(self) -> list[OrderingColumn]: + """ + Get list of ordering columns. + Returns: + List of OrderingColumn. + """ + return self.ordered_dataframe.ordering_columns + + @property + def row_position_snowflake_quoted_identifier(self) -> Optional[str]: + return self.ordered_dataframe.row_position_snowflake_quoted_identifier + + @property + def row_count_snowflake_quoted_identifier(self) -> Optional[str]: + return self.ordered_dataframe.row_count_snowflake_quoted_identifier + + @property + def data_column_pandas_index_names(self) -> list[Hashable]: + """Returns pandas labels from column index (df.columns.names).""" + return [to_pandas_label(name) for name in self.data_column_index_names] + + def num_index_levels(self, *, axis: int = 0) -> int: + """ + Returns number of index levels for given `axis`. + + Args: + axis: If axis=0, return number of levels in row labels. + If axis=1, return number of levels in columns labels. + + Returns: + number of index levels for given `axis` + + Raises: + ValueError if `axis` is not valid. + """ + if axis == 0: + return self.num_index_columns + elif axis == 1: + return len(self.data_column_pandas_index_names) + else: + raise ValueError("'axis' can only be 0 or 1") + + def is_multiindex(self, *, axis: int = 0) -> bool: + """ + Returns whether the InternalFrame has a MultiIndex along `axis`. + Args: + axis: If axis=0, return whether the InternalFrame has a MultiIndex as df.index. + If axis=1, return whether the InternalFrame has a MultiIndex as df.columns. + """ + return self.num_index_levels(axis=axis) > 1 + + def is_unnamed_series(self) -> bool: + """ + Check if the InternalFrame is a representation for an unnamed series. An InternalFrame represents an + unnamed series if there is only one data column and the data column has label name MODIN_UNNAMED_SERIES_LABEL. + """ + return ( + len(self.data_column_pandas_labels) == 1 + and self.data_column_pandas_labels[0] == MODIN_UNNAMED_SERIES_LABEL + ) + + @property + def data_columns_index(self) -> pd.Index: + """ + Returns pandas Index object for column index (df.columns). + We can't do the same thing for df.index here because it requires pulling + the data from snowflake and filing a query to snowflake. + """ + if self.is_multiindex(axis=1): + return pd.MultiIndex.from_tuples( + self.data_column_pandas_labels, + names=self.data_column_pandas_index_names, + ) + else: + return pd.Index( + self.data_column_pandas_labels, + name=self.data_column_pandas_index_names[0], + # setting tupleize_cols=False to avoid creating a MultiIndex + # otherwise, when labels are tuples (e.g., [("A", "a"), ("B", "b")]), + # a MultiIndex will be created incorrectly + tupleize_cols=False, + ) + + @property + def index_columns_index(self) -> pd.Index: + """ + Get pandas index. The method eagerly pulls the values from Snowflake because index requires the values to be + filled + + Returns: + The index (row labels) of the DataFrame. + """ + + index_values = snowpark_to_pandas_helper( + self.ordered_dataframe.select( + self.index_column_snowflake_quoted_identifiers + ) + ).values + if self.is_multiindex(axis=0): + value_tuples = [tuple(row) for row in index_values] + return pd.MultiIndex.from_tuples( + value_tuples, names=self.index_column_pandas_labels + ) + else: + # We have one index column. Fill in the type correctly. + index_identifier = self.index_column_snowflake_quoted_identifiers[0] + index_type = TypeMapper.to_pandas( + self.quoted_identifier_to_snowflake_type()[index_identifier] + ) + ret = pd.Index( + [row[0] for row in index_values], + name=self.index_column_pandas_labels[0], + # setting tupleize_cols=False to avoid creating a MultiIndex + # otherwise, when labels are tuples (e.g., [("A", "a"), ("B", "b")]), + # a MultiIndex will be created incorrectly + tupleize_cols=False, + ) + # When pd.Index() failed to reduce dtype to a numpy or pandas extension type, it will be object type. For + # example, an empty dataframe will be object dtype by default, or a variant, or a timestamp column with + # multiple timezones. So here we cast the index to the index_type when ret = pd.Index(...) above cannot + # figure out a non-object dtype. Note that the index_type is a logical type may not be 100% accurate. + if is_object_dtype(ret.dtype) and not is_object_dtype(index_type): + ret = ret.astype(index_type) + return ret + + def get_snowflake_quoted_identifiers_group_by_pandas_labels( + self, + pandas_labels: list[Hashable], + include_index: bool = True, + ) -> list[tuple[str, ...]]: + """ + Map given pandas labels to names in underlying snowpark dataframe. Given labels can be data or index labels. + Single label can map to multiple snowpark names from underlying dataframe. Which is represented by tuples. + We return the result in the same order as input pandas_labels. + + Args: + pandas_labels: A list of pandas labels. + include_index: Include the index columns in addition to data columns, default is True. + + Returns: + A list of tuples for matched identifiers. Each element of list is a tuple of str containing matched + snowflake quoted identifiers for corresponding pandas label in 'pandas_labels'. + Length and order of this list is same as length of given 'pandas_labels'. + """ + + snowflake_quoted_identifiers = [] + for label in pandas_labels: + matched_columns = list( + filter( + lambda col: to_pandas_label(col.label) == label, + self.label_to_snowflake_quoted_identifier[ + (0 if include_index else self.num_index_columns) : + ], + ) + ) + snowflake_quoted_identifiers.append( + tuple(col.snowflake_quoted_identifier for col in matched_columns) + ) + + return snowflake_quoted_identifiers + + def parse_levels_to_integer_levels( + self, levels: IndexLabel, allow_duplicates: bool, axis: int = 0 + ) -> list[int]: + """ + Returns a list of integers representing levels in Index object on given axis. + + Args: + levels: IndexLabel, can be int, level name, or sequence of such. + allow_duplicates: whether allow duplicated levels in the result. When False, the result will not + contain any duplicated levels. Otherwise, the result will contain duplicated level number if + different level value is mapped to the same level number. + axis: DataFrame axis, given levels belong to. Defaults to 0. Allowed values + are 0 or 1. + Returns: + List[int] + A list of integers corresponding to the index levels for the given level, and in the same + order as given level + """ + num_level = self.num_index_levels(axis=axis) + if levels is not None: + if not isinstance(levels, (tuple, list)): + levels = [levels] + result = [] + for key in levels: + if isinstance(key, int): + error_message = f"Too many levels: Index has only {num_level} level{'s' if num_level > 1 else ''}" + # when key < 0, raise IndexError if key < -num_level as native pandas does + # set key to a positive number as native pandas does + if key < 0: + key = key + num_level + if key < 0: + raise IndexError( + f"{error_message}, {key - num_level} is not a valid level number" + ) + # when key > num_level - 1, raise IndexError as native pandas does + elif key > num_level - 1: # level starts from 0 + raise IndexError(f"{error_message}, not {key + 1}") + elif isinstance(key, str): # get level number from label + try: + if axis == 0: + key = self.index_column_pandas_labels.index(key) + else: + key = self.data_column_pandas_index_names.index(key) + # if key doesn't exist, a ValueError will be raised + except ValueError: + if num_level > 1: + raise KeyError(f"Level {key} not found") + else: + raise KeyError( + f"Requested level ({key}) does not match index name ({self.index_column_pandas_labels[0]})" + ) + # do not add key in the result if the key is already in the result and duplication is not allowed + if (key not in result) or allow_duplicates: + result.append(key) + else: + result = list(range(num_level)) + return result + + def get_pandas_labels_for_levels(self, levels: list[int]) -> list[Hashable]: + """ + Get the list of corresponding pandas labels for a list of given integer + Index levels. + Note: duplication in levels is allowed. + """ + return [self.index_column_pandas_labels[level] for level in levels] + + def get_snowflake_identifiers_for_levels(self, levels: list[int]) -> list[str]: + """ + Get the list of corresponding Snowflake identifiers for a list of given integer index levels. + + Note: duplication in levels is allowed. + """ + return [ + self.index_column_snowflake_quoted_identifiers[level] for level in levels + ] + + def get_snowflake_identifiers_and_pandas_labels_from_levels( + self, levels: list[int] + ) -> tuple[list[Hashable], list[str], list[Hashable], list[str]]: + """ + Selects snowflake identifiers and pandas labels from index columns in `levels`. + Also returns snowflake identifiers and pandas labels not in `levels`. + + Args: + levels: A list of integers represents levels in pandas Index. + + Returns: + A tuple contains 4 lists: + 1. The first list contains snowflake identifiers of index columns in `levels`. + 2. The second list contains pandas labels of index columns in `levels`. + 3. The third list contains snowflake identifiers of index columns not in `levels`. + 4. The fourth list contains pandas labels of index columns not in `levels`. + """ + index_column_pandas_labels_in_levels = [] + index_column_snowflake_quoted_identifiers_in_levels = [] + index_column_pandas_labels_not_in_levels = [] + index_column_snowflake_quoted_identifiers_not_in_levels = [] + for idx, (identifier, label) in enumerate( + zip( + self.index_column_snowflake_quoted_identifiers, + self.index_column_pandas_labels, + ) + ): + if idx in levels: + index_column_pandas_labels_in_levels.append(label) + index_column_snowflake_quoted_identifiers_in_levels.append(identifier) + else: + index_column_pandas_labels_not_in_levels.append(label) + index_column_snowflake_quoted_identifiers_not_in_levels.append( + identifier + ) + + return ( + index_column_pandas_labels_in_levels, + index_column_snowflake_quoted_identifiers_in_levels, + index_column_pandas_labels_not_in_levels, + index_column_snowflake_quoted_identifiers_not_in_levels, + ) + + @functools.cached_property + def num_rows(self) -> int: + """ + Returns: + Number of rows in this frame. + """ + return count_rows(self.ordered_dataframe) + + def has_unique_index(self, axis: Optional[int] = 0) -> bool: + """ + Returns true if index has unique values on specified axis. + Args: + axis: {0, 1} defaults to 0 + + Returns: + True if index has unique values on specified axis, otherwise returns False. + + """ + if axis == 1: + return self.data_columns_index.is_unique + else: + # Note: We can't use 'count_distinct' because it ignores null values. + total_rows = self.num_rows + distinct_rows = count_rows( + get_distinct_rows( + self.ordered_dataframe.select( + self.index_column_snowflake_quoted_identifiers + ) + ) + ) + return total_rows == distinct_rows + + def validate_no_duplicated_data_columns_mapped_for_labels( + self, + pandas_labels: list[Hashable], + user_frame_identifier: Optional[str] = None, + ) -> None: + """ + For a given set of pandas labels, verify that there are no multiple data columns in the frame + mapped to the same label in the `pandas_labels`. + + Args: + pandas_labels: set of pandas labels to check for duplicated column mappings + user_frame_identifier: the identifier for the frame that is used in the error message to help user to + identify which input frame has error. For example, it can be 'condition' or 'other' frame for + where API. + Raises: + ValueError: if for a pandas label, there exists more than one data columns in the given frame mapped to the label. + """ + label_identifiers_list = ( + self.get_snowflake_quoted_identifiers_group_by_pandas_labels( + pandas_labels=pandas_labels, include_index=False + ) + ) + labels_with_duplication = [ + pandas_labels[i] + for (i, label_identifiers_tuple) in enumerate(label_identifiers_list) + if len(label_identifiers_tuple) > 1 + ] + if len(labels_with_duplication) > 0: + # The error message raised under duplication cases is different from native pandas. + # Native pandas raises ValueError with message "cannot reindex on an axis with duplicate labels" + # for duplication occurs in the condition frame, and raises InvalidIndexError with no message for + # duplication occurs in other frame. + # Snowpark pandas gives a clear message to the customer about what is the problem with the dataframe. + message = f"Multiple columns are mapped to each label in {labels_with_duplication} in DataFrame" + if user_frame_identifier is not None: + message += f" {user_frame_identifier}" + raise ValueError(message) + + ########################################################################### + # START: Internal Frame mutation APIs. + # APIs that creates a new InternalFrame instance, should only be added below + def ensure_row_position_column(self) -> "InternalFrame": + """ + Ensure row position column is computed for given internal frame. + + Returns: + A new InternalFrame instance with computed virtual index. + """ + return InternalFrame.create( + ordered_dataframe=self.ordered_dataframe.ensure_row_position_column(), + data_column_pandas_labels=self.data_column_pandas_labels, + data_column_snowflake_quoted_identifiers=self.data_column_snowflake_quoted_identifiers, + data_column_pandas_index_names=self.data_column_pandas_index_names, + index_column_pandas_labels=self.index_column_pandas_labels, + index_column_snowflake_quoted_identifiers=self.index_column_snowflake_quoted_identifiers, + ) + + def ensure_row_count_column(self) -> "InternalFrame": + """ + Ensure row position column is computed for given internal frame. + + Returns: + A new InternalFrame instance with computed virtual index. + """ + return InternalFrame.create( + ordered_dataframe=self.ordered_dataframe.ensure_row_count_column(), + data_column_pandas_labels=self.data_column_pandas_labels, + data_column_snowflake_quoted_identifiers=self.data_column_snowflake_quoted_identifiers, + data_column_pandas_index_names=self.data_column_pandas_index_names, + index_column_pandas_labels=self.index_column_pandas_labels, + index_column_snowflake_quoted_identifiers=self.index_column_snowflake_quoted_identifiers, + ) + + def append_column( + self, pandas_label: Hashable, value: SnowparkColumn + ) -> "InternalFrame": + """ + Append a column to this frame. The column is added at the end. For a frame with multiindex column, it + automatically fills the missing levels with None. For example, in a table with MultiIndex columns like + ("A", "col1"), ("A", "col2"), ("B", "col1"), ("B", "col2"), appending a count column "cnt" will produce + a column labelled ("cnt", None). + + Args: + pandas_label: pandas label for column to be inserted. + value: SnowparkColumn. + + Returns: + A new InternalFrame with new column. + """ + # +---------------+---------------+---------------+---------------+ +---------------+ + # | ("A", "col1") | ("A", "col2") | ("B", "col1") | ("B", "col2") | | "cnt" | + # +---------------+---------------+---------------+---------------+ + +---------------+ + # | . . . | . . . | . . . | . . . | | . . . | + # +---------------+---------------+---------------+---------------+ +---------------+ + # + # Appending a column "cnt" to the table below will produce the following table: + # +---------------+---------------+---------------+---------------+---------------+ + # | ("A", "col1") | ("A", "col2") | ("B", "col1") | ("B", "col2") | ("cnt", None) | + # +---------------+---------------+---------------+---------------+---------------+ + # | . . . | . . . | . . . | . . . | . . . | + # +---------------+---------------+---------------+---------------+---------------+ + + # Generate label for the column to be appended. + nlevels = self.num_index_levels(axis=1) + pandas_label = fill_missing_levels_for_pandas_label( + pandas_label, nlevels, 0, None + ) + + # Generate snowflake quoted identifier for new column to be added. + new_column_identifier = ( + self.ordered_dataframe.generate_snowflake_quoted_identifiers( + pandas_labels=[pandas_label], + )[0] + ) + new_ordered_dataframe = append_columns( + self.ordered_dataframe, new_column_identifier, value + ) + return InternalFrame.create( + ordered_dataframe=new_ordered_dataframe, + data_column_pandas_labels=self.data_column_pandas_labels + [pandas_label], + data_column_snowflake_quoted_identifiers=self.data_column_snowflake_quoted_identifiers + + [new_column_identifier], + data_column_pandas_index_names=self.data_column_pandas_index_names, + index_column_pandas_labels=self.index_column_pandas_labels, + index_column_snowflake_quoted_identifiers=self.index_column_snowflake_quoted_identifiers, + ) + + def project_columns( + self, + pandas_labels: list[Hashable], + column_objects: list[SnowparkColumn], + ) -> "InternalFrame": + """ + Project new columns with column_objects as the new data columns for the new Internal Frame. + The original index columns, ordering columns and row position columns are still used as the + index columns, ordering columns and row position columns for the new Internal Frame. + + * Note that this is different with append column in the sense that the data columns of the + original data frame will not be part of the data columns of the result dataframe. The data + column of the result dataframe only contains the new projected data columns. + + Args: + pandas_labels: The pandas labels for the newly projected data columns + column_objects: the Snowpark columns used to project the new data columns + + Returns: + A new InternalFrame with the newly projected columns as data column + """ + new_column_identifiers = ( + self.ordered_dataframe.generate_snowflake_quoted_identifiers( + pandas_labels=pandas_labels, + ) + ) + new_ordered_dataframe = append_columns( + self.ordered_dataframe, new_column_identifiers, column_objects + ) + return InternalFrame.create( + ordered_dataframe=new_ordered_dataframe, + data_column_pandas_labels=pandas_labels, + data_column_snowflake_quoted_identifiers=new_column_identifiers, + data_column_pandas_index_names=self.data_column_pandas_index_names, + index_column_pandas_labels=self.index_column_pandas_labels, + index_column_snowflake_quoted_identifiers=self.index_column_snowflake_quoted_identifiers, + ) + + def rename_snowflake_identifiers( + self, old_to_new_identifiers: dict[str, str] + ) -> "InternalFrame": + """ + Rename columns for underlying ordered dataframe. + + Args: + old_to_new_identifiers: A dictionary from old to new identifiers name. + Identifiers which do not occur in dictionary are not renamed. + + Returns: + A new InternalFrame instance after rename. + + Raises: + KeyError if columns are not index or data column of the current internal frame. + """ + if not old_to_new_identifiers: + return self + + ordered_dataframe = self.ordered_dataframe + internal_frame_column_quoted_identifiers = ( + self.index_column_snowflake_quoted_identifiers + + self.data_column_snowflake_quoted_identifiers + ) + for old_id in old_to_new_identifiers: + if old_id not in internal_frame_column_quoted_identifiers: + raise KeyError( + f"Column not found: '{old_id}'." + f" Internal frame has following data and index columns: {internal_frame_column_quoted_identifiers}" + ) + select_list = [] + any_column_to_rename = False + ordering_and_row_position_columns = ( + ordered_dataframe.ordering_column_snowflake_quoted_identifiers + + [ordered_dataframe.row_position_snowflake_quoted_identifier] + if ordered_dataframe.row_position_snowflake_quoted_identifier is not None + else [] + ) + for old_id in ordered_dataframe.projected_column_snowflake_quoted_identifiers: + # Alias to new identifier name if present in 'old_to_new_identifiers', + # otherwise leave unchanged. + new_id = old_to_new_identifiers.get(old_id, old_id) + if old_id == new_id: + # retain the original column + select_list.append(old_id) + else: + select_list.append(col(old_id).as_(new_id)) + # if the old column is part of the ordering or row position columns, retains the column + # as part of the projected columns. + if old_id in ordering_and_row_position_columns: + select_list.append(old_id) + + any_column_to_rename = any_column_to_rename or new_id != old_id + if not any_column_to_rename: + # This is possible when values in 'old_to_new_identifiers' are same as keys. + return self + ordered_dataframe = ordered_dataframe.select(select_list) + + def get_updated_identifiers(identifiers: list[str]) -> list[str]: + """ + Get the new identifier after rename, and if not exist in the rename map, + no rename happens, the original name is returned + + Args: + identifiers: List of identifiers to get updated identifiers. + + Returns: + A list of identifiers after rename, if not exist in the rename map, + original name is returned. + """ + return [old_to_new_identifiers.get(i, i) for i in identifiers] + + return InternalFrame.create( + ordered_dataframe=ordered_dataframe, + data_column_pandas_labels=self.data_column_pandas_labels, + data_column_snowflake_quoted_identifiers=get_updated_identifiers( + self.data_column_snowflake_quoted_identifiers + ), + data_column_pandas_index_names=self.data_column_pandas_index_names, + index_column_pandas_labels=self.index_column_pandas_labels, + index_column_snowflake_quoted_identifiers=get_updated_identifiers( + self.index_column_snowflake_quoted_identifiers + ), + ) + + def update_snowflake_quoted_identifiers_with_expressions( + self, + quoted_identifier_to_column_map: dict[str, SnowparkColumn], + ) -> UpdatedInternalFrameResult: + """ + Points Snowflake quoted identifiers to column expression given by `quoted_identifier_to_column_map`. + + This function takes a mapping from existing snowflake quoted identifiers to + new Snowpark column expressions and points the existing quoted identifiers to the + column expressions provided by the mapping. For optimization purposes, + existing expressions are kept as columns. This does not change pandas labels. + + The process involves the following steps: + + 1. Create a list of new snowflake quoted column identifiers from existing snowflake quoted + column identifiers (keys of `quoted_identifier_to_column_map`) to prevent naming conflicts. + 2. Append new Snowpark columns (values of `quoted_identifier_to_column_map`) + to the end of the Snowpark DataFrame with new snowflake quoted column identifiers + generated at step 1. + 3. Update index and data column identifiers in the internal frame, by replacing existing + snowflake quoted identifiers (keys of `quoted_identifier_to_column_map`) + with new snowflake quoted column identifiers created in step 1 + + Args: + quoted_identifier_to_column_map (Dict[str, SnowparkColumn]): A dictionary mapping + existing snowflake quoted identifiers to new Snowpark columns. + As keys of a dictionary, all snowflake column identifiers are unique here and + must be index columns and data columns in the original internal frame. + + Returns: + UpdatedInternalFrameResult: A tuple contaning the new InternalFrame with updated column references, and a mapping + of the old column ids to the new column ids. + + Raises: + ValueError if any key of quoted_identifier_to_column_map is not in the data or index columns of the internal frame. + + Example: + `update_snowflake_quoted_identifiers_with_expressions(quoted_identifier_to_column_map={'"A"' : lit(10), '"B"': col('"A"') + col('"B"')}).frame` + The internal frame has pandas labels ['pd_a', 'pd_b', 'pd_a'] (there can be duplicates), + mapping to the snowflake quoted identifiers ['"A"', '"B"', '"C"'], i.e. 'pd_a' -> "A", 'pd_b' -> "B", 'pd_a' -> "C". + Index column identifiers are ['"A"'] and data column identifiers are ['"B"', '"C"']. + Calling this function will now create new identifiers (and keep the old ones), so that + 'pd_a' -> lit(10), 'pd_b' -> col('"A"') + col('"B"'), 'pd_a' -> "C". + For this, the function generates new aliases, e.g. '"A2"' for lit(10), and '"B2"' for col('"A"') + col('"B"'). + Thus, after applying this function the snowpark dataframe backing this internal frame up has + ['"A"', '"B"', '"C"', '"A2"', '"B2"'] as quoted identifiers. + Index column identifiers become ['"A2"'] and data column identifiers are still ['"B2"', '"C"']. + """ + # no-op + if not quoted_identifier_to_column_map: + return UpdatedInternalFrameResult(self, {}) + + all_data_index_identifiers = set( + self.data_column_snowflake_quoted_identifiers + + self.index_column_snowflake_quoted_identifiers + ) + for identifier in quoted_identifier_to_column_map: + if identifier not in all_data_index_identifiers: + raise ValueError(f"{identifier} is not in {all_data_index_identifiers}") + + existing_id_to_new_id_mapping = {} + new_columns = [] + for ( + existing_identifier, + column_expression, + ) in quoted_identifier_to_column_map.items(): + new_identifier = ( + self.ordered_dataframe.generate_snowflake_quoted_identifiers( + pandas_labels=[ + extract_pandas_label_from_snowflake_quoted_identifier( + existing_identifier + ) + ], + )[0] + ) + existing_id_to_new_id_mapping[existing_identifier] = new_identifier + new_columns.append(column_expression) + new_ordered_dataframe = append_columns( + self.ordered_dataframe, + list(existing_id_to_new_id_mapping.values()), + new_columns, + ) + # update index_column_snowflake_quoted_identifiers and data_column_snowflake_quoted_identifiers + # the order of index/data_column_snowflake_quoted_identifiers is not changed so we can still + # keep the correct mapping between quoted identifiers and pandas labels + new_index_column_snowflake_quoted_identifiers = [ + existing_id_to_new_id_mapping.get(identifier, identifier) + for identifier in self.index_column_snowflake_quoted_identifiers + ] + new_data_column_snowflake_quoted_identifiers = [ + existing_id_to_new_id_mapping.get(identifier, identifier) + for identifier in self.data_column_snowflake_quoted_identifiers + ] + + return UpdatedInternalFrameResult( + InternalFrame.create( + ordered_dataframe=new_ordered_dataframe, + data_column_pandas_labels=self.data_column_pandas_labels, + data_column_snowflake_quoted_identifiers=new_data_column_snowflake_quoted_identifiers, + data_column_pandas_index_names=self.data_column_pandas_index_names, + index_column_pandas_labels=self.index_column_pandas_labels, + index_column_snowflake_quoted_identifiers=new_index_column_snowflake_quoted_identifiers, + ), + existing_id_to_new_id_mapping, + ) + + def apply_snowpark_function_to_data_columns( + self, snowpark_func: Callable[[Any], SnowparkColumn] + ) -> "InternalFrame": + """ + Apply snowpark function callable to data columns of an InternalFrame. The snowflake quoted identifiers + are preserved. + + Arguments: + snowpark_func: Snowpark function to apply to data columns of underlying snowpark df. + + Returns: + InternalFrame with snowpark_func applies to data columns of original frame, all other columns remain unchanged. + """ + new_internal_frame = self.update_snowflake_quoted_identifiers_with_expressions( + { + snowflake_quoted_identifier: snowpark_func(snowflake_quoted_identifier) + for snowflake_quoted_identifier in self.data_column_snowflake_quoted_identifiers + } + ).frame + return new_internal_frame + + def select_active_columns(self) -> "InternalFrame": + """ + Select active columns of the current internal frame, the active columns include index + data columns, + ordering columns and row position column if exists. This function is used to re-project all active columns + in the ordered dataframe, and drop off unnecessary columns from the projected columns of the ordered dataframe. + + Returns: + A new InternalFrame with the associated ordered dataframe contains the following projected columns: + 1) index + data columns + 2) ordering columns + 3) row position column if exists + + """ + active_column_quoted_identifiers = ( + self.index_column_snowflake_quoted_identifiers + + self.data_column_snowflake_quoted_identifiers + ) + # add the missing ordering columns + active_column_quoted_identifiers += [ + quoted_identifier + for quoted_identifier in self.ordering_column_snowflake_quoted_identifiers + if quoted_identifier not in active_column_quoted_identifiers + ] + + if ( + self.row_position_snowflake_quoted_identifier is not None + and self.row_position_snowflake_quoted_identifier + not in active_column_quoted_identifiers + ): + active_column_quoted_identifiers.append( + self.row_position_snowflake_quoted_identifier + ) + + return InternalFrame.create( + ordered_dataframe=self.ordered_dataframe.select( + active_column_quoted_identifiers + ), + index_column_pandas_labels=self.index_column_pandas_labels, + index_column_snowflake_quoted_identifiers=self.index_column_snowflake_quoted_identifiers, + data_column_pandas_labels=self.data_column_pandas_labels, + data_column_snowflake_quoted_identifiers=self.data_column_snowflake_quoted_identifiers, + data_column_pandas_index_names=self.data_column_pandas_index_names, + ) + + def strip_duplicates( + self: "InternalFrame", quoted_identifiers: list[str] + ) -> "InternalFrame": + """ + When assigning frames via index operations for duplicates only the last entry is used, as entries are repeatedly overwritten. + For example writing a series to a key [0, 1, 0] with values [1,2,3] will put value 2 to position 1, and value 3 to position 0. + This function strips the preceding index/value rows to emulate repeated writes. + + Args: + quoted_identifiers: the column identifiers to use for creating individual groups from which to take the last element. + + Returns: + new internal frame with unique index. + """ + + frame = self.ensure_row_position_column() + + # To remove the duplicates, first compute via windowing over index columns the value of the last row position. + # with this join then select only the relevant rows. Note that an EXISTS subquery doesn't work here because + # Snowflake fails with a non-supported subquery expression error for LAST_VALUE. + # SELECT a.* EXCLUDE (pos) FROM df a JOIN (SELECT DISTINCT LAST_VALUE(pos) OVER + # (PARTITION BY (idx, other_idx) ORDER BY pos) AS pos FROM df) b ON a.pos = b.pos; + + assert len(quoted_identifiers) == len( + set(quoted_identifiers) + & set(frame.ordered_dataframe.projected_column_snowflake_quoted_identifiers) + ), "could not find all quoted identifiers in frame" + + relevant_last_value_row_positions_quoted_identifier = ( + frame.ordered_dataframe.generate_snowflake_quoted_identifiers( + pandas_labels=[ROW_POSITION_COLUMN_LABEL], + )[0] + ) + + relevant_last_value_row_positions = get_distinct_rows( + frame.ordered_dataframe.select( + last_value(col(frame.row_position_snowflake_quoted_identifier)) + .over( + Window.partition_by(quoted_identifiers).order_by( + frame.row_position_snowflake_quoted_identifier + ) + ) + .as_(relevant_last_value_row_positions_quoted_identifier) + ) + ) + + joined_ordered_dataframe = frame.ordered_dataframe.join( + right=relevant_last_value_row_positions, + left_on_cols=[frame.row_position_snowflake_quoted_identifier], + right_on_cols=[relevant_last_value_row_positions_quoted_identifier], + how="inner", + ) + + # Because we reuse row position to select the relevant columns, we need to + # generate a new row position column here so locational indexing after this operation + # continues to work correctly. + new_ordered_dataframe = joined_ordered_dataframe.ensure_row_position_column() + return InternalFrame.create( + ordered_dataframe=new_ordered_dataframe, + data_column_pandas_labels=frame.data_column_pandas_labels, + data_column_snowflake_quoted_identifiers=frame.data_column_snowflake_quoted_identifiers, + data_column_pandas_index_names=frame.data_column_pandas_index_names, + index_column_pandas_labels=frame.index_column_pandas_labels, + index_column_snowflake_quoted_identifiers=frame.index_column_snowflake_quoted_identifiers, + ) + + def filter( + self: "InternalFrame", expr: Union[SnowparkColumn, str] + ) -> "InternalFrame": + """ + A helper method to apply filter on the internal frame + Args: + expr: the expression of the filter + + Returns: + The internal frame after filtering + """ + return InternalFrame.create( + ordered_dataframe=self.ordered_dataframe.filter(expr), + data_column_pandas_labels=self.data_column_pandas_labels, + data_column_snowflake_quoted_identifiers=self.data_column_snowflake_quoted_identifiers, + data_column_pandas_index_names=self.data_column_pandas_index_names, + index_column_pandas_labels=self.index_column_pandas_labels, + index_column_snowflake_quoted_identifiers=self.index_column_snowflake_quoted_identifiers, + ) + + def normalize_snowflake_quoted_identifiers_with_pandas_label( + self, + ) -> "InternalFrame": + """ + Normalize snowflake quoted identifiers for index and data columns based on the pandas label to make sure + the quoted identifier is in format of