from typing import Dict, Text
import numpy as np

import gymnasium as gym

from metadrive.envs import BaseEnv
from metadrive.engine.engine_utils import close_engine
from scipy.__config__ import show
from sympy import det
close_engine()

from metadrive.component.map.pg_map import MapGenerateMethod
from metadrive.component.map.base_map import BaseMap
from metadrive import MetaDriveEnv
from metadrive.manager import BaseManager
from metadrive.component.vehicle.vehicle_type import DefaultVehicle
from metadrive import MultiAgentMetaDrive
from metadrive.examples import expert

from PIL import Image
import os
import random
import logging

# This environment is a four-lane highway environment,
# one victim agent is controlled by RL
# the others are controlled by IDM 
class Environment(gym.Env):
    
    @classmethod
    def default_config(cls) -> dict:        
        config = {
            "num_agents": 4,
            "map_config": {
                BaseMap.GENERATE_TYPE: MapGenerateMethod.BIG_BLOCK_SEQUENCE, 
                BaseMap.GENERATE_CONFIG: "SSS",
                "lane_num": 4,            # Set number of lanes to 4
                "lane_width": 3,        # Optional: Set lane width (default is 3.5 meters)
            },
            "use_render": False,
            "prefer_track_agent": "agent0",

            "agent_configs": {
                "agent0": {
                    "spawn_lane_index": ('>', '>>', 0),
                    "spawn_longitude": 0,
                    "use_special_color": False
                },
                "agent1": {
                    "spawn_lane_index": ('>', '>>', 1),
                    "spawn_longitude": 25,
                    "use_special_color": True
                },
                "agent2": {
                    "spawn_lane_index": ('>', '>>', 2),
                    "spawn_longitude": 15,
                    "use_special_color": False
                },
                "agent3": {
                    "spawn_longitude": 10,
                    "spawn_lane_index": ('>', '>>', 3),
                    "use_special_color": False
                }
            },
            "horizon": 100000, # Max steps per agent
            "log_level": logging.ERROR,
            "physics_world_step_size": 0.1
        }
        return config
    
    def __init__(self, render_mode = None):
        self.t = 0
        self.action_dim = 2
        self.max_t = 200
        self.out_of_road_ticks = 0
        self.config = self.default_config()
        self.state_dim = 21
        self.expert = expert
        
        #config['duration'] = 200 
        # config['render_agent'] = render
        self.metaenv = MultiAgentMetaDrive(self.config)
        self.time_step = self.metaenv.config['physics_world_step_size']
        self.observation_space = gym.spaces.Box(shape=(self.state_dim,), low = -float("inf"), high = float("inf"), dtype=np.float32)
        self.action_space = gym.spaces.Box(shape=(self.action_dim+1,), low=-2, high=2, dtype=np.float32)
        self.render_mode = render_mode

        # self.controlled_vehicle = self.metaenv.agents["agent1"]
        # self.victim_vehicles = [self.metaenv.agents["agent0"], self.metaenv.agents["agent2"], self.metaenv.agents["agent3"]]
        # self.all_vehicles = [self.controlled_vehicle] + self.victim_vehicles

        # create directory for saving gifs
        if not os.path.exists("../save_demos"):
            os.mkdir("../save_demos")

    def reset(self, seed = None, options = None):
        self.metaenv.reset()
        if "agent4" in self.metaenv.agents:
            self.metaenv = MultiAgentMetaDrive(self.config)

        self.frames = []
        self.t = 0
        self.out_of_road_ticks = 0
        # self.min_ttc = 100
        self.controlled_vehicle = self.metaenv.agents["agent1"]
        self.victim_vehicles = [self.metaenv.agents["agent0"], self.metaenv.agents["agent2"], self.metaenv.agents["agent3"]]
        self.all_vehicles = {"agent0": self.metaenv.agents["agent0"],\
                                "agent1": self.metaenv.agents["agent1"],\
                                "agent2": self.metaenv.agents["agent2"],\
                                "agent3": self.metaenv.agents["agent3"]}
        
        
        return self._get_obs(), self._get_info()
    
    def _reward(self, action) -> float:
        """
        The reward is defined to encourage collision potential of victim vehicles.
        :param action: the last action performed
        :return: the corresponding reward
        """
        rewards = self._rewards(action)
        if rewards["attacker_collision"]:
            reward = -1
        elif rewards["victim_collision"]:
            reward = 1
        else:
            reward = rewards["attacker_on_road"] * 0.1
        return reward

    def _rewards(self, action) -> Dict[Text, float]:
        collision_attacker= self.controlled_vehicle.crash_vehicle
        collision_victim = max([vehicle.crash_vehicle for vehicle in self.victim_vehicles])
        
        return {
            "victim_collision": collision_victim,
            "attacker_collision": collision_attacker,
            "attacker_on_road": float(not self.metaenv._is_out_of_road(self.controlled_vehicle))
        }
    
    def _get_obs(self): # return the state of the environment
        res = [self.t/self.max_t]
        # add acc of the vehicles
        for vehicle in self.all_vehicles.values():
            # vehicle = self.all_vehicles[vehicle]
            res.append(vehicle.position[0]/100)
            res.append(vehicle.position[1]/10)
            res.append(vehicle.velocity[0]/10)
            res.append(vehicle.velocity[1])
            res.append((vehicle.velocity[0] - vehicle.last_velocity[0])/self.time_step/10)

        return np.array(res, dtype = np.float32)
    
    def _get_info(self, frame = None, terminated = False):
        return {"frame": frame,
                "terminated": terminated
                }
    
    def _get_victim_obs(self):
        return self.observation_type.observe()[0]
    
    def _is_terminated(self) -> bool:
        return False

    def _is_truncated(self) -> bool:
        return False
    
    def get_state(self, obs):
        return obs
    
    def get_error(self, samples):
        return 0
    
    # def calculate_ttc(self, vehicle1, vehicle2):
    #     if vehicle1.velocity[0] - vehicle2.velocity[0] == 0:
    #         return 100
    #     res =  (vehicle2.position[0] - vehicle1.position[0]) / (vehicle1.velocity[0] - vehicle2.velocity[0])
    #     if res < 0:
    #         return 100
    #     return res
    
    # def calcualte_min_ttc(self):
    #     min_ttc = 1000
    #     victim = self.victim_vehicle[0]
    #     for vehicle in self.victim_vehicle[1:-1]:
    #         if (victim.position[1] < vehicle.position[1]) < 2:
    #             current_ttc = self.calculate_ttc(victim, vehicle)
    #             if current_ttc < min_ttc:
    #                 min_ttc = current_ttc
    #     return max(min_ttc, 0.1)
        
    def step(self, adv_action):
        if "agent4" in self.metaenv.agents: # drop the trajectory
            return self._get_obs(), 0, True, True, 1e-1

        rewards = self._rewards([])
        exit_flag = adv_action[0]
        adv_action = adv_action[1:]
        adv_action[1] = min(adv_action[1] + 0.5, 2.0)

        obs = self._get_obs()

        # check if at the ending state, + 1000 * np.exp((self.controlled_vehicle.position[0]-100)/10)  -
        if self.t+1 >= self.max_t: # run out of time
            return obs, 0, True, False, 1000*np.exp(10 * self.t/self.max_t-10) + 1e-1 # Assign a small reward to encourage the case to not collision until the end
        if rewards['attacker_collision']: # attacker crashed
            # print("Attacker crashed")
            return obs, 0, True, False, 1000*np.exp(10 * self.t/self.max_t-10) + 1e-1
        elif rewards['victim_collision']: # victim crashed
            print("Victim crashed")
            imgs = [frame for frame in self.frames]
            imgs = [Image.fromarray(img) for img in imgs]
            num = len(os.listdir("../saved_demos/"))
            imgs[0].save(f"../saved_demos/victim_crash_{self.t}_{num}.gif", save_all=True, append_images=imgs[1:], duration=50, loop=0)

            return obs, 20000, True, False, 1000*np.exp(10 * self.t/self.max_t-10) + 1e-1
        elif not rewards['attacker_on_road']: # running off the road
            return obs, 0, True, False, 1000*np.exp(10 * self.t/self.max_t-10) + 1e-1
        elif exit_flag > 0.5:
            # print("Exit")
            return obs, 0, True, False, 1000*np.exp(10 * self.t/self.max_t-10) + 1e-1

        actions = {}
        for agent in self.all_vehicles.keys(): #self.metaenv.agents:
            # Use the expert policy to get actions
            if agent == "agent1":
                if agent in self.metaenv.agents:
                    actions[agent] = adv_action
            else:
                if agent in self.metaenv.agents:
                    actions[agent] = self.expert(self.metaenv.agents[agent], deterministic=True)
       
        _, _, terminated, _, _ = self.metaenv.step(actions)

        if any(terminated.values()):
            frame = None
            return  obs, 0, False, False, 1000*np.exp(10 * self.t/self.max_t-10) + 1.0 
        
        
        current_position = self.controlled_vehicle.position # self.metaenv.agents["agent2"].position

        if self.render_mode != None:
            self.metaenv.render_mode = "topdown"

            frame = self.metaenv.render(mode="topdown", 
                    scaling=4, # 4 pixels per meter
                    camera_position=current_position,
                    window=True,
                    screen_size=(500, 500),
                    )
            self.metaenv.engine.top_down_renderer.position = current_position
            self.frames.append(frame)

        self.t += 1

        # self.min_ttc = min(self.min_ttc, self.calcualte_min_ttc()) 

        # print(self._get_obs())
        
        return  self._get_obs(), 0, False, False, self._get_info(frame=frame, terminated=any(terminated.values())) # encourage RL to get small ttc