Ali BaniAsad

Robotics Engineer at Fasta | Robust RL & Embedded AI Researcher | Seeking PhD Positions

Robust RL with Transformer Detection | Ali BaniAsad

Robust RL with Transformer Detection

Zero-Sum Robust Deep RL for Real-Time Embedded Control

Transformer-Based Disturbance Detection & Dual-Policy Mixing on NVIDIA Jetson

Ongoing Research Ali BaniAsad – Fasta Robotics

PyTorch JAX TensorRT ROS2 Jetson


Abstract

Deploying reinforcement learning policies on physical robots requires robustness to unmodeled disturbances, parameter uncertainty, and the sim-to-real gap. This project presents a methodology for robust control that frames the problem as a two-player zero-sum Markov game in which a learned adversary systematically probes the weaknesses of the control policy during training.

The architecture maintains two distinct policies:

At runtime, a lightweight transformer encoder processes a sliding window of recent observations to produce a real-time disturbance estimate – both a detection probability and an estimated magnitude. A learned gating function converts this estimate into a continuous mixing coefficient $\alpha_t \in [0,1]$ that blends the two policies:

\[a_t = \alpha_t \, \pi^{\text{rob}}(s_t) + (1 - \alpha_t) \, \pi^{\text{opt}}(s_t)\]

All neural network components are exported through ONNX to TensorRT engines optimized for the NVIDIA Jetson platform and integrated into a ROS 2 control node achieving sub-two-millisecond inference latency at up to 500 Hz.

Zero-Sum Robust RL System Overview

System overview: dual-policy architecture with transformer-based disturbance detection and real-time gating on embedded GPU.


Key Contributions

Dual-Policy Mixing

Separate optimal and robust policies blended via a continuous gating coefficient -- preserving nominal performance while providing worst-case robustness on demand.

Transformer Disturbance Detector

Lightweight transformer encoder on a sliding window of observations outputs detection probability and estimated magnitude in real time.

End-to-End Deployment Pipeline

ONNX → TensorRT (FP16/INT8) → ROS 2 node on NVIDIA Jetson Orin with <2 ms inference latency.


Problem Formulation

Two-Player Zero-Sum Markov Game

The environment is modeled as a game $(S, \mathcal{A}^c, \mathcal{A}^d, P, R, \gamma)$ where:

A Nash equilibrium $({\pi^{c,}}, {\pi^{d,}})$ satisfies:

\[V^{\pi^{c,*}, \pi^d}(s) \ge V^{\pi^{c,*}, \pi^{d,*}}(s) \ge V^{\pi^c, \pi^{d,*}}(s) \quad \forall\, s, \pi^c, \pi^d\]

The minimax value function:

\[V^*(s) = \max_{\pi^c} \min_{\pi^d} \mathbb{E}\left[\sum_{t=0}^{\infty} \gamma^t r(s_t, a_t^c, a_t^d) \,\middle|\, s_0 = s\right]\]

Disturbance Model

Two instantiations of adversarial disturbance:

  1. Additive force/torque perturbation: $\dot{q} = f(q, \dot{q}, a^c) + B\, a^d$
  2. External body force: wrench applied to a specified body (pushes, wind, contact)

The adversary’s action is bounded: $|a^d|2 \le \epsilon$, where $\epsilon > 0$ is the disturbance budget with curriculum scheduling from $0$ to $\epsilon{\max}$.


Architecture

Dual-Policy Training

Phase 1: Optimal Policy

Train $\pi^{\text{opt}}$ with SAC/TD3/PPO on nominal environment (no adversary). Maximizes expected return under nominal dynamics.

Phase 2: Robust Policy + Adversary

Co-train $\pi^{\text{rob}}$ and $\pi^{\text{adv}}$ in alternation. Adversary output: $a^d = \epsilon \cdot \tanh(\text{NN}_\psi(o))$. Curriculum annealing of $\epsilon$.

Transformer Disturbance Detector

The detector is a lightweight transformer encoder processing a sliding window of $L$ observations:

Parameter Value
Layers $N_L$ 2–4
Model dimension $d_{\text{model}}$ 64–128
Attention heads $N_H$ 2–4
Window length $L$ 16
Parameters ~50k
Latency (Jetson Orin, TensorRT) < 0.5 ms

Output heads:

Training loss (supervised on logged adversarial data):

\[\mathcal{L} = \mathcal{L}_{\text{BCE}}(p_t, y_t) + \lambda \, \mathcal{L}_{\text{MSE}}(\hat{\delta}_t, \delta_t)\]

Gating Rule

The mixing coefficient is computed via a learned gating function:

Linear gating:

\[\alpha_t = \sigma(w_p \cdot p_t + w_\delta \cdot \hat{\delta}_t + b)\]

MLP gating (more expressive):

\[\alpha_t = f_\phi(p_t, \hat{\delta}_t)\]

When $\alpha_t \approx 0$: nominal-optimal policy. When $\alpha_t \approx 1$: robust policy. Soft mixing avoids chattering at the decision boundary.


Deployment Pipeline

Export Path

PyTorch / JAX  -->  torch.jit.trace / jax.export
       |
    ONNX (opset 17)
       |
    TensorRT (trtexec)
       |-- FP16 for policy networks
       |-- INT8 with calibration for transformer
       |
    ROS 2 Node (Jetson Orin)

Inference Latency Budget (500 Hz on Jetson Orin MAXN)

Component Target Latency
Transformer forward pass ($\mathcal{T}_\theta$) < 0.50 ms
Policy forward pass ($\pi^{\text{opt}}$) < 0.30 ms
Policy forward pass ($\pi^{\text{rob}}$) < 0.30 ms
Gating computation ($f_\phi$) < 0.05 ms
ROS message serialization + publish < 0.20 ms
Total < 1.35 ms

The two policy forward passes run concurrently on separate CUDA streams, reducing effective policy inference from 0.60 ms to ~0.35 ms. A 0.65 ms margin accommodates latency spikes.

ROS 2 Integration


RL Algorithms

Four deep RL algorithms serve as candidate training methods for both the controller and adversary:

Algorithm Type Key Feature
DDPG Off-policy, Deterministic Deterministic policy gradient + replay buffer
TD3 Off-policy, Deterministic Twin critics, delayed updates, target smoothing
SAC Off-policy, Stochastic Maximum entropy + automatic temperature tuning
PPO On-policy, Stochastic Clipped surrogate objective, GAE

Comparison with Prior Work

Feature RARL SA-MDP Domain Rand. Ours
Adversarial training Yes Yes -- Yes
Dual-policy mixing -- -- -- Yes
Runtime detection -- -- -- Yes
Transformer detector -- -- -- Yes
TensorRT deployment -- -- -- Yes
ROS 2 integration -- -- -- Yes

Evaluation Plan

Benchmark Environments (MuJoCo / Gymnasium)

Disturbances: random external forces, parameter perturbation ($\pm$30%), and additive observation noise.

Planned Ablation Studies

  1. No adversary – nominal policy only (performance ceiling, robustness floor)
  2. No transformer detector – fixed $\alpha = 0.5$ (value of runtime awareness)
  3. Hard switching vs. soft mixing – binary $\alpha \in {0,1}$ vs. continuous blending
  4. Window length – $L \in {8, 16, 32, 64}$ (detection accuracy vs. compute)
  5. Adversary budget – vary $\epsilon$ for robustness-performance Pareto frontier
  6. Single vs. population adversary – generalization to unseen disturbances

Baselines


Technology Stack

Category Tools
Training PyTorch, Gymnasium, MuJoCo (planned: JAX + MJX)
RL Algorithms DDPG, TD3, SAC, PPO (zero-sum MARL variants)
Detection Transformer encoder (causal attention, ~50k params)
Export ONNX (opset 17), TensorRT (FP16/INT8)
Deployment NVIDIA Jetson Orin, CUDA streams, ROS 2 Humble
Data HDF5 / Apache Arrow, standardized logging spec

Status

This project describes the full methodology and system design. Empirical results across all benchmark environments and on physical robot hardware are in progress and will be reported in a follow-up study.


Ongoing research at Fasta Robotics