part2 complete
This commit is contained in:
2
hw3/bash/2-2-experiments.sh
Normal file
2
hw3/bash/2-2-experiments.sh
Normal file
@@ -0,0 +1,2 @@
|
||||
python run.py --env_name CartPole-v1 -n 100 -b 5000 -rtg --exp_name cartpole_rtg_no_baseline
|
||||
python run.py --env_name CartPole-v1 -n 100 -b 5000 -rtg -na --use_baseline --exp_name cartpole_na_rtg_baseline
|
||||
8
hw3/bash/2-3-experiments.sh
Normal file
8
hw3/bash/2-3-experiments.sh
Normal file
@@ -0,0 +1,8 @@
|
||||
python run.py --env_name CartPole-v1 -n 100 -b 5000 -rtg --exp_name cartpole_rtg_no_baseline
|
||||
python run.py --env_name CartPole-v1 -n 100 -b 5000 -rtg --use_baseline --exp_name cartpole_rtg_baseline
|
||||
# with na
|
||||
python run.py --env_name CartPole-v1 -n 100 -b 5000 -rtg -na --use_baseline --exp_name cartpole_na_rtg_baseline
|
||||
# add bgs (default 5) and blr (default 5e-3) to the experiments
|
||||
python run.py --env_name CartPole-v1 -n 100 -b 5000 -rtg --use_baseline --baseline_gradient_steps 3 --exp_name cartpole_rtg_baseline_bgs3
|
||||
python run.py --env_name CartPole-v1 -n 100 -b 5000 -rtg --use_baseline --baseline_learning_rate 0.001 --exp_name cartpole_rtg_baseline_blr1e-3
|
||||
python run.py --env_name CartPole-v1 -n 100 -b 5000 -rtg --use_baseline --baseline_gradient_steps 3 --baseline_learning_rate 0.001 --exp_name cartpole_rtg_baseline_bgs3_blr1e-3
|
||||
10
hw3/bash/2-4-experiments.sh
Normal file
10
hw3/bash/2-4-experiments.sh
Normal file
@@ -0,0 +1,10 @@
|
||||
python run.py --env_name HalfCheetah-v4 -n 100 -b 5000 -na --use_baseline --exp_name halfcheetah_na_baseline
|
||||
|
||||
python run.py --env_name HalfCheetah-v4 -n 100 -b 5000 -na --use_baseline --baseline_gradient_steps 3 --exp_name halfcheetah_na_baseline_bgs3
|
||||
python run.py --env_name HalfCheetah-v4 -n 100 -b 5000 -na --use_baseline --baseline_learning_rate 0.001 --exp_name halfcheetah_na_baseline_blr1e-3
|
||||
# with reward to go
|
||||
python run.py --env_name HalfCheetah-v4 -n 100 -b 5000 -rtg -na --use_baseline --exp_name halfcheetah_na_rtg_baseline
|
||||
python run.py --env_name HalfCheetah-v4 -n 100 -b 5000 -rtg -na --use_baseline --baseline_gradient_steps 3 --exp_name halfcheetah_na_rtg_baseline_bgs3
|
||||
python run.py --env_name HalfCheetah-v4 -n 100 -b 5000 -rtg -na --use_baseline --baseline_learning_rate 0.001 --exp_name halfcheetah_na_rtg_baseline_blr1e-3
|
||||
# Berkely parameters
|
||||
python run.py --env_name HalfCheetah-v4 -n 100 -b 5000 -rtg --use_baseline --baseline_gradient_steps 5 --baseline_learning_rate 0.01 --exp_name halfcheetah_na_rtg_baseline_bgs5_blr1e-2
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -41,7 +41,9 @@ class ValueCritic(nn.Module):
|
||||
|
||||
############################
|
||||
# YOUR IMPLEMENTATION HERE #
|
||||
values=self.network(obs)
|
||||
assert isinstance(obs, torch.Tensor), "obs must be a torch tensor"
|
||||
# squeeze the last dimension to get the values as 1D tensor
|
||||
values=self.network.forward(obs).squeeze(dim=-1)
|
||||
############################
|
||||
|
||||
return values
|
||||
@@ -56,7 +58,8 @@ class ValueCritic(nn.Module):
|
||||
############################
|
||||
# YOUR IMPLEMENTATION HERE #
|
||||
values = self.forward(obs)
|
||||
loss = F.mse_loss(values, q_values)
|
||||
# use mean squared error loss
|
||||
loss = torch.mean(torch.square(q_values-values))
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
@@ -85,7 +85,7 @@ class PGAgent(nn.Module):
|
||||
critic_info: dict = None
|
||||
############################
|
||||
# YOUR IMPLEMENTATION HERE #
|
||||
|
||||
critic_info = self.critic.update(obs, q_values)
|
||||
############################
|
||||
|
||||
info.update(critic_info)
|
||||
@@ -138,7 +138,9 @@ class PGAgent(nn.Module):
|
||||
advantages = None
|
||||
############################
|
||||
# YOUR IMPLEMENTATION HERE #
|
||||
|
||||
values = self.critic.forward(ptu.from_numpy(obs))
|
||||
q_values = q_values - ptu.to_numpy(values)
|
||||
advantages = q_values.copy()
|
||||
############################
|
||||
assert values.shape == q_values.shape
|
||||
|
||||
|
||||
Reference in New Issue
Block a user