Source code for scalr.nn.model.shap_model

"""This file is an implementation of the custom SHAP model."""

from torch import nn


[docs] class CustomShapModel(nn.Module): """Class for a custom model for SHAP.""" def __init__(self, model, key='cls_output'): """Initialize required parameters for SHAP model. Args: model: Trained model used for SHAP calculation. key: key from model output dict. """ super().__init__() self.model = model self.key = key
[docs] def forward(self, x): """Pass input through the model and return output. Args: x: Tensor. """ output = self.model(x) if isinstance(output, dict): output = output[self.key] return output