Source code for scalr.data.split.group_splitter

"""This file is an implementation of group splitter."""

from pandas import DataFrame
from sklearn.model_selection import GroupShuffleSplit

from scalr.data.split import StratifiedSplitter


[docs] class GroupSplitter(StratifiedSplitter): """Class for splitting data based on the provided group. Generate a stratified split of data into train, validation, and test sets. Stratification ensures samples have the same value for `stratify` column, can not belong to different sets. """ def __init__(self, split_ratio: list[float], stratify: str): """Initialize splitter with required parameters. Args: split_ratio (list[float]): Ratio to split number of samples in. stratify (str): Column name to metadata the split upon in `obs`. """ super().__init__(split_ratio) self.stratify = stratify
[docs] def _split_data_with_stratification( self, metadata: DataFrame, target: str, test_ratio: float) -> tuple[list[int], list[int]]: """A function to split given metadata into a training and testing set. Args: metadata (DataFrame): Dataframe containing all samples to be split. target (str): Target for classification present in `obs`. test_ratio (float): Ratio of samples belonging to the test split. Returns: (list(int), list(int)): Two lists consisting of train and test indices. """ splitter = GroupShuffleSplit(test_size=test_ratio, n_splits=1, random_state=42) train_inds, test_inds = next( splitter.split(metadata, metadata[target], groups=metadata[self.stratify])) return train_inds, test_inds
[docs] @classmethod def get_default_params(cls) -> dict: """Class method to get default params for model_config.""" return dict(split_ratio=[7, 1, 2], stratify='donor_id')