Example 6: Solving Partial Differential Equation (PDE)

We aim to solve a 2D poisson equation \(\nabla^2 f(x,y) = -2\pi^2{\rm sin}(\pi x){\rm sin}(\pi y)\), with boundary condition \(f(-1,y)=f(1,y)=f(x,-1)=f(x,1)=0\). The ground truth solution is \(f(x,y)={\rm sin}(\pi x){\rm sin}(\pi y)\).

from kan import *
import matplotlib.pyplot as plt
from torch import autograd
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

dim = 2
np_i = 21 # number of interior points (along each dimension)
np_b = 21 # number of boundary points (along each dimension)
ranges = [-1, 1]

model = KAN(width=[2,2,1], grid=5, k=3, seed=1, device=device)

def batch_jacobian(func, x, create_graph=False):
    # x in shape (Batch, Length)
    def _func_sum(x):
        return func(x).sum(dim=0)
    return autograd.functional.jacobian(_func_sum, x, create_graph=create_graph).permute(1,0,2)

# define solution
sol_fun = lambda x: torch.sin(torch.pi*x[:,[0]])*torch.sin(torch.pi*x[:,[1]])
source_fun = lambda x: -2*torch.pi**2 * torch.sin(torch.pi*x[:,[0]])*torch.sin(torch.pi*x[:,[1]])

# interior
sampling_mode = 'random' # 'radnom' or 'mesh'

x_mesh = torch.linspace(ranges[0],ranges[1],steps=np_i)
y_mesh = torch.linspace(ranges[0],ranges[1],steps=np_i)
X, Y = torch.meshgrid(x_mesh, y_mesh, indexing="ij")
if sampling_mode == 'mesh':
    #mesh
    x_i = torch.stack([X.reshape(-1,), Y.reshape(-1,)]).permute(1,0)
else:
    #random
    x_i = torch.rand((np_i**2,2))*2-1

x_i = x_i.to(device)

# boundary, 4 sides
helper = lambda X, Y: torch.stack([X.reshape(-1,), Y.reshape(-1,)]).permute(1,0)
xb1 = helper(X[0], Y[0])
xb2 = helper(X[-1], Y[0])
xb3 = helper(X[:,0], Y[:,0])
xb4 = helper(X[:,0], Y[:,-1])
x_b = torch.cat([xb1, xb2, xb3, xb4], dim=0)

x_b = x_b.to(device)

steps = 20
alpha = 0.01
log = 1

def train():
    optimizer = LBFGS(model.parameters(), lr=1, history_size=10, line_search_fn="strong_wolfe", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)

    pbar = tqdm(range(steps), desc='description', ncols=100)

    for _ in pbar:
        def closure():
            global pde_loss, bc_loss
            optimizer.zero_grad()
            # interior loss
            sol = sol_fun(x_i)
            sol_D1_fun = lambda x: batch_jacobian(model, x, create_graph=True)[:,0,:]
            sol_D1 = sol_D1_fun(x_i)
            sol_D2 = batch_jacobian(sol_D1_fun, x_i, create_graph=True)[:,:,:]
            lap = torch.sum(torch.diagonal(sol_D2, dim1=1, dim2=2), dim=1, keepdim=True)
            source = source_fun(x_i)
            pde_loss = torch.mean((lap - source)**2)

            # boundary loss
            bc_true = sol_fun(x_b)
            bc_pred = model(x_b)
            bc_loss = torch.mean((bc_pred-bc_true)**2)

            loss = alpha * pde_loss + bc_loss
            loss.backward()
            return loss

        if _ % 5 == 0 and _ < 50:
            model.update_grid_from_samples(x_i)

        optimizer.step(closure)
        sol = sol_fun(x_i)
        loss = alpha * pde_loss + bc_loss
        l2 = torch.mean((model(x_i) - sol)**2)

        if _ % log == 0:
            pbar.set_description("pde loss: %.2e | bc loss: %.2e | l2: %.2e " % (pde_loss.cpu().detach().numpy(), bc_loss.cpu().detach().numpy(), l2.cpu().detach().numpy()))

train()
cuda
checkpoint directory created: ./model
saving model version 0.0
pde loss: 2.23e+00 | bc loss: 5.99e-03 | l2: 3.78e-03 : 100%|███████| 20/20 [00:22<00:00,  1.11s/it]

Plot the trained KAN

model.plot(beta=10)
../_images/Example_6_PDE_interpretation_4_0.png

Fix the first layer activation to be linear function, and the second layer to be sine functions (caveat: this is quite sensitive to hypreparams)

model.fix_symbolic(0,0,0,'x')
model.fix_symbolic(0,0,1,'x')
model.fix_symbolic(0,1,0,'x')
model.fix_symbolic(0,1,1,'x')
r2 is 0.8357976675033569
r2 is not very high, please double check if you are choosing the correct symbolic function.
saving model version 0.1
r2 is 0.8300805687904358
r2 is not very high, please double check if you are choosing the correct symbolic function.
saving model version 0.2
r2 is 0.8376883268356323
r2 is not very high, please double check if you are choosing the correct symbolic function.
saving model version 0.3
r2 is 0.8372848629951477
r2 is not very high, please double check if you are choosing the correct symbolic function.
saving model version 0.4
tensor(0.8373)

After setting all to be symbolic, we further train the model (affine parameters are still trainable). The model can now reach machine precision!

train()
pde loss: 1.71e+01 | bc loss: 1.14e-02 | l2: 1.37e-01 :  50%|███▌   | 10/20 [00:11<00:11,  1.20s/it]
---------------------------------------------------------------------------

KeyboardInterrupt                         Traceback (most recent call last)

/var/folders/6j/b6y80djd4nb5hl73rv3sv8y80000gn/T/ipykernel_75424/3364925475.py in <module>
----> 1 train()


/var/folders/6j/b6y80djd4nb5hl73rv3sv8y80000gn/T/ipykernel_75424/2545871995.py in train()
     76             model.update_grid_from_samples(x_i)
     77
---> 78         optimizer.step(closure)
     79         sol = sol_fun(x_i)
     80         loss = alpha * pde_loss + bc_loss


~/opt/anaconda3/lib/python3.9/site-packages/torch/optim/optimizer.py in wrapper(*args, **kwargs)
    383                             )
    384
--> 385                 out = func(*args, **kwargs)
    386                 self._optimizer_step_code()
    387


~/opt/anaconda3/lib/python3.9/site-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
    113     def decorate_context(*args, **kwargs):
    114         with ctx_factory():
--> 115             return func(*args, **kwargs)
    116
    117     return decorate_context


~/Desktop/2022/research/code/pykan/kan/LBFGS.py in step(self, closure)
    441                     def obj_func(x, t, d):
    442                         return self._directional_evaluate(closure, x, t, d)
--> 443                     loss, flat_grad, t, ls_func_evals = _strong_wolfe(
    444                         obj_func, x_init, t, d, loss, flat_grad, gtd)
    445                 self._add_grad(t, d)


~/Desktop/2022/research/code/pykan/kan/LBFGS.py in _strong_wolfe(obj_func, x, t, d, f, g, gtd, c1, c2, tolerance_change, max_ls)
     48     g = g.clone(memory_format=torch.contiguous_format)
     49     # evaluate objective and gradient using initial step
---> 50     f_new, g_new = obj_func(x, t, d)
     51     ls_func_evals = 1
     52     gtd_new = g_new.dot(d)


~/Desktop/2022/research/code/pykan/kan/LBFGS.py in obj_func(x, t, d)
    440
    441                     def obj_func(x, t, d):
--> 442                         return self._directional_evaluate(closure, x, t, d)
    443                     loss, flat_grad, t, ls_func_evals = _strong_wolfe(
    444                         obj_func, x_init, t, d, loss, flat_grad, gtd)


~/Desktop/2022/research/code/pykan/kan/LBFGS.py in _directional_evaluate(self, closure, x, t, d)
    289     def _directional_evaluate(self, closure, x, t, d):
    290         self._add_grad(t, d)
--> 291         loss = float(closure())
    292         flat_grad = self._gather_flat_grad()
    293         self._set_param(x)


~/opt/anaconda3/lib/python3.9/site-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
    113     def decorate_context(*args, **kwargs):
    114         with ctx_factory():
--> 115             return func(*args, **kwargs)
    116
    117     return decorate_context


/var/folders/6j/b6y80djd4nb5hl73rv3sv8y80000gn/T/ipykernel_75424/2545871995.py in closure()
     70
     71             loss = alpha * pde_loss + bc_loss
---> 72             loss.backward()
     73             return loss
     74


~/opt/anaconda3/lib/python3.9/site-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    520                 inputs=inputs,
    521             )
--> 522         torch.autograd.backward(
    523             self, gradient, retain_graph, create_graph, inputs=inputs
    524         )


~/opt/anaconda3/lib/python3.9/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    264     # some Python versions print out the first line of a multi-line function
    265     # calls in the traceback and some print out the last line
--> 266     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    267         tensors,
    268         grad_tensors_,


KeyboardInterrupt:

Print out the symbolic formula

formula = model.symbolic_formula()[0][0]
ex_round(formula,6)
\[\displaystyle - 0.5 \sin{\left(3.141592 x_{1} + 3.141593 x_{2} - 4.712389 \right)} + 0.5 \sin{\left(3.141593 x_{1} - 3.141592 x_{2} + 1.570797 \right)}\]