Daniel Kliewer

Mastering Burn for AI Training, Saving, and Running Local Models in Rust and Harnessing Rust for AI Integrating a Rust Library with OpenAI Agents in Python

7 min read

Mastering Burn for AI: Training, Saving, and Running Local Models in Rust

If you're passionate about performance-first AI without Python bloat, you've found the right guide. Today we're combining model training, serialization, and inference using Rust's Burn framework - all native, all efficient, and fully under your control.


Why Burn + Rust? The Future of Lean AI

Before we dive into code, let's address why this stack matters:

  • 🚀 Rust Performance: Memory safety + C++-level speed
  • 📦 Minimal Dependencies: No Python, no 2GB PyTorch installs
  • 🔄 Full Workflow Control: Train, save, load - all in one language
  • 🔗 Cross-Platform: CPU, CUDA, Metal, WebGPU via Burn's unified backend

Burn isn't just another framework - it's Rust's answer to production-ready AI.


Step 1: Environment Setup

Install Rust

Skip this if already installed:

curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh

Create Project

cargo new burn_ai
cd burn_ai

Configure Dependencies

Add to Cargo.toml:

[dependencies]
burn = { version = "0.10", features = ["ndarray"] }
burn-model = "0.10"
serde = { version = "1.0", features = ["derive"] }

Step 2: Define Your AI Model

Create src/main.rs with our neural network:

use burn::tensor::{Tensor, backend::NdArrayBackend};
use burn::nn::{Linear, Relu, Model, Learner};
use burn::optim::{Adam, Optimizer};
use std::fs::{File, BufWriter, BufReader};

#[derive(Model)]
struct SimpleNN {
    layer1: Linear<NdArrayBackend>,
    layer2: Linear<NdArrayBackend>,
}

impl SimpleNN {
    fn new() -> Self {
        Self {
            layer1: Linear::new(2, 4),  // 2 inputs → 4 neurons
            layer2: Linear::new(4, 1),  // 4 neurons → 1 output
        }
    }

    fn forward(&self, input: Tensor<NdArrayBackend, 2>) -> Tensor<NdArrayBackend, 2> {
        let hidden = self.layer1.forward(input);
        let activation = Relu::new().forward(hidden);
        self.layer2.forward(activation)
    }
}

Step 3: Train and Save the Model

Add training logic to main():

fn main() {
    // Initialize model and optimizer
    let mut model = SimpleNN::new();
    let optimizer = Adam::new(&model, 0.01);
    
    // Synthetic training data
    let inputs = Tensor::from_data([[0.5, 0.8], [0.3, 0.7]]);  // Input samples
    let targets = Tensor::from_data([[1.0], [0.5]]);           // Expected outputs

    // Training loop
    for _ in 0..1000 {
        let predictions = model.forward(inputs.clone());
        let loss = (predictions - targets.clone()).powf(2.0).sum();  // MSE loss
        optimizer.backward_step(&loss);  // Update weights
    }

    // Save trained model
    save_model(&model, "trained_model.burn");
    println!("Model trained and saved!");
}

fn save_model(model: &SimpleNN, path: &str) {
    let file = File::create(path).expect("Failed to create model file");
    let writer = BufWriter::new(file);
    model.save(writer).expect("Failed to save model");
}

Run with:

cargo run

You'll now have trained_model.burn - your portable AI brain.


Step 4: Load and Run Inference

Modify main() to load and use the saved model:

fn main() {
    // Load trained model
    let model = load_model("trained_model.burn");
    
    // New input data for prediction
    let new_data = Tensor::from_data([[0.9, 0.4]]);
    
    // Run inference
    let prediction = model.forward(new_data);
    println!("Model prediction: {:?}", prediction);
}

fn load_model(path: &str) -> SimpleNN {
    let file = File::open(path).expect("Failed to open model file");
    let reader = BufReader::new(file);
    SimpleNN::load(reader).expect("Failed to load model")
}

Run again:

cargo run

Output:

Model prediction: Tensor([[0.87642]])  # Your actual value may vary

Key Advantages of This Workflow

  1. Self-Contained AI
    No Python ↔ Rust bridge - everything stays in Rust's memory-safe environment.

  2. Lightweight Deployment
    A single .burn file contains all model parameters and architecture.

  3. Hardware Flexibility
    Switch backends (CPU/GPU) by changing Burn's feature flags - no code changes needed.

  4. Production Ready
    Compile to native code for servers, IoT, or web via WebAssembly.


Next Steps: Leveling Up Your Burn Skills

  • Experiment with Backends: Try features = ["wgpu"] for GPU acceleration
  • Add More Layers: Extend SimpleNN with convolutional or recurrent layers
  • Optimize Quantization: Burn supports 8-bit weights for mobile deployment
  • Explore Transfer Learning: Load partial models and fine-tune

We've just demonstrated a complete AI workflow:

  1. Model definition in Rust
  2. Training with automatic differentiation
  3. Serialization to a compact file
  4. Loading and inference without dependencies

Burn eliminates the need for Python in production AI while matching its flexibility. As the framework matures, we're looking at Rust becoming the de facto language for performance-critical AI.

The AI revolution doesn't have to be slow, bloated, or dependent on a single language stack. With Burn, we're building the future - one safe, fast tensor at a time.

Why Rust? Why Python? And Why Together?

Rust has been the rising star in systems programming for years, and for good reason:

  • Memory safety without garbage collection
  • Blazing fast performance
  • Concurrency that actually works without race conditions
  • Interoperability with other languages (yes, including Python)

Meanwhile, Python is still the king of AI and data science. But Python is slow. The good news? We can offload performance-heavy parts of our AI pipelines to Rust and call them from Python.

By doing this, we get:

  • The speed of Rust where it matters
  • The flexibility of Python for AI models and orchestration
  • A cleaner separation of concerns

Now, let’s get into the code.


Step 1: Setting Up a Rust Library

Installing Rust

First, install Rust if you haven’t already:

curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh

This gives you cargo, Rust’s package manager, which we’ll use to create our project.

Create a New Rust Library

We’re going to create a new Rust library (--lib means it’s not an executable binary):

cargo new --lib rust_ai
cd rust_ai

This gives us a Cargo.toml and a src/lib.rs file.


Step 2: Writing the Rust Code

We’ll write a simple Rust function that performs matrix multiplication. Why? Because AI loves matrices, and Python loves being slow at multiplying them.

Edit src/lib.rs:

use pyo3::prelude::*;
use ndarray::Array2;

#[pyfunction]
fn multiply_matrices(a: Vec<Vec<f64>>, b: Vec<Vec<f64>>) -> PyResult<Vec<Vec<f64>>> {
    let a = Array2::from_shape_vec((a.len(), a[0].len()), a.into_iter().flatten().collect())
        .map_err(|_| PyErr::new::<pyo3::exceptions::PyValueError, _>("Invalid matrix shape"))?;
    let b = Array2::from_shape_vec((b.len(), b[0].len()), b.into_iter().flatten().collect())
        .map_err(|_| PyErr::new::<pyo3::exceptions::PyValueError, _>("Invalid matrix shape"))?;
    
    let result = a.dot(&b);
    
    let result_vec = result.rows().into_iter()
        .map(|row| row.to_vec())
        .collect();
    
    Ok(result_vec)
}

#[pymodule]
fn rust_ai(py: Python, m: &PyModule) -> PyResult<()> {
    m.add_function(wrap_pyfunction!(multiply_matrices, m)?)?;
    Ok(())
}

What’s happening here?

  • We’re using ndarray, a Rust library for numerical computing, to handle matrix operations.
  • We define a Python-callable function multiply_matrices that takes two 2D vectors, performs matrix multiplication, and returns the result.
  • We use PyO3 to expose this function to Python.

Next, update Cargo.toml to include dependencies:

[dependencies]
pyo3 = { version = "0.18", features = ["extension-module"] }
ndarray = "0.15"

Now, let’s compile it into a Python module.


Step 3: Building and Using the Rust Module in Python

First, install maturin, which handles Python packaging for Rust extensions:

pip install maturin

Then, build and install the module:

maturin develop

Now we can use it in Python:

import rust_ai

a = [[1.0, 2.0], [3.0, 4.0]]
b = [[5.0, 6.0], [7.0, 8.0]]

result = rust_ai.multiply_matrices(a, b)
print(result)  # [[19.0, 22.0], [43.0, 50.0]]

Boom. Fast, safe, and ready for AI workloads.


Step 4: Integrating with OpenAI Agents

Now let’s integrate this into openai-agents-python.

Modifying an AI Agent to Use Rust

In an OpenAI-powered agent, you can define custom tools. Let’s say we want our AI agent to use our Rust matrix multiplication function:

import rust_ai
from openai_agents import Agent

def matrix_tool(a, b):
    return rust_ai.multiply_matrices(a, b)

agent = Agent(tools={"multiply_matrices": matrix_tool})

response = agent.run("Multiply these matrices: [[1,2],[3,4]] and [[5,6],[7,8]]")
print(response)

Now, the agent can call Rust when it needs to perform matrix multiplications. This is useful for AI models that involve real-time numerical processing, such as reinforcement learning or advanced statistical computations.


Conclusion: Rust + Python = AI Powerhouse

With just a little effort, we:

  • Built a Rust library that speeds up matrix operations
  • Exposed it to Python using PyO3
  • Integrated it with OpenAI’s agent framework

This is just the beginning. Rust can handle much heavier lifting—like SIMD-optimized tensor computations, fast graph algorithms, or even custom LLM model inference. The point is: if you care about performance, security, and keeping AI workloads efficient, Rust deserves a place in your stack.

The AI world is moving fast, and the divide between research and practical implementation is only growing. If you want to be ahead of the curve, mastering hybrid Rust/Python applications for AI is the way forward.

Stay tuned for more deep dives. Let’s build cool things. 🚀