"""This file is an implementation of the custom SHAP model."""
from torch import nn
[docs]
class CustomShapModel(nn.Module):
"""Class for a custom model for SHAP."""
def __init__(self, model, key='cls_output'):
"""Initialize required parameters for SHAP model.
Args:
model: Trained model used for SHAP calculation.
key: key from model output dict.
"""
super().__init__()
self.model = model
self.key = key
[docs]
def forward(self, x):
"""Pass input through the model and return output.
Args:
x: Tensor.
"""
output = self.model(x)
if isinstance(output, dict):
output = output[self.key]
return output