Source code for xbtorch

"""
XBTorch root package
=====================

This module provides the root API for XBTorch, including:

- The `XBParams` singleton class for global configuration.
- Helper functions to get/set parameters and initialize the library.
- Lists of supported activation layers, parameterized layers, and parameter-less layers.
"""

from functools import partial

import torch.nn as nn

import xbtorch
import xbtorch.quant.wage_qtorch as wage_qtorch

[docs] class XBParams: """ Singleton class to store global XBTorch parameters. This class ensures a single global configuration dictionary that controls decomposition, device, quantization, weight ranges, and accelerators. Attributes ---------- _global_dict : dict Dictionary storing all global parameters and flags. _wage_defaults : dict Default settings for WAGE quantization. """ _instance = None _global_dict = {'initialized': False, 'decomposition_algorithm': None, 'device_type': None, 'pytorch_device': 'cpu', 'inference_accelerator': None, 'weight_range': (-1, 1), } _wage_defaults = { 'wl_weight': 2, # 2 = ternary weights 'wl_activation': 8, 'wl_grad': 8, 'wl_error': 8, 'rounding_weight' : 'nearest', 'rounding_activation' : 'nearest', 'rounding_grad' : 'nearest', 'rounding_error' : 'nearest', } def __new__(cls): if cls._instance is None: cls._instance = super(XBParams, cls).__new__(cls) return cls._instance
[docs] def set_var(self, key, value): self._global_dict[key] = value
[docs] def get_var(self, key, default=None): return self._global_dict.get(key, default)
[docs] def initialize(self, decomposition_algorithm=None, device_type=None, weight_range=(-1, 1), pytorch_device='cpu', wage_quantize=False, wage_params={}, inference_accelerator=None): """ Initialize the XBTorch environment. Sets up decomposition algorithm, device, weight ranges, WAGE quantization, and optional inference accelerators. Also migrates tensors to the selected device. Parameters ---------- decomposition_algorithm : xbtorch.decomposition.base.GenericDecomposition, optional Decomposition algorithm to use for layers (default is None). device_type : xbtorch.devices.base.GenericDevice, optional Hardware device abstraction (default is None). weight_range : tuple of float, optional Min and max allowed weights, default is (-1, 1). pytorch_device : str or torch.device, optional PyTorch device for tensor allocation (default 'cpu'). wage_quantize : bool, optional Whether to enable WAGE quantization (default False). wage_params : dict, optional Overrides for WAGE quantization defaults. inference_accelerator : xbtorch.deployment.base.GenericAccelerator, optional Inference accelerator to use. Raises ------ TypeError If provided decomposition_algorithm, device_type, or inference_accelerator is not of the expected type, or weight_range is invalid. """ if decomposition_algorithm and not issubclass(type(decomposition_algorithm), xbtorch.decomposition.base.GenericDecomposition): raise TypeError("Invalid decomposition algorithm provided") if device_type and not issubclass(type(device_type), xbtorch.devices.base.GenericDevice): raise TypeError("Invalid device type provided") if wage_quantize: self._global_dict['wage_quantize'] = True self._global_dict['wage_params'] = {} for key, val in self._wage_defaults.items(): if key not in wage_params: self._global_dict['wage_params'][key] = val # use defaults else: self._global_dict['wage_params'][key] = wage_params[key] # weight quantizer if self._global_dict['wage_params']['wl_weight'] == -1: self._global_dict['wage_params']['quantizer_weight'] = None else: self._global_dict['wage_params']['quantizer_weight'] = lambda x, scale: wage_qtorch.QW( x, self._global_dict['wage_params']['wl_weight'], scale, mode=self._global_dict['wage_params']['rounding_weight'] ) # gradient quantizer if self._global_dict['wage_params']['wl_grad'] == -1: self._global_dict['wage_params']['quantizer_grad'] = None self._global_dict['wage_params']['grad_clip'] = None else: self._global_dict['wage_params']['quantizer_grad'] = lambda x, lr: wage_qtorch.QG( x, self._global_dict['wage_params']['wl_grad'], lr, mode=self._global_dict['wage_params']['rounding_grad'] ) self._global_dict['wage_params']['grad_clip'] = lambda x: wage_qtorch.C(x, self._global_dict['wage_params']['wl_weight']) # activation and error quantizer self._global_dict['wage_params']['quantizer_act_error'] = partial( wage_qtorch.WAGEQuantizer, A_mode=self._global_dict['wage_params']['rounding_activation'], E_mode=self._global_dict['wage_params']['rounding_error'] ) if weight_range and not (type(weight_range) == tuple or len(weight_range) != 2 or weight_range[0] >= weight_range[1]): raise TypeError("Invalid weight range provided") if inference_accelerator and not issubclass(type(inference_accelerator), xbtorch.deployment.base.GenericAccelerator): raise TypeError("Invalid accelerator algorithm provided") # Setting the variables self._global_dict['initialized'] = True self._global_dict['decomposition_algorithm'] = decomposition_algorithm self._global_dict['device_type'] = device_type self._global_dict['pytorch_device'] = pytorch_device self._global_dict['weight_range'] = weight_range self._global_dict['inference_accelerator'] = inference_accelerator # if a device_type was provided, migrate local tensors if needed if (pytorch_device != 'cpu'): if (device_type and issubclass(type(device_type), xbtorch.devices.base.TabularDevice)): device_type.set_G = device_type.set_G.to(pytorch_device) device_type.reset_G = device_type.reset_G.to(pytorch_device) device_type.set_dG = device_type.set_dG.to(pytorch_device) device_type.reset_dG = device_type.reset_dG.to(pytorch_device) device_type.set_cdf = device_type.set_cdf.to(pytorch_device) device_type.reset_cdf = device_type.reset_cdf.to(pytorch_device) device_type.min_conductance = device_type.min_conductance.to(pytorch_device) device_type.max_conductance = device_type.max_conductance.to(pytorch_device)
# print('Initialization complete..\n') def get_xbtorch_param(key, default=None): """ Retrieve a global XBTorch parameter. Convenience function to access the `XBParams` singleton. Parameters ---------- key : str Name of the parameter to retrieve. default : any, optional Value to return if key is not found (default None). Returns ------- any The value of the requested parameter. """ return XBParams().get_var(key, default)
[docs] def initialize(*args, **kwargs): """ Initialize the XBTorch library. Convenience function that calls `XBParams.initialize`. See Also -------- XBParams.initialize """ XBParams().initialize(*args, **kwargs)
activation_types = ( nn.ReLU, nn.Sigmoid, nn.Tanh, nn.LeakyReLU, nn.Softmax, nn.Softplus, nn.Softsign, nn.ELU, nn.PReLU, nn.SELU, nn.GELU, nn.Hardtanh, nn.Hardsigmoid, nn.Hardshrink, nn.Hardswish, nn.LogSigmoid, nn.SiLU, nn.Mish, nn.LogSoftmax, ) layer_types = ( nn.Linear, nn.Conv2d, nn.RNN, nn.LSTM, # future work # nn.Conv3d, # nn.ConvTranspose2d, # nn.ConvTranspose3d, # nn.BatchNorm1d, # nn.BatchNorm2d, # nn.BatchNorm3d, # nn.InstanceNorm1d, # nn.InstanceNorm2d, # nn.InstanceNorm3d, # nn.LayerNorm, # nn.GroupNorm, # nn.Embedding, # nn.GRU, ) # Parameter-less layers misc_types = ( # Dropout nn.Dropout, nn.Dropout2d, nn.Dropout3d, # Pooling nn.MaxPool1d, nn.MaxPool2d, nn.MaxPool3d, nn.AvgPool1d, nn.AvgPool2d, nn.AvgPool3d, nn.AdaptiveMaxPool1d, nn.AdaptiveMaxPool2d, nn.AdaptiveMaxPool3d, nn.AdaptiveAvgPool1d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d, # Flatten nn.Flatten, ) __all__ = ['initialize', 'XBParams']