Files
CSE5100H2/hw2/agent.py
Zheyuan Wu e74aac95e3 updates
2025-10-14 20:34:47 -05:00

131 lines
5.2 KiB
Python

import os
import torch
import torch.optim as optim
from copy import deepcopy
from model import QNetwork, DuelingQNetwork, NoisyQNetwork
from gymnasium.wrappers import TimeLimit
class DQNAgent:
def __init__(self, state_size, action_size, cfg, device='cuda'):
self.device = device
self.use_double = cfg.use_double
self.use_dueling = cfg.use_dueling
self.use_noisy = cfg.use_noisy
self.noisy_sigma = cfg.noisy_sigma
self.target_update_interval = cfg.target_update_interval
q_model = QNetwork
if self.use_dueling:
q_model = DuelingQNetwork
if self.use_noisy:
q_model = NoisyQNetwork
self.q_net = q_model(state_size, action_size, cfg.hidden_size, cfg.activation, sigma_init=cfg.noisy_sigma).to(self.device)
else:
self.q_net = q_model(state_size, action_size, cfg.hidden_size, cfg.activation).to(self.device)
self.target_net = deepcopy(self.q_net).to(self.device)
self.optimizer = optim.AdamW(self.q_net.parameters(), lr=cfg.lr)
self.tau = cfg.tau
# update the gamma we use in the Bellman equation for n-step DQN
self.gamma = cfg.gamma ** cfg.nstep
def soft_update(self, target, source):
"""
Soft update the target network using the source network
"""
for target_param, source_param in zip(target.parameters(), source.parameters()):
target_param.data.copy_((1 - self.tau) * target_param.data + self.tau * source_param.data)
@torch.no_grad()
def get_action(self, state):
"""
Get the action according to the current state and Q value
"""
############################
# YOUR IMPLEMENTATION HERE #
# update from single state
torch_max_idx = torch.argmax(self.q_net(torch.tensor(state).to(self.device)), dim=0)
action = torch_max_idx.cpu().numpy()
############################
return action
@torch.no_grad()
def get_Q_target(self, state, action, reward, done, next_state) -> torch.Tensor:
"""
Get the target Q value according to the Bellman equation
"""
if self.use_double:
# YOUR IMPLEMENTATION HERE
reward_tensor = reward.to(self.device)
next_q_tensor = self.target_net(next_state.to(self.device))
next_action = torch.argmax(self.q_net(next_state.to(self.device)), dim=1).unsqueeze(1)
# print(next_q_tensor.shape, next_action.shape)
# return the max Q value
next_q = torch.gather(next_q_tensor, dim=1, index=next_action).squeeze(1)
q_target = reward_tensor + (1-done.to(self.device)) * self.gamma * next_q
return q_target
else:
# YOUR IMPLEMENTATION HERE
reward_tensor = reward.to(self.device)
# update from batch states
next_q_tensor = self.target_net(next_state.to(self.device))
# return the max Q value
next_q = torch.max(next_q_tensor, dim=1).values
q_target = reward_tensor + (1-done.to(self.device)) * self.gamma * next_q
return q_target
def get_Q(self, state, action, use_double_net=False) -> torch.Tensor:
"""
Get the Q value of the current state and action
"""
############################
# YOUR IMPLEMENTATION HERE #
# elegant python move by Jack Wu. Fantastic...
# q= self.q_net(state.to(self.device))[:, action.int()]
# update from batch states
q_tensor = self.q_net(state.to(self.device))
action_idx = action.squeeze(1).to(dtype=torch.int32).to(self.device)
# select corresponding action, do not use index_select... That don't works
q = q_tensor.gather(1, action_idx.unsqueeze(1)).squeeze(1)
return q
############################
def update(self, batch, step, weights=None):
state, action, reward, next_state, done = batch
Q_target = self.get_Q_target(state, action, reward, done, next_state)
Q = self.get_Q(state, action)
if weights is None:
weights = torch.ones_like(Q).to(self.device)
td_error = torch.abs(Q - Q_target).detach()
loss = torch.mean((Q - Q_target)**2 * weights)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if not step % self.target_update_interval:
with torch.no_grad():
self.soft_update(self.target_net, self.q_net)
return loss.item(), td_error, Q.mean().item()
def save(self, name):
os.makedirs('models', exist_ok=True)
torch.save(self.q_net.state_dict(), os.path.join('models', name))
def load(self, name='best.pt'):
self.q_net.load_state_dict(torch.load(os.path.join('models', name)))
def __repr__(self) -> str:
use_double = 'Double' if self.use_double else ''
use_dueling = 'Dueling' if self.use_dueling else ''
use_noisy = 'Noisy' if self.use_noisy else ''
prefix = 'Normal' if not self.use_double and not self.use_dueling and not self.use_noisy else ''
suffix = f'with noisy sigma={self.noisy_sigma}' if self.use_noisy else ''
return use_double + use_dueling + use_noisy+ prefix + 'QNetwork' + suffix