import os import torch import torch.optim as optim from copy import deepcopy from model import QNetwork, DuelingQNetwork, NoisyQNetwork from gymnasium.wrappers import TimeLimit class DQNAgent: def __init__(self, state_size, action_size, cfg, device='cuda'): self.device = device self.use_double = cfg.use_double self.use_dueling = cfg.use_dueling self.use_noisy = cfg.use_noisy self.noisy_sigma = cfg.noisy_sigma self.target_update_interval = cfg.target_update_interval q_model = QNetwork if self.use_dueling: q_model = DuelingQNetwork if self.use_noisy: q_model = NoisyQNetwork self.q_net = q_model(state_size, action_size, cfg.hidden_size, cfg.activation, sigma_init=cfg.noisy_sigma).to(self.device) else: self.q_net = q_model(state_size, action_size, cfg.hidden_size, cfg.activation).to(self.device) self.target_net = deepcopy(self.q_net).to(self.device) self.optimizer = optim.AdamW(self.q_net.parameters(), lr=cfg.lr) self.tau = cfg.tau # update the gamma we use in the Bellman equation for n-step DQN self.gamma = cfg.gamma ** cfg.nstep def soft_update(self, target, source): """ Soft update the target network using the source network """ for target_param, source_param in zip(target.parameters(), source.parameters()): target_param.data.copy_((1 - self.tau) * target_param.data + self.tau * source_param.data) @torch.no_grad() def get_action(self, state): """ Get the action according to the current state and Q value """ ############################ # YOUR IMPLEMENTATION HERE # # update from single state torch_max_idx = torch.argmax(self.q_net(torch.tensor(state).to(self.device)), dim=0) action = torch_max_idx.cpu().numpy() ############################ return action @torch.no_grad() def get_Q_target(self, state, action, reward, done, next_state) -> torch.Tensor: """ Get the target Q value according to the Bellman equation """ if self.use_double: # YOUR IMPLEMENTATION HERE reward_tensor = reward.to(self.device) next_q_tensor = self.target_net(next_state.to(self.device)) next_action = torch.argmax(self.q_net(next_state.to(self.device)), dim=1).unsqueeze(1) # print(next_q_tensor.shape, next_action.shape) # return the max Q value next_q = torch.gather(next_q_tensor, dim=1, index=next_action).squeeze(1) q_target = reward_tensor + (1-done.to(self.device)) * self.gamma * next_q return q_target else: # YOUR IMPLEMENTATION HERE reward_tensor = reward.to(self.device) # update from batch states next_q_tensor = self.target_net(next_state.to(self.device)) # return the max Q value next_q = torch.max(next_q_tensor, dim=1).values q_target = reward_tensor + (1-done.to(self.device)) * self.gamma * next_q return q_target def get_Q(self, state, action, use_double_net=False) -> torch.Tensor: """ Get the Q value of the current state and action """ ############################ # YOUR IMPLEMENTATION HERE # # elegant python move by Jack Wu. Fantastic... # q= self.q_net(state.to(self.device))[:, action.int()] # update from batch states q_tensor = self.q_net(state.to(self.device)) action_idx = action.squeeze(1).to(dtype=torch.int32).to(self.device) # select corresponding action, do not use index_select... That don't works q = q_tensor.gather(1, action_idx.unsqueeze(1)).squeeze(1) return q ############################ def update(self, batch, step, weights=None): state, action, reward, next_state, done = batch Q_target = self.get_Q_target(state, action, reward, done, next_state) Q = self.get_Q(state, action) if weights is None: weights = torch.ones_like(Q).to(self.device) td_error = torch.abs(Q - Q_target).detach() loss = torch.mean((Q - Q_target)**2 * weights) self.optimizer.zero_grad() loss.backward() self.optimizer.step() if not step % self.target_update_interval: with torch.no_grad(): self.soft_update(self.target_net, self.q_net) return loss.item(), td_error, Q.mean().item() def save(self, name): os.makedirs('models', exist_ok=True) torch.save(self.q_net.state_dict(), os.path.join('models', name)) def load(self, name='best.pt'): self.q_net.load_state_dict(torch.load(os.path.join('models', name))) def __repr__(self) -> str: use_double = 'Double' if self.use_double else '' use_dueling = 'Dueling' if self.use_dueling else '' use_noisy = 'Noisy' if self.use_noisy else '' prefix = 'Normal' if not self.use_double and not self.use_dueling and not self.use_noisy else '' suffix = f'with noisy sigma={self.noisy_sigma}' if self.use_noisy else '' return use_double + use_dueling + use_noisy+ prefix + 'QNetwork' + suffix