.. _hello-kan:

Hello, KAN!
===========

Kolmogorov-Arnold representation theorem
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Kolmogorov-Arnold representation theorem states that if :math:`f` is a
multivariate continuous function on a bounded domain, then it can be
written as a finite composition of continuous functions of a single
variable and the binary operation of addition. More specifically, for a
smooth :math:`f : [0,1]^n \to \mathbb{R}`,

.. math:: f(x) = f(x_1,...,x_n)=\sum_{q=1}^{2n+1}\Phi_q(\sum_{p=1}^n \phi_{q,p}(x_p))

where :math:`\phi_{q,p}:[0,1]\to\mathbb{R}` and
:math:`\Phi_q:\mathbb{R}\to\mathbb{R}`. In a sense, they showed that the
only true multivariate function is addition, since every other function
can be written using univariate functions and sum. However, this 2-Layer
width-:math:`(2n+1)` Kolmogorov-Arnold representation may not be smooth
due to its limited expressive power. We augment its expressive power by
generalizing it to arbitrary depths and widths.

Kolmogorov-Arnold Network (KAN)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The Kolmogorov-Arnold representation can be written in matrix form

.. math:: f(x)={\bf \Phi}_{\rm out}\circ{\bf \Phi}_{\rm in}\circ {\bf x}

where

.. math:: {\bf \Phi}_{\rm in}= \begin{pmatrix} \phi_{1,1}(\cdot) & \cdots & \phi_{1,n}(\cdot) \\ \vdots & & \vdots \\ \phi_{2n+1,1}(\cdot) & \cdots & \phi_{2n+1,n}(\cdot) \end{pmatrix},\quad {\bf \Phi}_{\rm out}=\begin{pmatrix} \Phi_1(\cdot) & \cdots & \Phi_{2n+1}(\cdot)\end{pmatrix}

We notice that both :math:`{\bf \Phi}_{\rm in}` and
:math:`{\bf \Phi}_{\rm out}` are special cases of the following function
matrix :math:`{\bf \Phi}` (with :math:`n_{\rm in}` inputs, and
:math:`n_{\rm out}` outputs), we call a Kolmogorov-Arnold layer:

.. math:: {\bf \Phi}= \begin{pmatrix} \phi_{1,1}(\cdot) & \cdots & \phi_{1,n_{\rm in}}(\cdot) \\ \vdots & & \vdots \\ \phi_{n_{\rm out},1}(\cdot) & \cdots & \phi_{n_{\rm out},n_{\rm in}}(\cdot) \end{pmatrix}

:math:`{\bf \Phi}_{\rm in}` corresponds to
:math:`n_{\rm in}=n, n_{\rm out}=2n+1`, and :math:`{\bf \Phi}_{\rm out}`
corresponds to :math:`n_{\rm in}=2n+1, n_{\rm out}=1`.

After defining the layer, we can construct a Kolmogorov-Arnold network
simply by stacking layers! Let’s say we have :math:`L` layers, with the
:math:`l^{\rm th}` layer :math:`{\bf \Phi}_l` have shape
:math:`(n_{l+1}, n_{l})`. Then the whole network is

.. math:: {\rm KAN}({\bf x})={\bf \Phi}_{L-1}\circ\cdots \circ{\bf \Phi}_1\circ{\bf \Phi}_0\circ {\bf x}

In constrast, a Multi-Layer Perceptron is interleaved by linear layers
:math:`{\bf W}_l` and nonlinearities :math:`\sigma`:

.. math:: {\rm MLP}({\bf x})={\bf W}_{L-1}\circ\sigma\circ\cdots\circ {\bf W}_1\circ\sigma\circ {\bf W}_0\circ {\bf x}

A KAN can be easily visualized. (1) A KAN is simply stack of KAN layers.
(2) Each KAN layer can be visualized as a fully-connected layer, with a
1D function placed on each edge. Let’s see an example below.

Get started with KANs
~~~~~~~~~~~~~~~~~~~~~

Initialize KAN

.. code:: ipython3

    from kan import *
    # create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5).
    model = KAN(width=[2,5,1], grid=5, k=3, seed=0)

Create dataset

.. code:: ipython3

    # create dataset f(x,y) = exp(sin(pi*x)+y^2)
    f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
    dataset = create_dataset(f, n_var=2)
    dataset['train_input'].shape, dataset['train_label'].shape




.. parsed-literal::

    (torch.Size([1000, 2]), torch.Size([1000, 1]))



Plot KAN at initialization

.. code:: ipython3

    # plot KAN at initialization
    model(dataset['train_input']);
    model.plot(beta=100)



.. image:: intro_files/intro_15_0.png


Train KAN with sparsity regularization

.. code:: ipython3

    # train the model
    model.train(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=10.);


.. parsed-literal::

    train loss: 1.57e-01 | test loss: 1.31e-01 | reg: 2.05e+01 : 100%|██| 20/20 [00:18<00:00,  1.06it/s]


Plot trained KAN

.. code:: ipython3

    model.plot()



.. image:: intro_files/intro_19_0.png


Prune KAN and replot (keep the original shape)

.. code:: ipython3

    model.prune()
    model.plot(mask=True)



.. image:: intro_files/intro_21_0.png


Prune KAN and replot (get a smaller shape)

.. code:: ipython3

    model = model.prune()
    model(dataset['train_input'])
    model.plot()



.. image:: intro_files/intro_23_0.png


Continue training and replot

.. code:: ipython3

    model.train(dataset, opt="LBFGS", steps=50);


.. parsed-literal::

    train loss: 4.74e-03 | test loss: 4.80e-03 | reg: 2.98e+00 : 100%|██| 50/50 [00:07<00:00,  7.03it/s]


.. code:: ipython3

    model.plot()



.. image:: intro_files/intro_26_0.png


Automatically or manually set activation functions to be symbolic

.. code:: ipython3

    mode = "auto" # "manual"
    
    if mode == "manual":
        # manual mode
        model.fix_symbolic(0,0,0,'sin');
        model.fix_symbolic(0,1,0,'x^2');
        model.fix_symbolic(1,0,0,'exp');
    elif mode == "auto":
        # automatic mode
        lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','abs']
        model.auto_symbolic(lib=lib)


.. parsed-literal::

    fixing (0,0,0) with sin, r2=0.999987252534279
    fixing (0,1,0) with x^2, r2=0.9999996536741071
    fixing (1,0,0) with exp, r2=0.9999988529417926


Continue training to almost machine precision

.. code:: ipython3

    model.train(dataset, opt="LBFGS", steps=50);


.. parsed-literal::

    train loss: 2.02e-10 | test loss: 1.13e-10 | reg: 2.98e+00 : 100%|██| 50/50 [00:02<00:00, 22.59it/s]


Obtain the symbolic formula

.. code:: ipython3

    model.symbolic_formula()[0][0]




.. math::

    \displaystyle 1.0 e^{1.0 x_{2}^{2} + 1.0 \sin{\left(3.14 x_{1} \right)}}