"""This file is a base class for splitter."""
import os
from os import path
from typing import Union
from anndata import AnnData
from anndata.experimental import AnnCollection
import scalr
from scalr.utils import build_object
from scalr.utils import EventLogger
from scalr.utils import read_data
from scalr.utils import write_chunkwise_data
from scalr.utils import write_data
[docs]
class SplitterBase:
    """Base class for splitter, to make Train|Val|Test Splits."""
    def __init__(self):
        self.event_logger = EventLogger('Splitter')
    # Abstract
[docs]
    def generate_train_val_test_split_indices(datapath: str, target: str,
                                              **kwargs) -> dict:
        """Generate a list of indices for train/val/test split of whole dataset.
        Args:
            datapath (str): Path to full data.
            target (str): Target for classification present in `obs`.
            **kwargs: Any other params needed for splitting.
        Returns:
            dict: 'train', 'val' and 'test' indices list.
        """
        pass 
[docs]
    def check_splits(self, datapath: str, data_splits: dict, target: str):
        """This function performs certain checks regarding splits and logs
        the distribution of target classes in each split.
        Args:
            datapath (str): Path to full data.
            data_splits (dict): Split of 'train', 'val' and 'test' indices.
            target (str): Classification target column name in `obs`.
        """
        adata = read_data(datapath)
        metadata = adata.obs
        n_cls = metadata[target].nunique()
        train_inds = data_splits['train']
        val_inds = data_splits['val']
        test_inds = data_splits['test']
        # Check for classes present in splits.
        if len(metadata[target].iloc[train_inds].unique()) != n_cls:
            self.event_logger.warning(
                'All classes are not present in Train set')
        if len(metadata[target].iloc[val_inds].unique()) != n_cls:
            self.event_logger.warning(
                'All classes are not present in Validation set')
        if len(metadata[target].iloc[test_inds].unique()) != n_cls:
            self.event_logger.warning('All classes are not present in Test set')
        # Check for overlapping samples.
        assert len(set(train_inds).intersection(
            test_inds)) == 0, "Test and Train sets contain overlapping samples"
        assert len(
            set(val_inds).intersection(train_inds)
        ) == 0, "Validation and Train sets contain overlapping samples"
        assert len(set(test_inds).intersection(val_inds)
                  ) == 0, "Test and Validation sets contain overlapping samples"
        # LOGGING.
        self.event_logger.info('Train|Validation|Test Splits\n')
        self.event_logger.info(f'Length of train set: {len(train_inds)}')
        self.event_logger.info(f'Distribution of train set: ')
        self.event_logger.info(
            f'{metadata[target].iloc[train_inds].value_counts()}\n')
        self.event_logger.info(f'Length of val set: {len(val_inds)}')
        self.event_logger.info(f'Distribution of val set: ')
        self.event_logger.info(
            f'{metadata[target].iloc[val_inds].value_counts()}\n')
        self.event_logger.info(f'Length of test set: {len(test_inds)}')
        self.event_logger.info(f'Distribution of test set: ')
        self.event_logger.info(
            f'{metadata[target].iloc[test_inds].value_counts()}\n') 
[docs]
    def write_splits(self,
                     full_data: Union[AnnData, AnnCollection],
                     data_split_indices: dict,
                     sample_chunksize: int,
                     dirpath: int,
                     num_workers: int = None):
        """THis function writes the train validation and test splits to the disk.
        Args:
            full_data (Union[AnnData, AnnCollection]): Full data to be split.
            data_split_indices (dict): Indices of each split.
            sample_chunksize (int): Number of samples to be written in one file.
            dirpath (int): Path to write data into.
            num_workers (int): number of jobs to run in parallel for data writing.
        """
        for split in data_split_indices.keys():
            if sample_chunksize:
                split_dirpath = path.join(dirpath, split)
                os.makedirs(split_dirpath, exist_ok=True)
                write_chunkwise_data(full_data,
                                     sample_chunksize,
                                     split_dirpath,
                                     data_split_indices[split],
                                     num_workers=num_workers)
            else:
                filepath = path.join(dirpath, f'{split}.h5ad')
                write_data(full_data[data_split_indices[split]].to_memory(),
                           filepath) 
[docs]
    @classmethod
    def get_default_params(cls) -> dict:
        """Class method to get default params for model_config."""
        return dict() 
 
[docs]
def build_splitter(splitter_config: dict) -> tuple[SplitterBase, dict]:
    """Builder object to get splitter, updated splitter_config."""
    return build_object(scalr.data.split, splitter_config)