This commit is contained in:
Zheyuan Wu
2025-10-12 00:55:07 -05:00
commit 250f763f1f
467 changed files with 19784 additions and 0 deletions

BIN
hw2/DRL_Homework_2.pdf Normal file

Binary file not shown.

24
hw2/README.md Normal file
View File

@@ -0,0 +1,24 @@
# Installation
Since we are using PyTorch for hw2, we recommend using conda to manage the environment. Please refer to the [miniconda](https://docs.conda.io/en/latest/miniconda.html) homepage for a compact conda installation.
You have two options for creating the environment of hw2
* For mac users or a cpu-only installation, please remove the `pytorch-cuda` term in either ways.
* To create a new conda environment, simply run `conda env create -f environment.yml`
* If you want to install the package within the environment you created with hw1, please following the below steps:
```bash
conda activate <hw1-env-name>
# we are using PyTorch 2.0!
# remove the pytorch-cuda=11.7 term if you are a mac user to want a cpu-only installation
conda install pytorch==2.0.0 pytorch-cuda=11.7 -c pytorch -c nvidia
pip install gymnasium[classic_control]==0.27.1
pip install matplotlib==3.7.1
# for hyperparameter management
pip install hydra-core==1.3.2
# for video recording
pip install moviepy==1.0.3
```
That's it! If you encounter any trouble creating the environment, please let us know :-)

127
hw2/agent.py Normal file
View File

@@ -0,0 +1,127 @@
import os
import torch
import torch.optim as optim
from copy import deepcopy
from model import QNetwork, DuelingQNetwork
from gymnasium.wrappers import TimeLimit
class DQNAgent:
def __init__(self, state_size, action_size, cfg, device='cuda'):
self.device = device
self.use_double = cfg.use_double
self.use_dueling = cfg.use_dueling
self.target_update_interval = cfg.target_update_interval
q_model = DuelingQNetwork if self.use_dueling else QNetwork
self.q_net = q_model(state_size, action_size, cfg.hidden_size, cfg.activation).to(self.device)
self.target_net = deepcopy(self.q_net).to(self.device)
self.optimizer = optim.AdamW(self.q_net.parameters(), lr=cfg.lr)
self.tau = cfg.tau
# update the gamma we use in the Bellman equation for n-step DQN
self.gamma = cfg.gamma ** cfg.nstep
def soft_update(self, target, source):
"""
Soft update the target network using the source network
"""
for target_param, source_param in zip(target.parameters(), source.parameters()):
target_param.data.copy_((1 - self.tau) * target_param.data + self.tau * source_param.data)
@torch.no_grad()
def get_action(self, state):
"""
Get the action according to the current state and Q value
"""
############################
# YOUR IMPLEMENTATION HERE #
# update from single state
torch_max_idx = torch.argmax(self.q_net(torch.tensor(state).to(self.device)), dim=0)
action = torch_max_idx.cpu().numpy()
############################
return action
@torch.no_grad()
def get_Q_target(self, state, action, reward, done, next_state) -> torch.Tensor:
"""
Get the target Q value according to the Bellman equation
"""
if self.use_double:
# YOUR IMPLEMENTATION HERE
reward_tensor = reward.to(self.device)
# update from batch states via q_net
next_q_tensor = self.q_net(next_state.to(self.device))
# return the max Q value
next_q = torch.max(next_q_tensor, dim=1).values
q_target = reward_tensor + (1-done.to(self.device)) * self.gamma * next_q
return q_target
else:
# YOUR IMPLEMENTATION HERE
reward_tensor = reward.to(self.device)
# update from batch states
next_q_tensor = self.target_net(next_state.to(self.device))
# return the max Q value
next_q = torch.max(next_q_tensor, dim=1).values
q_target = reward_tensor + (1-done.to(self.device)) * self.gamma * next_q
return q_target
def get_Q(self, state, action, use_double_net=False) -> torch.Tensor:
"""
Get the Q value of the current state and action
"""
############################
# YOUR IMPLEMENTATION HERE #
if use_double_net:
# get from target net
q_tensor = self.target_net(state.to(self.device))
action_idx = action.squeeze(1).to(dtype=torch.int32).to(self.device)
# select corresponding action, do not use index_select... That don't works
q = q_tensor.gather(1, action_idx.unsqueeze(1)).squeeze(1)
return q
else:
# elegant python move by Jack Wu. Fantastic...
# q= self.q_net(state.to(self.device))[:, action.int()]
# update from batch states
q_tensor = self.q_net(state.to(self.device))
action_idx = action.squeeze(1).to(dtype=torch.int32).to(self.device)
# select corresponding action, do not use index_select... That don't works
q = q_tensor.gather(1, action_idx.unsqueeze(1)).squeeze(1)
return q
############################
def update(self, batch, step, weights=None):
state, action, reward, next_state, done = batch
Q_target = self.get_Q_target(state, action, reward, done, next_state)
Q = self.get_Q(state, action)
if weights is None:
weights = torch.ones_like(Q).to(self.device)
td_error = torch.abs(Q - Q_target).detach()
loss = torch.mean((Q - Q_target)**2 * weights)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if not step % self.target_update_interval:
with torch.no_grad():
self.soft_update(self.target_net, self.q_net)
return loss.item(), td_error, Q.mean().item()
def save(self, name):
os.makedirs('models', exist_ok=True)
torch.save(self.q_net.state_dict(), os.path.join('models', name))
def load(self, name='best.pt'):
self.q_net.load_state_dict(torch.load(os.path.join('models', name)))
def __repr__(self) -> str:
use_double = 'Double' if self.use_double else ''
use_dueling = 'Dueling' if self.use_dueling else ''
prefix = 'Normal' if not self.use_double and not self.use_dueling else ''
return use_double + use_dueling + prefix + 'QNetwork'

202
hw2/buffer.py Normal file
View File

@@ -0,0 +1,202 @@
import torch
import numpy as np
from collections import deque
def get_buffer(cfg, **args):
assert type(cfg.nstep) == int and cfg.nstep > 0, 'nstep must be a positive integer'
if not cfg.use_per:
if cfg.nstep == 1:
return ReplayBuffer(cfg.capacity, **args)
else:
return NStepReplayBuffer(cfg.capacity, cfg.nstep, cfg.gamma, **args)
else:
if cfg.nstep == 1:
return PrioritizedReplayBuffer(cfg.capacity, cfg.per_eps, cfg.per_alpha, cfg.per_beta, **args)
else:
return PrioritizedNStepReplayBuffer(cfg.capacity, cfg.per_eps, cfg.per_alpha, cfg.per_beta, cfg.nstep, cfg.gamma, **args)
class ReplayBuffer:
def __init__(self, capacity, state_size, device):
self.device = device
self.state = torch.empty(capacity, state_size, dtype=torch.float)
self.action = torch.empty(capacity, 1, dtype=torch.float)
self.reward = torch.empty(capacity, dtype=torch.float)
self.next_state = torch.empty(capacity, state_size, dtype=torch.float)
self.done = torch.empty(capacity, dtype=torch.int)
self.idx = 0
self.size = 0
self.capacity = capacity
def __repr__(self) -> str:
return 'NormalReplayBuffer'
def add(self, transition):
state, action, reward, next_state, done = transition
# store transition in the buffer and update the index and size of the buffer
# you may need to convert the data type to torch.tensor
############################
# YOUR IMPLEMENTATION HERE #
self.state[self.idx] = torch.tensor(state, device=self.device)
self.action[self.idx] = torch.tensor(action, device=self.device)
self.reward[self.idx] = torch.tensor(reward, device=self.device)
self.next_state[self.idx] = torch.tensor(next_state, device=self.device)
self.done[self.idx] = torch.tensor(done, device=self.device)
self.idx = (self.idx + 1) % self.capacity
self.size = min(self.size + 1, self.capacity)
############################
def sample(self, batch_size):
# sample batch_size data from the buffer without replacement
sample_idxs = np.random.choice(self.size, batch_size, replace=False)
batch = ()
# get a batch of data from the buffer according to the sample_idxs
# please transfer the data to the corresponding device before return
############################
# YOUR IMPLEMENTATION HERE #
# do not load to gpu device since the buffer is not loaded on init
batch = (torch.index_select(self.state, 0, torch.tensor(sample_idxs)),
torch.index_select(self.action, 0, torch.tensor(sample_idxs)),
torch.index_select(self.reward, 0, torch.tensor(sample_idxs)),
torch.index_select(self.next_state, 0, torch.tensor(sample_idxs)),
torch.index_select(self.done, 0, torch.tensor(sample_idxs))
)
############################
return batch
class NStepReplayBuffer(ReplayBuffer):
def __init__(self, capacity, n_step, gamma, state_size, device):
super().__init__(capacity, state_size, device=device)
self.n_step = n_step
self.n_step_buffer = deque([], maxlen=n_step)
self.gamma = gamma
def __repr__(self) -> str:
return f'{self.n_step}StepReplayBuffer'
def n_step_handler(self):
"""Get n-step state, action, reward and done for the transition, discard those rewards after done=True"""
############################
# YOUR IMPLEMENTATION HERE #
state, action, reward, done = self.n_step_buffer[0]
# compute n-step discounted reward
gamma = self.gamma
for i in range(1, len(self.n_step_buffer)):
if done:
break
reward += gamma * self.n_step_buffer[i][2]
gamma *= self.gamma
############################
return state, action, reward, done
def add(self, transition):
state, action, reward, next_state, done = transition
self.n_step_buffer.append((state, action, reward, done))
if len(self.n_step_buffer) < self.n_step:
return
state, action, reward, done = self.n_step_handler()
super().add((state, action, reward, next_state, done))
class PrioritizedReplayBuffer(ReplayBuffer):
def __init__(self, capacity, eps, alpha, beta, state_size, device):
self.weights = np.zeros(capacity, dtype=np.float32) # stores weights for importance sampling
self.eps = eps # minimal priority for stability
self.alpha = alpha # determines how much prioritization is used, α = 0 corresponding to the uniform case
self.beta = beta # determines the amount of importance-sampling correction, b = 1 fully compensate for the non-uniform probabilities
self.max_priority = eps # priority for new samples, init as eps
super().__init__(capacity, state_size, device=device)
def add(self, transition):
"""
Add a new experience to memory, and update it's priority to the max_priority.
"""
############################
# YOUR IMPLEMENTATION HERE #
super().add(transition)
self.weights[self.idx] = self.max_priority
############################
def sample(self, batch_size):
"""
Sample a batch of experiences from the buffer with priority, and calculates the weights used for the correction of bias used in the Q-learning update
Returns:
batch: a batch of experiences as in the normal replay buffer
weights: torch.Tensor (batch_size, ), importance sampling weights for each sample
sample_idxs: numpy.ndarray (batch_size, ), the indexes of the sample in the buffer
"""
############################
# YOUR IMPLEMENTATION HERE #
# assume sample with replacement, in case if sample size is too small
sample_idxs_tensor = torch.multinomial(torch.tensor(self.weights), batch_size, replacement=True)
sample_idxs = sample_idxs_tensor.cpu().numpy()
# do not load to gpu device since the buffer is not loaded on init
batch = (
torch.index_select(self.state, 0, torch.tensor(sample_idxs)),
torch.index_select(self.action, 0, torch.tensor(sample_idxs)),
torch.index_select(self.reward, 0, torch.tensor(sample_idxs)),
torch.index_select(self.next_state, 0, torch.tensor(sample_idxs)),
torch.index_select(self.done, 0, torch.tensor(sample_idxs))
)
weights = torch.tensor(self.weights[sample_idxs], device=self.device).unsqueeze(1)
############################
return batch, weights, sample_idxs
def update_priorities(self, data_idxs, priorities: np.ndarray):
priorities = (priorities + self.eps) ** self.alpha
self.weights[data_idxs] = priorities
self.max_priority = max(self.weights)
def __repr__(self) -> str:
return 'PrioritizedReplayBuffer'
# Avoid Diamond Inheritance
class PrioritizedNStepReplayBuffer(PrioritizedReplayBuffer):
def __init__(self, capacity, eps, alpha, beta, n_step, gamma, state_size, device):
############################
# YOUR IMPLEMENTATION HERE #
super().__init__(capacity, eps, alpha, beta, state_size, device)
self.n_step = n_step
self.n_step_buffer = deque([], maxlen=n_step)
self.gamma = gamma
############################
def __repr__(self) -> str:
return f'Prioritized{self.n_step}StepReplayBuffer'
def add(self, transition):
############################
# YOUR IMPLEMENTATION HERE #
state, action, reward, next_state, done = transition
self.n_step_buffer.append((state, action, reward, done))
if len(self.n_step_buffer) < self.n_step:
return
state, action, reward, done = self.n_step_handler()
super().add((state, action, reward, next_state, done))
############################
# def the other necessary class methods as your need
def n_step_handler(self):
"""Get n-step state, action, reward and done for the transition, discard those rewards after done=True"""
############################
# YOUR IMPLEMENTATION HERE #
state, action, reward, done = self.n_step_buffer[0]
# compute n-step discounted reward
gamma = self.gamma
for i in range(1, len(self.n_step_buffer)):
if done:
break
reward += gamma * self.n_step_buffer[i][2]
gamma *= self.gamma
############################
return state, action, reward, done

45
hw2/cfgs/config.yaml Normal file
View File

@@ -0,0 +1,45 @@
seed: 42
env_name: CartPole-v1
train:
nstep: ${buffer.nstep}
timesteps: 50_000
batch_size: 128
test_every: 2500
eps_max: 1
eps_min: 0.05
eps_steps: 12_500
start_steps: 0
plot_interval: 2000
eval_interval: 2000
eval_episodes: 10
agent:
gamma: 0.99
lr: 0.002
tau: 0.1
nstep: ${buffer.nstep}
target_update_interval: 3
hidden_size: 64
activation:
_target_: torch.nn.ELU
# you can define other parameters of the __init__ function (if any) for the object here
use_dueling: False
use_double: False
buffer:
capacity: 50_000
use_per: False
nstep: 1
gamma: ${agent.gamma}
per_alpha: 0.7
per_beta: 0.4
per_eps: 0.01
hydra:
job:
chdir: true
run:
dir: ./runs/${now:%Y-%m-%d}/${now:%H-%M-%S}_${hydra.job.override_dirname}
sweep:
dir: ./sweeps/${now:%Y-%m-%d}/${now:%H-%M-%S}_${hydra.job.override_dirname}

128
hw2/core.py Normal file
View File

@@ -0,0 +1,128 @@
from copy import deepcopy
import random
import logging
import numpy as np
from buffer import ReplayBuffer, PrioritizedReplayBuffer
import matplotlib.pyplot as plt
from utils import moving_average, merge_videos, get_epsilon
from gymnasium.wrappers import RecordVideo, RecordEpisodeStatistics
logger = logging.getLogger(__name__)
def visualize(step, title, train_steps, train_returns, eval_steps, eval_returns, losses, q_values):
train_window, loss_window, q_window = 10, 100, 100
plt.figure(figsize=(20, 6))
# plot train and eval returns
plt.subplot(1, 3, 1)
plt.title('frame %s. score: %s' % (step, np.mean(train_returns[-train_window:])))
plt.plot(train_steps[train_window - 1:], moving_average(train_returns, train_window), label='train')
if len(eval_steps) > 0:
plt.plot(eval_steps, eval_returns, label='eval')
plt.legend()
plt.xlabel('step')
# plot td losses
plt.subplot(1, 3, 2)
plt.title('loss')
plt.plot(moving_average(losses, loss_window))
plt.xlabel('step')
plt.subplot(1, 3, 3)
# plot q values
plt.title('q_values')
plt.plot(moving_average(q_values, q_window))
plt.xlabel('step')
plt.suptitle(title, fontsize=16)
plt.savefig('results.png')
plt.close()
def eval(env, agent, episodes, seed):
returns = []
for episode in range(episodes):
state, _ = env.reset(seed=episode + seed)
done, truncated = False, False
while not (done or truncated):
state, _, done, truncated, info = env.step(agent.get_action(state))
returns.append(info['episode']['r'].item())
return np.mean(returns), np.std(returns)
def train(cfg, env, agent, buffer, seed):
# wrap env to record episode returns
env = RecordEpisodeStatistics(env)
eval_env = deepcopy(env)
losses, Qs = [], []
episode_rewards, train_steps = [], []
eval_rewards, eval_steps = [], []
best_reward = -np.inf
done, truncated = False, False
state, _ = env.reset(seed=seed)
for step in range(1, cfg.timesteps + 1):
if done or truncated:
state, _ = env.reset()
done, truncated = False, False
# store episode reward
episode_rewards.append(info['episode']['r'].item())
train_steps.append(step - 1)
eps = get_epsilon(step - 1, cfg.eps_min, cfg.eps_max, cfg.eps_steps)
if random.random() < eps:
action = env.action_space.sample()
else:
action = agent.get_action(state)
next_state, reward, done, truncated, info = env.step(action)
buffer.add((state, action, reward, next_state, int(done)))
state = next_state
if step > cfg.batch_size + cfg.nstep:
# sample and do one step update
if isinstance(buffer, PrioritizedReplayBuffer):
# sample with priorities and update the priorities with td_error
batch, weights, tree_idxs = buffer.sample(cfg.batch_size)
loss, td_error, Q = agent.update(batch, step, weights=weights)
buffer.update_priorities(tree_idxs, td_error.cpu().numpy())
elif isinstance(buffer, ReplayBuffer):
batch = buffer.sample(cfg.batch_size)
loss, _, Q = agent.update(batch, step)
else:
raise RuntimeError("Unknown Buffer")
Qs.append(Q)
losses.append(loss)
if step % cfg.eval_interval == 0:
eval_mean, eval_std = eval(eval_env, agent=agent, episodes=cfg.eval_episodes, seed=seed)
state, _ = env.reset()
eval_steps.append(step - 1)
eval_rewards.append(eval_mean)
logger.info(f"Step: {step}, Eval mean: {eval_mean}, Eval std: {eval_std}")
if eval_mean > best_reward:
best_reward = eval_mean
agent.save('best_model.pt')
if step % cfg.plot_interval == 0:
visualize(step, f'{agent} with {buffer}', train_steps, episode_rewards, eval_steps, eval_rewards, losses, Qs)
agent.save('final_model.pt')
visualize(step, f'{agent} with {buffer}', train_steps, episode_rewards, eval_steps, eval_rewards, losses, Qs)
env = RecordVideo(eval_env, 'final_videos', name_prefix='eval', episode_trigger=lambda x: x % 2 == 0 and x < cfg.eval_episodes)
eval_mean, eval_std = eval(env, agent=agent, episodes=cfg.eval_episodes, seed=seed)
agent.load('best_model.pt') # use best model for visualization
env = RecordVideo(eval_env, 'best_videos', name_prefix='eval', episode_trigger=lambda x: x % 2 == 0 and x < cfg.eval_episodes)
eval_mean, eval_std = eval(env, agent=agent, episodes=cfg.eval_episodes, seed=seed)
env.close()
logger.info(f"Final Eval mean: {eval_mean}, Eval std: {eval_std}")
merge_videos('final_videos')
merge_videos('best_videos')
return eval_mean

15
hw2/environment.yml Normal file
View File

@@ -0,0 +1,15 @@
name: drl_hw2
channels:
- pytorch
- nvidia
- defaults
dependencies:
- python=3.10
- pytorch=2.0.0
- pytorch-cuda=11.7 # Comment this line if you are a mac user or want a cpu-only installation
- pip=23.0.1
- pip:
- gymnasium[classic-control]==0.27.1
- hydra-core==1.3.2
- matplotlib==3.7.1
- moviepy==1.0.3

Binary file not shown.

BIN
hw2/gallery/All-In-One.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 146 KiB

BIN
hw2/gallery/DQN.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 144 KiB

BIN
hw2/gallery/Double DQN.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 140 KiB

BIN
hw2/gallery/Dueling DQN.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 162 KiB

BIN
hw2/gallery/NStep + PER.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 140 KiB

BIN
hw2/gallery/NStep.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 144 KiB

BIN
hw2/gallery/PER.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 156 KiB

30
hw2/main.py Normal file
View File

@@ -0,0 +1,30 @@
import hydra
import utils
import torch
import logging
from agent import DQNAgent
from core import train
from buffer import get_buffer
import gymnasium as gym
logger = logging.getLogger(__name__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@hydra.main(config_path="cfgs", config_name="config", version_base="1.3")
def main(cfg):
env = gym.make(cfg.env_name, render_mode="rgb_array")
utils.set_seed_everywhere(env, cfg.seed)
state_size = utils.get_space_shape(env.observation_space)
action_size = utils.get_space_shape(env.action_space)
buffer = get_buffer(cfg.buffer, state_size=state_size, device=device)
agent = DQNAgent(state_size=state_size, action_size=action_size, cfg=cfg.agent, device=device)
logger.info(f"Training for {cfg.train.timesteps} timesteps with {agent} and {buffer}")
eval_mean = train(cfg.train, env, agent, buffer, seed=cfg.seed)
logger.info(f"Finish training with eval mean: {eval_mean}")
if __name__ == "__main__":
main()

53
hw2/model.py Normal file
View File

@@ -0,0 +1,53 @@
from hydra.utils import instantiate
import torch
import torch.nn as nn
class QNetwork(nn.Module):
def __init__(self, state_size, action_size, hidden_size, activation):
super(QNetwork, self).__init__()
self.q_head = nn.Sequential(
nn.Linear(state_size, hidden_size),
instantiate(activation),
nn.Linear(hidden_size, hidden_size),
instantiate(activation),
nn.Linear(hidden_size, action_size)
)
def forward(self, state):
Qs = self.q_head(state)
return Qs
class DuelingQNetwork(nn.Module):
def __init__(self, state_size, action_size, hidden_size, activation):
super(DuelingQNetwork, self).__init__()
self.feature_layer = nn.Sequential(
nn.Linear(state_size, hidden_size),
instantiate(activation),
)
self.value_head = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
instantiate(activation),
nn.Linear(hidden_size, 1)
)
self.advantage_head = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
instantiate(activation),
nn.Linear(hidden_size, action_size)
)
def forward(self, state):
"""
Get the Q value of the current state and action using dueling network
"""
############################
# YOUR IMPLEMENTATION HERE #
# using equation (7) on https://arxiv.org/pdf/1511.06581
Qs=self.value_head(self.feature_layer(state))+self.advantage_head(self.feature_layer(state))
############################
return Qs

65
hw2/utils.py Normal file
View File

@@ -0,0 +1,65 @@
import os
import glob
import torch
import shutil
import random
import numpy as np
import gymnasium as gym
from gymnasium.spaces import Discrete, Box
from moviepy.editor import VideoFileClip, concatenate_videoclips
def moving_average(a, n):
"""
Return an array of the moving average of a with window size n
"""
if len(a) <= n:
return a
ret = np.cumsum(a, dtype=float)
ret[n:] = ret[n:] - ret[:-n]
return ret[n - 1:] / n
def get_epsilon(step, eps_min, eps_max, eps_steps):
"""
Return the linearly descending epsilon of the current step for the epsilon-greedy policy. After eps_steps, epsilon will keep at eps_min
"""
############################
# YOUR IMPLEMENTATION HERE #
return max(eps_min, eps_max - (step / eps_steps) * (eps_max - eps_min))
############################
def merge_videos(video_dir):
"""
Merge videos in the video_dir into a single video
"""
videos = glob.glob(os.path.join(video_dir, "*.mp4"))
videos = sorted(videos, key=lambda x: int(x.split("-")[-1].split(".")[0]))
clip = concatenate_videoclips([VideoFileClip(video) for video in videos])
clip.write_videofile(f"{video_dir}.mp4")
shutil.rmtree(video_dir)
def set_seed_everywhere(env: gym.Env, seed=0):
"""
Set seed for all randomness sources
"""
env.action_space.seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
def get_space_shape(space):
"""
Return the shape of the gym.Space object
"""
if isinstance(space, Discrete):
return space.n
elif isinstance(space, Box):
if len(space.shape) == 1:
return space.shape[0]
else:
return space.shape
else:
raise ValueError(f"Space not supported: {space}")