import os
import cv2
import json
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import pytesseract
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import easyocr
from fuzzywuzzy import process
import certifi
import ssl
import re
from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np
from difflib import SequenceMatcher

# SSL settings for HTTPS requests
ssl._create_default_https_context = ssl._create_unverified_context
os.environ['SSL_CERT_FILE'] = certifi.where()

# Paths
folder_path = 'AV_OCR/dataset/'
output_folder = 'AV_OCR/results/'
baseline1_output_file = os.path.join(output_folder, "results_baseline_trocr_apr30data.json")
baseline2_output_file = os.path.join(output_folder, "results_baseline_tesseract_apr30data.json")
baseline3_output_file = os.path.join(output_folder, "results_baseline_easyocr_apr30data.json")
annotations_file = 'AV_OCR/annotations/full_annotations.json'

# Load the image paths from the JSON file
with open(annotations_file, 'r') as file:
    annotation_data = json.load(file)

road_signs = ['BEGIN ONE WAY', 'FREEWAY ENDS', 'LOW OR SOFT SHOULDER', 'NO SHOULDER ON BRIDGE', 'RUNAWAY TRUCK RAMP 2000 FEET', 'NO DRIVING ON SHOULDER', 'CAUTION ICY SPOTS', 'ROAD ENDS', 'REDUCE SPEED NOW', '35 MP.H.', 'ONE LANE ROAD AHEAD', 'NARROW BRIDGE', 'USE TURN SIGNAL', "DON'T TEXT AND DRIVE", 'ONE LANE BRIDGE AHEAD', 'BRIDGE ICES BEFORE ROAD', 'ACTIVE WORK ZONE WHEN FLASHING', 'SOFT SHOULDER', 'WATCH FOR LIVESTOCK', 'TRUCK CROSSING', 'NEXT 3 MILES', 'SIDEWALK CLOSED', 'NEXT 11 MILES', 'LANE ENDS HERGE RIGHT', 'TURN RIGHT ONLY', 'STOP', 'ROAD WORD AHEAD', 'TRUCKS AND BUSES USE LOW GEAR', 'RAMP 35 MPH', 'SHOULDER CLOSED', 'HOV- 2 LANE BUSES & CARPOOLS LEFT LANE', 'NO TURN ON RED', 'GUSTY WINDS MAY EXIST', 'CAUTION', 'DETOUR', 'NEXT 1/2 MILE', 'END ONE WAY', 'SPEED LIMIT PHOTO ENFORCED', 'CENTER LANE ONLY', 'WEIGHT LIMIT 10 TONS', '3-WAY', 'LAST EXIT BEFORE TOLLWAY', 'EXCEPT RIGHT TURN', 'SPEED LIMIT', 'LANE ENDS MERGE LEFT', 'EMERGENCY USE ONLY', '25 MPH', 'SHARE THE ROAD', 'END DETOUR', 'LEFT LANE ENDS', 'RIGHT TURN ONLY', 'HOV LANE', 'BRIDGE MAY BE ICY', 'LEFT TURN YIELD ON GREEN', 'YIELD TO PEDS IN CROSSWALK', 'CONSTRUCTION ENTRANCE ONLY', 'EXIT 25 MPH', 'ROAD CLOSED', 'LEFT TURN YIELD TO ONCOMING TRAFFIC', 'YIELD HERE TO PEDESTRIANS', 'SHOULDER WORK AHEAD', 'STOP HERE ON RED', 'REDUCED SPEED 50 AHEAD', '45 MPH SPEED ZONE AHEAD', 'SPEED LIMIT 45', 'BE PREPARED T0 STOP', 'BRAKE RETARDERS PROHIBITED', 'LEFT LANE MUST TURN LEFT', 'ONE LANE BRIDGE', ' ', 'SPEED LIMIT 55', 'SPEED HUMP', 'TRUCKS USE LOW GEAR', 'WRONG WAY', 'LEFT LANE CLOSED AHEAD', 'DEAD END', 'DRAW BRIDGE AHEAD', 'UNAUTHORIZED VEHICLES PROHIBITED', 'WATCH CHILDREN', 'LEFT TURN YIELD ON FLASHING YELLOW ARROW', 'EXCEPT BIKES', 'ROAD NARROWS', 'PARK & RIDE', 'RIGHT LANE MUST TURN RIGHT', 'MAY USE FULL LANE', 'YIELD', 'BRIDGE MAY BE SLIPPERY', 'SCHOOL ZONE AHEAD', 'ROAD CLOSED AHEAD', 'LEFT LEFT LANE NO TRUCKS', 'ALL WAY', 'SCHOOL BUS STOP AHEAD', 'DETOUR AHEAD', 'NO MOTOR VEHICLES', 'ONE WAY', 'ROAD MAY FLOOD', 'RIGHT LANE ENDS', 'KEEP RIGHT EXCEPT TO PASS', 'BIKE LANE ENDS', 'WEIGHT LIMIT 4 TONS', 'MERGE RIGHT', 'SPEED LIMIT 35', 'ROAD CLOSED TO THROUGH TRAFFIC', 'LOOSE GRAVEL', 'DEAF CHILD AREA', '5 MPH', 'ALL TRAFFIC MUST EXIT', 'LANE ENDS MERGE RIGHT', 'ALL TRAFFIC MUST TURN RIGHT', 'OVERNIGHT PARKING PROHIBITED', 'END SCHOOL ZONE', 'RUNAWAY TRUCK RAMP', 'REDUCED SPEED ZONE AHEAD', 'SPEED LIMIT 50', 'SHARE LANE WITH BICYCLISTS', 'RIGHT TURN SIGNAL', '4-WAY', 'WATCH FOR BICYCLISTS', '10 MPH', 'BRIDGE CLOSED', 'SPEED LIMIT 25', 'SPEED LIMIT 15', 'LEFT LANE MUST EXIT', 'ANTI ICING IN PROGRESS', 'PAVEMENT ENDS', 'RIGHT LANE ENDS MERGE LEFT', 'DO NOT ENTER', 'EMERGENCY STOPPING ONLY', '15 MPH', 'STEEL DECK', 'NO PARKING ON ANY STREET', 'PEDESTRIAN CROSSING', 'REST AREA', 'BLUE BELT', 'LEFT LANE NO TRUCKS', 'SPEED LIMIT 40', 'SPEED LIMIT 5 MPH', 'CROSS TRAFFIC DOES NOT STOP', 'LEFT TURN ONLY', 'STAY IN LANE', 'PARK AND RIDE', 'OPPOSING TRAFFIC DOES NOT STOP', 'PAT NARROWS', 'DRIVEWAY', 'TWO WAY TRAFFIC AHEAD', 'P OR STOPPING', 'WATCH FOR ROCKS', 'CAUTION MINIMUM MAINTENANCE ONLY', 'NOT TURN ON RED', 'BUSES ONLY', 'DIVIDED HIGHWAY', 'END ROAD WORK', 'PHOTO ENFORCED', "NO TRUCKS BUSSES AND RV'S", 'MOTOR VEHICLES ONLY', 'DO NOT PASS', 'BIKE LANE', 'RIGHT LANE CLOSED AHEAD', 'BUS STOP', 'DO NOT BLOCK INTERSECTION', 'SPEED 30', 'FINES HIGHER', 'SLOW CHILDREN AT PLAY', 'DRAW BRIDGE', '55 MPH', 'LANE CLOSED', 'ROUGH ROAD', 'WEIGHT LIMIT 5 TONS', 'BE PREPARED TO STOP', 'SPEED LIMIT 30', 'PUST BUTTON FOR WALK SIGNAL', '30 MPH', 'CAUTION CONSTRUCTION AREA SIDEWALK CLOSED', 'NO OUTLET', 'PATH NARROWS', 'TRUCK DETOUR', 'BIKE ROUTE', 'ROAD WORK AHEAD', 'BLIND PEDESTRIAN CROSSING', 'SCHOOL', 'PASS WITH CARE', 'HIDDEN DRIVEWAY', 'NEW TRAFFIC PATTERN AHEAD', 'BUMP', 'FINES DOUBLE', 'CAUTION HIDDEN DRIVEWAY', 'YOUR SPEED', 'RADAR ENFORCED', 'TRUCK ROUTE', 'DIP', 'WEIGHT LIMIT', 'SPEED LIMIT 70']

# Function to save the results in JSON format
def save_results(results, output_file):
    with open(output_file, 'w') as f:
        json.dump(results, f, indent=4)

# Class to match OCR text to road signs using fuzzy string matching and various similarity measures
class SignTextMatcher_v2:
    def __init__(self, road_signs):
        self.road_signs = road_signs
        self.vectorizer = TfidfVectorizer(analyzer='char_wb', ngram_range=(2, 3), lowercase=True)
        self.tfidf_matrix = self.vectorizer.fit_transform(road_signs)
    
    def preprocess_text(self, text):
        """Clean and standardize text for better matching."""
        text = text.upper().strip()
        text = re.sub(r'[^A-Z0-9\s]', '', text)
        text = ' '.join(text.split())
        return text
    
    def get_word_overlap_score(self, text1, text2):
        """Calculate word overlap score between two texts."""
        words1 = set(text1.split())
        words2 = set(text2.split())
        intersection = words1.intersection(words2)
        union = words1.union(words2)
        return len(intersection) / len(union) if union else 0

    def get_sequence_score(self, text1, text2):
        """Calculate sequence similarity using SequenceMatcher."""
        return SequenceMatcher(None, text1, text2).ratio()
    
    def get_tfidf_similarity(self, text):
        """Calculate TF-IDF cosine similarity between text and all road signs."""
        text_vector = self.vectorizer.transform([text])
        return np.array(text_vector.dot(self.tfidf_matrix.T).toarray())[0]
    
    def get_number_consistency(self, text1, text2):
        """Check if numbers in both texts match."""
        numbers1 = set(re.findall(r'\d+', text1))
        numbers2 = set(re.findall(r'\d+', text2))
        if not numbers1 and not numbers2:
            return 1.0
        if not numbers1 or not numbers2:
            return 0.0
        return 1.0 if numbers1 == numbers2 else 0.0
    
    def find_best_match(self, text, threshold=0.3):
        """Find the best matching road sign using multiple similarity metrics."""
        text = self.preprocess_text(text)
        if not text:
            return None
        
        best_score = 0
        best_match = None
        
        for sign in self.road_signs:
            processed_sign = self.preprocess_text(sign)
            word_score = self.get_word_overlap_score(text, processed_sign)
            sequence_score = self.get_sequence_score(text, processed_sign)
            tfidf_score = self.get_tfidf_similarity(text)[self.road_signs.index(sign)]
            number_score = self.get_number_consistency(text, processed_sign)
            
            final_score = 0.3 * word_score + 0.3 * sequence_score + 0.2 * tfidf_score + 0.2 * number_score
            
            if final_score > best_score:
                best_score = final_score
                best_match = sign
        
        return best_match if best_score >= threshold else None


matcher = SignTextMatcher_v2(road_signs)

def match_road_sign(texts):
    matched_signs = []
    for text in texts:
        matched_sign = matcher.find_best_match(text)
        if matched_sign is not None:
            matched_signs.append(matched_sign)
    return matched_signs

class RoadSignDataset(Dataset):
    def __init__(self, annotation_data, folder_path):
        self.annotation_data = annotation_data
        self.folder_path = folder_path

    def __len__(self):
        return len(self.annotation_data)

    def __getitem__(self, idx):
        entry = self.annotation_data[idx]
        image_name = entry['image_name']
        full_image_path = os.path.join(self.folder_path, image_name)
        image = cv2.imread(full_image_path)
        return {
            'image': image,
            'image_name': image_name,
            'original_annotations': entry['matched_signs']
        }

def custom_collate(batch):
    images = [torch.from_numpy(item['image']).permute(2, 0, 1) for item in batch]
    image_names = [item['image_name'] for item in batch]
    original_annotations = [item['original_annotations'] for item in batch]
    
    return {
        'image': torch.stack(images),
        'image_name': image_names,
        'original_annotations': original_annotations
    }



def process_batch(batch, processor, model, reader):
    images = batch['image']
    image_names = batch['image_name']
    original_annotations = batch['original_annotations']
    import time

    # Start timing for ViT + TrOCR
    start_vit_trocr = time.time()
    pixel_values = processor(images=images, return_tensors="pt").pixel_values.to('cuda')
    with torch.no_grad():
        generated_ids = model.generate(pixel_values)
    generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
    matched_signs_trocr = [match_road_sign([text]) for text in generated_texts]
    end_vit_trocr = time.time()
    print(f"ViT + TrOCR processing time: {end_vit_trocr - start_vit_trocr:.4f} seconds")

    # Start timing for Tesseract
    start_tesseract = time.time()
    gray_images = [cv2.cvtColor(img.permute(1, 2, 0).byte().numpy(), cv2.COLOR_RGB2GRAY) for img in images]
    ocr_texts_tesseract = [pytesseract.image_to_string(img).strip() for img in gray_images]
    matched_signs_tesseract = [match_road_sign([text]) for text in ocr_texts_tesseract]
    end_tesseract = time.time()
    print(f"Tesseract processing time: {end_tesseract - start_tesseract:.4f} seconds")

    # Start timing for EasyOCR
    start_easyocr = time.time()
    numpy_images = [img.permute(1, 2, 0).byte().numpy() for img in images]
    results_easyocr = reader.readtext_batched(numpy_images)
    ocr_texts_easyocr = [" ".join([res[1] for res in result]) if result else "" for result in results_easyocr]
    matched_signs_easyocr = [match_road_sign([text]) for text in ocr_texts_easyocr]
    end_easyocr = time.time()
    print(f"EasyOCR processing time: {end_easyocr - start_easyocr:.4f} seconds")


    return {
        'trocr': [{'image_name': name, 'ocr_texts': [text], 'matched_signs': signs, 'original_annotations': annot}
                  for name, text, signs, annot in zip(image_names, generated_texts, matched_signs_trocr, original_annotations)],
        'tesseract': [{'image_name': name, 'ocr_texts': [text], 'matched_signs': signs, 'original_annotations': annot}
                      for name, text, signs, annot in zip(image_names, ocr_texts_tesseract, matched_signs_tesseract, original_annotations)],
        'easyocr': [{'image_name': name, 'ocr_texts': [text], 'matched_signs': signs, 'original_annotations': annot}
                    for name, text, signs, annot in zip(image_names, ocr_texts_easyocr, matched_signs_easyocr, original_annotations)]
    }

def main():
    dataset = RoadSignDataset(annotation_data, folder_path)
    dataloader = DataLoader(dataset, batch_size=1, num_workers=2, pin_memory=True, collate_fn=custom_collate)

    processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-printed')
    model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-printed').to('cuda')
    model.eval()

    reader = easyocr.Reader(['en'])

    baseline1_results = []
    baseline2_results = []
    baseline3_results = []

    for batch in tqdm(dataloader):
        try:
            results = process_batch(batch, processor, model, reader)
            baseline1_results.extend(results['trocr'])
            baseline2_results.extend(results['tesseract'])
            baseline3_results.extend(results['easyocr'])
        except Exception as e:
            print(f"Error processing batch: {str(e)}")
            continue
        except KeyboardInterrupt:
            print("Keyboard interrupt detected. Exiting...")
            break

    save_results(baseline1_results, baseline1_output_file)
    save_results(baseline2_results, baseline2_output_file)
    save_results(baseline3_results, baseline3_output_file)

    print("Baseline outputs generated using TrOCR, Tesseract, and EasyOCR.")

if __name__ == "__main__":
    main()
