import torch
import torch.nn as nn
# This is inspired by Kolmogorov-Arnold Networks but using Chebyshev polynomials instead of splines coefficients
class ChebyKANLayer(nn.Module):
def __init__(self, input_dim, output_dim, degree):
super(ChebyKANLayer, self).__init__()
self.inputdim = input_dim
self.outdim = output_dim
self.degree = degree
self.cheby_coeffs = nn.Parameter(torch.empty(input_dim, output_dim, degree + 1))
self.cheby_coeffs, mean=0.0, std=1 / (input_dim * (degree + 1)))
nn.init.normal_(self.register_buffer("arange", torch.arange(0, degree + 1, 1))
def forward(self, x):
# Since Chebyshev polynomial is defined in [-1, 1]
# We need to normalize x to [-1, 1] using tanh
= torch.tanh(x)
x # View and repeat input degree + 1 times
= x.view((-1, self.inputdim, 1)).expand(
x -1, -1, self.degree + 1
# shape = (batch_size, inputdim, self.degree + 1)
) # Apply acos
= x.acos()
x # Multiply by arange [0 .. degree]
*= self.arange
x # Apply cos
= x.cos()
x # Compute the Chebyshev interpolation
= torch.einsum(
y "bid,iod->bo", x, self.cheby_coeffs
# shape = (batch_size, outdim)
) = y.view(-1, self.outdim)
y return y
KANs with Chebyshev
Replace B-Splines with Radial Basis Functions. They should be faster to compute. Following base code is taken from fast-KAN
Function Fitting
We will take some interesting test functions from Wavelet Regression notebook.
The Doppler function is \[ f(x) = x(1-x) \sin(\frac{2.1\pi}{x+0.05}) \\ x \sim U[0,1] \]
import torch
import matplotlib.pyplot as plt
import numpy as np
from kan.utils import create_dataset
torch.set_default_dtype(torch.float64)
= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device print('device is: ',device)
= lambda x: x[:,[0]]*(1-x[:,[0]])*torch.sin((2*np.pi)/(x[:,[0]]+.1))
f
= create_dataset(f, n_var=1, device=device, ranges=[0,1])
dataset
print('train input data shape', dataset['train_input'].shape)
print('train label data shape', dataset['train_label'].shape)
'train_input'],dataset['train_label']) plt.scatter(dataset[
device is: cpu
train input data shape torch.Size([1000, 1])
train label data shape torch.Size([1000, 1])
# Define model
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
class ChebyKAN(nn.Module):
def __init__(self):
super(ChebyKAN, self).__init__()
self.chebykan1 = ChebyKANLayer(1, 5, 4)
self.ln1 = nn.LayerNorm(5) # To avoid gradient vanishing caused by tanh
self.chebykan2 = ChebyKANLayer(5, 5, 4)
self.ln2 = nn.LayerNorm(5)
self.chebykan3 = ChebyKANLayer(5, 1, 4)
def forward(self, x):
= self.chebykan1(x)
x = self.ln1(x)
x = self.chebykan2(x)
x = self.ln2(x)
x = self.chebykan3(x)
x return x
= torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
# create a KAN: 1D inputs, 1D output, and 5 hidden neurons.
=ChebyKAN()
model
model.to(device)
# Define loss
= 0.01
learning_rate = nn.MSELoss()
criterion = torch.optim.Adam(model.parameters(),lr=learning_rate) optimizer
def train_network(model,optimizer,criterion,X_train,y_train,X_test,y_test,num_epochs,train_losses,test_losses):
for epoch in range(num_epochs):
#clear out the gradients from the last step loss.backward()
model.train()
optimizer.zero_grad()
#forward feed
= model(X_train)
output_train
#calculate the loss
= criterion(output_train, y_train)
loss_train
#backward propagation: calculate gradients
loss_train.backward()
#update the weights
optimizer.step()
eval()
model.= model(X_test)
output_test = criterion(output_test,y_test)
loss_test
= loss_train.item()
train_losses[epoch] = loss_test.item()
test_losses[epoch]
if (epoch + 1) % 50 == 0:
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {loss_train.item():.4f}, Test Loss: {loss_test.item():.4f}")
return model, train_losses, test_losses
import numpy as np
= 1000
num_epochs = np.zeros(num_epochs)
train_losses = np.zeros(num_epochs)
test_losses = dataset['train_input']
X_train = dataset['train_label']
y_train = dataset['test_input']
X_test = dataset['test_label']
y_test = train_network(model,optimizer,criterion,X_train,y_train,X_test,y_test,num_epochs,train_losses,test_losses) mlp, train_losses, test_losses
Epoch 50/1000, Train Loss: 0.0079, Test Loss: 0.0079
Epoch 100/1000, Train Loss: 0.0049, Test Loss: 0.0046
Epoch 150/1000, Train Loss: 0.0041, Test Loss: 0.0037
Epoch 200/1000, Train Loss: 0.0027, Test Loss: 0.0025
Epoch 250/1000, Train Loss: 0.0018, Test Loss: 0.0016
Epoch 300/1000, Train Loss: 0.0014, Test Loss: 0.0014
Epoch 350/1000, Train Loss: 0.0085, Test Loss: 0.0093
Epoch 400/1000, Train Loss: 0.0013, Test Loss: 0.0013
Epoch 450/1000, Train Loss: 0.0007, Test Loss: 0.0008
Epoch 500/1000, Train Loss: 0.0005, Test Loss: 0.0005
Epoch 550/1000, Train Loss: 0.0004, Test Loss: 0.0005
Epoch 600/1000, Train Loss: 0.0006, Test Loss: 0.0006
Epoch 650/1000, Train Loss: 0.0002, Test Loss: 0.0002
Epoch 700/1000, Train Loss: 0.0002, Test Loss: 0.0003
Epoch 750/1000, Train Loss: 0.0001, Test Loss: 0.0001
Epoch 800/1000, Train Loss: 0.0001, Test Loss: 0.0002
Epoch 850/1000, Train Loss: 0.0001, Test Loss: 0.0001
Epoch 900/1000, Train Loss: 0.0001, Test Loss: 0.0001
Epoch 950/1000, Train Loss: 0.0001, Test Loss: 0.0001
Epoch 1000/1000, Train Loss: 0.0003, Test Loss: 0.0003
# let us look at the recontruction
= dataset['train_input']
X = 1000
n 0] = torch.linspace(0,1,steps=n)
X[:,= f(X)
y = y[:,0].detach().numpy()
y = model.forward(X)
yh
= yh[:,0].detach().numpy()
yh ='blue')
plt.plot(y,color='red')
plt.plot(yh, color
plt.show()-y, color='black')
plt.plot(yh-1,1)
plt.ylim( plt.show()