Source code for scalr.feature.feature_subsetting

"""This file contains implementation for model training on feature subsets."""

from copy import deepcopy
import os
from os import path
from typing import Union

from anndata import AnnData
from anndata.experimental import AnnCollection
from joblib import delayed
from joblib import Parallel
from torch import nn

from scalr.model_training_pipeline import ModelTrainingPipeline
from scalr.utils import EventLogger
from scalr.utils import FlowLogger
from scalr.utils import read_data
from scalr.utils import write_chunkwise_data


[docs] class FeatureSubsetting: """Class for FeatureSubsetting. It trains a model for each subsetted datasets, each containing `feature_subsetsize` genes as features. """ def __init__(self, feature_subsetsize: int, chunk_model_config: dict, chunk_model_train_config: dict, train_data: Union[AnnData, AnnCollection], val_data: Union[AnnData, AnnCollection], target: str, mappings: dict, dirpath: str = None, device: str = 'cpu', num_workers: int = 1, sample_chunksize: int = None): """Initialize required parameters for feature subset training. Args: feature_subsetsize (int): Number of features in one subset. chunk_model_config (dict): Chunked model config. chunk_model_train_config (dict): Chunked model training config. train_data (Union[AnnData, AnnCollection]): Train dataset. val_data (Union[AnnData, AnnCollection]): Validation dataset. target (str): Target to train model. mappings (dict): mapping of target to labels. dirpath (str, optional): Dirpath to store chunked model weights. Defaults to None. device (str, optional): Device to train models on. Defaults to 'cpu'. num_workers (int, optional): Number of parallel processes to launch to train multiple feature subsets simultaneously. Defaults to using single process. sample_chunksize (int, optional): Chunks of samples to be loaded in memory at once. Required when `num_workers` > 1. """ self.feature_subsetsize = feature_subsetsize self.chunk_model_config = chunk_model_config self.chunk_model_train_config = chunk_model_train_config self.train_data = train_data self.val_data = val_data self.target = target self.mappings = mappings self.dirpath = dirpath self.device = device self.num_workers = num_workers if num_workers else 1 self.sample_chunksize = sample_chunksize self.total_features = len(self.train_data.var_names) # Note that EventLogger does not work with parallel training # You may use tensorboard logging to track model training logs if self.num_workers == 1: self.event_logger = EventLogger('FeatureSubsetting')
[docs] def write_feature_subsetted_data(self): """Write chunks of feature-subsetted data, to enable parallel training of models using different chunks of data.""" if self.num_workers == 1: return self.feature_chunked_data_dirpath = path.join(self.dirpath, 'chunked_data') os.makedirs(self.feature_chunked_data_dirpath, exist_ok=True) i = 0 for start in range(0, self.total_features, self.feature_subsetsize): feature_subset_inds = list( range(start, min(start + self.feature_subsetsize, self.total_features))) write_chunkwise_data(self.train_data, self.sample_chunksize, path.join(self.feature_chunked_data_dirpath, 'train', str(i)), feature_inds=feature_subset_inds, num_workers=self.num_workers) write_chunkwise_data(self.val_data, self.sample_chunksize, path.join(self.feature_chunked_data_dirpath, 'val', str(i)), feature_inds=feature_subset_inds, num_workers=self.num_workers) i += 1 del self.train_data del self.val_data
[docs] def train_chunked_models(self) -> list[nn.Module]: """Trains a model for each subset data. Returns: list[nn.Module]: List of models for each subset. """ if self.num_workers == 1: self.event_logger.info('Feature subset models training') chunked_models_dirpath = path.join(self.dirpath, 'chunked_models') os.makedirs(chunked_models_dirpath, exist_ok=True) def train_chunked_model(i, start): if self.num_workers == 1: self.event_logger.info(f'\nChunk {i}') chunk_dirpath = path.join(chunked_models_dirpath, str(i)) os.makedirs(chunk_dirpath, exist_ok=True) if self.num_workers > 1: train_features_subset = read_data( path.join(self.feature_chunked_data_dirpath, 'train', str(i))) val_features_subset = read_data( path.join(self.feature_chunked_data_dirpath, 'val', str(i))) else: train_features_subset = self.train_data[:, start:start + self.feature_subsetsize] val_features_subset = self.val_data[:, start:start + self.feature_subsetsize] chunk_model_config = deepcopy(self.chunk_model_config) model_trainer = ModelTrainingPipeline(chunk_model_config, self.chunk_model_train_config, chunk_dirpath, self.device) model_trainer.set_data_and_targets(train_features_subset, val_features_subset, self.target, self.mappings) model_trainer.build_model_training_artifacts() best_model = model_trainer.train() self.chunk_model_config, self.chunk_model_train_config = model_trainer.get_updated_config( ) return i, best_model parallel = Parallel(n_jobs=self.num_workers) models = parallel( delayed(train_chunked_model)(i, start) for i, (start) in enumerate( range(0, self.total_features, self.feature_subsetsize))) # parallel loop returns all models with the chunk number, which is used to sort models in order # model[1] returns only the model, without the chunk number models = sorted(models) models = [model[1] for model in models] return models
[docs] def get_updated_configs(self): """Returns updated configs.""" return self.chunk_model_config, self.chunk_model_train_config