Example 10: Relativitistic Velocity Addition

In this example, we will symbolically regress \(f(u,v)=\frac{u+v}{1+uv}\). In relavitity, we know the rapidity trick \(f(u,v)={\rm tanh}({\rm arctanh}\ u+{\rm arctanh}\ v)\). Can we rediscover rapidity trick with KAN?

Intialize model and create dataset

from kan import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# initialize KAN with G=3
model = KAN(width=[2,1,1], grid=10, k=3, device=device)

# create dataset
f = lambda x: (x[:,[0]]+x[:,[1]])/(1+x[:,[0]]*x[:,[1]])
dataset = create_dataset(f, n_var=2, ranges=[-0.9,0.9], device=device)
cuda
checkpoint directory created: ./model
saving model version 0.0

Train KAN and plot

model.fit(dataset, opt="LBFGS", steps=20);
| train_loss: 2.28e-03 | test_loss: 2.31e-03 | reg: 6.50e+00 | : 100%|█| 20/20 [00:03<00:00,  5.88it
saving model version 0.1
model.plot(beta=10)
../_images/Example_10_relativity-addition_6_0.png

Retrain the model, the loss remains similar, meaning that the locking does not degrade model behavior, justifying our hypothesis that these two activation functions are the same. Let’s now determine what this function is using \(\texttt{suggest_symbolic}\)

model.suggest_symbolic(0,1,0,weight_simple=0.0)
  function  fitting r2    r2 loss  complexity  complexity loss  total loss
0  arctanh    0.999992 -15.786788           4                4  -15.786788
1      tan    0.999825 -12.397871           3                3  -12.397871
2   arccos    0.998852  -9.753944           4                4   -9.753944
3   arcsin    0.998852  -9.753944           4                4   -9.753944
4     sqrt    0.982166  -5.808383           2                2   -5.808383
('arctanh',
 (<function kan.utils.<lambda>(x)>,
  <function kan.utils.<lambda>(x)>,
  4,
  <function kan.utils.<lambda>(x, y_th)>),
 0.999992311000824,
 4)

We can see that \({\rm arctanh}\) is at the top of the suggestion list! So we can set both to arctanh, retrain the model, and plot it.

model.fix_symbolic(0,0,0,'arctanh')
model.fix_symbolic(0,1,0,'arctanh')
r2 is 0.9999759197235107
saving model version 0.2
r2 is 0.999992311000824
saving model version 0.3
tensor(1.0000, device='cuda:0')
model.fit(dataset, opt="LBFGS", steps=20, update_grid=False);
| train_loss: 7.94e-04 | test_loss: 9.43e-04 | reg: 4.12e+00 | : 100%|█| 20/20 [00:04<00:00,  4.34it
saving model version 0.4
model.plot(beta=10)
../_images/Example_10_relativity-addition_12_0.png

We will see that \({\rm tanh}\) is at the top of the suggestion list! So we can set it to \({\rm tanh}\), retrain the model to machine precision, plot it and finally get the symbolic formula.

model.suggest_symbolic(1,0,0,weight_simple=0.)
   function  fitting r2    r2 loss  complexity  complexity loss  total loss
0      tanh    0.999998 -16.336284           3                3  -16.336284
1    arctan    0.999435 -10.764618           4                4  -10.764618
2       cos    0.995899  -7.926177           2                2   -7.926177
3       sin    0.995899  -7.926177           2                2   -7.926177
4  gaussian    0.994457  -7.492519           3                3   -7.492519
('tanh',
 (<function kan.utils.<lambda>(x)>,
  <function kan.utils.<lambda>(x)>,
  3,
  <function kan.utils.<lambda>(x, y_th)>),
 0.9999979138374329,
 3)
model.fix_symbolic(1,0,0,'tanh')
r2 is 0.9999979138374329
saving model version 0.5
tensor(1.0000, device='cuda:0')
model.fit(dataset, opt="Adam", lr=1e-3, steps=2000, update_grid=False, singularity_avoiding=True);
| train_loss: 1.97e-06 | test_loss: 2.06e-06 | reg: 0.00e+00 | : 100%|█| 2000/2000 [00:21<00:00, 93.
saving model version 0.6
model.plot()
../_images/Example_10_relativity-addition_17_0.png
formula = model.symbolic_formula()[0][0]
nsimplify(ex_round(formula, 4))
\[\displaystyle \tanh{\left(\operatorname{atanh}{\left(x_{1} \right)} + \operatorname{atanh}{\left(x_{2} \right)} \right)}\]