"""This file is an implementation of stratified group splitter."""
from pandas import DataFrame
from sklearn.model_selection import GroupShuffleSplit
from scalr.data.split import SplitterBase
from scalr.utils import read_data
[docs]
class StratifiedGroupSplitter(SplitterBase):
"""Class for stratified group splitter.
Generates split of data into train, validation, and test
sets. Stratification ensures samples have the same value for `stratify`
column, can not belong to different sets. Also, it ensures every split
contains samples from each class available in the data.
"""
def __init__(self, split_ratio: list[float], stratify: str):
"""Initialize splitter with required parameters.
Args:
split_ratio (list[float]): ratio to split number of samples in
stratify (str): column name to metadata the split upon in `obs`
"""
super().__init__()
self.stratify = stratify
self.split_ratio = split_ratio
[docs]
def _split_data_with_stratification(
self, metadata: DataFrame, target: str,
test_ratio: float) -> tuple[list[int], list[int]]:
"""A function to split given metadata into a training and testing set.
Args:
metadata (DataFrame): Dataframe containing all samples to be split.
target (str): Target for classification present in `obs`.
test_ratio (float): Ratio of samples belonging to the test split.
Returns:
(list(int), list(int)): Two lists consisting of train and test indices.
"""
splitter = GroupShuffleSplit(test_size=test_ratio,
n_splits=1,
random_state=42)
train_inds, test_inds = next(
splitter.split(metadata,
metadata[target],
groups=metadata[self.stratify]))
return train_inds, test_inds
[docs]
def generate_train_val_test_split_indices(self, datapath: str,
target: str) -> dict:
"""A function to generate a list of indices for train/val/test split of the whole dataset.
Args:
datapath (str): Path to full data.
target (str): Target for classification present in `obs`.
Returns:
dict: 'train', 'val' and 'test' indices list.
"""
if not target:
raise ValueError('Must provide target for StratifiedGroupSplitter')
adata = read_data(datapath)
metadata = adata.obs
metadata['true_index'] = range(len(metadata))
n_cls = metadata[target].nunique()
if n_cls > 2:
raise ValueError(
'StratifiedGroupSplitter only works for binary classification.')
total_ratio = sum(self.split_ratio)
train_ratio = self.split_ratio[0] / total_ratio
val_ratio = self.split_ratio[1] / total_ratio
val_ratio = val_ratio / (val_ratio + train_ratio)
test_ratio = self.split_ratio[2] / total_ratio
train_indices = []
val_indices = []
test_indices = []
for label in metadata[target].unique():
label_metadata = metadata[metadata[target] == label]
# Split testing and (train+val) indices.
relative_train_val_inds, relative_test_inds = self._split_data_with_stratification(
label_metadata, target, test_ratio)
train_val_data = label_metadata.iloc[relative_train_val_inds]
# Get train and val indices, relative to the `train_val_data`.
relative_train_inds, relative_val_inds = self._split_data_with_stratification(
train_val_data, target, val_ratio)
# Get true_indices relative to the entire data.
test_indices.extend(
label_metadata.iloc[relative_test_inds]['true_index'].tolist())
val_indices.extend(
train_val_data.iloc[relative_val_inds]['true_index'].tolist())
train_indices.extend(
train_val_data.iloc[relative_train_inds]['true_index'].tolist())
data_split = {
'train': train_indices,
'val': val_indices,
'test': test_indices
}
return data_split
[docs]
@classmethod
def get_default_params(cls) -> dict:
"""Class method to get default params for model_config."""
return dict(split_ratio=[7, 1, 2], stratify='donor_id')