cloudinwind's blog
RL笔记(20):Decision TransformerBlur image

引言(Introduction)#

在之前的笔记中,无论是有模型还是无模型,在线还是离线,我们解决 RL 问题的核心思路都是基于 动态规划 (Dynamic Programming) 的:

  • 我们要估计价值函数 V(s)V(s)Q(s,a)Q(s,a)
  • 利用贝尔曼方程进行迭代更新(自举)。
  • 通过最大化价值来获得策略。

Decision Transformer (DT) 提出了一种完全不同的视角: 如果我们拥有大量的离线轨迹数据,为什么不把它看作是一个 序列建模 (Sequence Modeling) 问题呢? 就像 GPT 预测下一个单词一样,能不能根据过去的轨迹期望的回报,直接预测下一个动作

核心思想: 强化学习 \approx 在给定期望回报 (Target Return) 条件下的行为克隆 (Conditional Behavior Cloning)


轨迹表示与 Return-to-Go#

为了将 RL 问题转化为 Transformer 可以处理的序列问题,我们需要重新定义模型的输入。

轨迹 (Trajectory)#

一条标准的 RL 轨迹由状态、动作和奖励组成: τ=(s1,a1,r1,s2,a2,r2,,sT,aT,rT)\tau = (s_1, a_1, r_1, s_2, a_2, r_2, \dots, s_T, a_T, r_T)

剩余回报 (Return-to-Go, RTG)#

传统的 RL 使用即时奖励 rtr_t。但在做序列预测时,我们更关心“未来还能拿多少分”。 定义 tt 时刻的 Return-to-Go (RTG) R^t\hat{R}_t 为从当前时刻到回合结束的累积回报:

R^t=k=tTrk\hat{R}_t = \sum_{k=t}^T r_k

模型的输入序列#

DT 的核心创新在于将 RTG 作为一种条件 (Condition) 输入给模型。 输入序列被组织为三元组的序列(K-V-Q 模式):

τinput=(R^1,s1,a1,R^2,s2,a2,,R^T,sT,aT)\tau_{input} = (\hat{R}_1, s_1, a_1, \hat{R}_2, s_2, a_2, \dots, \hat{R}_T, s_T, a_T)
  • 直觉:这就像是在告诉模型:“我现在状态是 s1s_1,我想在未来总共获得 R^1\hat{R}_1 分,请告诉我该做什么动作 a1a_1?”

网络架构#

DT 直接使用了 GPT (Generative Pre-trained Transformer) 的架构,即因果掩码 Transformer (Causal Transformer)。

嵌入层 (Embeddings)#

由于 R,s,aR, s, a 的模态不同(标量、图像/向量、离散/连续),我们需要先将它们映射到同一个维度 dd

  1. 状态嵌入:CNN 或 MLP 处理 sts_t
  2. 动作嵌入:Embedding 层或 MLP 处理 ata_t
  3. 回报嵌入:MLP 处理 R^t\hat{R}_t
  4. 时间步嵌入 (Timestep Embedding):为了让模型知道当前处于轨迹的哪个阶段,额外加入一个可学习的时间位置编码。

上下文窗口 (Context Window)#

由于轨迹可能很长,Transformer 无法处理整个 Episode。DT 使用一个固定的上下文窗口 KK(context length),只把最近的 KK 步输入模型:

Inputt=[R^tK,stK,atK,,R^t,st]\text{Input}_t = [\hat{R}_{t-K}, s_{t-K}, a_{t-K}, \dots, \hat{R}_t, s_t]

预测目标#

模型的目标是预测下一个 token。在 DT 中,我们主要关注预测动作 ata_t

at=DecisionTransformer(R^tK,stK,atK,,R^t,st)a_t = \text{DecisionTransformer}(\hat{R}_{t-K}, s_{t-K}, a_{t-K}, \dots, \hat{R}_t, s_t)

训练与推断#

训练 (Training)#

DT 的训练过程完全是监督学习 (Supervised Learning),不需要计算梯度,不需要贝尔曼误差,不需要目标网络。

从离线数据集 D\mathcal{D} 中采样轨迹片段,最小化预测动作与真实动作的误差:

  • 离散动作:交叉熵损失 (Cross-Entropy Loss)。
  • 连续动作:均方误差 (MSE Loss)。
L(θ)=EτD[t=1T(ata^t)2]\mathcal{L}(\theta) = \mathbb{E}_{\tau \sim \mathcal{D}} \left[ \sum_{t=1}^T (a_t - \hat{a}_t)^2 \right]

💡 注意:虽然数据集中可能包含低分的轨迹(“臭棋”),但模型学习的是条件概率 P(atR^t,st,)P(a_t | \hat{R}_t, s_t, \dots)。也就是说,模型学会了“如果想要低分该怎么做”以及“如果想要高分该怎么做”。

推断 (Inference)#

这是 DT 最神奇的地方。训练完成后,我们可以通过提示 (Prompting) 来控制智能体。

我们给智能体设定一个目标回报 (Target Return) R^target\hat{R}_{target}(通常设为数据集中最高分的那个回报,或者更高)。

自回归生成过程

  1. 初始时刻 t=1t=1:输入 (R^target,s1)(\hat{R}_{target}, s_1),模型输出动作 a1a_1
  2. 环境交互:执行 a1a_1,环境返回即时奖励 r1r_1 和新状态 s2s_2
  3. 更新目标:既然已经拿到了 r1r_1,剩下的目标就要减去它: R^2=R^1r1\hat{R}_2 = \hat{R}_1 - r_1
  4. 下一时刻:输入 (,R^1,s1,a1,R^2,s2)(\dots, \hat{R}_1, s_1, a_1, \hat{R}_2, s_2),预测 a2a_2
  5. 重复直至结束。

理论对比:DT vs. CQL#

为了理解 DT 的位置,我们将它与上一篇笔记中的 CQL 进行对比:

维度Conservative Q-Learning (CQL)Decision Transformer (DT)
核心范式动态规划 (DP)序列建模 (SL)
学习目标逼近最优价值函数 Q(s,a)Q^*(s,a)拟合条件分布 P(aR^,s)P(a\|\hat{R}, s)
训练方式最小化贝尔曼误差 (TD Error)最大化动作似然 (Supervised)
OOD 处理显式惩罚未知动作的 Q 值 (悲观主义)依靠 Transformer 的泛化能力 (无显式约束)
长时序信用分配依靠 QQ 值的自举传播 (Bootstrapping)依靠 Attention 机制直接关联过去与未来
优点理论下界保证,擅长拼接轨迹 (Stitching)训练极其稳定,易于扩展,能够处理稀疏奖励

总结#

Decision Transformer 的提出证明了:只要模型足够强(Transformer),强化学习可以被简化为监督学习。

  • 它不需要复杂的 Actor-Critic 架构。
  • 它不需要处理“过高估计”等死板的数值问题。
  • 它通过 Return-to-Go 实现了类似于“事后诸葛亮”的条件控制:利用这一局最终的得分作为条件,来学习这一局中的动作。
RL笔记(20):Decision Transformer
https://cloudflare.cloudinwind.top/blog/rl-note-20
Author 云之痕
Published at December 29, 2025
Comment seems to stuck. Try to refresh?✨