This commit is contained in:
Zheyuan Wu
2025-10-12 00:55:07 -05:00
commit 250f763f1f
467 changed files with 19784 additions and 0 deletions

65
hw2/utils.py Normal file
View File

@@ -0,0 +1,65 @@
import os
import glob
import torch
import shutil
import random
import numpy as np
import gymnasium as gym
from gymnasium.spaces import Discrete, Box
from moviepy.editor import VideoFileClip, concatenate_videoclips
def moving_average(a, n):
"""
Return an array of the moving average of a with window size n
"""
if len(a) <= n:
return a
ret = np.cumsum(a, dtype=float)
ret[n:] = ret[n:] - ret[:-n]
return ret[n - 1:] / n
def get_epsilon(step, eps_min, eps_max, eps_steps):
"""
Return the linearly descending epsilon of the current step for the epsilon-greedy policy. After eps_steps, epsilon will keep at eps_min
"""
############################
# YOUR IMPLEMENTATION HERE #
return max(eps_min, eps_max - (step / eps_steps) * (eps_max - eps_min))
############################
def merge_videos(video_dir):
"""
Merge videos in the video_dir into a single video
"""
videos = glob.glob(os.path.join(video_dir, "*.mp4"))
videos = sorted(videos, key=lambda x: int(x.split("-")[-1].split(".")[0]))
clip = concatenate_videoclips([VideoFileClip(video) for video in videos])
clip.write_videofile(f"{video_dir}.mp4")
shutil.rmtree(video_dir)
def set_seed_everywhere(env: gym.Env, seed=0):
"""
Set seed for all randomness sources
"""
env.action_space.seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
def get_space_shape(space):
"""
Return the shape of the gym.Space object
"""
if isinstance(space, Discrete):
return space.n
elif isinstance(space, Box):
if len(space.shape) == 1:
return space.shape[0]
else:
return space.shape
else:
raise ValueError(f"Space not supported: {space}")