xbtorch.quant.wage_qtorch
WAGE quantization routines adapted from QPyTorch example: https://github.com/Tiiiger/QPyTorch/blob/master/examples/WAGE/wage_qtorch.py
This module implements fixed-point quantization for weights, activations, and gradients, as well as a WAGEQuantizer wrapper for PyTorch modules.
Functions
- shift(x)
Scales tensor x to prevent overflow during fixed-point quantization.
- C(x, bits)
Clamps tensor x to the representable range given a bit width.
- QW(x, bits, scale=1.0, mode=”nearest”)
Quantizes weights x using fixed-point representation with scaling.
- QG(x, bits_G, lr, mode=”nearest”)
Quantizes gradients x to a fixed-point representation for WAGE updates.
Classes
- WAGEQuantizer
PyTorch module that performs activation and error quantization using WAGE.
Functions
|
Clamps the input to a representable range given the number of bits. |
|
Quantizes gradients x for WAGE updates. |
|
Quantizes weights x to fixed-point representation. |
|
Scales tensor to prevent overflow in fixed-point quantization. |
Classes
|
PyTorch module that performs WAGE quantization for activations and errors using QPyTorch's quantizer. |
- xbtorch.quant.wage_qtorch.C(x, bits)[source]
Clamps the input to a representable range given the number of bits.
- Parameters:
x (torch.Tensor) – Input tensor.
bits (int) – Number of quantization bits.
- Returns:
Clamped tensor.
- Return type:
torch.Tensor
- xbtorch.quant.wage_qtorch.QG(x, bits_G, lr, mode='nearest')[source]
Quantizes gradients x for WAGE updates.
- Parameters:
x (torch.Tensor) – Gradient tensor.
bits_G (int) – Bit-width of gradient representation.
lr (float) – Learning rate scaling factor.
mode (str, optional) – Rounding mode: “nearest” or “stochastic”.
- Returns:
Quantized gradients.
- Return type:
torch.Tensor
- xbtorch.quant.wage_qtorch.QW(x, bits, scale=1.0, mode='nearest')[source]
Quantizes weights x to fixed-point representation.
- Parameters:
x (torch.Tensor) – Input tensor.
bits (int) – Bit-width of weight representation.
scale (float, optional) – Scaling factor to normalize the layer weights.
mode (str, optional) – Rounding mode: “nearest” or “stochastic”.
- Returns:
Quantized weights.
- Return type:
torch.Tensor
- class xbtorch.quant.wage_qtorch.WAGEQuantizer(bits_A, bits_E, A_mode='nearest', E_mode='nearest')[source]
Bases:
ModulePyTorch module that performs WAGE quantization for activations and errors using QPyTorch’s quantizer.
- Parameters:
bits_A (int) – Bit-width for activation quantization.
bits_E (int) – Bit-width for error/gradient quantization.
A_mode (str, optional) – Rounding mode for activations.
E_mode (str, optional) – Rounding mode for errors.
- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.