import unittest
import carla
import math
from unittest.mock import Mock

# Import your TTC function here
from openpilot_env import calculate_ttc

class TestTTCFunction(unittest.TestCase):

    def setUp(self):
        # Mock CARLA objects and methods
        self.mock_vehicle = Mock(spec=carla.Vehicle)
        self.mock_bbox = Mock(spec=carla.BoundingBox)
        self.mock_location = Mock(spec=carla.Location)
        self.mock_velocity = Mock(spec=carla.Vector3D)
        self.mock_transform = Mock(spec=carla.Transform)
        self.mock_rotation = Mock(spec=carla.Rotation)

    def test_vehicles_same_direction_same_speed(self):
        # Set up two vehicles moving in the same direction with the same speed
        v1, v2 = self.setup_two_vehicles()
        v1.get_velocity.return_value = carla.Vector3D(10, 0, 0)
        v2.get_velocity.return_value = carla.Vector3D(10, 0, 0)
        v1.get_location.return_value = carla.Location(0, 0, 0)
        v2.get_location.return_value = carla.Location(20, 0, 0)
        v1.get_transform().rotation.yaw = 0
        v2.get_transform().rotation.yaw = 0

        ttc = calculate_ttc(v1, v2)
        self.assertIsNone(ttc)

    def test_vehicles_same_direction_different_speeds(self):
        # Set up two vehicles moving in the same direction with different speeds
        v1, v2 = self.setup_two_vehicles()
        v1.get_velocity.return_value = carla.Vector3D(15, 0, 0)
        v2.get_velocity.return_value = carla.Vector3D(10, 0, 0)
        v1.get_location.return_value = carla.Location(0, 0, 0)
        v2.get_location.return_value = carla.Location(20, 0, 0)
        v1.get_transform().rotation.yaw = 0
        v2.get_transform().rotation.yaw = 0

        ttc = calculate_ttc(v1, v2)
        self.assertAlmostEqual(ttc, (20 - (2 + 2)) / 5, places=2)

    def test_vehicles_following_same_speed(self):
        # Set up two vehicles following each other with the same speed
        v1, v2 = self.setup_two_vehicles()
        v1.get_velocity.return_value = carla.Vector3D(10, 0, 0)
        v2.get_velocity.return_value = carla.Vector3D(10, 0, 0)
        v1.get_location.return_value = carla.Location(0, 0, 0)
        v2.get_location.return_value = carla.Location(10, 0, 0)
        v1.get_transform().rotation.yaw = 0
        v2.get_transform().rotation.yaw = 0

        ttc = calculate_ttc(v1, v2)
        self.assertIsNone(ttc)

    def test_vehicles_following_different_speeds(self):
        # Set up two vehicles following each other with different speeds
        v1, v2 = self.setup_two_vehicles()
        v1.get_velocity.return_value = carla.Vector3D(10, 0, 0)
        v2.get_velocity.return_value = carla.Vector3D(5, 0, 0)
        v1.get_location.return_value = carla.Location(0, 0, 0)
        v2.get_location.return_value = carla.Location(10, 0, 0)
        v1.get_transform().rotation.yaw = 0
        v2.get_transform().rotation.yaw = 0

        ttc = calculate_ttc(v1, v2)
        self.assertAlmostEqual(ttc, (10 - (2 + 2)) / 5, places=2)

    def test_vehicles_different_lanes_bounding_boxes_collide(self):
        # Set up two vehicles in different lanes but bounding boxes may collide
        v1, v2 = self.setup_two_vehicles()
        v1.bounding_box.extent = carla.Vector3D(2, 1.5, 1.5)  # Smaller vehicle
        v2.bounding_box.extent = carla.Vector3D(2, 1.5, 1.5)  # Larger vehicle
        v1.get_velocity.return_value = carla.Vector3D(12, 0, 0)
        v2.get_velocity.return_value = carla.Vector3D(10, 0, 0)
        v1.get_location.return_value = carla.Location(0, 0, 0)
        v2.get_location.return_value = carla.Location(10, 2, 0)  # 2 meters lateral offset
        v1.get_transform().rotation.yaw = 0
        v2.get_transform().rotation.yaw = 0

        ttc = calculate_ttc(v1, v2)
        self.assertAlmostEqual(ttc, (10 - (2 + 2)) / 2, places=2)

    def setup_two_vehicles(self):
        v1 = self.mock_vehicle
        v2 = Mock(spec=carla.Vehicle)
        
        for v in [v1, v2]:
            v.bounding_box = self.mock_bbox
            v.bounding_box.extent = carla.Vector3D(2, 1, 1)
            v.get_location.return_value = self.mock_location
            v.get_velocity.return_value = self.mock_velocity
            v.get_transform.return_value = self.mock_transform
            v.get_transform().rotation = self.mock_rotation

        return v1, v2

if __name__ == '__main__':
    unittest.main()