API 12: Checkpoint, save & load model
Whenever the KAN (model) is altered (e.g., fit, prune …), a new version is saved to the model.ckpt folder (by default ‘model’). The version number is ‘a.b’, where a is the round number (starting from zero, +1 when model.rewind() is called), b is the version number in each round.
the initialized model has version 0.0
from kan import *
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2, device=device)
model = KAN(width=[2,5,1], grid=5, k=3, seed=1, auto_save=True, device=device)
model.get_act(dataset)
model.plot()
cuda
checkpoint directory created: ./model
saving model version 0.0
data:image/s3,"s3://crabby-images/8d5b2/8d5b2514615180b597c32125d70d9ab40de9f3c1" alt="../_images/API_12_checkpoint_save_load_model_3_1.png"
the auto_save is on (by default)
model.auto_save
True
After fitting, the version becomes 0.1
model.fit(dataset, opt="LBFGS", steps=20, lamb=0.01);
model.plot()
| train_loss: 3.34e-02 | test_loss: 3.29e-02 | reg: 4.93e+00 | : 100%|█| 20/20 [00:03<00:00, 5.10it
saving model version 0.1
data:image/s3,"s3://crabby-images/7c6c3/7c6c3bf8a2163bb05adfc3a5a81fd628b86a4ec6" alt="../_images/API_12_checkpoint_save_load_model_7_2.png"
After pruning, the version becomes 0.2
model = model.prune()
model.plot()
saving model version 0.2
data:image/s3,"s3://crabby-images/2a367/2a367916199192a4d310ab3b5aff53f2cbb4d67f" alt="../_images/API_12_checkpoint_save_load_model_9_1.png"
Suppose we want to revert back to version 0.1, use model = model.rewind(‘0.1’). This starts a new round, meaning version 0.1 renamed to version 1.1.
# revert to version 0.1 (if continuing)
model = model.rewind('0.1')
# revert to version 0.1 (if starting from scratch)
#model = KAN.loadckpt('./model' + '0.1')
#model.get_act(dataset)
model.plot()
rewind to model version 0.1, renamed as 1.1
data:image/s3,"s3://crabby-images/56a63/56a6388a2769139c7d2969dc3227acf512918f17" alt="../_images/API_12_checkpoint_save_load_model_11_1.png"
Suppose we do some more manipulation to version 1.1, we will roll forward to version 1.2
model.fit(dataset, opt="LBFGS", steps=2);
model.plot()
| train_loss: 2.06e-02 | test_loss: 2.18e-02 | reg: 5.48e+00 | : 100%|█| 2/2 [00:00<00:00, 5.83it/s
saving model version 1.2
data:image/s3,"s3://crabby-images/1aa40/1aa40428a697e55bf6cc6f8992e547e4836b1d56" alt="../_images/API_12_checkpoint_save_load_model_13_2.png"