Interpretability 2: Advanced MultKAN
In the last tutorial, we introduced multiplications to KANs which makes interpretation easier in the case when multiplications are needed. Multiplication nodes by default takes in two numbers, but can take more variables specified by the user. This is done through the mult_arity argument (by default mult_arity=2).
from kan import *
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = KAN(width=[2,[3,2],1], device=device)
x = torch.randn(100,2).to(device)
model(x)
model.plot()
cuda
checkpoint directory created: ./model
saving model version 0.0
mult_arity=3
model = KAN(width=[2,[3,2],1], mult_arity=3, device=device)
model(x)
model.plot()
checkpoint directory created: ./model
saving model version 0.0
mult_arity=4
model = KAN(width=[2,[3,2],1], mult_arity=4, device=device)
model(x)
model.plot()
checkpoint directory created: ./model
saving model version 0.0
You may want different multiplication nodes to take in different number of variables. This is also possible: pass in mult_arity as a list of lists, specifying the arities in each layer, including input layer, hidden layer(s), and output layer.
In the following example, we have 0 multiplications in the input or in the output layer, corresponding to empty lists. In the hidden layer, we have two multiplications with arity = 2 and arity = 3, so we have the list [2,3] in the middle.
model = KAN(width=[2,[3,2],1], mult_arity=[[],[2,3],[]], device=device)
model(x)
model.plot()
checkpoint directory created: ./model
saving model version 0.0
Make a deeper network
model = KAN(width=[2,[2,2],[1,3],[3,2],[1,1]], mult_arity=2, device=device)
model(x)
model.plot()
checkpoint directory created: ./model
saving model version 0.0