API 12: Checkpoint, save & load model
Whenever the KAN (model) is altered (e.g., fit, prune …), a new version is saved to the model.ckpt folder (by default ‘model’). The version number is ‘a.b’, where a is the round number (starting from zero, +1 when model.rewind() is called), b is the version number in each round.
the initialized model has version 0.0
from kan import *
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2, device=device)
model = KAN(width=[2,5,1], grid=5, k=3, seed=1, auto_save=True, device=device)
model.get_act(dataset)
model.plot()
cuda
checkpoint directory created: ./model
saving model version 0.0
the auto_save is on (by default)
model.auto_save
True
After fitting, the version becomes 0.1
model.fit(dataset, opt="LBFGS", steps=20, lamb=0.01);
model.plot()
| train_loss: 3.34e-02 | test_loss: 3.29e-02 | reg: 4.93e+00 | : 100%|█| 20/20 [00:03<00:00, 5.10it
saving model version 0.1
After pruning, the version becomes 0.2
model = model.prune()
model.plot()
saving model version 0.2
Suppose we want to revert back to version 0.1, use model = model.rewind(‘0.1’). This starts a new round, meaning version 0.1 renamed to version 1.1.
# revert to version 0.1 (if continuing)
model = model.rewind('0.1')
# revert to version 0.1 (if starting from scratch)
#model = KAN.loadckpt('./model' + '0.1')
#model.get_act(dataset)
model.plot()
rewind to model version 0.1, renamed as 1.1
Suppose we do some more manipulation to version 1.1, we will roll forward to version 1.2
model.fit(dataset, opt="LBFGS", steps=2);
model.plot()
| train_loss: 2.06e-02 | test_loss: 2.18e-02 | reg: 5.48e+00 | : 100%|█| 2/2 [00:00<00:00, 5.83it/s
saving model version 1.2