Source code for rocelib.tasks.ClassificationTask

import os
import zipfile

import pandas as pd
import time
import torch
import numpy as np
from tabulate import tabulate  # For better table formatting
import matplotlib.pyplot as plt

from rocelib.datasets.DatasetLoader import DatasetLoader
from rocelib.models.TrainedModel import TrainedModel
from rocelib.tasks.Task import Task
from typing import List, Dict, Any, Union, Tuple

from rocelib.recourse_methods.ArgEnsembling import ArgEnsembling
from rocelib.recourse_methods.RecourseGenerator import RecourseGenerator
from rocelib.recourse_methods.BinaryLinearSearch import BinaryLinearSearch
from rocelib.recourse_methods.NNCE import NNCE
from rocelib.recourse_methods.KDTreeNNCE import KDTreeNNCE
from rocelib.recourse_methods.MCE import MCE
from rocelib.recourse_methods.Wachter import Wachter
from rocelib.recourse_methods.RNCE import RNCE
from rocelib.recourse_methods.MCER import MCER
from rocelib.recourse_methods.RoCourseNet import RoCourseNet
from rocelib.recourse_methods.STCE import TrexNN
from rocelib.recourse_methods.GuidedBinaryLinearSearch import GuidedBinaryLinearSearch
from rocelib.recourse_methods.ModelMultiplicityMILP import ModelMultiplicityMILP
from rocelib.recourse_methods.APAS import APAS
from rocelib.recourse_methods.DiverseRobustCE import DiverseRobustCE
from rocelib.recourse_methods.PROPLACE import PROPLACE

# Evaluators
from rocelib.evaluations.robustness_evaluations.Evaluator import Evaluator
from rocelib.evaluations.ManifoldEvaluator import ManifoldEvaluator
from rocelib.evaluations.DistanceEvaluator import DistanceEvaluator
from rocelib.evaluations.ValidityEvaluator import ValidityEvaluator
from rocelib.evaluations.robustness_evaluations.MC_Robustness_Implementations.DeltaRobustnessEvaluator import DeltaRobustnessEvaluator
from rocelib.evaluations.robustness_evaluations.MC_Robustness_Implementations.VaRRobustnessEvaluator import VaRRobustnessEvaluator
from rocelib.evaluations.robustness_evaluations.ModelMultiplicityRobustnessEvaluator import ModelMultiplicityRobustnessEvaluator
from rocelib.evaluations.robustness_evaluations.NE_Robustness_Implementations.InvalidationRateRobustnessEvaluator import InvalidationRateRobustnessEvaluator
from rocelib.evaluations.robustness_evaluations.MM_Robustness_Implementations.MultiplicityValidityRobustnessEvaluator import MultiplicityValidityRobustnessEvaluator
from rocelib.evaluations.robustness_evaluations.IC_Robustness_Implementations.SetDistanceRobustnessEvaluator import SetDistanceRobustnessEvaluator


TIMEOUT_SECONDS = 60


[docs] class ClassificationTask(Task): """ A specific task type for classification problems that extends the base Task class. This class provides methods for training the model and retrieving positive instances from the training data. Attributes: model: The model to be trained and used for predictions. _dataset: The dataset used for training the model. """ def __init__(self, model: TrainedModel, dataset: DatasetLoader, mm_models: Dict[str, TrainedModel] = None): super().__init__(model, dataset, mm_models) self.methods = { "BinaryLinearSearch": BinaryLinearSearch, # "GuidedBinaryLinearSearch": GuidedBinaryLinearSearch, "MMMILP": ModelMultiplicityMILP, "NNCE": NNCE, "KDTreeNNCE": KDTreeNNCE, "MCE": MCE, "Wachter": Wachter, "RNCE": RNCE, "MCER": MCER, "RoCourseNet": RoCourseNet, "STCE": TrexNN, "APAS": APAS, "ArgEnsembling": ArgEnsembling, "DiverseRobustCE": DiverseRobustCE, "PROPLACE": PROPLACE } self.evaluation_metrics = { "Distance": DistanceEvaluator, "Validity": ValidityEvaluator, "ModelMultiplicityRobustness": MultiplicityValidityRobustnessEvaluator, "DeltaRobustnessEvaluator": DeltaRobustnessEvaluator, "InvalidationRateRobustnessEvaluator": InvalidationRateRobustnessEvaluator, "SetDistanceRobustnessEvaluator": SetDistanceRobustnessEvaluator, "ManifoldEvaluator": ManifoldEvaluator, # "VaRRobustnessEvaluator": VaRRobustnessEvaluator }
[docs] def get_random_positive_instance(self, neg_value, column_name="target") -> pd.Series: """ Retrieves a random positive instance from the training data that does not have the specified negative value. This method continues to sample from the training data until a positive instance is found whose predicted label is not equal to the negative value. @param neg_value: The value considered negative in the target variable. @param column_name: The name of the target column used to identify positive instances. @return: A Pandas Series representing a random positive instance. """ # Get a random positive instance from the training data pos_instance = self._dataset.get_random_positive_instance() # Loop until a positive instance whose prediction is positive is found while self.model.predict_single(pos_instance) == neg_value: pos_instance = self._dataset.get_random_positive_instance() return pos_instance
[docs] def generate(self, methods: List[str]=None, type="DataFrame", **kwargs) -> Dict[str, Tuple[pd.DataFrame, float]]: """ Generates counterfactual explanations for the specified methods and stores the results. @param methods: List of recourse methods (by name) to use for counterfactual generation. If not provided, then counterfactuals will be generated for all methods @param type: The datatype your instances are in e.g. dataframe, nparray, tensor @return: A dictionary from recourse method to a tuple of (Pandas dataframe holding the counterfactual, time taken to generate the counterfactual) """ if methods is None: methods = self.get_recourse_methods() for method in methods: print(f"Generating for {method}") try: # Check if the method exists in the dictionary if method not in self.methods: raise ValueError(f"Recourse method '{method}' not found. Available methods: {list(self.methods.keys())}") # Instantiate the recourse method recourse_method = self.methods[method](self) # Pass the classification task to the method # Start timer start_time = time.perf_counter() res = recourse_method.generate_for_all(**kwargs) # Generate counterfactuals res_correct_type = self.convert_datatype(res, type) # End timer end_time = time.perf_counter() # Store the result in the counterfactual explanations dictionary self._CEs[method] = [res, end_time - start_time] except Exception as e: print(f"Error generating counterfactuals with method '{method}': {e}") return self.CEs
[docs] def generate_mm(self, methods: List[str]=None, type="DataFrame", **kwargs) -> Dict[str, Dict[str, Tuple[pd.DataFrame, float]]]: """ Generates counterfactual explanations for the specified methods for each of the stored models and stores the results. @param methods: List of recourse methods (by name) to use for counterfactual generation. @return: A nested dictionary from recourse method to model name to a tuple of (Pandas dataframe holding the counterfactual, time taken to generate the counterfactual) """ if methods is None: methods = self.get_recourse_methods() if not self.mm_flag: raise ValueError("Multiple models must be added in order to generate for MM") for method in methods: print(f"Generating for {method}") for i, model_name in enumerate(self.mm_models): ces = self.generate_for_model_method(model_name, method, type, **kwargs) if i == 0: # Primary model so we should store results in self._CEs self._CEs[method] = ces # Store results in mm_CEs if method not in self.mm_CEs: self.mm_CEs[method] = {} self.mm_CEs[method][model_name] = ces return self.mm_CEs
[docs] def generate_for_model_method(self, model_name, method, type, **kwargs) -> Tuple[pd.DataFrame, float]: print(f"GENERATING FOR: model: {model_name}, method: {method}") try: # Check if the method exists in the dictionary if method not in self.methods: raise ValueError(f"Recourse method '{method}' not found. Available methods: {list(self.methods.keys())}") # Instantiate the recourse method task = ClassificationTask(self.mm_models[model_name], dataset=self.dataset, mm_models=self.mm_models) recourse_method = self.methods[method](task) # Pass the classification task to the method # Start timer start_time = time.perf_counter() res = recourse_method.generate_for_all(**kwargs) # Generate counterfactuals res_correct_type = self.convert_datatype(res, type) # End timer end_time = time.perf_counter() # Store the result in the counterfactual explanations dictionary return [res, end_time - start_time] except Exception as e: print(f"Error generating counterfactuals with method '{method}': {e}") return None
[docs] def evaluate(self, methods: List[str]=None, evaluations: List[str]=None, visualisation=False, **kwargs) -> Dict[str, Dict[str, Any]]: """ Evaluates the generated counterfactual explanations using specified evaluation metrics. @param methods: List of recourse methods to evaluate. @param evaluations: List of evaluation metrics to apply. @return: Dictionary containing evaluation results per method and metric. """ if methods is None: methods = self.get_recourse_methods() if evaluations is None: evaluations = self.get_evaluation_metrics() evaluation_results = {} # Validate evaluation names invalid_evaluations = [ev for ev in evaluations if ev not in self.evaluation_metrics] if invalid_evaluations: raise ValueError(f"Invalid evaluation metrics: {invalid_evaluations}. Available: {list(self.evaluation_metrics.keys())}") # Filter out methods that haven't been generated valid_methods = [method for method in methods if method in self._CEs] if valid_methods != methods: print(f"generate has not been called for {list(set(methods) - set(valid_methods))} so evaluations were not performed for these") # Filter out methods that haven't been generated for MM if mm_flag is on mm_metric = [isinstance(self.evaluation_metrics[metric](self), ModelMultiplicityRobustnessEvaluator) for metric in evaluations] if any(mm_metric): if not self.mm_flag: print("Multiple models must be added to the task in order to evaluate model multiplicity") #Remove the MM metrics from this evaluation evaluations = [metric for (i,metric) in enumerate(evaluations) if not mm_metric[i]] else: valid_methods = [method for method in methods if (method in self.mm_CEs and len(self.mm_CEs[method].keys()) == len(self.mm_models))] if not valid_methods: print("No valid methods have been generated for MM for evaluation. Call generate_mm for these methods") return evaluation_results print(f"generate_mm has not been called for {list(set(methods) - set(valid_methods))} so evaluations were not performed for these") if not valid_methods: print("No valid methods have been generated for evaluation.") return evaluation_results # Perform evaluation for evaluation in evaluations: print(f"Evaluation technique {evaluation}") evaluator_class = self.evaluation_metrics[evaluation] # Create evaluator instance evaluator = evaluator_class(self) for method in valid_methods: try: print(f"Method: {method}") # Retrieve generated counterfactuals counterfactuals = self._CEs[method][0] # Extract DataFrame from stored list print(f"Shape of CEs for {method}: {counterfactuals.shape}") # Ensure counterfactuals are not empty if counterfactuals is None or counterfactuals.empty: print(f"Skipping evaluation for method '{method}' as no counterfactuals were generated.") continue # Perform evaluation score = evaluator.evaluate(method, **kwargs) # Store results if method not in evaluation_results: evaluation_results[method] = {} evaluation_results[method][evaluation] = score except Exception as e: print(f"'{method}': Error evaluating '{evaluation}' for : {e}") # Print results in table format table_data, headers = self._print_evaluation_results(evaluation_results, evaluations) csv_filename = "recourse_evaluations.csv" time.sleep(2) graph_filename = self._visualise_results(evaluation_results, evaluations, visualisation) self.save_evals_as_zip(table_data, headers, csv_filename, graph_filename) return evaluation_results
def _visualise_results(self, evaluations_results: Dict[str, Dict[str, Any]], evaluations: List[str], visualisation = False): if not evaluations_results: print("No evaluation results to display.") return if len(evaluations) > 3: return self._visualise_results_radar_chart(evaluations_results, evaluations, visualisation) else: return self._visualise_results_bar_chart(evaluations_results, evaluations, visualisation) def _visualise_results_bar_chart(self, evaluation_results: Dict[str, Dict[str, Any]], evaluations: List[str], visualisation = False): recourse_methods = list(evaluation_results.keys()) # Extract metric values metric_values = {method: [evaluation_results[method].get(metric, 0) for metric in evaluations] for method in recourse_methods} x = np.arange(len(recourse_methods)) width = 0.2 fig, ax = plt.subplots(figsize=(8, 6)) for i, metric in enumerate(evaluations): values = [metric_values[method][i] for method in recourse_methods] ax.bar(x + i * width, values, width, label=metric) ax.set_xticks(x + width / 2) ax.set_xticklabels(recourse_methods) ax.set_xlabel("Recourse Methods") ax.set_ylabel("Metric Values") ax.set_title("Bar Chart of Evaluation Metrics") ax.legend() filename = "/tmp/evaluation_chart.png" plt.savefig(filename) # Save the figure instead of displaying it if visualisation: plt.show() plt.close(fig) # Ensure the figure is closed properly return filename def _visualise_results_radar_chart(self, evaluation_results: Dict[str, Dict[str, Any]], evaluations: List[str], visualisation = False): """ Generate a radar chart for evaluation results. Parameters: evaluation_results (Dict[str, Dict[str, Any]]): A dictionary where keys are recourse methods, and values are dictionaries mapping metric names to values. evaluations (List[str]): A list of metric names to be visualized (must have at least 4 metrics). """ assert len(evaluations) >= 4, "There must be at least 4 evaluation metrics to plot a radar chart." # Extract recourse methods recourse_methods = list(evaluation_results.keys()) # Extract metric values for each recourse method metric_values = {method: [evaluation_results[method].get(metric, 0) for metric in evaluations] for method in recourse_methods} # Define radar chart angles num_vars = len(evaluations) angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist() # Close the radar chart loop angles += angles[:1] # Create figure fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True)) # Plot each recourse method for method, values in metric_values.items(): values += values[:1] # Close the loop ax.plot(angles, values, label=method, linewidth=2) ax.fill(angles, values, alpha=0.2) # Add labels and legend ax.set_xticks(angles[:-1]) ax.set_xticklabels(evaluations, fontsize=12) ax.set_yticklabels([]) ax.set_title("Radar Chart of Evaluation Metrics", fontsize=14, fontweight='bold') ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1)) filename = "/tmp/evaluation_chart.png" plt.savefig(filename) # Save the figure instead of displaying it plt.pause(0.001) # Allows the figure to be shown briefly without blocking if visualisation: plt.show() plt.close(fig) # Closes the figure to avoid blocking execution return filename def _print_evaluation_results(self, evaluation_results: Dict[str, Dict[str, Any]], evaluations: List[str]): """ Prints the evaluation results in a table format. @param evaluation_results: Dictionary containing evaluation scores per method and metric. @param evaluations: List of evaluation metrics that were actually requested. """ if not evaluation_results: print("No evaluation results to display.") return # Prepare table data table_data = [] headers = ["Recourse Method"] + evaluations # Only include requested evaluations for method, scores in evaluation_results.items(): row = [method] + [scores.get(metric, "N/A") for metric in evaluations] table_data.append(row) print("\nEvaluation Results:") print(tabulate(table_data, headers=headers, tablefmt="grid")) return table_data, headers
[docs] def save_evals_as_zip(self, table_data, headers, csv_filename, graph_filename): # Create DataFrame df = pd.DataFrame(table_data, columns=headers) # Save to CSV df.to_csv(csv_filename, index=False) # Create a zip file containing the CSV and graph files zip_filename = "evaluations.zip" # Output zip file name with zipfile.ZipFile(zip_filename, 'w') as zipf: # Add the CSV file to the zip zipf.write(csv_filename, os.path.basename(csv_filename)) # Add the graph image to the zip zipf.write(graph_filename, os.path.basename(graph_filename))
[docs] def convert_datatype(self, data: pd.DataFrame, target_type: str): """ Converts a Pandas DataFrame to the specified data type. @param data: pd.DataFrame - The input DataFrame. @param target_type: str - The target data type: "DataFrame", "NPArray", or "TTensor". @return: Converted data in the specified format. """ if not isinstance(data, pd.DataFrame): raise ValueError("Input data must be a Pandas DataFrame.") target_type = target_type.lower() # Normalize input for case insensitivity if target_type == "dataframe": return data elif target_type == "nparray": return data.to_numpy() elif target_type == "tensor": return torch.tensor(data.to_numpy(), dtype=torch.float32) else: raise ValueError("Invalid target_type. Choose from: 'DataFrame', 'NPArray', 'Tensor'.")
[docs] def add_recourse_method(self, method_name: str, method_class: RecourseGenerator): self.methods[method_name] = method_class
[docs] def add_evaluation_metric(self, metric_name: str, metric_class: Evaluator): self.evaluation_metrics[metric_name] = metric_class