xbtorch.patches.model
Decorators for patching PyTorch models to make them compatible with XBTorch operations.
Functions
|
Patch a PyTorch model for XBTorch compatibility. |
- xbtorch.patches.model.xbtorch_model(original_model)[source]
Patch a PyTorch model for XBTorch compatibility.
This function takes an existing model (with a
.modelattribute that is annn.Sequentialor similar container) and replaces supported layers (e.g.,nn.Linear,nn.Conv2d,nn.RNN,nn.LSTM) with their XBTorch equivalents. It preserves trained parameters by copying the state dictionary, and it attaches additional functionality for hardware-aware training and inference.Features
Device simulation: Layers are converted into XBTorch versions that support device-aware weight updates (noise, variability).
WAGE quantization: If enabled during initialization, activation, error, and weight quantization modules are automatically inserted.
Inference accelerator integration: If an inference accelerator was initialized, the returned model gains methods to toggle hardware-aware inference and to map weights onto simulated crossbar arrays.
Seamless PyTorch API: All layers remain compatible with PyTorch training, optimizers, and loss functions.
- param original_model:
A PyTorch model instance with a
.modelattribute (typically annn.Sequential) that contains the layers to be patched.- type original_model:
torch.nn.Module
- returns:
model – The patched model. Additional methods may be attached depending on initialization parameters:
model.xb_eval(enable=True/False)Toggle hardware-aware inference mode.model.initialize_array_mappings(output_polling_mode='avg', ...)Map weights onto the simulated accelerator crossbar.model.get_array_mappings()Retrieve conductance array mappings for inspection or reuse.
- rtype:
torch.nn.Module
- raises RuntimeError:
If XBTorch has not been initialized with
xbtorch.initialize().- raises RuntimeError:
If the provided model does not have a
.modelattribute.- raises ValueError:
If an encountered PyTorch layer type is not supported by XBTorch.
Notes
Unsupported modules are kept as-is, but a warning is printed.
Quantization requires WAGE parameters (bit-widths, rounding) to have been set during initialization.
Inference accelerator mappings assume that crossbar dimensions and encoding/mapping schemes are defined during initialization.