import logging
from typing import cast
from pandas import DataFrame, Series
from pyspark.sql import DataFrame as SparkDataFrame
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
col,
count,
countDistinct,
greatest,
least,
lit,
rand,
row_number,
)
from pyspark.sql.functions import sum as spark_sum
from pyspark.sql.window import Window
from ..aggregation import Aggregation
from ..dp_params import DPParams
from ..errors import (
AggregationError,
ExecutionBackendError,
UnsupportedBackendError,
)
from ..utils import safely_get_threshold
from .sql_backend import DataFrameLike, SQLBackend
logger = logging.getLogger(__name__)
[docs]
class SparkSQLBackend(SQLBackend):
"""
Backend for executing SparkSQL queries with differential privacy mechanisms.
Args:
spark_session (SparkSession): The Spark session
"""
def __init__(self, spark_session: SparkSession):
self.spark = spark_session
logger.debug("SparkSQLBackend initialized with SparkSession")
[docs]
def create_inner_df(self, inner_sql: str) -> SparkDataFrame:
logger.debug("Spark create_inner_df: sql_head=%s", inner_sql[:160])
try:
return self.spark.sql(inner_sql)
except Exception as e: # noqa
raise ExecutionBackendError(
"Failed to execute Spark SQL",
context={"sql_head": inner_sql[:140]},
hint="Check SQL syntax and table existence",
cause=e,
) from e
[docs]
def contribution_bound(
self, inner_df: DataFrameLike, privacy_unit: str, params: DPParams
) -> DataFrameLike:
logger.debug(
"Spark contribution_bound: privacy_unit=%s bound=%s",
privacy_unit,
params.contribution_bound,
)
if not isinstance(inner_df, SparkDataFrame):
raise UnsupportedBackendError(
"Expected Spark DataFrame",
context={"actual_type": type(inner_df).__name__},
hint="Provide a pyspark.sql.DataFrame",
)
# Add random order column to shuffle the records
df_with_random = inner_df.withColumn("random_order", rand())
window_spec = Window.partitionBy(privacy_unit).orderBy("random_order")
df_with_row_num = df_with_random.withColumn(
"row_num", row_number().over(window_spec)
)
# Filter at most contribution_bound records for each privacy_unit
filtered_df = df_with_row_num.filter(
col("row_num") <= params.contribution_bound
).drop("row_num", "random_order")
logger.debug("Spark contribution_bound applied")
return filtered_df
[docs]
def apply_aggregation(
self,
agg_type: Aggregation,
column_name: list[str],
df: DataFrameLike,
group_by: list[str],
clipping_threshold: list[tuple[float, float]] | None = None,
) -> Series:
logger.debug(
"Spark apply_aggregation: agg=%s columns=%s group_by=%s clip=%s",
getattr(agg_type, "name", str(agg_type)),
column_name,
group_by,
clipping_threshold,
)
if not isinstance(df, SparkDataFrame):
raise UnsupportedBackendError(
"Expected Spark DataFrame",
context={"actual_type": type(df).__name__},
hint="Provide a pyspark.sql.DataFrame",
)
if agg_type == Aggregation.COUNT:
res = df.groupBy(group_by).agg(count(column_name[0]))
elif agg_type == Aggregation.COUNT_DISTINCT:
res = df.groupBy(group_by).agg(countDistinct(column_name[0]))
elif agg_type == Aggregation.SUM:
if clipping_threshold is None:
raise AggregationError(
"Missing `clipping_threshold` for `SUM` aggregation",
context={"agg_type": agg_type.name, "column": column_name[0]},
hint="Provide a list of (lower_bound, upper_bound) tuples",
)
lower_bound, upper_bound = safely_get_threshold(
clipping_threshold, 0, agg_type.name
)
res = df.groupBy(group_by).agg(
spark_sum(
least(lit(upper_bound), greatest(lit(lower_bound), column_name[0]))
)
)
elif agg_type == Aggregation.SQUARED_SUM:
if clipping_threshold is None:
raise AggregationError(
"Missing `clipping_threshold` for `SQUARED_SUM` aggregation",
context={"agg_type": agg_type.name, "column": column_name[0]},
hint="Provide a list of (lower_bound, upper_bound) tuples",
)
lower_bound, upper_bound = safely_get_threshold(
clipping_threshold, 0, agg_type.name
)
res = df.groupBy(group_by).agg(
spark_sum(
least(lit(upper_bound), greatest(lit(lower_bound), column_name[0]))
** 2
)
)
elif agg_type == Aggregation.PRODUCT_SUM:
lower_bound_x, upper_bound_x = safely_get_threshold(
clipping_threshold, 0, agg_type.name
)
lower_bound_y, upper_bound_y = safely_get_threshold(
clipping_threshold, 1, agg_type.name
)
# Calculate the product of the first two columns
col1_clipped = least(
lit(upper_bound_x), greatest(lit(lower_bound_x), column_name[0])
)
col2_clipped = least(
lit(upper_bound_y), greatest(lit(lower_bound_y), column_name[1])
)
res = df.groupBy(group_by).agg(spark_sum(col1_clipped * col2_clipped))
else:
raise AggregationError(
"Unsupported aggregation type",
context={"agg_type": getattr(agg_type, "name", str(agg_type))},
hint="Use COUNT, COUNT_DISTINCT, SUM, SQUARED_SUM, PRODUCT_SUM",
)
pd_res = cast(DataFrame, res.toPandas()) # cast DataFramneLike to DataFrame
logger.debug("Spark aggregation result converted to pandas Series")
if group_by:
return pd_res.set_index(group_by).iloc[:, 0] # convert to Series
else:
return pd_res.iloc[:, 0] # convert to Series without group_by
[docs]
def use_database(self, database_name: str | None) -> None:
logger.info("Spark use_database: %s", database_name)
"""
Use a database in the Spark session.
Args:
database_name (str): The name of the database to use.
"""
if database_name is None:
raise UnsupportedBackendError(
"Missing database name",
hint="Provide a database name after USE",
)
try:
self.spark.sql(f"USE {database_name}")
except Exception as e: # noqa
raise ExecutionBackendError(
"Failed to switch Spark database",
context={"database": database_name},
hint="Verify existence with SHOW DATABASES",
cause=e,
) from e
[docs]
def get_table_name(self) -> list[str]:
logger.debug("Spark get_table_name")
"""
Get the list of tables in the database.
Returns:
list[str]: The list of table names.
"""
try:
tables = self.spark.sql("SHOW TABLES").collect()
except Exception as e: # noqa
raise ExecutionBackendError(
"Failed to list Spark tables",
hint="Check the current database context",
cause=e,
) from e
return [table.tableName for table in tables]
[docs]
def filter_by_selected_keys(
self,
df: DataFrameLike,
group_by: list[str],
selected_keys: list[tuple[str, ...]],
) -> DataFrameLike:
logger.debug(
"Spark filter_by_selected_keys: group_by=%s selected_keys_count=%s",
group_by,
len(selected_keys),
)
if not isinstance(df, SparkDataFrame):
raise UnsupportedBackendError(
"df must be a Spark DataFrame",
context={"actual_type": type(df).__name__},
)
if len(group_by) == 0:
# If no group by columns, return the original DataFrame
return df
elif len(selected_keys) == 0:
# If selected_keys is empty, return an empty DataFrame with the same schema
return self.spark.createDataFrame([], df.schema)
else:
try:
ref_df = self.spark.createDataFrame(selected_keys, schema=group_by)
filtered_df = df.join(ref_df, on=group_by, how="inner")
except Exception as e:
raise ExecutionBackendError(
"Failed to filter keys",
context={"group_by": group_by, "key_count": len(selected_keys)},
hint="Validate key column schema and existence",
cause=e,
) from e
logger.debug("Spark keys filtered")
return filtered_df
[docs]
def get_column_name(self, table_name: str) -> list[str]:
logger.debug("Spark get_column_name: table=%s", table_name)
"""
Get the list of columns in the table.
Args:
table_name (str): The name of the table.
Returns:
list[str]: The list of column names.
"""
try:
columns = self.spark.sql(f"DESCRIBE {table_name}").collect()
except Exception as e: # noqa
raise ExecutionBackendError(
"Failed to describe table",
context={"table": table_name},
hint="Verify table existence and permissions",
cause=e,
) from e
return [column.col_name for column in columns]
[docs]
def create_temporary_table(
self, df: DataFrame, table_name: str, index: bool = True
) -> None:
logger.info(
"Spark create_temporary_table: table=%s index=%s", table_name, index
)
try:
spark_df = self.spark.createDataFrame(df.reset_index() if index else df)
spark_df.createTempView(table_name)
except Exception as e: # noqa
raise ExecutionBackendError(
"Failed to create temporary table",
context={"table": table_name, "index_included": index},
hint="Check name conflicts and schema consistency",
cause=e,
) from e