import sys
import torch
import argparse
import gymnasium as gym
from gflownet import GFlowNet
from stable_baselines3 import PPO, SAC, TD3, DQN
import random
import numpy as np
import zipfile
import json
import gym_envs


# Use case: python train.py -m gfn -s meta --seed 10 -t 10000
def get_arguments(argv):
      parser = argparse.ArgumentParser(description='Adversarial driving maneuvers training')

      # Models
      parser.add_argument('-m', '--method', type=str, default='gfn', help='The method to use for training, options: ppo, sac, td3, dqn, gfn, gfnsub')
      parser.add_argument('-s', '--sim', type=str, default='gmm', help='The simulation environment to use, options: hypergrid, grid, gmm, bitgen, pusher, hypergrid_simple')
      parser.add_argument('-se', '--seed', type=int, default=42, help='The random seed to use')
      parser.add_argument('-d', '--discrete', action='store_true', help='Whether to use discrete action space')

      # Loader for stored models
      parser.add_argument('-md', '--model_dir', type=str, default='', help='The model to load in')
      parser.add_argument('-o', '--overwrite', action='store_true', help='Whether to overwrite the model directory')

      # Hyperparameters for training algs
      # parser.add_argument('-t', '--total_timesteps', type=int, default=1000000, help='The total timesteps to train for')
      parser.add_argument('-t', '--total_iterations', type=int, default=100000, help='The total iterations to train for')
      parser.add_argument('--gamma', type=float, default=0.99, help='The discount factor for RL')
      parser.add_argument('-lr', '--learning_rate', type=float, default=5e-4, help='The learning rate')
      parser.add_argument('-b', '--batch_size', type=int, default=64, help='The batch size')

      # for PPO
      parser.add_argument('-e', '--ent_coef', type=float, default=0, help='The entropy coefficient')
      parser.add_argument('-tp', '--timesteps_per_epoch', type=int, default=1024, help='The timesteps per training epoch')
      parser.add_argument('--clip_range', type=float, default=0.2, help='The clip range')

      # for SAC, TD3, DQN, and GFlowNet
      parser.add_argument('--learning_starts', type=int, default=0, help='The number of timesteps to start learning')
      parser.add_argument('--train_freq', type=int, default=5, help='The training frequency')
      parser.add_argument('--buffer_size', type=int, default=100000, help='The buffer size')
      parser.add_argument('--gradient_steps', type=int, default=10, help='The number of gradient steps')

      # for GFlowNet
      parser.add_argument('-em', '--explorative_model', type=str, default='', help='The explorative policy to load in')
      parser.add_argument('-en', '--explorative_num', type=int, default=0, help='The number sample to generate from the explorative policy')
      parser.add_argument('-ex', '--explorative_as_initial', action='store_true', help='Whether to use the explorative policy as the initial policy')
      parser.add_argument('-rs', '--rl_start', type=float, default=0, help='The start of rl updates in GFN training')
      parser.add_argument('-rl', '--rl_length', type=float, default=0, help='The raio of rl updates in GFN training')

      # for printing the training logs
      parser.add_argument('-v', '--verbose', action='store_true', help='Whether to print out the training process')
      
      # for validation during training and saving intermediate models
      parser.add_argument('-nvs', '--num_val_samples', type=int, default=0, help='Number of samples to use for validation during training and saving intermediate models. No validation is done if set to 0.')
      
      # Neural network architecture
      parser.add_argument('--nn_hidden_sizes', type=int, nargs='+', default=[64, 64], help='Sizes of hidden layers in the neural network, e.g. 64 64 for two layers of size 64')

      parser.add_argument('-p', '--parallel', action='store_true', help='Whether to use parallel environments for training')
      parser.add_argument('-hf', '--high_fidelity', action='store_true', help='Whether to generate new trajectories using the current model in validation')

      return parser.parse_args(argv)

def seed_all(seed):
      random.seed(seed)
      np.random.seed(seed)
      torch.backends.cudnn.deterministic = True
      torch.backends.cudnn.benchmark = False
      torch.manual_seed(seed)

if __name__ == '__main__':
      hyperparameters = get_arguments(sys.argv[1:])
      hyperparameters = vars(hyperparameters)
      hyperparameters['continuous'] = not hyperparameters['discrete']

      print(hyperparameters)

      seed_all(hyperparameters['seed'])

      validation_env = None
      data_env = None

      if hyperparameters['sim'] == 'toy':
            env = gym.make('adv_driving/toy-v0')
            if hyperparameters['parallel']:
                  data_env = gym.make_vec('adv_driving/toy-v0', num_envs = hyperparameters['train_freq'], vectorization_mode='sync')
      elif hyperparameters['sim'] == 'highway':
            env = gym.make('adv_driving/gym-highway-v0')
            if hyperparameters['parallel']:
                  data_env = gym.make_vec('adv_driving/gym-highway-v0', num_envs = hyperparameters['train_freq'], vectorization_mode='sync')
      elif hyperparameters['sim'] == 'openpilot':
            env = gym.make('adv_driving/openpilot-v0')
            # if hyperparameters['parallel']: # openpilot is not parallelizable I guess
            #       data_env = gym.make_vec('adv_driving/openpilot-v0', num_envs = hyperparameters['train_freq'], vectorization_mode='sync')
      elif hyperparameters['sim'] == 'meta':
            env = gym.make("adv_driving/metadrive_highway_env", render_mode="topdown")
            # if hyperparameters['parallel']:
            #       data_env = gym.make_vec('adv_driving/metadrive_highway_env', num_envs = hyperparameters['train_freq'], vectorization_mode='sync')
      else:
            print("Invalid simulation environment.")
            sys.exit(0)

      if hyperparameters['method'] not in ['ppo', 'sac', 'td3', 'dqn', 'gfn', 'gfnsub', 'gafn']:
            print("Invalid method.")
            sys.exit(0)
      
      if hyperparameters['model_dir'] != '' and not hyperparameters['overwrite']:
            print("Model directory is not empty, please make sure this training won't cause any data loss then add -o.")
            sys.exit(0)

      model_name = f"{hyperparameters['method']}_{hyperparameters['total_iterations']}_{hyperparameters['nn_hidden_sizes']}_{hyperparameters['seed']}"

      activation_fn = torch.nn.LeakyReLU

      t_start = 0
      if hyperparameters['method'] == 'ppo':
            policy_kwargs = dict(activation_fn=activation_fn,
                  net_arch=dict(pi=hyperparameters['nn_hidden_sizes'], vf=hyperparameters['nn_hidden_sizes']))
            
            model_name += f"_{hyperparameters['learning_rate']}_{hyperparameters['batch_size']}_{hyperparameters['gamma']}_{hyperparameters['ent_coef']}_{hyperparameters['clip_range']}_{hyperparameters['timesteps_per_epoch']}"
            
            log_dir = f"../output/{hyperparameters['sim']}/{model_name}"
            model_dir = f"../output/{hyperparameters['sim']}/{model_name}.zip"
            model = PPO("MlpPolicy", env, learning_rate = hyperparameters['learning_rate'], \
                        batch_size = hyperparameters['batch_size'], \
                        n_steps = hyperparameters['timesteps_per_epoch'], \
                        ent_coef = hyperparameters['ent_coef'], \
                        clip_range = hyperparameters['clip_range'],\
                        gamma = hyperparameters['gamma'], \
                        seed = hyperparameters['seed'],\
                        tensorboard_log = log_dir, \
                        policy_kwargs = policy_kwargs, \
                        verbose = hyperparameters['verbose'])
            
            if hyperparameters['model_dir'] != '':
                  print("Loading model from ", hyperparameters['model_dir'])
                  model.load(hyperparameters['model_dir'])
                  # load data in the hyperparameters['model_dir'], then read the num_timesteps
                  with zipfile.ZipFile(hyperparameters['model_dir'], 'r') as zip:
                        t_start = json.loads(zip.read('data'))['num_timesteps']
                        print("Resume training from time step: = ", t_start)
                  
            
      elif hyperparameters['method'] == 'sac':
            policy_kwargs = dict(activation_fn=activation_fn,
                  net_arch=dict(pi=hyperparameters['nn_hidden_sizes'], qf=[400, 300]))
            
            model_name += f"_{hyperparameters['learning_rate']}_{hyperparameters['batch_size']}_{hyperparameters['gamma']}_{hyperparameters['buffer_size']}_{hyperparameters['train_freq']}_{hyperparameters['gradient_steps']}_{hyperparameters['learning_starts']}"
            
            log_dir = f"../output/{hyperparameters['sim']}/{model_name}"
            model_dir = f"../output/{hyperparameters['sim']}/{model_name}.zip"

            model = SAC("MlpPolicy", env, learning_rate = hyperparameters['learning_rate'], \
                        batch_size = hyperparameters['batch_size'], \
                        buffer_size = hyperparameters['buffer_size'], \
                        train_freq = hyperparameters['train_freq'], \
                        gradient_steps = hyperparameters['gradient_steps'], \
                        learning_starts = hyperparameters['learning_starts'], \
                        gamma = hyperparameters['gamma'], \
                        seed = hyperparameters['seed'],\
                        tensorboard_log = log_dir, \
                        policy_kwargs = policy_kwargs, \
                        verbose = hyperparameters['verbose'])
            if hyperparameters['model_dir'] != '':
                  print("Loading model from ", hyperparameters['model_dir'])
                  model.load(hyperparameters['model_dir'])
                  model.load_replay_buffer(hyperparameters['model_dir'].replace('.zip', '') + '.pkl')
                  # load data in the hyperparameters['model_dir'], then read the num_timesteps
                  with zipfile.ZipFile(hyperparameters['model_dir'], 'r') as zip:
                        t_start = json.loads(zip.read('data'))['num_timesteps']
                        print("Resume training from time step: = ", t_start)
      
      elif hyperparameters['method'] == 'td3':
            policy_kwargs = dict(activation_fn=activation_fn,
                  net_arch=dict(pi=hyperparameters['nn_hidden_sizes'], qf=[400, 300]))
            
            model_name += f"_{hyperparameters['learning_rate']}_{hyperparameters['batch_size']}_{hyperparameters['gamma']}_{hyperparameters['buffer_size']}_{hyperparameters['train_freq']}_{hyperparameters['gradient_steps']}_{hyperparameters['learning_starts']}"
            
            log_dir = f"../output/{hyperparameters['sim']}/{model_name}"
            model_dir = f"../output/{hyperparameters['sim']}/{model_name}.zip"

            model = TD3("MlpPolicy", env, learning_rate = hyperparameters['learning_rate'], \
                        batch_size = hyperparameters['batch_size'], \
                        buffer_size = hyperparameters['buffer_size'], \
                        train_freq = hyperparameters['train_freq'], \
                        gradient_steps = hyperparameters['gradient_steps'], \
                        learning_starts = hyperparameters['learning_starts'], \
                        gamma = hyperparameters['gamma'], \
                        seed = hyperparameters['seed'],\
                        tensorboard_log = log_dir, \
                        policy_kwargs = policy_kwargs, \
                        verbose = hyperparameters['verbose'])
            
            if hyperparameters['model_dir'] != '':
                  print("Loading model from ", hyperparameters['model_dir'])
                  model.load(hyperparameters['model_dir'])
                  model.load_replay_buffer(hyperparameters['model_dir'].replace('.zip', '') + '.pkl')
                  # load data in the hyperparameters['model_dir'], then read the num_timesteps
                  with zipfile.ZipFile(hyperparameters['model_dir'], 'r') as zip:
                        t_start = json.loads(zip.read('data'))['num_timesteps']
                        print("Resume training from time step: = ", t_start)

      elif hyperparameters['method'] == 'dqn':
            policy_kwargs = dict(activation_fn=activation_fn,
                  net_arch=hyperparameters['nn_hidden_sizes'])
            
            model_name += f"_{hyperparameters['learning_rate']}_{hyperparameters['batch_size']}_{hyperparameters['gamma']}_{hyperparameters['ent_coef']}_{hyperparameters['clip_range']}_{hyperparameters['timesteps_per_epoch']}"
            
            log_dir = f"../output/{hyperparameters['sim']}/{model_name}"
            model_dir = f"../output/{hyperparameters['sim']}/{model_name}.zip"
            model = DQN("MlpPolicy", env, learning_rate = hyperparameters['learning_rate'], \
                        batch_size = hyperparameters['batch_size'], \
                        buffer_size = hyperparameters['buffer_size'], \
                        train_freq = hyperparameters['train_freq'], \
                        gradient_steps = hyperparameters['gradient_steps'], \
                        learning_starts = hyperparameters['learning_starts'], \
                        gamma = hyperparameters['gamma'], \
                        seed = hyperparameters['seed'],\
                        tensorboard_log = log_dir, \
                        policy_kwargs = policy_kwargs, \
                        verbose = hyperparameters['verbose'])
            
            if hyperparameters['model_dir'] != '':
                  print("Loading model from ", hyperparameters['model_dir'])
                  model.load(hyperparameters['model_dir'])
                  model.load_replay_buffer(hyperparameters['model_dir'].replace('.zip', '') + '.pkl')
                  # load data in the hyperparameters['model_dir'], then read the num_timesteps
                  with zipfile.ZipFile(hyperparameters['model_dir'], 'r') as zip:
                        t_start = json.loads(zip.read('data'))['num_timesteps']
                        print("Resume training from time step: = ", t_start)

      elif hyperparameters['method'] == 'gfn':
            # load a model trained by stable-baselines3, the idea is to get the policy and then use only the actor part
            # not used in the paper
            model_path = hyperparameters['explorative_model']
            if model_path == '':
                  explorative_policy = None
                  hyperparameters['explorative_num'] = 0
            elif 'ppo' in model_path:
                  explorative_policy = PPO.load(model_path)

            elif 'sac' in model_path:
                  explorative_policy = SAC.load(model_path)

            elif 'td3' in model_path:
                  explorative_policy = TD3.load(model_path)

            model_name += f"_{hyperparameters['learning_rate']}_{hyperparameters['batch_size']}_{hyperparameters['gamma']}_{hyperparameters['buffer_size']}_{hyperparameters['train_freq']}_{hyperparameters['gradient_steps']}_{hyperparameters['learning_starts']}_False_{hyperparameters['explorative_num']}"

            log_dir = f"../output/{hyperparameters['sim']}/{model_name}"
            model_dir = f"../output/{hyperparameters['sim']}/{model_name}"

            # check if the model exists
            import os
            if os.path.exists(model_dir + '/forward_policy.pth'):
                  print("Model results exists, exit")
                  sys.exit(0)
            
            model = GFlowNet(env=env, learning_rate = hyperparameters['learning_rate'], \
                        batch_size = hyperparameters['batch_size'], \
                        buffer_size = hyperparameters['buffer_size'], \
                        train_freq = hyperparameters['train_freq'], \
                        gradient_steps = hyperparameters['gradient_steps'], \
                        learning_starts = hyperparameters['learning_starts'], \
                        explorative_policy = explorative_policy, \
                        explorative_num = hyperparameters['explorative_num'],\
                        continuous = hyperparameters['continuous'],\
                        tensorboard_log = log_dir, \
                        verbose = hyperparameters['verbose'],\
                        hidden_sizes = hyperparameters['nn_hidden_sizes'],\
                        activation_fn = activation_fn,\
                        num_val_samples = hyperparameters['num_val_samples'],\
                        model_dir=model_dir,
                        validation_env = validation_env,
                        data_env = data_env
                    )
            
            t_start, i_start, e_start = 0, 0, 0

            if hyperparameters['model_dir'] != '':
                  print("Loading model from ", hyperparameters['model_dir'])
                  t_start, i_start, e_start = model.load(hyperparameters['model_dir'], True)
                  print("Resume training from time step: = ", t_start)
                  model.load_replay_buffer(hyperparameters['model_dir'] + '.pkl')

      # save hyperparameters
      with open(log_dir + '/hyperparameters.json', 'w') as f:
            json.dump(hyperparameters, f)
      # model directory
      if 'gfn' not in hyperparameters['method']:
            model.learn(hyperparameters['total_iterations'] - t_start, reset_num_timesteps=False) # for the models in stable-baselines3
      else:
            model.learn(hyperparameters['total_iterations'], t_start, i_start, e_start) # for our GFN model
      model.save(model_dir)
      if hyperparameters['method'] != 'ppo':
            model.save_replay_buffer(model_dir.replace('.zip', '') + '.pkl')
      
      