Interpretability 9: Different plotting metrics

from kan import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

model = KAN(width=[2,5,1], device=device)
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2, device=device)
model.fit(dataset, steps = 20, lamb=1e-3);
cuda
checkpoint directory created: ./model
saving model version 0.0
| train_loss: 1.48e-02 | test_loss: 1.53e-02 | reg: 7.01e+00 | : 100%|█| 20/20 [00:04<00:00,  4.64it
saving model version 0.1

Note: To plot the KAN diagram, there are also three options * forward_u: the “norm” of edge, normalized (output std/input std) * forward_n: the “norm” of edge, unnormalized (output std) * backward: the edge attribution score (default)

model.plot(metric='forward_u')
../_images/Interp_9_different_plotting_metrics_3_0.png
model.plot(metric='forward_n')
../_images/Interp_9_different_plotting_metrics_4_0.png
model.plot(metric='backward')
../_images/Interp_9_different_plotting_metrics_5_0.png