131 lines
5.2 KiB
Python
131 lines
5.2 KiB
Python
import os
|
|
import torch
|
|
import torch.optim as optim
|
|
from copy import deepcopy
|
|
from model import QNetwork, DuelingQNetwork, NoisyQNetwork
|
|
from gymnasium.wrappers import TimeLimit
|
|
|
|
class DQNAgent:
|
|
def __init__(self, state_size, action_size, cfg, device='cuda'):
|
|
self.device = device
|
|
self.use_double = cfg.use_double
|
|
self.use_dueling = cfg.use_dueling
|
|
self.use_noisy = cfg.use_noisy
|
|
self.noisy_sigma = cfg.noisy_sigma
|
|
self.target_update_interval = cfg.target_update_interval
|
|
q_model = QNetwork
|
|
if self.use_dueling:
|
|
q_model = DuelingQNetwork
|
|
if self.use_noisy:
|
|
q_model = NoisyQNetwork
|
|
self.q_net = q_model(state_size, action_size, cfg.hidden_size, cfg.activation, sigma_init=cfg.noisy_sigma).to(self.device)
|
|
else:
|
|
self.q_net = q_model(state_size, action_size, cfg.hidden_size, cfg.activation).to(self.device)
|
|
self.target_net = deepcopy(self.q_net).to(self.device)
|
|
self.optimizer = optim.AdamW(self.q_net.parameters(), lr=cfg.lr)
|
|
|
|
self.tau = cfg.tau
|
|
# update the gamma we use in the Bellman equation for n-step DQN
|
|
self.gamma = cfg.gamma ** cfg.nstep
|
|
|
|
def soft_update(self, target, source):
|
|
"""
|
|
Soft update the target network using the source network
|
|
"""
|
|
for target_param, source_param in zip(target.parameters(), source.parameters()):
|
|
target_param.data.copy_((1 - self.tau) * target_param.data + self.tau * source_param.data)
|
|
|
|
@torch.no_grad()
|
|
def get_action(self, state):
|
|
"""
|
|
Get the action according to the current state and Q value
|
|
"""
|
|
|
|
############################
|
|
# YOUR IMPLEMENTATION HERE #
|
|
|
|
# update from single state
|
|
torch_max_idx = torch.argmax(self.q_net(torch.tensor(state).to(self.device)), dim=0)
|
|
action = torch_max_idx.cpu().numpy()
|
|
############################
|
|
return action
|
|
|
|
@torch.no_grad()
|
|
def get_Q_target(self, state, action, reward, done, next_state) -> torch.Tensor:
|
|
"""
|
|
Get the target Q value according to the Bellman equation
|
|
"""
|
|
if self.use_double:
|
|
# YOUR IMPLEMENTATION HERE
|
|
reward_tensor = reward.to(self.device)
|
|
next_q_tensor = self.target_net(next_state.to(self.device))
|
|
next_action = torch.argmax(self.q_net(next_state.to(self.device)), dim=1).unsqueeze(1)
|
|
# print(next_q_tensor.shape, next_action.shape)
|
|
# return the max Q value
|
|
next_q = torch.gather(next_q_tensor, dim=1, index=next_action).squeeze(1)
|
|
q_target = reward_tensor + (1-done.to(self.device)) * self.gamma * next_q
|
|
return q_target
|
|
|
|
else:
|
|
# YOUR IMPLEMENTATION HERE
|
|
reward_tensor = reward.to(self.device)
|
|
# update from batch states
|
|
next_q_tensor = self.target_net(next_state.to(self.device))
|
|
# return the max Q value
|
|
next_q = torch.max(next_q_tensor, dim=1).values
|
|
q_target = reward_tensor + (1-done.to(self.device)) * self.gamma * next_q
|
|
return q_target
|
|
|
|
def get_Q(self, state, action, use_double_net=False) -> torch.Tensor:
|
|
"""
|
|
Get the Q value of the current state and action
|
|
"""
|
|
############################
|
|
# YOUR IMPLEMENTATION HERE #
|
|
# elegant python move by Jack Wu. Fantastic...
|
|
# q= self.q_net(state.to(self.device))[:, action.int()]
|
|
# update from batch states
|
|
q_tensor = self.q_net(state.to(self.device))
|
|
action_idx = action.squeeze(1).to(dtype=torch.int32).to(self.device)
|
|
# select corresponding action, do not use index_select... That don't works
|
|
q = q_tensor.gather(1, action_idx.unsqueeze(1)).squeeze(1)
|
|
return q
|
|
############################
|
|
|
|
def update(self, batch, step, weights=None):
|
|
state, action, reward, next_state, done = batch
|
|
|
|
Q_target = self.get_Q_target(state, action, reward, done, next_state)
|
|
|
|
Q = self.get_Q(state, action)
|
|
|
|
if weights is None:
|
|
weights = torch.ones_like(Q).to(self.device)
|
|
|
|
td_error = torch.abs(Q - Q_target).detach()
|
|
loss = torch.mean((Q - Q_target)**2 * weights)
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
|
|
if not step % self.target_update_interval:
|
|
with torch.no_grad():
|
|
self.soft_update(self.target_net, self.q_net)
|
|
|
|
return loss.item(), td_error, Q.mean().item()
|
|
|
|
def save(self, name):
|
|
os.makedirs('models', exist_ok=True)
|
|
torch.save(self.q_net.state_dict(), os.path.join('models', name))
|
|
|
|
def load(self, name='best.pt'):
|
|
self.q_net.load_state_dict(torch.load(os.path.join('models', name)))
|
|
|
|
def __repr__(self) -> str:
|
|
use_double = 'Double' if self.use_double else ''
|
|
use_dueling = 'Dueling' if self.use_dueling else ''
|
|
use_noisy = 'Noisy' if self.use_noisy else ''
|
|
prefix = 'Normal' if not self.use_double and not self.use_dueling and not self.use_noisy else ''
|
|
suffix = f'with noisy sigma={self.noisy_sigma}' if self.use_noisy else ''
|
|
return use_double + use_dueling + use_noisy+ prefix + 'QNetwork' + suffix
|