Source code for scalr.feature.feature_subsetting

"""This file contains implementation for model training on feature subsets."""

from copy import deepcopy
import os
from os import path
from typing import Union

from anndata import AnnData
from anndata.experimental import AnnCollection
from torch import nn

from scalr.model_training_pipeline import ModelTrainingPipeline
from scalr.utils import EventLogger
from scalr.utils import FlowLogger


[docs] class FeatureSubsetting: """Class for FeatureSubsetting. It trains a model for each subsetted datasets, each containing `feature_subsetsize` genes as features. """ def __init__(self, feature_subsetsize: int, chunk_model_config: dict, chunk_model_train_config: dict, train_data: Union[AnnData, AnnCollection], val_data: Union[AnnData, AnnCollection], target: str, mappings: dict, dirpath: str = None, device: str = 'cpu'): """Initialize required parameters for feature subset training. Args: feature_subsetsize (int): Number of features in one subset. chunk_model_config (dict): Chunked model config. chunk_model_train_config (dict): Chunked model training config. train_data (Union[AnnData, AnnCollection]): Train dataset. val_data (Union[AnnData, AnnCollection]): Validation dataset. target (str): Target to train model. mappings (dict): mapping of target to labels. dirpath (str, optional): Dirpath to store chunked model weights. Defaults to None. device (str, optional): Device to train models on. Defaults to 'cpu'. """ self.event_logger = EventLogger('FeatureSubsetting') self.feature_subsetsize = feature_subsetsize self.chunk_model_config = chunk_model_config self.chunk_model_train_config = chunk_model_train_config self.train_data = train_data self.val_data = val_data self.target = target self.mappings = mappings self.dirpath = dirpath self.device = device
[docs] def train_chunked_models(self) -> list[nn.Module]: """Trains a model for each subset data. Returns: list[nn.Module]: List of models for each subset. """ self.event_logger.info('Feature subset models training') models = [] chunked_models_dirpath = path.join(self.dirpath, 'chunked_models') os.makedirs(chunked_models_dirpath, exist_ok=True) i = 0 for start in range(0, len(self.train_data.var_names), self.feature_subsetsize): self.event_logger.info(f'\nChunk {i}') chunk_dirpath = path.join(chunked_models_dirpath, str(i)) os.makedirs(chunk_dirpath, exist_ok=True) i += 1 train_features_subset = self.train_data[:, start:start + self.feature_subsetsize] val_features_subset = self.val_data[:, start:start + self.feature_subsetsize] chunk_model_config = deepcopy(self.chunk_model_config) model_trainer = ModelTrainingPipeline(chunk_model_config, self.chunk_model_train_config, chunk_dirpath, self.device) model_trainer.set_data_and_targets(train_features_subset, val_features_subset, self.target, self.mappings) model_trainer.build_model_training_artifacts() best_model = model_trainer.train() self.chunk_model_config, self.chunk_model_train_config = model_trainer.get_updated_config( ) models.append(best_model) return models
[docs] def get_updated_configs(self): """Returns updated configs.""" return self.chunk_model_config, self.chunk_model_train_config