API 2: Plotting
Initialize KAN and create dataset
from kan import *
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).
model = KAN(width=[2,5,1], grid=3, k=3, seed=1, device=device)
# create dataset f(x,y) = exp(sin(pi*x)+y^2)
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2, device=device)
dataset['train_input'].shape, dataset['train_label'].shape
cuda
checkpoint directory created: ./model
saving model version 0.0
(torch.Size([1000, 2]), torch.Size([1000, 1]))
Plot KAN at initialization
# plot KAN at initialization
model(dataset['train_input']);
model.plot(beta=100)
data:image/s3,"s3://crabby-images/fb1a8/fb1a8e5a68d97b4e85bd5605f397296663c78035" alt="../_images/API_2_plotting_4_0.png"
# if you want to add variable names and title
model.plot(beta=100, in_vars=[r'$\alpha$', 'x'], out_vars=['y'], title = 'My KAN')
data:image/s3,"s3://crabby-images/31aa4/31aa41642a4acb61375349148b535142b0e7361f" alt="../_images/API_2_plotting_5_0.png"
Train KAN with sparsity regularization
# train the model
model.fit(dataset, opt="LBFGS", steps=20, lamb=0.01);
| train_loss: 5.20e-02 | test_loss: 5.35e-02 | reg: 4.93e+00 | : 100%|█| 20/20 [00:03<00:00, 5.22it
saving model version 0.1
\(\beta\) controls the transparency of activations. Larger \(\beta\) => more activation functions show up. We usually want to set a proper beta such that only important connections are visually significant. transparency is set to be \({\rm tanh}(\beta \phi)\) where \(\phi\) is the scale of the activation function (metric=‘forward_u’), normalized scale (metric=‘forward_n’) or the feature attribution score (metric=‘backward’). By default \(\beta=3\) and metric=‘backward’.
model.plot()
data:image/s3,"s3://crabby-images/94d88/94d88f8fc2cf0bc99d9b21f804a1f12eac81cda7" alt="../_images/API_2_plotting_9_0.png"
model.plot(beta=100000)
data:image/s3,"s3://crabby-images/95345/953456c686d59ac40fd67c8016cf109f79825370" alt="../_images/API_2_plotting_10_0.png"
model.plot(beta=0.1)
data:image/s3,"s3://crabby-images/3432a/3432abc4204296bdb6a9afec9af923369f277714" alt="../_images/API_2_plotting_11_0.png"
plotting with different metrics: ‘forward_n’, ‘forward_u’, ‘backward’
model.plot(metric='forward_n', beta=100)
data:image/s3,"s3://crabby-images/8cb00/8cb00362ca370765b40ca36e5860fdb64a9558c7" alt="../_images/API_2_plotting_13_0.png"
model.plot(metric='forward_u', beta=100)
data:image/s3,"s3://crabby-images/7457a/7457a3e6bd1b0e3e4f779a76fedea46efcf2cd6e" alt="../_images/API_2_plotting_14_0.png"
model.plot(metric='backward', beta=100)
data:image/s3,"s3://crabby-images/8fa5e/8fa5e736558ca22bad4182d32b882b5e1def6fa9" alt="../_images/API_2_plotting_15_0.png"
Remove insignificant neurons
model = model.prune()
model.plot()
saving model version 0.2
data:image/s3,"s3://crabby-images/34ce4/34ce48d01857cfd132e8fb199a9efb1b8ea516bf" alt="../_images/API_2_plotting_17_1.png"
Resize the figure using the “scale” parameter. By default: 0.5
model.plot(scale=0.5)
data:image/s3,"s3://crabby-images/1c5c3/1c5c3b8a5b594a239a7ea6b2cdb6a77065900bd9" alt="../_images/API_2_plotting_19_0.png"
model.plot(scale=0.2)
data:image/s3,"s3://crabby-images/7e012/7e01294d4e46a7803f24acac29e967ba3b4c1e88" alt="../_images/API_2_plotting_20_0.png"
model.plot(scale=2.0)
data:image/s3,"s3://crabby-images/d6632/d663237fb1608d1bc65e25a25f06224da4e28000" alt="../_images/API_2_plotting_21_0.png"
If you want to see sample distribution in addition to the line, set “sample=True”
model.plot(sample=True)
data:image/s3,"s3://crabby-images/f0997/f09972e797776452cf01f2adc04c6ae893ccd0b1" alt="../_images/API_2_plotting_23_0.png"
The samples are more visible if we use a smaller number of samples
model.get_act(dataset['train_input'][:20])
model.plot(sample=True)
data:image/s3,"s3://crabby-images/91bd5/91bd5ee19b79bddae86e0bf5800f11d6ef40674d" alt="../_images/API_2_plotting_25_0.png"
If a function is set to be symbolic, it becomes red
model.fix_symbolic(0,1,0,'x^2')
r2 is 0.9992202520370483
saving model version 0.3
tensor(0.9992, device='cuda:0')
model.plot()
data:image/s3,"s3://crabby-images/d34fb/d34fbc432bd3b3b0cd3a2bf1d8f58427accac62c" alt="../_images/API_2_plotting_28_0.png"
If a function is set to be both symbolic and numeric (its output is the addition of symbolic and spline), then it shows up in purple
model.set_mode(0,1,0,mode='ns')
model.plot(beta=100)
data:image/s3,"s3://crabby-images/09668/09668ff33b59d4bec757a5fe60bf0d7642465d7c" alt="../_images/API_2_plotting_31_0.png"