Source code for dpsql.engine.engine

import logging

from pandas import DataFrame

from ..accountant import Accountant
from ..aggregation import Aggregation
from ..backend import SQLBackend
from ..dp_params import DPParams, generate_dpparams
from ..errors import (
    EngineError,
    InsufficientPrivacyBudgetError,
    InvalidPrivacyParametersError,
    UnsupportedBackendError,
)
from ..validator import Validator

logger = logging.getLogger(__name__)


[docs] class Engine: def __init__( self, accountant: Accountant, sql_backend: SQLBackend, validator: Validator ): self.accountant = accountant self.sql_backend = sql_backend self.validator = validator self.privacy_unit_columns: dict[str, str] = {} self.db_schema: dict[str | None, dict[str, list[str]]] = {} logger.info("Initializing Engine") logger.debug( "Components: backend=%s accountant=%s auditor=%s", type(self.sql_backend).__name__, type(self.accountant).__name__ if self.accountant else None, type(self.validator).__name__ if self.validator else None, )
[docs] def get_db_schema(self, database_name: str | None) -> dict[str, list[str]]: """ Get the schema of a database. Args: database_name (str): The name of the database. Returns: dict: The schema of the database. ex) {"table1": ["column1", "column2"], "table2": ["column1"]} """ self.sql_backend.use_database(database_name) tables = self.sql_backend.get_table_name() db_schema: dict[str, list[str]] = {} logger.info("Fetching DB schema") logger.debug("Target database: %s", database_name) for table_name in tables: columns = self.sql_backend.get_column_name(table_name) db_schema[table_name] = columns logger.debug("Schema: %s -> %s", table_name, columns) return db_schema
[docs] def register_database( self, database_name: str | None = None, privacy_unit_columns: dict[str, str] | None = None, ) -> None: """ Register a database with privacy unit columns in Engine. Args: database_name (str): The name of the database. privacy_unit_columns (dict): The privacy unit columns of the database. ex) {"table1": "column1"} if there is no privacy unit column in a table, give {} Returns: None """ if privacy_unit_columns is None: privacy_unit_columns = {} db_schema = self.get_db_schema(database_name) logger.info("Registering database") logger.debug( "Database=%s privacy_unit_columns=%s", database_name, privacy_unit_columns ) for table_name, privacy_unit_column in privacy_unit_columns.items(): if table_name not in db_schema: logger.error( "Missing table in database: table=%s database=%s", table_name, database_name, ) raise EngineError( "Missing table in database", context={"database": database_name, "table": table_name}, hint="Verify table name via get_db_schema() or SHOW TABLES", ) if privacy_unit_column not in db_schema[table_name]: logger.error( "Missing privacy unit column: table=%s column=%s available=%s", table_name, privacy_unit_column, db_schema[table_name], ) raise EngineError( "Missing privacy unit column in table", context={ "database": database_name, "table": table_name, "column": privacy_unit_column, "available_columns": db_schema[table_name], }, hint="Confirm the column exists or " "adjust privacy_unit_columns mapping", ) self.db_schema[database_name] = db_schema logger.debug("Registered DB schema for %s", database_name) for table_name, privacy_unit_column in privacy_unit_columns.items(): if database_name is None: self.privacy_unit_columns[f"{table_name}"] = privacy_unit_column else: self.privacy_unit_columns[f"{database_name}.{table_name}"] = ( privacy_unit_column ) logger.debug( "Registered privacy unit: %s -> %s", f"{table_name}" if database_name is None else f"{database_name}.{table_name}", privacy_unit_column, )
[docs] def execute_query( self, query: str, dpparams: DPParams | None, temporary_table_name: str | None = None, ) -> DataFrame: """ Execute a SQL query with privacy parameters. Args: query (str): The SQL query to execute. dpparams (DPParams | None): The differential privacy parameters. temporary_table_name (str | None): The name of the temporary table to save the result. Returns: DataFrame: The result of the query execution. """ ( intermediate_privacy_unit, inner_sql, final_result_columns, group_by_columns, ordering_terms, limit, offset, privacy_params, ) = self.validator.validate_and_get_final_select_items( query, self.db_schema, self.privacy_unit_columns ) logger.info("Executing private query") logger.debug("Query head: %s", query[:200]) # Filter out non-aggregated columns agg_columns = [ agg_column for agg_column in final_result_columns if agg_column.aggregation_type != Aggregation.NONE ] agg_funcs = [agg_column.aggregation_type for agg_column in agg_columns] if dpparams is None: if privacy_params is None: logger.error("Missing privacy parameters for query") raise InvalidPrivacyParametersError( "Missing privacy parameters for query", context={"query_head": query[:120]}, hint="Add PRIVATE_QUERY OPTIONS or pass dpparams explicitly", ) dpparams = generate_dpparams(privacy_params, agg_columns) logger.debug("Generated DPParams from privacy_params") else: if privacy_params is not None: logger.error( "Duplicate privacy parameter sources (explicit and in-query)" ) raise InvalidPrivacyParametersError( "Duplicate privacy parameter sources " "(both explicit and in-query provided)", context={"query_head": query[:120]}, hint="Use only one of dpparams argument or PRIVATE_QUERY OPTIONS", ) logger.debug("Aggregations: %s", [a.name for a in agg_funcs]) if not self.accountant.check_budget(agg_funcs, dpparams): raise InsufficientPrivacyBudgetError( "Privacy budget exceeded", context={ "requested_aggregations": [a.name for a in agg_funcs], "epsilon_total": self.accountant.epsilon, "delta_total": self.accountant.delta, }, hint="Reduce per-query budget or increase global accountant budget", ) logger.info("Executing inner SQL on backend") logger.debug("Inner SQL head: %s", inner_sql[:200]) result = self.sql_backend.execute_sql( intermediate_privacy_unit, dpparams, inner_sql, agg_columns, group_by_columns, ordering_terms, limit, offset, ) logger.info("Backend execution succeeded") if temporary_table_name is not None: logger.info("Creating temporary table: %s", temporary_table_name) try: self.sql_backend.create_temporary_table(result, temporary_table_name) logger.debug("Temporary table created: %s", temporary_table_name) except UnsupportedBackendError as e: raise e self.accountant.update_budget(agg_funcs, dpparams) logger.info("Privacy budget updated") return result