from abc import ABC, abstractmethod
import pandas as pd
from rocelib.datasets.DatasetLoader import DatasetLoader
from rocelib.models.TrainedModel import TrainedModel
from typing import List, Dict, Any, Tuple
[docs]
class Task(ABC):
"""
An abstract base class representing a general task that involves training a model
on a specific dataset.
Attributes:
_dataset (DatasetLoader): The dataset used for training the model.
__model (TrainableModel): The model to be trained and used for predictions.
"""
def __init__(self, model: TrainedModel, dataset: DatasetLoader, mm_models: Dict[str, TrainedModel] = None):
"""
Initializes the Task with a model and training data and optionally multiple models for MM
@param model: An instance of a model that extends TrainedModel
@param dataset: An instance of DatasetLoader containing the training data.
@param dataset: A list of instances of a model that extends TrainedModel
"""
self._dataset = dataset
self.__model = model
self._CEs: Dict[str, Tuple[pd.DataFrame, float]] = {} # Stores generated counterfactuals per method
# self._mm_CEs: List[Dict[str, Tuple[pd.DataFrame, float]]] = []
self._mm_CEs: Dict[str, Dict[str, Tuple[pd.DataFrame, float]]] = {} #Stores generated counterfactuals per model per method
self.__mm_models: Dict[str, TrainedModel] = mm_models
# Set mm_flag based on whether the user added multiple models
if mm_models and len(mm_models) > 1:
self.mm_flag = True
else:
self.mm_flag = False
self.methods = {}
self.evaluation_metrics = {}
[docs]
def get_random_positive_instance(self, neg_value, column_name="target") -> pd.Series:
"""
Abstract method to retrieve a random positive instance from the training data.
@param neg_value: The value considered negative in the target variable.
@param column_name: The name of the target column.
@return: A Pandas Series representing a random positive instance.
"""
pass
[docs]
def generate(self, methods: List[str]) -> Dict[str, Tuple[pd.DataFrame, float]]:
pass
[docs]
def generate_mm(self, methods: List[str]) -> Dict[str, Tuple[pd.DataFrame, float]]:
pass
[docs]
def evaluate(self, methods: List[str], evaluations: List[str]) -> Dict[str, Dict[str, Any]]:
pass
[docs]
def get_recourse_methods(self) -> List[str]:
return list(self.methods.keys())
[docs]
def get_evaluation_metrics(self) -> List[str]:
return list(self.evaluation_metrics.keys())
@property
def dataset(self):
"""
Property to access the training data.
@return: The training data loaded from DatasetLoader.
"""
return self._dataset
@property
def ces(self):
"""
Property to access the training data.
@return: The training data loaded from DatasetLoader.
"""
return self._CEs
@property
def model(self):
"""
Property to access the model.
@return: The model instance that extends TrainableModel
"""
return self.__model
@property
def mm_models(self):
"""
Property to access the model.
@return: The model instance that extends TrainableModel
"""
return self.__mm_models
@property
def CEs(self):
return self._CEs
@property
def mm_CEs(self):
return self._mm_CEs