from cgitb import reset
from operator import is_
from pdb import Restart
from re import sub
from tracemalloc import start
from turtle import pos, up
import carla
import os
import time
import subprocess
import numpy as np
import math
import gymnasium as gym
from gym_envs.utils import *
import logging

logging.basicConfig(level=logging.INFO)


SHUTDOWN_OP = "docker exec openpilot_client /bin/bash -c 'pkill -f bridge.py'"

def make_launch_op(location, town):
    return f"docker exec  openpilot_client /bin/bash -c \"cd /home/batman/openpilot/tools/sim && ./bridge.py --town {town} --spawn_location  \\\'{location}\\\' --auto-engage >/dev/null \"&"

class Environment(gym.Env):
    metadata = {"render_modes": ["human"]}
    def __init__(self, render_mode=None):
        self.state_dim = 19
        self.action_dim = 2
        self.max_a2 = [1, 1]
        self.min_a2 = [-1, -1]
        self.max_t = 120
        
        while True:
            try:
                self.client, self.world = connect_to_carla()
                break  # If the connection succeeds, break out of the loop
            except Exception as e:
                print(f"An error occurred: {e}")
                print("Restarting CARLA...")
                
                start_carla()
                
                connected = False
                while not connected:
                    try:
                        self.client, self.world = connect_to_carla()
                        connected = True
                    except Exception as e:
                        print(f"Connection failed: {e}. Retrying in 5 seconds...")
                        time.sleep(5)

        self.sensors = []
        kill_manager_process()
        start_op_process()
        # constraints flags
        self.constriants = {
            "collision_attacker": False,
            "safe_distance": False,
            "speeding": False
        }
        time.sleep(5)

        self.observation_space = gym.spaces.Box(shape=(self.state_dim,), low = -float("inf"), high = float("inf"), dtype=np.float32)
        self.action_space = gym.spaces.Box(shape=(self.action_dim,), low=-1, high=1, dtype=np.float32)
        assert render_mode is None or (render_mode in self.metadata["render_modes"])
        self.render_mode = render_mode

    def victim_collision_callback(self):
        # print("Victim collision detected")
        self.collision_victim = True 

    def attacker_collision_callback(self):
        # print("Attacker collision detected")
        self.collision_attacker = True
        self.constriants["collision_attacker"] = True

    def reset(self, seed=None, options=None): 
        super().reset(seed = seed)

        attacker, victim_init, npc_list, target_obj, weather, map, time_step, calc_rob = load_scene_config("../data/setups/C2.json", weather_randomization = True, speed_randomization = True, location_randomization = False)
        if not is_ui_process_running():
            kill_manager_process()
            start_op_process()
            
        while True:
            try:
                victim = self.restart_op(victim_init, map)
                break  # If the operation succeeds, break out of the loop
            except Exception as e:
                print(f"An error occurred: {e}")
                print("Restarting CARLA...")
                print("Restarting OP Docker...")
                
                restart_opdocker()
                connected = False
                while not connected:
                    try:
                        start_carla()
                        self.client, self.world = connect_to_carla()
                        connected = True
                    except Exception as e:
                        print(f"Connection failed: {e}. Retrying in 5 seconds...")
                        time.sleep(5)
                
                # Retry the operation
                victim = self.restart_op(victim_init, map)
        
        '''
        # clear all vehicles and sensors
        for sensor in self.sensors:
            sensor.destroy()



        # Ruoyu: Note to Zengxiang: It is not always the case the tesla is the obstacle. Here we destroy all vehicles so when the scene is reloaded, we have the new NPC setup.
        for actor in self.world.get_actors():
            if 'vehicle' in actor.type_id:
                actor.destroy()
        
        '''
        self.constriants = {
            "collision_attacker": False,
            "safe_distance": False,
            "speeding": False
        }
         
        self._reset_vehicle(attacker, victim, npc_list) # reset the attacker's location, so as the state
        self.min_ttc = 100
        self.t = 0

        
        # put the spectator to the victim
        spectator = self.world.get_spectator()
        transform = victim.get_transform()
        transform.location.x -= 10
        transform.location.z += 10
        transform.rotation.pitch -= 15
        spectator.set_transform(transform)

        collision_sensor_bp = self.world.get_blueprint_library().find('sensor.other.collision')
        self.collision_sensor = self.world.spawn_actor(collision_sensor_bp, carla.Transform(), attach_to=self.victim)
        self.collision_sensor.listen(lambda event: self.victim_collision_callback())
        self.collision_victim = False

        collision_sensor_bp2 = self.world.get_blueprint_library().find('sensor.other.collision')
        self.collision_sensor2 = self.world.spawn_actor(collision_sensor_bp2, carla.Transform(), attach_to=self.attacker)
        self.collision_sensor2.listen(lambda event: self.attacker_collision_callback())
        self.collision_attacker = False

        
        # wait for openpilot to start
        flag = True
        start_time = time.time()

        while flag:
            time.sleep(1)
            for actor in self.world.get_actors():

                if 'vehicle' in actor.type_id and actor.id != self.attacker.id and actor.id != self.victim.id:
                    if abs(actor.get_velocity().x) > 1e-4 or abs(actor.get_velocity().y) > 1e-4 or time.time() - start_time > 10:
                        flag = False
        self.attacker.set_target_velocity(self.parse_speed_string(self.attacker, attacker.speed))
        self.victim.set_target_velocity(self.parse_speed_string(self.victim, victim_init.speed))
        
        for npc_agent, npc in zip(self.npcs, npc_list):
            if npc.role != 'stopped_vehicle':
                npc_agent.set_target_velocity(self.parse_speed_string(npc_agent, npc.speed))
                #npc_agent.set_autopilot(True)
        
        
        '''
        # add initial speed to all vehicles except the npcs
        for actor in self.world.get_actors():
            if 'vehicle' in actor.type_id:
                actor.set_target_velocity(carla.Vector3D(10, 0, 0))
        '''
        return self._get_obs(), self._get_info()
    
    def restart_op(self, victim, map):
        os.system(SHUTDOWN_OP)
        self.world = self.client.load_world(map)
        time.sleep(5)
        pos_set = ""
        victim_pos = eval(victim.pos)
        pos_set += str(round(victim_pos.location.x, 2)) + "," + str(victim_pos.location.y) + "," + str(round(victim_pos.location.z, 2)) + "," + str(victim_pos.rotation.pitch) + "," + str(victim_pos.rotation.yaw) + "," + str(victim_pos.rotation.roll)
        os.system(make_launch_op(pos_set, map))
        time.sleep(2)
        self.world = self.client.get_world()
        victim = self.world.get_actors().filter('vehicle.tesla.model3')[0]
            

        return victim
        
    def _reset_vehicle(self, attacker, victim, npc_list):
        #self.close()
        self.world = self.client.get_world()
        if attacker.model == None:
            attacker.model = self.client.get_world().get_blueprint_library().find('vehicle.audi.tt')
        else:
            attacker.model = self.client.get_world().get_blueprint_library().find(attacker.model)
        for npc in npc_list:
            if npc.model == None:
                npc.model = self.client.get_world().get_blueprint_library().find('vehicle.audi.tt')
            else:
                npc.model = self.client.get_world().get_blueprint_library().find(npc.model)
        self.victim = victim
        print(f"Victim: {self.victim.get_transform().location.x}, {self.victim.get_transform().location.y}")
        print(f"Attacker: {attacker.pos}")
        self.attacker = self.world.spawn_actor(attacker.model, eval(attacker.pos))
        self.npcs = []
        for npc in npc_list:
            self.npcs.append(self.world.try_spawn_actor(npc.model, eval(npc.pos)))
        
    
    def step(self, a2):
        # truncated if the attacker violates the constraints
        if self.check_constraints() or self.attacker.get_location().distance(self.victim.get_location()) > 30:
            return self._get_obs(), 0, True, False, self._get_info() # state, reward, terminated, truncated, info
        
        if self.collision_victim:
            if is_ui_process_running(): # make sure the openpilot is running
                return self._get_obs(), 10, True, False, self._get_info()
            else: # openpilot is dead, mark the sample as truncated (dropped)
                kill_manager_process()
                start_op_process()
                return self._get_obs(), 10, True, True, self._get_info()
            
        if self.over_speed(self.victim):
            if is_ui_process_running():
                return self._get_obs(), 5, True, False, self._get_info()
            else:
                kill_manager_process()
                start_op_process()
                return self._get_obs(), 5, True, True, self._get_info()
        
        # TODO: need to figure out check these constraints with which vehicle
        # Here I just assume the bad behavior occurs between the attacker and the victim
        if self.lane_change_bad_distance(self.victim, self.attacker, 1):
            if is_ui_process_running():
                return self._get_obs(), 3, True, False, self._get_info()
            else:
                kill_manager_process()
                start_op_process()
                return self._get_obs(), 3, True, True, self._get_info()
        
        if self.follow_bad_distance(self.victim, self.attacker, 1):
            if is_ui_process_running():
                return self._get_obs(), 2, True, False, self._get_info()
            else:
                kill_manager_process()
                start_op_process()
                return self._get_obs(), 2, True, True, self._get_info()
            
        if self.wrong_direction(self.victim):
            if is_ui_process_running():
                return self._get_obs(), 1, True, False, self._get_info()
            else:
                kill_manager_process()
                start_op_process()
                return self._get_obs(), 1, True, True, self._get_info()
        
        if self.t >= self.max_t:
            return self._get_obs(), 0, True, False, self._get_info()
        
        # check if at end state
        a2 = np.clip(a2, self.min_a2, self.max_a2)

        self.apply_commands(self.attacker, a2, 1)

        self.t += 1
        
        min_ttc = 100
        for actor in self.world.get_actors(): # use the min ttc to guide the attack
            if 'vehicle' in actor.type_id:
                for actor2 in self.world.get_actors():
                    if 'vehicle' in actor2.type_id:
                        if actor.id != actor2.id and actor.id != self.attacker.id and actor2.id != self.attacker.id:
                            ttc = calculate_ttc(actor, actor2)
                            if ttc is not None and ttc < min_ttc:
                                min_ttc = ttc

        if(self.min_ttc > min_ttc):
            self.min_ttc = min_ttc

        # check if victim violate any constraints
        
        
        
        return self._get_obs(), 0.1/self.min_ttc, False, False, self._get_info()
    
    def get_distance_among_victims(self):
        min_distance = 10
        for actor in self.world.get_actors():
            if 'vehicle' in actor.type_id:
                for actor2 in self.world.get_actors():
                    if 'vehicle' in actor2.type_id:
                        if actor.id != actor2.id and actor.id != self.attacker.id and actor2.id != self.attacker.id:
                            distance = actor.get_location().distance(actor2.get_location())
                            if distance < min_distance:
                                min_distance = distance

        return min_distance
    
    def close(self):
        os.system(SHUTDOWN_OP)


    # Ruoyu: Comment out the old get_obs. Bring back if needed.
    '''
    def _get_obs(self):
        # go through all vehicles, get states
        res = [self.t/self.max_t]

        attacker_x, attacker_y = (self.attacker.get_location().x - 167)/10, (self.attacker.get_location().y + 371)/10
        v_x, v_y = self.attacker.get_velocity().x, self.attacker.get_velocity().y
        a_x, a_y = self.attacker.get_acceleration().x, self.attacker.get_acceleration().y

        res += [attacker_x, attacker_y, v_x, v_y, a_x, a_y]

        for actor in self.world.get_actors():
            # if actor is a vehicle
            if 'vehicle' in actor.type_id:
                if actor.id != self.attacker.id:
                    res += self.get_vehicle_state(actor, attacker_x, attacker_y)

        return np.array(res)
    '''
    def _get_obs(self):
        res = [self.t/self.max_t]
        attacker_x, attacker_y = self.attacker.get_location().x, self.attacker.get_location().y
        attacker_vx, attacker_vy = self.attacker.get_velocity().x, self.attacker.get_velocity().y
        attacker_ax, attacker_ay = self.attacker.get_acceleration().x, self.attacker.get_acceleration().y
        victim_x, victim_y = self.victim.get_location().x, self.victim.get_location().y
        victim_vx, victim_vy = self.victim.get_velocity().x, self.victim.get_velocity().y
        victim_ax, victim_ay = self.victim.get_acceleration().x, self.victim.get_acceleration().y
        res += [attacker_x, attacker_y, attacker_vx, attacker_vy, attacker_ax, attacker_ay]
        res += [victim_x-attacker_x, victim_y-attacker_y, victim_vx-attacker_vx, victim_vy-attacker_vy, victim_ax-attacker_ax, victim_ay-attacker_ay]
        for npc in self.npcs:
            npc_x, npc_y = npc.get_location().x, npc.get_location().y
            npc_vx, npc_vy = npc.get_velocity().x, npc.get_velocity().y
            npc_ax, npc_ay = npc.get_acceleration().x, npc.get_acceleration().y
            res += [npc_x - attacker_x, npc_y - attacker_y, npc_vx - attacker_vx, npc_vy - attacker_vy, npc_ax - attacker_ax, npc_ay - attacker_ay]
        return np.array(res, dtype=np.float32)

    def _get_info(self):
        return {}

    def get_vehicle_state(self, vehicle, attacker_x, attacker_y):
        x, y = (attacker_x - vehicle.get_location().x)/10, (attacker_y - vehicle.get_location().y)/10
        v_x, v_y = vehicle.get_velocity().x, vehicle.get_velocity().y
        a_x , a_y = vehicle.get_acceleration().x, vehicle.get_acceleration().y

        return [x, y, v_x, v_y, a_x, a_y]

    def snap_to_ground(self, location, z_offset=0.5): # given the x, y, provide legit x, y, z in CARLA
        waypoint = self.world.get_map().get_waypoint(location)
        location.z = waypoint.transform.location.z + z_offset
        return location

    def apply_commands(self, vehicle, commands, duration=1):
        if commands[0] > 0:
            vehicle.apply_control(carla.VehicleControl(throttle=commands[0], steer=commands[1], brake=0, hand_brake=False, reverse=False))
        else:
            vehicle.apply_control(carla.VehicleControl(throttle=0, steer=commands[1], brake=-commands[0], hand_brake=False, reverse=False))
        time.sleep(duration)
        vehicle.apply_control(carla.VehicleControl(throttle=0, steer=0, brake=0, hand_brake=False, reverse=False))

    def update_constraints(self, vehicle, speed_limit=50, safe_ttc=1.5):
        # check if the vehicle is speeding
        speed = vehicle.get_velocity()
        speed_magnitude = math.sqrt(speed.x**2 + speed.y**2 + speed.z**2)
        if speed_magnitude > speed_limit:
            self.constriants["speeding"] = True
        else:
            self.constriants["speeding"] = False

        # check if the attacker is too close to the front vehicle
        # for actor in self.world.get_actors():
        #     if actor.id != vehicle.id:
        #         ttc = calculate_ttc(vehicle, actor)
        #         if ttc is None:
        #             continue
        #         elif ttc < safe_ttc:
        #             self.constriants["safe_distance"] = True

    def parse_speed_string(self, agent, speed_str):
        speed_norm = float(speed_str)
        speed_x = speed_norm * np.cos(np.deg2rad(agent.get_transform().rotation.yaw))
        speed_y = speed_norm * np.sin(np.deg2rad(agent.get_transform().rotation.yaw))

        return carla.Vector3D(speed_x, speed_y, 0)

    def check_constraints(self):
        self.update_constraints(self.attacker)
        for key in self.constriants:
            if self.constriants[key]:
                return True
        return False

    # added based on Acero
    def over_speed(self, vehicle):
        speed = vehicle.get_velocity()
        speed_magnitude = np.sqrt(speed.x**2 + speed.y**2 + speed.z**2)
        if speed_magnitude > 35:
            print(f"Vehicle is over speed limit, speed: {speed_magnitude}")
            return True
        return False
        
    def lane_change_bad_distance(self, vehicle, npc, threshold):
        ttc = calculate_ttc(vehicle, npc)
        if ttc < threshold:
            print(f"Vehicle is too close to NPC, ttc: {ttc}")
            return True
        return False
    
    def follow_bad_distance(self, vehicle, front, threshold):
        ttc = calculate_ttc(vehicle, front)
        if ttc < threshold:
            print(f"Vehicle is too close to NPC, ttc: {ttc}")
            return True
        return False
    
    def wrong_direction(self, vehicle):
        waypoint = self.world.get_map().get_waypoint(vehicle.get_location(), project_to_road=True, lane_type=(carla.LaneType.Driving | carla.LaneType.Sidewalk))
        #if waypoint.lane_id < 0:
            #print("Vehicle is going in the wrong direction")
            #return True
        return False
    
    
    
def calculate_ttc(vehicle1, vehicle2):
    # Get bounding boxes
    bb1 = vehicle1.bounding_box
    bb2 = vehicle2.bounding_box
    
    # Get current positions and velocities
    loc1 = vehicle1.get_location()
    vel1 = vehicle1.get_velocity()
    loc2 = vehicle2.get_location()
    vel2 = vehicle2.get_velocity()
    
    # Calculate relative velocity
    rel_vel = carla.Vector3D(vel1.x - vel2.x, vel1.y - vel2.y, vel1.z - vel2.z)
    
    # Calculate the relative position
    rel_pos = carla.Vector3D(loc2.x - loc1.x, loc2.y - loc1.y, loc2.z - loc1.z)
    
    # Calculate the sum of half-widths and half-lengths
    sum_half_widths = (bb1.extent.y + bb2.extent.y)
    sum_half_lengths = (bb1.extent.x + bb2.extent.x)
    
    # Rotate relative position and velocity to align with vehicle1's orientation
    yaw1 = math.radians(vehicle1.get_transform().rotation.yaw)
    cos_yaw1 = math.cos(yaw1)
    sin_yaw1 = math.sin(yaw1)
    
    rel_pos_rotated = carla.Vector2D(
        rel_pos.x * cos_yaw1 + rel_pos.y * sin_yaw1,
        -rel_pos.x * sin_yaw1 + rel_pos.y * cos_yaw1
    )
    rel_vel_rotated = carla.Vector2D(
        rel_vel.x * cos_yaw1 + rel_vel.y * sin_yaw1,
        -rel_vel.x * sin_yaw1 + rel_vel.y * cos_yaw1
    )

    # If relative lateral velocity is zero, no collision will occur
    if rel_vel_rotated.length() == 0:
        return None
    # Check if vehicles gonna crash by maintaining the current relative speed
    ttc_long = float('inf')
    ttc_lat = float('inf')
    if (rel_pos_rotated.x * rel_vel_rotated.x > 0): # veh is getting closer in long
        tmp = (rel_pos_rotated.x - sum_half_lengths) / rel_vel_rotated.x # > 0 otherwise the vehicles are just running in parallel
        if tmp > 0 and abs(rel_pos_rotated.y - rel_vel_rotated.y * tmp) < sum_half_widths: # will eventually crash (the lat diff will be remain within the sum_half_widths)
            ttc_long = tmp
    if (rel_pos_rotated.y * rel_vel_rotated.y > 0): # veh is getting closer in lat
        tmp = (rel_pos_rotated.y - sum_half_widths) / rel_vel_rotated.y # > 0
        if tmp > 0 and abs(rel_pos_rotated.x - rel_vel_rotated.x * tmp) < sum_half_lengths: # will eventually crash (the long diff will be remain within the sum_half_lengths)
            ttc_lat = tmp

    # The actual TTC is the minimum of longitudinal and lateral TTCs
    ttc = min(ttc_long, ttc_lat)
    
    return max(ttc, 0.1) if ttc > 0 else None

def is_ui_process_running(container_name='openpilot_client'):
    try:
        # Run the `docker top` command and pipe it to `grep ui`
        result = subprocess.run(
            f"docker top {container_name} | grep ui",
            shell=True,
            capture_output=True,
            text=True
        )
        
        # Check if the grep command found any matches
        if result.returncode == 0 and result.stdout:
            return True
        return False
    except Exception as e:
        print(f"An error occurred: {e}")
        return False
    

def kill_manager_process():
    container_name = "openpilot_client"
    try:
        # Find the PID of the process named "manager.py" inside the container
        result = subprocess.run(
            f"docker exec {container_name} pgrep -f manager.py",
            shell=True,
            capture_output=True,
            text=True
        )
        
        # Check if the command was successful
        if result.returncode != 0 or not result.stdout.strip():
            print("No process named 'manager.py' found in the container.")
            return False
        
        # Get the PID from the command output
        pid = result.stdout.strip()
        
        # Kill the process using the PID
        kill_result = subprocess.run(
            f"docker exec {container_name} kill {pid}",
            shell=True,
            capture_output=True,
            text=True
        )

    except Exception as e:
        print(f"An error occurred: {e}")
        return False
    
def start_op_process():
    try:
        # Run the `docker exec` command to start the UI process
        result = subprocess.run(
            "docker exec openpilot_client /bin/bash -c 'cd /home/batman/openpilot/tools/sim && ./launch_openpilot.sh>/dev/null \'&",
            shell=True,
            capture_output=False,
            text=True
        )
        
        # Check if the command was successful
        if result.returncode != 0:
            print(f"An error occurred: {result.stderr}")
            return False
        
        print("UI process started successfully.")
        return True
    except Exception as e:
        print(f"An error occurred: {e}")
        return False
    
def random_weather():
    """
    Randomly generate weather parameters based on a data-driven statistical model
    """

    sunA = 90 * random.uniform(0,1)
    pg = 50 * random.uniform(0,1)
    cb = np.random.beta(2, 2)
    cu = random.uniform(0,1)
    c = 0
    pa = 0

    if cu < 0.5:
        c = 30 * cb
    else:
        c = 40 * cb + 60

    if c >= 70:
        pa = c

    weather = carla.WeatherParameters(
        cloudiness=c,
        precipitation=pa,
        precipitation_deposits=pg,
        sun_altitude_angle=sunA,) 

    return weather

def load_scene_config(config, weather_randomization = False, speed_randomization = False, location_randomization = False):
    """
    Read attack scene configuration from json file and initial the scene
    """
    class carla_vechile:
        def __init__(self, model, pos, speed, control_list = None, role = None):
            self.model = model
            self.pos = pos
            self.speed = speed
            self.control_list = control_list
            self.role = role

    # Read attack scene configuration from json file
    with open(config, 'r') as f:
        scene_config = json.load(f)

    # if weather randomization is enabled, set the weather
    if weather_randomization:
        weather = random_weather()

    # Get the map
    map = scene_config["Map"]
    
    # Get the victim spawn location, speed
    victim = carla_vechile(None, scene_config["VictimTransform"], scene_config["VictimSpeed"])
    if speed_randomization:
        victim.speed = random.uniform(3, 5)
    
    # Get the attacker spawn location, speed
    attacker = carla_vechile(None, scene_config["AttackTransform"], scene_config["AttackSpeed"])
    if speed_randomization:
        attacker.speed = random.uniform(5, 10)
    if location_randomization:
        attacker.pos = str(random_transform_vehicle(attacker.pos))


    npc_list = []
    time_step = scene_config["timestep"]
    for i in scene_config['NPC']:
        npc = carla_vechile(scene_config['NPC'][i][0], scene_config['NPC'][i][1], scene_config['NPC'][i][2], scene_config['NPC'][i][3], i)
        npc_list.append(npc)

    mission_ID = scene_config["MissionID"]

    if "C" in mission_ID:
        calc_rob = "TTC"
    else:
        calc_rob = "L2"
    target_obj = npc_list[0].role
    return attacker, victim, npc_list, target_obj, weather, map, time_step, calc_rob

    
if __name__ == "__main__":
    kill_manager_process()
    start_op_process()