import os
import time
import math
import pickle
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.optim import Adam

from nns import ContinuousBackwardPolicy, ContinuousForwardPolicy, DiscreteForwardPolicy, DiscreteBackwardPolicy
from memory import Memory

from abc import abstractmethod
from torch.utils.tensorboard import SummaryWriter

# GFlowNET for sparse reward envs using trajectory balance loss 
class GFlowNet:
      def __init__(self, env, \
                   learning_rate = 1e-3, batch_size = 32, buffer_size = 10000,\
                   train_freq=16, gradient_steps = 10, learning_starts = 100, \
                   explorative_policy = None, explorative_num = 100,\
                   temperature = 1, \
                   device = 'auto', continuous = True, tensorboard_log = None, verbose = False,\
                   hidden_sizes = [256, 256], \
                   activation_fn = torch.nn.ReLU,\
                   initial_z = 0.1,\
                   num_val_samples=0,\
                   model_dir=None, \
                   validation_env=None, \
                   data_env=None, \
                   no_decay = False):
          
            # initialize the hyperparameters
            self.learning_rate = learning_rate
            self.buffer_size = buffer_size
            self.batch_size = batch_size
            self.temperature = temperature
            self.temperature_init = temperature
            self.gradient_steps = gradient_steps
            self.learning_starts = learning_starts
            self.explorative_num = explorative_num
            self.train_freq = train_freq
            self.device = device
            self.verbose = verbose
            self.num_val_samples = num_val_samples
            self.model_dir = model_dir
            self.validation_env = validation_env
            self.data_env = data_env
            self.no_decay = no_decay

            if self.device == 'auto':
                  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

            self.state_dim = env.observation_space.shape[0]
            self.continuous = continuous

            if continuous:
                  self.action_dim = env.action_space.shape[0]
                  self.forward_policy = ContinuousForwardPolicy(self.state_dim, self.action_dim - 1, hidden_sizes, activation_fn, device = self.device).to(self.device)
                  self.backward_policy = ContinuousBackwardPolicy(self.state_dim, self.action_dim - 1, hidden_sizes, activation_fn, device = self.device).to(self.device)
            else:
                  self.action_dim = env.action_space.n
                  self.forward_policy = DiscreteForwardPolicy(self.state_dim, self.action_dim, hidden_sizes, activation_fn, device = self.device).to(self.device)
                  self.backward_policy = DiscreteBackwardPolicy(self.state_dim, self.action_dim - 1, hidden_sizes, activation_fn, device = self.device).to(self.device)

            self.logZ = torch.tensor(initial_z, requires_grad = True, device = self.device, dtype=torch.float32)

            self.memory = Memory(self.buffer_size, device = self.device)

            self.explorative_policy = explorative_policy # Note: the input and output should be the same as the forward policy 

            self.forward_optim = Adam(self.forward_policy.parameters(), lr=self.learning_rate)
            self.backward_optim = Adam(self.backward_policy.parameters(), lr=self.learning_rate)
            self.logZ_optim = Adam([self.logZ], lr=self.learning_rate * 100) # According to the original paper, the learning rate of logZ is greater than the forward and backward policy

            self.forward_optim.param_groups[0]['initial_lr'] = self.learning_rate
            self.backward_optim.param_groups[0]['initial_lr'] = self.learning_rate
            self.logZ_optim.param_groups[0]['initial_lr'] = self.learning_rate * 100

            self.forward_optim.param_groups[0]['min_lr'] = self.learning_rate/100
            self.backward_optim.param_groups[0]['min_lr'] = self.learning_rate/100
            self.logZ_optim.param_groups[0]['min_lr'] = self.learning_rate

            self._n_updates = 0

            self.logger = {
                  'delta_t': time.time_ns(),
                  't_so_far': 0,          # timesteps so far
                  'i_so_far': 0,          # iterations so far
                  'e_so_far': 0,          # episodes so far
                  'batch_lens': [],       # episodic lengths in batch
                  'batch_rews': [],       # episodic returns in batch
                  'traj_losses': [],     # trajectory balance losses in current iteration
                  'traj_losses_std': [], # trajectory balance losses std in current iteration
                  'log_Z': [],            # log_Z in current iteration
            }
            
            self.env = env

            self.writer = SummaryWriter(tensorboard_log)
      
      # single env version
      def collect_rollouts(self):
            buffer_lens = []
            buffer_rews_report = []

            num_collected_steps, num_collected_episodes = 0, 0

            with torch.no_grad():
                  while num_collected_episodes < self.train_freq:
                        ep_acts = []
                        ep_obs = []
                        ep_next_obs = []
                        ep_rews = []

                        # Reset the environment. sNote that obs is short for observation. 
                        obs, _ = self.env.reset()
                        done = False
                        ep_t = 0

                        while not done:
                              # Track observations in this buffer
                              ep_obs.append(obs)
                              # Calculate action and make a step in the env. 
                              # Note that rew is short for reward.
                              if self.continuous:
                                    action = self.predict(obs)
                              else:
                                    action = self.predict(obs, self.env.unwrapped.get_forward_action_masks(obs))
                              # print("Before: ", obs)
                              obs, rew, done, truncated, augmented_rew = self.env.step(action)
                              # print("After: ", obs)
                              # Track recent action, and action log probability
                              ep_acts.append(action)
                              ep_next_obs.append(obs)
                              ep_rews.append(rew)

                              ep_t += 1

                        if not truncated:
                              num_collected_steps += len(ep_obs)
                              num_collected_episodes += 1

                              buffer_lens.append(len(ep_obs))
                              buffer_rews_report.append(ep_rews[-1])

                              ep_rtgs = ep_rews.copy()

                              # add to the memory
                              self.memory.push_traj(ep_obs, ep_acts, ep_next_obs, ep_rtgs, augmented_rew)

            # Log the episodic returns and episodic lengths in this buffer.
            self.logger['buffer_rews'] = buffer_rews_report
            self.logger['buffer_lens'] = buffer_lens

            return num_collected_steps, num_collected_episodes
      
      # parallel version
      def collect_rollouts_parallel(self):
            buffer_rews_report = np.zeros(self.train_freq) 
            num_collected_steps, num_collected_episodes = 0, 0
            buffer_lens = np.zeros(self.train_freq, dtype = int)
            with torch.no_grad():
                  dones = np.zeros(self.train_freq, dtype = bool)
                  obs_buffer = []
                  action_buffer = []
                  next_obs_buffer = []
                  reward_buffer = []
                  augmented_rew_buffer = np.ones(self.train_freq, dtype = np.float32)
                  obs, _ = self.data_env.reset()

                  prev_obs = obs
                  
                  while not np.all(dones):
                        obs_buffer.append(obs)
                        if self.continuous:
                              action = self.predict(obs)
                        else:
                              action = self.predict(obs, self.env.unwrapped.get_forward_action_masks(obs))
                        
                        prev_obs = obs
                        obs, rew, d, t, augmented_rew = self.data_env.step(action) # if done, automatically reset
                        action_buffer.append(action)
                        reward_buffer.append(rew * (~dones))
                        obs[d, :] = prev_obs[d, :]
                        next_obs_buffer.append(obs)

                        if np.any(d & (~dones) & (~t)):
                              buffer_lens[d & (~dones) & (~t)] = len(obs_buffer)
                              buffer_rews_report[d & (~dones) & (~t)] = reward_buffer[-1][d & (~dones) & (~t)] 
                              augmented_rew_buffer[d & (~dones) & (~t)] = augmented_rew['final_info'][d & (~dones) & (~t)]
                              dones[d & (~dones) & (~t)] = True

                        num_collected_steps += np.sum(~dones)

                  num_collected_episodes += np.sum(dones)

                  self.memory.push_trajs(obs_buffer, action_buffer, next_obs_buffer, reward_buffer, buffer_lens, augmented_rew_buffer)
            self.logger['buffer_rews'] = buffer_rews_report
            self.logger['buffer_lens'] = buffer_lens
            
            return num_collected_steps, num_collected_episodes
      
      # unused, for sampling transitions for RL updates
      def get_rollouts(self, buffer_obs, buffer_acts, buffer_next_obs, buffer_rtgs, buffer_log_probs, buffer_log_probs2):
            buffer_size = len(buffer_obs)
            indices = np.random.permutation(buffer_size)

            start_idx = 0
            while start_idx + self.batch_size < buffer_size:
                  idx = indices[start_idx: start_idx + self.batch_size]
                  yield buffer_obs[idx], buffer_acts[idx], buffer_next_obs[idx], buffer_rtgs[idx], buffer_log_probs[idx], buffer_log_probs2[idx]
                  start_idx += self.batch_size
            if start_idx < buffer_size:
                  idx = indices[start_idx:]
                  yield buffer_obs[idx], buffer_acts[idx], buffer_next_obs[idx], buffer_rtgs[idx], buffer_log_probs[idx], buffer_log_probs2[idx]

      
      # off-policy learning
      def learn(self, total_iterations, t_start = 0, i_start = 0, e_start = 0):
            if self.verbose:
                  print(f"Learning... Using {self.batch_size} episodes per updates, ", end='')
                  print(f"at least {self.train_freq} episodes per update for {total_iterations} iterations.")
            t_so_far = t_start # Timesteps simulated so far
            i_so_far = i_start # Iterations ran so far
            e_so_far = e_start # Episodes simulated so far

            # collect data from the explorative policy
            if self.explorative_policy is not None:
                  explorative_e_so_far = 0
                  while explorative_e_so_far < self.explorative_num:
                        tmp_t, tmp_e = self.collect_rollouts_from_explorative_policy()
                        explorative_e_so_far += tmp_e

                  # Print a summary of the collected data
                  self._log_summary()

            if self.explorative_policy is None and self.explorative_num > 0:
                  print("Warning: Explorative policy is not provided, collect data from random policy.")

            val_errs = []
            while i_so_far < total_iterations:
                  # increment the number of epoches
                  if self.data_env is not None:
                        tmp_t, tmp_e = self.collect_rollouts_parallel()
                  else:
                        tmp_t, tmp_e = self.collect_rollouts()
                  e_so_far += tmp_e

                  # Logging timesteps so far and iterations so far
                  self.logger['e_so_far'] = e_so_far
                  # print(e_so_far)

                  if len(self.memory) >= self.batch_size and e_so_far >= self.learning_starts:
                        self.train()
                        i_so_far += 1

                        t_so_far += tmp_t
                        self.logger['t_so_far'] = t_so_far
                        self.logger['i_so_far'] = i_so_far

                        # Update optimizers learning rate with linear decay
                        gfn_optimizers = [] # self.logZ_optim, self.forward_optim, self.backward_optim
                        # Update learning rate according to lr schedule
                        self._update_learning_rate(gfn_optimizers, i_so_far, total_iterations)

                        # Print a summary of our training so far
                        self._log_summary()

                  if self.num_val_samples != 0 and i_so_far % 100 == 0:
                        # print("Validating")
                        if self.validation_env is not None:
                              samples = []
                              with torch.no_grad():
                                    # compute the error by generating some trajectories with the current model
                                    s, _ = self.validation_env.reset()
                                    dones = np.zeros((s.shape[0],), dtype = bool)
                                    while(len(samples) < self.num_val_samples):
                                          if self.continuous:
                                                a = self.predict(s)
                                          else:
                                                # we use the original env to get the action mask
                                                action_mask = self.env.unwrapped.get_forward_action_masks(s)
                                                a = self.predict(s, action_mask)

                                          prev_s = s
                                          s, _, d, _, _ = self.validation_env.step(a) # will automatically reset if done, use prev_s
                                          if np.any(d & (~dones)):
                                                samples +=  list(self.env.unwrapped.get_state(prev_s[d&(~dones), :]))
                                                prev_s[d, :] = s[d, :]
                                                # stop take new samples if we have enough
                                                remaining_samples = self.num_val_samples - len(samples) - np.sum(~dones)
                                                if remaining_samples < 0:
                                                      indices_to_mark = np.where(d&(~dones))[0][: -remaining_samples]
                                                      dones[indices_to_mark] = True
                              samples = samples[-self.num_val_samples:]
                        else:
                              samples = self.env.unwrapped.get_state(np.array(self.memory.visited_end_states[-min(self.num_val_samples, len(self.memory.visited_end_states)):]))
                        val_err = self.env.unwrapped.get_error(samples)
                        print(f"Temperate: = {self.temperature}")
                        print(f"Validation error = {val_err}")
                        val_errs.append(val_err)
            # save val_errs
            with open(os.path.join(self.model_dir, "val_errs.pkl"), 'wb') as f:
                  pickle.dump(val_errs, f)

      def _update_learning_rate(self, gfn_optimizers, i_so_far, total_iterations):
            if not isinstance(gfn_optimizers, list):
                  gfn_optimizers = [gfn_optimizers]

            discount_ratio_gfn = 1 - i_so_far / (total_iterations + 1e-8)

            if self.temperature_init > 0: # diminishing the learning rate, unused in the paper
                  for optimizer in gfn_optimizers:
                        for param_group in optimizer.param_groups:
                              param_group['lr'] = max(param_group['min_lr'], param_group['initial_lr'] * discount_ratio_gfn)

            # update temperature
            # consider a smooth decay from 1 to 0, a segmented version is max(self.temperature_init * math.pow(10, - 15 * (1 - min(discount_ratio_gfn, 0.75) / 0.75))
            self.temperature = max(self.temperature_init * 1/(1+ math.pow(10, 10 - 20 * discount_ratio_gfn)), 1e-10)

      def predict(self, obs, action_mask= None, deterministic = False): 
            # Query the forward poilcy for an action
            if self.continuous:
                  action = self.forward_policy(obs, deterministic = deterministic)
            else:
                  action = self.forward_policy(obs, action_mask, deterministic = deterministic)

            # Return the action
            return action.detach().cpu().numpy()
      
      def _log_summary(self):
            """
                  Print to stdout what we've logged so far in the most recent batch.

                  Parameters:
                        None

                  Return:
                        None
            """
            # Calculate logging values. I use a few python shortcuts to calculate each value
            # without explaining since it's not too important to PPO; feel free to look it over,
            # and if you have any questions you can email me (look at bottom of README)
            delta_t = self.logger['delta_t']
            self.logger['delta_t'] = time.time_ns()
            delta_t = (self.logger['delta_t'] - delta_t) / 1e9
            delta_t = round(delta_t, 2)

            t_so_far = self.logger['t_so_far']
            i_so_far = self.logger['i_so_far']
            e_so_far = self.logger['e_so_far']
            avg_ep_lens = np.mean(self.logger['buffer_lens'])
            avg_ep_rews = np.mean(self.logger['buffer_rews'])
            avg_traj_loss = np.mean(self.logger['traj_losses'])
            avg_traj_loss_std = np.mean(self.logger['traj_losses_std'])
            avg_log_Z = np.mean(self.logger['log_Z'])

            # Round decimal places for more aesthetic logging messages
            avg_ep_lens = round(avg_ep_lens, 2)
            avg_ep_rews = round(avg_ep_rews, 2)
            avg_traj_loss = round(avg_traj_loss, 5)
            avg_traj_loss_std = round(avg_traj_loss_std, 5)
            avg_log_Z = round(avg_log_Z, 5)

            # Print logging statements
            if self.verbose:
                  print(flush=True)
                  print(f"-------------------- Epoch #{i_so_far} --------------------", flush=True)
                  print(f"Average Episodic Length: {avg_ep_lens}", flush=True)
                  print(f"Average Episodic Return: {avg_ep_rews}", flush=True)
                  print(f"Average Traj Loss: {avg_traj_loss:.5f}", flush=True)
                  print(f"Average Traj Loss Std: {avg_traj_loss_std:.5f}", flush=True)
                  print(f"Average log_Z: {avg_log_Z:.5f}", flush=True)

                  print(f"Timesteps So Far: {t_so_far}", flush=True)
                  print(f"Episode So Far: {e_so_far}", flush=True)
                  print(f"Iteration took: {delta_t} secs", flush=True)
                  print(f"------------------------------------------------------", flush=True)
                  print(flush=True)

            self.writer.add_scalar('train/episodic_length', avg_ep_lens, i_so_far)
            self.writer.add_scalar('train/episodic_return', avg_ep_rews, i_so_far)
            self.writer.add_scalar('train/traj_loss', avg_traj_loss, i_so_far)
            self.writer.add_scalar('train/traj_loss_std', avg_traj_loss_std, i_so_far)
            self.writer.add_scalar('train/log_Z', avg_log_Z, i_so_far)

            # Reset batch-specific logging data
            self.logger['batch_lens'] = []
            self.logger['batch_rews'] = []
            self.logger['traj_losses'] = []
            self.logger['traj_losses_std'] = []

      def train(self):
            # First epsilon * total step rounds for RL update, we use the basic PPO update with default hyper-parameters:
            traj_losses = 0
            traj_losses_std = 0    
            for gradient_step in range(self.gradient_steps):       
                  #  Sample from the rollout buffer
                  batch_obs, batch_acts, batch_next_obs, batch_rews, batch_augmented_rews, batch_idx = self.memory.priority_sample(self.batch_size, self.temperature)
                  
                  # Efficient version
                  batch_obs_pad = pad_sequence(batch_obs, batch_first = True)
                  batch_acts_pad = pad_sequence(batch_acts, batch_first = True)
                  batch_next_obs_pad = pad_sequence(batch_next_obs, batch_first = True)

                  lengths = torch.tensor(np.array([len(obs) for obs in batch_obs]), device = self.device)

                  # Forward pass
                  if self.continuous:
                        logPF = self.forward_policy.evaluate_actions(batch_obs_pad, batch_acts_pad, lengths = lengths)
                        logPB = self.backward_policy.evaluate_actions(batch_next_obs_pad, batch_acts_pad, lengths = lengths)
                  else:
                        forward_mask = self.env.unwrapped.get_forward_action_masks(batch_obs_pad)
                        backward_mask = self.env.unwrapped.get_backward_action_masks(batch_next_obs_pad)
                        logPF = self.forward_policy.evaluate_actions(batch_obs_pad, batch_acts_pad, forward_mask, lengths = lengths)
                        logPB = self.backward_policy.evaluate_actions(batch_next_obs_pad, batch_acts_pad, backward_mask, lengths = lengths)

                  logPF = torch.sum(logPF, dim = 1)
                  logPB = torch.sum(logPB, dim = 1)

                  batch_rews = torch.tensor(batch_rews, device = self.device)
                  batch_augmented_rews = torch.tensor(batch_augmented_rews, device = self.device)
                  log_reward = torch.log(batch_rews + self.temperature * batch_augmented_rews)

                  if torch.any(torch.isinf(logPF)) or torch.any(torch.isinf(logPB)):
                        raise ValueError("Infinite logprobs found")
                  
                  loss = (logPF + self.logZ - logPB - log_reward).pow(2)

                  if torch.any(torch.isinf(loss)) or torch.any(torch.isnan(loss)):
                        raise ValueError(f"Invalid loss found, loss: {loss}")
                  
                  traj_losses += loss.mean().detach().item()
                  traj_losses_std += loss.std().detach().item()

                  self.memory.update_priority(batch_idx, np.abs(loss.detach().cpu().numpy()))

                  rew_filter = (batch_rews.squeeze() + self.temperature * batch_augmented_rews.squeeze()) > (batch_rews.mean() + self.temperature * batch_augmented_rews.mean()) # high reward records, since there is a vast space with nearly 0 reward (no hope to sufficiently explore all of them to get their PB estimation)
                  batch_norm = - logPF.detach() - logPB.detach() + log_reward.detach()
                  
                  if torch.sum(rew_filter) > 1:
                        filter_norm = batch_norm[rew_filter]
                        batch_mean = filter_norm.mean()
                        batch_std = filter_norm.std()

                        batch_filter1 = (batch_norm > batch_mean) & rew_filter # when this to be true, this is an underexplored high reward trajectory 
                  
                        if(torch.any(batch_filter1) and self.verbose):
                              print("Will promote:")
                              print("reward:", batch_rews[batch_filter1])
                              print("logPF:", logPF[batch_filter1])
                              print("logPB:", logPB[batch_filter1])
                              print("This", batch_norm[batch_filter1], "Mean:", batch_mean, "Std:", batch_std)

                        loss = loss - (logPF + logPB) * (batch_filter1) 
                        
                        batch_filter2 =  (batch_norm < batch_mean) & rew_filter # when this to be true, this is a relatively old trajectory in a frequently visited area 
                        
                        if(torch.any(batch_filter2) and self.verbose):
                              print("Will depress:")
                              print("reward:", batch_rews[batch_filter2])
                              print("logPF:", logPF[batch_filter2])
                              print("logPB:", logPB[batch_filter2])
                              print("This", batch_norm[batch_filter2], "Mean:", batch_mean, "Std:", batch_std)

                        loss = loss * (~batch_filter2) - (logPB - logPF) * (batch_filter2) # release the PF (IMO this is the exploration budget) and update the PB by the maximum likelihood

                  # Compute traj loss
                  loss = loss.mean()
                  # print(loss)

                  # Optimize the models
                  self.forward_optim.zero_grad()
                  self.backward_optim.zero_grad()
                  self.logZ_optim.zero_grad()

                  loss.backward()

                  # Gradient clipping for forward policy
                  # torch.nn.utils.clip_grad_norm_(self.forward_policy.parameters(), self.max_grad_norm)

                  # Gradient clipping for backward policy
                  # torch.nn.utils.clip_grad_norm_(self.backward_policy.parameters(), self.max_grad_norm)
                  self.forward_optim.step()
                  self.backward_optim.step()
                  self.logZ_optim.step()

                  self.logger['log_Z'].append(self.logZ.item())

            self._n_updates += gradient_step
            # Log actor loss
            self.logger['traj_losses'].append(traj_losses/self.gradient_steps)
            self.logger['traj_losses_std'].append(traj_losses_std/self.gradient_steps)
      
      def save(self, model_dir):
            if not os.path.exists(model_dir):
                  os.makedirs(model_dir)
            torch.save(self.forward_policy.state_dict(), f'{model_dir}/forward_policy.pth')
            torch.save(self.backward_policy.state_dict(), f'{model_dir}/backward_policy.pth')
            torch.save(self.logZ, f'{model_dir}/logZ.pth')
            # save the optimizer
            torch.save(self.forward_optim.state_dict(), f'{model_dir}/forward_optim.pth')
            torch.save(self.backward_optim.state_dict(), f'{model_dir}/backward_optim.pth')
            torch.save(self.logZ_optim.state_dict(), f'{model_dir}/logZ_optim.pth')
            # save the i_so_far, e_so_far, t_so_far
            current_progress = (self.logger['t_so_far'], self.logger['i_so_far'], self.logger['e_so_far'])
            with open(f'{model_dir}/progress.pkl', 'wb') as f:
                  pickle.dump(current_progress, f)

      def save_replay_buffer(self, model_dir):
            # Save our model and memory at the end of training
            self.memory.save(model_dir)

      def load(self, model_dir, load_optim = False):
            # Load our model 
            self.forward_policy.load_state_dict(torch.load(f'{model_dir}/forward_policy.pth'))
            self.backward_policy.load_state_dict(torch.load(f'{model_dir}/backward_policy.pth'))
            self.logZ = torch.load(f'{model_dir}/logZ.pth')

            if load_optim:
                  self.forward_optim.load_state_dict(torch.load(f'{model_dir}/forward_optim.pth'))
                  self.backward_optim.load_state_dict(torch.load(f'{model_dir}/backward_optim.pth'))
                  self.logZ_optim.load_state_dict(torch.load(f'{model_dir}/logZ_optim.pth'))

                  # load the i_so_far, e_so_far, t_so_far
                  with open(f'{model_dir}/progress.pkl', 'rb') as f:
                        t_so_far, i_so_far, e_so_far = pickle.load(f)
            
                  return t_so_far, i_so_far, e_so_far
            return 0, 0, 0
      
      def load_replay_buffer(self, model_dir):
            # Load the memory
            self.memory.load(model_dir)