API 4: Initialization
Initialization is the first step to gaurantee good training. Each activation function is initialized to be \(\phi(x)={\rm scale\_base}*b(x) + {\rm scale\_sp}*{\rm spline}(x)\). 1. \(b(x)\) is the base function, default: ‘silu’, can be set with \({\rm base\_fun}\)
scale_sp sample from N(0, noise_scale^2)
scale_base sampled from N(scale_base_mu, scale_base_sigma^2)
sparse initialization: if sparse_init = True, most scale_base and scale_sp will be set to zero
Default setup
from kan import KAN, create_dataset
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = KAN(width=[2,5,1], grid=5, k=3, seed=0, device=device)
x = torch.normal(0,1,size=(100,2)).to(device)
model(x) # forward is needed to collect activations for plotting
model.plot()
cuda
checkpoint directory created: ./model
saving model version 0.0
data:image/s3,"s3://crabby-images/c6ed8/c6ed823e8978ec72c8411ee57b8969008ecd90dc" alt="../_images/API_4_initialization_3_1.png"
Case 1: Initialize all activation functions to be exactly linear. We need to set noise_scale_base = 0., base_fun = identity, noise_scale = 0.
model = KAN(width=[2,5,1], grid=5, k=3, seed=0, base_fun = 'identity', device=device)
x = torch.normal(0,1,size=(100,2)).to(device)
model(x) # forward is needed to collect activations for plotting
model.plot()
checkpoint directory created: ./model
saving model version 0.0
data:image/s3,"s3://crabby-images/04624/04624162abdb48c2b5ac8216724b61e8d664dc94" alt="../_images/API_4_initialization_5_1.png"
Case 2: Noisy spline initialization (not recommended, just for illustration). Set noise_scale to be a large number.
model = KAN(width=[2,5,1], grid=5, k=3, seed=0, noise_scale=0.3, device=device)
x = torch.normal(0,1,size=(100,2)).to(device)
model(x) # forward is needed to collect activations for plotting
model.plot()
checkpoint directory created: ./model
saving model version 0.0
data:image/s3,"s3://crabby-images/b85e2/b85e200d1d1acc4f3335d1ffae1b9b1e19f110d7" alt="../_images/API_4_initialization_7_1.png"
model = KAN(width=[2,5,1], grid=5, k=3, seed=0, noise_scale=10., device=device)
x = torch.normal(0,1,size=(100,2)).to(device)
model(x) # forward is needed to collect activations for plotting
model.plot()
checkpoint directory created: ./model
saving model version 0.0
data:image/s3,"s3://crabby-images/38014/38014a02f6b68c5ec4dd7806caea84e2fc40f2b9" alt="../_images/API_4_initialization_8_1.png"
Case 3: scale_base_mu and scale_base_sigma
model = KAN(width=[2,5,1], grid=5, k=3, seed=0, scale_base_mu=5, scale_base_sigma=0, device=device)
x = torch.normal(0,1,size=(100,2)).to(device)
model(x) # forward is needed to collect activations for plotting
model.plot()
checkpoint directory created: ./model
saving model version 0.0
data:image/s3,"s3://crabby-images/16e3a/16e3a9d5054d90324fbe26aba1a391bcdde63b8c" alt="../_images/API_4_initialization_10_1.png"
model = KAN(width=[2,5,1], grid=5, k=3, seed=0, sparse_init=True, device=device)
x = torch.normal(0,1,size=(100,2)).to(device)
model(x) # forward is needed to collect activations for plotting
model.plot()
checkpoint directory created: ./model
saving model version 0.0
data:image/s3,"s3://crabby-images/7fdd5/7fdd5ae71623c5d09119df95a0e0171d09750da0" alt="../_images/API_4_initialization_11_1.png"