import os
import cv2
import json
from ultralytics import YOLO
from paddleocr import PaddleOCR
from fuzzywuzzy import process
from tqdm import tqdm
import re
from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np
from difflib import SequenceMatcher
import re
from collections import Counter
from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np
import cv2
# Paths
folder_path = 'AV_OCR/dataset/'
output_folder = 'AV_OCR/results/'
yolo_output_folder = os.path.join(output_folder, "yolo_outputs")
ocr_output_file = os.path.join(output_folder, "ocr_outputs.json")
final_output_file = os.path.join(output_folder, "results_nlp_matching_may1.json")
annotations_file = 'AV_OCR/annotations/full_annotations.json'

# Create necessary folders
os.makedirs(output_folder, exist_ok=True)
os.makedirs(yolo_output_folder, exist_ok=True)

# Load the image paths from the JSON file
with open(annotations_file, 'r') as file:
    annotation_data = json.load(file)

# Define road signs and matcher
road_signs = ['BEGIN ONE WAY', 'FREEWAY ENDS', 'LOW OR SOFT SHOULDER', 'NO SHOULDER ON BRIDGE', 'RUNAWAY TRUCK RAMP 2000 FEET', 'NO DRIVING ON SHOULDER', 'CAUTION ICY SPOTS', 'ROAD ENDS', 'REDUCE SPEED NOW', '35 MP.H.', 'ONE LANE ROAD AHEAD', 'NARROW BRIDGE', 'USE TURN SIGNAL', "DON'T TEXT AND DRIVE", 'ONE LANE BRIDGE AHEAD', 'BRIDGE ICES BEFORE ROAD', 'ACTIVE WORK ZONE WHEN FLASHING', 'SOFT SHOULDER', 'WATCH FOR LIVESTOCK', 'TRUCK CROSSING', 'NEXT 3 MILES', 'SIDEWALK CLOSED', 'NEXT 11 MILES', 'LANE ENDS HERGE RIGHT', 'TURN RIGHT ONLY', 'STOP', 'ROAD WORD AHEAD', 'TRUCKS AND BUSES USE LOW GEAR', 'RAMP 35 MPH', 'SHOULDER CLOSED', 'HOV- 2 LANE BUSES & CARPOOLS LEFT LANE', 'NO TURN ON RED', 'GUSTY WINDS MAY EXIST', 'CAUTION', 'DETOUR', 'NEXT 1/2 MILE', 'END ONE WAY', 'SPEED LIMIT PHOTO ENFORCED', 'CENTER LANE ONLY', 'WEIGHT LIMIT 10 TONS', '3-WAY', 'LAST EXIT BEFORE TOLLWAY', 'EXCEPT RIGHT TURN', 'LANE ENDS MERGE LEFT', 'EMERGENCY USE ONLY', '25 MPH', 'SHARE THE ROAD', 'END DETOUR', 'LEFT LANE ENDS', 'RIGHT TURN ONLY', 'HOV LANE', 'BRIDGE MAY BE ICY', 'LEFT TURN YIELD ON GREEN', 'YIELD TO PEDS IN CROSSWALK', 'CONSTRUCTION ENTRANCE ONLY', 'EXIT 25 MPH', 'ROAD CLOSED', 'LEFT TURN YIELD TO ONCOMING TRAFFIC', 'YIELD HERE TO PEDESTRIANS', 'SHOULDER WORK AHEAD', 'STOP HERE ON RED', 'REDUCED SPEED 50 AHEAD', '45 MPH SPEED ZONE AHEAD', 'SPEED LIMIT 45', 'BE PREPARED T0 STOP', 'BRAKE RETARDERS PROHIBITED', 'LEFT LANE MUST TURN LEFT', 'ONE LANE BRIDGE', ' ', 'SPEED LIMIT 55', 'SPEED HUMP', 'TRUCKS USE LOW GEAR', 'WRONG WAY', 'LEFT LANE CLOSED AHEAD', 'DEAD END', 'DRAW BRIDGE AHEAD', 'UNAUTHORIZED VEHICLES PROHIBITED', 'WATCH CHILDREN', 'LEFT TURN YIELD ON FLASHING YELLOW ARROW', 'EXCEPT BIKES', 'ROAD NARROWS', 'PARK & RIDE', 'RIGHT LANE MUST TURN RIGHT', 'MAY USE FULL LANE', 'YIELD', 'BRIDGE MAY BE SLIPPERY', 'SCHOOL ZONE AHEAD', 'ROAD CLOSED AHEAD', 'LEFT LEFT LANE NO TRUCKS', 'ALL WAY', 'SCHOOL BUS STOP AHEAD', 'DETOUR AHEAD', 'NO MOTOR VEHICLES', 'ONE WAY', 'ROAD MAY FLOOD', 'RIGHT LANE ENDS', 'KEEP RIGHT EXCEPT TO PASS', 'BIKE LANE ENDS', 'WEIGHT LIMIT 4 TONS', 'MERGE RIGHT', 'SPEED LIMIT 35', 'ROAD CLOSED TO THROUGH TRAFFIC', 'LOOSE GRAVEL', 'DEAF CHILD AREA', '5 MPH', 'ALL TRAFFIC MUST EXIT', 'LANE ENDS MERGE RIGHT', 'ALL TRAFFIC MUST TURN RIGHT', 'OVERNIGHT PARKING PROHIBITED', 'END SCHOOL ZONE', 'RUNAWAY TRUCK RAMP', 'REDUCED SPEED ZONE AHEAD', 'SPEED LIMIT 50', 'SHARE LANE WITH BICYCLISTS', 'RIGHT TURN SIGNAL', '4-WAY', 'WATCH FOR BICYCLISTS', '10 MPH', 'BRIDGE CLOSED', 'SPEED LIMIT 25', 'SPEED LIMIT 15', 'LEFT LANE MUST EXIT', 'ANTI ICING IN PROGRESS', 'PAVEMENT ENDS', 'RIGHT LANE ENDS MERGE LEFT', 'DO NOT ENTER', 'EMERGENCY STOPPING ONLY', '15 MPH', 'STEEL DECK', 'NO PARKING ON ANY STREET', 'PEDESTRIAN CROSSING', 'REST AREA', 'BLUE BELT', 'LEFT LANE NO TRUCKS', 'SPEED LIMIT 40', 'SPEED LIMIT 5 MPH', 'CROSS TRAFFIC DOES NOT STOP', 'LEFT TURN ONLY', 'STAY IN LANE', 'PARK AND RIDE', 'OPPOSING TRAFFIC DOES NOT STOP', 'PAT NARROWS', 'DRIVEWAY', 'TWO WAY TRAFFIC AHEAD', 'WATCH FOR ROCKS', 'CAUTION MINIMUM MAINTENANCE ONLY', 'NOT TURN ON RED', 'BUSES ONLY', 'DIVIDED HIGHWAY', 'END ROAD WORK', 'PHOTO ENFORCED', "NO TRUCKS BUSSES AND RV'S", 'MOTOR VEHICLES ONLY', 'DO NOT PASS', 'BIKE LANE', 'RIGHT LANE CLOSED AHEAD', 'BUS STOP', 'DO NOT BLOCK INTERSECTION', 'SPEED 30', 'FINES HIGHER', 'SLOW CHILDREN AT PLAY', 'DRAW BRIDGE', '55 MPH', 'LANE CLOSED', 'ROUGH ROAD', 'WEIGHT LIMIT 5 TONS', 'BE PREPARED TO STOP', 'SPEED LIMIT 30', 'PUST BUTTON FOR WALK SIGNAL', '30 MPH', 'CAUTION CONSTRUCTION AREA SIDEWALK CLOSED', 'NO OUTLET', 'PATH NARROWS', 'TRUCK DETOUR', 'BIKE ROUTE', 'ROAD WORK AHEAD', 'BLIND PEDESTRIAN CROSSING', 'SCHOOL', 'PASS WITH CARE', 'HIDDEN DRIVEWAY', 'NEW TRAFFIC PATTERN AHEAD', 'BUMP', 'FINES DOUBLE', 'CAUTION HIDDEN DRIVEWAY', 'YOUR SPEED', 'RADAR ENFORCED', 'TRUCK ROUTE', 'DIP', 'WEIGHT LIMIT', 'SPEED LIMIT 70']


def compute_iou(box1, box2):
    """
    Compute IoU (Intersection over Union) of two bounding boxes.
    Each box is in the format [x_min, y_min, x_max, y_max].
    """
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])

    # Intersection area
    intersection = max(0, x2 - x1) * max(0, y2 - y1)
    
    # Union area
    area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
    area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
    union = area1 + area2 - intersection

    return intersection / union if union > 0 else 0

def merge_overlapping_bboxes(bboxes, iou_threshold=0.5):
    """
    Merge overlapping bounding boxes based on IoU threshold.
    """
    merged_bboxes = []
    used = set()

    for i, box1 in enumerate(bboxes):
        if i in used:
            continue
        
        # Start with the current box
        superset = box1
        used.add(i)
        
        for j, box2 in enumerate(bboxes):
            if j in used:
                continue
            
            iou = compute_iou(superset, box2)
            if iou > iou_threshold:
                # Merge the two boxes into the superset
                superset = [
                    min(superset[0], box2[0]),  # x_min
                    min(superset[1], box2[1]),  # y_min
                    max(superset[2], box2[2]),  # x_max
                    max(superset[3], box2[3])   # y_max
                ]
                used.add(j)
        
        merged_bboxes.append(superset)
    
    return np.array(merged_bboxes)


class SignTextMatcher:
    def __init__(self, road_signs):
        self.road_signs = road_signs
        self.vectorizer = TfidfVectorizer(analyzer='char_wb', ngram_range=(2, 3), lowercase=True)
        self.tfidf_matrix = self.vectorizer.fit_transform(road_signs)

    def preprocess_text(self, text):
        text = text.upper().strip()
        text = re.sub(r'[^A-Z0-9\s]', '', text)
        text = ' '.join(text.split())
        return text

    def get_tfidf_similarity(self, text):
        text_vector = self.vectorizer.transform([text])
        return np.array(text_vector.dot(self.tfidf_matrix.T).toarray())[0]

    def find_best_match(self, text, threshold=0.5):
        text = self.preprocess_text(text)
        if not text:
            return None, 0
        
        best_score = 0
        best_match = None
        
        for sign in self.road_signs:
            processed_sign = self.preprocess_text(sign)
            tfidf_score = self.get_tfidf_similarity(text)[self.road_signs.index(sign)]
            
            if tfidf_score > best_score:
                best_score = tfidf_score
                best_match = sign
        
        if best_score < threshold:
            return None, 0
        
        return best_match

class SignTextMatcher_v2:
    def __init__(self, road_signs):
        self.road_signs = road_signs
        self.vectorizer = TfidfVectorizer(
            analyzer='char_wb',  # Character n-grams including word boundaries
            ngram_range=(2, 3),  # Use both bigrams and trigrams
            lowercase=True
        )
        # Pre-compute TF-IDF matrix for road signs
        self.tfidf_matrix = self.vectorizer.fit_transform(road_signs)
        
    def preprocess_text(self, text):
        """Clean and standardize text for better matching."""
        # Convert to uppercase and remove extra whitespace
        text = text.upper().strip()
        # Remove special characters but keep spaces
        text = re.sub(r'[^A-Z0-9\s]', '', text)
        # Normalize spaces
        text = ' '.join(text.split())
        return text
    
    def get_word_overlap_score(self, text1, text2):
        """Calculate word overlap score between two texts."""
        words1 = set(text1.split())
        words2 = set(text2.split())
        intersection = words1.intersection(words2)
        union = words1.union(words2)
        return len(intersection) / len(union) if union else 0

    def get_length_similarity(self, text1, text2):
        """
        Calculate length similarity between two texts.
        
        Args:
            text1 (str): First text.
            text2 (str): Second text.
            
        Returns:
            float: A similarity score between 0 and 1.
        """
        len1, len2 = len(text1), len(text2)
        return 1 - abs(len1 - len2) / max(len1, len2) if max(len1, len2) > 0 else 0

    def get_sequence_score(self, text1, text2):
        """Calculate sequence similarity using SequenceMatcher."""
        return SequenceMatcher(None, text1, text2).ratio()
    
    def get_tfidf_similarity(self, text):
        """Calculate TF-IDF cosine similarity between text and all road signs."""
        text_vector = self.vectorizer.transform([text])
        return np.array(text_vector.dot(self.tfidf_matrix.T).toarray())[0]
    
    def get_number_consistency(self, text1, text2):
        """Check if numbers in both texts match."""
        numbers1 = set(re.findall(r'\d+', text1))
        numbers2 = set(re.findall(r'\d+', text2))
        if not numbers1 and not numbers2:
            return 1.0  # Both have no numbers
        if not numbers1 or not numbers2:
            return 0.0  # One has numbers, other doesn't
        return 1.0 if numbers1 == numbers2 else 0.0
    
    def find_best_match(self, text, threshold=0.4):
        """
        Find the best matching road sign using multiple similarity metrics.
        
        Args:
            text (str): The text to match
            threshold (float): Minimum similarity score to consider a match
            
        Returns:
            tuple: (best_match, confidence_score) or (None, 0) if no good match
        """
        text = self.preprocess_text(text)
        if not text:
            return None, 0
        
        best_score = 0
        best_match = None
        
        # Get similarity scores using different metrics
        for sign in self.road_signs:
            processed_sign = self.preprocess_text(sign)
            
            # Calculate different similarity scores
            word_score = self.get_word_overlap_score(text, processed_sign)
            sequence_score = self.get_sequence_score(text, processed_sign)
            tfidf_score = self.get_tfidf_similarity(text)[self.road_signs.index(sign)]
            number_score = self.get_number_consistency(text, processed_sign)
            
            # Weight the different scores
            # Adjust these weights based on your specific needs
            final_score = (
                0.3 * word_score +      # Word overlap
                0.3 * sequence_score +   # Sequence similarity
                0.2 * tfidf_score +      # TF-IDF similarity
                0.2 * number_score       # Number consistency
            )
            
            if final_score > best_score:
                best_score = final_score
                best_match = sign
        
        # Return None if the best match is below the threshold
        if best_score < threshold:
            return None, 0
            
        return best_match


def find_best_match_leven(text):
    closest_match, score = process.extractOne(text, road_signs)
    return closest_match if closest_match else None

# Load the model and OCR
model = YOLO('AV_OCR/src/sign_boundary_det/best_model.pt')
ocr = PaddleOCR(use_angle_cls=True, lang='en', show_log=False)
matcher = SignTextMatcher_v2(road_signs)

# Storage for results
save_results = []

import time

# Process images
total_time=0
for idx, entry in tqdm(enumerate(annotation_data)):
    try:
        image_name = entry['image_name']
        full_image_path = os.path.join(folder_path, image_name)
        high_res_image = cv2.imread(full_image_path)
        
        start_time=time.time()
        results = model(high_res_image, verbose=False, conf=0.2)
        bboxes = results[0].boxes.xyxy.cpu().numpy()  # Get xyxy bounding boxes
        bboxes = np.array(bboxes) 
        merged_bboxes = merge_overlapping_bboxes(bboxes)

        # OCR on each merged bounding box

        cropped_images = []

        for bbox in merged_bboxes:
            x_min, y_min, x_max, y_max = map(int, bbox)
            cropped_image = high_res_image[y_min:y_max, x_min:x_max]

            # Calculate the new dimensions (twice the size)
            # height, width = cropped_image.shape[:2]
            # new_size = (width * 2, height * 2)

            # # Upsample the image
            # upsampled_image = cv2.resize(cropped_image, new_size, interpolation=cv2.INTER_CUBIC)
            # cropped_images.append(upsampled_image)

            cropped_images.append(cropped_image)
        texts = []
        for img in cropped_images:
            result = ocr.ocr(img)
            if len(result) > 0 and result[0] is not None:
                # print(result,flush=True)
                str_text = " ".join([res[1][0] for res in result[0]])
                texts.append(str_text)

        # Find the closest matches for the detected text using matcher
        matched_signs = []
        for text in texts:
            #matched_sign=find_best_match_leven(text)
            matched_sign= matcher.find_best_match(text)
            if matched_sign[0] is not None:
                matched_signs.append(matched_sign)

        end_time=time.time()
        total_time+=end_time-start_time
        # Save OCR texts and matched signs as separate fields, and retain original annotations
        save_results.append({
            'image_name': image_name,
            'ocr_texts': texts,               # OCR results for the detected bounding boxes
            'matched_signs': matched_signs,   # Matched road signs from OCR
            'original_annotations': entry['matched_signs']  # Original matched signs from annotation file
        })
    except KeyboardInterrupt:
        raise
    except Exception as e:
        pass
    # except KeyboardInterrupt:
    #     raise
    # except Exception as e:
    #     print(f"Error processing {image_name}: {e}", flush=True)

# Save the OCR outputs, matched signs, and annotations to a JSON file
print(f"Average time taken: {total_time/len(annotation_data)}")

with open(final_output_file, 'w') as f:
    json.dump(save_results, f, indent=4)

print("Processed and saved OCR, matched signs, and original annotations separately.")

