Interprebility 6: Test symmetries of trained NN
from kan import *
from kan.hypothesis import plot_tree
f = lambda x: (x[:,[0]]**2 + x[:,[1]]**2) ** 2 + (x[:,[2]]**2 + x[:,[3]]**2) ** 2
x = torch.rand(100,4) * 2 - 1
plot_tree(f, x)
data:image/s3,"s3://crabby-images/661cc/661cc23a4591e239cf80a3029d287beba0b65eda" alt="../_images/Interp_6_test_symmetry_NN_1_0.png"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
dataset = create_dataset(f, n_var=4, device=device)
model = KAN(width=[4,5,5,1], seed=0, device=device)
model.fit(dataset, steps=100);
cuda
checkpoint directory created: ./model
saving model version 0.0
| train_loss: 1.58e-03 | test_loss: 4.79e-03 | reg: 2.38e+01 | : 100%|█| 100/100 [00:20<00:00, 4.93
saving model version 0.1
model.tree(sym_th=1e-2, sep_th=5e-1)
data:image/s3,"s3://crabby-images/5d2b5/5d2b5df4b67805a67993af58989b2f5957ad059a" alt="../_images/Interp_6_test_symmetry_NN_3_0.png"