Demo 6: Training Hyperparamters
Regularization helps interpretability by making KANs sparser. This may require some hyperparamter tuning. Let’s see how hyperparameters can affect training
Load KAN and create_dataset
from kan import KAN, create_dataset
import torch
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2)
dataset['train_input'].shape, dataset['train_label'].shape
(torch.Size([1000, 2]), torch.Size([1000, 1]))
Default setup
# train the model
model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
model.train(dataset, opt="LBFGS", steps=20, lamb=0.1);
model.plot()
model.prune()
model.plot(mask=True)
train loss: 1.69e-01 | test loss: 1.50e-01 | reg: 5.01e+00 : 100%|██| 20/20 [00:12<00:00, 1.59it/s]
![../_images/API_6_training_hyperparameter_4_1.png](../_images/API_6_training_hyperparameter_4_1.png)
![../_images/API_6_training_hyperparameter_4_2.png](../_images/API_6_training_hyperparameter_4_2.png)
Parameter 1: \(\lambda\), overall penalty strength.
Previously \(\lambda=0.1\), now we try different \(\lambda\).
\(\lambda=0\)
# train the model
model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
model.train(dataset, opt="LBFGS", steps=20, lamb=0.00);
model.plot()
model.prune()
model.plot(mask=True)
train loss: 4.16e-03 | test loss: 5.00e-03 | reg: 1.24e+01 : 100%|██| 20/20 [00:10<00:00, 1.86it/s]
![../_images/API_6_training_hyperparameter_7_1.png](../_images/API_6_training_hyperparameter_7_1.png)
![../_images/API_6_training_hyperparameter_7_2.png](../_images/API_6_training_hyperparameter_7_2.png)
\(\lambda=10^{-2}\)
# train the model
model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
model.train(dataset, opt="LBFGS", steps=20, lamb=0.1, lamb_entropy=10.);
model.plot()
model.prune()
model.plot(mask=True)
train loss: 6.01e-01 | test loss: 5.65e-01 | reg: 1.78e+01 : 100%|██| 20/20 [00:13<00:00, 1.51it/s]
![../_images/API_6_training_hyperparameter_9_1.png](../_images/API_6_training_hyperparameter_9_1.png)
![../_images/API_6_training_hyperparameter_9_2.png](../_images/API_6_training_hyperparameter_9_2.png)
\(\lambda=1\)
# train the model
model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
model.train(dataset, opt="LBFGS", steps=20, lamb=1);
model.plot()
model.prune()
model.plot(mask=True)
train loss: 1.09e+00 | test loss: 1.02e+00 | reg: 5.18e+00 : 100%|██| 20/20 [00:11<00:00, 1.67it/s]
![../_images/API_6_training_hyperparameter_11_1.png](../_images/API_6_training_hyperparameter_11_1.png)
![../_images/API_6_training_hyperparameter_11_2.png](../_images/API_6_training_hyperparameter_11_2.png)
Parameter 2: (relative) penalty strength of entropy \(\lambda_{\rm ent}\).
The absolute magnitude is \(\lambda\lambda_{\rm ent}\). Previously we set \(\lambda=0.1\) and \(\lambda_{\rm ent}=10.0\). Below we fix \(\lambda=0.1\) and vary \(\lambda_{\rm ent}\).
\(\lambda_{\rm ent}=0.0\)
# train the model
model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
model.train(dataset, opt="LBFGS", steps=20, lamb=0.1, lamb_entropy=0.0);
model.plot()
model.prune()
model.plot(mask=True)
train loss: 8.90e-02 | test loss: 8.40e-02 | reg: 1.68e+00 : 100%|██| 20/20 [00:12<00:00, 1.65it/s]
![../_images/API_6_training_hyperparameter_14_1.png](../_images/API_6_training_hyperparameter_14_1.png)
![../_images/API_6_training_hyperparameter_14_2.png](../_images/API_6_training_hyperparameter_14_2.png)
\(\lambda_{\rm ent}=10.\)
model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
model.train(dataset, opt="LBFGS", steps=20, lamb=0.1, lamb_entropy=10.0);
model.plot()
model.prune()
model.plot(mask=True)
train loss: 6.03e-01 | test loss: 5.67e-01 | reg: 1.77e+01 : 100%|██| 20/20 [00:10<00:00, 1.89it/s]
![../_images/API_6_training_hyperparameter_16_1.png](../_images/API_6_training_hyperparameter_16_1.png)
![../_images/API_6_training_hyperparameter_16_2.png](../_images/API_6_training_hyperparameter_16_2.png)
\(\lambda_{\rm ent}=100.\)
model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
model.train(dataset, opt="LBFGS", steps=20, lamb=0.1, lamb_entropy=100.0);
model.plot()
model.prune()
model.plot(mask=True)
train loss: 1.60e+00 | test loss: 1.54e+00 | reg: 2.69e+02 : 100%|██| 20/20 [00:11<00:00, 1.67it/s]
![../_images/API_6_training_hyperparameter_18_1.png](../_images/API_6_training_hyperparameter_18_1.png)
![../_images/API_6_training_hyperparameter_18_2.png](../_images/API_6_training_hyperparameter_18_2.png)
Parameter 3: Grid size \(G\).
Previously we set \(G=5\), we vary \(G\) below.
\(G=1\)
model = KAN(width=[2,5,1], grid=1, k=3, seed=0)
model.train(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=2.);
model.plot()
model.prune()
model.plot(mask=True)
train loss: 1.41e-01 | test loss: 1.33e-01 | reg: 1.01e+01 : 100%|██| 20/20 [00:06<00:00, 2.95it/s]
![../_images/API_6_training_hyperparameter_21_1.png](../_images/API_6_training_hyperparameter_21_1.png)
![../_images/API_6_training_hyperparameter_21_2.png](../_images/API_6_training_hyperparameter_21_2.png)
\(G=3\)
model = KAN(width=[2,5,1], grid=3, k=3, seed=0)
model.train(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=2.);
model.plot()
model.prune()
model.plot(mask=True)
train loss: 6.18e-02 | test loss: 5.66e-02 | reg: 5.93e+00 : 100%|██| 20/20 [00:11<00:00, 1.76it/s]
![../_images/API_6_training_hyperparameter_23_1.png](../_images/API_6_training_hyperparameter_23_1.png)
![../_images/API_6_training_hyperparameter_23_2.png](../_images/API_6_training_hyperparameter_23_2.png)
\(G=5\)
model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
model.train(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=2.);
model.plot()
model.prune()
model.plot(mask=True)
train loss: 7.47e-02 | test loss: 6.52e-02 | reg: 6.12e+00 : 100%|██| 20/20 [00:12<00:00, 1.58it/s]
![../_images/API_6_training_hyperparameter_25_1.png](../_images/API_6_training_hyperparameter_25_1.png)
![../_images/API_6_training_hyperparameter_25_2.png](../_images/API_6_training_hyperparameter_25_2.png)
\(G=10\)
model = KAN(width=[2,5,1], grid=10, k=3, seed=0)
model.train(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=2.);
model.plot()
model.prune()
model.plot(mask=True)
train loss: 8.08e-02 | test loss: 7.24e-02 | reg: 5.89e+00 : 100%|██| 20/20 [00:13<00:00, 1.44it/s]
![../_images/API_6_training_hyperparameter_27_1.png](../_images/API_6_training_hyperparameter_27_1.png)
![../_images/API_6_training_hyperparameter_27_2.png](../_images/API_6_training_hyperparameter_27_2.png)
\(G=20\)
model = KAN(width=[2,5,1], grid=20, k=3, seed=0)
model.train(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=2.);
model.plot()
model.prune()
model.plot(mask=True)
train loss: 5.14e-02 | test loss: 5.50e-02 | reg: 7.70e+00 : 100%|██| 20/20 [00:16<00:00, 1.23it/s]
![../_images/API_6_training_hyperparameter_29_1.png](../_images/API_6_training_hyperparameter_29_1.png)
![../_images/API_6_training_hyperparameter_29_2.png](../_images/API_6_training_hyperparameter_29_2.png)
Parameter 4: seed.
Previously we use seed = 0. Below we vary seed.
\({\rm seed} = 1\)
model = KAN(width=[2,5,1], grid=5, k=3, seed=1, noise_scale_base=0.0)
model.train(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=10.);
model.plot()
model.prune()
model.plot(mask=True)
train loss: 5.58e-02 | test loss: 5.50e-02 | reg: 8.48e+00 : 100%|██| 20/20 [00:13<00:00, 1.50it/s]
![../_images/API_6_training_hyperparameter_32_1.png](../_images/API_6_training_hyperparameter_32_1.png)
![../_images/API_6_training_hyperparameter_32_2.png](../_images/API_6_training_hyperparameter_32_2.png)
\({\rm seed} = 42\)
model = KAN(width=[2,5,1], grid=5, k=3, seed=42, noise_scale_base=0.0)
model.train(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=10.);
model.plot()
model.prune()
model.plot(mask=True)
train loss: 1.43e-01 | test loss: 1.25e-01 | reg: 1.85e+01 : 100%|██| 20/20 [00:12<00:00, 1.65it/s]
![../_images/API_6_training_hyperparameter_34_1.png](../_images/API_6_training_hyperparameter_34_1.png)
![../_images/API_6_training_hyperparameter_34_2.png](../_images/API_6_training_hyperparameter_34_2.png)
\({\rm seed} = 2024\)
model = KAN(width=[2,5,1], grid=5, k=3, seed=2024, noise_scale_base=0.0)
model.train(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=10.);
model.plot()
model.prune()
model.plot(mask=True)
train loss: 1.50e-01 | test loss: 1.39e-01 | reg: 2.37e+01 : 100%|██| 20/20 [00:12<00:00, 1.57it/s]
![../_images/API_6_training_hyperparameter_36_1.png](../_images/API_6_training_hyperparameter_36_1.png)
![../_images/API_6_training_hyperparameter_36_2.png](../_images/API_6_training_hyperparameter_36_2.png)