import gymnasium as gym
import highway_env
from stable_baselines3 import TD3

env = gym.make("highway-v0", render_mode="rgb_array")

# This config should be the same as the one in gym_highway_env.py, except ...
config = {
      "observation": {
            "type": "Kinematics",
            "vehicles_count": 7,
            "absolute": False,
            "order": "sorted",
      },
      "action": {
            "type": "ContinuousAction",
      },
      "lanes_count": 3,
      "vehicles_count": 6,
      "initial_spacing": 1,
      "initial_lane_id": None,
      "controlled_vehicles": 1,
      "vehicles_density": 2,
      "simulation_frequency": 10,  # [Hz]
      "policy_frequency": 10,  # [Hz]
      "other_vehicles_type": "highway_env.vehicle.behavior.IDMVehicle",
      "centering_position": [0.8, 0.5],
      "right_lane_reward": 0,
      "policy_frequency": 2,  # [Hz]
      "show_trajectories": True,
      "offroad_terminal": True,
      # "offscreen_rendering": False
}

env.unwrapped.configure(config)

env.reset()

model = TD3('MlpPolicy', env,
              policy_kwargs=dict(net_arch=[128, 128]),
              learning_rate=5e-4,
              learning_starts=100,
              buffer_size=10000,
              batch_size=32,
              gamma=0.99,
              train_freq=1,
              gradient_steps=1,
              verbose=1,
              seed = 42,
              tensorboard_log="../data/log/highway_td3/")

model.learn(int(4e4))
model.save("../data/highway_td3")

# Load and test saved model
model = TD3.load("../data/highway_td3")
while True:
  done = truncated = False
  obs, info = env.reset()
  while not (done or truncated):
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, done, truncated, info = env.step(action)
    env.render()