xbtorch.quant.wage_init

Initialization routines for WAGE quantized networks, adapted from QPyTorch: https://github.com/Tiiiger/QPyTorch/blob/master/examples/WAGE/wage_qtorch.py

Functions

truncated_normal_(tensor, mean=0, std=1)

Fill a tensor with values drawn from a truncated normal distribution.

scale_limit(param, limit, bits_W)

Compute a weight scaling factor based on the WAGE quantization bit width.

wage_init_(tensor, bits_W, factor=2.0, mode=”fan_in”)

Initialize a tensor for WAGE quantization with appropriate scaling limits.

Functions

scale_limit(param, limit, bits_W)

Compute a scaling factor for WAGE weight initialization based on the quantization bit-width and the maximum absolute weight.

truncated_normal_(tensor[, mean, std])

Fill the input tensor with values drawn from a truncated normal distribution.

wage_init_(tensor, bits_W[, factor, mode])

Initialize a tensor for WAGE quantization using uniform distribution with limits determined by fan-in and quantization bit-width.

xbtorch.quant.wage_init.scale_limit(param, limit, bits_W)[source]

Compute a scaling factor for WAGE weight initialization based on the quantization bit-width and the maximum absolute weight.

Parameters:
  • param (torch.nn.Parameter or torch.Tensor) – Parameter tensor with .weight_scale attribute to be set.

  • limit (float) – Maximum absolute value for weight initialization.

  • bits_W (int) – Bit width of the weight representation.

Returns:

limit – Scaled maximum absolute value for weight initialization.

Return type:

float

xbtorch.quant.wage_init.truncated_normal_(tensor, mean=0, std=1)[source]

Fill the input tensor with values drawn from a truncated normal distribution. The values are truncated to [-2, 2] standard deviations.

Parameters:
  • tensor (torch.Tensor) – Tensor to fill.

  • mean (float, default=0) – Mean of the normal distribution.

  • std (float, default=1) – Standard deviation of the normal distribution.

xbtorch.quant.wage_init.wage_init_(tensor, bits_W, factor=2.0, mode='fan_in')[source]

Initialize a tensor for WAGE quantization using uniform distribution with limits determined by fan-in and quantization bit-width.

Parameters:
  • tensor (torch.Tensor) – Tensor to initialize.

  • bits_W (int) – Bit-width of the weight representation.

  • factor (float, default=2.0) – Scaling factor for weight initialization.

  • mode (str, default="fan_in") – Initialization mode. Currently, only “fan_in” is supported.

Raises:
  • NotImplementedError – If mode is not “fan_in”.

  • ValueError – If tensor has fewer than 2 dimensions.