xbtorch.patches.model

Decorators for patching PyTorch models to make them compatible with XBTorch operations.

Functions

xbtorch_model(original_model)

Patch a PyTorch model for XBTorch compatibility.

xbtorch.patches.model.xbtorch_model(original_model)[source]

Patch a PyTorch model for XBTorch compatibility.

This function takes an existing model (with a .model attribute that is an nn.Sequential or similar container) and replaces supported layers (e.g., nn.Linear, nn.Conv2d, nn.RNN, nn.LSTM) with their XBTorch equivalents. It preserves trained parameters by copying the state dictionary, and it attaches additional functionality for hardware-aware training and inference.

Features

  • Device simulation: Layers are converted into XBTorch versions that support device-aware weight updates (noise, variability).

  • WAGE quantization: If enabled during initialization, activation, error, and weight quantization modules are automatically inserted.

  • Inference accelerator integration: If an inference accelerator was initialized, the returned model gains methods to toggle hardware-aware inference and to map weights onto simulated crossbar arrays.

  • Seamless PyTorch API: All layers remain compatible with PyTorch training, optimizers, and loss functions.

param original_model:

A PyTorch model instance with a .model attribute (typically an nn.Sequential) that contains the layers to be patched.

type original_model:

torch.nn.Module

returns:

model – The patched model. Additional methods may be attached depending on initialization parameters:

  • model.xb_eval(enable=True/False) Toggle hardware-aware inference mode.

  • model.initialize_array_mappings(output_polling_mode='avg', ...) Map weights onto the simulated accelerator crossbar.

  • model.get_array_mappings() Retrieve conductance array mappings for inspection or reuse.

rtype:

torch.nn.Module

raises RuntimeError:

If XBTorch has not been initialized with xbtorch.initialize().

raises RuntimeError:

If the provided model does not have a .model attribute.

raises ValueError:

If an encountered PyTorch layer type is not supported by XBTorch.

Notes

  • Unsupported modules are kept as-is, but a warning is printed.

  • Quantization requires WAGE parameters (bit-widths, rounding) to have been set during initialization.

  • Inference accelerator mappings assume that crossbar dimensions and encoding/mapping schemes are defined during initialization.