# Make a function that takes an image and converts it into a numpy array
def image_to_numpy(image_path):
    import numpy as np
    import PIL as Image

    image = Image.open(image_path)
    image = np.array(image)
    return image

# Make a function that overlays the context_output list on the image
def make_context_image(image_numpy, context_output, save_path, save_path_context_output_jsons):
    import cv2
    import numpy as np
    from PIL import Image, ImageDraw, ImageFont

    # Make a copy of the image_numpy
    image = image_numpy.copy()

    # Convert numpy arry to PIL image
    img = Image.fromarray(image)

    # Create a draw object
    draw = ImageDraw.Draw(img)

    # Set the font size and font type
    font = ImageFont.truetype("/usr/share/fonts/truetype/freefont/FreeMono.ttf", 20)

    # Set the text color and background color
    text_color = (255, 255, 255)
    bg_color = (0, 0, 0)
    answer_color = (0, 255, 0)

    # Get the size of the image
    width, height = img.size

    # Count the number of items in the context_output list
    num_items = len(context_output)

    # If there are no items in the context_output list, return the image
    if (num_items != 0):
    
        # Divide the height of the image by the number of items in the context_output list
        height_divided = height / num_items

        # Loop through the context_output list and add the text to the image in each divided section
        for i in range(num_items):
            # Get the text
            text = context_output[i]

            # Get the text size
            text_size = font.getsize(text)

            # Get the text position
            text_position = (width - text_size[0], height_divided * i)

            # Draw the text in bold
            draw.text(text_position, text, font=font, fill=text_color, stroke_width=2, stroke_fill=bg_color)

    if (save_path != None):
        # Modify the save_path to include the current date and time
        import datetime
        now = datetime.datetime.now()

        # Make save_path directory if it doesn't exist
        import os
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        
        save_path = save_path + now.strftime("%Y-%m-%d_%H-%M-%S") + ".jpg"
        
        # Save the image
        img.save(save_path)

        if (save_path_context_output_jsons != None):

        # Save the context_output as a json with the same timestamp as the image

            # Make save_path_context_output_jsons directory if it doesn't exist

            if not os.path.exists(save_path_context_output_jsons):
                os.makedirs(save_path_context_output_jsons)

            # Modify the save_path_context_output_jsons to include the current date and time
            save_path_context_output_jsons = save_path_context_output_jsons + now.strftime("%Y-%m-%d_%H-%M-%S") + ".json"

            # Save the context_output as a json
            import json
            with open(save_path_context_output_jsons, 'w') as outfile:
                json.dump(context_output, outfile)

    return img


# Make a function that runs the context detection on an image in CADRE
def run_context_detection(image_numpy):
    # First convert the image_numpy to a PIL image
    from PIL import Image
    image = Image.fromarray(image_numpy)

    # Load the vqa model
    vqa_pipeline = load_vqa_model()

    # Make a context_output list
    context_output = []

    # Load the context questions
    context_questions = load_context_questions('./context_questions_ttsdet.json')

    # Run the evaluation
    evaluate(context_output, vqa_pipeline, image, context_questions)

    # Set the save_path
    save_path = 'DrivingContexts/save_path/'

    # Set a save_path to store context_output jsons
    save_path_context_output_jsons = 'DrivingContexts/context_output_jsons/'

    # Make the context image
    #context_image = make_context_image(image_numpy, context_output, save_path, save_path_context_output_jsons)

    # Return the context_output dictionary
    return context_output

# Make the load_vqa_model function
def load_vqa_model():
    from PIL import Image
    from transformers import pipeline

    from transformers import ViltForQuestionAnswering, ViltProcessor, pipeline
    import torch

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print(device)

    # Load the model and processor
    model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
    processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

    # Load the fine-tuned weights
    checkpoint = torch.load('model_finetuned.pth', map_location=device)
    # Remove 'module.' prefix if present
    from collections import OrderedDict

    state_dict = checkpoint['model_state_dict']
    new_state_dict = OrderedDict()

    for k, v in state_dict.items():
        new_key = k.replace("module.", "")  # Adjust the prefix as needed
        new_state_dict[new_key] = v

    # Load the updated state dict
    model.load_state_dict(new_state_dict, strict=False)  # Use strict=False to ignore missing keys

    # Move model to the device
    model.to(device)

    # Initialize the pipeline
    vqa_pipeline = pipeline("visual-question-answering", model=model, tokenizer=processor.tokenizer, image_processor=processor.image_processor, device=0)

    return vqa_pipeline

# Make a function to load the context questions
def load_context_questions(context_questions_path):
    import json

    with open(context_questions_path) as json_file:
        context_questions = json.load(json_file)
    return context_questions

# Get the keys of a dictionary. Use this to get the keys of the context questions dictionary (which are the the context categories)
def get_keys(dictionary):
    keys = []
    for key in dictionary.keys():
        keys.append(key)
    #print("keys=",keys)
    return keys

# Make a function that evaluates the context questions
def evaluate(context_output, vqa_pipeline, image, context_questions):
    # Get context categories
    context_categories = get_keys(context_questions)

    # Evaluate the first subcategory
    evaluate_sub_categories(context_output, vqa_pipeline, image, context_questions, context_categories)

# Get the category from subcategory
def get_category_from_subcategory(context_output, vqa_pipeline, image_path, context_questions, subcategory):
    # Find the key in context_questions that contains the subcategory
    keys = get_keys(context_questions)

    for key in keys:
        if subcategory in context_questions[key]:
            return key
    
    #print("Error: Subcategory " + subcategory + " not found")
    #print("Looking for category " + subcategory)

    # Check if subcategory is a key in context_questions
    category = subcategory
    if category in keys:
       #print("Found " + category + " in context_questions dictionary")
       # Evaluate the first subcategory of category
       sub_categories = get_keys(context_questions[category])
       # #print that we are now evaluating the first subcategory
       ##print("Now Evaluating: " + sub_categories[0])
       # Evaluate the first subcategory
       evaluate_sub_category(context_output, vqa_pipeline, image_path, context_questions, sub_categories)
       return None

# Make the gen_answer function
def get_answer(vqa_pipeline, image, question):
    from PIL import Image
    import time

    start_time = time.time()
    answer = vqa_pipeline(image, question, top_k = 1)
    end_time = time.time()

    time_taken = end_time - start_time
    answer.append({'time_taken': time_taken})
    return answer

# Make a function that adds the output_on_yes information to the context_output dictionary
def add_context_info(context_output,sub_category, output_on_yes):
    #print("Adding context info to context_output list")

    # Add the output_on_yes variable to the context_output list
    context_output.append(output_on_yes)

# Make a function that evaluates a list of categories
def evaluate_sub_categories(context_output, vqa_pipeline, image, context_questions, sub_categories):
    for sub_category in sub_categories:
        evaluate_sub_category(context_output, vqa_pipeline, image, context_questions, sub_category)

# Make a function that evaluates a subcategory
def evaluate_sub_category(context_output, vqa_pipeline, image, context_questions, sub_category):
    # Get the category from the subcategory
    #category = get_category_from_subcategory(context_output, vqa_pipeline, image, context_questions, sub_category)
    
    # Check if category is None
    #if(category == None):
    #    return None
    
    # Get the questions from the subcategory
    ##print(context_questions)
    #print(sub_category)
    question = context_questions[sub_category]['question']
    answer_positive = context_questions[sub_category]['answer_positive']
    answer_negative = context_questions[sub_category]['answer_negative']
    yes_min_score = context_questions[sub_category]['yes_min_score']
    no_min_score = context_questions[sub_category]['no_min_score']
    #print(context_questions[sub_category])
    #print(context_questions[sub_category]['output_on_yes'])
    next_questions_on_yes = context_questions[sub_category]['next_streams_on_yes']
    next_questions_on_no = context_questions[sub_category]['next_streams_on_no']
    next_questions_on_failure = context_questions[sub_category]['next_streams_on_failure']
    output_on_yes = context_questions[sub_category]['output_on_yes']

    # Get the answer to the question
    answer = get_answer(vqa_pipeline, image, question)

    # If the answer is positive and the score is greater than the yes_min_score, call the add_context_info(output_on_yes)function and evaluate the next_questions_on_yes category
    if answer[0]['answer'] == answer_positive and answer[0]['score'] > yes_min_score:
        #print("Got a Yes!")
        #print("Question: " + question)
        #print("Answer: " + answer[0]['answer'])
        #print("Score: " + str(answer[0]['score']))
        #print("Time Taken: " + str(answer[1]['time_taken']))
        #print('\n')
        add_context_info(context_output, sub_category, output_on_yes)
        # evaluate_sub_categories(context_output, vqa_pipeline, image, context_questions, next_questions_on_yes)

