kan package
kan.KAN module
- class kan.MultKAN.MultKAN(*args: Any, **kwargs: Any)
Bases:
Module
KAN class
Attributes:
- gridint
the number of grid intervals
- kint
spline order
act_fun : a list of KANLayers symbolic_fun: a list of Symbolic_KANLayer depth : int
depth of KAN
- widthlist
number of neurons in each layer. Without multiplication nodes, [2,5,5,3] means 2D inputs, 3D outputs, with 2 layers of 5 hidden neurons. With multiplication nodes, [2,[5,3],[5,1],3] means besides the [2,5,53] KAN, there are 3 (1) mul nodes in layer 1 (2).
- mult_arityint, or list of int lists
multiplication arity for each multiplication node (the number of numbers to be multiplied)
- gridint
the number of grid intervals
- kint
the order of piecewise polynomial
- base_funfun
residual function b(x). an activation function phi(x) = sb_scale * b(x) + sp_scale * spline(x)
- symbolic_funa list of Symbolic_KANLayer
Symbolic_KANLayers
- symbolic_enabledbool
If False, the symbolic front is not computed (to save time). Default: True.
- width_inlist
The number of input neurons for each layer
- width_outlist
The number of output neurons for each layer
- base_fun_namestr
The base function b(x)
- grip_epsfloat
The parameter that interpolates between uniform grid and adaptive grid (based on sample quantile)
node_bias : a list of 1D torch.float node_scale : a list of 1D torch.float subnode_bias : a list of 1D torch.float subnode_scale : a list of 1D torch.float symbolic_enabled : bool
when symbolic_enabled = False, the symbolic branch (symbolic_fun) will be ignored in computation (set to zero)
- affine_trainablebool
indicate whether affine parameters are trainable (node_bias, node_scale, subnode_bias, subnode_scale)
- sp_trainablebool
indicate whether the overall magnitude of splines is trainable
- sb_trainablebool
indicate whether the overall magnitude of base function is trainable
- save_actbool
indicate whether intermediate activations are saved in forward pass
- node_scoresNone or list of 1D torch.float
node attribution score
- edge_scoresNone or list of 2D torch.float
edge attribution score
- subnode_scoresNone or list of 1D torch.float
subnode attribution score
- cache_dataNone or 2D torch.float
cached input data
- actsNone or a list of 2D torch.float
activations on nodes
- auto_savebool
indicate whether to automatically save a checkpoint once the model is modified
- state_idint
the state of the model (used to save checkpoint)
- ckpt_pathstr
the folder to store checkpoints
- roundint
the number of times rewind() has been called
device : str
- __init__(width=None, grid=3, k=3, mult_arity=2, noise_scale=0.3, scale_base_mu=0.0, scale_base_sigma=1.0, base_fun='silu', symbolic_enabled=True, affine_trainable=False, grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, seed=1, save_act=True, sparse_init=False, auto_save=True, first_init=True, ckpt_path='./model', state_id=0, round=0, device='cpu')
initalize a KAN model
Args:
- widthlist of int
Without multiplication nodes: \([n_0, n_1, .., n_{L-1}]\) specify the number of neurons in each layer (including inputs/outputs) With multiplication nodes: \([[n_0,m_0=0], [n_1,m_1], .., [n_{L-1},m_{L-1}]]\) specify the number of addition/multiplication nodes in each layer (including inputs/outputs)
- gridint
number of grid intervals. Default: 3.
- kint
order of piecewise polynomial. Default: 3.
- mult_arityint, or list of int lists
multiplication arity for each multiplication node (the number of numbers to be multiplied)
- noise_scalefloat
initial injected noise to spline.
- base_funstr
the residual function b(x). Default: ‘silu’
- symbolic_enabledbool
compute (True) or skip (False) symbolic computations (for efficiency). By default: True.
- affine_trainablebool
affine parameters are updated or not. Affine parameters include node_scale, node_bias, subnode_scale, subnode_bias
- grid_epsfloat
When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes.
- grid_rangelist/np.array of shape (2,))
setting the range of grids. Default: [-1,1]. This argument is not important if fit(update_grid=True) (by default updata_grid=True)
- sp_trainablebool
If true, scale_sp is trainable. Default: True.
- sb_trainablebool
If true, scale_base is trainable. Default: True.
- devicestr
device
- seedint
random seed
- save_actbool
indicate whether intermediate activations are saved in forward pass
- sparse_initbool
sparse initialization (True) or normal dense initialization. Default: False.
- auto_savebool
indicate whether to automatically save a checkpoint once the model is modified
- state_idint
the state of the model (used to save checkpoint)
- ckpt_pathstr
the folder to store checkpoints. Default: ‘./model’
- roundint
the number of times rewind() has been called
device : str
Returns:
self
Example
>>> from kan import * >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0) checkpoint directory created: ./model saving model version 0.0
- attribute(l=None, i=None, out_score=None, plot=True)
get attribution scores
Args:
- lNone or int
layer index
- iNone or int
neuron index
- out_scoreNone or 1D torch.float
specify output scores
- plotbool
when plot = True, display the bar show
Returns:
attribution scores
Example
>>> from kan import * >>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2) >>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2 >>> dataset = create_dataset(f, n_var=3) >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); >>> model.attribute() >>> model.feature_score
- auto_swap()
automatically swap neurons such as connection costs are minimized
- auto_swap_l(l)
- auto_symbolic(a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=1, weight_simple=0.8, r2_threshold=0.0)
automatic symbolic regression for all edges
Args:
- a_rangetuple
search range of a
- b_rangetuple
search range of b
- liblist of str
library of candidate symbolic functions
- verboseint
larger verbosity => more verbosity
- weight_simplefloat
a weight that prioritizies simplicity (low complexity) over performance (high r2) - set to 0.0 to ignore complexity
- r2_thresholdfloat
If r2 is below this threshold, the edge will not be fixed with any symbolic function - set to 0.0 to ignore this threshold
Returns:
None
Example
>>> from kan import * >>> model = KAN(width=[2,1,1], grid=5, k=3, noise_scale=0.0, seed=0) >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2) >>> dataset = create_dataset(f, n_var=3) >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); >>> model.auto_symbolic()
- checkout(model_id)
check out an old version
Args:
- model_idstr
in format ‘{a}.{b}’ where a is the round number, b is the version number in that round
Returns:
MultKAN
Example
Same use as rewind, although checkout doesn’t change states
- property connection_cost
- copy()
deepcopy
Args:
- pathstr
the path where checkpoints are saved
Returns:
MultKAN
Example
>>> from kan import * >>> model = KAN(width=[1,1], grid=5, k=3, seed=0) >>> model2 = model.copy() >>> model2.act_fun[0].coef.data *= 2 >>> print(model2.act_fun[0].coef.data) >>> print(model.act_fun[0].coef.data)
- disable_symbolic_in_fit(lamb)
during fitting, disable symbolic if either is true (lamb = 0, none of symbolic functions is active)
- evaluate(dataset)
- expand_depth()
expand network depth, add an indentity layer to the end. For usage, please refer to tutorials interp_3_KAN_compiler.ipynb.
Args:
- varNone or a list of sympy expression
input variables
normalizer : [mean, std] output_normalizer : [mean, std]
Returns:
None
- expand_width(layer_id, n_added_nodes, sum_bool=True, mult_arity=2)
expand network width. For usage, please refer to tutorials interp_3_KAN_compiler.ipynb.
Args:
- layer_idint
layer index
- n_added_nodesinit
the number of added nodes
- sum_boolbool
if sum_bool == True, added nodes are addition nodes; otherwise multiplication nodes
- mult_arityinit
multiplication arity (the number of numbers to be multiplied)
Returns:
None
- feature_interaction(l, neuron_th=0.01, feature_th=0.01)
get feature interaction
Args:
- lint
layer index
- neuron_thfloat
threshold to determine whether a neuron is active
- feature_thfloat
threshold to determine whether a feature is active
Returns:
dictionary
Example
>>> from kan import * >>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2) >>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2 >>> dataset = create_dataset(f, n_var=3) >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); >>> model.attribute() >>> model.feature_interaction(1)
- property feature_score
attribution scores for inputs
- fit(dataset, opt='LBFGS', steps=100, log=1, lamb=0.0, lamb_l1=1.0, lamb_entropy=2.0, lamb_coef=0.0, lamb_coefdiff=0.0, update_grid=True, grid_update_num=10, loss_fn=None, lr=1.0, start_grid_update_step=-1, stop_grid_update_step=50, batch=-1, metrics=None, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', singularity_avoiding=False, y_th=1000.0, reg_metric='edge_forward_spline_n', display_metrics=None)
training
Args:
- datasetdic
contains dataset[‘train_input’], dataset[‘train_label’], dataset[‘test_input’], dataset[‘test_label’]
- optstr
“LBFGS” or “Adam”
- stepsint
training steps
- logint
logging frequency
- lambfloat
overall penalty strength
- lamb_l1float
l1 penalty strength
- lamb_entropyfloat
entropy penalty strength
- lamb_coeffloat
coefficient magnitude penalty strength
- lamb_coefdifffloat
difference of nearby coefficits (smoothness) penalty strength
- update_gridbool
If True, update grid regularly before stop_grid_update_step
- grid_update_numint
the number of grid updates before stop_grid_update_step
- start_grid_update_stepint
no grid updates before this training step
- stop_grid_update_stepint
no grid updates after this training step
- loss_fnfunction
loss function
- lrfloat
learning rate
- batchint
batch size, if -1 then full.
- save_fig_freqint
save figure every (save_fig_freq) steps
- singularity_avoidingbool
indicate whether to avoid singularity for the symbolic part
- y_thfloat
singularity threshold (anything above the threshold is considered singular and is softened in some ways)
- reg_metricstr
regularization metric. Choose from {‘edge_forward_spline_n’, ‘edge_forward_spline_u’, ‘edge_forward_sum’, ‘edge_backward’, ‘node_backward’}
- metricsa list of metrics (as functions)
the metrics to be computed in training
- display_metricsa list of functions
the metric to be displayed in tqdm progress bar
Returns:
- resultsdic
results[‘train_loss’], 1D array of training losses (RMSE) results[‘test_loss’], 1D array of test losses (RMSE) results[‘reg’], 1D array of regularization other metrics specified in metrics
Example
>>> from kan import * >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2) >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) >>> dataset = create_dataset(f, n_var=2) >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); >>> model.plot() # Most examples in toturals involve the fit() method. Please check them for useness.
- fix_symbolic(l, i, j, fun_name, fit_params_bool=True, a_range=(-10, 10), b_range=(-10, 10), verbose=True, random=False, log_history=True)
set (l,i,j) activation to be symbolic (specified by fun_name)
Args:
- lint
layer index
- iint
input neuron index
- jint
output neuron index
- fun_namestr
function name
- fit_params_boolbool
obtaining affine parameters through fitting (True) or setting default values (False)
- a_rangetuple
sweeping range of a
- b_rangetuple
sweeping range of b
- verbosebool
If True, more information is printed.
- randombool
initialize affine parameteres randomly or as [1,0,1,0]
- log_historybool
indicate whether to log history when the function is called
Returns:
None or r2 (coefficient of determination)
Example 1
>>> # when fit_params_bool = False >>> model = KAN(width=[2,5,1], grid=5, k=3) >>> model.fix_symbolic(0,1,3,'sin',fit_params_bool=False) >>> print(model.act_fun[0].mask.reshape(2,5)) >>> print(model.symbolic_fun[0].mask.reshape(2,5))
Example 2
>>> # when fit_params_bool = True >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=1.) >>> x = torch.normal(0,1,size=(100,2)) >>> model(x) # obtain activations (otherwise model does not have attributes acts) >>> model.fix_symbolic(0,1,3,'sin',fit_params_bool=True) >>> print(model.act_fun[0].mask.reshape(2,5)) >>> print(model.symbolic_fun[0].mask.reshape(2,5))
- forward(x, singularity_avoiding=False, y_th=10.0)
forward pass
Args:
- x2D torch.tensor
inputs
- singularity_avoidingbool
whether to avoid singularity for the symbolic branch
- y_thfloat
the threshold for singularity
Returns:
None
Example1
>>> from kan import * >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0) >>> x = torch.rand(100,2) >>> model(x).shape
Example2
>>> from kan import * >>> model = KAN(width=[1,1], grid=5, k=3, seed=0) >>> x = torch.tensor([[1],[-0.01]]) >>> model.fix_symbolic(0,0,0,'log',fit_params_bool=False) >>> print(model(x)) >>> print(model(x, singularity_avoiding=True)) >>> print(model(x, singularity_avoiding=True, y_th=1.))
- get_act(x=None)
collect intermidate activations
- get_fun(l, i, j)
get function (l,i,j)
- get_params()
Get parameters
- get_range(l, i, j, verbose=True)
Get the input range and output range of the (l,i,j) activation
Args:
- lint
layer index
- iint
input neuron index
- jint
output neuron index
Returns:
- x_minfloat
minimum of input
- x_maxfloat
maximum of input
- y_minfloat
minimum of output
- y_maxfloat
maximum of output
Example
>>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.) >>> x = torch.normal(0,1,size=(100,2)) >>> model(x) # do a forward pass to obtain model.acts >>> model.get_range(0,0,0)
- get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
Get regularization. This seems unnecessary but in case a class wants to inherit this, it may want to rewrite get_reg, but not reg.
- history(k='all')
get history
- initialize_from_another_model(another_model, x)
initialize from another model of the same width, but their ‘grid’ parameter can be different. Note this is equivalent to refine() when we don’t want to keep another_model
Args:
another_model : MultKAN x : 2D torch.float
Returns:
self
Example
>>> from kan import * >>> model1 = KAN(width=[2,5,1], grid=3) >>> model2 = KAN(width=[2,5,1], grid=10) >>> x = torch.rand(100,2) >>> model2.initialize_from_another_model(model1, x)
- initialize_grid_from_another_model(model, x)
initialize grid from another model
Args:
- modelMultKAN
parent model
- x2D torch.tensor
inputs
Returns:
None
Example
>>> from kan import * >>> model = KAN(width=[1,1], grid=5, k=3, seed=0) >>> print(model.act_fun[0].grid) >>> x = torch.linspace(-10,10,steps=101)[:,None] >>> model2 = KAN(width=[1,1], grid=10, k=3, seed=0) >>> model2.initialize_grid_from_another_model(model, x) >>> print(model2.act_fun[0].grid)
- static loadckpt(path='model')
load checkpoint from path
Args:
- pathstr
the path where checkpoints are saved
Returns:
MultKAN
Example
>>> from kan import * >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0) >>> model.saveckpt('./mark') >>> KAN.loadckpt('./mark')
- log_history(method_name)
- module(start_layer, chain)
specify network modules
Args:
- start_layerint
the earliest layer of the module
- chainstr
specify neurons in the module
Returns:
None
- property n_edge
the number of active edges
- property n_mult
The number of multiplication nodes for each layer
- property n_sum
The number of addition nodes for each layer
- node_attribute()
- perturb(mag=1.0, mode='non-intrusive')
preturb a network. For usage, please refer to tutorials interp_3_KAN_compiler.ipynb.
Args:
- magfloat
perturbation magnitude
- modestr
pertubatation mode, choices = {‘non-intrusive’, ‘all’, ‘minimal’}
Returns:
None
- plot(folder='./figures', beta=3, metric='backward', scale=0.5, tick=False, sample=False, in_vars=None, out_vars=None, title=None, varscale=1.0)
plot KAN
Args:
- folderstr
the folder to store pngs
- betafloat
positive number. control the transparency of each activation. transparency = tanh(beta*l1).
- maskbool
If True, plot with mask (need to run prune() first to obtain mask). If False (by default), plot all activation functions.
- modebool
“supervised” or “unsupervised”. If “supervised”, l1 is measured by absolution value (not subtracting mean); if “unsupervised”, l1 is measured by standard deviation (subtracting mean).
- scalefloat
control the size of the diagram
- in_vars: None or list of str
the name(s) of input variables
- out_vars: None or list of str
the name(s) of output variables
- title: None or str
title
- varscalefloat
the size of input variables
Returns:
Figure
Example
>>> # see more interactive examples in demos >>> model = KAN(width=[2,3,1], grid=3, k=3, noise_scale=1.0) >>> x = torch.normal(0,1,size=(100,2)) >>> model(x) # do a forward pass to obtain model.acts >>> model.plot()
- prune(both nodes and edges)
Args:
- node_thfloat
if the attribution score of a node is below node_th, it is considered dead and will be set to zero.
- edge_thfloat
if the attribution score of an edge is below node_th, it is considered dead and will be set to zero.
Returns:
pruned network : MultKAN
Example
>>> from kan import * >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2) >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) >>> dataset = create_dataset(f, n_var=2) >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); >>> model = model.prune() >>> model.plot()
- prune_edge(threshold=0.03, log_history=True)
pruning edges
Args:
- thresholdfloat
if the attribution score of an edge is below the threshold, it is considered dead and will be set to zero.
Returns:
pruned network : MultKAN
Example
>>> from kan import * >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2) >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) >>> dataset = create_dataset(f, n_var=2) >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); >>> model = model.prune_edge() >>> model.plot()
- prune_input(threshold=0.01, active_inputs=None, log_history=True)
prune inputs
Args:
- thresholdfloat
if the attribution score of the input feature is below threshold, it is considered irrelevant.
- active_inputsNone or list
if a list is passed, the manual mode will disregard attribution score and prune as instructed.
Returns:
pruned network : MultKAN
Example1
>>> # automatic >>> from kan import * >>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2) >>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2 >>> dataset = create_dataset(f, n_var=3) >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); >>> model.plot() >>> model = model.prune_input() >>> model.plot()
Example2
>>> # automatic >>> from kan import * >>> model = KAN(width=[3,5,1], grid=5, k=3, noise_scale=0.3, seed=2) >>> f = lambda x: 1 * x[:,[0]]**2 + 0.3 * x[:,[1]]**2 + 0.0 * x[:,[2]]**2 >>> dataset = create_dataset(f, n_var=3) >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); >>> model.plot() >>> model = model.prune_input(active_inputs=[0,1]) >>> model.plot()
- prune_node(threshold=0.01, mode='auto', active_neurons_id=None, log_history=True)
pruning nodes
Args:
- thresholdfloat
if the attribution score of a neuron is below the threshold, it is considered dead and will be removed
- modestr
‘auto’ or ‘manual’. with ‘auto’, nodes are automatically pruned using threshold. with ‘manual’, active_neurons_id should be passed in.
Returns:
pruned network : MultKAN
Example
>>> from kan import * >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2) >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) >>> dataset = create_dataset(f, n_var=2) >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); >>> model = model.prune_node() >>> model.plot()
- refine(new_grid)
grid refinement
Args:
- new_gridinit
the number of grid intervals after refinement
Returns:
a refined model : MultKAN
Example
>>> from kan import * >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0) >>> print(model.grid) >>> x = torch.rand(100,2) >>> model.get_act(x) >>> model = model.refine(10) >>> print(model.grid) checkpoint directory created: ./model saving model version 0.0 5 saving model version 0.1 10
- reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
Get regularization
Args:
- reg_metricthe regularization metric
‘edge_forward_spline_n’, ‘edge_forward_spline_u’, ‘edge_forward_sum’, ‘edge_backward’, ‘node_backward’
- lamb_l1float
l1 penalty strength
- lamb_entropyfloat
entropy penalty strength
- lamb_coeffloat
coefficient penalty strength
- lamb_coefdifffloat
coefficient smoothness strength
Returns:
reg_ : torch.float
Example
>>> model = KAN(width=[2,3,1], grid=5, k=3, noise_scale=1.) >>> x = torch.rand(100,2) >>> model.get_act(x) >>> model.reg('edge_forward_spline_n', 1.0, 2.0, 1.0, 1.0)
- remove_edge(l, i, j, log_history=True)
remove activtion phi(l,i,j) (set its mask to zero)
- remove_node(l, i, mode='all', log_history=True)
remove neuron (l,i) (set the masks of all incoming and outgoing activation functions to zero)
- rewind(model_id)
rewind to an old version
Args:
- model_idstr
in format ‘{a}.{b}’ where a is the round number, b is the version number in that round
Returns:
MultKAN
Example
Please refer to tutorials. API 12: Checkpoint, save & load model
- saveckpt(path='model')
save the current model to files (configuration file and state file)
Args:
- pathstr
the path where checkpoints are saved
Returns:
None
Example
>>> from kan import * >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0) >>> model.saveckpt('./mark') # There will be three files appearing in the current folder: mark_cache_data, mark_config.yml, mark_state
- set_mode(l, i, j, mode, mask_n=None)
- speed(compile=False)
turn on KAN’s speed mode
- suggest_symbolic(l, i, j, a_range=(-10, 10), b_range=(-10, 10), lib=None, topk=5, verbose=True, r2_loss_fun=<function MultKAN.<lambda>>, c_loss_fun=<function MultKAN.<lambda>>, weight_simple=0.8)
suggest symbolic function
Args:
- lint
layer index
- iint
neuron index in layer l
- jint
neuron index in layer j
- a_rangetuple
search range of a
- b_rangetuple
search range of b
- liblist of str
library of candidate symbolic functions
- topkint
the number of top functions displayed
- verbosebool
if verbose = True, print more information
- r2_loss_funfunctoon
function : r2 -> “bits”
- c_loss_funfun
function : c -> ‘bits’
- weight_simplefloat
the simplifty weight: the higher, more prefer simplicity over performance
Returns:
best_name (str), best_fun (function), best_r2 (float), best_c (float)
Example
>>> from kan import * >>> model = KAN(width=[2,1,1], grid=5, k=3, noise_scale=0.0, seed=0) >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2) >>> dataset = create_dataset(f, n_var=3) >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); >>> model.suggest_symbolic(0,1,0)
- swap(l, i1, i2, log_history=True)
- symbolic_formula(var=None, normalizer=None, output_normalizer=None)
get symbolic formula
Args:
- varNone or a list of sympy expression
input variables
normalizer : [mean, std] output_normalizer : [mean, std]
Returns:
None
Example
>>> from kan import * >>> model = KAN(width=[2,1,1], grid=5, k=3, noise_scale=0.0, seed=0) >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]])+x[:,[1]]**2) >>> dataset = create_dataset(f, n_var=3) >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001); >>> model.auto_symbolic() >>> model.symbolic_formula()[0][0]
- to(device)
move the model to device
Args:
device : str or device
Returns:
self
Example
>>> from kan import * >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0) >>> model.to(device)
- tree(x=None, in_var=None, style='tree', sym_th=0.001, sep_th=0.1, skip_sep_test=False, verbose=False)
turn KAN into a tree
- unfix_symbolic(l, i, j, log_history=True)
unfix the (l,i,j) activation function.
- unfix_symbolic_all(log_history=True)
unfix all activation functions.
- update_grid(x)
call update_grid_from_samples. This seems unnecessary but we retain it for the sake of classes that might inherit from MultKAN
- update_grid_from_samples(x)
update grid from samples
Args:
- x2D torch.tensor
inputs
Returns:
None
Example
>>> from kan import * >>> model = KAN(width=[1,1], grid=5, k=3, seed=0) >>> print(model.act_fun[0].grid) >>> x = torch.linspace(-10,10,steps=101)[:,None] >>> model.update_grid_from_samples(x) >>> print(model.act_fun[0].grid)
- property width_in
The number of input nodes for each layer
- property width_out
The number of output subnodes for each layer
kan.KANLayer module
- class kan.KANLayer.KANLayer(*args: Any, **kwargs: Any)
Bases:
Module
KANLayer class
Attributes:
- in_dim: int
input dimension
- out_dim: int
output dimension
- num: int
the number of grid intervals
- k: int
the piecewise polynomial order of splines
- noise_scale: float
spline scale at initialization
- coef: 2D torch.tensor
coefficients of B-spline bases
- scale_base_mu: float
magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_mu
- scale_base_sigma: float
magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_sigma
- scale_sp: float
mangitude of the spline function spline(x)
- base_fun: fun
residual function b(x)
- mask: 1D torch.float
mask of spline functions. setting some element of the mask to zero means setting the corresponding activation to zero function.
- grid_eps: float in [0,1]
a hyperparameter used in update_grid_from_samples. When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes. the id of activation functions that are locked
- device: str
device
- __init__(in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.5, scale_base_mu=0.0, scale_base_sigma=1.0, scale_sp=1.0, base_fun=torch.nn.SiLU, grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, save_plot_data=True, device='cpu', sparse_init=False)
‘ initialize a KANLayer
Args:
- in_dimint
input dimension. Default: 2.
- out_dimint
output dimension. Default: 3.
- numint
the number of grid intervals = G. Default: 5.
- kint
the order of piecewise polynomial. Default: 3.
- noise_scalefloat
the scale of noise injected at initialization. Default: 0.1.
- scale_base_mufloat
the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2).
- scale_base_sigmafloat
the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2).
- scale_spfloat
the scale of the base function spline(x).
- base_funfunction
residual function b(x). Default: torch.nn.SiLU()
- grid_epsfloat
When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes.
- grid_rangelist/np.array of shape (2,)
setting the range of grids. Default: [-1,1].
- sp_trainablebool
If true, scale_sp is trainable
- sb_trainablebool
If true, scale_base is trainable
- devicestr
device
- sparse_initbool
if sparse_init = True, sparse initialization is applied.
Returns:
self
Example
>>> from kan.KANLayer import * >>> model = KANLayer(in_dim=3, out_dim=5) >>> (model.in_dim, model.out_dim)
- forward(x)
KANLayer forward given input x
Args:
- x2D torch.float
inputs, shape (number of samples, input dimension)
Returns:
- y2D torch.float
outputs, shape (number of samples, output dimension)
- preacts3D torch.float
fan out x into activations, shape (number of sampels, output dimension, input dimension)
- postacts3D torch.float
the outputs of activation functions with preacts as inputs
- postspline3D torch.float
the outputs of spline functions with preacts as inputs
Example
>>> from kan.KANLayer import * >>> model = KANLayer(in_dim=3, out_dim=5) >>> x = torch.normal(0,1,size=(100,3)) >>> y, preacts, postacts, postspline = model(x) >>> y.shape, preacts.shape, postacts.shape, postspline.shape
- get_subset(in_id, out_id)
get a smaller KANLayer from a larger KANLayer (used for pruning)
Args:
- in_idlist
id of selected input neurons
- out_idlist
id of selected output neurons
Returns:
spb : KANLayer
Example
>>> kanlayer_large = KANLayer(in_dim=10, out_dim=10, num=5, k=3) >>> kanlayer_small = kanlayer_large.get_subset([0,9],[1,2,3]) >>> kanlayer_small.in_dim, kanlayer_small.out_dim (2, 3)
- initialize_grid_from_parent(parent, x, mode='sample')
update grid from a parent KANLayer & samples
Args:
- parentKANLayer
a parent KANLayer (whose grid is usually coarser than the current model)
- x2D torch.float
inputs, shape (number of samples, input dimension)
Returns:
None
Example
>>> batch = 100 >>> parent_model = KANLayer(in_dim=1, out_dim=1, num=5, k=3) >>> print(parent_model.grid.data) >>> model = KANLayer(in_dim=1, out_dim=1, num=10, k=3) >>> x = torch.normal(0,1,size=(batch, 1)) >>> model.initialize_grid_from_parent(parent_model, x) >>> print(model.grid.data)
- swap(i1, i2, mode='in')
swap the i1 neuron with the i2 neuron in input (if mode == ‘in’) or output (if mode == ‘out’)
Args:
i1 : int i2 : int mode : str
mode = ‘in’ or ‘out’
Returns:
None
Example
>>> from kan.KANLayer import * >>> model = KANLayer(in_dim=2, out_dim=2, num=5, k=3) >>> print(model.coef) >>> model.swap(0,1,mode='in') >>> print(model.coef)
- to(device)
- update_grid_from_samples(x, mode='sample')
update grid from samples
Args:
- x2D torch.float
inputs, shape (number of samples, input dimension)
Returns:
None
Example
>>> model = KANLayer(in_dim=1, out_dim=1, num=5, k=3) >>> print(model.grid.data) >>> x = torch.linspace(-3,3,steps=100)[:,None] >>> model.update_grid_from_samples(x) >>> print(model.grid.data)
kan.LBFGS module
- class kan.LBFGS.LBFGS(*args: Any, **kwargs: Any)
Bases:
Optimizer
Implements L-BFGS algorithm.
Heavily inspired by minFunc.
Warning
This optimizer doesn’t support per-parameter options and parameter groups (there can be only one).
Warning
Right now all parameters have to be on a single device. This will be improved in the future.
Note
This is a very memory intensive optimizer (it requires additional
param_bytes * (history_size + 1)
bytes). If it doesn’t fit in memory try reducing the history size, or use a different algorithm.- Args:
lr (float): learning rate (default: 1) max_iter (int): maximal number of iterations per optimization step
(default: 20)
- max_eval (int): maximal number of function evaluations per optimization
step (default: max_iter * 1.25).
- tolerance_grad (float): termination tolerance on first order optimality
(default: 1e-7).
- tolerance_change (float): termination tolerance on function
value/parameter changes (default: 1e-9).
history_size (int): update history size (default: 100). line_search_fn (str): either ‘strong_wolfe’ or None (default: None).
- __init__(params, lr=1, max_iter=20, max_eval=None, tolerance_grad=1e-07, tolerance_change=1e-09, tolerance_ys=1e-32, history_size=100, line_search_fn=None)
- step(closure)
Perform a single optimization step.
- Args:
- closure (Callable): A closure that reevaluates the model
and returns the loss.
kan.Symbolic_KANLayer module
- class kan.Symbolic_KANLayer.Symbolic_KANLayer(*args: Any, **kwargs: Any)
Bases:
Module
KANLayer class
Attributes:
- in_dimint
input dimension
- out_dimint
output dimension
- funs2D array of torch functions (or lambda functions)
symbolic functions (torch)
funs_avoid_singularity : 2D array of torch functions (or lambda functions) with singularity avoiding funs_name : 2D arry of str
names of symbolic functions
- funs_sympy2D array of sympy functions (or lambda functions)
symbolic functions (sympy)
- affine3D array of floats
affine transformations of inputs and outputs
- __init__(in_dim=3, out_dim=2, device='cpu')
initialize a Symbolic_KANLayer (activation functions are initialized to be identity functions)
Args:
- in_dimint
input dimension
- out_dimint
output dimension
- devicestr
device
Returns:
self
Example
>>> sb = Symbolic_KANLayer(in_dim=3, out_dim=3) >>> len(sb.funs), len(sb.funs[0])
- fix_symbolic(i, j, fun_name, x=None, y=None, random=False, a_range=(-10, 10), b_range=(-10, 10), verbose=True)
fix an activation function to be symbolic
Args:
- iint
the id of input neuron
- jint
the id of output neuron
- fun_namestr
the name of the symbolic functions
- x1D array
preactivations
- y1D array
postactivations
- a_rangetuple
sweeping range of a
- b_rangetuple
sweeping range of a
- verbosebool
print more information if True
Returns:
r2 (coefficient of determination)
Example 1
>>> # when x & y are not provided. Affine parameters are set to a = 1, b = 0, c = 1, d = 0 >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=2) >>> sb.fix_symbolic(2,1,'sin') >>> print(sb.funs_name) >>> print(sb.affine)
Example 2
>>> # when x & y are provided, fit_params() is called to find the best fit coefficients >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=2) >>> batch = 100 >>> x = torch.linspace(-1,1,steps=batch) >>> noises = torch.normal(0,1,(batch,)) * 0.02 >>> y = 5.0*torch.sin(3.0*x + 2.0) + 0.7 + noises >>> sb.fix_symbolic(2,1,'sin',x,y) >>> print(sb.funs_name) >>> print(sb.affine[1,2,:].data)
- forward(x, singularity_avoiding=False, y_th=10.0)
Args:
- x2D array
inputs, shape (batch, input dimension)
- singularity_avoidingbool
if True, funs_avoid_singularity is used; if False, funs is used.
- y_thfloat
the singularity threshold
Returns:
- y2D array
outputs, shape (batch, output dimension)
- postacts3D array
activations after activation functions but before being summed on nodes
Example
>>> sb = Symbolic_KANLayer(in_dim=3, out_dim=5) >>> x = torch.normal(0,1,size=(100,3)) >>> y, postacts = sb(x) >>> y.shape, postacts.shape (torch.Size([100, 5]), torch.Size([100, 5, 3]))
- get_subset(in_id, out_id)
get a smaller Symbolic_KANLayer from a larger Symbolic_KANLayer (used for pruning)
Args:
- in_idlist
id of selected input neurons
- out_idlist
id of selected output neurons
Returns:
spb : Symbolic_KANLayer
Example
>>> sb_large = Symbolic_KANLayer(in_dim=10, out_dim=10) >>> sb_small = sb_large.get_subset([0,9],[1,2,3]) >>> sb_small.in_dim, sb_small.out_dim
- swap(i1, i2, mode='in')
swap the i1 neuron with the i2 neuron in input (if mode == ‘in’) or output (if mode == ‘out’)
- to(device)
move to device
kan.spline module
- kan.spline.B_batch(x, grid, k=0, extend=True, device='cpu')
evaludate x on B-spline bases
Args:
- x2D torch.tensor
inputs, shape (number of splines, number of samples)
- grid2D torch.tensor
grids, shape (number of splines, number of grid points)
- kint
the piecewise polynomial order of splines.
- extendbool
If True, k points are extended on both ends. If False, no extension (zero boundary condition). Default: True
- devicestr
devicde
Returns:
- spline values3D torch.tensor
shape (batch, in_dim, G+k). G: the number of grid intervals, k: spline order.
Example
>>> from kan.spline import B_batch >>> x = torch.rand(100,2) >>> grid = torch.linspace(-1,1,steps=11)[None, :].expand(2, 11) >>> B_batch(x, grid, k=3).shape
- kan.spline.coef2curve(x_eval, grid, coef, k, device='cpu')
converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves (summing up B_batch results over B-spline basis).
Args:
- x_eval2D torch.tensor
shape (batch, in_dim)
- grid2D torch.tensor
shape (in_dim, G+2k). G: the number of grid intervals; k: spline order.
- coef3D torch.tensor
shape (in_dim, out_dim, G+k)
- kint
the piecewise polynomial order of splines.
- devicestr
devicde
Returns:
- y_eval3D torch.tensor
shape (number of samples, in_dim, out_dim)
- kan.spline.curve2coef(x_eval, y_eval, grid, k, lamb=1e-08)
converting B-spline curves to B-spline coefficients using least squares.
Args:
- x_eval2D torch.tensor
shape (in_dim, out_dim, number of samples)
- y_eval2D torch.tensor
shape (in_dim, out_dim, number of samples)
- grid2D torch.tensor
shape (in_dim, grid+2*k)
- kint
spline order
- lambfloat
regularized least square lambda
Returns:
- coef3D torch.tensor
shape (in_dim, out_dim, G+k)
- kan.spline.extend_grid(grid, k_extend=0)
extend grid
kan.utils module
- kan.utils.add_symbolic(name, fun, c=1, fun_singularity=None)
add a symbolic function to library
Args:
- namestr
name of the function
- funfun
torch function or lambda function
Returns:
None
Example
>>> print(SYMBOLIC_LIB['Bessel']) KeyError: 'Bessel' >>> add_symbolic('Bessel', torch.special.bessel_j0) >>> print(SYMBOLIC_LIB['Bessel']) (<built-in function special_bessel_j0>, Bessel)
- kan.utils.augment_input(orig_vars, aux_vars, x)
augment inputs
Args:
orig_vars : list of sympy symbols aux_vars : list of auxiliary symbols x : inputs
Returns:
augmented inputs
Example
>>> from kan.utils import * >>> from sympy import * >>> orig_vars = a, b = symbols('a b') >>> aux_vars = [a + b, a * b] >>> x = torch.rand(100, 2) >>> augment_input(orig_vars, aux_vars, x).shape
- kan.utils.batch_hessian(model, x, create_graph=False)
hessian
Args:
func : function or model x : inputs create_graph : bool
Returns:
jacobian
Example
>>> from kan.utils import batch_hessian >>> x = torch.normal(0,1,size=(100,2)) >>> model = lambda x: x[:,[0]]**2 + x[:,[1]]**2 >>> batch_hessian(model, x)
- kan.utils.batch_jacobian(func, x, create_graph=False, mode='scalar')
jacobian
Args:
func : function or model x : inputs create_graph : bool
Returns:
jacobian
Example
>>> from kan.utils import batch_jacobian >>> x = torch.normal(0,1,size=(100,2)) >>> model = lambda x: x[:,[0]] + x[:,[1]] >>> batch_jacobian(model, x)
- kan.utils.create_dataset(f, n_var=2, f_mode='col', ranges=[-1, 1], train_num=1000, test_num=1000, normalize_input=False, normalize_label=False, device='cpu', seed=0)
create dataset
Args:
- ffunction
the symbolic formula used to create the synthetic dataset
- rangeslist or np.array; shape (2,) or (n_var, 2)
the range of input variables. Default: [-1,1].
- train_numint
the number of training samples. Default: 1000.
- test_numint
the number of test samples. Default: 1000.
- normalize_inputbool
If True, apply normalization to inputs. Default: False.
- normalize_labelbool
If True, apply normalization to labels. Default: False.
- devicestr
device. Default: ‘cpu’.
- seedint
random seed. Default: 0.
Returns:
- datasetdic
- Train/test inputs/labels are dataset[‘train_input’], dataset[‘train_label’],
dataset[‘test_input’], dataset[‘test_label’]
Example
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) >>> dataset = create_dataset(f, n_var=2, train_num=100) >>> dataset['train_input'].shape torch.Size([100, 2])
- kan.utils.create_dataset_from_data(inputs, labels, train_ratio=0.8, device='cpu')
create dataset from data
Args:
inputs : 2D torch.float labels : 2D torch.float train_ratio : float
the ratio of training fraction
device : str
Returns:
dataset (dictionary)
Example
>>> from kan.utils import create_dataset_from_data >>> x = torch.normal(0,1,size=(100,2)) >>> y = torch.normal(0,1,size=(100,1)) >>> dataset = create_dataset_from_data(x, y) >>> dataset['train_input'].shape
- kan.utils.ex_round(ex1, n_digit)
rounding the numbers in an expression to certain floating points
Args:
ex1 : sympy expression n_digit : int
Returns:
ex2 : sympy expression
Example
>>> from kan.utils import * >>> from sympy import * >>> input_vars = a, b = symbols('a b') >>> expression = 3.14534242 * exp(sin(pi*a) + b**2) - 2.32345402 >>> ex_round(expression, 2)
- kan.utils.f_arccos(x, y_th)
- kan.utils.f_arcsin(x, y_th)
- kan.utils.f_arctanh(x, y_th)
- kan.utils.f_exp(x, y_th)
- kan.utils.f_inv(x, y_th)
- kan.utils.f_inv2(x, y_th)
- kan.utils.f_inv3(x, y_th)
- kan.utils.f_inv4(x, y_th)
- kan.utils.f_inv5(x, y_th)
- kan.utils.f_invsqrt(x, y_th)
- kan.utils.f_log(x, y_th)
- kan.utils.f_power1d5(x, y_th)
- kan.utils.f_sqrt(x, y_th)
- kan.utils.f_tan(x, y_th)
- kan.utils.fit_params(x, y, fun, a_range=(-10, 10), b_range=(-10, 10), grid_number=101, iteration=3, verbose=True, device='cpu')
fit a, b, c, d such that
\[|y-(cf(ax+b)+d)|^2\]is minimized. Both x and y are 1D array. Sweep a and b, find the best fitted model.
Args:
- x1D array
x values
- y1D array
y values
- funfunction
symbolic function
- a_rangetuple
sweeping range of a
- b_rangetuple
sweeping range of b
- grid_numint
number of steps along a and b
- iterationint
number of zooming in
- verbosebool
print extra information if True
- devicestr
device
Returns:
- a_bestfloat
best fitted a
- b_bestfloat
best fitted b
- c_bestfloat
best fitted c
- d_bestfloat
best fitted d
- r2_bestfloat
best r2 (coefficient of determination)
Example
>>> num = 100 >>> x = torch.linspace(-1,1,steps=num) >>> noises = torch.normal(0,1,(num,)) * 0.02 >>> y = 5.0*torch.sin(3.0*x + 2.0) + 0.7 + noises >>> fit_params(x, y, torch.sin) r2 is 0.9999727010726929 (tensor([2.9982, 1.9996, 5.0053, 0.7011]), tensor(1.0000))
- kan.utils.get_derivative(model, inputs, labels, derivative='hessian', loss_mode='pred', reg_metric='w', lamb=0.0, lamb_l1=1.0, lamb_entropy=0.0)
compute the jacobian/hessian of loss wrt to model parameters
Args:
inputs : 2D torch.float labels : 2D torch.float derivative : str
‘jacobian’ or ‘hessian’
device : str
Returns:
jacobian or hessian
- kan.utils.model2param(model)
turn model parameters into a flattened vector
- kan.utils.sparse_mask(in_dim, out_dim)
get sparse mask
kan.compiler module
- kan.compiler.expr2kan(input_variables, expr, grid=5, k=3, auto_save=False)
compile a symbolic formula to a MultKAN
Args:
input_variables : a list of sympy symbols expr : sympy expression grid : int
the number of grid intervals
- kint
spline order
- auto_savebool
if auto_save = True, models are automatically saved
Returns:
MultKAN
Example
>>> from kan.compiler import * >>> from sympy import * >>> input_vars = a, b = symbols('a b') >>> expression = exp(sin(pi*a) + b**2) >>> model = kanpiler(input_vars, expression) >>> x = torch.rand(100,2) * 2 - 1 >>> model(x) >>> model.plot()
- kan.compiler.kanpiler(input_variables, expr, grid=5, k=3, auto_save=False)
compile a symbolic formula to a MultKAN
Args:
input_variables : a list of sympy symbols expr : sympy expression grid : int
the number of grid intervals
- kint
spline order
- auto_savebool
if auto_save = True, models are automatically saved
Returns:
MultKAN
Example
>>> from kan.compiler import * >>> from sympy import * >>> input_vars = a, b = symbols('a b') >>> expression = exp(sin(pi*a) + b**2) >>> model = kanpiler(input_vars, expression) >>> x = torch.rand(100,2) * 2 - 1 >>> model(x) >>> model.plot()
- kan.compiler.next_nontrivial_operation(expr, scale=1, bias=0)
remove the affine part of an expression
Args:
expr : sympy expression scale : float bias : float
Returns:
expr : sympy expression scale : float bias : float
Example
>>> from kan.compiler import * >>> from sympy import * >>> input_vars = a, b = symbols('a b') >>> expression = 3.14534242 * exp(sin(pi*a) + b**2) - 2.32345402 >>> next_nontrivial_operation(expression)
- kan.compiler.sf2kan(input_variables, expr, grid=5, k=3, auto_save=False)
compile a symbolic formula to a MultKAN
Args:
input_variables : a list of sympy symbols expr : sympy expression grid : int
the number of grid intervals
- kint
spline order
- auto_savebool
if auto_save = True, models are automatically saved
Returns:
MultKAN
Example
>>> from kan.compiler import * >>> from sympy import * >>> input_vars = a, b = symbols('a b') >>> expression = exp(sin(pi*a) + b**2) >>> model = kanpiler(input_vars, expression) >>> x = torch.rand(100,2) * 2 - 1 >>> model(x) >>> model.plot()
kan.hypothesis module
- kan.hypothesis.batch_grad_normgrad(model, x, group, create_graph=False)
- kan.hypothesis.detect_separability(model, x, mode='add', score_th=0.01, res_th=0.01, n_clusters=None, bias=0.0, verbose=False)
detect function separability
Args:
model : MultKAN, MLP or python function x : 2D torch.float
inputs
- modestr
mode = ‘add’ or mode = ‘mul’
- score_thfloat
threshold of score
- res_thfloat
threshold of residue
- n_clustersNone or int
the number of clusters
- biasfloat
bias (for multiplicative separability)
verbose : bool
Returns:
results (dictionary)
Example1
>>> from kan.hypothesis import * >>> model = lambda x: x[:,[0]] ** 2 + torch.exp(x[:,[1]]+x[:,[2]]) >>> x = torch.normal(0,1,size=(100,3)) >>> detect_separability(model, x, mode='add')
Example2
>>> from kan.hypothesis import * >>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]]) >>> x = torch.normal(0,1,size=(100,3)) >>> detect_separability(model, x, mode='mul')
- kan.hypothesis.get_dependence(model, x, group)
- kan.hypothesis.get_molecule(model, x, sym_th=0.001, verbose=True)
how variables are combined hierarchically
Args:
model : MultKAN, MLP or python function x : 2D torch.float
inputs
- sym_thfloat
threshold of symmetry
verbose : bool
Returns:
list
Example
>>> from kan.hypothesis import * >>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2 >>> x = torch.normal(0,1,size=(100,8)) >>> get_molecule(model, x, verbose=False) [[[0], [1], [2], [3], [4], [5], [6], [7]], [[0, 1], [2, 3], [4, 5], [6, 7]], [[0, 1, 2, 3], [4, 5, 6, 7]], [[0, 1, 2, 3, 4, 5, 6, 7]]]
- kan.hypothesis.get_tree_node(model, x, moleculess, sep_th=0.01, skip_test=True)
get tree nodes
Args:
model : MultKAN, MLP or python function x : 2D torch.float
inputs
- sep_thfloat
threshold of separability
- skip_testbool
if True, don’t test the property of each module (to save time)
Returns:
arities : list of numbers properties : list of strings
Example
>>> from kan.hypothesis import * >>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2 >>> x = torch.normal(0,1,size=(100,8)) >>> moleculess = get_molecule(model, x, verbose=False) >>> get_tree_node(model, x, moleculess, skip_test=False)
- kan.hypothesis.plot_tree(model, x, in_var=None, style='tree', sym_th=0.001, sep_th=0.1, skip_sep_test=False, verbose=False)
get tree graph
Args:
model : MultKAN, MLP or python function x : 2D torch.float
inputs
- in_varlist of symbols
input variables
- stylestr
‘tree’ or ‘box’
- sym_thfloat
threshold of symmetry
- sep_thfloat
threshold of separability
- skip_sep_testbool
if True, don’t test the property of each module (to save time)
verbose : bool
Returns:
a tree graph
Example
>>> from kan.hypothesis import * >>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2 >>> x = torch.normal(0,1,size=(100,8)) >>> plot_tree(model, x)
- kan.hypothesis.test_general_separability(model, x, groups, threshold=0.01)
test function separability
Args:
model : MultKAN, MLP or python function x : 2D torch.float
inputs
- modestr
mode = ‘add’ or mode = ‘mul’
- score_thfloat
threshold of score
- res_thfloat
threshold of residue
- biasfloat
bias (for multiplicative separability)
verbose : bool
Returns:
bool
Example
>>> from kan.hypothesis import * >>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]**2+x[:,[2]]**2)**2 >>> x = torch.normal(0,1,size=(100,3)) >>> print(test_general_separability(model, x, [[1],[0,2]])) # False >>> print(test_general_separability(model, x, [[0],[1,2]])) # True
- kan.hypothesis.test_separability(model, x, groups, mode='add', threshold=0.01, bias=0)
test function separability
Args:
model : MultKAN, MLP or python function x : 2D torch.float
inputs
- modestr
mode = ‘add’ or mode = ‘mul’
- score_thfloat
threshold of score
- res_thfloat
threshold of residue
- biasfloat
bias (for multiplicative separability)
verbose : bool
Returns:
bool
Example
>>> from kan.hypothesis import * >>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]]) >>> x = torch.normal(0,1,size=(100,3)) >>> print(test_separability(model, x, [[0],[1,2]], mode='mul')) # True >>> print(test_separability(model, x, [[0],[1,2]], mode='add')) # False
- kan.hypothesis.test_symmetry(model, x, group, dependence_th=0.001)
detect function separability
Args:
model : MultKAN, MLP or python function x : 2D torch.float
inputs
group : a list of indices dependence_th : float
threshold of dependence
Returns:
bool
Example
>>> from kan.hypothesis import * >>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]]) >>> x = torch.normal(0,1,size=(100,3)) >>> print(test_symmetry(model, x, [1,2])) # True >>> print(test_symmetry(model, x, [0,2])) # False
- kan.hypothesis.test_symmetry_var(model, x, input_vars, symmetry_var)
test symmetry
Args:
model : MultKAN, MLP or python function x : 2D torch.float
inputs
input_vars : list of sympy symbols symmetry_var : sympy expression
Returns:
cosine similarity
Example
>>> from kan.hypothesis import * >>> from sympy import * >>> model = lambda x: x[:,[0]] * (x[:,[1]] + x[:,[2]]) >>> x = torch.normal(0,1,size=(100,8)) >>> input_vars = a, b, c = symbols('a b c') >>> symmetry_var = b + c >>> test_symmetry_var(model, x, input_vars, symmetry_var); >>> symmetry_var = b * c >>> test_symmetry_var(model, x, input_vars, symmetry_var);