xbtorch.quant.wage_init
Initialization routines for WAGE quantized networks, adapted from QPyTorch: https://github.com/Tiiiger/QPyTorch/blob/master/examples/WAGE/wage_qtorch.py
Functions
- truncated_normal_(tensor, mean=0, std=1)
Fill a tensor with values drawn from a truncated normal distribution.
- scale_limit(param, limit, bits_W)
Compute a weight scaling factor based on the WAGE quantization bit width.
- wage_init_(tensor, bits_W, factor=2.0, mode=”fan_in”)
Initialize a tensor for WAGE quantization with appropriate scaling limits.
Functions
|
Compute a scaling factor for WAGE weight initialization based on the quantization bit-width and the maximum absolute weight. |
|
Fill the input tensor with values drawn from a truncated normal distribution. |
|
Initialize a tensor for WAGE quantization using uniform distribution with limits determined by fan-in and quantization bit-width. |
- xbtorch.quant.wage_init.scale_limit(param, limit, bits_W)[source]
Compute a scaling factor for WAGE weight initialization based on the quantization bit-width and the maximum absolute weight.
- Parameters:
param (torch.nn.Parameter or torch.Tensor) – Parameter tensor with .weight_scale attribute to be set.
limit (float) – Maximum absolute value for weight initialization.
bits_W (int) – Bit width of the weight representation.
- Returns:
limit – Scaled maximum absolute value for weight initialization.
- Return type:
float
- xbtorch.quant.wage_init.truncated_normal_(tensor, mean=0, std=1)[source]
Fill the input tensor with values drawn from a truncated normal distribution. The values are truncated to [-2, 2] standard deviations.
- Parameters:
tensor (torch.Tensor) – Tensor to fill.
mean (float, default=0) – Mean of the normal distribution.
std (float, default=1) – Standard deviation of the normal distribution.
- xbtorch.quant.wage_init.wage_init_(tensor, bits_W, factor=2.0, mode='fan_in')[source]
Initialize a tensor for WAGE quantization using uniform distribution with limits determined by fan-in and quantization bit-width.
- Parameters:
tensor (torch.Tensor) – Tensor to initialize.
bits_W (int) – Bit-width of the weight representation.
factor (float, default=2.0) – Scaling factor for weight initialization.
mode (str, default="fan_in") – Initialization mode. Currently, only “fan_in” is supported.
- Raises:
NotImplementedError – If mode is not “fan_in”.
ValueError – If tensor has fewer than 2 dimensions.