from matplotlib.patches import Rectangle
import numpy as np
import scipy
import gymnasium as gym


class CarFollowingModel:
    def __init__(self, max_t):
        # We initialize the hyperparameters of the environment using the pooled MA-IDM model 
        # in https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=10415310
        v0 = 12.244
        s0 = 2.833
        T = 1.133
        alpha = 0.362
        beta = 2.140
        sigma_v = 0.005
        sigma_x = 0.026
        sigma_k = 0.390
        delay = 1.607
        delta = 4

        self.a_future = [0] * int(delay*10) # delay unit is 0.1 second
        self.max_a = beta
        self.min_a = -3 # minimum deceleration, a very important parameter
        self.comfort_b = alpha # comfort deleration
        self.target_v = v0 # target speed, unit: 1 m/s
        self.min_s = s0 # min space gap, unit 1 m
        self.min_T = T # min time gap, unit: 1 s
        self.delay = delay # reaction delay, unit: 0.1 s
        self.delta = delta # delta parameter

        # Observation noise
        self.sigma_v = sigma_v
        self.sigma_x = sigma_x

        # MA-IDM Gaussian noise
        X = np.expand_dims(np.linspace(0, 100, max_t), 1)
        self.cov_k = self.exponentiated_quadratic(X, X, sigma_k, delay)
         
        self.a_GP_sim = np.random.multivariate_normal(np.zeros((max_t)), self.cov_k, size = 1)[0].tolist()

    def update(self, x1, v1, a1, l1, x2, v2): # update the acceleration of the behind vehicles
        # add the human observation noise
        v2 += np.random.normal(0, self.sigma_v)
        x2 += np.random.normal(0, self.sigma_x)

        dv = v1 - v2 # speed difference
        s = x2 - x1 - l1 # net gap
        tmp_a = self.max_a * (1- (v1/self.target_v)**self.delta - \
                              ((self.min_s + max(v1*self.min_T + dv * v1 / (2*np.sqrt(self.comfort_b * self.max_a)),0))/s)**2)
        tmp_a = np.clip(tmp_a, self.min_a, self.max_a) + self.a_GP_sim.pop(0)
        self.a_future.append(tmp_a)
        a1 = self.a_future.pop(0)
        return a1
    
    # Define the exponentiated quadratic 
    def exponentiated_quadratic(self, xa, xb, sigma_k, length_scale):
        """Exponentiated quadratic with σ=1"""
        # L2 distance (Squared Euclidian)
        sq_norm = -0.5 * scipy.spatial.distance.cdist(xa, xb, 'sqeuclidean') / length_scale**2
        return sigma_k**2 * np.exp(sq_norm)

# This environment is a simple car following environment with three vehicles, the front vehicle (index as 2) is the attacker, and it wants the second vehicle to crash into the third vehicle
class Environment(gym.Env):
    metadata = {"render_modes": ["matplotlib"]}
    def __init__(self, render_mode = None, ax = None):
        # vehicle locations
        self.l0 = self.l1 = self.l2 = 5 # vehicle length
        self.x0 = 0
        self.x1 = self.x0 + self.l0 + 2.5 + np.random.random() * 10 
        self.x2 = self.x0 + self.l0 + 2.5 + 10 + np.random.random() * 10  
        self.a0 = self.a1 = self.a2 = 0 # vehicle 1 acceleration
        self.v2 = np.random.random() * 10 # vehicles' speed, uniform distribution between 0 and 10 m/s
        self.v1 = 0 # victim cannot be faster than the attacker in the beginning
        self.v0 = 0
        
        self.t = 0 # time
        self.l_max = 300 # maximum length of the road, 300m
        self.max_t = 120 # maximum time of the simulation

        self.model_0= CarFollowingModel(self.max_t) # car following model
        self.model_1 = CarFollowingModel(self.max_t)
        self.state_dim = 1 + 3 * 3 + len(self.model_0.a_future) + len(self.model_1.a_future) # state dimension
        self.action_dim = 1
        self.max_a2 = 1 # -minimum/maximum acceleration of the attacker vehicle

        self.observation_space = gym.spaces.Box(shape=(self.state_dim,), low = -float("inf"), high = float("inf"), dtype=np.float32)
        self.action_space = gym.spaces.Box(shape=(self.action_dim,), low=-1, high=1, dtype=np.float32)

        assert render_mode is None or (render_mode in self.metadata["render_modes"] and ax is not None)
        self.render_mode = render_mode
        self.ax = ax

    def reset(self, seed = None, options = None):
        super().reset(seed = seed)
        # vehicle locations
        self.x0 = 0
        self.x1 = self.x0 + self.l0 + 2 + np.random.random() * 20 
        self.x2 = self.x1 + self.l1 + 2 + np.random.random() * 20  
        self.a0 = self.a1 = self.a2 = 0
        self.v2 = np.random.random() * 10 # vehicles' speed, uniform distribution between 0 and 10 m/s
        self.v1 = min(np.random.random() * 10, self.v2) # victim cannot be faster than the attacker in the beginning
        self.v0 = min(np.random.random() * 10, self.v1) 

        self.model_0.a_future = [0] * int(self.model_0.delay*10) 
        self.model_1.a_future = [0] * int(self.model_1.delay*10)

        self.model_0.a_GP_sim = np.random.multivariate_normal(np.zeros((self.max_t)), self.model_0.cov_k, size = 1)[0].tolist()
        self.model_1.a_GP_sim = np.random.multivariate_normal(np.zeros((self.max_t)), self.model_1.cov_k, size = 1)[0].tolist()

        self.t = 0
        self.min_dist = self.l_max

        return self._get_obs(), self._get_info()
    

    # the ending state is the crash state (or the one before the time out)
    def step(self, a2):
         # check if sim is finished
        if (self.x2 + self.l2 >= self.l_max) or (self.x1 + self.l1 >= self.x2) or (self.t >= self.max_t):
            return self._get_obs(), max(self.t-self.max_t/10, 0) / self.max_t, True, False, self._get_info()
        
        if self.x0 + self.l0 >= self.x1:
            # print(f"Victim crushed: {self.v0}, {self.v1}, {self.a0}, {self.a1}, {self.x0}, {self.x1}")
            return self._get_obs(), 10, True, False, self._get_info() # reward is 10, victim crushed
        
        # the decision of the victims are known by the attacker, this is an important assumption so the GFN can be applied theorectically
        self.a1 = self.model_1.update(self.x1, self.v1, self.a1, self.l1, self.x2, self.v2)
        self.a0 = self.model_0.update(self.x0, self.v0, self.a0, self.l0, self.x1, self.v1)

        # if a2 is an array, convert it to a scalar
        if isinstance(a2, np.ndarray):
            a2 = a2[0] * self.max_a2
        
        self.a2 = np.clip(a2, -self.max_a2, self.max_a2) * 3

        # no backward movement
        v0 = max(self.v0 + self.a0 * 0.1, 0)
        v1 = max(self.v1 + self.a1 * 0.1, 0)
        v2 = max(self.v2 + self.a2 * 0.1, 0)

        self.x0 += (self.v0 + v0) * 0.05 # 0.1 second per step
        self.x1 += (self.v1 + v1) * 0.05 
        self.x2 += (self.v2 + v2) * 0.05

        self.a0 = (v0 - self.v0) * 10
        self.a1 = (v1 - self.v1) * 10
        self.a2 = (v2 - self.v2) * 10

        self.v0 = v0
        self.v1 = v1
        self.v2 = v2

        self.t += 1

        self.min_dist = min(self.min_dist, self.x1 - self.x0 - self.l0)

        return self._get_obs(), 0, False, False, self._get_info()
    
    def _get_obs(self): # return the state of the environment
        return np.array([self.t/self.max_t, self.x0/self.l_max, self.v0, self.a0, \
                         self.x1/self.l_max, self.v1, self.a1, \
                            self.x2/self.l_max, self.v2, self.a2]+ self.model_0.a_future + self.model_1.a_future, dtype = np.float32)
    
    def _get_info(self): # return the information of the environment
        return {}
        
    def _render_frame(self):
        # plot the info of all vehicles, each vehicle is displayed as a 5x3 rectangle box
        # ax is the axis of the plot)
        im0 = self.ax.add_patch(Rectangle((self.x0, 1), min(250-self.x0, self.l0), 3)) # victim vehicle 2
        im1 = self.ax.add_patch(Rectangle((self.x1, 1), min(250-self.x1, self.l1), 3)) # victim vehicle 1
        im2 = self.ax.add_patch(Rectangle((self.x2, 1), min(250-self.x2, self.l2), 3)) # attacker
        return [im0, im1, im2]
    
    def render(self):
        if self.render_mode == "matplotlib":
            return self._render_frame()