Coset

Hierarchical Nested-Lattice Quantization for PyTorch

A high-performance PyTorch library implementing Hierarchical Nested-Lattice Quantization (HNLQ) for quantization-aware training (QAT) with transformer models.

Features

  • E8 Lattice Support: High-dimensional E8 lattice quantization with optimized algorithms
  • Transformer Integration: Pre-trained BERT with quantized classification heads
  • QAT with Cold Start: Gradual quantization activation for stable training
  • Constructor-Based API: Easy-to-use layer constructors for different lattices
  • Flexible Scale Parameters: Learnable or fixed scale parameters for quantization
  • Comprehensive Examples: Binary and multi-class classification examples

Installation

# Clone the repository
git clone https://github.com/coset/coset.git
cd coset

# Create virtual environment
python3 -m venv venv
source venv/bin/activate  # On Windows: venv\Scripts\activate

# Install dependencies
pip install -e .
pip install transformers torchvision scikit-learn matplotlib

Quick Start

import torch
from transformers import AutoTokenizer, AutoModel
from coset.core.e8.layers import create_e8_hnlq_linear

# Create a quantized BERT classifier
class QuantizedBERTClassifier(torch.nn.Module):
    def __init__(self, num_classes=1):
        super().__init__()
        self.bert = AutoModel.from_pretrained('bert-base-uncased')
        for param in self.bert.parameters():
            param.requires_grad = False
        self.classifier = create_e8_hnlq_linear(
            in_dim=768, out_dim=num_classes,
            warmup_epochs=2, enable_diagnostics=True,
            weight_clip_value=2.0, theta_trainable=True
        )
        self.sigmoid = torch.nn.Sigmoid()
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        logits = self.classifier(outputs.pooler_output)
        return logits, self.sigmoid(logits)

model = QuantizedBERTClassifier(num_classes=1)
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

Documentation

Examples

python examples/bert_binary_classifier_e8.py
python examples/bert_multiclass_classifier_e8.py
python examples/qat_cold_start_comparison.py