diff --git a/CHANGELOG.md b/CHANGELOG.md index bfefa57d2..a15174f7f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ ## 0.10.7dev +* [Feature] Add Spark Connection as a dialect for Jupysql ([#965](https://github.com/ploomber/jupysql/issues/965)) (by [@gilandose](https://github.com/gilandose)) + ## 0.10.6 (2023-12-21) * [Fix] Fix error when `%sql` includes a query with negative numbers ([#958](https://github.com/ploomber/jupysql/issues/958)) diff --git a/doc/_toc.yml b/doc/_toc.yml index 2d9850b10..f667e807b 100644 --- a/doc/_toc.yml +++ b/doc/_toc.yml @@ -43,6 +43,7 @@ parts: - file: integrations/duckdb-native - file: integrations/compatibility - file: integrations/chdb + - file: integrations/spark - caption: API Reference chapters: diff --git a/doc/api/configuration.md b/doc/api/configuration.md index e2bb114a5..254ea712a 100644 --- a/doc/api/configuration.md +++ b/doc/api/configuration.md @@ -234,6 +234,26 @@ value enables the ones from previous values plus new ones: - `2`: All feedback - Footer to distinguish pandas/polars data frames from JupySQL's result sets +## `lazy_execution` + +```{versionadded} 0.10.7 +This option only works when connecting to Spark +``` + +Default: `False` + +Return lazy relation to dataset rather than executing through JupySql. + +```{code-cell} ipython3 +%config SqlMagic.lazy_execution = True +df = %sql SELECT * FROM languages +``` + +```{code-cell} ipython3 +%config SqlMagic.lazy_execution = False +res = %sql SELECT * FROM languages +``` + ## `named_parameters` ```{versionadded} 0.9 diff --git a/doc/conf.py b/doc/conf.py index 5e1792154..39080d77d 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -27,6 +27,7 @@ "integrations/oracle.ipynb", "integrations/snowflake.ipynb", "integrations/redshift.ipynb", + "integrations/spark.ipynb", ] nb_execution_in_temp = True nb_execution_show_tb = True diff --git a/doc/integrations/compatibility.md b/doc/integrations/compatibility.md index 4e6b36432..d59760a98 100644 --- a/doc/integrations/compatibility.md +++ b/doc/integrations/compatibility.md @@ -114,4 +114,20 @@ These table reflects the compatibility status of JupySQL `>=0.7` - Listing tables with `%sqlcmd tables` ✅ - Listing columns with `%sqlcmd columns` ✅ - Parametrized SQL queries via `{{parameter}}` ✅ -- Interactive SQL queries via `--interact` ✅ \ No newline at end of file +- Interactive SQL queries via `--interact` ✅ + +## Spark + +- Running queries with `%%sql` ✅ +- CTEs with `%%sql --save NAME` ✅ +- Plotting with `%%sqlplot boxplot` ❓ +- Plotting with `%%sqlplot bar` ✅ +- Plotting with `%%sqlplot pie` ✅ +- Plotting with `%%sqlplot histogram` ✅ +- Plotting with `ggplot` ✅ +- Profiling tables with `%sqlcmd profile` ✅ +- Listing tables with `%sqlcmd tables` ❌ +- Listing columns with `%sqlcmd columns` ❌ +- Parametrized SQL queries via `{{parameter}}` ✅ +- Interactive SQL queries via `--interact` ✅ +- Persisting Dataframes via `--persist` ✅ \ No newline at end of file diff --git a/doc/integrations/spark.ipynb b/doc/integrations/spark.ipynb new file mode 100644 index 000000000..4f150500d --- /dev/null +++ b/doc/integrations/spark.ipynb @@ -0,0 +1,1399 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Spark\n", + "\n", + "This tutorial will show you how to get a Spark instance up and running locally to integrate with JupySQL. You can run this in a Jupyter notebook. We'll use [Spark Connect](https://spark.apache.org/docs/latest/api/python/getting_started/quickstart_connect.html) which is the new thin client for Spark" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pre-requisites\n", + "\n", + "To run this tutorial, you need to install following Python packages:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install jupysql pyspark==3.4.1 arrow pyarrow==12.0.1 pandas grpcio-status --quiet" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Start Spark instance\n", + "\n", + "We fetch the official image, create a new database, and user (this will take a few seconds)." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "12f699ee8e8e35ab10186f3c39024a7e443691bb4213e56ca3c2e90cd80daf1b\n" + ] + } + ], + "source": [ + "%%bash\n", + "docker run -p 15002:15002 -p 4040:4040 -d --name spark wh1isper/sparglim-server" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Our database is running, let's load some data!" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load sample data\n", + "\n", + "Now, let's fetch some sample data. We'll be using the [NYC taxi dataset](https://www.nyc.gov/site/tlc/about/tlc-trip-record-data.page):" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from pyspark.sql.connect.session import SparkSession\n", + "\n", + "spark = SparkSession.builder.remote(\"sc://localhost\").getOrCreate()\n", + "\n", + "df = pd.read_parquet(\n", + " \"https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2021-01.parquet\"\n", + ")\n", + "sparkDf = spark.createDataFrame(df.head(10000))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Set [eagerEval](https://spark.apache.org/docs/latest/api/python/getting_started/quickstart_df.html#Viewing-Data) on to print dataframes, This makes Spark print dataframes eagerly in notebook environments, rather than it's default lazy execution which requires .show() to see the data. In Spark 3.4.1 we need to override, as below, but in 3.5.0 it will print in html. " + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "def __pretty_(self, p, cycle):\n", + " self.show(truncate=False)\n", + "\n", + "\n", + "from pyspark.sql.connect.dataframe import DataFrame\n", + "\n", + "DataFrame._repr_pretty_ = __pretty_\n", + "spark.conf.set(\"spark.sql.repl.eagerEval.enabled\", True)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Add dataset to temporary view to allow querying:" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "sparkDf.createOrReplaceTempView(\"taxi\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Query\n", + "\n", + "Now, let's start JupySQL, authenticate, and query the data!" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The sql extension is already loaded. To reload it, use:\n", + " %reload_ext sql\n" + ] + } + ], + "source": [ + "%load_ext sql" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%sql spark" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```{important}\n", + "If the cell above fails, you might have some missing packages. Message us on [Slack](https://ploomber.io/community) and we'll help you!\n", + "```" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "List the tables in the database:" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "Running query in 'SparkSession'" + ], + "text/plain": [ + "Running query in 'SparkSession'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
namespaceviewNameisTemporary
taxiTrue
" + ], + "text/plain": [ + "+-----------+----------+-------------+\n", + "| namespace | viewName | isTemporary |\n", + "+-----------+----------+-------------+\n", + "| | taxi | True |\n", + "+-----------+----------+-------------+" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%sql show views in default" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can turn on `lazy_spark` to avoid executing spark plan and return a Spark Dataframe" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "%config SqlMagic.lazy_execution = True" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "Running query in 'SparkSession'" + ], + "text/plain": [ + "Running query in 'SparkSession'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+---------+--------+-----------+\n", + "|namespace|viewName|isTemporary|\n", + "+---------+--------+-----------+\n", + "| |taxi |true |\n", + "+---------+--------+-----------+\n", + "\n" + ] + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%sql show views in default" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "%config SqlMagic.lazy_execution = False" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "List columns in the taxi table:" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "Running query in 'SparkSession'" + ], + "text/plain": [ + "Running query in 'SparkSession'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "root\n", + " |-- VendorID: long (nullable = true)\n", + " |-- tpep_pickup_datetime: timestamp (nullable = true)\n", + " |-- tpep_dropoff_datetime: timestamp (nullable = true)\n", + " |-- passenger_count: double (nullable = true)\n", + " |-- trip_distance: double (nullable = true)\n", + " |-- RatecodeID: double (nullable = true)\n", + " |-- store_and_fwd_flag: string (nullable = true)\n", + " |-- PULocationID: long (nullable = true)\n", + " |-- DOLocationID: long (nullable = true)\n", + " |-- payment_type: long (nullable = true)\n", + " |-- fare_amount: double (nullable = true)\n", + " |-- extra: double (nullable = true)\n", + " |-- mta_tax: double (nullable = true)\n", + " |-- tip_amount: double (nullable = true)\n", + " |-- tolls_amount: double (nullable = true)\n", + " |-- improvement_surcharge: double (nullable = true)\n", + " |-- total_amount: double (nullable = true)\n", + " |-- congestion_surcharge: double (nullable = true)\n", + " |-- airport_fee: double (nullable = true)\n", + "\n" + ] + } + ], + "source": [ + "df = %sql select * from taxi\n", + "df.sqlaproxy.dataframe.printSchema()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Query our data:" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "Running query in 'SparkSession'" + ], + "text/plain": [ + "Running query in 'SparkSession'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
count(1)
10000
" + ], + "text/plain": [ + "+----------+\n", + "| count(1) |\n", + "+----------+\n", + "| 10000 |\n", + "+----------+" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "SELECT COUNT(*) FROM taxi" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Parameterize queries" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "threshold = 10" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "Running query in 'SparkSession'" + ], + "text/plain": [ + "Running query in 'SparkSession'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
count(1)
9476
" + ], + "text/plain": [ + "+----------+\n", + "| count(1) |\n", + "+----------+\n", + "| 9476 |\n", + "+----------+" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "SELECT COUNT(*) FROM taxi\n", + "WHERE trip_distance < {{threshold}}" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "threshold = 0.5" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "Running query in 'SparkSession'" + ], + "text/plain": [ + "Running query in 'SparkSession'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
count(1)
642
" + ], + "text/plain": [ + "+----------+\n", + "| count(1) |\n", + "+----------+\n", + "| 642 |\n", + "+----------+" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "SELECT COUNT(*) FROM taxi\n", + "WHERE trip_distance < {{threshold}}" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## CTEs" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "Running query in 'SparkSession'" + ], + "text/plain": [ + "Running query in 'SparkSession'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Skipping execution..." + ], + "text/plain": [ + "Skipping execution..." + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%sql --save many_passengers --no-execute\n", + "SELECT *\n", + "FROM taxi\n", + "WHERE passenger_count > 3\n", + "-- remove top 1% outliers for better visualization\n", + "AND trip_distance < 18.93" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "Running query in 'SparkSession'" + ], + "text/plain": [ + "Running query in 'SparkSession'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
min(trip_distance)avg(trip_distance)max(trip_distance)
0.03.109138187221396318.46
" + ], + "text/plain": [ + "+--------------------+--------------------+--------------------+\n", + "| min(trip_distance) | avg(trip_distance) | max(trip_distance) |\n", + "+--------------------+--------------------+--------------------+\n", + "| 0.0 | 3.1091381872213963 | 18.46 |\n", + "+--------------------+--------------------+--------------------+" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql --save trip_stats --with many_passengers\n", + "SELECT MIN(trip_distance), AVG(trip_distance), MAX(trip_distance)\n", + "FROM many_passengers" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is what JupySQL executes:" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WITH `many_passengers` AS (\n", + "SELECT *\n", + "FROM taxi\n", + "WHERE passenger_count > 3\n", + "\n", + "AND trip_distance < 18.93)\n", + "SELECT MIN(trip_distance), AVG(trip_distance), MAX(trip_distance)\n", + "FROM many_passengers\n" + ] + } + ], + "source": [ + "query = %sqlcmd snippets trip_stats\n", + "print(query)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Profiling" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
Following statistics are not available in\n", + " SparkSession: STD, 25%, 50%, 75%
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
VendorIDtpep_pickup_datetimetpep_dropoff_datetimepassenger_counttrip_distanceRatecodeIDstore_and_fwd_flagPULocationIDDOLocationIDpayment_typefare_amountextramta_taxtip_amounttolls_amountimprovement_surchargetotal_amountcongestion_surchargeairport_fee
count1000010000100001000010000100001000010000100001000010000100001000010000100001000010000100000
unique287668745712436217323042288350418395930
topnan2021-01-01 00:41:192021-01-02 00:00:00nannannanNnannannannannannannannannannannanNone
freqnan47nannannan9808nannannannannannannannannannannan0
mean1.6901nannan1.50803.10021.0712nan158.5551154.72961.381911.88220.82590.48641.78460.22460.294516.96962.1063nan
std0.4625nannan1.13543.59701.0755nan70.928875.25040.555210.84201.11670.10412.43511.27300.057012.50230.9562nan
min1nannan0.00.01.0nan111-100.0-0.5-0.5-1.07-6.12-0.3-100.3-2.5nan
25%1.0000nannan1.00001.04001.0000nan100.000083.00001.00006.00000.00000.50000.00000.00000.300010.30002.5000nan
50%2.0000nannan1.00001.93001.0000nan152.0000151.00001.00008.50000.50000.50001.54000.00000.300013.55002.5000nan
75%2.0000nannan2.00003.60001.0000nan234.0000234.00002.000013.50002.50000.50002.65000.00000.300019.30002.5000nan
max2nannan6.045.9299.0nan2652654121.03.50.580.025.50.3137.762.5nan
" + ], + "text/plain": [ + "+--------+----------+----------------------+-----------------------+-----------------+---------------+------------+--------------------+--------------+--------------+--------------+-------------+--------+---------+------------+--------------+-----------------------+--------------+----------------------+-------------+\n", + "| | VendorID | tpep_pickup_datetime | tpep_dropoff_datetime | passenger_count | trip_distance | RatecodeID | store_and_fwd_flag | PULocationID | DOLocationID | payment_type | fare_amount | extra | mta_tax | tip_amount | tolls_amount | improvement_surcharge | total_amount | congestion_surcharge | airport_fee |\n", + "+--------+----------+----------------------+-----------------------+-----------------+---------------+------------+--------------------+--------------+--------------+--------------+-------------+--------+---------+------------+--------------+-----------------------+--------------+----------------------+-------------+\n", + "| count | 10000 | 10000 | 10000 | 10000 | 10000 | 10000 | 10000 | 10000 | 10000 | 10000 | 10000 | 10000 | 10000 | 10000 | 10000 | 10000 | 10000 | 10000 | 0 |\n", + "| unique | 2 | 8766 | 8745 | 7 | 1243 | 6 | 2 | 173 | 230 | 4 | 228 | 8 | 3 | 504 | 18 | 3 | 959 | 3 | 0 |\n", + "| top | nan | 2021-01-01 00:41:19 | 2021-01-02 00:00:00 | nan | nan | nan | N | nan | nan | nan | nan | nan | nan | nan | nan | nan | nan | nan | None |\n", + "| freq | nan | 4 | 7 | nan | nan | nan | 9808 | nan | nan | nan | nan | nan | nan | nan | nan | nan | nan | nan | 0 |\n", + "| mean | 1.6901 | nan | nan | 1.5080 | 3.1002 | 1.0712 | nan | 158.5551 | 154.7296 | 1.3819 | 11.8822 | 0.8259 | 0.4864 | 1.7846 | 0.2246 | 0.2945 | 16.9696 | 2.1063 | nan |\n", + "| std | 0.4625 | nan | nan | 1.1354 | 3.5970 | 1.0755 | nan | 70.9288 | 75.2504 | 0.5552 | 10.8420 | 1.1167 | 0.1041 | 2.4351 | 1.2730 | 0.0570 | 12.5023 | 0.9562 | nan |\n", + "| min | 1 | nan | nan | 0.0 | 0.0 | 1.0 | nan | 1 | 1 | 1 | -100.0 | -0.5 | -0.5 | -1.07 | -6.12 | -0.3 | -100.3 | -2.5 | nan |\n", + "| 25% | 1.0000 | nan | nan | 1.0000 | 1.0400 | 1.0000 | nan | 100.0000 | 83.0000 | 1.0000 | 6.0000 | 0.0000 | 0.5000 | 0.0000 | 0.0000 | 0.3000 | 10.3000 | 2.5000 | nan |\n", + "| 50% | 2.0000 | nan | nan | 1.0000 | 1.9300 | 1.0000 | nan | 152.0000 | 151.0000 | 1.0000 | 8.5000 | 0.5000 | 0.5000 | 1.5400 | 0.0000 | 0.3000 | 13.5500 | 2.5000 | nan |\n", + "| 75% | 2.0000 | nan | nan | 2.0000 | 3.6000 | 1.0000 | nan | 234.0000 | 234.0000 | 2.0000 | 13.5000 | 2.5000 | 0.5000 | 2.6500 | 0.0000 | 0.3000 | 19.3000 | 2.5000 | nan |\n", + "| max | 2 | nan | nan | 6.0 | 45.92 | 99.0 | nan | 265 | 265 | 4 | 121.0 | 3.5 | 0.5 | 80.0 | 25.5 | 0.3 | 137.76 | 2.5 | nan |\n", + "+--------+----------+----------------------+-----------------------+-----------------+---------------+------------+--------------------+--------------+--------------+--------------+-------------+--------+---------+------------+--------------+-----------------------+--------------+----------------------+-------------+" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%sqlcmd profile -t taxi" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Plotting" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%sqlplot histogram --table taxi --column trip_distance --bins 10" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%sqlplot boxplot --table taxi --column trip_distance" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "Removing NULLs, if there exists any from payment_type" + ], + "text/plain": [ + "Removing NULLs, if there exists any from payment_type" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%sqlplot bar --table taxi --column payment_type" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "Removing NULLs, if there exists any from payment_type" + ], + "text/plain": [ + "Removing NULLs, if there exists any from payment_type" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%sqlplot pie --table taxi --column payment_type" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "from sql.ggplot import ggplot, aes, geom_histogram" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "(ggplot(table=\"taxi\", mapping=aes(x=\"trip_distance\")) + geom_histogram(bins=10))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Clean up\n", + "\n", + "To stop and remove the container:" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES\n", + "12f699ee8e8e wh1isper/sparglim-server \"tini -- sparglim-se…\" About a minute ago Up About a minute 0.0.0.0:4040->4040/tcp, 0.0.0.0:15002->15002/tcp spark\n", + "f019407c6426 docker.dev.slicelife.com/onelogin-aws-assume-role:stable \"onelogin-aws-assume…\" 2 weeks ago Up 2 weeks heuristic_tu\n" + ] + } + ], + "source": [ + "! docker container ls" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%%capture out\n", + "! docker container ls --filter ancestor=wh1isper/sparglim-server --quiet" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Container id: 12f699ee8e8e\n" + ] + } + ], + "source": [ + "container_id = out.stdout.strip()\n", + "print(f\"Container id: {container_id}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "12f699ee8e8e\n" + ] + } + ], + "source": [ + "! docker container stop {container_id}" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "12f699ee8e8e\n" + ] + } + ], + "source": [ + "! docker container rm {container_id}" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + }, + "myst": { + "html_meta": { + "description lang=en": "Query using Spark SQL from Jupyter via JupySQL", + "keywords": "jupyter, sql, jupysql, spark", + "property=og:locale": "en_US" + } + }, + "vscode": { + "interpreter": { + "hash": "8de7291ac4f217ed756f77e1d71d41823fff9c4ffb13df0a183e9309929ad9aa" + } + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/noxfile.py b/noxfile.py index 92c2b8fb6..5a62c0f26 100644 --- a/noxfile.py +++ b/noxfile.py @@ -35,6 +35,8 @@ "pyodbc==4.0.34", "sqlalchemy-pytds", "python-tds", + "pyspark>=3.4.1", + "grpcio-status", ] diff --git a/setup.py b/setup.py index e1ba5911a..1c03419f8 100644 --- a/setup.py +++ b/setup.py @@ -71,6 +71,9 @@ "redshift-connector", "sqlalchemy-redshift", "clickhouse-sqlalchemy", + # following two dependencies required for spark + "pyspark", + "grpcio-status", ] setup( diff --git a/src/sql/_testing.py b/src/sql/_testing.py index 14a5e9675..041994e28 100644 --- a/src/sql/_testing.py +++ b/src/sql/_testing.py @@ -210,6 +210,10 @@ def get_tmp_dir(): "docker_ct": None, "query": {}, }, + "spark": { + "alias": "SparkSession", + "drivername": "SparkSession", + }, "clickhouse": { "drivername": "clickhouse+native", "username": "username", diff --git a/src/sql/command.py b/src/sql/command.py index 7b7bd168b..675420d74 100644 --- a/src/sql/command.py +++ b/src/sql/command.py @@ -5,7 +5,7 @@ from sql import parse, exceptions from sql.store import store -from sql.connection import ConnectionManager, is_pep249_compliant +from sql.connection import ConnectionManager, is_pep249_compliant, is_spark from sql.util import validate_nonidentifier_connection @@ -49,7 +49,11 @@ def __init__(self, magic, user_ns, line, cell) -> None: if ( one_arg and self.args.line[0] in user_ns - and (isinstance(user_ns[self.args.line[0]], Engine) or is_dbapi_connection_) + and ( + isinstance(user_ns[self.args.line[0]], Engine) + or is_dbapi_connection_ + or is_spark(user_ns[self.args.line[0]]) + ) ): line_for_command = [] add_conn = True diff --git a/src/sql/connection/__init__.py b/src/sql/connection/__init__.py index 7c48e624b..4d9dfb10a 100644 --- a/src/sql/connection/__init__.py +++ b/src/sql/connection/__init__.py @@ -2,7 +2,9 @@ ConnectionManager, SQLAlchemyConnection, DBAPIConnection, + SparkConnectConnection, is_pep249_compliant, + is_spark, PLOOMBER_DOCS_LINK_STR, default_alias_for_engine, ResultSetCollection, @@ -14,7 +16,9 @@ "ConnectionManager", "SQLAlchemyConnection", "DBAPIConnection", + "SparkConnectConnection", "is_pep249_compliant", + "is_spark", "PLOOMBER_DOCS_LINK_STR", "default_alias_for_engine", "ResultSetCollection", diff --git a/src/sql/connection/connection.py b/src/sql/connection/connection.py index 678ca30de..fc4552aaa 100644 --- a/src/sql/connection/connection.py +++ b/src/sql/connection/connection.py @@ -16,6 +16,9 @@ InternalError, ProgrammingError, ) + +from sql.run.sparkdataframe import handle_spark_dataframe + from IPython.core.error import UsageError import sqlglot import sqlparse @@ -257,6 +260,10 @@ def set( ) elif is_pep249_compliant(descriptor): cls.current = DBAPIConnection(descriptor, config=config, alias=alias) + elif is_spark(descriptor): + cls.current = SparkConnectConnection( + descriptor, config=config, alias=alias + ) else: existing = rough_dict_get(cls.connections, descriptor) if existing and existing.alias == alias: @@ -1060,6 +1067,82 @@ def to_table(self, table_name, data_frame, if_exists, index, schema=None): ) +class SparkConnectConnection(AbstractConnection): + is_dbapi_connection = False + + @telemetry.log_call("SparkConnectConnection", payload=True) + def __init__(self, payload, connection, alias=None, config=None): + try: + payload["engine"] = type(connection) + except Exception as e: + payload["engine_parsing_error"] = str(e) + self._driver = None + + # TODO: implement the dialect blacklist and add unit tests + self._requires_manual_commit = True if config is None else config.autocommit + + self._connection = connection + self._connection_class_name = type(connection).__name__ + + # calling init from AbstractConnection must be the last thing we do as it + # register the connection + super().__init__(alias=alias or self._connection_class_name) + + self.name = self._connection_class_name + + @property + def dialect(self): + """Returns a string with the SQL dialect name""" + return "spark2" + + def raw_execute(self, query, parameters=None): + """Run the query without any pre-processing""" + return handle_spark_dataframe(self._connection.sql(query)) + + def _get_database_information(self): + """ + Get the dialect, driver, and database server version info of current + connection + """ + return { + "dialect": self.dialect, + "driver": self._connection_class_name, + "server_version_info": self._connection.version, + } + + @property + def url(self): + """Returns None since Spark connections don't have a url""" + return None + + @property + def connection_sqlalchemy(self): + """ + Raises NotImplementedError since Spark connections don't have a SQLAlchemy + connection object + """ + raise NotImplementedError( + "This feature is only available for SQLAlchemy connections" + ) + + def to_table(self, table_name, data_frame, if_exists, index, schema=None): + mode = ( + "overwrite" + if if_exists == "replace" + else "append" + if if_exists == "append" + else "error" + ) + self._connection.createDataFrame(data_frame).write.mode(mode).saveAsTable( + f"{schema}.{table_name}" if schema else table_name + ) + + def close(self): + """Override of the abstract close as SparkSession is usually + shared with pyspark""" + pass + + def _check_if_duckdb_dbapi_connection(conn): """Check if the connection is a native duckdb connection""" # NOTE: duckdb defines df and pl to efficiently convert results to @@ -1154,6 +1237,26 @@ def is_pep249_compliant(conn): return True +def is_spark(conn): + """Check if it is a SparkSession by checking for available methods""" + + sparksession_methods = [ + "table", + "read", + "createDataFrame", + "sql", + "stop", + "catalog", + "version", + ] + for method_name in sparksession_methods: + # Checking whether the connection object has the method + if not hasattr(conn, method_name): + return False + + return True + + def default_alias_for_engine(engine): if not engine.url.username: # keeping this for compatibility diff --git a/src/sql/error_handler.py b/src/sql/error_handler.py index 0eac9292a..cd1dfb3cd 100644 --- a/src/sql/error_handler.py +++ b/src/sql/error_handler.py @@ -50,6 +50,7 @@ def _detailed_message_with_error_type(error, query): "error in your sql syntax", "incorrect syntax", "invalid sql", + "syntax_error", ] not_found_substrings = [ r"(\btable with name\b).+(\bdoes not exist\b)", diff --git a/src/sql/magic.py b/src/sql/magic.py index d34d32e2c..17a2a2a49 100644 --- a/src/sql/magic.py +++ b/src/sql/magic.py @@ -147,6 +147,15 @@ class SqlMagic(Magics, Configurable): config=True, help="Verbosity level. 0=minimal, 1=normal, 2=all", ) + lazy_execution = Bool( + default_value=False, + config=True, + help="Whether to evaluate using ResultSet which will " + "cause the plan to execute or just return a lazily " + "executed plan allowing validating schemas, " + "without expensive compute." + "Currently only supported for Spark Connection.", + ) named_parameters = Bool( default_value=False, config=True, diff --git a/src/sql/plot.py b/src/sql/plot.py index 10e7e7895..d3cecb759 100644 --- a/src/sql/plot.py +++ b/src/sql/plot.py @@ -268,10 +268,9 @@ def _min_max(conn, table, column, with_=None, use_backticks=False): """ if use_backticks: template_ = template_.replace('"', "`") - + table = table.replace('"', "`") template = Template(template_) query = template.render(table=table, column=column) - min_, max_ = conn.execute(query, with_).fetchone() return min_, max_ @@ -628,6 +627,7 @@ def _histogram( if use_backticks: template_ = template_.replace('"', "`") + table = table.replace('"', "`") template = Template(template_) @@ -663,6 +663,7 @@ def _histogram( if use_backticks: template_ = template_.replace('"', "`") + table = table.replace('"', "`") template = Template(template_) @@ -681,6 +682,7 @@ def _histogram( if use_backticks: template_ = template_.replace('"', "`") + table = table.replace('"', "`") template = Template(template_) @@ -835,6 +837,7 @@ def _bar(table, column, with_=None, conn=None): if use_backticks: template_ = template_.replace('"', "`") + table = table.replace('"', "`") template = Template(template_) query = template.render(table=table, x_=x_, height_=height_) @@ -854,6 +857,7 @@ def _bar(table, column, with_=None, conn=None): if use_backticks: template_ = template_.replace('"', "`") + table = table.replace('"', "`") template = Template(template_) query = template.render(table=table, column=column) @@ -1022,6 +1026,7 @@ def _pie(table, column, with_=None, conn=None): """ if use_backticks: template_ = template_.replace('"', "`") + table = table.replace('"', "`") template = Template(template_) query = template.render(table=table, labels_=labels_, size_=size_) @@ -1037,6 +1042,7 @@ def _pie(table, column, with_=None, conn=None): """ if use_backticks: template_ = template_.replace('"', "`") + table = table.replace('"', "`") template = Template(template_) query = template.render(table=table, column=column) diff --git a/src/sql/run/resultset.py b/src/sql/run/resultset.py index 8451150aa..4b977a8e2 100644 --- a/src/sql/run/resultset.py +++ b/src/sql/run/resultset.py @@ -434,7 +434,10 @@ def fetchmany(self, size): raise RuntimeError(f"Error running the query: {str(e)}") from e self.mark_fetching_as_done() return - + # spark doesn't support cursor + if hasattr(self._sqlaproxy, "dataframe"): + self._results = [] + self._pretty_table.clear() self._extend_results(returned) if len(returned) < size: @@ -458,6 +461,9 @@ def fetch_for_repr_if_needed(self): def fetchall(self): if not self._done_fetching(): + if hasattr(self._sqlaproxy, "dataframe"): + self._results = [] + self._pretty_table.clear() self._extend_results(self.sqlaproxy.fetchall()) self.mark_fetching_as_done() @@ -500,6 +506,8 @@ def _convert_to_data_frame( # maybe create accessors in the connection objects? if result_set._conn.is_dbapi_connection: native_connection = result_set.sqlaproxy + elif hasattr(result_set.sqlaproxy, "dataframe"): + return result_set.sqlaproxy.dataframe.toPandas() else: native_connection = result_set._conn._connection.connection diff --git a/src/sql/run/run.py b/src/sql/run/run.py index 11312c450..a1e34aa7d 100644 --- a/src/sql/run/run.py +++ b/src/sql/run/run.py @@ -52,6 +52,8 @@ def run_statements(conn, sql, config, parameters=None): # regular query else: result = conn.raw_execute(statement, parameters=parameters) + if is_spark(conn.dialect) and config.lazy_execution: + return result.dataframe if ( config.feedback >= 1 @@ -69,6 +71,10 @@ def is_postgres_or_redshift(dialect): return "postgres" in str(dialect) or "redshift" in str(dialect) +def is_spark(dialect): + return "spark" in str(dialect) + + def select_df_type(resultset, config): """ Converts the input resultset to either a Pandas DataFrame diff --git a/src/sql/run/sparkdataframe.py b/src/sql/run/sparkdataframe.py new file mode 100644 index 000000000..81644b1e2 --- /dev/null +++ b/src/sql/run/sparkdataframe.py @@ -0,0 +1,52 @@ +try: + from pyspark.sql import DataFrame + from pyspark.sql.connect.dataframe import DataFrame as CDataFrame +except ModuleNotFoundError: + DataFrame = None + CDataFrame = None + +from sql import exceptions + + +def handle_spark_dataframe(dataframe, should_cache=False): + """Execute a ResultSet sqlaproxy using pysark module.""" + if not DataFrame and not CDataFrame: + raise exceptions.MissingPackageError("pysark not installed") + + return SparkResultProxy(dataframe, dataframe.columns, should_cache) + + +class SparkResultProxy(object): + """A fake class that pretends to behave like the ResultProxy from + SqlAlchemy. + """ + + dataframe = None + + def __init__(self, dataframe, headers, should_cache): + self.dataframe = dataframe + self.fetchall = dataframe.collect + self.rowcount = dataframe.count() + self.keys = lambda: headers + self.cursor = SparkCursor(headers) + self.returns_rows = True + if should_cache: + self.dataframe.cache() + + def fetchmany(self, size): + return self.dataframe.take(size) + + def fetchone(self): + return self.dataframe.head() + + def close(self): + self.dataframe.unpersist() + + +class SparkCursor(object): + """Class to extend to give SqlAlchemy Cursor like behaviour""" + + description = None + + def __init__(self, headers) -> None: + self.description = headers diff --git a/src/sql/stats.py b/src/sql/stats.py index 0fc12e87f..b03252154 100644 --- a/src/sql/stats.py +++ b/src/sql/stats.py @@ -45,7 +45,6 @@ def _summary_stats_one_by_one(conn, table, column, with_=None): other = list(conn.execute(query, with_).fetchone()) keys = ["q1", "med", "q3", "mean", "N"] - return {k: float(v) for k, v in zip(keys, percentiles + other)} diff --git a/src/sql/util.py b/src/sql/util.py index b0d9c0c3b..4f7fd6d4f 100644 --- a/src/sql/util.py +++ b/src/sql/util.py @@ -543,6 +543,9 @@ def is_non_sqlalchemy_error(error): "pyodbc.ProgrammingError", # Clickhouse errors "DB::Exception:", + # Pyspark + "UNRESOLVED_ROUTINE", + "PARSE_SYNTAX_ERROR", ] return any(msg in str(error) for msg in specific_db_errors) diff --git a/src/tests/integration/conftest.py b/src/tests/integration/conftest.py index e0a0b5e7f..cad899583 100644 --- a/src/tests/integration/conftest.py +++ b/src/tests/integration/conftest.py @@ -2,6 +2,7 @@ from pathlib import Path import shutil import pandas as pd +from pyspark.sql import SparkSession import pytest from sqlalchemy import MetaData, Table, create_engine, text import uuid @@ -288,6 +289,49 @@ def setup_duckDB_native(test_table_name_dict): conn.close() +@pytest.fixture(scope="session") +def setup_spark(test_table_name_dict): + import os + import shutil + import sys + + os.environ["PYSPARK_PYTHON"] = sys.executable + os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable + spark = SparkSession.builder.master("local[1]").enableHiveSupport().getOrCreate() + load_generic_testing_data_spark(spark, test_table_name_dict) + yield spark + spark.stop() + shutil.rmtree("metastore_db", ignore_errors=True) + shutil.rmtree("spark-warehouse", ignore_errors=True) + os.remove("derby.log") + + +def load_generic_testing_data_spark(spark: SparkSession, test_table_name_dict): + spark.createDataFrame( + pd.DataFrame( + {"taxi_driver_name": ["Eric Ken", "John Smith", "Kevin Kelly"] * 15} + ) + ).createOrReplaceTempView(test_table_name_dict["taxi"]) + spark.createDataFrame( + pd.DataFrame({"x": range(0, 5), "y": range(5, 10)}) + ).createOrReplaceTempView(test_table_name_dict["plot_something"]) + spark.createDataFrame( + pd.DataFrame({"numbers_elements": [1, 2, 3] * 20}) + ).createOrReplaceTempView(test_table_name_dict["numbers"]) + + +@pytest.fixture +def ip_with_spark(ip_empty, setup_spark): + alias = "SparkSession" + + ip_empty.push({"conn": setup_spark}) + # Select database engine, use different sqlite database endpoint + ip_empty.run_cell("%sql " + "conn" + " --alias " + alias) + yield ip_empty + # Disconnect database + ip_empty.run_cell("%sql -x " + alias) + + def load_generic_testing_data_duckdb_native(ip, test_table_name_dict): ip.run_cell("import pandas as pd") ip.run_cell( diff --git a/src/tests/integration/test_connection.py b/src/tests/integration/test_connection.py index 5d867b6e6..63dc71ac9 100644 --- a/src/tests/integration/test_connection.py +++ b/src/tests/integration/test_connection.py @@ -8,7 +8,12 @@ import pytest -from sql.connection import SQLAlchemyConnection, DBAPIConnection, ConnectionManager +from sql.connection import ( + SQLAlchemyConnection, + DBAPIConnection, + ConnectionManager, + SparkConnectConnection, +) from sql import _testing from sql.connection import connection @@ -92,6 +97,7 @@ def test_connection_properties(dynamic_db, request, Constructor, alias, dialect) partial(DBAPIConnection, alias="another-alias"), "another-alias", ], + ["setup_spark", SparkConnectConnection, "SparkSession"], ], ) def test_connection_identifiers( diff --git a/src/tests/integration/test_generic_db_operations.py b/src/tests/integration/test_generic_db_operations.py index 47325ed16..95722758a 100644 --- a/src/tests/integration/test_generic_db_operations.py +++ b/src/tests/integration/test_generic_db_operations.py @@ -21,6 +21,7 @@ "ip_with_Snowflake", "ip_with_oracle", "ip_with_clickhouse", + "ip_with_spark", ] @@ -54,6 +55,7 @@ def mock_log_api(monkeypatch): ("ip_with_clickhouse", "", "LIMIT 3"), ("ip_with_oracle", "", "FETCH FIRST 3 ROWS ONLY"), ("ip_with_MSSQL", "TOP 3", ""), + ("ip_with_spark", "", "LIMIT 3"), ], ) def test_run_query( @@ -93,6 +95,7 @@ def test_run_query( "ip_with_Snowflake", "ip_with_redshift", "ip_with_clickhouse", + "ip_with_spark", ], ) def test_handle_multiple_open_result_sets( @@ -151,6 +154,7 @@ def test_handle_multiple_open_result_sets( "No engine for table " ), ), + ("ip_with_spark", "--no-index"), ], ) def test_create_table_with_indexed_df( @@ -218,6 +222,7 @@ def get_connection_count(ip_with_dynamic_db): ("ip_with_MSSQL", 1), ("ip_with_Snowflake", 1), ("ip_with_clickhouse", 1), + ("ip_with_spark", 1), ], ) def test_active_connection_number(ip_with_dynamic_db, expected, request): @@ -273,6 +278,7 @@ def test_close_and_connect( ("ip_with_Snowflake", "snowflake", "snowflake"), ("ip_with_oracle", "oracle", "oracledb"), ("ip_with_clickhouse", "clickhouse", "native"), + ("ip_with_spark", "spark2", "SparkSession"), ], ) def test_telemetry_execute_command_has_connection_info( @@ -337,6 +343,7 @@ def test_telemetry_execute_command_has_connection_info( ("ip_with_Snowflake"), ("ip_with_duckDB_native"), ("ip_with_redshift"), + ("ip_with_spark"), pytest.param( "ip_with_MSSQL", marks=pytest.mark.xfail(reason="sqlglot does not support SQL server"), @@ -419,6 +426,9 @@ def test_sqlplot_histogram(ip_with_dynamic_db, cell, request, test_table_name_di reason="Plotting from snippet not working in clickhouse" ), ), + pytest.param( + "ip_with_spark", marks=pytest.mark.xfail(reason=BOX_PLOT_FAIL_REASON) + ), ], ) def test_sqlplot_boxplot(ip_with_dynamic_db, cell, request, test_table_name_dict): @@ -442,6 +452,7 @@ def test_sqlplot_boxplot(ip_with_dynamic_db, cell, request, test_table_name_dict "ip_with_duckDB", "ip_with_redshift", "ip_with_MSSQL", + "ip_with_spark", ], ) def test_sqlplot_bar(ip_with_dynamic_db, request, test_table_name_dict): @@ -464,7 +475,13 @@ def test_sqlplot_bar(ip_with_dynamic_db, request, test_table_name_dict): @pytest.mark.parametrize( "ip_with_dynamic_db", - ["ip_with_postgreSQL", "ip_with_duckDB", "ip_with_redshift", "ip_with_MSSQL"], + [ + "ip_with_postgreSQL", + "ip_with_duckDB", + "ip_with_redshift", + "ip_with_MSSQL", + "ip_with_spark", + ], ) def test_sqlplot_pie(ip_with_dynamic_db, request, test_table_name_dict): plt.cla() @@ -517,6 +534,7 @@ def test_sqlplot_pie(ip_with_dynamic_db, request, test_table_name_dict): reason="Plotting from snippet not working in clickhouse" ), ), + "ip_with_spark", ], ) def test_sqlplot_using_schema(ip_with_dynamic_db, request): @@ -569,6 +587,7 @@ def test_sqlplot_using_schema(ip_with_dynamic_db, request): ("ip_with_Snowflake"), ("ip_with_oracle"), ("ip_with_clickhouse"), + ("ip_with_spark"), ], ) def test_sqlcmd_test(ip_with_dynamic_db, request, test_table_name_dict): @@ -604,6 +623,7 @@ def test_sqlcmd_test(ip_with_dynamic_db, request, test_table_name_dict): ), ("ip_with_oracle"), ("ip_with_clickhouse"), + ("ip_with_spark"), ], ) def test_profile_data_mismatch(ip_with_dynamic_db, request, capsys): @@ -786,6 +806,25 @@ def test_profile_data_mismatch(ip_with_dynamic_db, request, capsys): }, "Following statistics are not available in", ), + ( + "ip_with_spark", + "taxi", + ["taxi_driver_name"], + { + "count": [45], + "mean": [math.nan], + "min": ["Eric Ken"], + "max": ["Kevin Kelly"], + "unique": [3], + "freq": [15], + "top": ["Eric Ken"], + "std": [math.nan], + "25%": [math.nan], + "50%": [math.nan], + "75%": [math.nan], + }, + None, + ), ], ) def test_sqlcmd_profile( @@ -847,6 +886,10 @@ def test_sqlcmd_profile( ("ip_with_MSSQL"), ("ip_with_Snowflake"), ("ip_with_clickhouse"), + pytest.param( + "ip_with_spark", + marks=pytest.mark.xfail(reason="Not Implemented"), + ), ], ) def test_sqlcmd_columns(ip_with_dynamic_db, table, request, test_table_name_dict): @@ -873,6 +916,10 @@ def test_sqlcmd_columns(ip_with_dynamic_db, table, request, test_table_name_dict ("ip_with_MSSQL"), ("ip_with_Snowflake"), ("ip_with_clickhouse"), + pytest.param( + "ip_with_spark", + marks=pytest.mark.xfail(reason="Not Implemented"), + ), ], ) def test_sqlcmd_tables(ip_with_dynamic_db, request): @@ -927,6 +974,7 @@ def test_sql_query(ip_with_dynamic_db, cell, request, test_table_name_dict): "ip_with_Snowflake", "ip_with_oracle", "ip_with_clickhouse", + "ip_with_spark", ], ) def test_sql_query_cte(ip_with_dynamic_db, request, test_table_name_dict, cell): @@ -957,6 +1005,7 @@ def test_sql_query_cte(ip_with_dynamic_db, request, test_table_name_dict, cell): "ip_with_clickhouse", marks=pytest.mark.xfail(reason="Not yet implemented"), ), + "ip_with_spark", ], ) def test_sql_error_suggests_using_cte(ip_with_dynamic_db, request): @@ -987,6 +1036,7 @@ def test_sql_error_suggests_using_cte(ip_with_dynamic_db, request): "ip_with_MSSQL", "ip_with_oracle", "ip_with_clickhouse", + "ip_with_spark", ], ) def test_results_sets_are_closed(ip_with_dynamic_db, request, test_table_name_dict): @@ -1024,6 +1074,7 @@ def test_results_sets_are_closed(ip_with_dynamic_db, request, test_table_name_di "ip_with_MSSQL", "ip_with_oracle", "ip_with_clickhouse", + "ip_with_spark", ], ) @pytest.mark.parametrize( @@ -1150,6 +1201,7 @@ def test_autocommit_retrieve_existing_resultssets_duckdb_from( CREATE_TABLE, marks=pytest.mark.xfail(reason="Not working yet"), ), + ("ip_with_spark", CREATE_TABLE), ], ) def test_autocommit_create_table_single_cell( @@ -1222,6 +1274,7 @@ def test_autocommit_create_table_single_cell( CREATE_TABLE, marks=pytest.mark.xfail(reason="Not working yet"), ), + ("ip_with_spark", CREATE_TABLE), ], ) def test_autocommit_create_table_multiple_cells( @@ -1408,6 +1461,20 @@ def test_autocommit_create_table_multiple_cells( ["Table with name mysnip does not exist!"], "RuntimeError", ), + ( + "ip_with_spark", + "mysnippet", + [ + "Cannot resolve function `not_a_function` on search path", + ], + "RuntimeError", + ), + ( + "ip_with_spark", + "mysnip", + ["Cannot resolve function `not_a_function` on search path"], + "RuntimeError", + ), ], ids=[ "no-typo-postgreSQL", @@ -1428,6 +1495,8 @@ def test_autocommit_create_table_multiple_cells( "with-typo-redshift", "no-typo-duckDB-native", "with-typo-duckDB-native", + "no-typo-spark", + "with-typo-spark", ], ) def test_query_snippet_invalid_function_error_message( @@ -1456,7 +1525,7 @@ def test_query_snippet_invalid_function_error_message( # Save result and test error message result_error = excinfo.value.error_type result_msg = str(excinfo.value) - + print(result_msg) assert error_type == result_error assert all(msg in result_msg for msg in error_msgs) @@ -1502,6 +1571,7 @@ def test_query_snippet_invalid_function_error_message( "No engine for table " ), ), + ("ip_with_spark", "--no-index"), ], ) def test_persist_in_schema(ip_with_dynamic_db, args, request, test_table_name_dict): diff --git a/src/tests/integration/test_stats.py b/src/tests/integration/test_stats.py index d8f93439d..fe33b18f5 100644 --- a/src/tests/integration/test_stats.py +++ b/src/tests/integration/test_stats.py @@ -1,7 +1,7 @@ import pytest from sql.stats import _summary_stats -from sql.connection import SQLAlchemyConnection +from sql.connection import SQLAlchemyConnection, SparkConnectConnection @pytest.mark.parametrize( @@ -26,3 +26,23 @@ def test_summary_stats(fixture_name, request, test_table_name_dict): "mean": 2.0, "N": 5.0, } + + +@pytest.mark.parametrize( + "fixture_name", + [ + "setup_spark", + ], +) +def test_summary_stats_spark(fixture_name, request, test_table_name_dict): + conn = SparkConnectConnection(request.getfixturevalue(fixture_name)) + table = test_table_name_dict["plot_something"] + column = "x" + + assert _summary_stats(conn, table, column) == { + "q1": 1.0, + "med": 2.0, + "q3": 3.0, + "mean": 2.0, + "N": 5.0, + } diff --git a/src/tests/test_connection.py b/src/tests/test_connection.py index 9a9a0a399..8369ff690 100644 --- a/src/tests/test_connection.py +++ b/src/tests/test_connection.py @@ -13,6 +13,7 @@ from sqlalchemy.engine import Engine from sqlalchemy import exc + from sql.connection import connection as connection_module import sql.connection from sql.connection import ( @@ -21,6 +22,7 @@ ConnectionManager, is_pep249_compliant, default_alias_for_engine, + is_spark, ResultSetCollection, detect_duckdb_summarize_or_select, ) @@ -41,6 +43,34 @@ def mock_database(monkeypatch, cleanup): monkeypatch.setattr(sqlalchemy, "create_engine", Mock()) +def mock_sparksession(): + mock = Mock( + spec=[ + "table", + "read", + "createDataFrame", + "sql", + "stop", + "catalog", + "version", + ] + ) + return mock + + +def mock_not_sparksession(): + mock = Mock( + spec=[ + "read", + "readStream", + "createDataFrame", + "sql", + "version", + ] + ) + return mock + + @pytest.fixture def mock_postgres(monkeypatch, cleanup): monkeypatch.setitem(sys.modules, "psycopg2", Mock()) @@ -457,6 +487,24 @@ def test_is_pep249_compliant(conn, expected): assert is_pep249_compliant(conn) is expected +@pytest.mark.parametrize( + "descriptor, expected", + [ + [sqlite3.connect(""), False], + [duckdb.connect(""), False], + [create_engine("sqlite://"), False], + [mock_sparksession(), True], + [mock_not_sparksession(), False], + [None, False], + [object(), False], + ["not_a_valid_connection", False], + [0, False], + ], +) +def test_is_spark(descriptor, expected): + assert is_spark(descriptor) is expected + + def test_close_all(ip_empty, monkeypatch): connections = {} monkeypatch.setattr(ConnectionManager, "connections", connections) @@ -590,6 +638,22 @@ def test_set_dbapi(monkeypatch, callable_, key): assert ConnectionManager.current == conn +@pytest.mark.parametrize( + "spark, key", + [ + [mock_sparksession(), "Mock"], + ], +) +def test_set_spark(monkeypatch, spark, key): + connections = {} + monkeypatch.setattr(ConnectionManager, "connections", connections) + + conn = ConnectionManager.set(spark, displaycon=False) + + assert connections == {key: conn} + assert ConnectionManager.current == conn + + def test_set_with_alias(monkeypatch): connections = {} monkeypatch.setattr(ConnectionManager, "connections", connections)