# Copyright 2024 Li, Ziyao
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import *
class SplineLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int, init_scale: float = 0.1, **kw) -> None:
self.init_scale = init_scale
super().__init__(in_features, out_features, bias=False, **kw)
def reset_parameters(self) -> None:
self.weight, mean=0, std=self.init_scale)
nn.init.trunc_normal_(
class RadialBasisFunction(nn.Module):
def __init__(
self,
float = -2.,
grid_min: float = 2.,
grid_max: int = 8,
num_grids: float = None, # larger denominators lead to smoother basis
denominator:
):super().__init__()
self.grid_min = grid_min
self.grid_max = grid_max
self.num_grids = num_grids
= torch.linspace(grid_min, grid_max, num_grids)
grid self.grid = torch.nn.Parameter(grid, requires_grad=False)
self.denominator = denominator or (grid_max - grid_min) / (num_grids - 1)
def forward(self, x):
return torch.exp(-((x[..., None] - self.grid) / self.denominator) ** 2)
class FastKANLayer(nn.Module):
def __init__(
self,
int,
input_dim: int,
output_dim: float = -2.,
grid_min: float = 2.,
grid_max: int = 8,
num_grids: bool = True,
use_base_update: bool = False,
use_layernorm: = F.silu,
base_activation float = 0.1,
spline_weight_init_scale: -> None:
) super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.layernorm = None
if use_layernorm:
assert input_dim > 1, "Do not use layernorms on 1D inputs. Set `use_layernorm=False`."
self.layernorm = nn.LayerNorm(input_dim)
self.rbf = RadialBasisFunction(grid_min, grid_max, num_grids)
self.spline_linear = SplineLinear(input_dim * num_grids, output_dim, spline_weight_init_scale)
self.use_base_update = use_base_update
if use_base_update:
self.base_activation = base_activation
self.base_linear = nn.Linear(input_dim, output_dim)
def forward(self, x, use_layernorm=True):
if self.layernorm is not None and use_layernorm:
= self.rbf(self.layernorm(x))
spline_basis else:
= self.rbf(x)
spline_basis = self.spline_linear(spline_basis.view(*spline_basis.shape[:-2], -1))
ret if self.use_base_update:
= self.base_linear(self.base_activation(x))
base = ret + base
ret return ret
def plot_curve(
self,
int,
input_index: int,
output_index: int = 1000,
num_pts: int = 2
num_extrapolate_bins:
):'''this function returns the learned curves in a FastKANLayer.
input_index: the selected index of the input, in [0, input_dim) .
output_index: the selected index of the output, in [0, output_dim) .
num_pts: num of points sampled for the curve.
num_extrapolate_bins (N_e): num of bins extrapolating from the given grids. The curve
will be calculate in the range of [grid_min - h * N_e, grid_max + h * N_e].
'''
= self.rbf.num_grids
ng = self.rbf.denominator
h assert input_index < self.input_dim
assert output_index < self.output_dim
= self.spline_linear.weight[
w * ng : (input_index + 1) * ng
output_index, input_index # num_grids,
] = torch.linspace(
x self.rbf.grid_min - num_extrapolate_bins * h,
self.rbf.grid_max + num_extrapolate_bins * h,
num_pts# num_pts, num_grids
) with torch.no_grad():
= (w * self.rbf(x.to(w.dtype))).sum(-1)
y return x, y
class FastKAN(nn.Module):
def __init__(
self,
int],
layers_hidden: List[float = -2.,
grid_min: float = 2.,
grid_max: int = 8,
num_grids: bool = True,
use_base_update: = F.silu,
base_activation float = 0.1,
spline_weight_init_scale: -> None:
) super().__init__()
self.layers = nn.ModuleList([
FastKANLayer(
in_dim, out_dim,=grid_min,
grid_min=grid_max,
grid_max=num_grids,
num_grids=use_base_update,
use_base_update=base_activation,
base_activation=spline_weight_init_scale,
spline_weight_init_scalefor in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
)
])
def forward(self, x):
for layer in self.layers:
= layer(x)
x return x
KANs with RBFs
Replace B-Splines with Radial Basis Functions. They should be faster to compute. Following base code is taken from fast-KAN
Function Fitting
We will take some interesting test functions from Wavelet Regression notebook.
The Doppler function is \[ f(x) = x(1-x) \sin(\frac{2.1\pi}{x+0.05}) \\ x \sim U[0,1] \]
import torch
import matplotlib.pyplot as plt
import numpy as np
from kan.utils import create_dataset
torch.set_default_dtype(torch.float64)
= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device print('device is: ',device)
= lambda x: x[:,[0]]*(1-x[:,[0]])*torch.sin((2*np.pi)/(x[:,[0]]+.1))
f
= create_dataset(f, n_var=1, device=device, ranges=[0,1])
dataset
print('train input data shape', dataset['train_input'].shape)
print('train label data shape', dataset['train_label'].shape)
'train_input'],dataset['train_label']) plt.scatter(dataset[
device is: cpu
train input data shape torch.Size([1000, 1])
train label data shape torch.Size([1000, 1])
# Define model
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
= torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
# create a KAN: 1D inputs, 1D output, and 5 hidden neurons.
=FastKAN([1,1,5,5],grid_min=0,grid_max=1,num_grids=5)
model
model.to(device)
# Define loss
= 0.01
learning_rate = nn.MSELoss()
criterion = torch.optim.Adam(model.parameters(),lr=learning_rate) optimizer
def train_network(model,optimizer,criterion,X_train,y_train,X_test,y_test,num_epochs,train_losses,test_losses):
for epoch in range(num_epochs):
#clear out the gradients from the last step loss.backward()
model.train()
optimizer.zero_grad()
#forward feed
= model(X_train)
output_train
#calculate the loss
= criterion(output_train, y_train)
loss_train
#backward propagation: calculate gradients
loss_train.backward()
#update the weights
optimizer.step()
eval()
model.= model(X_test)
output_test = criterion(output_test,y_test)
loss_test
= loss_train.item()
train_losses[epoch] = loss_test.item()
test_losses[epoch]
if (epoch + 1) % 50 == 0:
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {loss_train.item():.4f}, Test Loss: {loss_test.item():.4f}")
return model, train_losses, test_losses
import numpy as np
= 1000
num_epochs = np.zeros(num_epochs)
train_losses = np.zeros(num_epochs)
test_losses = dataset['train_input']
X_train = dataset['train_label']
y_train = dataset['test_input']
X_test = dataset['test_label']
y_test = train_network(model,optimizer,criterion,X_train,y_train,X_test,y_test,num_epochs,train_losses,test_losses) mlp, train_losses, test_losses
/opt/miniconda3/envs/ai839/lib/python3.10/site-packages/torch/nn/modules/loss.py:538: UserWarning: Using a target size (torch.Size([1000, 1])) that is different to the input size (torch.Size([1000, 5])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
return F.mse_loss(input, target, reduction=self.reduction)
Epoch 50/1000, Train Loss: 0.0138, Test Loss: 0.0137
Epoch 100/1000, Train Loss: 0.0074, Test Loss: 0.0070
Epoch 150/1000, Train Loss: 0.0062, Test Loss: 0.0059
Epoch 200/1000, Train Loss: 0.0060, Test Loss: 0.0057
Epoch 250/1000, Train Loss: 0.0059, Test Loss: 0.0057
Epoch 300/1000, Train Loss: 0.0058, Test Loss: 0.0055
Epoch 350/1000, Train Loss: 0.0058, Test Loss: 0.0054
Epoch 400/1000, Train Loss: 0.0058, Test Loss: 0.0054
Epoch 450/1000, Train Loss: 0.0057, Test Loss: 0.0053
Epoch 500/1000, Train Loss: 0.0057, Test Loss: 0.0053
Epoch 550/1000, Train Loss: 0.0057, Test Loss: 0.0053
Epoch 600/1000, Train Loss: 0.0056, Test Loss: 0.0052
Epoch 650/1000, Train Loss: 0.0057, Test Loss: 0.0053
Epoch 700/1000, Train Loss: 0.0059, Test Loss: 0.0053
Epoch 750/1000, Train Loss: 0.0056, Test Loss: 0.0052
Epoch 800/1000, Train Loss: 0.0055, Test Loss: 0.0052
Epoch 850/1000, Train Loss: 0.0055, Test Loss: 0.0051
Epoch 900/1000, Train Loss: 0.0054, Test Loss: 0.0050
Epoch 950/1000, Train Loss: 0.0055, Test Loss: 0.0050
Epoch 1000/1000, Train Loss: 0.0055, Test Loss: 0.0051
# let us look at the recontruction
= dataset['train_input']
X = 1000
n 0] = torch.linspace(0,1,steps=n)
X[:,= f(X)
y = y[:,0].detach().numpy()
y = model.forward(X)
yh
= yh[:,0].detach().numpy()
yh ='blue')
plt.plot(y,color='red')
plt.plot(yh, color
plt.show()-y, color='black')
plt.plot(yh-1,1)
plt.ylim( plt.show()