from typing import Dict, Text
import numpy as np
from highway_env import utils
from highway_env.envs.common.abstract import AbstractEnv
from highway_env.envs.common.action import Action
from highway_env.road.road import Road, RoadNetwork
from highway_env.vehicle.kinematics import Vehicle
from stable_baselines3 import TD3, PPO
import gymnasium as gym


# This environment is a three-lane highway environment,
# one victim agent is controlled by RL
# the others are controlled by IDM 
class Environment(AbstractEnv):
    
    @classmethod
    def default_config(cls) -> dict:
        config = super().default_config()
        
        config.update(
            {
            "observation": {
                  "type": "MultiAgentObservation",
                    "observation_config": {
                    "type": "Kinematics",
                    "vehicles_count": 7,
                    "absolute": False,
                    "order": "sorted"
                   },
            },
            "action": {
                "type": "MultiAgentAction",
                "action_config": {
                    "type": "ContinuousAction",
                }
            },
            "lanes_count": 3,
            "vehicles_count": 5,
            "initial_spacing": 1,
            "ego_spacing": 1,
            "initial_lane_id": None,
            "controlled_vehicles": 1,
            "vehicles_density": 2,
            "simulation_frequency": 10,  # [Hz]
            "policy_frequency": 10,  # [Hz]
            "other_vehicles_type": "highway_env.vehicle.behavior.IDMVehicle",
            "screen_width": 600,  # [px]
            "screen_height": 150,  # [px]
            "centering_position": [0.2, 0.5],
            "right_lane_reward": 0,
            "scaling": 4,
            "offroad_terminal": True,
            # "offscreen_rendering": False
            }
        )
        return config
    
    def __init__(self, render_mode = None):
        self.t = 0
        self.action_dim = 2
        self.max_t = 120
        self.out_of_road_ticks = 0
        config = self.default_config()
        self.state_dim = 5 * config["observation"]['observation_config']['vehicles_count'] + 1
        config['duration'] = 120/config['policy_frequency']
        # config['render_agent'] = render
        super().__init__(config)
        self.rl_model = PPO.load("../data/highway_ppo/model")

        self.observation_space = gym.spaces.Box(shape=(self.state_dim,), low = -float("inf"), high = float("inf"), dtype=np.float32)
        self.action_space = gym.spaces.Box(shape=(self.action_dim,), low=-1, high=1, dtype=np.float32)
        self.render_mode = render_mode

    def reset(self, seed = None, options = None):
        super().reset(seed = seed)
        self.observation_space = gym.spaces.Box(shape=(self.state_dim,), low = -float("inf"), high = float("inf"), dtype=np.float32)
        self.action_space = gym.spaces.Box(shape=(self.action_dim,), low=-1, high=1, dtype=np.float32)
        self.t = 0
        self.out_of_road_ticks = 0
        self.min_ttc = 100
        return self._get_obs(), self._get_info()
    
    def _reset(self):
        self._create_road()
        self._create_vehicles()

    def _create_road(self) -> None:
        """Create a road composed of straight adjacent lanes."""
        self.road = Road(
            network=RoadNetwork.straight_road_network(
                self.config["lanes_count"], speed_limit=30
            ),
            np_random=self.np_random,
            record_history=self.config["show_trajectories"],
        )

    def _create_vehicles(self) -> None:
        """Create some new random vehicles of a given type, and add them on the road."""
        other_vehicles_type = utils.class_from_path(self.config["other_vehicles_type"])
        others = self.config["vehicles_count"]

        self.controlled_vehicles = []

        victim = Vehicle.create_random(
            self.road,
            lane_id=self.config["initial_lane_id"],
            spacing=self.config["ego_spacing"],
        )
        victim = self.action_type.vehicle_class(self.road, victim.position, victim.heading, victim.speed)
        self.controlled_vehicles.append(victim)
        self.road.vehicles.append(victim)

        for _ in range(others):
            vehicle = other_vehicles_type.create_random(
                self.road, spacing= 1 / self.config["vehicles_density"]
            )
            vehicle.randomize_behavior()
            self.road.vehicles.append(vehicle)
        
        attacker = Vehicle.create_random(
            self.road,
            lane_id=self.config["initial_lane_id"],
            spacing=self.config["ego_spacing"],
        )
        attacker = self.action_type.vehicle_class(
            self.road, attacker.position, attacker.heading, attacker.speed
        )
        self.controlled_vehicles.append(attacker)
        self.road.vehicles.append(attacker)

    def _reward(self, action) -> float:
        """
        The reward is defined to encourage collision potential of victim vehicles.
        :param action: the last action performed
        :return: the corresponding reward
        """
        rewards = self._rewards(action)
        if rewards["attacker_collision"]:
            reward = -1
        elif rewards["victim_collision"]:
            reward = 1
        else:
            reward = rewards["attacker_on_road"] * 0.1
        return reward

    def _rewards(self, action) -> Dict[Text, float]:
        collision_attacker= self.controlled_vehicles[1].crashed
        collision_victim = max([vehicle.crashed for vehicle in self.road.vehicles])
        
        return {
            "victim_collision": collision_victim,
            "attacker_collision": collision_attacker,
            "attacker_on_road": float(self.controlled_vehicles[1].on_road)
        }
    
    def _get_obs(self): # return the state of the environment
        res = [self.t/self.max_t]
        # add acc of the vehicles
        for vehicle in self.road.vehicles:
            res.append(vehicle.position[0]/100)
            res.append(vehicle.position[1])
            res.append(vehicle.velocity[0]/10)
            res.append(vehicle.velocity[1])
            res.append(vehicle.action['acceleration'])

        return np.array(res, dtype = np.float32)
    
    def _get_info(self):
        return {}
    
    def _get_victim_obs(self):
        return self.observation_type.observe()[0]
    
    def _is_terminated(self) -> bool:
        return False

    def _is_truncated(self) -> bool:
        return False
    
    def calculate_ttc(self, vehicle1, vehicle2):
        if vehicle1.velocity[0] - vehicle2.velocity[0] == 0:
            return 100
        res =  (vehicle2.position[0] - vehicle1.position[0]) / (vehicle1.velocity[0] - vehicle2.velocity[0])
        if res < 0:
            return 100
        return res
    
    def calcualte_min_ttc(self):
        min_ttc = 1000
        victim = self.road.vehicles[0]
        for vehicle in self.road.vehicles[1:-1]:
            if (victim.position[1] < vehicle.position[1]) < 2:
                current_ttc = self.calculate_ttc(victim, vehicle)
                if current_ttc < min_ttc:
                    min_ttc = current_ttc
        return max(min_ttc, 0.1)
        
    def step(self, adv_action):
        rewards = self._rewards([])

        # check if at the ending state
        if self.t >= self.max_t: # run out of time
            return self._get_obs(), 1, True, False, self._get_info() # Assign a small reward to encourage the case to not collision until the end
        if rewards['attacker_collision']: # attacker crashed
            return self._get_obs(), 0, True, False, self._get_info()
        elif rewards['victim_collision']: # victim crashed
            return self._get_obs(), 10, True, False, self._get_info()
        elif not rewards['attacker_on_road']: # running off the road
            return self._get_obs(), self.t/self.max_t * 0.1, True, False, self._get_info()

        victim_action = self.rl_model.predict(self._get_victim_obs(), deterministic=True)[0]

        action = tuple([victim_action, np.array(adv_action)])

        super().step(action)

        if self.render_mode != None:
            self.render()

        self.t += 1

        self.min_ttc = min(self.min_ttc, self.calcualte_min_ttc()) 
        
        return  self._get_obs(), 0.1/self.min_ttc, False, False, self._get_info() # encourage RL to get small ttc