Skip to content

Commit

Permalink
[SPARK-26364][PYTHON][TESTING] Clean up imports in test_pandas_udf*
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Clean up unconditional import statements and move them to the top.

Conditional imports (pandas, numpy, pyarrow) are left as-is.

## How was this patch tested?

Exising tests.

Closes apache#23314 from icexelloss/clean-up-test-imports.

Authored-by: Li Jin <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
icexelloss authored and HyukjinKwon committed Dec 14, 2018
1 parent 362e472 commit 160e583
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 163 deletions.
16 changes: 4 additions & 12 deletions python/pyspark/sql/tests/test_pandas_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,23 @@

import unittest

from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
from pyspark.sql.types import *
from pyspark.sql.utils import ParseException
from pyspark.rdd import PythonEvalType
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message
from pyspark.testing.utils import QuietTest

from py4j.protocol import Py4JJavaError


@unittest.skipIf(
not have_pandas or not have_pyarrow,
pandas_requirement_message or pyarrow_requirement_message)
class PandasUDFTests(ReusedSQLTestCase):

def test_pandas_udf_basic(self):
from pyspark.rdd import PythonEvalType
from pyspark.sql.functions import pandas_udf, PandasUDFType

udf = pandas_udf(lambda x: x, DoubleType())
self.assertEqual(udf.returnType, DoubleType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
Expand Down Expand Up @@ -65,10 +66,6 @@ def test_pandas_udf_basic(self):
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

def test_pandas_udf_decorator(self):
from pyspark.rdd import PythonEvalType
from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.types import StructType, StructField, DoubleType

@pandas_udf(DoubleType())
def foo(x):
return x
Expand Down Expand Up @@ -114,8 +111,6 @@ def foo(x):
self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)

def test_udf_wrong_arg(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType

with QuietTest(self.sc):
with self.assertRaises(ParseException):
@pandas_udf('blah')
Expand Down Expand Up @@ -151,9 +146,6 @@ def foo(k, v, w):
return k

def test_stopiteration_in_udf(self):
from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
from py4j.protocol import Py4JJavaError

def foo(x):
raise StopIteration()

Expand Down
39 changes: 3 additions & 36 deletions python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

import unittest

from pyspark.rdd import PythonEvalType
from pyspark.sql.functions import array, explode, col, lit, mean, sum, \
udf, pandas_udf, PandasUDFType
from pyspark.sql.types import *
from pyspark.sql.utils import AnalysisException
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
Expand All @@ -31,7 +34,6 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):

@property
def data(self):
from pyspark.sql.functions import array, explode, col, lit
return self.spark.range(10).toDF('id') \
.withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \
.withColumn("v", explode(col('vs'))) \
Expand All @@ -40,8 +42,6 @@ def data(self):

@property
def python_plus_one(self):
from pyspark.sql.functions import udf

@udf('double')
def plus_one(v):
assert isinstance(v, (int, float))
Expand All @@ -51,7 +51,6 @@ def plus_one(v):
@property
def pandas_scalar_plus_two(self):
import pandas as pd
from pyspark.sql.functions import pandas_udf, PandasUDFType

@pandas_udf('double', PandasUDFType.SCALAR)
def plus_two(v):
Expand All @@ -61,17 +60,13 @@ def plus_two(v):

@property
def pandas_agg_mean_udf(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType

@pandas_udf('double', PandasUDFType.GROUPED_AGG)
def avg(v):
return v.mean()
return avg

@property
def pandas_agg_sum_udf(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType

@pandas_udf('double', PandasUDFType.GROUPED_AGG)
def sum(v):
return v.sum()
Expand All @@ -80,16 +75,13 @@ def sum(v):
@property
def pandas_agg_weighted_mean_udf(self):
import numpy as np
from pyspark.sql.functions import pandas_udf, PandasUDFType

@pandas_udf('double', PandasUDFType.GROUPED_AGG)
def weighted_mean(v, w):
return np.average(v, weights=w)
return weighted_mean

def test_manual(self):
from pyspark.sql.functions import pandas_udf, array

df = self.data
sum_udf = self.pandas_agg_sum_udf
mean_udf = self.pandas_agg_mean_udf
Expand Down Expand Up @@ -118,8 +110,6 @@ def test_manual(self):
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())

def test_basic(self):
from pyspark.sql.functions import col, lit, mean

df = self.data
weighted_mean_udf = self.pandas_agg_weighted_mean_udf

Expand Down Expand Up @@ -150,9 +140,6 @@ def test_basic(self):
self.assertPandasEqual(expected4.toPandas(), result4.toPandas())

def test_unsupported_types(self):
from pyspark.sql.types import DoubleType, MapType
from pyspark.sql.functions import pandas_udf, PandasUDFType

with QuietTest(self.sc):
with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
pandas_udf(
Expand All @@ -173,8 +160,6 @@ def mean_and_std_udf(v):
return {v.mean(): v.std()}

def test_alias(self):
from pyspark.sql.functions import mean

df = self.data
mean_udf = self.pandas_agg_mean_udf

Expand All @@ -187,8 +172,6 @@ def test_mixed_sql(self):
"""
Test mixing group aggregate pandas UDF with sql expression.
"""
from pyspark.sql.functions import sum

df = self.data
sum_udf = self.pandas_agg_sum_udf

Expand Down Expand Up @@ -225,8 +208,6 @@ def test_mixed_udfs(self):
"""
Test mixing group aggregate pandas UDF with python UDF and scalar pandas UDF.
"""
from pyspark.sql.functions import sum

df = self.data
plus_one = self.python_plus_one
plus_two = self.pandas_scalar_plus_two
Expand Down Expand Up @@ -292,8 +273,6 @@ def test_multiple_udfs(self):
"""
Test multiple group aggregate pandas UDFs in one agg function.
"""
from pyspark.sql.functions import sum, mean

df = self.data
mean_udf = self.pandas_agg_mean_udf
sum_udf = self.pandas_agg_sum_udf
Expand All @@ -315,8 +294,6 @@ def test_multiple_udfs(self):
self.assertPandasEqual(expected1, result1)

def test_complex_groupby(self):
from pyspark.sql.functions import sum

df = self.data
sum_udf = self.pandas_agg_sum_udf
plus_one = self.python_plus_one
Expand Down Expand Up @@ -359,8 +336,6 @@ def test_complex_groupby(self):
self.assertPandasEqual(expected7.toPandas(), result7.toPandas())

def test_complex_expressions(self):
from pyspark.sql.functions import col, sum

df = self.data
plus_one = self.python_plus_one
plus_two = self.pandas_scalar_plus_two
Expand Down Expand Up @@ -434,7 +409,6 @@ def test_complex_expressions(self):
self.assertPandasEqual(expected3, result3)

def test_retain_group_columns(self):
from pyspark.sql.functions import sum
with self.sql_conf({"spark.sql.retainGroupColumns": False}):
df = self.data
sum_udf = self.pandas_agg_sum_udf
Expand All @@ -444,17 +418,13 @@ def test_retain_group_columns(self):
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())

def test_array_type(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType

df = self.data

array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array<double>', PandasUDFType.GROUPED_AGG)
result1 = df.groupby('id').agg(array_udf(df['v']).alias('v2'))
self.assertEquals(result1.first()['v2'], [1.0, 2.0])

def test_invalid_args(self):
from pyspark.sql.functions import mean

df = self.data
plus_one = self.python_plus_one
mean_udf = self.pandas_agg_mean_udf
Expand All @@ -478,9 +448,6 @@ def test_invalid_args(self):
df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()

def test_register_vectorized_udf_basic(self):
from pyspark.sql.functions import pandas_udf
from pyspark.rdd import PythonEvalType

sum_pandas_udf = pandas_udf(
lambda v: v.sum(), "integer", PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)

Expand Down
Loading

0 comments on commit 160e583

Please sign in to comment.