rocelib.models package

Subpackages

Submodules

rocelib.models.TrainableModel module

class rocelib.models.TrainableModel.TrainableModel(model)[source]

Bases: ABC

Abstract base class to define the essential methods that all model types must implement, providing template for training, predicting, and evaluating models in a standardized way.

_model

The underlying model object (e.g., a scikit-learn model or a PyTorch model) that this class wraps.

Type:

object

train(X: pd.DataFrame, y: pd.DataFrame) None:[source]

Trains the model using the provided feature and target data.

predict(X: pd.DataFrame) pd.DataFrame:

Predicts the outcomes for the given instances.

predict_single(X: pd.DataFrame) int:

Predicts the outcome for a single instance and returns an integer.

predict_proba(X: pd.DataFrame) pd.DataFrame:

Predicts the probabilities of outcomes for the given instances.

predict_proba_tensor(X: torch.Tensor) torch.Tensor:

Predicts the probabilities of outcomes for tensor inputs.

evaluate(X: pd.DataFrame, y: pd.DataFrame):

Evaluates the model’s performance on the provided feature and target data.

Properties()
----------
model:

Returns the underlying model object.

property model

Returns the underlying model object.

@return: The model object.

abstract train(X, y, **kwargs)[source]

Trains the model using X feature variables and y target variable. Each implementing class can decide how to train their model and can add additional parameters, but X and y must be of type DataFrame.

@return: None

Return type:

TrainedModel

rocelib.models.Models module

rocelib.models.Models.get_sklearn_model(name)[source]

Retrieves an instance of a scikit-learn model based on the provided name.

@param name: The name of the desired model. Options are:
  • “log_reg” for Logistic Regression

  • “decision_tree” for Decision Tree

  • “svm” for Support Vector Machine

@return: An instance of the requested scikit-learn model. The model class should be a subclass of TrainableModel.

@raises ValueError: If the provided model name does not match any of the predefined options.

Module contents