Coset
Hierarchical Nested-Lattice Quantization for PyTorch
A high-performance PyTorch library implementing Hierarchical Nested-Lattice Quantization (HNLQ) for quantization-aware training (QAT) with transformer models.
Features
- E8 Lattice Support: High-dimensional E8 lattice quantization with optimized algorithms
- Transformer Integration: Pre-trained BERT with quantized classification heads
- QAT with Cold Start: Gradual quantization activation for stable training
- Constructor-Based API: Easy-to-use layer constructors for different lattices
- Flexible Scale Parameters: Learnable or fixed scale parameters for quantization
- Comprehensive Examples: Binary and multi-class classification examples
Installation
# Clone the repository
git clone https://github.com/coset/coset.git
cd coset
# Create virtual environment
python3 -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
# Install dependencies
pip install -e .
pip install transformers torchvision scikit-learn matplotlibQuick Start
import torch
from transformers import AutoTokenizer, AutoModel
from coset.core.e8.layers import create_e8_hnlq_linear
# Create a quantized BERT classifier
class QuantizedBERTClassifier(torch.nn.Module):
def __init__(self, num_classes=1):
super().__init__()
self.bert = AutoModel.from_pretrained('bert-base-uncased')
for param in self.bert.parameters():
param.requires_grad = False
self.classifier = create_e8_hnlq_linear(
in_dim=768, out_dim=num_classes,
warmup_epochs=2, enable_diagnostics=True,
weight_clip_value=2.0, theta_trainable=True
)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
logits = self.classifier(outputs.pooler_output)
return logits, self.sigmoid(logits)
model = QuantizedBERTClassifier(num_classes=1)
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')Documentation
- Full Documentation — API reference, configuration, and usage patterns
- Publication — Research paper and citation information
- Test Documentation — Testing overview and results
Examples
python examples/bert_binary_classifier_e8.py
python examples/bert_multiclass_classifier_e8.py
python examples/qat_cold_start_comparison.py