Robust RL with Transformer Detection
Zero-Sum Robust Deep RL for Real-Time Embedded Control
Transformer-Based Disturbance Detection & Dual-Policy Mixing on NVIDIA Jetson
Ongoing Research Ali BaniAsad – Fasta Robotics
Abstract
Deploying reinforcement learning policies on physical robots requires robustness to unmodeled disturbances, parameter uncertainty, and the sim-to-real gap. This project presents a methodology for robust control that frames the problem as a two-player zero-sum Markov game in which a learned adversary systematically probes the weaknesses of the control policy during training.
The architecture maintains two distinct policies:
- Optimal policy $\pi^{\text{opt}}$: trained on the nominal environment to maximize expected return
- Robust policy $\pi^{\text{rob}}$: co-trained against the adversary to maximize worst-case return
At runtime, a lightweight transformer encoder processes a sliding window of recent observations to produce a real-time disturbance estimate – both a detection probability and an estimated magnitude. A learned gating function converts this estimate into a continuous mixing coefficient $\alpha_t \in [0,1]$ that blends the two policies:
\[a_t = \alpha_t \, \pi^{\text{rob}}(s_t) + (1 - \alpha_t) \, \pi^{\text{opt}}(s_t)\]All neural network components are exported through ONNX to TensorRT engines optimized for the NVIDIA Jetson platform and integrated into a ROS 2 control node achieving sub-two-millisecond inference latency at up to 500 Hz.
System overview: dual-policy architecture with transformer-based disturbance detection and real-time gating on embedded GPU.
Key Contributions
Dual-Policy Mixing
Separate optimal and robust policies blended via a continuous gating coefficient -- preserving nominal performance while providing worst-case robustness on demand.
Transformer Disturbance Detector
Lightweight transformer encoder on a sliding window of observations outputs detection probability and estimated magnitude in real time.
End-to-End Deployment Pipeline
ONNX → TensorRT (FP16/INT8) → ROS 2 node on NVIDIA Jetson Orin with <2 ms inference latency.
Problem Formulation
Two-Player Zero-Sum Markov Game
The environment is modeled as a game $(S, \mathcal{A}^c, \mathcal{A}^d, P, R, \gamma)$ where:
- Controller $\pi^c$: maximizes expected discounted return
- Adversary $\pi^d$: minimizes it (receives $-R$)
A Nash equilibrium $({\pi^{c,}}, {\pi^{d,}})$ satisfies:
\[V^{\pi^{c,*}, \pi^d}(s) \ge V^{\pi^{c,*}, \pi^{d,*}}(s) \ge V^{\pi^c, \pi^{d,*}}(s) \quad \forall\, s, \pi^c, \pi^d\]The minimax value function:
\[V^*(s) = \max_{\pi^c} \min_{\pi^d} \mathbb{E}\left[\sum_{t=0}^{\infty} \gamma^t r(s_t, a_t^c, a_t^d) \,\middle|\, s_0 = s\right]\]Disturbance Model
Two instantiations of adversarial disturbance:
- Additive force/torque perturbation: $\dot{q} = f(q, \dot{q}, a^c) + B\, a^d$
- External body force: wrench applied to a specified body (pushes, wind, contact)
The adversary’s action is bounded: $|a^d|2 \le \epsilon$, where $\epsilon > 0$ is the disturbance budget with curriculum scheduling from $0$ to $\epsilon{\max}$.
Architecture
Dual-Policy Training
Phase 1: Optimal Policy
Train $\pi^{\text{opt}}$ with SAC/TD3/PPO on nominal environment (no adversary). Maximizes expected return under nominal dynamics.
Phase 2: Robust Policy + Adversary
Co-train $\pi^{\text{rob}}$ and $\pi^{\text{adv}}$ in alternation. Adversary output: $a^d = \epsilon \cdot \tanh(\text{NN}_\psi(o))$. Curriculum annealing of $\epsilon$.
Transformer Disturbance Detector
The detector is a lightweight transformer encoder processing a sliding window of $L$ observations:
| Parameter | Value |
|---|---|
| Layers $N_L$ | 2–4 |
| Model dimension $d_{\text{model}}$ | 64–128 |
| Attention heads $N_H$ | 2–4 |
| Window length $L$ | 16 |
| Parameters | ~50k |
| Latency (Jetson Orin, TensorRT) | < 0.5 ms |
Output heads:
- Detection branch: linear $\to$ sigmoid $\to$ $p_t \in [0,1]$
- Magnitude branch: linear $\to$ softplus $\to$ $\hat{\delta}_t \ge 0$
Training loss (supervised on logged adversarial data):
\[\mathcal{L} = \mathcal{L}_{\text{BCE}}(p_t, y_t) + \lambda \, \mathcal{L}_{\text{MSE}}(\hat{\delta}_t, \delta_t)\]Gating Rule
The mixing coefficient is computed via a learned gating function:
Linear gating:
\[\alpha_t = \sigma(w_p \cdot p_t + w_\delta \cdot \hat{\delta}_t + b)\]MLP gating (more expressive):
\[\alpha_t = f_\phi(p_t, \hat{\delta}_t)\]When $\alpha_t \approx 0$: nominal-optimal policy. When $\alpha_t \approx 1$: robust policy. Soft mixing avoids chattering at the decision boundary.
Deployment Pipeline
Export Path
PyTorch / JAX --> torch.jit.trace / jax.export
|
ONNX (opset 17)
|
TensorRT (trtexec)
|-- FP16 for policy networks
|-- INT8 with calibration for transformer
|
ROS 2 Node (Jetson Orin)
Inference Latency Budget (500 Hz on Jetson Orin MAXN)
| Component | Target Latency |
|---|---|
| Transformer forward pass ($\mathcal{T}_\theta$) | < 0.50 ms |
| Policy forward pass ($\pi^{\text{opt}}$) | < 0.30 ms |
| Policy forward pass ($\pi^{\text{rob}}$) | < 0.30 ms |
| Gating computation ($f_\phi$) | < 0.05 ms |
| ROS message serialization + publish | < 0.20 ms |
| Total | < 1.35 ms |
The two policy forward passes run concurrently on separate CUDA streams, reducing effective policy inference from 0.60 ms to ~0.35 ms. A 0.65 ms margin accommodates latency spikes.
ROS 2 Integration
- Sensor subscriber: joint states, IMU, proprioceptive observations (best-effort QoS)
- Timer callback: triggers inference at configured control frequency (200–500 Hz)
- Action publisher: pre-allocated messages, zero dynamic allocation
- NITROS: zero-copy GPU tensor transfer between nodes (Isaac ROS)
RL Algorithms
Four deep RL algorithms serve as candidate training methods for both the controller and adversary:
| Algorithm | Type | Key Feature |
|---|---|---|
| DDPG | Off-policy, Deterministic | Deterministic policy gradient + replay buffer |
| TD3 | Off-policy, Deterministic | Twin critics, delayed updates, target smoothing |
| SAC | Off-policy, Stochastic | Maximum entropy + automatic temperature tuning |
| PPO | On-policy, Stochastic | Clipped surrogate objective, GAE |
Comparison with Prior Work
| Feature | RARL | SA-MDP | Domain Rand. | Ours |
|---|---|---|---|---|
| Adversarial training | Yes | Yes | -- | Yes |
| Dual-policy mixing | -- | -- | -- | Yes |
| Runtime detection | -- | -- | -- | Yes |
| Transformer detector | -- | -- | -- | Yes |
| TensorRT deployment | -- | -- | -- | Yes |
| ROS 2 integration | -- | -- | -- | Yes |
Evaluation Plan
Benchmark Environments (MuJoCo / Gymnasium)
- HalfCheetah-v4: 17-dim obs, 6-dim action
- Walker2d-v4: 17-dim obs, 6-dim action
- Ant-v4: 27-dim obs, 8-dim action
- Humanoid-v4: 376-dim obs, 17-dim action
Disturbances: random external forces, parameter perturbation ($\pm$30%), and additive observation noise.
Planned Ablation Studies
- No adversary – nominal policy only (performance ceiling, robustness floor)
- No transformer detector – fixed $\alpha = 0.5$ (value of runtime awareness)
- Hard switching vs. soft mixing – binary $\alpha \in {0,1}$ vs. continuous blending
- Window length – $L \in {8, 16, 32, 64}$ (detection accuracy vs. compute)
- Adversary budget – vary $\epsilon$ for robustness-performance Pareto frontier
- Single vs. population adversary – generalization to unseen disturbances
Baselines
- SAC/TD3 (nominal)
- RARL (single robust policy, no gating)
- SA-MDP (state-adversarial, no gating)
- Domain Randomization (no adversary, no gating)
Technology Stack
| Category | Tools |
|---|---|
| Training | PyTorch, Gymnasium, MuJoCo (planned: JAX + MJX) |
| RL Algorithms | DDPG, TD3, SAC, PPO (zero-sum MARL variants) |
| Detection | Transformer encoder (causal attention, ~50k params) |
| Export | ONNX (opset 17), TensorRT (FP16/INT8) |
| Deployment | NVIDIA Jetson Orin, CUDA streams, ROS 2 Humble |
| Data | HDF5 / Apache Arrow, standardized logging spec |
Status
This project describes the full methodology and system design. Empirical results across all benchmark environments and on physical robot hardware are in progress and will be reported in a follow-up study.
Ongoing research at Fasta Robotics