31 lines
1018 B
Python
31 lines
1018 B
Python
import hydra
|
|
import utils
|
|
import torch
|
|
import logging
|
|
from agent import DQNAgent
|
|
from core import train
|
|
from buffer import get_buffer
|
|
import gymnasium as gym
|
|
logger = logging.getLogger(__name__)
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
@hydra.main(config_path="cfgs", config_name="config", version_base="1.3")
|
|
def main(cfg):
|
|
env = gym.make(cfg.env_name, render_mode="rgb_array")
|
|
utils.set_seed_everywhere(env, cfg.seed)
|
|
|
|
state_size = utils.get_space_shape(env.observation_space)
|
|
action_size = utils.get_space_shape(env.action_space)
|
|
|
|
buffer = get_buffer(cfg.buffer, state_size=state_size, device=device)
|
|
agent = DQNAgent(state_size=state_size, action_size=action_size, cfg=cfg.agent, device=device)
|
|
|
|
logger.info(f"Training for {cfg.train.timesteps} timesteps with {agent} and {buffer}")
|
|
eval_mean = train(cfg.train, env, agent, buffer, seed=cfg.seed)
|
|
logger.info(f"Finish training with eval mean: {eval_mean}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|