diff --git a/critical/agent.py b/critical/agent.py new file mode 100644 index 00000000..b4f4142e --- /dev/null +++ b/critical/agent.py @@ -0,0 +1,200 @@ +# rl_algorithms/ppo/agent.py +# 标准 PPO 智能体:Actor-Critic + GAE + 标准裁剪 + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim + +from rl_algorithms.base_agent import BaseAgent +from rl_algorithms.ppo.network import Actor, Critic +from rl_algorithms.ppo.storage import RolloutStorage +from rl_algorithms.ppo.clip_utils import standard_clip +from config.ppo_config import ( + STATE_SIZE, ACTION_SIZE, + LR_ACTOR, LR_CRITIC, + GAMMA, LAMBDA, EPS_CLIP, + UPDATE_EVERY, UPDATE_POLICY_TIMES, BATCH_SIZE, + VALUE_LOSS_COEF, ENTROPY_COEF, MAX_GRAD_NORM, +) + + +class PPOAgent(BaseAgent): + """ + 标准 PPO 智能体(Proximal Policy Optimization)。 + + 适用场景: #1 大雨跟车, #6 行人横穿, #8 行人闯红灯 + """ + + def __init__(self, state_size=None, action_size=None): + state_size = state_size or STATE_SIZE + action_size = action_size or ACTION_SIZE + super().__init__(state_size, action_size, name="PPO") + + self.actor = Actor(state_size, action_size).to(self.device) + self.critic = Critic(state_size).to(self.device) + + self.actor_opt = optim.Adam(self.actor.parameters(), lr=LR_ACTOR) + self.critic_opt = optim.Adam(self.critic.parameters(), lr=LR_CRITIC) + + self.gamma = GAMMA + self.lambd = LAMBDA + self.eps_clip = EPS_CLIP + self.update_every = UPDATE_EVERY + self.k_epochs = UPDATE_POLICY_TIMES + self.batch_size = BATCH_SIZE + self.value_coef = VALUE_LOSS_COEF + self.entropy_coef = ENTROPY_COEF + self.max_grad_norm = MAX_GRAD_NORM + + self.storage = RolloutStorage() + + # 上次训练的 loss 信息 + self.last_loss_info = {} + + # ================================================================ + # 核心接口 + # ================================================================ + + def act(self, state, evaluate=False): + """ + 采样动作。返回 (action, log_prob)。 + evaluate=True 时返回概率最高的动作。 + """ + with torch.no_grad(): + state_t = self.to_tensor(state).unsqueeze(0) + probs = self.actor(state_t) + dist = torch.distributions.Categorical(probs) + if evaluate: + action = probs.argmax(dim=-1) + else: + action = dist.sample() + log_prob = dist.log_prob(action) + return action.item(), log_prob.item() + + def store(self, state, action, log_prob, reward, next_state, done): + self.storage.push(state, action, log_prob, reward, next_state, done) + + def train(self): + """收集足够步数后执行 PPO 更新""" + if len(self.storage) < self.update_every: + return None + + self.train_steps += 1 + + states, actions, old_log_probs, rewards, next_states, dones = \ + self.storage.get_all() + + s = self.to_tensor(states) + a = torch.tensor(actions, dtype=torch.long, device=self.device) + old_lp = torch.tensor(old_log_probs, dtype=torch.float32, device=self.device) + + # 计算 GAE 和 returns + with torch.no_grad(): + values = self.critic(s).squeeze(-1) + next_val = self.critic( + self.to_tensor(next_states[-1:])).squeeze(-1).item() + + advantages = self._compute_gae( + rewards, values.detach().cpu().numpy(), next_val, dones) + advantages = self.to_tensor(advantages) + returns = advantages + values.detach() + + # 标准化 advantages + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + + total_actor_loss = 0.0 + total_critic_loss = 0.0 + + n = len(states) + for _ in range(self.k_epochs): + # 小批量训练 + indices = torch.randperm(n) + for start in range(0, n, self.batch_size): + idx = indices[start:start + self.batch_size] + + s_batch = s[idx] + a_batch = a[idx] + old_lp_batch = old_lp[idx] + adv_batch = advantages[idx] + ret_batch = returns[idx] + + # Actor 损失 + probs = self.actor(s_batch) + dist = torch.distributions.Categorical(probs) + new_lp = dist.log_prob(a_batch) + entropy = dist.entropy().mean() + + ratio = torch.exp(new_lp - old_lp_batch) + clipped = standard_clip(ratio, self.eps_clip) + actor_loss = -torch.min( + ratio * adv_batch, clipped * adv_batch).mean() + actor_loss = actor_loss - self.entropy_coef * entropy + + self.actor_opt.zero_grad() + actor_loss.backward() + torch.nn.utils.clip_grad_norm_( + self.actor.parameters(), self.max_grad_norm) + self.actor_opt.step() + + # Critic 损失 + values_pred = self.critic(s_batch).squeeze(-1) + critic_loss = self.value_coef * nn.MSELoss()(values_pred, ret_batch) + + self.critic_opt.zero_grad() + critic_loss.backward() + torch.nn.utils.clip_grad_norm_( + self.critic.parameters(), self.max_grad_norm) + self.critic_opt.step() + + total_actor_loss += actor_loss.item() + total_critic_loss += critic_loss.item() + + self.storage.clear() + self.last_loss_info = { + "actor_loss": total_actor_loss / max(self.k_epochs, 1), + "critic_loss": total_critic_loss / max(self.k_epochs, 1), + } + return self.last_loss_info + + # ================================================================ + # GAE + # ================================================================ + + def _compute_gae(self, rewards, values, next_value, dones): + """ + 计算 Generalized Advantage Estimation。 + + rewards: list of float + values: np.ndarray (T,) V(s_t) + next_value: float V(s_{T+1}) + dones: np.ndarray (T,) + """ + T = len(rewards) + vals = np.append(values, next_value) + advantages = np.zeros(T, dtype=np.float32) + gae = 0.0 + for t in reversed(range(T)): + delta = rewards[t] + self.gamma * vals[t + 1] * (1 - dones[t]) - vals[t] + gae = delta + self.gamma * self.lambd * (1 - dones[t]) * gae + advantages[t] = gae + return advantages + + # ================================================================ + # 持久化 + # ================================================================ + + def _save_checkpoint(self, checkpoint, path): + checkpoint.update({ + "actor": self.actor.state_dict(), + "critic": self.critic.state_dict(), + "actor_opt": self.actor_opt.state_dict(), + "critic_opt": self.critic_opt.state_dict(), + }) + torch.save(checkpoint, path) + + def _load_checkpoint(self, checkpoint): + self.actor.load_state_dict(checkpoint["actor"]) + self.critic.load_state_dict(checkpoint["critic"]) + self.actor_opt.load_state_dict(checkpoint["actor_opt"]) + self.critic_opt.load_state_dict(checkpoint["critic_opt"]) diff --git a/critical/attention_agent.py b/critical/attention_agent.py new file mode 100644 index 00000000..54709836 --- /dev/null +++ b/critical/attention_agent.py @@ -0,0 +1,162 @@ +# rl_algorithms/dqn/attention_agent.py +# Attention-DQN 智能体:多头注意力机制 + DQN(毕设创新算法) + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim + +from rl_algorithms.base_agent import BaseAgent +from rl_algorithms.dqn.attention_network import AttentionQNetwork +from rl_algorithms.dqn.replay_buffer import ReplayBuffer +from config.dqn_config import ( + STATE_SIZE, ACTION_SIZE, HIDDEN_SIZES, + LEARNING_RATE, GAMMA, TAU, TARGET_UPDATE_FREQ, + MEMORY_SIZE, BATCH_SIZE, MIN_REPLAY_SIZE, + EPSILON_START, EPSILON_MIN, EPSILON_DECAY, + TRAIN_EVERY_N_STEPS, + NUM_ATTENTION_HEADS, ATTENTION_HEAD_DIM, ATTENTION_DROPOUT, +) + + +class AttentionDQNAgent(BaseAgent): + """ + Attention-DQN 智能体(毕设创新算法)。 + + 在标准 DQN 基础上引入多头注意力机制,使智能体能够自适应地 + 关注状态的不同特征维度,在低能见度、突发危险等场景中表现更优。 + + 适用场景: #2 浓雾巡航, #3 夜间行驶, #7 鬼探头, + #9 夜间行人横穿, #10 雾天鬼探头 + """ + + def __init__(self, state_size=None, action_size=None): + state_size = state_size or STATE_SIZE + action_size = action_size or ACTION_SIZE + super().__init__(state_size, action_size, name="AttentionDQN") + + # 注意力 Q 网络 + self.q_net = AttentionQNetwork( + state_size, action_size, HIDDEN_SIZES, + NUM_ATTENTION_HEADS, ATTENTION_HEAD_DIM, ATTENTION_DROPOUT, + ).to(self.device) + self.target_net = AttentionQNetwork( + state_size, action_size, HIDDEN_SIZES, + NUM_ATTENTION_HEADS, ATTENTION_HEAD_DIM, ATTENTION_DROPOUT, + ).to(self.device) + self.target_net.load_state_dict(self.q_net.state_dict()) + + self.optimizer = optim.Adam(self.q_net.parameters(), lr=LEARNING_RATE) + self.loss_fn = nn.MSELoss() + + # 回放池 + self.memory = ReplayBuffer(MEMORY_SIZE) + self.batch_size = BATCH_SIZE + + # 超参数 + self.gamma = GAMMA + self.tau = TAU + self.target_update_freq = TARGET_UPDATE_FREQ + self.train_every = TRAIN_EVERY_N_STEPS + + # 探索 + self.epsilon = EPSILON_START + self.epsilon_min = EPSILON_MIN + self.epsilon_decay = EPSILON_DECAY + + # 注意力权重(供可视化) + self.last_attention = None + self.last_loss = None + + # ================================================================ + # 核心接口 + # ================================================================ + + def act(self, state, evaluate=False): + """ + 选择动作。返回注意力权重供外部分析(创新点:可解释性)。 + """ + if not evaluate and np.random.random() < self.epsilon: + self.last_attention = None + return np.random.randint(self.action_size) + + with torch.no_grad(): + state_t = self.to_tensor(state).unsqueeze(0) + q_values, attn = self.q_net(state_t) + self.last_attention = attn.detach().cpu().numpy() + return q_values.argmax(dim=-1).item() + + def store(self, state, action, reward, next_state, done): + self.memory.push(state, action, reward, next_state, done) + + def train(self): + """执行一步学习更新""" + if len(self.memory) < MIN_REPLAY_SIZE: + return None + if self.total_steps % self.train_every != 0: + return None + + self.train_steps += 1 + + states, actions, rewards, next_states, dones = \ + self.memory.sample(self.batch_size) + + s = self.to_tensor(states) + a = torch.tensor(actions, dtype=torch.long, device=self.device).unsqueeze(1) + r = torch.tensor(rewards, dtype=torch.float32, device=self.device).unsqueeze(1) + ns = self.to_tensor(next_states) + d = torch.tensor(dones, dtype=torch.float32, device=self.device).unsqueeze(1) + + # 当前 Q + q_all, _ = self.q_net(s) + q = q_all.gather(1, a) + + # 目标 Q + with torch.no_grad(): + next_q_all, _ = self.target_net(ns) + next_q = next_q_all.max(dim=1, keepdim=True)[0] + target_q = r + self.gamma * next_q * (1 - d) + + loss = self.loss_fn(q, target_q) + + self.optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), 10.0) + self.optimizer.step() + + if self.train_steps % self.target_update_freq == 0: + self._soft_update() + + self.last_loss = loss.item() + self.update_epsilon() + return {"loss": self.last_loss, "epsilon": self.epsilon} + + def update_epsilon(self): + self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay) + + # ================================================================ + # 持久化 + # ================================================================ + + def _save_checkpoint(self, checkpoint, path): + checkpoint.update({ + "q_net": self.q_net.state_dict(), + "target_net": self.target_net.state_dict(), + "optimizer": self.optimizer.state_dict(), + "epsilon": self.epsilon, + }) + torch.save(checkpoint, path) + + def _load_checkpoint(self, checkpoint): + self.q_net.load_state_dict(checkpoint["q_net"]) + self.target_net.load_state_dict(checkpoint["target_net"]) + self.optimizer.load_state_dict(checkpoint["optimizer"]) + self.epsilon = checkpoint.get("epsilon", self.epsilon) + + # ================================================================ + # 内部 + # ================================================================ + + def _soft_update(self): + for tp, p in zip(self.target_net.parameters(), self.q_net.parameters()): + tp.data.copy_(self.tau * p.data + (1 - self.tau) * tp.data) diff --git a/critical/attention_network.py b/critical/attention_network.py new file mode 100644 index 00000000..8392ca87 --- /dev/null +++ b/critical/attention_network.py @@ -0,0 +1,116 @@ +# rl_algorithms/dqn/attention_network.py +# Attention-DQN 网络:多头注意力 + Q 值估计(毕设创新点) + +import math + +import torch +import torch.nn as nn + +from config.dqn_config import ( + STATE_SIZE, ACTION_SIZE, HIDDEN_SIZES, + ATTENTION_HEAD_DIM, NUM_ATTENTION_HEADS, ATTENTION_DROPOUT, +) + + +class MultiHeadAttention(nn.Module): + """标准多头注意力模块""" + + def __init__(self, embed_dim, num_heads, dropout=0.1): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.scale = self.head_dim ** 0.5 + + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.out_proj = nn.Linear(embed_dim, embed_dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + """ + x: (B, N, embed_dim) + 返回: out (B, N, embed_dim), attn_weights (B, H, N, N) + """ + B, N, D = x.shape + H = self.num_heads + d = self.head_dim + + q = self.q_proj(x).view(B, N, H, d).transpose(1, 2) # (B, H, N, d) + k = self.k_proj(x).view(B, N, H, d).transpose(1, 2) + v = self.v_proj(x).view(B, N, H, d).transpose(1, 2) + + attn = torch.matmul(q, k.transpose(-2, -1)) / self.scale # (B, H, N, N) + attn = torch.softmax(attn, dim=-1) + attn = self.dropout(attn) + + out = torch.matmul(attn, v) # (B, H, N, d) + out = out.transpose(1, 2).contiguous().view(B, N, D) # (B, N, D) + out = self.out_proj(out) + return out, attn + + +class AttentionQNetwork(nn.Module): + """ + 带多头注意力的 Q 网络(Attention-DQN,毕设创新)。 + + 流程: + 输入 state → 投影到 embed_dim → 复制为 N 个 token (加可学习位置编码) + → 多头注意力编码 token 间关系 → 残差 + LayerNorm + → 全局平均池化 → MLP 头 → Q 值 + """ + + def __init__(self, state_size=None, action_size=None, + hidden_sizes=None, num_heads=None, head_dim=None, dropout=None): + super().__init__() + self.state_size = state_size or STATE_SIZE + self.action_size = action_size or ACTION_SIZE + self.num_heads = num_heads or NUM_ATTENTION_HEADS + self.head_dim = head_dim or ATTENTION_HEAD_DIM + self.dropout = dropout if dropout is not None else ATTENTION_DROPOUT + self.embed_dim = self.num_heads * self.head_dim # e.g. 4 × 64 = 256 + + # 状态投影 + self.state_proj = nn.Linear(self.state_size, self.embed_dim) + + # 可学习的 token 位置编码:N 个 token,每个 embed_dim 维 + self.num_tokens = self.num_heads # token 数量 = 头数 + self.pos_embed = nn.Parameter( + torch.zeros(1, self.num_tokens, self.embed_dim)) + nn.init.normal_(self.pos_embed, std=0.02) + + # 多头注意力 + self.attention = MultiHeadAttention(self.embed_dim, self.num_heads, self.dropout) + self.attn_norm = nn.LayerNorm(self.embed_dim) + + # Q 值输出头: embed_dim → hidden → action_size + hs = hidden_sizes or HIDDEN_SIZES + prev = self.embed_dim + mlp_layers = [] + for h in hs: + mlp_layers.append(nn.Linear(prev, h)) + mlp_layers.append(nn.ReLU()) + prev = h + mlp_layers.append(nn.Linear(prev, self.action_size)) + self.q_head = nn.Sequential(*mlp_layers) + + def forward(self, x): + B = x.size(0) + + # 投影到嵌入空间 + embed = self.state_proj(x) # (B, embed_dim) + + # 复制为 N 个 token,加上可学习位置编码 + tokens = embed.unsqueeze(1).expand(-1, self.num_tokens, -1) # (B, N, embed_dim) + tokens = tokens + self.pos_embed # (B, N, embed_dim) + + # 多头注意力 + 残差 + attn_out, attn_weights = self.attention(tokens) # (B, N, embed_dim) + attn_out = self.attn_norm(attn_out + tokens) + + # 全局平均池化 → Q 值 + pooled = attn_out.mean(dim=1) # (B, embed_dim) + q_values = self.q_head(pooled) # (B, action_size) + + return q_values, attn_weights diff --git a/critical/base_agent.py b/critical/base_agent.py new file mode 100644 index 00000000..944450a5 --- /dev/null +++ b/critical/base_agent.py @@ -0,0 +1,99 @@ +# rl_algorithms/base_agent.py +# 所有 RL 智能体的基类:设备管理、模型保存/加载、训练状态追踪 + +import os +from abc import ABC, abstractmethod + +import numpy as np +import torch + + +class BaseAgent(ABC): + """ + RL 智能体抽象基类。 + + 子类必须实现: + act(state) -> action + train() -> loss_info + save(path) / load(path) + """ + + def __init__(self, state_size, action_size, name="BaseAgent"): + self.state_size = state_size + self.action_size = action_size + self.name = name + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # 训练状态 + self.episode = 0 + self.total_steps = 0 + self.train_steps = 0 + + # 日志 + self.loss_history = [] + self.reward_history = [] + + @abstractmethod + def act(self, state): + """根据当前状态选择动作""" + + @abstractmethod + def train(self): + """执行一步学习更新,返回 loss 信息 dict 或 None""" + + def store(self, *args): + """ + 存储一条 transition。默认空操作,子类按需覆盖。 + DQN 子类存入 replay buffer,PPO 子类存入 rollout storage。 + """ + + def update_epsilon(self): + """探索率衰减(DQN 系用),子类覆盖""" + + def start_episode(self): + """新 episode 开始时调用""" + self.episode += 1 + + def end_episode(self, total_reward): + """episode 结束时记录奖励""" + self.reward_history.append(total_reward) + + # ================================================================ + # 模型持久化 + # ================================================================ + + def save(self, path): + """保存模型检查点""" + os.makedirs(os.path.dirname(path), exist_ok=True) + checkpoint = { + "episode": self.episode, + "total_steps": self.total_steps, + "train_steps": self.train_steps, + } + self._save_checkpoint(checkpoint, path) + + def load(self, path): + """加载模型检查点""" + if not os.path.exists(path): + raise FileNotFoundError("检查点不存在: %s" % path) + checkpoint = torch.load(path, map_location=self.device) + self.episode = checkpoint.get("episode", 0) + self.total_steps = checkpoint.get("total_steps", 0) + self.train_steps = checkpoint.get("train_steps", 0) + self._load_checkpoint(checkpoint) + + def _save_checkpoint(self, checkpoint, path): + """子类覆盖以添加自有参数""" + torch.save(checkpoint, path) + + def _load_checkpoint(self, checkpoint): + """子类覆盖以恢复自有参数""" + + # ================================================================ + # 工具 + # ================================================================ + + def to_tensor(self, x, dtype=torch.float32): + """便捷地将 numpy 转为 tensor 并放到正确设备""" + return torch.tensor(np.asarray(x), dtype=dtype, device=self.device) diff --git a/critical/clip_utils.py b/critical/clip_utils.py new file mode 100644 index 00000000..1fa4a4e1 --- /dev/null +++ b/critical/clip_utils.py @@ -0,0 +1,54 @@ +# rl_algorithms/ppo/clip_utils.py +# 标准裁剪与平滑裁剪函数(Smooth-PPO 创新点) + +import torch + + +def standard_clip(ratio, eps_clip): + """ + 标准 PPO 裁剪: + clip(ratio, 1-ε, 1+ε) + + 在边界处梯度直接截断。 + """ + return torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip) + + +def smooth_clip(ratio, eps_clip, alpha=0.1): + """ + 平滑裁剪(毕设创新点)。 + + 在裁剪边界的过渡区间 [ε-α, ε] 内使用二次插值替代硬截断, + 使梯度在边界附近连续变化,避免突变,从而: + - 提升策略更新的稳定性 + - 降低行为抖动,提高"行为自然度" + - 适合车辆对抗(加塞/急刹)等需要平稳动作的场景 + + 参数: + ratio: 重要性采样比率 r(θ) = π_new / π_old + eps_clip: 裁剪范围 ε + alpha: 平滑过渡宽度(默认 0.1 或从 config 读取) + """ + low = 1.0 - eps_clip # 下界 + high = 1.0 + eps_clip # 上界 + + # 过渡区间 + low_smooth = low + alpha + high_smooth = high - alpha + + # 基础 clamp + clipped = torch.clamp(ratio, low, high) + + # 下界的平滑过渡 [low, low_smooth] + mask_low = (ratio > low) & (ratio < low_smooth) + if mask_low.any(): + t = (ratio[mask_low] - low) / alpha # [0, 1] + clipped[mask_low] = low + alpha * (t ** 2) + + # 上界的平滑过渡 [high_smooth, high] + mask_high = (ratio > high_smooth) & (ratio < high) + if mask_high.any(): + t = (high - ratio[mask_high]) / alpha # [0, 1] + clipped[mask_high] = high - alpha * (t ** 2) + + return clipped diff --git a/critical/dqn_agent.py b/critical/dqn_agent.py new file mode 100644 index 00000000..5ce09609 --- /dev/null +++ b/critical/dqn_agent.py @@ -0,0 +1,151 @@ +# rl_algorithms/dqn/agent.py +# 标准 DQN 智能体:ε-greedy 探索 + 经验回放 + 目标网络 + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim + +from rl_algorithms.base_agent import BaseAgent +from rl_algorithms.dqn.network import QNetwork +from rl_algorithms.dqn.replay_buffer import ReplayBuffer +from config.dqn_config import ( + STATE_SIZE, ACTION_SIZE, HIDDEN_SIZES, + LEARNING_RATE, GAMMA, TAU, TARGET_UPDATE_FREQ, + MEMORY_SIZE, BATCH_SIZE, MIN_REPLAY_SIZE, + EPSILON_START, EPSILON_MIN, EPSILON_DECAY, + TRAIN_EVERY_N_STEPS, +) + + +class DQNAgent(BaseAgent): + """ + 标准 DQN 智能体。 + + 使用双网络(Q-Network + Target-Network)+ ε-greedy 探索。 + """ + + def __init__(self, state_size=None, action_size=None): + state_size = state_size or STATE_SIZE + action_size = action_size or ACTION_SIZE + super().__init__(state_size, action_size, name="DQN") + + # 网络 + self.q_net = QNetwork(state_size, action_size, HIDDEN_SIZES).to(self.device) + self.target_net = QNetwork(state_size, action_size, HIDDEN_SIZES).to(self.device) + self.target_net.load_state_dict(self.q_net.state_dict()) + + self.optimizer = optim.Adam(self.q_net.parameters(), lr=LEARNING_RATE) + self.loss_fn = nn.MSELoss() + + # 回放池 + self.memory = ReplayBuffer(MEMORY_SIZE) + self.batch_size = BATCH_SIZE + + # 超参数 + self.gamma = GAMMA + self.tau = TAU + self.target_update_freq = TARGET_UPDATE_FREQ + self.train_every = TRAIN_EVERY_N_STEPS + + # 探索 + self.epsilon = EPSILON_START + self.epsilon_min = EPSILON_MIN + self.epsilon_decay = EPSILON_DECAY + + # 上次训练的 loss + self.last_loss = None + + # ================================================================ + # 核心接口 + # ================================================================ + + def act(self, state, evaluate=False): + """ + 选择动作。 + + evaluate=True 时关闭探索(用于评估/测试)。 + """ + if not evaluate and np.random.random() < self.epsilon: + return np.random.randint(self.action_size) + + with torch.no_grad(): + state_t = self.to_tensor(state).unsqueeze(0) + q_values = self.q_net(state_t) + return q_values.argmax(dim=-1).item() + + def store(self, state, action, reward, next_state, done): + self.memory.push(state, action, reward, next_state, done) + + def train(self): + """执行一步学习更新""" + if len(self.memory) < MIN_REPLAY_SIZE: + return None + if self.total_steps % self.train_every != 0: + return None + + self.train_steps += 1 + + # 采样 + states, actions, rewards, next_states, dones = \ + self.memory.sample(self.batch_size) + + s = self.to_tensor(states) + a = torch.tensor(actions, dtype=torch.long, device=self.device).unsqueeze(1) + r = torch.tensor(rewards, dtype=torch.float32, device=self.device).unsqueeze(1) + ns = self.to_tensor(next_states) + d = torch.tensor(dones, dtype=torch.float32, device=self.device).unsqueeze(1) + + # 当前 Q 值 + q = self.q_net(s).gather(1, a) + + # 目标 Q 值(使用 target 网络) + with torch.no_grad(): + next_q = self.target_net(ns).max(dim=1, keepdim=True)[0] + target_q = r + self.gamma * next_q * (1 - d) + + loss = self.loss_fn(q, target_q) + + self.optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(self.q_net.parameters(), 10.0) + self.optimizer.step() + + # 软更新目标网络 + if self.train_steps % self.target_update_freq == 0: + self._soft_update() + + self.last_loss = loss.item() + self.update_epsilon() + return {"loss": self.last_loss, "epsilon": self.epsilon} + + def update_epsilon(self): + self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay) + + # ================================================================ + # 持久化 + # ================================================================ + + def _save_checkpoint(self, checkpoint, path): + checkpoint.update({ + "q_net": self.q_net.state_dict(), + "target_net": self.target_net.state_dict(), + "optimizer": self.optimizer.state_dict(), + "epsilon": self.epsilon, + }) + torch.save(checkpoint, path) + + def _load_checkpoint(self, checkpoint): + self.q_net.load_state_dict(checkpoint["q_net"]) + self.target_net.load_state_dict(checkpoint["target_net"]) + self.optimizer.load_state_dict(checkpoint["optimizer"]) + self.epsilon = checkpoint.get("epsilon", self.epsilon) + + # ================================================================ + # 内部 + # ================================================================ + + def _soft_update(self): + """目标网络软更新:θ_target = τ·θ + (1-τ)·θ_target""" + for tp, p in zip(self.target_net.parameters(), self.q_net.parameters()): + tp.data.copy_(self.tau * p.data + (1 - self.tau) * tp.data) diff --git a/critical/network.py b/critical/network.py new file mode 100644 index 00000000..f3ced00e --- /dev/null +++ b/critical/network.py @@ -0,0 +1,83 @@ +# rl_algorithms/ppo/network.py +# PPO 标准 Actor-Critic 网络 + +import torch +import torch.nn as nn + +from config.ppo_config import ( + STATE_SIZE, ACTION_SIZE, + HIDDEN_SIZES, ACTOR_HIDDEN, CRITIC_HIDDEN, ACTIVATION, +) + + +def _get_activation(name): + return {"relu": nn.ReLU(), "tanh": nn.Tanh()}.get(name, nn.Tanh()) + + +class Actor(nn.Module): + """PPO Actor: 输出离散动作概率分布""" + + def __init__(self, state_size=None, action_size=None, + hidden_sizes=None, head_hidden=None): + super().__init__() + self.state_size = state_size or STATE_SIZE + self.action_size = action_size or ACTION_SIZE + hs = hidden_sizes or HIDDEN_SIZES + hh = head_hidden or ACTOR_HIDDEN + act = _get_activation(ACTIVATION) + + # 共享特征提取 + shared = [] + prev = self.state_size + for h in hs: + shared.append(nn.Linear(prev, h)) + shared.append(act) + prev = h + self.shared = nn.Sequential(*shared) + + # 策略头 + head = [] + prev = hs[-1] if hs else self.state_size + for h in hh: + head.append(nn.Linear(prev, h)) + head.append(act) + prev = h + head.append(nn.Linear(prev, self.action_size)) + head.append(nn.Softmax(dim=-1)) + self.head = nn.Sequential(*head) + + def forward(self, x): + features = self.shared(x) + return self.head(features) + + +class Critic(nn.Module): + """PPO Critic: 输出状态价值 V(s)""" + + def __init__(self, state_size=None, hidden_sizes=None, head_hidden=None): + super().__init__() + self.state_size = state_size or STATE_SIZE + hs = hidden_sizes or HIDDEN_SIZES + hh = head_hidden or CRITIC_HIDDEN + act = _get_activation(ACTIVATION) + + shared = [] + prev = self.state_size + for h in hs: + shared.append(nn.Linear(prev, h)) + shared.append(act) + prev = h + self.shared = nn.Sequential(*shared) + + head = [] + prev = hs[-1] if hs else self.state_size + for h in hh: + head.append(nn.Linear(prev, h)) + head.append(act) + prev = h + head.append(nn.Linear(prev, 1)) + self.head = nn.Sequential(*head) + + def forward(self, x): + features = self.shared(x) + return self.head(features) diff --git a/critical/ppo_agent.py b/critical/ppo_agent.py new file mode 100644 index 00000000..b4f4142e --- /dev/null +++ b/critical/ppo_agent.py @@ -0,0 +1,200 @@ +# rl_algorithms/ppo/agent.py +# 标准 PPO 智能体:Actor-Critic + GAE + 标准裁剪 + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim + +from rl_algorithms.base_agent import BaseAgent +from rl_algorithms.ppo.network import Actor, Critic +from rl_algorithms.ppo.storage import RolloutStorage +from rl_algorithms.ppo.clip_utils import standard_clip +from config.ppo_config import ( + STATE_SIZE, ACTION_SIZE, + LR_ACTOR, LR_CRITIC, + GAMMA, LAMBDA, EPS_CLIP, + UPDATE_EVERY, UPDATE_POLICY_TIMES, BATCH_SIZE, + VALUE_LOSS_COEF, ENTROPY_COEF, MAX_GRAD_NORM, +) + + +class PPOAgent(BaseAgent): + """ + 标准 PPO 智能体(Proximal Policy Optimization)。 + + 适用场景: #1 大雨跟车, #6 行人横穿, #8 行人闯红灯 + """ + + def __init__(self, state_size=None, action_size=None): + state_size = state_size or STATE_SIZE + action_size = action_size or ACTION_SIZE + super().__init__(state_size, action_size, name="PPO") + + self.actor = Actor(state_size, action_size).to(self.device) + self.critic = Critic(state_size).to(self.device) + + self.actor_opt = optim.Adam(self.actor.parameters(), lr=LR_ACTOR) + self.critic_opt = optim.Adam(self.critic.parameters(), lr=LR_CRITIC) + + self.gamma = GAMMA + self.lambd = LAMBDA + self.eps_clip = EPS_CLIP + self.update_every = UPDATE_EVERY + self.k_epochs = UPDATE_POLICY_TIMES + self.batch_size = BATCH_SIZE + self.value_coef = VALUE_LOSS_COEF + self.entropy_coef = ENTROPY_COEF + self.max_grad_norm = MAX_GRAD_NORM + + self.storage = RolloutStorage() + + # 上次训练的 loss 信息 + self.last_loss_info = {} + + # ================================================================ + # 核心接口 + # ================================================================ + + def act(self, state, evaluate=False): + """ + 采样动作。返回 (action, log_prob)。 + evaluate=True 时返回概率最高的动作。 + """ + with torch.no_grad(): + state_t = self.to_tensor(state).unsqueeze(0) + probs = self.actor(state_t) + dist = torch.distributions.Categorical(probs) + if evaluate: + action = probs.argmax(dim=-1) + else: + action = dist.sample() + log_prob = dist.log_prob(action) + return action.item(), log_prob.item() + + def store(self, state, action, log_prob, reward, next_state, done): + self.storage.push(state, action, log_prob, reward, next_state, done) + + def train(self): + """收集足够步数后执行 PPO 更新""" + if len(self.storage) < self.update_every: + return None + + self.train_steps += 1 + + states, actions, old_log_probs, rewards, next_states, dones = \ + self.storage.get_all() + + s = self.to_tensor(states) + a = torch.tensor(actions, dtype=torch.long, device=self.device) + old_lp = torch.tensor(old_log_probs, dtype=torch.float32, device=self.device) + + # 计算 GAE 和 returns + with torch.no_grad(): + values = self.critic(s).squeeze(-1) + next_val = self.critic( + self.to_tensor(next_states[-1:])).squeeze(-1).item() + + advantages = self._compute_gae( + rewards, values.detach().cpu().numpy(), next_val, dones) + advantages = self.to_tensor(advantages) + returns = advantages + values.detach() + + # 标准化 advantages + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + + total_actor_loss = 0.0 + total_critic_loss = 0.0 + + n = len(states) + for _ in range(self.k_epochs): + # 小批量训练 + indices = torch.randperm(n) + for start in range(0, n, self.batch_size): + idx = indices[start:start + self.batch_size] + + s_batch = s[idx] + a_batch = a[idx] + old_lp_batch = old_lp[idx] + adv_batch = advantages[idx] + ret_batch = returns[idx] + + # Actor 损失 + probs = self.actor(s_batch) + dist = torch.distributions.Categorical(probs) + new_lp = dist.log_prob(a_batch) + entropy = dist.entropy().mean() + + ratio = torch.exp(new_lp - old_lp_batch) + clipped = standard_clip(ratio, self.eps_clip) + actor_loss = -torch.min( + ratio * adv_batch, clipped * adv_batch).mean() + actor_loss = actor_loss - self.entropy_coef * entropy + + self.actor_opt.zero_grad() + actor_loss.backward() + torch.nn.utils.clip_grad_norm_( + self.actor.parameters(), self.max_grad_norm) + self.actor_opt.step() + + # Critic 损失 + values_pred = self.critic(s_batch).squeeze(-1) + critic_loss = self.value_coef * nn.MSELoss()(values_pred, ret_batch) + + self.critic_opt.zero_grad() + critic_loss.backward() + torch.nn.utils.clip_grad_norm_( + self.critic.parameters(), self.max_grad_norm) + self.critic_opt.step() + + total_actor_loss += actor_loss.item() + total_critic_loss += critic_loss.item() + + self.storage.clear() + self.last_loss_info = { + "actor_loss": total_actor_loss / max(self.k_epochs, 1), + "critic_loss": total_critic_loss / max(self.k_epochs, 1), + } + return self.last_loss_info + + # ================================================================ + # GAE + # ================================================================ + + def _compute_gae(self, rewards, values, next_value, dones): + """ + 计算 Generalized Advantage Estimation。 + + rewards: list of float + values: np.ndarray (T,) V(s_t) + next_value: float V(s_{T+1}) + dones: np.ndarray (T,) + """ + T = len(rewards) + vals = np.append(values, next_value) + advantages = np.zeros(T, dtype=np.float32) + gae = 0.0 + for t in reversed(range(T)): + delta = rewards[t] + self.gamma * vals[t + 1] * (1 - dones[t]) - vals[t] + gae = delta + self.gamma * self.lambd * (1 - dones[t]) * gae + advantages[t] = gae + return advantages + + # ================================================================ + # 持久化 + # ================================================================ + + def _save_checkpoint(self, checkpoint, path): + checkpoint.update({ + "actor": self.actor.state_dict(), + "critic": self.critic.state_dict(), + "actor_opt": self.actor_opt.state_dict(), + "critic_opt": self.critic_opt.state_dict(), + }) + torch.save(checkpoint, path) + + def _load_checkpoint(self, checkpoint): + self.actor.load_state_dict(checkpoint["actor"]) + self.critic.load_state_dict(checkpoint["critic"]) + self.actor_opt.load_state_dict(checkpoint["actor_opt"]) + self.critic_opt.load_state_dict(checkpoint["critic_opt"]) diff --git a/critical/replay_buffer.py b/critical/replay_buffer.py new file mode 100644 index 00000000..bd38030a --- /dev/null +++ b/critical/replay_buffer.py @@ -0,0 +1,37 @@ +# rl_algorithms/dqn/replay_buffer.py +# 经验回放池:优先从容量缓冲区中均匀采样 + +import random +from collections import deque + +import numpy as np + + +class ReplayBuffer: + """固定容量的经验回放池 (FIFO)""" + + def __init__(self, capacity): + self.buffer = deque(maxlen=capacity) + self.capacity = capacity + + def push(self, state, action, reward, next_state, done): + self.buffer.append((state, action, reward, next_state, done)) + + def sample(self, batch_size): + """均匀随机采样,返回 numpy 数组""" + batch = random.sample(self.buffer, batch_size) + states, actions, rewards, next_states, dones = zip(*batch) + return ( + np.array(states, dtype=np.float32), + np.array(actions, dtype=np.int64), + np.array(rewards, dtype=np.float32), + np.array(next_states, dtype=np.float32), + np.array(dones, dtype=np.float32), + ) + + def __len__(self): + return len(self.buffer) + + def is_ready(self, min_size): + """回放池是否已达到最低采样量""" + return len(self.buffer) >= min_size diff --git a/critical/smooth_agent.py b/critical/smooth_agent.py new file mode 100644 index 00000000..63ce8854 --- /dev/null +++ b/critical/smooth_agent.py @@ -0,0 +1,187 @@ +# rl_algorithms/ppo/smooth_agent.py +# Smooth-PPO 智能体:平滑裁剪 + LayerNorm 网络(毕设创新算法) + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim + +from rl_algorithms.base_agent import BaseAgent +from rl_algorithms.ppo.smooth_network import SmoothActor, SmoothCritic +from rl_algorithms.ppo.storage import RolloutStorage +from rl_algorithms.ppo.clip_utils import smooth_clip +from config.ppo_config import ( + STATE_SIZE, ACTION_SIZE, + LR_ACTOR, LR_CRITIC, + GAMMA, LAMBDA, EPS_CLIP, + UPDATE_EVERY, UPDATE_POLICY_TIMES, BATCH_SIZE, + VALUE_LOSS_COEF, ENTROPY_COEF, MAX_GRAD_NORM, + SMOOTH_ENABLED, SMOOTH_EPS_LOW, SMOOTH_EPS_HIGH, SMOOTH_BETA, +) + + +class SmoothPPOAgent(BaseAgent): + """ + Smooth-PPO 智能体(毕设创新算法)。 + + 双重创新: + 1. 网络层:LayerNorm 平滑特征分布,减少梯度震荡 + 2. 裁剪层:平滑裁剪函数替代硬截断,在边界处连续过渡 + + 适用场景: #4 前车急刹, #5 旁车加塞(行为自然度要求高) + """ + + def __init__(self, state_size=None, action_size=None): + state_size = state_size or STATE_SIZE + action_size = action_size or ACTION_SIZE + super().__init__(state_size, action_size, name="SmoothPPO") + + # 平滑网络(含 LayerNorm) + self.actor = SmoothActor(state_size, action_size).to(self.device) + self.critic = SmoothCritic(state_size).to(self.device) + + self.actor_opt = optim.Adam(self.actor.parameters(), lr=LR_ACTOR) + self.critic_opt = optim.Adam(self.critic.parameters(), lr=LR_CRITIC) + + self.gamma = GAMMA + self.lambd = LAMBDA + self.eps_clip = EPS_CLIP + self.update_every = UPDATE_EVERY + self.k_epochs = UPDATE_POLICY_TIMES + self.batch_size = BATCH_SIZE + self.value_coef = VALUE_LOSS_COEF + self.entropy_coef = ENTROPY_COEF + self.max_grad_norm = MAX_GRAD_NORM + + # 平滑裁剪参数 + self.smooth_enabled = SMOOTH_ENABLED + self.smooth_alpha = (SMOOTH_EPS_HIGH - SMOOTH_EPS_LOW) / 2.0 + + self.storage = RolloutStorage() + self.last_loss_info = {} + + # ================================================================ + # 核心接口 + # ================================================================ + + def act(self, state, evaluate=False): + with torch.no_grad(): + state_t = self.to_tensor(state).unsqueeze(0) + probs = self.actor(state_t) + dist = torch.distributions.Categorical(probs) + if evaluate: + action = probs.argmax(dim=-1) + else: + action = dist.sample() + log_prob = dist.log_prob(action) + return action.item(), log_prob.item() + + def store(self, state, action, log_prob, reward, next_state, done): + self.storage.push(state, action, log_prob, reward, next_state, done) + + def train(self): + """收集足够步数后执行 Smooth-PPO 更新""" + if len(self.storage) < self.update_every: + return None + + self.train_steps += 1 + + states, actions, old_log_probs, rewards, next_states, dones = \ + self.storage.get_all() + + s = self.to_tensor(states) + a = torch.tensor(actions, dtype=torch.long, device=self.device) + old_lp = torch.tensor(old_log_probs, dtype=torch.float32, device=self.device) + + with torch.no_grad(): + values = self.critic(s).squeeze(-1) + next_val = self.critic( + self.to_tensor(next_states[-1:])).squeeze(-1).item() + + advantages = self._compute_gae( + rewards, values.detach().cpu().numpy(), next_val, dones) + advantages = self.to_tensor(advantages) + returns = advantages + values.detach() + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + + total_actor_loss = 0.0 + total_critic_loss = 0.0 + + n = len(states) + for _ in range(self.k_epochs): + indices = torch.randperm(n) + for start in range(0, n, self.batch_size): + idx = indices[start:start + self.batch_size] + s_b = s[idx]; a_b = a[idx]; old_lp_b = old_lp[idx] + adv_b = advantages[idx]; ret_b = returns[idx] + + probs = self.actor(s_b) + dist = torch.distributions.Categorical(probs) + new_lp = dist.log_prob(a_b) + entropy = dist.entropy().mean() + + ratio = torch.exp(new_lp - old_lp_b) + # 使用平滑裁剪(核心创新) + clipped = smooth_clip(ratio, self.eps_clip, self.smooth_alpha) + actor_loss = -torch.min( + ratio * adv_b, clipped * adv_b).mean() + actor_loss = actor_loss - self.entropy_coef * entropy + + self.actor_opt.zero_grad() + actor_loss.backward() + torch.nn.utils.clip_grad_norm_( + self.actor.parameters(), self.max_grad_norm) + self.actor_opt.step() + + values_pred = self.critic(s_b).squeeze(-1) + critic_loss = self.value_coef * nn.MSELoss()(values_pred, ret_b) + + self.critic_opt.zero_grad() + critic_loss.backward() + torch.nn.utils.clip_grad_norm_( + self.critic.parameters(), self.max_grad_norm) + self.critic_opt.step() + + total_actor_loss += actor_loss.item() + total_critic_loss += critic_loss.item() + + self.storage.clear() + self.last_loss_info = { + "actor_loss": total_actor_loss / max(self.k_epochs, 1), + "critic_loss": total_critic_loss / max(self.k_epochs, 1), + } + return self.last_loss_info + + # ================================================================ + # GAE + # ================================================================ + + def _compute_gae(self, rewards, values, next_value, dones): + T = len(rewards) + vals = np.append(values, next_value) + advantages = np.zeros(T, dtype=np.float32) + gae = 0.0 + for t in reversed(range(T)): + delta = rewards[t] + self.gamma * vals[t + 1] * (1 - dones[t]) - vals[t] + gae = delta + self.gamma * self.lambd * (1 - dones[t]) * gae + advantages[t] = gae + return advantages + + # ================================================================ + # 持久化 + # ================================================================ + + def _save_checkpoint(self, checkpoint, path): + checkpoint.update({ + "actor": self.actor.state_dict(), + "critic": self.critic.state_dict(), + "actor_opt": self.actor_opt.state_dict(), + "critic_opt": self.critic_opt.state_dict(), + }) + torch.save(checkpoint, path) + + def _load_checkpoint(self, checkpoint): + self.actor.load_state_dict(checkpoint["actor"]) + self.critic.load_state_dict(checkpoint["critic"]) + self.actor_opt.load_state_dict(checkpoint["actor_opt"]) + self.critic_opt.load_state_dict(checkpoint["critic_opt"]) diff --git a/critical/smooth_network.py b/critical/smooth_network.py new file mode 100644 index 00000000..71fa5e0b --- /dev/null +++ b/critical/smooth_network.py @@ -0,0 +1,92 @@ +# rl_algorithms/ppo/smooth_network.py +# Smooth-PPO 网络:带 LayerNorm 的平滑策略 Actor-Critic(毕设创新) + +import torch +import torch.nn as nn + +from config.ppo_config import ( + STATE_SIZE, ACTION_SIZE, + HIDDEN_SIZES, ACTOR_HIDDEN, CRITIC_HIDDEN, ACTIVATION, +) + + +def _get_activation(name): + return {"relu": nn.ReLU(), "tanh": nn.Tanh()}.get(name, nn.Tanh()) + + +class SmoothActor(nn.Module): + """ + Smooth-PPO Actor 网络(创新点)。 + + 与标准 Actor 的区别:在每一层后加入 LayerNorm, + 使特征分布更加平滑稳定,减少梯度震荡,提升行为自然度。 + """ + + def __init__(self, state_size=None, action_size=None, + hidden_sizes=None, head_hidden=None): + super().__init__() + self.state_size = state_size or STATE_SIZE + self.action_size = action_size or ACTION_SIZE + hs = hidden_sizes or HIDDEN_SIZES + hh = head_hidden or ACTOR_HIDDEN + act = _get_activation(ACTIVATION) + + shared = [] + prev = self.state_size + for h in hs: + shared.append(nn.Linear(prev, h)) + shared.append(nn.LayerNorm(h)) + shared.append(act) + prev = h + self.shared = nn.Sequential(*shared) + + head = [] + prev = hs[-1] if hs else self.state_size + for h in hh: + head.append(nn.Linear(prev, h)) + head.append(nn.LayerNorm(h)) + head.append(act) + prev = h + head.append(nn.Linear(prev, self.action_size)) + head.append(nn.Softmax(dim=-1)) + self.head = nn.Sequential(*head) + + def forward(self, x): + return self.head(self.shared(x)) + + +class SmoothCritic(nn.Module): + """ + Smooth-PPO Critic 网络。 + + 同样加入 LayerNorm 以平滑价值估计。 + """ + + def __init__(self, state_size=None, hidden_sizes=None, head_hidden=None): + super().__init__() + self.state_size = state_size or STATE_SIZE + hs = hidden_sizes or HIDDEN_SIZES + hh = head_hidden or CRITIC_HIDDEN + act = _get_activation(ACTIVATION) + + shared = [] + prev = self.state_size + for h in hs: + shared.append(nn.Linear(prev, h)) + shared.append(nn.LayerNorm(h)) + shared.append(act) + prev = h + self.shared = nn.Sequential(*shared) + + head = [] + prev = hs[-1] if hs else self.state_size + for h in hh: + head.append(nn.Linear(prev, h)) + head.append(nn.LayerNorm(h)) + head.append(act) + prev = h + head.append(nn.Linear(prev, 1)) + self.head = nn.Sequential(*head) + + def forward(self, x): + return self.head(self.shared(x)) diff --git a/critical/storage.py b/critical/storage.py new file mode 100644 index 00000000..9f91f08a --- /dev/null +++ b/critical/storage.py @@ -0,0 +1,55 @@ +# rl_algorithms/ppo/storage.py +# 轨迹数据存储(Rollout Buffer),用于 GAE 优势计算 + +import numpy as np + + +class RolloutStorage: + """ + PPO 轨迹存储。 + + 在 rollout 阶段收集 (s, a, log_prob, r, s', done) 序列, + 在 update 阶段批量取出用于 GAE + 多轮策略更新。 + """ + + def __init__(self): + self.states = [] + self.actions = [] + self.log_probs = [] + self.rewards = [] + self.next_states = [] + self.dones = [] + + def push(self, state, action, log_prob, reward, next_state, done): + self.states.append(np.asarray(state, dtype=np.float32)) + self.actions.append(action) + self.log_probs.append(float(log_prob)) + self.rewards.append(float(reward)) + self.next_states.append(np.asarray(next_state, dtype=np.float32)) + self.dones.append(float(done)) + + def get_all(self): + """返回全部轨迹数据的 numpy 数组""" + return ( + np.array(self.states, dtype=np.float32), + np.array(self.actions, dtype=np.int64), + np.array(self.log_probs, dtype=np.float32), + np.array(self.rewards, dtype=np.float32), + np.array(self.next_states, dtype=np.float32), + np.array(self.dones, dtype=np.float32), + ) + + def clear(self): + self.states.clear() + self.actions.clear() + self.log_probs.clear() + self.rewards.clear() + self.next_states.clear() + self.dones.clear() + + def __len__(self): + return len(self.states) + + @property + def total_reward(self): + return sum(self.rewards)