Documentation

Coset Documentation

Detailed project documentation for the Coset (Hierarchical Nested-Lattice Quantization) library.

Overview

Coset is a PyTorch library implementing Hierarchical Nested-Lattice Quantization (HNLQ) for quantization-aware training (QAT). It provides:

  • E8 and D4 lattice support
  • Integration with transformer models (BERT)
  • QAT with cold start for stable training

Installation

Quick Setup

git clone https://github.com/coset/coset.git
cd coset
python3 -m venv venv
source venv/bin/activate
pip install -e .
pip install transformers torchvision scikit-learn matplotlib

API Reference

Constructor Functions

from coset.core.e8.layers import create_e8_hnlq_linear

layer = create_e8_hnlq_linear(
    in_dim=768,
    out_dim=10,
    q=4,                    # Quantization parameter
    M=2,                    # Hierarchical levels
    warmup_epochs=2,
    enable_diagnostics=True,
    weight_clip_value=2.0,
    theta_trainable=True,
    theta_init_value=0.0
)

Scale Parameter Options

Learnable scale parameters (default):

layer = create_e8_hnlq_linear(
    in_dim=768, out_dim=10,
    theta_trainable=True,
    theta_init_value=0.0
)

Fixed scale parameters (deterministic):

layer = create_e8_hnlq_linear(
    in_dim=768, out_dim=10,
    theta_trainable=False,
    theta_init_value=0.0
)

QAT Methods

layer.update_epoch(epoch)    # Update epoch for cold start
layer.enable_quantization()
layer.disable_quantization()

# Diagnostics
diagnostics = layer.get_diagnostic_summary()
quant_error = layer.get_quantization_error()
weight_stats = layer.get_weight_statistics()

Configuration

from coset.core.base import LatticeConfig

config = LatticeConfig(
    lattice_type="E8",
    q=4,
    M=2,
    beta=1.0,
    alpha=1.0,
    decoding="full",
    check_overload=False,
    disable_scaling=False,
    disable_overload_protection=True
)

Usage Patterns

Binary Classification with BERT

import torch
from transformers import AutoTokenizer, AutoModel
from coset.core.e8.layers import create_e8_hnlq_linear

class QuantizedBERTClassifier(torch.nn.Module):
    def __init__(self, num_classes=1):
        super().__init__()
        self.bert = AutoModel.from_pretrained('bert-base-uncased')
        for param in self.bert.parameters():
            param.requires_grad = False
        self.classifier = create_e8_hnlq_linear(
            in_dim=768, out_dim=num_classes,
            warmup_epochs=2, enable_diagnostics=True,
            weight_clip_value=2.0, theta_trainable=True
        )
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        logits = self.classifier(outputs.pooler_output)
        return logits, self.sigmoid(logits)

Training Loop

import torch.optim as optim

optimizer = optim.Adam(model.classifier.parameters(), lr=0.001)
criterion = torch.nn.BCEWithLogitsLoss()

model.train()
for epoch in range(15):
    model.classifier.update_epoch(epoch)
    for batch in train_loader:
        input_ids, attention_mask, labels = batch
        logits, _ = model(input_ids, attention_mask)
        loss = criterion(logits.squeeze(), labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

Examples

Run the provided examples:

python examples/bert_binary_classifier_e8.py
python examples/bert_multiclass_classifier_e8.py
python examples/qat_cold_start_comparison.py
python examples/mnist_mlp_e8.py

Performance

Metric Value
Training Speed ~4.2s per epoch (batch 128)
Parameter Reduction 99.3% (quantized head only)
Binary Classification Accuracy ~84%
Multi-Class Accuracy ~91%

Library Components

Core Modules

  • coset.core.e8.layers — E8 lattice linear layers
  • coset.core.e8.codecs — E8 encoding/decoding
  • coset.core.base — Base configuration and utilities
  • coset.core.d4.layers — D4 lattice layers

Lattice Types

  • E8 — 8D optimal lattice, high precision
  • D4 — 4D checkerboard lattice

Test Documentation

For detailed test coverage and results, see Test Documentation.

Development

pytest tests/
black coset examples
ruff check coset examples