129 lines
4.8 KiB
Python
129 lines
4.8 KiB
Python
from copy import deepcopy
|
|
import random
|
|
import logging
|
|
import numpy as np
|
|
from buffer import ReplayBuffer, PrioritizedReplayBuffer
|
|
import matplotlib.pyplot as plt
|
|
from utils import moving_average, merge_videos, get_epsilon
|
|
from gymnasium.wrappers import RecordVideo, RecordEpisodeStatistics
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def visualize(step, title, train_steps, train_returns, eval_steps, eval_returns, losses, q_values):
|
|
train_window, loss_window, q_window = 10, 100, 100
|
|
plt.figure(figsize=(20, 6))
|
|
|
|
# plot train and eval returns
|
|
plt.subplot(1, 3, 1)
|
|
plt.title('frame %s. score: %s' % (step, np.mean(train_returns[-train_window:])))
|
|
plt.plot(train_steps[train_window - 1:], moving_average(train_returns, train_window), label='train')
|
|
if len(eval_steps) > 0:
|
|
plt.plot(eval_steps, eval_returns, label='eval')
|
|
plt.legend()
|
|
plt.xlabel('step')
|
|
|
|
# plot td losses
|
|
plt.subplot(1, 3, 2)
|
|
plt.title('loss')
|
|
plt.plot(moving_average(losses, loss_window))
|
|
plt.xlabel('step')
|
|
plt.subplot(1, 3, 3)
|
|
|
|
# plot q values
|
|
plt.title('q_values')
|
|
plt.plot(moving_average(q_values, q_window))
|
|
plt.xlabel('step')
|
|
plt.suptitle(title, fontsize=16)
|
|
plt.savefig('results.png')
|
|
plt.close()
|
|
|
|
|
|
def eval(env, agent, episodes, seed):
|
|
returns = []
|
|
for episode in range(episodes):
|
|
state, _ = env.reset(seed=episode + seed)
|
|
done, truncated = False, False
|
|
while not (done or truncated):
|
|
state, _, done, truncated, info = env.step(agent.get_action(state))
|
|
|
|
returns.append(info['episode']['r'].item())
|
|
return np.mean(returns), np.std(returns)
|
|
|
|
|
|
def train(cfg, env, agent, buffer, seed):
|
|
# wrap env to record episode returns
|
|
env = RecordEpisodeStatistics(env)
|
|
eval_env = deepcopy(env)
|
|
losses, Qs = [], []
|
|
episode_rewards, train_steps = [], []
|
|
eval_rewards, eval_steps = [], []
|
|
|
|
best_reward = -np.inf
|
|
done, truncated = False, False
|
|
state, _ = env.reset(seed=seed)
|
|
|
|
for step in range(1, cfg.timesteps + 1):
|
|
if done or truncated:
|
|
state, _ = env.reset()
|
|
done, truncated = False, False
|
|
# store episode reward
|
|
episode_rewards.append(info['episode']['r'].item())
|
|
train_steps.append(step - 1)
|
|
|
|
eps = get_epsilon(step - 1, cfg.eps_min, cfg.eps_max, cfg.eps_steps)
|
|
|
|
if random.random() < eps:
|
|
action = env.action_space.sample()
|
|
else:
|
|
action = agent.get_action(state)
|
|
|
|
next_state, reward, done, truncated, info = env.step(action)
|
|
buffer.add((state, action, reward, next_state, int(done)))
|
|
state = next_state
|
|
|
|
if step > cfg.batch_size + cfg.nstep:
|
|
# sample and do one step update
|
|
if isinstance(buffer, PrioritizedReplayBuffer):
|
|
# sample with priorities and update the priorities with td_error
|
|
batch, weights, tree_idxs = buffer.sample(cfg.batch_size)
|
|
loss, td_error, Q = agent.update(batch, step, weights=weights)
|
|
|
|
buffer.update_priorities(tree_idxs, td_error.cpu().numpy())
|
|
elif isinstance(buffer, ReplayBuffer):
|
|
batch = buffer.sample(cfg.batch_size)
|
|
loss, _, Q = agent.update(batch, step)
|
|
else:
|
|
raise RuntimeError("Unknown Buffer")
|
|
|
|
Qs.append(Q)
|
|
losses.append(loss)
|
|
|
|
if step % cfg.eval_interval == 0:
|
|
eval_mean, eval_std = eval(eval_env, agent=agent, episodes=cfg.eval_episodes, seed=seed)
|
|
state, _ = env.reset()
|
|
eval_steps.append(step - 1)
|
|
eval_rewards.append(eval_mean)
|
|
logger.info(f"Step: {step}, Eval mean: {eval_mean}, Eval std: {eval_std}")
|
|
if eval_mean > best_reward:
|
|
best_reward = eval_mean
|
|
agent.save('best_model.pt')
|
|
|
|
if step % cfg.plot_interval == 0:
|
|
visualize(step, f'{agent} with {buffer}', train_steps, episode_rewards, eval_steps, eval_rewards, losses, Qs)
|
|
|
|
agent.save('final_model.pt')
|
|
visualize(step, f'{agent} with {buffer}', train_steps, episode_rewards, eval_steps, eval_rewards, losses, Qs)
|
|
|
|
env = RecordVideo(eval_env, 'final_videos', name_prefix='eval', episode_trigger=lambda x: x % 2 == 0 and x < cfg.eval_episodes)
|
|
eval_mean, eval_std = eval(env, agent=agent, episodes=cfg.eval_episodes, seed=seed)
|
|
|
|
agent.load('best_model.pt') # use best model for visualization
|
|
env = RecordVideo(eval_env, 'best_videos', name_prefix='eval', episode_trigger=lambda x: x % 2 == 0 and x < cfg.eval_episodes)
|
|
eval_mean, eval_std = eval(env, agent=agent, episodes=cfg.eval_episodes, seed=seed)
|
|
|
|
env.close()
|
|
logger.info(f"Final Eval mean: {eval_mean}, Eval std: {eval_std}")
|
|
merge_videos('final_videos')
|
|
merge_videos('best_videos')
|
|
return eval_mean
|