Interpretability 3: KAN Compiler

We have shown in many examples how to extract symbolic formulas from KANs. Now we want to consider the reverse task: compiling a symbolic formula into KANs. This might be needed for many reasons. One use case is that we have prior knowledge which is the approximate ground truth (empirical/constitutive laws etc.) and we want to build this knowledge into neural networks and only fine tune the network to real data.

from kan.compiler import kanpiler
from sympy import *
from kan.utils import create_dataset
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

input_variables = x,y = symbols('x y')
expr = exp(sin(pi*x)+y**2)

model = kanpiler(input_variables, expr).to(device)

f = lambda x: torch.exp(torch.sin(torch.pi*x[:,0]) + x[:,1]**2)
dataset = create_dataset(f, n_var=2, device=device)
model.get_act(dataset)

model.plot()
cuda
../_images/Interp_3_KAN_Compiler_2_1.png

if you want more complicated formulas, you can load in an equation in the Feynman dataset.

from kan.feynman import get_feynman_dataset
import matplotlib.pyplot as plt

problem_id = 36 # problem_id in 1-120
input_variables, expr, f, ranges = get_feynman_dataset(problem_id)
n_var = len(input_variables)
model = kanpiler(input_variables, expr)

dataset = create_dataset(f, n_var=n_var, ranges=ranges)
model.get_act(dataset)
#model.plot(in_vars=input_variables, out_vars=[expr], beta=10000, title='P{}'.format(problem_id))
model.plot(in_vars=input_variables, out_vars=[symbols('omega')], beta=10000)
#plt.savefig('./fig1.pdf', bbox_inches='tight', dpi=200)
../_images/Interp_3_KAN_Compiler_4_0.png

We can check that the model indeed achieves zero loss (near machine precision) on the data

torch.mean((model(dataset['train_input'])-dataset['train_label'])**2)
tensor(1.5383e-15, grad_fn=<MeanBackward0>)

Assume we have a dataset for which the symbolic formula is only an approximate ground truth, we want to train on the real data to fine tune the model. The current model has the symbolic front turned on and the spline front turned off. So only the affine parameters in the symbolic equations are trainable. Depending on how much expressive power you would like, you may need:

  • If you want to keep the symbolic functions, but just train the affine parameters, no need to do anything.

  • If you want to the functions to be trainable, call model.perturb(). If you want only the currently active functions to be trainable while the currently dead functions to remain dead, use mode=‘minimal’. Otherwise if you want to allow the currently dead functions to be active, use mode = ‘all’ (by default).

  • If you think the ground truth should be more complicated than the current network, you can expand it first using expand_width and/or expand_depth, and then use model.perturb().

In the following, we present the most complicated case where you want to expand the network first.

step 1: expand depth, add an extra linear function in the end

model.expand_depth()
model.get_act(dataset)
model.plot()
../_images/Interp_3_KAN_Compiler_9_0.png

step 2: add two addition nodes in layer 1.

model.expand_width(1, 2)
model.get_act(dataset)
model.plot()
../_images/Interp_3_KAN_Compiler_11_0.png

step 3: add two multiplication nodes in layer 2, with arity 2 and 3.

model.expand_width(2, 2, sum_bool=False, mult_arity=[2,3])
model.get_act(dataset)
model.plot()
../_images/Interp_3_KAN_Compiler_13_0.png

step 4: now we perturb all edges (mode=‘minimal’ only perturb the currently active edges, mode=‘all’ perturbs all neurons).

model.perturb(mag=0.1, mode='all')
model.get_act(dataset)
model.plot(metric='forward_n')
# purple means both symbolic front (red) and spline front (black) are active
../_images/Interp_3_KAN_Compiler_15_0.png
model.plot(beta=1000)
../_images/Interp_3_KAN_Compiler_16_0.png