杰瑞科技汇

Spark UDAF Python如何实现自定义聚合?

Of course! Let's dive deep into Spark UDAFs (User-Defined Aggregate Functions) in Python. This is a powerful feature for performing complex aggregations that go beyond the built-in functions like sum(), avg(), or count().

Spark UDAF Python如何实现自定义聚合?-图1
(图片来源网络,侵删)

We'll cover:

  1. What is a UDAF? (The "Why")
  2. The Old Way: pyspark.sql.GroupedData.agg() with F.udaf() (For Spark 2.3+)
  3. The Modern & Recommended Way: pyspark.sql.functions.pandas_udf() (Since Spark 2.4)
  4. A Complete, Practical Example
  5. Key Differences and When to Use What

What is a UDAF?

A User-Defined Aggregate Function (UDAF) is a custom aggregation operation that you define. Unlike a standard UDF (User-Defined Function) that operates on a single row at a time, a UDAF operates on a group of rows and produces a single aggregated output for that group.

Think of it as extending Spark's built-in aggregators (sum, avg, max) to handle your own custom logic.

Classic Examples of UDAFs:

Spark UDAF Python如何实现自定义聚合?-图2
(图片来源网络,侵删)
  • Calculating the median of a group of numbers.
  • Finding the top-N most frequent items in a group.
  • Concatenating strings with a custom separator.
  • Calculating a custom weighted average.

The Old Way: F.udaf() (The Declarative Approach)

Before Spark 2.4, you had to define a UDAF by creating a class that inherits from UserDefinedAggregateFunction. This approach is more verbose and requires you to define several abstract methods, but it's very powerful as it's purely based on Spark's internal APIs and doesn't require any external libraries.

Key Components of the UserDefinedAggregateFunction class:

  • inputSchema(): Defines the schema of the input columns.
  • bufferSchema(): Defines the schema of the intermediate "accumulator" or "buffer" that holds the state of the aggregation.
  • dataType(): Defines the data type of the final result.
  • deterministic(): A boolean indicating if the function always returns the same output for the same input.
  • initialize(): Initializes the accumulator/buffer for a new aggregation.
  • update(): Updates the accumulator with a row from the input DataFrame.
  • merge(): Merges two accumulators together (used in distributed environments).
  • evaluate(): Computes the final result from the accumulator.

Example: Calculating the average of a column.

This is a simple example to illustrate the structure. In practice, you'd just use F.avg().

Spark UDAF Python如何实现自定义聚合?-图3
(图片来源网络,侵删)
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.sql.types import *
from pyspark.sql.udaf import UserDefinedAggregateFunction
# --- 1. Define the UDAF Class ---
class AverageUDAF(UserDefinedAggregateFunction):
    # Input schema: we are aggregating a single column of Doubles
    def inputSchema(self):
        return StructType([StructField("input_column", DoubleType())])
    # Buffer schema: we need to store the sum and the count
    def bufferSchema(self):
        return StructType([
            StructField("sum", DoubleType()),
            StructField("count", LongType())
        ])
    # Return type of the final result
    def dataType(self):
        return DoubleType()
    def deterministic(self):
        return True
    # Initialize the buffer with zeros
    def initialize(self, buffer):
        return (0.0, 0) # sum, count
    # Update the buffer with a new row
    def update(self, buffer, input):
        # input is a Row object from the inputSchema
        if input[0] is not None:
            return (buffer[0] + input[0], buffer[1] + 1)
        return buffer
    # Merge two buffers (e.g., from different partitions)
    def merge(self, buffer1, buffer2):
        return (buffer1[0] + buffer2[0], buffer1[1] + buffer2[1])
    # Calculate the final result
    def evaluate(self, buffer):
        return buffer[0] / buffer[1] if buffer[1] != 0 else 0.0
# --- 2. Use the UDAF ---
spark = SparkSession.builder.appName("UDAFExample").getOrCreate()
data = [(1,), (2,), (3,), (4,), (5,)]
df = spark.createDataFrame(data, ["value"])
# Register the UDAF
average_udaf = AverageUDAF()
# Use it in a groupBy or aggregation
df.agg(average_udaf(col("value")).alias("average_value")).show()
# +--------------+
# |average_value|
# +--------------+
# |           3.0|
# +--------------+

This works, but as you can see, it's quite a lot of boilerplate code for a simple average.


The Modern & Recommended Way: pandas_udf() (The Vectorized Approach)

Since Spark 2.4, the recommended way to create UDAFs in Python is by using pandas_udf. This approach leverages Apache Arrow to operate on Pandas Series (or Python iterables) directly, which is significantly faster than the row-by-row processing of older UDFs.

There are two types of pandas_udf:

  1. Scalar Pandas UDF: Operates on Pandas Series (a column) and returns a Pandas Series (a column). This is NOT for aggregation.
  2. Grouped Map Pandas UDF (or Aggregation UDF): This is what we want for UDAFs. It operates on groups of data, converting them to Pandas DataFrames, and you apply your aggregation logic to them.

The most common and performant type for UDAFs is the Grouped Aggregate Pandas UDF.

Syntax for a Grouped Aggregate Pandas UDF:

from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import *
# Define the function with the @pandas_udf decorator
# The decorator specifies the return data type
@pandas_udf(returnType=DoubleType())
def my_aggregation_function(s: pd.Series) -> float:
    # Your aggregation logic here
    # 's' is a Pandas Series representing the entire group
    return s.mean() # Example: calculate the mean

How it works:

  1. Spark groups the data as it normally would.
  2. For each group, it passes the data to your Python function as a Pandas Series (or DataFrame if you use itertools.itertuples).
  3. Your function performs the aggregation using highly optimized Pandas operations.
  4. The result is returned to Spark and combined into the final DataFrame.

A Complete, Practical Example: Weighted Average

Let's calculate a weighted average. The formula is: Weighted Average = Σ (value * weight) / Σ (weight)

We'll use the modern pandas_udf approach.

Step 1: Setup Data

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, pandas_udf, PandasUDFType
from pyspark.sql.types import *
import pandas as pd
spark = SparkSession.builder \
    .appName("PandasUDAFExample") \
    .config("spark.sql.execution.arrow.pyspark.enabled", "true") \
    .getOrCreate()
# Sample data: product, category, value, weight
data = [
    ("A", "Electronics", 100, 0.5),
    ("B", "Electronics", 200, 1.0),
    ("C", "Electronics", 150, 0.75),
    ("D", "Furniture", 300, 2.0),
    ("E", "Furniture", 400, 1.5),
    ("F", "Furniture", 350, 1.8),
    ("A", "Electronics", 110, 0.6), # Another entry for product A
]
df = spark.createDataFrame(data, ["product", "category", "value", "weight"])
df.show()

Output:

+-------+-----------+-----+------+
|product|  category|value|weight|
+-------+-----------+-----+------+
|      A|Electronics|  100|   0.5|
|      B|Electronics|  200|   1.0|
|      C|Electronics|  150|  0.75|
|      D|   Furniture|  300|   2.0|
|      E|   Furniture|  400|   1.5|
|      F|   Furniture|  350|   1.8|
|      A|Electronics|  110|   0.6|
+-------+-----------+-----+------+

Step 2: Define the Weighted Average UDAF

# Define the return type
weighted_avg_return_type = DoubleType()
# Define the pandas_udf
@pandas_udf(weighted_avg_return_type)
def weighted_avg_udf(values: pd.Series, weights: pd.Series) -> float:
    """
    Calculates the weighted average of a series of values.
    :param values: A pandas Series of values.
    :param weights: A pandas Series of corresponding weights.
    :return: A single float representing the weighted average.
    """
    # Handle potential division by zero or empty series
    if weights.sum() == 0 or len(values) == 0:
        return 0.0
    return (values * weights).sum() / weights.sum()

Step 3: Use the UDAF in a GroupBy Operation

# Group by category and apply the UDAF
# Note: The UDAF takes multiple columns as input
result_df = df.groupBy("category").agg(
    weighted_avg_udf(col("value"), col("weight")).alias("weighted_avg_value")
)
result_df.show()

Output:

+-----------+-------------------+
|   category|weighted_avg_value|
+-----------+-------------------+
|   Furniture| 367.9166666666667|
|Electronics|  153.9130434782609|
+-----------+-------------------+

Explanation of the result for "Electronics":

  • Values: [100, 200, 150, 110]
  • Weights: [0.5, 1.0, 0.75, 0.6]
  • Sum of (value * weight) = (100*0.5) + (200*1.0) + (150*0.75) + (110*0.6) = 50 + 200 + 112.5 + 66 = 5
  • Sum of weights = 5 + 1.0 + 0.75 + 0.6 = 85
  • Weighted Average = 5 / 2.8591

The calculation matches perfectly!


Key Differences and When to Use What

Feature UserDefinedAggregateFunction (Old Way) pandas_udf (Recommended Way)
Performance Slower. Operates row-by-row or in small batches. Much Faster. Uses Apache Arrow and Pandas for vectorized operations.
Dependencies Pure PySpark. No external libraries needed. Requires Apache Pandas and PyArrow.
Syntax Verbose. Requires defining a class with multiple abstract methods (initialize, update, etc.). Clean and simple. Uses a standard Python function with a decorator.
Data Handling Works with Spark's internal Row objects and basic types. Works with Pandas Series/DataFrames, which are ideal for complex numerical/string operations.
Flexibility Can handle any Spark data type. Best for numerical and string data types that Pandas supports well.
When to Use - When you cannot use Pandas.
- For very simple aggregations where you don't want the Pandas dependency.
- Legacy codebases.
- Almost always for new Python UDAFs.
- When performance is critical.
- When your aggregation logic is complex and benefits from Pandas/Numpy.

Conclusion:

For any new development in Spark with Python, you should strongly prefer using pandas_udf for creating UDAFs. It is more performant, easier to write, and integrates seamlessly with the data science ecosystem (Pandas, NumPy). The older UserDefinedAggregateFunction class is largely kept for backward compatibility.

分享:
扫描分享到社交APP
上一篇
下一篇