done?
This commit is contained in:
53
hw2/model.py
Normal file
53
hw2/model.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from hydra.utils import instantiate
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class QNetwork(nn.Module):
|
||||
def __init__(self, state_size, action_size, hidden_size, activation):
|
||||
super(QNetwork, self).__init__()
|
||||
self.q_head = nn.Sequential(
|
||||
nn.Linear(state_size, hidden_size),
|
||||
instantiate(activation),
|
||||
nn.Linear(hidden_size, hidden_size),
|
||||
instantiate(activation),
|
||||
nn.Linear(hidden_size, action_size)
|
||||
)
|
||||
|
||||
def forward(self, state):
|
||||
Qs = self.q_head(state)
|
||||
return Qs
|
||||
|
||||
|
||||
class DuelingQNetwork(nn.Module):
|
||||
def __init__(self, state_size, action_size, hidden_size, activation):
|
||||
super(DuelingQNetwork, self).__init__()
|
||||
self.feature_layer = nn.Sequential(
|
||||
nn.Linear(state_size, hidden_size),
|
||||
instantiate(activation),
|
||||
)
|
||||
self.value_head = nn.Sequential(
|
||||
nn.Linear(hidden_size, hidden_size),
|
||||
instantiate(activation),
|
||||
nn.Linear(hidden_size, 1)
|
||||
)
|
||||
self.advantage_head = nn.Sequential(
|
||||
nn.Linear(hidden_size, hidden_size),
|
||||
instantiate(activation),
|
||||
nn.Linear(hidden_size, action_size)
|
||||
)
|
||||
|
||||
def forward(self, state):
|
||||
"""
|
||||
Get the Q value of the current state and action using dueling network
|
||||
"""
|
||||
############################
|
||||
# YOUR IMPLEMENTATION HERE #
|
||||
|
||||
# using equation (7) on https://arxiv.org/pdf/1511.06581
|
||||
Qs=self.value_head(self.feature_layer(state))+self.advantage_head(self.feature_layer(state))
|
||||
############################
|
||||
return Qs
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user