Files
CSE5100H2/hw2/core.py
Zheyuan Wu 250f763f1f done?
2025-10-12 00:55:07 -05:00

129 lines
4.8 KiB
Python

from copy import deepcopy
import random
import logging
import numpy as np
from buffer import ReplayBuffer, PrioritizedReplayBuffer
import matplotlib.pyplot as plt
from utils import moving_average, merge_videos, get_epsilon
from gymnasium.wrappers import RecordVideo, RecordEpisodeStatistics
logger = logging.getLogger(__name__)
def visualize(step, title, train_steps, train_returns, eval_steps, eval_returns, losses, q_values):
train_window, loss_window, q_window = 10, 100, 100
plt.figure(figsize=(20, 6))
# plot train and eval returns
plt.subplot(1, 3, 1)
plt.title('frame %s. score: %s' % (step, np.mean(train_returns[-train_window:])))
plt.plot(train_steps[train_window - 1:], moving_average(train_returns, train_window), label='train')
if len(eval_steps) > 0:
plt.plot(eval_steps, eval_returns, label='eval')
plt.legend()
plt.xlabel('step')
# plot td losses
plt.subplot(1, 3, 2)
plt.title('loss')
plt.plot(moving_average(losses, loss_window))
plt.xlabel('step')
plt.subplot(1, 3, 3)
# plot q values
plt.title('q_values')
plt.plot(moving_average(q_values, q_window))
plt.xlabel('step')
plt.suptitle(title, fontsize=16)
plt.savefig('results.png')
plt.close()
def eval(env, agent, episodes, seed):
returns = []
for episode in range(episodes):
state, _ = env.reset(seed=episode + seed)
done, truncated = False, False
while not (done or truncated):
state, _, done, truncated, info = env.step(agent.get_action(state))
returns.append(info['episode']['r'].item())
return np.mean(returns), np.std(returns)
def train(cfg, env, agent, buffer, seed):
# wrap env to record episode returns
env = RecordEpisodeStatistics(env)
eval_env = deepcopy(env)
losses, Qs = [], []
episode_rewards, train_steps = [], []
eval_rewards, eval_steps = [], []
best_reward = -np.inf
done, truncated = False, False
state, _ = env.reset(seed=seed)
for step in range(1, cfg.timesteps + 1):
if done or truncated:
state, _ = env.reset()
done, truncated = False, False
# store episode reward
episode_rewards.append(info['episode']['r'].item())
train_steps.append(step - 1)
eps = get_epsilon(step - 1, cfg.eps_min, cfg.eps_max, cfg.eps_steps)
if random.random() < eps:
action = env.action_space.sample()
else:
action = agent.get_action(state)
next_state, reward, done, truncated, info = env.step(action)
buffer.add((state, action, reward, next_state, int(done)))
state = next_state
if step > cfg.batch_size + cfg.nstep:
# sample and do one step update
if isinstance(buffer, PrioritizedReplayBuffer):
# sample with priorities and update the priorities with td_error
batch, weights, tree_idxs = buffer.sample(cfg.batch_size)
loss, td_error, Q = agent.update(batch, step, weights=weights)
buffer.update_priorities(tree_idxs, td_error.cpu().numpy())
elif isinstance(buffer, ReplayBuffer):
batch = buffer.sample(cfg.batch_size)
loss, _, Q = agent.update(batch, step)
else:
raise RuntimeError("Unknown Buffer")
Qs.append(Q)
losses.append(loss)
if step % cfg.eval_interval == 0:
eval_mean, eval_std = eval(eval_env, agent=agent, episodes=cfg.eval_episodes, seed=seed)
state, _ = env.reset()
eval_steps.append(step - 1)
eval_rewards.append(eval_mean)
logger.info(f"Step: {step}, Eval mean: {eval_mean}, Eval std: {eval_std}")
if eval_mean > best_reward:
best_reward = eval_mean
agent.save('best_model.pt')
if step % cfg.plot_interval == 0:
visualize(step, f'{agent} with {buffer}', train_steps, episode_rewards, eval_steps, eval_rewards, losses, Qs)
agent.save('final_model.pt')
visualize(step, f'{agent} with {buffer}', train_steps, episode_rewards, eval_steps, eval_rewards, losses, Qs)
env = RecordVideo(eval_env, 'final_videos', name_prefix='eval', episode_trigger=lambda x: x % 2 == 0 and x < cfg.eval_episodes)
eval_mean, eval_std = eval(env, agent=agent, episodes=cfg.eval_episodes, seed=seed)
agent.load('best_model.pt') # use best model for visualization
env = RecordVideo(eval_env, 'best_videos', name_prefix='eval', episode_trigger=lambda x: x % 2 == 0 and x < cfg.eval_episodes)
eval_mean, eval_std = eval(env, agent=agent, episodes=cfg.eval_episodes, seed=seed)
env.close()
logger.info(f"Final Eval mean: {eval_mean}, Eval std: {eval_std}")
merge_videos('final_videos')
merge_videos('best_videos')
return eval_mean