Source code for scalr.data.split.stratified_group_splitter

"""This file is an implementation of stratified group splitter."""

from pandas import DataFrame
from sklearn.model_selection import GroupShuffleSplit

from scalr.data.split import SplitterBase
from scalr.utils import read_data


[docs] class StratifiedGroupSplitter(SplitterBase): """Class for stratified group splitter. Generates split of data into train, validation, and test sets. Stratification ensures samples have the same value for `stratify` column, can not belong to different sets. Also, it ensures every split contains samples from each class available in the data. """ def __init__(self, split_ratio: list[float], stratify: str): """Initialize splitter with required parameters. Args: split_ratio (list[float]): ratio to split number of samples in stratify (str): column name to metadata the split upon in `obs` """ super().__init__() self.stratify = stratify self.split_ratio = split_ratio
[docs] def _split_data_with_stratification( self, metadata: DataFrame, target: str, test_ratio: float) -> tuple[list[int], list[int]]: """A function to split given metadata into a training and testing set. Args: metadata (DataFrame): Dataframe containing all samples to be split. target (str): Target for classification present in `obs`. test_ratio (float): Ratio of samples belonging to the test split. Returns: (list(int), list(int)): Two lists consisting of train and test indices. """ splitter = GroupShuffleSplit(test_size=test_ratio, n_splits=1, random_state=42) train_inds, test_inds = next( splitter.split(metadata, metadata[target], groups=metadata[self.stratify])) return train_inds, test_inds
[docs] def generate_train_val_test_split_indices(self, datapath: str, target: str) -> dict: """A function to generate a list of indices for train/val/test split of the whole dataset. Args: datapath (str): Path to full data. target (str): Target for classification present in `obs`. Returns: dict: 'train', 'val' and 'test' indices list. """ if not target: raise ValueError('Must provide target for StratifiedGroupSplitter') adata = read_data(datapath) metadata = adata.obs metadata['true_index'] = range(len(metadata)) n_cls = metadata[target].nunique() if n_cls > 2: raise ValueError( 'StratifiedGroupSplitter only works for binary classification.') total_ratio = sum(self.split_ratio) train_ratio = self.split_ratio[0] / total_ratio val_ratio = self.split_ratio[1] / total_ratio val_ratio = val_ratio / (val_ratio + train_ratio) test_ratio = self.split_ratio[2] / total_ratio train_indices = [] val_indices = [] test_indices = [] for label in metadata[target].unique(): label_metadata = metadata[metadata[target] == label] # Split testing and (train+val) indices. relative_train_val_inds, relative_test_inds = self._split_data_with_stratification( label_metadata, target, test_ratio) train_val_data = label_metadata.iloc[relative_train_val_inds] # Get train and val indices, relative to the `train_val_data`. relative_train_inds, relative_val_inds = self._split_data_with_stratification( train_val_data, target, val_ratio) # Get true_indices relative to the entire data. test_indices.extend( label_metadata.iloc[relative_test_inds]['true_index'].tolist()) val_indices.extend( train_val_data.iloc[relative_val_inds]['true_index'].tolist()) train_indices.extend( train_val_data.iloc[relative_train_inds]['true_index'].tolist()) data_split = { 'train': train_indices, 'val': val_indices, 'test': test_indices } return data_split
[docs] @classmethod def get_default_params(cls) -> dict: """Class method to get default params for model_config.""" return dict(split_ratio=[7, 1, 2], stratify='donor_id')