"""This file is an implementation of a linear scorer."""
from typing import Union
from anndata import AnnData
from anndata.experimental import AnnCollection
import numpy as np
import torch
from torch import nn
from scalr.feature.scoring import ScoringBase
[docs]
class LinearScorer(ScoringBase):
    """Class for the linear scorer.
    
    This Scorer is only applicable for linear (single-layer) models.
    It directly uses the weights as the score for each feature.
    """
    def __init__(self):
        pass
[docs]
    def generate_scores(self, model: nn.Module, *args, **kwargs) -> np.ndarray:
        """A function to generate and return the weights of the model as a score."""
        return model.state_dict()['out_layer.weight'].cpu().detach().numpy()