Documentation
Coset Documentation
Detailed project documentation for the Coset (Hierarchical Nested-Lattice Quantization) library.
Overview
Coset is a PyTorch library implementing Hierarchical Nested-Lattice Quantization (HNLQ) for quantization-aware training (QAT). It provides:
- E8 and D4 lattice support
- Integration with transformer models (BERT)
- QAT with cold start for stable training
Installation
Quick Setup
git clone https://github.com/coset/coset.git
cd coset
python3 -m venv venv
source venv/bin/activate
pip install -e .
pip install transformers torchvision scikit-learn matplotlibAPI Reference
Constructor Functions
from coset.core.e8.layers import create_e8_hnlq_linear
layer = create_e8_hnlq_linear(
in_dim=768,
out_dim=10,
q=4, # Quantization parameter
M=2, # Hierarchical levels
warmup_epochs=2,
enable_diagnostics=True,
weight_clip_value=2.0,
theta_trainable=True,
theta_init_value=0.0
)Scale Parameter Options
Learnable scale parameters (default):
layer = create_e8_hnlq_linear(
in_dim=768, out_dim=10,
theta_trainable=True,
theta_init_value=0.0
)Fixed scale parameters (deterministic):
layer = create_e8_hnlq_linear(
in_dim=768, out_dim=10,
theta_trainable=False,
theta_init_value=0.0
)QAT Methods
layer.update_epoch(epoch) # Update epoch for cold start
layer.enable_quantization()
layer.disable_quantization()
# Diagnostics
diagnostics = layer.get_diagnostic_summary()
quant_error = layer.get_quantization_error()
weight_stats = layer.get_weight_statistics()Configuration
from coset.core.base import LatticeConfig
config = LatticeConfig(
lattice_type="E8",
q=4,
M=2,
beta=1.0,
alpha=1.0,
decoding="full",
check_overload=False,
disable_scaling=False,
disable_overload_protection=True
)Usage Patterns
Binary Classification with BERT
import torch
from transformers import AutoTokenizer, AutoModel
from coset.core.e8.layers import create_e8_hnlq_linear
class QuantizedBERTClassifier(torch.nn.Module):
def __init__(self, num_classes=1):
super().__init__()
self.bert = AutoModel.from_pretrained('bert-base-uncased')
for param in self.bert.parameters():
param.requires_grad = False
self.classifier = create_e8_hnlq_linear(
in_dim=768, out_dim=num_classes,
warmup_epochs=2, enable_diagnostics=True,
weight_clip_value=2.0, theta_trainable=True
)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
logits = self.classifier(outputs.pooler_output)
return logits, self.sigmoid(logits)Training Loop
import torch.optim as optim
optimizer = optim.Adam(model.classifier.parameters(), lr=0.001)
criterion = torch.nn.BCEWithLogitsLoss()
model.train()
for epoch in range(15):
model.classifier.update_epoch(epoch)
for batch in train_loader:
input_ids, attention_mask, labels = batch
logits, _ = model(input_ids, attention_mask)
loss = criterion(logits.squeeze(), labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()Examples
Run the provided examples:
python examples/bert_binary_classifier_e8.py
python examples/bert_multiclass_classifier_e8.py
python examples/qat_cold_start_comparison.py
python examples/mnist_mlp_e8.pyPerformance
| Metric | Value |
|---|---|
| Training Speed | ~4.2s per epoch (batch 128) |
| Parameter Reduction | 99.3% (quantized head only) |
| Binary Classification Accuracy | ~84% |
| Multi-Class Accuracy | ~91% |
Library Components
Core Modules
coset.core.e8.layers— E8 lattice linear layerscoset.core.e8.codecs— E8 encoding/decodingcoset.core.base— Base configuration and utilitiescoset.core.d4.layers— D4 lattice layers
Lattice Types
- E8 — 8D optimal lattice, high precision
- D4 — 4D checkerboard lattice
Test Documentation
For detailed test coverage and results, see Test Documentation.
Development
pytest tests/
black coset examples
ruff check coset examples