Interprebility 6: Test symmetries of trained NN

from kan import *
from kan.hypothesis import plot_tree

f = lambda x: (x[:,[0]]**2 + x[:,[1]]**2) ** 2 + (x[:,[2]]**2 + x[:,[3]]**2) ** 2
x = torch.rand(100,4) * 2 - 1
plot_tree(f, x)
../_images/Interp_6_test_symmetry_NN_1_0.png
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

dataset = create_dataset(f, n_var=4, device=device)
model = KAN(width=[4,5,5,1], seed=0, device=device)
model.fit(dataset, steps=100);
cuda
checkpoint directory created: ./model
saving model version 0.0
| train_loss: 1.58e-03 | test_loss: 4.79e-03 | reg: 2.38e+01 | : 100%|█| 100/100 [00:20<00:00,  4.93
saving model version 0.1
model.tree(sym_th=1e-2, sep_th=5e-1)
../_images/Interp_6_test_symmetry_NN_3_0.png