Source code for dpsql.backend.spark_sql_backend

import logging
from typing import cast

from pandas import DataFrame, Series
from pyspark.sql import DataFrame as SparkDataFrame
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col,
    count,
    countDistinct,
    greatest,
    least,
    lit,
    rand,
    row_number,
)
from pyspark.sql.functions import sum as spark_sum
from pyspark.sql.window import Window

from ..aggregation import Aggregation
from ..dp_params import DPParams
from ..errors import (
    AggregationError,
    ExecutionBackendError,
    UnsupportedBackendError,
)
from ..utils import safely_get_threshold
from .sql_backend import DataFrameLike, SQLBackend

logger = logging.getLogger(__name__)


[docs] class SparkSQLBackend(SQLBackend): """ Backend for executing SparkSQL queries with differential privacy mechanisms. Args: spark_session (SparkSession): The Spark session """ def __init__(self, spark_session: SparkSession): self.spark = spark_session logger.debug("SparkSQLBackend initialized with SparkSession")
[docs] def create_inner_df(self, inner_sql: str) -> SparkDataFrame: logger.debug("Spark create_inner_df: sql_head=%s", inner_sql[:160]) try: return self.spark.sql(inner_sql) except Exception as e: # noqa raise ExecutionBackendError( "Failed to execute Spark SQL", context={"sql_head": inner_sql[:140]}, hint="Check SQL syntax and table existence", cause=e, ) from e
[docs] def contribution_bound( self, inner_df: DataFrameLike, privacy_unit: str, params: DPParams ) -> DataFrameLike: logger.debug( "Spark contribution_bound: privacy_unit=%s bound=%s", privacy_unit, params.contribution_bound, ) if not isinstance(inner_df, SparkDataFrame): raise UnsupportedBackendError( "Expected Spark DataFrame", context={"actual_type": type(inner_df).__name__}, hint="Provide a pyspark.sql.DataFrame", ) # Add random order column to shuffle the records df_with_random = inner_df.withColumn("random_order", rand()) window_spec = Window.partitionBy(privacy_unit).orderBy("random_order") df_with_row_num = df_with_random.withColumn( "row_num", row_number().over(window_spec) ) # Filter at most contribution_bound records for each privacy_unit filtered_df = df_with_row_num.filter( col("row_num") <= params.contribution_bound ).drop("row_num", "random_order") logger.debug("Spark contribution_bound applied") return filtered_df
[docs] def apply_aggregation( self, agg_type: Aggregation, column_name: list[str], df: DataFrameLike, group_by: list[str], clipping_threshold: list[tuple[float, float]] | None = None, ) -> Series: logger.debug( "Spark apply_aggregation: agg=%s columns=%s group_by=%s clip=%s", getattr(agg_type, "name", str(agg_type)), column_name, group_by, clipping_threshold, ) if not isinstance(df, SparkDataFrame): raise UnsupportedBackendError( "Expected Spark DataFrame", context={"actual_type": type(df).__name__}, hint="Provide a pyspark.sql.DataFrame", ) if agg_type == Aggregation.COUNT: res = df.groupBy(group_by).agg(count(column_name[0])) elif agg_type == Aggregation.COUNT_DISTINCT: res = df.groupBy(group_by).agg(countDistinct(column_name[0])) elif agg_type == Aggregation.SUM: if clipping_threshold is None: raise AggregationError( "Missing `clipping_threshold` for `SUM` aggregation", context={"agg_type": agg_type.name, "column": column_name[0]}, hint="Provide a list of (lower_bound, upper_bound) tuples", ) lower_bound, upper_bound = safely_get_threshold( clipping_threshold, 0, agg_type.name ) res = df.groupBy(group_by).agg( spark_sum( least(lit(upper_bound), greatest(lit(lower_bound), column_name[0])) ) ) elif agg_type == Aggregation.SQUARED_SUM: if clipping_threshold is None: raise AggregationError( "Missing `clipping_threshold` for `SQUARED_SUM` aggregation", context={"agg_type": agg_type.name, "column": column_name[0]}, hint="Provide a list of (lower_bound, upper_bound) tuples", ) lower_bound, upper_bound = safely_get_threshold( clipping_threshold, 0, agg_type.name ) res = df.groupBy(group_by).agg( spark_sum( least(lit(upper_bound), greatest(lit(lower_bound), column_name[0])) ** 2 ) ) elif agg_type == Aggregation.PRODUCT_SUM: lower_bound_x, upper_bound_x = safely_get_threshold( clipping_threshold, 0, agg_type.name ) lower_bound_y, upper_bound_y = safely_get_threshold( clipping_threshold, 1, agg_type.name ) # Calculate the product of the first two columns col1_clipped = least( lit(upper_bound_x), greatest(lit(lower_bound_x), column_name[0]) ) col2_clipped = least( lit(upper_bound_y), greatest(lit(lower_bound_y), column_name[1]) ) res = df.groupBy(group_by).agg(spark_sum(col1_clipped * col2_clipped)) else: raise AggregationError( "Unsupported aggregation type", context={"agg_type": getattr(agg_type, "name", str(agg_type))}, hint="Use COUNT, COUNT_DISTINCT, SUM, SQUARED_SUM, PRODUCT_SUM", ) pd_res = cast(DataFrame, res.toPandas()) # cast DataFramneLike to DataFrame logger.debug("Spark aggregation result converted to pandas Series") if group_by: return pd_res.set_index(group_by).iloc[:, 0] # convert to Series else: return pd_res.iloc[:, 0] # convert to Series without group_by
[docs] def use_database(self, database_name: str | None) -> None: logger.info("Spark use_database: %s", database_name) """ Use a database in the Spark session. Args: database_name (str): The name of the database to use. """ if database_name is None: raise UnsupportedBackendError( "Missing database name", hint="Provide a database name after USE", ) try: self.spark.sql(f"USE {database_name}") except Exception as e: # noqa raise ExecutionBackendError( "Failed to switch Spark database", context={"database": database_name}, hint="Verify existence with SHOW DATABASES", cause=e, ) from e
[docs] def get_table_name(self) -> list[str]: logger.debug("Spark get_table_name") """ Get the list of tables in the database. Returns: list[str]: The list of table names. """ try: tables = self.spark.sql("SHOW TABLES").collect() except Exception as e: # noqa raise ExecutionBackendError( "Failed to list Spark tables", hint="Check the current database context", cause=e, ) from e return [table.tableName for table in tables]
[docs] def filter_by_selected_keys( self, df: DataFrameLike, group_by: list[str], selected_keys: list[tuple[str, ...]], ) -> DataFrameLike: logger.debug( "Spark filter_by_selected_keys: group_by=%s selected_keys_count=%s", group_by, len(selected_keys), ) if not isinstance(df, SparkDataFrame): raise UnsupportedBackendError( "df must be a Spark DataFrame", context={"actual_type": type(df).__name__}, ) if len(group_by) == 0: # If no group by columns, return the original DataFrame return df elif len(selected_keys) == 0: # If selected_keys is empty, return an empty DataFrame with the same schema return self.spark.createDataFrame([], df.schema) else: try: ref_df = self.spark.createDataFrame(selected_keys, schema=group_by) filtered_df = df.join(ref_df, on=group_by, how="inner") except Exception as e: raise ExecutionBackendError( "Failed to filter keys", context={"group_by": group_by, "key_count": len(selected_keys)}, hint="Validate key column schema and existence", cause=e, ) from e logger.debug("Spark keys filtered") return filtered_df
[docs] def get_column_name(self, table_name: str) -> list[str]: logger.debug("Spark get_column_name: table=%s", table_name) """ Get the list of columns in the table. Args: table_name (str): The name of the table. Returns: list[str]: The list of column names. """ try: columns = self.spark.sql(f"DESCRIBE {table_name}").collect() except Exception as e: # noqa raise ExecutionBackendError( "Failed to describe table", context={"table": table_name}, hint="Verify table existence and permissions", cause=e, ) from e return [column.col_name for column in columns]
[docs] def create_temporary_table( self, df: DataFrame, table_name: str, index: bool = True ) -> None: logger.info( "Spark create_temporary_table: table=%s index=%s", table_name, index ) try: spark_df = self.spark.createDataFrame(df.reset_index() if index else df) spark_df.createTempView(table_name) except Exception as e: # noqa raise ExecutionBackendError( "Failed to create temporary table", context={"table": table_name, "index_included": index}, hint="Check name conflicts and schema consistency", cause=e, ) from e