This commit is contained in:
Zheyuan Wu
2025-10-12 00:55:07 -05:00
commit 250f763f1f
467 changed files with 19784 additions and 0 deletions

53
hw2/model.py Normal file
View File

@@ -0,0 +1,53 @@
from hydra.utils import instantiate
import torch
import torch.nn as nn
class QNetwork(nn.Module):
def __init__(self, state_size, action_size, hidden_size, activation):
super(QNetwork, self).__init__()
self.q_head = nn.Sequential(
nn.Linear(state_size, hidden_size),
instantiate(activation),
nn.Linear(hidden_size, hidden_size),
instantiate(activation),
nn.Linear(hidden_size, action_size)
)
def forward(self, state):
Qs = self.q_head(state)
return Qs
class DuelingQNetwork(nn.Module):
def __init__(self, state_size, action_size, hidden_size, activation):
super(DuelingQNetwork, self).__init__()
self.feature_layer = nn.Sequential(
nn.Linear(state_size, hidden_size),
instantiate(activation),
)
self.value_head = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
instantiate(activation),
nn.Linear(hidden_size, 1)
)
self.advantage_head = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
instantiate(activation),
nn.Linear(hidden_size, action_size)
)
def forward(self, state):
"""
Get the Q value of the current state and action using dueling network
"""
############################
# YOUR IMPLEMENTATION HERE #
# using equation (7) on https://arxiv.org/pdf/1511.06581
Qs=self.value_head(self.feature_layer(state))+self.advantage_head(self.feature_layer(state))
############################
return Qs