企业网站内容如何更新,软件开发公司网站模板,小程序多少钱一年,怎么创建公司强化学习原理python篇05——DQN DQN 算法定义DQN网络初始化环境开始训练可视化结果 本章全篇参考赵世钰老师的教材 Mathmatical-Foundation-of-Reinforcement-Learning Deep Q-learning 章节#xff0c;请各位结合阅读#xff0c;本合集只专注于数学概念的代码实现。
DQN 算… 强化学习原理python篇05——DQN DQN 算法定义DQN网络初始化环境开始训练可视化结果 本章全篇参考赵世钰老师的教材 Mathmatical-Foundation-of-Reinforcement-Learning Deep Q-learning 章节请各位结合阅读本合集只专注于数学概念的代码实现。
DQN 算法
1使用随机权重 w ← 1.0 w←1.0 w←1.0初始化目标网络 Q ( s , a , w ) Q(s, a, w) Q(s,a,w)和网络 Q ^ ( s , a , w ) \hat Q(s, a, w) Q^(s,a,w) Q Q Q和 Q ^ \hat Q Q^相同清空回放缓冲区。
2以概率ε选择一个随机动作a否则 a a r g m a x Q ( s , a , w ) aargmaxQ(s,a,w) aargmaxQ(s,a,w)。
3在模拟器中执行动作a观察奖励r和下一个状态s’。
4将转移过程(s, a, r, s’)存储在回放缓冲区中。
5从回放缓冲区中采样一个随机的小批量转移过程。
6对于回放缓冲区中的每个转移过程如果片段在此步结束则计算目标 y r yr yr否则计算 y r γ m a x Q ^ ( s , a , w ) yr\gamma max \hat Q(s, a, w) yrγmaxQ^(s,a,w) 。
7计算损失 L ( Q ( s , a , w ) – y ) 2 L(Q(s, a, w)–y)^2 L(Q(s,a,w)–y)2。
8固定网络 Q ^ ( s , a , w ) \hat Q(s, a, w) Q^(s,a,w)不变通过最小化模型参数的损失使用SGD算法更新 Q ( s , a ) Q(s, a) Q(s,a)。
9每N步将权重从目标网络 Q Q Q复制到 Q ^ ( s , a , w ) \hat Q(s, a, w) Q^(s,a,w) 。
10从步骤2开始重复直到收敛为止。
定义DQN网络
import collections
import copy
import random
from collections import defaultdict
import math
import gym
import gym.spaces
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from gym.envs.toy_text import frozen_lake
from torch.utils.tensorboard import SummaryWriterclass Net(nn.Module):def __init__(self, obs_size, hidden_size, q_table_size):super(Net, self).__init__()self.net nn.Sequential(# 输入为状态样本为1*nnn.Linear(obs_size, hidden_size),nn.ReLU(),# nn.Linear(hidden_size, hidden_size),# nn.ReLU(),nn.Linear(hidden_size, q_table_size),)def forward(self, state):return self.net(state)class DQN:def __init__(self, env, tgt_net, net):self.env envself.tgt_net tgt_netself.net netdef generate_train_data(self, batch_size, epsilon):state, _ env.reset()train_data []while len(train_data)batch_size*2:q_table_tgt self.tgt_net(torch.Tensor(state)).detach()if np.random.uniform(0, 1, 1) epsilon:action self.env.action_space.sample()else:action int(torch.argmax(q_table_tgt))new_state, reward,terminated, truncted, info env.step(action)train_data.append([state, action, reward, new_state, terminated])state new_stateif terminated:state, _ env.reset()continuerandom.shuffle(train_data) return train_data[:batch_size]def calculate_y_hat_and_y(self, batch):# 6对于回放缓冲区中的每个转移过程如果片段在此步结束则计算目标$yr$否则计算$yr\gamma max \hat Q(s, a, w)$ 。y []state_space []action_space []for state, action, reward, new_state, terminated in batch:# y值if terminated:y.append(reward)else:# 下一步的 qtable 的最大值q_table_net self.net(torch.Tensor(np.array([new_state]))).detach()y.append(reward gamma * float(torch.max(q_table_net)))# y hat的值state_space.append(state)action_space.append(action)idx [list(range(len(action_space))), action_space]y_hat self.tgt_net(torch.Tensor(np.array(state_space)))[idx]return y_hat, torch.tensor(y)def update_net_parameters(self, updateTrue):self.net.load_state_dict(self.tgt_net.state_dict())
初始化环境 # 初始化环境
env gym.make(CartPole-v1)
# env DiscreteOneHotWrapper(env)hidden_num 64
# 定义网络
net Net(env.observation_space.shape[0],hidden_num, env.action_space.n)
tgt_net Net(env.observation_space.shape[0],hidden_num, env.action_space.n)
dqn DQN(envenv, netnet, tgt_nettgt_net)# 初始化参数
# dqn.init_net_and_target_net_weight()# 定义优化器
opt optim.Adam(tgt_net.parameters(), lr0.001)# 定义损失函数
loss nn.MSELoss()# 记录训练过程
# writer SummaryWriter(log_dirlogs/DQN, commentDQN)
开始训练
gamma 0.8
for i in range(10000):batch dqn.generate_train_data(256, 0.8)y_hat, y dqn.calculate_y_hat_and_y(batch)opt.zero_grad()l loss(y_hat, y)l.backward()opt.step()print(MSE: {}.format(l.item()))if i % 5 0:dqn.update_net_parameters(updateTrue)输出
MSE: 0.027348674833774567
MSE: 0.1803671419620514
MSE: 0.06523636728525162
MSE: 0.08363766968250275
MSE: 0.062360599637031555
MSE: 0.004909628536552191
MSE: 0.05730309337377548
MSE: 0.03543371334671974
MSE: 0.08458714932203293可视化结果
env gym.make(CartPole-v1, render_mode human)
env gym.wrappers.RecordVideo(env, video_foldervideo)state, info env.reset()
total_rewards 0while True:q_table_state dqn.tgt_net(torch.Tensor(state)).detach()# if np.random.uniform(0, 1, 1) 0.9:# action env.action_space.sample()# else:action int(torch.argmax(q_table_state))state, reward, terminated, truncted, info env.step(action)if terminated:break