kan package

kan.KAN module

kan.MultKAN.KAN

alias of MultKAN

class kan.MultKAN.MultKAN(*args: Any, **kwargs: Any)

Bases: Module

KAN class

Attributes:

gridint

the number of grid intervals

kint

spline order

act_fun : a list of KANLayers symbolic_fun: a list of Symbolic_KANLayer depth : int

depth of KAN

widthlist

number of neurons in each layer. Without multiplication nodes, [2,5,5,3] means 2D inputs, 3D outputs, with 2 layers of 5 hidden neurons. With multiplication nodes, [2,[5,3],[5,1],3] means besides the [2,5,53] KAN, there are 3 (1) mul nodes in layer 1 (2).

mult_arityint, or list of int lists

multiplication arity for each multiplication node (the number of numbers to be multiplied)

gridint

the number of grid intervals

kint

the order of piecewise polynomial

base_funfun

residual function b(x). an activation function phi(x) = sb_scale * b(x) + sp_scale * spline(x)

symbolic_funa list of Symbolic_KANLayer

Symbolic_KANLayers

symbolic_enabledbool

If False, the symbolic front is not computed (to save time). Default: True.

width_inlist

The number of input neurons for each layer

width_outlist

The number of output neurons for each layer

base_fun_namestr

The base function b(x)

grip_epsfloat

The parameter that interpolates between uniform grid and adaptive grid (based on sample quantile)

node_bias : a list of 1D torch.float node_scale : a list of 1D torch.float subnode_bias : a list of 1D torch.float subnode_scale : a list of 1D torch.float symbolic_enabled : bool

when symbolic_enabled = False, the symbolic branch (symbolic_fun) will be ignored in computation (set to zero)

affine_trainablebool

indicate whether affine parameters are trainable (node_bias, node_scale, subnode_bias, subnode_scale)

sp_trainablebool

indicate whether the overall magnitude of splines is trainable

sb_trainablebool

indicate whether the overall magnitude of base function is trainable

save_actbool

indicate whether intermediate activations are saved in forward pass

node_scoresNone or list of 1D torch.float

node attribution score

edge_scoresNone or list of 2D torch.float

edge attribution score

subnode_scoresNone or list of 1D torch.float

subnode attribution score

cache_dataNone or 2D torch.float

cached input data

actsNone or a list of 2D torch.float

activations on nodes

auto_savebool

indicate whether to automatically save a checkpoint once the model is modified

state_idint

the state of the model (used to save checkpoint)

ckpt_pathstr

the folder to store checkpoints

roundint

the number of times rewind() has been called

device : str

__init__(width=None, grid=3, k=3, mult_arity=2, noise_scale=0.3, scale_base_mu=0.0, scale_base_sigma=1.0, base_fun='silu', symbolic_enabled=True, affine_trainable=False, grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, seed=1, save_act=True, sparse_init=False, auto_save=True, first_init=True, ckpt_path='./model', state_id=0, round=0, device='cpu')

initalize a KAN model

Args:

widthlist of int

Without multiplication nodes: \([n_0, n_1, .., n_{L-1}]\) specify the number of neurons in each layer (including inputs/outputs) With multiplication nodes: \([[n_0,m_0=0], [n_1,m_1], .., [n_{L-1},m_{L-1}]]\) specify the number of addition/multiplication nodes in each layer (including inputs/outputs)

gridint

number of grid intervals. Default: 3.

kint

order of piecewise polynomial. Default: 3.

mult_arityint, or list of int lists

multiplication arity for each multiplication node (the number of numbers to be multiplied)

noise_scalefloat

initial injected noise to spline.

base_funstr

the residual function b(x). Default: ‘silu’

symbolic_enabledbool

compute (True) or skip (False) symbolic computations (for efficiency). By default: True.

affine_trainablebool

affine parameters are updated or not. Affine parameters include node_scale, node_bias, subnode_scale, subnode_bias

grid_epsfloat

When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes.

grid_rangelist/np.array of shape (2,))

setting the range of grids. Default: [-1,1]. This argument is not important if fit(update_grid=True) (by default updata_grid=True)

sp_trainablebool

If true, scale_sp is trainable. Default: True.

sb_trainablebool

If true, scale_base is trainable. Default: True.

devicestr

device

seedint

random seed

save_actbool

indicate whether intermediate activations are saved in forward pass

sparse_initbool

sparse initialization (True) or normal dense initialization. Default: False.

auto_savebool

indicate whether to automatically save a checkpoint once the model is modified

state_idint

the state of the model (used to save checkpoint)

ckpt_pathstr

the folder to store checkpoints. Default: ‘./model’

roundint

the number of times rewind() has been called

device : str

Returns:

self

Example

>>> from kan import *
>>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
checkpoint directory created: ./model
saving model version 0.0
attribute(l=None, i=None, out_score=None, plot=True)

get attribution scores

Args:

lNone or int

layer index

iNone or int

neuron index

out_scoreNone or 1D torch.float

specify output scores

plotbool

when plot = True, display the bar show

Returns:

attribution scores

Example

>>> from kan import *
>>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
>>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2
>>> dataset = create_dataset(f, n_var=3)
>>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
>>> model.attribute()
>>> model.feature_score
auto_swap()

automatically swap neurons such as connection costs are minimized

auto_swap_l(l)
auto_symbolic(a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1, weight_simple=0.8, r2_threshold=0.0)

automatic symbolic regression for all edges

Args:

a_rangetuple

search range of a

b_rangetuple

search range of b

liblist of str

library of candidate symbolic functions

verboseint

larger verbosity => more verbosity

weight_simplefloat

a weight that prioritizies simplicity (low complexity) over performance (high r2) - set to 0.0 to ignore complexity

r2_thresholdfloat

If r2 is below this threshold, the edge will not be fixed with any symbolic function - set to 0.0 to ignore this threshold

Returns:

None

Example

>>> from kan import *
>>> model = KAN(width=[2,1,1], grid=5, k=3, noise_scale=0.0, seed=0)
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2)
>>> dataset = create_dataset(f, n_var=3)
>>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
>>> model.auto_symbolic()
checkout(model_id)

check out an old version

Args:

model_idstr

in format ‘{a}.{b}’ where a is the round number, b is the version number in that round

Returns:

MultKAN

Example

Same use as rewind, although checkout doesn’t change states

property connection_cost
copy()

deepcopy

Args:

pathstr

the path where checkpoints are saved

Returns:

MultKAN

Example

>>> from kan import *
>>> model = KAN(width=[1,1], grid=5, k=3, seed=0)
>>> model2 = model.copy()
>>> model2.act_fun[0].coef.data *= 2
>>> print(model2.act_fun[0].coef.data)
>>> print(model.act_fun[0].coef.data)
disable_symbolic_in_fit(lamb)

during fitting, disable symbolic if either is true (lamb = 0, none of symbolic functions is active)

evaluate(dataset)
expand_depth()

expand network depth, add an indentity layer to the end. For usage, please refer to tutorials interp_3_KAN_compiler.ipynb.

Args:

varNone or a list of sympy expression

input variables

normalizer : [mean, std] output_normalizer : [mean, std]

Returns:

None

expand_width(layer_id, n_added_nodes, sum_bool=True, mult_arity=2)

expand network width. For usage, please refer to tutorials interp_3_KAN_compiler.ipynb.

Args:

layer_idint

layer index

n_added_nodesinit

the number of added nodes

sum_boolbool

if sum_bool == True, added nodes are addition nodes; otherwise multiplication nodes

mult_arityinit

multiplication arity (the number of numbers to be multiplied)

Returns:

None

feature_interaction(l, neuron_th=0.01, feature_th=0.01)

get feature interaction

Args:

lint

layer index

neuron_thfloat

threshold to determine whether a neuron is active

feature_thfloat

threshold to determine whether a feature is active

Returns:

dictionary

Example

>>> from kan import *
>>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
>>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2
>>> dataset = create_dataset(f, n_var=3)
>>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
>>> model.attribute()
>>> model.feature_interaction(1)
property feature_score

attribution scores for inputs

fit(dataset, opt='LBFGS', steps=100, log=1, lamb=0.0, lamb_l1=1.0, lamb_entropy=2.0, lamb_coef=0.0, lamb_coefdiff=0.0, update_grid=True, grid_update_num=10, loss_fn=None, lr=1.0, start_grid_update_step=-1, stop_grid_update_step=50, batch=-1, metrics=None, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', singularity_avoiding=False, y_th=1000.0, reg_metric='edge_forward_spline_n', display_metrics=None)

training

Args:

datasetdic

contains dataset[‘train_input’], dataset[‘train_label’], dataset[‘test_input’], dataset[‘test_label’]

optstr

“LBFGS” or “Adam”

stepsint

training steps

logint

logging frequency

lambfloat

overall penalty strength

lamb_l1float

l1 penalty strength

lamb_entropyfloat

entropy penalty strength

lamb_coeffloat

coefficient magnitude penalty strength

lamb_coefdifffloat

difference of nearby coefficits (smoothness) penalty strength

update_gridbool

If True, update grid regularly before stop_grid_update_step

grid_update_numint

the number of grid updates before stop_grid_update_step

start_grid_update_stepint

no grid updates before this training step

stop_grid_update_stepint

no grid updates after this training step

loss_fnfunction

loss function

lrfloat

learning rate

batchint

batch size, if -1 then full.

save_fig_freqint

save figure every (save_fig_freq) steps

singularity_avoidingbool

indicate whether to avoid singularity for the symbolic part

y_thfloat

singularity threshold (anything above the threshold is considered singular and is softened in some ways)

reg_metricstr

regularization metric. Choose from {‘edge_forward_spline_n’, ‘edge_forward_spline_u’, ‘edge_forward_sum’, ‘edge_backward’, ‘node_backward’}

metricsa list of metrics (as functions)

the metrics to be computed in training

display_metricsa list of functions

the metric to be displayed in tqdm progress bar

Returns:

resultsdic

results[‘train_loss’], 1D array of training losses (RMSE) results[‘test_loss’], 1D array of test losses (RMSE) results[‘reg’], 1D array of regularization other metrics specified in metrics

Example

>>> from kan import *
>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
>>> dataset = create_dataset(f, n_var=2)
>>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
>>> model.plot()
# Most examples in toturals involve the fit() method. Please check them for useness.
fix_symbolic(l, i, j, fun_name, fit_params_bool=True, a_range=(-10, 10), b_range=(-10, 10), verbose=True, random=False, log_history=True)

set (l,i,j) activation to be symbolic (specified by fun_name)

Args:

lint

layer index

iint

input neuron index

jint

output neuron index

fun_namestr

function name

fit_params_boolbool

obtaining affine parameters through fitting (True) or setting default values (False)

a_rangetuple

sweeping range of a

b_rangetuple

sweeping range of b

verbosebool

If True, more information is printed.

randombool

initialize affine parameteres randomly or as [1,0,1,0]

log_historybool

indicate whether to log history when the function is called

Returns:

None or r2 (coefficient of determination)

Example 1

>>> # when fit_params_bool = False
>>> model = KAN(width=[2,5,1], grid=5, k=3)
>>> model.fix_symbolic(0,1,3,'sin',fit_params_bool=False)
>>> print(model.act_fun[0].mask.reshape(2,5))
>>> print(model.symbolic_fun[0].mask.reshape(2,5))

Example 2

>>> # when fit_params_bool = True
>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=1.)
>>> x = torch.normal(0,1,size=(100,2))
>>> model(x) # obtain activations (otherwise model does not have attributes acts)
>>> model.fix_symbolic(0,1,3,'sin',fit_params_bool=True)
>>> print(model.act_fun[0].mask.reshape(2,5))
>>> print(model.symbolic_fun[0].mask.reshape(2,5))
forward(x, singularity_avoiding=False, y_th=10.0)

forward pass

Args:

x2D torch.tensor

inputs

singularity_avoidingbool

whether to avoid singularity for the symbolic branch

y_thfloat

the threshold for singularity

Returns:

None

Example1

>>> from kan import *
>>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
>>> x = torch.rand(100,2)
>>> model(x).shape

Example2

>>> from kan import *
>>> model = KAN(width=[1,1], grid=5, k=3, seed=0)
>>> x = torch.tensor([[1],[-0.01]])
>>> model.fix_symbolic(0,0,0,'log',fit_params_bool=False)
>>> print(model(x))
>>> print(model(x, singularity_avoiding=True))
>>> print(model(x, singularity_avoiding=True, y_th=1.))
get_act(x=None)

collect intermidate activations

get_fun(l, i, j)

get function (l,i,j)

get_params()

Get parameters

get_range(l, i, j, verbose=True)

Get the input range and output range of the (l,i,j) activation

Args:

lint

layer index

iint

input neuron index

jint

output neuron index

Returns:

x_minfloat

minimum of input

x_maxfloat

maximum of input

y_minfloat

minimum of output

y_maxfloat

maximum of output

Example

>>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.)
>>> x = torch.normal(0,1,size=(100,2))
>>> model(x) # do a forward pass to obtain model.acts
>>> model.get_range(0,0,0)
get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)

Get regularization. This seems unnecessary but in case a class wants to inherit this, it may want to rewrite get_reg, but not reg.

history(k='all')

get history

initialize_from_another_model(another_model, x)

initialize from another model of the same width, but their ‘grid’ parameter can be different. Note this is equivalent to refine() when we don’t want to keep another_model

Args:

another_model : MultKAN x : 2D torch.float

Returns:

self

Example

>>> from kan import *
>>> model1 = KAN(width=[2,5,1], grid=3)
>>> model2 = KAN(width=[2,5,1], grid=10)
>>> x = torch.rand(100,2)
>>> model2.initialize_from_another_model(model1, x)
initialize_grid_from_another_model(model, x)

initialize grid from another model

Args:

modelMultKAN

parent model

x2D torch.tensor

inputs

Returns:

None

Example

>>> from kan import *
>>> model = KAN(width=[1,1], grid=5, k=3, seed=0)
>>> print(model.act_fun[0].grid)
>>> x = torch.linspace(-10,10,steps=101)[:,None]
>>> model2 = KAN(width=[1,1], grid=10, k=3, seed=0)
>>> model2.initialize_grid_from_another_model(model, x)
>>> print(model2.act_fun[0].grid)
static loadckpt(path='model')

load checkpoint from path

Args:

pathstr

the path where checkpoints are saved

Returns:

MultKAN

Example

>>> from kan import *
>>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
>>> model.saveckpt('./mark')
>>> KAN.loadckpt('./mark')
log_history(method_name)
module(start_layer, chain)

specify network modules

Args:

start_layerint

the earliest layer of the module

chainstr

specify neurons in the module

Returns:

None

property n_edge

the number of active edges

property n_mult

The number of multiplication nodes for each layer

property n_sum

The number of addition nodes for each layer

node_attribute()
perturb(mag=1.0, mode='non-intrusive')

preturb a network. For usage, please refer to tutorials interp_3_KAN_compiler.ipynb.

Args:

magfloat

perturbation magnitude

modestr

pertubatation mode, choices = {‘non-intrusive’, ‘all’, ‘minimal’}

Returns:

None

plot(folder='./figures', beta=3, metric='backward', scale=0.5, tick=False, sample=False, in_vars=None, out_vars=None, title=None, varscale=1.0)

plot KAN

Args:

folderstr

the folder to store pngs

betafloat

positive number. control the transparency of each activation. transparency = tanh(beta*l1).

maskbool

If True, plot with mask (need to run prune() first to obtain mask). If False (by default), plot all activation functions.

modebool

“supervised” or “unsupervised”. If “supervised”, l1 is measured by absolution value (not subtracting mean); if “unsupervised”, l1 is measured by standard deviation (subtracting mean).

scalefloat

control the size of the diagram

in_vars: None or list of str

the name(s) of input variables

out_vars: None or list of str

the name(s) of output variables

title: None or str

title

varscalefloat

the size of input variables

Returns:

Figure

Example

>>> # see more interactive examples in demos
>>> model = KAN(width=[2,3,1], grid=3, k=3, noise_scale=1.0)
>>> x = torch.normal(0,1,size=(100,2))
>>> model(x) # do a forward pass to obtain model.acts
>>> model.plot()
prune(both nodes and edges)

Args:

node_thfloat

if the attribution score of a node is below node_th, it is considered dead and will be set to zero.

edge_thfloat

if the attribution score of an edge is below node_th, it is considered dead and will be set to zero.

Returns:

pruned network : MultKAN

Example

>>> from kan import *
>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
>>> dataset = create_dataset(f, n_var=2)
>>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
>>> model = model.prune()
>>> model.plot()
prune_edge(threshold=0.03, log_history=True)

pruning edges

Args:

thresholdfloat

if the attribution score of an edge is below the threshold, it is considered dead and will be set to zero.

Returns:

pruned network : MultKAN

Example

>>> from kan import *
>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
>>> dataset = create_dataset(f, n_var=2)
>>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
>>> model = model.prune_edge()
>>> model.plot()
prune_input(threshold=0.01, active_inputs=None, log_history=True)

prune inputs

Args:

thresholdfloat

if the attribution score of the input feature is below threshold, it is considered irrelevant.

active_inputsNone or list

if a list is passed, the manual mode will disregard attribution score and prune as instructed.

Returns:

pruned network : MultKAN

Example1

>>> # automatic
>>> from kan import *
>>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
>>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2
>>> dataset = create_dataset(f, n_var=3)
>>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
>>> model.plot()
>>> model = model.prune_input()
>>> model.plot()

Example2

>>> # automatic
>>> from kan import *
>>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
>>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2
>>> dataset = create_dataset(f, n_var=3)
>>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
>>> model.plot()
>>> model = model.prune_input(active_inputs=[0,1])
>>> model.plot()
prune_node(threshold=0.01, mode='auto', active_neurons_id=None, log_history=True)

pruning nodes

Args:

thresholdfloat

if the attribution score of a neuron is below the threshold, it is considered dead and will be removed

modestr

‘auto’ or ‘manual’. with ‘auto’, nodes are automatically pruned using threshold. with ‘manual’, active_neurons_id should be passed in.

Returns:

pruned network : MultKAN

Example

>>> from kan import *
>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
>>> dataset = create_dataset(f, n_var=2)
>>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
>>> model = model.prune_node()
>>> model.plot()
refine(new_grid)

grid refinement

Args:

new_gridinit

the number of grid intervals after refinement

Returns:

a refined model : MultKAN

Example

>>> from kan import *
>>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
>>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
>>> print(model.grid)
>>> x = torch.rand(100,2)
>>> model.get_act(x)
>>> model = model.refine(10)
>>> print(model.grid)
checkpoint directory created: ./model
saving model version 0.0
5
saving model version 0.1
10
reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)

Get regularization

Args:

reg_metricthe regularization metric

‘edge_forward_spline_n’, ‘edge_forward_spline_u’, ‘edge_forward_sum’, ‘edge_backward’, ‘node_backward’

lamb_l1float

l1 penalty strength

lamb_entropyfloat

entropy penalty strength

lamb_coeffloat

coefficient penalty strength

lamb_coefdifffloat

coefficient smoothness strength

Returns:

reg_ : torch.float

Example

>>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.)
>>> x = torch.rand(100,2)
>>> model.get_act(x)
>>> model.reg('edge_forward_spline_n', 1.0, 2.0, 1.0, 1.0)
remove_edge(l, i, j, log_history=True)

remove activtion phi(l,i,j) (set its mask to zero)

remove_node(l, i, mode='all', log_history=True)

remove neuron (l,i) (set the masks of all incoming and outgoing activation functions to zero)

rewind(model_id)

rewind to an old version

Args:

model_idstr

in format ‘{a}.{b}’ where a is the round number, b is the version number in that round

Returns:

MultKAN

Example

Please refer to tutorials. API 12: Checkpoint, save & load model

saveckpt(path='model')

save the current model to files (configuration file and state file)

Args:

pathstr

the path where checkpoints are saved

Returns:

None

Example

>>> from kan import *
>>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
>>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
>>> model.saveckpt('./mark')
# There will be three files appearing in the current folder: mark_cache_data, mark_config.yml, mark_state
set_mode(l, i, j, mode, mask_n=None)
speed(compile=False)

turn on KAN’s speed mode

suggest_symbolic(l, i, j, a_range=(-10, 10), b_range=(-10, 10), lib=None, topk=5, verbose=True, r2_loss_fun=<function MultKAN.<lambda>>, c_loss_fun=<function MultKAN.<lambda>>, weight_simple=0.8)

suggest symbolic function

Args:

lint

layer index

iint

neuron index in layer l

jint

neuron index in layer j

a_rangetuple

search range of a

b_rangetuple

search range of b

liblist of str

library of candidate symbolic functions

topkint

the number of top functions displayed

verbosebool

if verbose = True, print more information

r2_loss_funfunctoon

function : r2 -> “bits”

c_loss_funfun

function : c -> ‘bits’

weight_simplefloat

the simplifty weight: the higher, more prefer simplicity over performance

Returns:

best_name (str), best_fun (function), best_r2 (float), best_c (float)

Example

>>> from kan import *
>>> model = KAN(width=[2,1,1], grid=5, k=3, noise_scale=0.0, seed=0)
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2)
>>> dataset = create_dataset(f, n_var=3)
>>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
>>> model.suggest_symbolic(0,1,0)
swap(l, i1, i2, log_history=True)
symbolic_formula(var=None, normalizer=None, output_normalizer=None)

get symbolic formula

Args:

varNone or a list of sympy expression

input variables

normalizer : [mean, std] output_normalizer : [mean, std]

Returns:

None

Example

>>> from kan import *
>>> model = KAN(width=[2,1,1], grid=5, k=3, noise_scale=0.0, seed=0)
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2)
>>> dataset = create_dataset(f, n_var=3)
>>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
>>> model.auto_symbolic()
>>> model.symbolic_formula()[0][0]
to(device)

move the model to device

Args:

device : str or device

Returns:

self

Example

>>> from kan import *
>>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
>>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
>>> model.to(device)
tree(x=None, in_var=None, style='tree', sym_th=0.001, sep_th=0.1, skip_sep_test=False, verbose=False)

turn KAN into a tree

unfix_symbolic(l, i, j, log_history=True)

unfix the (l,i,j) activation function.

unfix_symbolic_all(log_history=True)

unfix all activation functions.

update_grid(x)

call update_grid_from_samples. This seems unnecessary but we retain it for the sake of classes that might inherit from MultKAN

update_grid_from_samples(x)

update grid from samples

Args:

x2D torch.tensor

inputs

Returns:

None

Example

>>> from kan import *
>>> model = KAN(width=[1,1], grid=5, k=3, seed=0)
>>> print(model.act_fun[0].grid)
>>> x = torch.linspace(-10,10,steps=101)[:,None]
>>> model.update_grid_from_samples(x)
>>> print(model.act_fun[0].grid)
property width_in

The number of input nodes for each layer

property width_out

The number of output subnodes for each layer

kan.KANLayer module

class kan.KANLayer.KANLayer(*args: Any, **kwargs: Any)

Bases: Module

KANLayer class

Attributes:

in_dim: int

input dimension

out_dim: int

output dimension

num: int

the number of grid intervals

k: int

the piecewise polynomial order of splines

noise_scale: float

spline scale at initialization

coef: 2D torch.tensor

coefficients of B-spline bases

scale_base_mu: float

magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_mu

scale_base_sigma: float

magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_sigma

scale_sp: float

mangitude of the spline function spline(x)

base_fun: fun

residual function b(x)

mask: 1D torch.float

mask of spline functions. setting some element of the mask to zero means setting the corresponding activation to zero function.

grid_eps: float in [0,1]

a hyperparameter used in update_grid_from_samples. When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes. the id of activation functions that are locked

device: str

device

__init__(in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.5, scale_base_mu=0.0, scale_base_sigma=1.0, scale_sp=1.0, base_fun=torch.nn.SiLU, grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, save_plot_data=True, device='cpu', sparse_init=False)

‘ initialize a KANLayer

Args:

in_dimint

input dimension. Default: 2.

out_dimint

output dimension. Default: 3.

numint

the number of grid intervals = G. Default: 5.

kint

the order of piecewise polynomial. Default: 3.

noise_scalefloat

the scale of noise injected at initialization. Default: 0.1.

scale_base_mufloat

the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2).

scale_base_sigmafloat

the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2).

scale_spfloat

the scale of the base function spline(x).

base_funfunction

residual function b(x). Default: torch.nn.SiLU()

grid_epsfloat

When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes.

grid_rangelist/np.array of shape (2,)

setting the range of grids. Default: [-1,1].

sp_trainablebool

If true, scale_sp is trainable

sb_trainablebool

If true, scale_base is trainable

devicestr

device

sparse_initbool

if sparse_init = True, sparse initialization is applied.

Returns:

self

Example

>>> from kan.KANLayer import *
>>> model = KANLayer(in_dim=3, out_dim=5)
>>> (model.in_dim, model.out_dim)
forward(x)

KANLayer forward given input x

Args:

x2D torch.float

inputs, shape (number of samples, input dimension)

Returns:

y2D torch.float

outputs, shape (number of samples, output dimension)

preacts3D torch.float

fan out x into activations, shape (number of sampels, output dimension, input dimension)

postacts3D torch.float

the outputs of activation functions with preacts as inputs

postspline3D torch.float

the outputs of spline functions with preacts as inputs

Example

>>> from kan.KANLayer import *
>>> model = KANLayer(in_dim=3, out_dim=5)
>>> x = torch.normal(0,1,size=(100,3))
>>> y, preacts, postacts, postspline = model(x)
>>> y.shape, preacts.shape, postacts.shape, postspline.shape
get_subset(in_id, out_id)

get a smaller KANLayer from a larger KANLayer (used for pruning)

Args:

in_idlist

id of selected input neurons

out_idlist

id of selected output neurons

Returns:

spb : KANLayer

Example

>>> kanlayer_large = KANLayer(in_dim=10, out_dim=10, num=5, k=3)
>>> kanlayer_small = kanlayer_large.get_subset([0,9],[1,2,3])
>>> kanlayer_small.in_dim, kanlayer_small.out_dim
(2, 3)
initialize_grid_from_parent(parent, x, mode='sample')

update grid from a parent KANLayer & samples

Args:

parentKANLayer

a parent KANLayer (whose grid is usually coarser than the current model)

x2D torch.float

inputs, shape (number of samples, input dimension)

Returns:

None

Example

>>> batch = 100
>>> parent_model = KANLayer(in_dim=1, out_dim=1, num=5, k=3)
>>> print(parent_model.grid.data)
>>> model = KANLayer(in_dim=1, out_dim=1, num=10, k=3)
>>> x = torch.normal(0,1,size=(batch, 1))
>>> model.initialize_grid_from_parent(parent_model, x)
>>> print(model.grid.data)
swap(i1, i2, mode='in')

swap the i1 neuron with the i2 neuron in input (if mode == ‘in’) or output (if mode == ‘out’)

Args:

i1 : int i2 : int mode : str

mode = ‘in’ or ‘out’

Returns:

None

Example

>>> from kan.KANLayer import *
>>> model = KANLayer(in_dim=2, out_dim=2, num=5, k=3)
>>> print(model.coef)
>>> model.swap(0,1,mode='in')
>>> print(model.coef)
to(device)
update_grid_from_samples(x, mode='sample')

update grid from samples

Args:

x2D torch.float

inputs, shape (number of samples, input dimension)

Returns:

None

Example

>>> model = KANLayer(in_dim=1, out_dim=1, num=5, k=3)
>>> print(model.grid.data)
>>> x = torch.linspace(-3,3,steps=100)[:,None]
>>> model.update_grid_from_samples(x)
>>> print(model.grid.data)

kan.LBFGS module

class kan.LBFGS.LBFGS(*args: Any, **kwargs: Any)

Bases: Optimizer

Implements L-BFGS algorithm.

Heavily inspired by minFunc.

Warning

This optimizer doesn’t support per-parameter options and parameter groups (there can be only one).

Warning

Right now all parameters have to be on a single device. This will be improved in the future.

Note

This is a very memory intensive optimizer (it requires additional param_bytes * (history_size + 1) bytes). If it doesn’t fit in memory try reducing the history size, or use a different algorithm.

Args:

lr (float): learning rate (default: 1) max_iter (int): maximal number of iterations per optimization step

(default: 20)

max_eval (int): maximal number of function evaluations per optimization

step (default: max_iter * 1.25).

tolerance_grad (float): termination tolerance on first order optimality

(default: 1e-7).

tolerance_change (float): termination tolerance on function

value/parameter changes (default: 1e-9).

history_size (int): update history size (default: 100). line_search_fn (str): either ‘strong_wolfe’ or None (default: None).

__init__(params, lr=1, max_iter=20, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, tolerance_ys=1e-32, history_size=100, line_search_fn=None)
step(closure)

Perform a single optimization step.

Args:
closure (Callable): A closure that reevaluates the model

and returns the loss.

kan.Symbolic_KANLayer module

class kan.Symbolic_KANLayer.Symbolic_KANLayer(*args: Any, **kwargs: Any)

Bases: Module

KANLayer class

Attributes:

in_dimint

input dimension

out_dimint

output dimension

funs2D array of torch functions (or lambda functions)

symbolic functions (torch)

funs_avoid_singularity : 2D array of torch functions (or lambda functions) with singularity avoiding funs_name : 2D arry of str

names of symbolic functions

funs_sympy2D array of sympy functions (or lambda functions)

symbolic functions (sympy)

affine3D array of floats

affine transformations of inputs and outputs

__init__(in_dim=3, out_dim=2, device='cpu')

initialize a Symbolic_KANLayer (activation functions are initialized to be identity functions)

Args:

in_dimint

input dimension

out_dimint

output dimension

devicestr

device

Returns:

self

Example

>>> sb = Symbolic_KANLayer(in_dim=3, out_dim=3)
>>> len(sb.funs), len(sb.funs[0])
fix_symbolic(i, j, fun_name, x=None, y=None, random=False, a_range=(-10, 10), b_range=(-10, 10), verbose=True)

fix an activation function to be symbolic

Args:

iint

the id of input neuron

jint

the id of output neuron

fun_namestr

the name of the symbolic functions

x1D array

preactivations

y1D array

postactivations

a_rangetuple

sweeping range of a

b_rangetuple

sweeping range of a

verbosebool

print more information if True

Returns:

r2 (coefficient of determination)

Example 1

>>> # when x & y are not provided. Affine parameters are set to a = 1, b = 0, c = 1, d = 0
>>> sb = Symbolic_KANLayer(in_dim=3, out_dim=2)
>>> sb.fix_symbolic(2,1,'sin')
>>> print(sb.funs_name)
>>> print(sb.affine)

Example 2

>>> # when x & y are provided, fit_params() is called to find the best fit coefficients
>>> sb = Symbolic_KANLayer(in_dim=3, out_dim=2)
>>> batch = 100
>>> x = torch.linspace(-1,1,steps=batch)
>>> noises = torch.normal(0,1,(batch,)) * 0.02
>>> y = 5.0*torch.sin(3.0*x + 2.0) + 0.7 + noises
>>> sb.fix_symbolic(2,1,'sin',x,y)
>>> print(sb.funs_name)
>>> print(sb.affine[1,2,:].data)
forward(x, singularity_avoiding=False, y_th=10.0)

Args:

x2D array

inputs, shape (batch, input dimension)

singularity_avoidingbool

if True, funs_avoid_singularity is used; if False, funs is used.

y_thfloat

the singularity threshold

Returns:

y2D array

outputs, shape (batch, output dimension)

postacts3D array

activations after activation functions but before being summed on nodes

Example

>>> sb = Symbolic_KANLayer(in_dim=3, out_dim=5)
>>> x = torch.normal(0,1,size=(100,3))
>>> y, postacts = sb(x)
>>> y.shape, postacts.shape
(torch.Size([100, 5]), torch.Size([100, 5, 3]))
get_subset(in_id, out_id)

get a smaller Symbolic_KANLayer from a larger Symbolic_KANLayer (used for pruning)

Args:

in_idlist

id of selected input neurons

out_idlist

id of selected output neurons

Returns:

spb : Symbolic_KANLayer

Example

>>> sb_large = Symbolic_KANLayer(in_dim=10, out_dim=10)
>>> sb_small = sb_large.get_subset([0,9],[1,2,3])
>>> sb_small.in_dim, sb_small.out_dim
swap(i1, i2, mode='in')

swap the i1 neuron with the i2 neuron in input (if mode == ‘in’) or output (if mode == ‘out’)

to(device)

move to device

kan.spline module

kan.spline.B_batch(x, grid, k=0, extend=True, device='cpu')

evaludate x on B-spline bases

Args:

x2D torch.tensor

inputs, shape (number of splines, number of samples)

grid2D torch.tensor

grids, shape (number of splines, number of grid points)

kint

the piecewise polynomial order of splines.

extendbool

If True, k points are extended on both ends. If False, no extension (zero boundary condition). Default: True

devicestr

devicde

Returns:

spline values3D torch.tensor

shape (batch, in_dim, G+k). G: the number of grid intervals, k: spline order.

Example

>>> from kan.spline import B_batch
>>> x = torch.rand(100,2)
>>> grid = torch.linspace(-1,1,steps=11)[None, :].expand(2, 11)
>>> B_batch(x, grid, k=3).shape
kan.spline.coef2curve(x_eval, grid, coef, k, device='cpu')

converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves (summing up B_batch results over B-spline basis).

Args:

x_eval2D torch.tensor

shape (batch, in_dim)

grid2D torch.tensor

shape (in_dim, G+2k). G: the number of grid intervals; k: spline order.

coef3D torch.tensor

shape (in_dim, out_dim, G+k)

kint

the piecewise polynomial order of splines.

devicestr

devicde

Returns:

y_eval3D torch.tensor

shape (number of samples, in_dim, out_dim)

kan.spline.curve2coef(x_eval, y_eval, grid, k, lamb=1e-08)

converting B-spline curves to B-spline coefficients using least squares.

Args:

x_eval2D torch.tensor

shape (in_dim, out_dim, number of samples)

y_eval2D torch.tensor

shape (in_dim, out_dim, number of samples)

grid2D torch.tensor

shape (in_dim, grid+2*k)

kint

spline order

lambfloat

regularized least square lambda

Returns:

coef3D torch.tensor

shape (in_dim, out_dim, G+k)

kan.spline.extend_grid(grid, k_extend=0)

extend grid

kan.utils module

kan.utils.add_symbolic(name, fun, c=1, fun_singularity=None)

add a symbolic function to library

Args:

namestr

name of the function

funfun

torch function or lambda function

Returns:

None

Example

>>> print(SYMBOLIC_LIB['Bessel'])
KeyError: 'Bessel'
>>> add_symbolic('Bessel', torch.special.bessel_j0)
>>> print(SYMBOLIC_LIB['Bessel'])
(<built-in function special_bessel_j0>, Bessel)
kan.utils.augment_input(orig_vars, aux_vars, x)

augment inputs

Args:

orig_vars : list of sympy symbols aux_vars : list of auxiliary symbols x : inputs

Returns:

augmented inputs

Example

>>> from kan.utils import *
>>> from sympy import *
>>> orig_vars = a, b = symbols('a b')
>>> aux_vars = [a + b, a * b]
>>> x = torch.rand(100, 2)
>>> augment_input(orig_vars, aux_vars, x).shape
kan.utils.batch_hessian(model, x, create_graph=False)

hessian

Args:

func : function or model x : inputs create_graph : bool

Returns:

jacobian

Example

>>> from kan.utils import batch_hessian
>>> x = torch.normal(0,1,size=(100,2))
>>> model = lambda x: x[:,[0]]**2 + x[:,[1]]**2
>>> batch_hessian(model, x)
kan.utils.batch_jacobian(func, x, create_graph=False, mode='scalar')

jacobian

Args:

func : function or model x : inputs create_graph : bool

Returns:

jacobian

Example

>>> from kan.utils import batch_jacobian
>>> x = torch.normal(0,1,size=(100,2))
>>> model = lambda x: x[:,[0]] + x[:,[1]]
>>> batch_jacobian(model, x)
kan.utils.create_dataset(f, n_var=2, f_mode='col', ranges=[-1, 1], train_num=1000, test_num=1000, normalize_input=False, normalize_label=False, device='cpu', seed=0)

create dataset

Args:

ffunction

the symbolic formula used to create the synthetic dataset

rangeslist or np.array; shape (2,) or (n_var, 2)

the range of input variables. Default: [-1,1].

train_numint

the number of training samples. Default: 1000.

test_numint

the number of test samples. Default: 1000.

normalize_inputbool

If True, apply normalization to inputs. Default: False.

normalize_labelbool

If True, apply normalization to labels. Default: False.

devicestr

device. Default: ‘cpu’.

seedint

random seed. Default: 0.

Returns:

datasetdic
Train/test inputs/labels are dataset[‘train_input’], dataset[‘train_label’],

dataset[‘test_input’], dataset[‘test_label’]

Example

>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
>>> dataset = create_dataset(f, n_var=2, train_num=100)
>>> dataset['train_input'].shape
torch.Size([100, 2])
kan.utils.create_dataset_from_data(inputs, labels, train_ratio=0.8, device='cpu')

create dataset from data

Args:

inputs : 2D torch.float labels : 2D torch.float train_ratio : float

the ratio of training fraction

device : str

Returns:

dataset (dictionary)

Example

>>> from kan.utils import create_dataset_from_data
>>> x = torch.normal(0,1,size=(100,2))
>>> y = torch.normal(0,1,size=(100,1))
>>> dataset = create_dataset_from_data(x, y)
>>> dataset['train_input'].shape
kan.utils.ex_round(ex1, n_digit)

rounding the numbers in an expression to certain floating points

Args:

ex1 : sympy expression n_digit : int

Returns:

ex2 : sympy expression

Example

>>> from kan.utils import *
>>> from sympy import *
>>> input_vars = a, b = symbols('a b')
>>> expression = 3.14534242 * exp(sin(pi*a) + b**2) - 2.32345402
>>> ex_round(expression, 2)
kan.utils.f_arccos(x, y_th)
kan.utils.f_arcsin(x, y_th)
kan.utils.f_arctanh(x, y_th)
kan.utils.f_exp(x, y_th)
kan.utils.f_inv(x, y_th)
kan.utils.f_inv2(x, y_th)
kan.utils.f_inv3(x, y_th)
kan.utils.f_inv4(x, y_th)
kan.utils.f_inv5(x, y_th)
kan.utils.f_invsqrt(x, y_th)
kan.utils.f_log(x, y_th)
kan.utils.f_power1d5(x, y_th)
kan.utils.f_sqrt(x, y_th)
kan.utils.f_tan(x, y_th)
kan.utils.fit_params(x, y, fun, a_range=(-10, 10), b_range=(-10, 10), grid_number=101, iteration=3, verbose=True, device='cpu')

fit a, b, c, d such that

\[|y-(cf(ax+b)+d)|^2\]

is minimized. Both x and y are 1D array. Sweep a and b, find the best fitted model.

Args:

x1D array

x values

y1D array

y values

funfunction

symbolic function

a_rangetuple

sweeping range of a

b_rangetuple

sweeping range of b

grid_numint

number of steps along a and b

iterationint

number of zooming in

verbosebool

print extra information if True

devicestr

device

Returns:

a_bestfloat

best fitted a

b_bestfloat

best fitted b

c_bestfloat

best fitted c

d_bestfloat

best fitted d

r2_bestfloat

best r2 (coefficient of determination)

Example

>>> num = 100
>>> x = torch.linspace(-1,1,steps=num)
>>> noises = torch.normal(0,1,(num,)) * 0.02
>>> y = 5.0*torch.sin(3.0*x + 2.0) + 0.7 + noises
>>> fit_params(x, y, torch.sin)
r2 is 0.9999727010726929
(tensor([2.9982, 1.9996, 5.0053, 0.7011]), tensor(1.0000))
kan.utils.get_derivative(model, inputs, labels, derivative='hessian', loss_mode='pred', reg_metric='w', lamb=0.0, lamb_l1=1.0, lamb_entropy=0.0)

compute the jacobian/hessian of loss wrt to model parameters

Args:

inputs : 2D torch.float labels : 2D torch.float derivative : str

‘jacobian’ or ‘hessian’

device : str

Returns:

jacobian or hessian

kan.utils.model2param(model)

turn model parameters into a flattened vector

kan.utils.sparse_mask(in_dim, out_dim)

get sparse mask

kan.compiler module

kan.compiler.expr2kan(input_variables, expr, grid=5, k=3, auto_save=False)

compile a symbolic formula to a MultKAN

Args:

input_variables : a list of sympy symbols expr : sympy expression grid : int

the number of grid intervals

kint

spline order

auto_savebool

if auto_save = True, models are automatically saved

Returns:

MultKAN

Example

>>> from kan.compiler import *
>>> from sympy import *
>>> input_vars = a, b = symbols('a b')
>>> expression = exp(sin(pi*a) + b**2)
>>> model = kanpiler(input_vars, expression)
>>> x = torch.rand(100,2) * 2 - 1
>>> model(x)
>>> model.plot()
kan.compiler.kanpiler(input_variables, expr, grid=5, k=3, auto_save=False)

compile a symbolic formula to a MultKAN

Args:

input_variables : a list of sympy symbols expr : sympy expression grid : int

the number of grid intervals

kint

spline order

auto_savebool

if auto_save = True, models are automatically saved

Returns:

MultKAN

Example

>>> from kan.compiler import *
>>> from sympy import *
>>> input_vars = a, b = symbols('a b')
>>> expression = exp(sin(pi*a) + b**2)
>>> model = kanpiler(input_vars, expression)
>>> x = torch.rand(100,2) * 2 - 1
>>> model(x)
>>> model.plot()
kan.compiler.next_nontrivial_operation(expr, scale=1, bias=0)

remove the affine part of an expression

Args:

expr : sympy expression scale : float bias : float

Returns:

expr : sympy expression scale : float bias : float

Example

>>> from kan.compiler import *
>>> from sympy import *
>>> input_vars = a, b = symbols('a b')
>>> expression = 3.14534242 * exp(sin(pi*a) + b**2) - 2.32345402
>>> next_nontrivial_operation(expression)
kan.compiler.sf2kan(input_variables, expr, grid=5, k=3, auto_save=False)

compile a symbolic formula to a MultKAN

Args:

input_variables : a list of sympy symbols expr : sympy expression grid : int

the number of grid intervals

kint

spline order

auto_savebool

if auto_save = True, models are automatically saved

Returns:

MultKAN

Example

>>> from kan.compiler import *
>>> from sympy import *
>>> input_vars = a, b = symbols('a b')
>>> expression = exp(sin(pi*a) + b**2)
>>> model = kanpiler(input_vars, expression)
>>> x = torch.rand(100,2) * 2 - 1
>>> model(x)
>>> model.plot()

kan.hypothesis module

kan.hypothesis.batch_grad_normgrad(model, x, group, create_graph=False)
kan.hypothesis.detect_separability(model, x, mode='add', score_th=0.01, res_th=0.01, n_clusters=None, bias=0.0, verbose=False)

detect function separability

Args:

model : MultKAN, MLP or python function x : 2D torch.float

inputs

modestr

mode = ‘add’ or mode = ‘mul’

score_thfloat

threshold of score

res_thfloat

threshold of residue

n_clustersNone or int

the number of clusters

biasfloat

bias (for multiplicative separability)

verbose : bool

Returns:

results (dictionary)

Example1

>>> from kan.hypothesis import *
>>> model = lambda x: x[:,[0]] ** 2 + torch.exp(x[:,[1]]+x[:,[2]])
>>> x = torch.normal(0,1,size=(100,3))
>>> detect_separability(model, x, mode='add')

Example2

>>> from kan.hypothesis import *
>>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]])
>>> x = torch.normal(0,1,size=(100,3))
>>> detect_separability(model, x, mode='mul')
kan.hypothesis.get_dependence(model, x, group)
kan.hypothesis.get_molecule(model, x, sym_th=0.001, verbose=True)

how variables are combined hierarchically

Args:

model : MultKAN, MLP or python function x : 2D torch.float

inputs

sym_thfloat

threshold of symmetry

verbose : bool

Returns:

list

Example

>>> from kan.hypothesis import *
>>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2
>>> x = torch.normal(0,1,size=(100,8))
>>> get_molecule(model, x, verbose=False)
[[[0], [1], [2], [3], [4], [5], [6], [7]],
 [[0, 1], [2, 3], [4, 5], [6, 7]],
 [[0, 1, 2, 3], [4, 5, 6, 7]],
 [[0, 1, 2, 3, 4, 5, 6, 7]]]
kan.hypothesis.get_tree_node(model, x, moleculess, sep_th=0.01, skip_test=True)

get tree nodes

Args:

model : MultKAN, MLP or python function x : 2D torch.float

inputs

sep_thfloat

threshold of separability

skip_testbool

if True, don’t test the property of each module (to save time)

Returns:

arities : list of numbers properties : list of strings

Example

>>> from kan.hypothesis import *
>>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2
>>> x = torch.normal(0,1,size=(100,8))
>>> moleculess = get_molecule(model, x, verbose=False)
>>> get_tree_node(model, x, moleculess, skip_test=False)
kan.hypothesis.plot_tree(model, x, in_var=None, style='tree', sym_th=0.001, sep_th=0.1, skip_sep_test=False, verbose=False)

get tree graph

Args:

model : MultKAN, MLP or python function x : 2D torch.float

inputs

in_varlist of symbols

input variables

stylestr

‘tree’ or ‘box’

sym_thfloat

threshold of symmetry

sep_thfloat

threshold of separability

skip_sep_testbool

if True, don’t test the property of each module (to save time)

verbose : bool

Returns:

a tree graph

Example

>>> from kan.hypothesis import *
>>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2
>>> x = torch.normal(0,1,size=(100,8))
>>> plot_tree(model, x)
kan.hypothesis.test_general_separability(model, x, groups, threshold=0.01)

test function separability

Args:

model : MultKAN, MLP or python function x : 2D torch.float

inputs

modestr

mode = ‘add’ or mode = ‘mul’

score_thfloat

threshold of score

res_thfloat

threshold of residue

biasfloat

bias (for multiplicative separability)

verbose : bool

Returns:

bool

Example

>>> from kan.hypothesis import *
>>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]**2+x[:,[2]]**2)**2
>>> x = torch.normal(0,1,size=(100,3))
>>> print(test_general_separability(model, x, [[1],[0,2]])) # False
>>> print(test_general_separability(model, x, [[0],[1,2]])) # True
kan.hypothesis.test_separability(model, x, groups, mode='add', threshold=0.01, bias=0)

test function separability

Args:

model : MultKAN, MLP or python function x : 2D torch.float

inputs

modestr

mode = ‘add’ or mode = ‘mul’

score_thfloat

threshold of score

res_thfloat

threshold of residue

biasfloat

bias (for multiplicative separability)

verbose : bool

Returns:

bool

Example

>>> from kan.hypothesis import *
>>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]])
>>> x = torch.normal(0,1,size=(100,3))
>>> print(test_separability(model, x, [[0],[1,2]], mode='mul')) # True
>>> print(test_separability(model, x, [[0],[1,2]], mode='add')) # False
kan.hypothesis.test_symmetry(model, x, group, dependence_th=0.001)

detect function separability

Args:

model : MultKAN, MLP or python function x : 2D torch.float

inputs

group : a list of indices dependence_th : float

threshold of dependence

Returns:

bool

Example

>>> from kan.hypothesis import *
>>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]])
>>> x = torch.normal(0,1,size=(100,3))
>>> print(test_symmetry(model, x, [1,2])) # True
>>> print(test_symmetry(model, x, [0,2])) # False
kan.hypothesis.test_symmetry_var(model, x, input_vars, symmetry_var)

test symmetry

Args:

model : MultKAN, MLP or python function x : 2D torch.float

inputs

input_vars : list of sympy symbols symmetry_var : sympy expression

Returns:

cosine similarity

Example

>>> from kan.hypothesis import *
>>> from sympy import *
>>> model = lambda x: x[:,[0]] * (x[:,[1]] + x[:,[2]])
>>> x = torch.normal(0,1,size=(100,8))
>>> input_vars = a, b, c = symbols('a b c')
>>> symmetry_var = b + c
>>> test_symmetry_var(model, x, input_vars, symmetry_var);
>>> symmetry_var = b * c
>>> test_symmetry_var(model, x, input_vars, symmetry_var);