Interpretability 11: sparse initialization
from kan import *
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = KAN([5,5,5,1], sparse_init=False, device=device)
x = torch.rand(100,5).to(device)
model.get_act(x)
model.plot()
cuda
checkpoint directory created: ./model
saving model version 0.0
data:image/s3,"s3://crabby-images/989a2/989a296eec302d527f9f36ce9e83af79defc516a" alt="../_images/Interp_11_sparse_init_1_1.png"
from kan import *
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = KAN([5,5,5,1], sparse_init=True, device=device)
x = torch.rand(100,5).to(device)
model.get_act(x)
model.plot()
cuda
checkpoint directory created: ./model
saving model version 0.0
data:image/s3,"s3://crabby-images/65a6d/65a6d544b832afbe7bfdfcbd4f528c53016beaac" alt="../_images/Interp_11_sparse_init_2_1.png"