Dissecting Kolmogorov-Arnold Network

Fei Cheung
11 min readMay 19, 2024

--

Recently, the Kolmogorov-Arnold Network (KAN) has garnered significant interest within the machine learning community, presenting a promising alternative to the traditional Multi-Layer Perceptron (MLP). Given that MLPs are fundamental components of many models, including Transformers, which power the latest large language model (LLM) applications, this development could lead to substantial changes in the field.

Believing that details matter, I spent some time exploring and dissecting the KAN architecture. Through this deep dive, I learned something new and gained insights into the excitement surrounding this innovation. In this blog post, I aim to share my findings and personal thoughts on KAN, highlighting what I find intriguing and where I see potential challenges.

TL;DR: The paper introduces KAN from KART. After dissecting KAN, my conclusion is that MLP is a specific form of KAN.

What is Kolmogorov–Arnold representation theorem?

The theoretical background of the Kolmogorov-Arnold Network (KAN) originates from the Kolmogorov–Arnold representation theorem. According to Wikipedia,

Kolmogorov–Arnold representation theorem (or superposition theorem) states that every multivariate continuous function f

eq for function

can be represented as a superposition of the two-argument addition of continuous functions of one variable.

The theorem is named after two distinguished mathematicians, Vladimir Arnold and Andrey Kolmogorov.

Specifically:

I am too lazy to type LaTex, and covert it to images.

In simplified terms, any multivariate continuous function f can be factorized into a composition and summation of single-valued functions.

This is a pretty strong claim. To see it in action, let’s look at an example:

simple example

The summation form of the equation may appear intimidating. Let’s plug in n=2 and illustrate the computation in a graph. This graphical representation is more comfortable for me to read. It is essentially the same graph as in the paper.

computational graph for KAN

To summarize, we only need to find (learn) single-valued functions for each input dimension to fit the desired multi-valued function. Looking at the graph and the theorem, it starts to resemble something very familiar: the MLP.

Now let’s look at the traditional MLP.

picking the single hidden layer with 5 neuron, to compare with KART

The two graphs (computation) look very similar, which is what the author proposed:

We can construct a new type of neural network layer with KART.

The red squares represent layers of MLP and KAN. Both layers have [2, 5] configurations. The whole toy network is a [2, 5, 1] configurations.

The example in the graph shows a KAN and MLP layer with 2 inputs and 5 outputs. For MLP, the learnable parameters are the weight matrix W and bias b, and we can choose our favorite activation function (e.g., ReLU).

For KAN, our learnable parameters will reside within the function ϕ, for which we have not yet chosen a specific form. The equation for KAN is actually quite cumbersome with numerous summations. I find the graph representation more intuitive, and interested readers can refer to the original paper for the full equations.

KAN vs MLP

Looking at the KAN and MLP computation, aren’t they just neural network with different way of non-linearity? What’s the big deal?

KAN allows us to escape the curse of dimensionality in function fitting.

Let’s examine the equations for MLP and KAN layers.

In MLPs, the activation function depends on all inputs, whereas in KANs, it depends only on a single input. This distinction opens the door for us to use various techniques for function fitting without the concern of dimensionality’s curse. The remaining question is: what function should we choose for KAN?

Spline

In the paper, the author choose to use Spline for the function. Spline is a well study topic in computer graphic. In fact, the font you are reading right now is constructed using splines.

source: https://raphlinus.github.io/curves/2018/12/21/new-spline.html

Splines consist of different segments, each parameterized by polynomials. Let’s use cubic splines as an example.

First, we only need to solve four unknowns for a cubic polynomial.

One way to do this is to use four fixed points to fit the polynomial for one segment. Then, we can stitch multiple segments together to form the complete spline.

Our choice of parameterization is not constrained by passing through specific points. We can also specify different orders of derivatives to achieve specific properties. For example, we can ensure the entire spline is continuous in the first-order derivative by constraining a continuous derivative at the knots (the points where segments join). Splines are a significant topic in computer graphics, games, and even physics. I highly recommend watching this excellent youtube video for a more in-depth explanation of splines.

Returning to the function for KAN, if we choose cubic splines, the learnable parameters are the coefficients in the polynomials, control points, or derivatives, depending on your interpretation of the spline. One aspect we also control is the location of the knots.

B-Spline

The paper selects B-splines for the learnable function. B-splines have several interesting properties:

  1. Local Control: B-splines offer local control across different segments. Generally, changing one segment of a spline affects all other segments. However, B-splines have the unique property of only affecting the local segment where the change occurs.
  2. Continuity: B-Spline have 2 order lower continuity. i.e. A B-Spline with order 4(degree 3 polynomial) have continuous first and second order derivatives.
  3. Non-Interpolative Control Points: B-splines do not necessarily interpolate the control points. This does not hurt our model since the control points are the learnable coefficients.

Now look at the B-Spline definition:

From Wiki

Generally the B-Spline are controlled by the coefficient c_i, the order of the polynomial k, and the segment location. The equation may seem intimidating, so let’s examine a specific example to gain some intuition. For instance, when

substituting t_i, t_i+1 and t_i+2, into t, you can see it is just a straight line in each intervals.

Observe that the basis function is defined recursively and depends on k+1 intervals. With an order of 2, the basis function becomes a degree 2 polynomial with 4 control points. The above illustration shows only one segment; the B-spline function is defined across the entire domain. Let’s plot some graphs and visualize all basis functions to better understand their behavior.

from kan.spline import B_batch

# Example usage
num_spline = 1 # Number of splines
num_samples = 1000 # Number of samples
num_grid_points = 7 # Number of grid points
k = 1 # Piecewise polynomial order

# Generate input data
x = torch.linspace(0, 1, steps=num_samples).unsqueeze(0).repeat(num_spline, 1)
grid = torch.linspace(0, 1, steps=num_grid_points + 1).unsqueeze(0).repeat(num_spline, 1)

# Evaluate B-spline bases
spline_values = B_batch(x, grid, k=k, extend=True)

# Plot the B-spline basis functions
plt.figure(figsize=(10, 6))
for i in range(spline_values.shape[1]):
plt.plot(x[0], spline_values[0, i], label=f'Basis {i}')

plt.title('B-spline Basis Functions')
plt.xlabel('x')
plt.ylabel('Basis Value')
plt.legend()
plt.grid(True)
plt.show()
k = 1, grid = 7

It is a basis since the spline is a linear combination of the above basis functions. An interesting property is that they all sum up to 1, which is generally true for all orders of polynomial.

Now, let’s perform a linear combination of the basis functions to form a spline.

# Example usage
num_spline = 1
num_sample = 100
num_grid_interval = 7
k = 1

# Generate random input data
x_eval = torch.linspace(0, 1, steps=num_sample).unsqueeze(0).repeat(num_spline, 1)
grid = torch.linspace(0, 1, steps=num_grid_interval+1).unsqueeze(0).repeat(num_spline, 1)
# coef = torch.randn(num_spline, num_grid_interval+k)
coef = torch.tensor([[1, 5, 3, 2, 1, 4, 4, 7]])

# Evaluate B-spline curves
y_eval = coef2curve(x_eval, grid, coef, k=k)

# Plot the B-spline curves
plt.figure(figsize=(8, 6))
for i in range(num_spline):
plt.plot(x_eval[i], y_eval[i], label=f"Spline {i+1}")

plt.title("B-spline Curves")
plt.xlabel("X")
plt.ylabel("Y")
plt.legend()
plt.grid(True)
plt.show()

The parameters here will be the coefficients and the grid points. We can use higher-order polynomials to obtain a smoother curve.

Code

The above visualization already demonstrates the most critical aspect of utilizing B-Splines for the KAN Layer.

The KAN Layer’s definition, as outlined in the paper, is straightforward:

A single KAN Layer constitutes a forward pass of the aforementioned equation. To evaluate the spline, we can employ the function coef2curve(), which accepts the input x, coefficients, grid location, and outputs the spline’s output y.

# Copy from pykan, modified for better read
class KANLayer(nn.Module):
def __init__(self, ...):

def forward(self, x):
batch = x.shape[0]

# create the input grid
x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, device=self.device)).reshape(batch, self.size).permute(1, 0)

# b(x)
base = self.base_fun(x).permute(1, 0)

# coef2curve() takes the input, grid, and coefficient, to output the y value
y = coef2curve(x_eval=x, grid=self.grid[self.weight_sharing], coef=self.coef[self.weight_sharing], k=self.k, device=self.device) # shape (size, batch)
y = y.permute(1, 0) # shape (batch, size)

# the scale factor w, and adding base with spline
y = self.scale_base.unsqueeze(dim=0) * base + self.scale_sp.unsqueeze(dim=0) * y
y = self.mask[None, :] * y

# sum across batches for training
y = torch.sum(y.reshape(batch, self.out_dim, self.in_dim), dim=2) # shape (batch, out_dim)

return y

Let’s delve into coef2curve(). It is a function designed to perform a linear combination of the coefficients and the basis function.

# Copy from pykan, modified for better read
def coef2curve(x_eval, grid, coef, k, device="cpu"):
if coef.dtype != x_eval.dtype:
coef = coef.to(x_eval.dtype)
# i: batch
# linear combination of coefficient and the basis function
y_eval = torch.einsum('ij,ijk->ik', coef, B_batch(x_eval, grid, k, device=device))
return y_eval

Now, let’s examine the basis function. It precisely defines the B-Spline basis function.

# Copy from pykan, modified for better read
def B_batch(x, grid, k=0, extend=True, device='cpu'):
grid = grid.unsqueeze(dim=2).to(device)
x = x.unsqueeze(dim=1).to(device)

if k == 0:
value = (x >= grid[:, :-1]) * (x < grid[:, 1:])
else:
# The recurive defination come here
B_km1 = B_batch(x[:, 0], grid=grid[:, :, 0], k=k - 1, extend=False, device=device)
value = (x - grid[:, :-(k + 1)]) / (grid[:, k:-1] - grid[:, :-(k + 1)]) * B_km1[:, :-1] + (
grid[:, k + 1:] - x) / (grid[:, k + 1:] - grid[:, 1:(-k)]) * B_km1[:, 1:]
return value

With all components in place, we can proceed to conduct the forward pass in the toy problem. PyTorch will manage the optimization process seamlessly.

Important properties and result

In the paper, the author conducts an analysis with toy problems and presents some interesting properties of KAN. I highlighted a few which I found most interesting.

Avoid catastrophic forgetting

B-Splines exhibit a local effect only on a particular segment. When trained with data from different domains, the learned parameters in other domains are unaffected by the newly learned region. This has the potential to offer a solution for catastrophic forgetting.

As depicted in the paper, in a toy continual learning problem involving a 1D regression task with 5 Gaussian peaks, KANs can effectively avoid catastrophic forgetting, while MLPs display severe catastrophic forgetting.

Interpretation

Traditionally, neural networks are challenging to interpret. KANs, being mainly single-valued functions, make it much easier to reason about. In the paper, the author demonstrates an example of doing symbolic regression.

Copy from paper. Symbolic regression

By interacting with each individual single-valued function, we can reconstruct the generating function from the toy dataset.

Paper Summary

We review different elements of the KAN and provide an introduction to Splines following the original paper. It is highly recommend interested audience to read the original paper for more details and missed example and applications.

We just done go through the paper, but the learning doesn’t stop here.

Thoughts and Questions

Question: What’s the difference between KAN and MLP?

While learning about MLPs, we learned that one key component is non-linearity. Without non-linearity, applying multiple weight matrices reduces to applying a single matrix. In MLPs, we put the learnable parameters in the matrix.

However, from a different perspective, we can absorb the parameters into the activation function. Thus, we can treat MLP as learning a specific form of activation.

We can shift our perspective of the input and output from the traditional MLP, and we can see a KAN-like structure immediately.

KAN can be viewed as a shifting of computation sequence for MLP. We can absorb the weight into the (not fixed) activation function, leaving us with a pure summation. This is pretty natural, just as KART stated that: any multivariate continuous function f can be factorized into a composition and summation of single-valued functions.

Therefore, we can see that MLP is a more restricted form of KAN, where we pick a specific form of the activation function. We can use this freedom to inject more inductive bias into the network, making it more efficient.

While I am not conducting the math, my guess is that using ReLu to fit the function will have the same effect as using order 1 KAN, since order 1 KAN is just a piecewise linear fitting, and MLP with ReLu can achieve the same when the domain is scaled and shifted with specific parameters.

Question: Grid Size

The original paper treats grid location as a hyperparameter rather than a learnable parameter. It is restricted to a uniform grid where, during training, the grid size increases with a specific schedule, enabling “coarse to fine” learning. This adds more parameters to the system, with more segments for the spline.

Copy from https://www.brnt.eu/phd/node11.html

Relaxing the constraint to a learnable knot location should allow us to gain more interoperability, even with a smaller grid size. To make a spline move more drastically, the control points should be brought closer together. The location of the knot should provide a hint of a drastic change in the output domain and give us a sense of a sensitive domain or even shred light on underfitting.

Potential: over fitting without catastrophic forgetting

Using splines for KAN allows a certain ability to fit for special input domains. This could benefit coordinate-based neural networks, such as NeRF, for overfitting a 3D scene in a network.

Potential: other form of function

In KAN, we are not restricted to B-Spline. For example, we may not want the second order derivatives to be continuous. When examining a control system, we may want to model a discrete and non continuous force as the control signal. We may also use other non-parametric regression, such as Gassuain Process as our function. It leave us freedom to inject our inductive bias into the system.

Closing

Reviewing KAN actually give me a new presepcitve for MLP, and review the importance for the non-linearity. I have also learn more in depth knowledge about spline, and its potential in machine learning.

Reference

Original paper: https://arxiv.org/pdf/2404.19756
Repos: https://github.com/KindXiaoming/pykan

Excellent videos about spline:

Article and books about spline:

MIT book
Florent Brunet Phd thesis

--

--

Fei Cheung

Hongkonger, Maker, Teacher. Interested in all kind of stuff, from physics to Machine Learning. Lead Engineer in 2019 CES innovation award honoree.