import json
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, confusion_matrix

# Load JSON file
def load_json(file_path):
    with open(file_path, 'r') as f:
        return json.load(f)

def load_environmental_context(file_path):
    with open(file_path, 'r') as f:
        return json.load(f)

# Prepare ground truth and predicted labels
def prepare_signs_data(data):
    signs_dict = {}
    for item in data:
        # print("HI",item)
        image_name = item['image_name']
        matched_signs = set(item['matched_signs'])  # Predictions from matched_signs
        original_annotations = set(item['original_annotations'])  # Ground truth from original_annotations
        signs_dict[image_name] = {
            'matched_signs': matched_signs,
            'original_annotations': original_annotations
        }
    return signs_dict

# Prepare true and predicted labels for precision, recall, F1 score, and accuracy
def prepare_labels(data, road_signs, env_context):
    y_true = {ctx: [] for ctx in ["DAYTIME", "NIGHTTIME", "SUNNY", "RAINY", "SNOWY", "FOGGY"]}
    y_pred = {ctx: [] for ctx in ["DAYTIME", "NIGHTTIME", "SUNNY", "RAINY", "SNOWY", "FOGGY"]}
    
    for img, signs in data.items():
        true_signs = set(signs['original_annotations'])
        predicted_signs = set(signs['matched_signs'])
        
        match_string = "25 MPH"
        replacement_string = "SPEED LIMIT 25"
        true_signs = [replacement_string if item == match_string else item for item in true_signs]
        predicted_signs = [replacement_string if item == match_string else item for item in predicted_signs]

        match_string = "15 MPH"
        replacement_string = "SPEED LIMIT 15"
        true_signs = [replacement_string if item == match_string else item for item in true_signs]
        predicted_signs = [replacement_string if item == match_string else item for item in predicted_signs]

        match_string = "10 MPH"
        replacement_string = "SPEED LIMIT 10"
        true_signs = [replacement_string if item == match_string else item for item in true_signs]
        predicted_signs = [replacement_string if item == match_string else item for item in predicted_signs]

        match_string = "35 MPH"
        replacement_string = "SPEED LIMIT 35"
        true_signs = [replacement_string if item == match_string else item for item in true_signs]
        predicted_signs = [replacement_string if item == match_string else item for item in predicted_signs]

        match_string = "55 MPH"
        replacement_string = "SPEED LIMIT 55"
        true_signs = [replacement_string if item == match_string else item for item in true_signs]
        predicted_signs = [replacement_string if item == match_string else item for item in predicted_signs]

        match_string = "45 MPH"
        replacement_string = "SPEED LIMIT 45"
        true_signs = [replacement_string if item == match_string else item for item in true_signs]
        predicted_signs = [replacement_string if item == match_string else item for item in predicted_signs]


        match_string = "60 MPH"
        replacement_string = "SPEED LIMIT 60"
        true_signs = [replacement_string if item == match_string else item for item in true_signs]
        predicted_signs = [replacement_string if item == match_string else item for item in predicted_signs]


        match_string = "70 MPH"
        replacement_string = "SPEED LIMIT 70"
        true_signs = [replacement_string if item == match_string else item for item in true_signs]
        predicted_signs = [replacement_string if item == match_string else item for item in predicted_signs]

        match_string = "65 MPH"
        replacement_string = "SPEED LIMIT 65"
        true_signs = [replacement_string if item == match_string else item for item in true_signs]
        predicted_signs = [replacement_string if item == match_string else item for item in predicted_signs]
        
        img_contexts = env_context.get(img, [])
        
        for sign in road_signs:
            for ctx in ["DAYTIME", "NIGHTTIME", "SUNNY", "RAINY", "SNOWY", "FOGGY"]:
                if ctx in img_contexts:
                    y_true[ctx].append(1 if sign in true_signs else 0)
                    y_pred[ctx].append(1 if sign in predicted_signs else 0)
    
    return y_true, y_pred

# Function to log and print error cases
import json

def log_errors_to_json(data, output_file):
    errors = []
    for img, signs in data.items():
        true_signs = set(signs['original_annotations'])
        predicted_signs = set(signs['matched_signs'])
        
        if true_signs != predicted_signs:
            errors.append({
                'image_name': img,
                'ground_truth': list(true_signs),
                'predictions': list(predicted_signs),
                'missing_signs': list(true_signs - predicted_signs),
                'extra_signs': list(predicted_signs - true_signs)
            })
    
    # Save errors to JSON
    with open(output_file, 'w') as f:
        json.dump(errors, f, indent=4)
    
    print(f"Errors logged to {output_file}")


# Load results and annotations from the same JSON file
file_path = 'AV_OCR/results/results_nlp_matching.json'

data = load_json(file_path)

# Prepare data as sets
prepared_data = prepare_signs_data(data)

# List of all possible road signs from annotations
road_signs = ['BEGIN ONE WAY', 'FREEWAY ENDS', 'LOW OR SOFT SHOULDER', 'NO SHOULDER ON BRIDGE', 'RUNAWAY TRUCK RAMP 2000 FEET', 'NO DRIVING ON SHOULDER', 'CAUTION ICY SPOTS', 'ROAD ENDS', 'REDUCE SPEED NOW', '35 MP.H.', 'ONE LANE ROAD AHEAD', 'NARROW BRIDGE', 'USE TURN SIGNAL', "DON'T TEXT AND DRIVE", 'ONE LANE BRIDGE AHEAD', 'BRIDGE ICES BEFORE ROAD', 'ACTIVE WORK ZONE WHEN FLASHING', 'SOFT SHOULDER', 'WATCH FOR LIVESTOCK', 'TRUCK CROSSING', 'NEXT 3 MILES', 'SIDEWALK CLOSED', 'NEXT 11 MILES', 'LANE ENDS HERGE RIGHT', 'TURN RIGHT ONLY', 'STOP', 'ROAD WORD AHEAD', 'TRUCKS AND BUSES USE LOW GEAR', 'RAMP 35 MPH', 'SHOULDER CLOSED', 'HOV- 2 LANE BUSES & CARPOOLS LEFT LANE', 'NO TURN ON RED', 'GUSTY WINDS MAY EXIST', 'CAUTION', 'DETOUR', 'NEXT 1/2 MILE', 'END ONE WAY', 'SPEED LIMIT PHOTO ENFORCED', 'CENTER LANE ONLY', 'WEIGHT LIMIT 10 TONS', '3-WAY', 'LAST EXIT BEFORE TOLLWAY', 'EXCEPT RIGHT TURN', 'SPEED LIMIT', 'LANE ENDS MERGE LEFT', 'EMERGENCY USE ONLY', '25 MPH', 'SHARE THE ROAD', 'END DETOUR', 'LEFT LANE ENDS', 'RIGHT TURN ONLY', 'HOV LANE', 'BRIDGE MAY BE ICY', 'LEFT TURN YIELD ON GREEN', 'YIELD TO PEDS IN CROSSWALK', 'CONSTRUCTION ENTRANCE ONLY', 'EXIT 25 MPH', 'ROAD CLOSED', 'LEFT TURN YIELD TO ONCOMING TRAFFIC', 'YIELD HERE TO PEDESTRIANS', 'SHOULDER WORK AHEAD', 'STOP HERE ON RED', 'REDUCED SPEED 50 AHEAD', '45 MPH SPEED ZONE AHEAD', 'SPEED LIMIT 45', 'BE PREPARED T0 STOP', 'BRAKE RETARDERS PROHIBITED', 'LEFT LANE MUST TURN LEFT', 'ONE LANE BRIDGE', ' ', 'SPEED LIMIT 55', 'SPEED HUMP', 'TRUCKS USE LOW GEAR', 'WRONG WAY', 'LEFT LANE CLOSED AHEAD', 'DEAD END', 'DRAW BRIDGE AHEAD', 'UNAUTHORIZED VEHICLES PROHIBITED', 'WATCH CHILDREN', 'LEFT TURN YIELD ON FLASHING YELLOW ARROW', 'EXCEPT BIKES', 'ROAD NARROWS', 'PARK & RIDE', 'RIGHT LANE MUST TURN RIGHT', 'MAY USE FULL LANE', 'YIELD', 'BRIDGE MAY BE SLIPPERY', 'SCHOOL ZONE AHEAD', 'ROAD CLOSED AHEAD', 'LEFT LEFT LANE NO TRUCKS', 'ALL WAY', 'SCHOOL BUS STOP AHEAD', 'DETOUR AHEAD', 'NO MOTOR VEHICLES', 'ONE WAY', 'ROAD MAY FLOOD', 'RIGHT LANE ENDS', 'KEEP RIGHT EXCEPT TO PASS', 'BIKE LANE ENDS', 'WEIGHT LIMIT 4 TONS', 'MERGE RIGHT', 'SPEED LIMIT 35', 'ROAD CLOSED TO THROUGH TRAFFIC', 'LOOSE GRAVEL', 'DEAF CHILD AREA', '5 MPH', 'ALL TRAFFIC MUST EXIT', 'LANE ENDS MERGE RIGHT', 'ALL TRAFFIC MUST TURN RIGHT', 'OVERNIGHT PARKING PROHIBITED', 'END SCHOOL ZONE', 'RUNAWAY TRUCK RAMP', 'REDUCED SPEED ZONE AHEAD', 'SPEED LIMIT 50', 'SHARE LANE WITH BICYCLISTS', 'RIGHT TURN SIGNAL', '4-WAY', 'WATCH FOR BICYCLISTS', '10 MPH', 'BRIDGE CLOSED', 'SPEED LIMIT 25', 'SPEED LIMIT 15', 'LEFT LANE MUST EXIT', 'ANTI ICING IN PROGRESS', 'PAVEMENT ENDS', 'RIGHT LANE ENDS MERGE LEFT', 'DO NOT ENTER', 'EMERGENCY STOPPING ONLY', '15 MPH', 'STEEL DECK', 'NO PARKING ON ANY STREET', 'PEDESTRIAN CROSSING', 'REST AREA', 'BLUE BELT', 'LEFT LANE NO TRUCKS', 'SPEED LIMIT 40', 'SPEED LIMIT 5 MPH', 'CROSS TRAFFIC DOES NOT STOP', 'LEFT TURN ONLY', 'STAY IN LANE', 'PARK AND RIDE', 'OPPOSING TRAFFIC DOES NOT STOP', 'PAT NARROWS', 'DRIVEWAY', 'TWO WAY TRAFFIC AHEAD', 'P OR STOPPING', 'WATCH FOR ROCKS', 'CAUTION MINIMUM MAINTENANCE ONLY', 'NOT TURN ON RED', 'BUSES ONLY', 'DIVIDED HIGHWAY', 'END ROAD WORK', 'PHOTO ENFORCED', "NO TRUCKS BUSSES AND RV'S", 'MOTOR VEHICLES ONLY', 'DO NOT PASS', 'BIKE LANE', 'RIGHT LANE CLOSED AHEAD', 'BUS STOP', 'DO NOT BLOCK INTERSECTION', 'SPEED 30', 'FINES HIGHER', 'SLOW CHILDREN AT PLAY', 'DRAW BRIDGE', '55 MPH', 'LANE CLOSED', 'ROUGH ROAD', 'WEIGHT LIMIT 5 TONS', 'BE PREPARED TO STOP', 'SPEED LIMIT 30', 'PUST BUTTON FOR WALK SIGNAL', '30 MPH', 'CAUTION CONSTRUCTION AREA SIDEWALK CLOSED', 'NO OUTLET', 'PATH NARROWS', 'TRUCK DETOUR', 'BIKE ROUTE', 'ROAD WORK AHEAD', 'BLIND PEDESTRIAN CROSSING', 'SCHOOL', 'PASS WITH CARE', 'HIDDEN DRIVEWAY', 'NEW TRAFFIC PATTERN AHEAD', 'BUMP', 'FINES DOUBLE', 'CAUTION HIDDEN DRIVEWAY', 'YOUR SPEED', 'RADAR ENFORCED', 'TRUCK ROUTE', 'DIP', 'WEIGHT LIMIT', 'SPEED LIMIT 70']


# Load environmental context data
env_context_file = 'AV_OCR/results/contextvlm_output_finetuned.json'
env_context = load_environmental_context(env_context_file)

# Prepare labels for precision, recall, F1 score, and accuracy
y_true, y_pred = prepare_labels(prepared_data, road_signs, env_context)

# Calculate and print metrics for each environmental context
contexts = ["DAYTIME", "NIGHTTIME", "SUNNY", "RAINY", "SNOWY", "FOGGY"]

for ctx in contexts:
    print(f"\nMetrics for {ctx}:")
    if len(y_true[ctx]) > 0:
        precision, recall, f1, _ = precision_recall_fscore_support(y_true[ctx], y_pred[ctx], average='macro')
        accuracy = accuracy_score(y_true[ctx], y_pred[ctx])
        
        print(f'Accuracy: {accuracy:.4f}')
        print(f'Precision (Macro): {precision:.4f}')
        print(f'Recall (Macro): {recall:.4f}')
        print(f'F1 Score (Macro): {f1:.4f}')
        
        conf_matrix = confusion_matrix(y_true[ctx], y_pred[ctx])
        TN, FP, FN, TP = conf_matrix.ravel()
        print(f'True Positives (TP): {TP}')
        print(f'True Negatives (TN): {TN}')
        print(f'False Positives (FP): {FP}')
        print(f'False Negatives (FN): {FN}')
    else:
        print("No data available for this context")


log_errors_to_json(prepared_data,'AV_OCR/results/mismatches.json')