updates
This commit is contained in:
53
hw2/agent.py
53
hw2/agent.py
@@ -2,7 +2,7 @@ import os
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from copy import deepcopy
|
||||
from model import QNetwork, DuelingQNetwork
|
||||
from model import QNetwork, DuelingQNetwork, NoisyQNetwork
|
||||
from gymnasium.wrappers import TimeLimit
|
||||
|
||||
class DQNAgent:
|
||||
@@ -10,10 +10,17 @@ class DQNAgent:
|
||||
self.device = device
|
||||
self.use_double = cfg.use_double
|
||||
self.use_dueling = cfg.use_dueling
|
||||
self.use_noisy = cfg.use_noisy
|
||||
self.noisy_sigma = cfg.noisy_sigma
|
||||
self.target_update_interval = cfg.target_update_interval
|
||||
q_model = DuelingQNetwork if self.use_dueling else QNetwork
|
||||
|
||||
self.q_net = q_model(state_size, action_size, cfg.hidden_size, cfg.activation).to(self.device)
|
||||
q_model = QNetwork
|
||||
if self.use_dueling:
|
||||
q_model = DuelingQNetwork
|
||||
if self.use_noisy:
|
||||
q_model = NoisyQNetwork
|
||||
self.q_net = q_model(state_size, action_size, cfg.hidden_size, cfg.activation, sigma_init=cfg.noisy_sigma).to(self.device)
|
||||
else:
|
||||
self.q_net = q_model(state_size, action_size, cfg.hidden_size, cfg.activation).to(self.device)
|
||||
self.target_net = deepcopy(self.q_net).to(self.device)
|
||||
self.optimizer = optim.AdamW(self.q_net.parameters(), lr=cfg.lr)
|
||||
|
||||
@@ -51,12 +58,14 @@ class DQNAgent:
|
||||
if self.use_double:
|
||||
# YOUR IMPLEMENTATION HERE
|
||||
reward_tensor = reward.to(self.device)
|
||||
# update from batch states via q_net
|
||||
next_q_tensor = self.q_net(next_state.to(self.device))
|
||||
next_q_tensor = self.target_net(next_state.to(self.device))
|
||||
next_action = torch.argmax(self.q_net(next_state.to(self.device)), dim=1).unsqueeze(1)
|
||||
# print(next_q_tensor.shape, next_action.shape)
|
||||
# return the max Q value
|
||||
next_q = torch.max(next_q_tensor, dim=1).values
|
||||
next_q = torch.gather(next_q_tensor, dim=1, index=next_action).squeeze(1)
|
||||
q_target = reward_tensor + (1-done.to(self.device)) * self.gamma * next_q
|
||||
return q_target
|
||||
|
||||
else:
|
||||
# YOUR IMPLEMENTATION HERE
|
||||
reward_tensor = reward.to(self.device)
|
||||
@@ -73,22 +82,14 @@ class DQNAgent:
|
||||
"""
|
||||
############################
|
||||
# YOUR IMPLEMENTATION HERE #
|
||||
if use_double_net:
|
||||
# get from target net
|
||||
q_tensor = self.target_net(state.to(self.device))
|
||||
action_idx = action.squeeze(1).to(dtype=torch.int32).to(self.device)
|
||||
# select corresponding action, do not use index_select... That don't works
|
||||
q = q_tensor.gather(1, action_idx.unsqueeze(1)).squeeze(1)
|
||||
return q
|
||||
else:
|
||||
# elegant python move by Jack Wu. Fantastic...
|
||||
# q= self.q_net(state.to(self.device))[:, action.int()]
|
||||
# update from batch states
|
||||
q_tensor = self.q_net(state.to(self.device))
|
||||
action_idx = action.squeeze(1).to(dtype=torch.int32).to(self.device)
|
||||
# select corresponding action, do not use index_select... That don't works
|
||||
q = q_tensor.gather(1, action_idx.unsqueeze(1)).squeeze(1)
|
||||
return q
|
||||
# elegant python move by Jack Wu. Fantastic...
|
||||
# q= self.q_net(state.to(self.device))[:, action.int()]
|
||||
# update from batch states
|
||||
q_tensor = self.q_net(state.to(self.device))
|
||||
action_idx = action.squeeze(1).to(dtype=torch.int32).to(self.device)
|
||||
# select corresponding action, do not use index_select... That don't works
|
||||
q = q_tensor.gather(1, action_idx.unsqueeze(1)).squeeze(1)
|
||||
return q
|
||||
############################
|
||||
|
||||
def update(self, batch, step, weights=None):
|
||||
@@ -123,5 +124,7 @@ class DQNAgent:
|
||||
def __repr__(self) -> str:
|
||||
use_double = 'Double' if self.use_double else ''
|
||||
use_dueling = 'Dueling' if self.use_dueling else ''
|
||||
prefix = 'Normal' if not self.use_double and not self.use_dueling else ''
|
||||
return use_double + use_dueling + prefix + 'QNetwork'
|
||||
use_noisy = 'Noisy' if self.use_noisy else ''
|
||||
prefix = 'Normal' if not self.use_double and not self.use_dueling and not self.use_noisy else ''
|
||||
suffix = f'with noisy sigma={self.noisy_sigma}' if self.use_noisy else ''
|
||||
return use_double + use_dueling + use_noisy+ prefix + 'QNetwork' + suffix
|
||||
|
||||
Reference in New Issue
Block a user