import argparse
import math
import cv2
import os
import numpy as np
import rasterio
import tensorflow as tf
from keras.models import load_model
from data_util import DataLoader
from h5Image import H5Image
from unet_util import (UNET_224, Residual_CNN_block,
                       attention_up_and_concatenate,
                       attention_up_and_concatenate2, dice_coef,
                       dice_coef_loss, evaluate_prediction_result, jacard_coef,
                       multiplication, multiplication2)

# Declare h5_image as a global variable to streamline data access across functions
h5_image = None

def prediction_mask(prediction_result):
    """
    Apply a mask to the prediction image to isolate the area of interest.

    Parameters:
    - pad_unpatch_predicted_threshold: The thresholded prediction array.

    Returns:
    - masked_img: The masked prediction image.
    """
    global h5_image

    # Get map array 
    map_array = h5_image.get_map()

    # Convert the map array to a grayscale image
    gray = cv2.cvtColor(map_array, cv2.COLOR_BGR2GRAY)  # greyscale image

    # Detect Background Color
    pix_hist = cv2.calcHist([gray],[0],None,[256],[0,256])
    background_pix_value = np.argmax(pix_hist, axis=None)

    # Flood fill borders
    height, width = gray.shape[:2]
    corners = [[0,0],[0,height-1],[width-1, 0],[width-1, height-1]]
    for c in corners:
        cv2.floodFill(gray, None, (c[0],c[1]), 255)

    # AdaptiveThreshold to remove noise
    thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 21, 4)

    # Edge Detection
    thresh_blur = cv2.GaussianBlur(thresh, (11, 11), 0)
    canny = cv2.Canny(thresh_blur, 0, 200)
    canny_dilate = cv2.dilate(canny, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)))

    # Finding contours for the detected edges.
    contours, hierarchy = cv2.findContours(canny_dilate, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)

    # Keeping only the largest detected contour.
    contour = sorted(contours, key=cv2.contourArea, reverse=True)[0]
    wid, hight = prediction_result.shape[0], prediction_result.shape[1]
    mask = np.zeros([wid, hight])
    mask = cv2.fillPoly(mask, pts=[contour], color=(1,1,1)).astype(int)
    masked_img = cv2.bitwise_and(prediction_result, mask)

    return masked_img

def perform_inference(legend_patch, map_patch, model):
    """
    Perform inference on a given map patch and legend patch using a trained model.

    Parameters:
    - legend_patch: The legend patch from the map.
    - map_patch: The map patch for inference.
    - model: The trained deep learning model.

    Returns:
    - prediction: The prediction result for the given map patch.
    """
    global h5_image
    
    legend_resized = tf.image.resize(legend_patch, (h5_image.patch_size, h5_image.patch_size))

    print("map_patch", map_patch.shape, "legend_patch", legend_patch.shape)

    # Concatenate along the third axis and normalize
    input_patch = tf.concat(axis=2, values=[map_patch, legend_resized])
    input_patch = input_patch * 2.0 - 1.0
    
    # Resize the input patch to match the model's expected input size
    input_patch_resized = tf.image.resize(input_patch, (h5_image.patch_size, h5_image.patch_size))
    
    # Expand dimensions for prediction
    input_patch_expanded = tf.expand_dims(input_patch_resized, axis=0)

    # Get the prediction from the model
    prediction = model.predict(input_patch_expanded)

    return prediction.squeeze()

def save_results(prediction, outputPath, map_name, legend):
    """
    Save the prediction results to a specified output path.

    Parameters:
    - prediction: The prediction result.
    - outputPath: The directory where the results should be saved.
    - map_name: The name of the map.
    - legend: The legend associated with the prediction.
    """
    global h5_image
    
    output_image_path = os.path.join(outputPath, f"{map_name}_{legend}.tif")

    # Save the prediction to the output path
    with open(output_image_path, 'w') as f:
        f.write(str(prediction))

    ### Waiting for georeferencing data
    # This section will be used in future releases to save georeferenced images.

    ### Waiting for georeferencing data
    # with rasterio.open(map_img_path) as src:
    #     metadata = src.meta

    # metadata.update({
    #     'dtype': 'uint8',
    #     'count': 1,
    #     'height': reconstructed_image.shape[0],
    #     'width': reconstructed_image.shape[1],
    #     'compress': 'lzw',
    # })

    # with rasterio.open(output_image_path, 'w', **metadata) as dst:
    #     dst.write(reconstructed_image, 1)

def main(args):
    """
    Main function to orchestrate the map inference process.

    Parameters:
    - args: Command-line arguments.
    """
    global h5_image

    # Load the HDF5 file using the H5Image class
    print("Loading the HDF5 file.")
    h5_image = H5Image(args.mapPath, mode='r')

    # Get map details
    print("Getting map details.")
    map_name = h5_image.get_maps()[0]
    print(f"Map Name: {map_name}")
    all_map_legends = h5_image.get_layers(map_name)
    print(f"All Map Legends: {all_map_legends}")
    
    # Filter the legends based on the feature type
    if args.featureType == "Polygon":
        map_legends = [legend for legend in all_map_legends if "_poly" in legend]
    elif args.featureType == "Point":
        map_legends = [legend for legend in all_map_legends if "_pt" in legend]
    elif args.featureType == "Line":
        map_legends = [legend for legend in all_map_legends if "_line" in legend]
    elif args.featureType == "All":
        map_legends = all_map_legends
    
    # Get the size of the map
    map_width, map_height, _ = h5_image.get_map_size(map_name)
    
    # Calculate the number of patches based on the patch size and border
    num_rows = math.ceil(map_width / h5_image.patch_size)
    num_cols = math.ceil(map_height / h5_image.patch_size)
    
    # Conditionally load the model based on the presence of "attention" in the model path
    if "attention" in args.modelPath:
        # Load the attention Unet model with custom objects for attention mechanisms
        print(f"Loading model with attention from {args.modelPath}")
        model = load_model(args.modelPath, custom_objects={'multiplication': multiplication,
                                                            'multiplication2': multiplication2,
                                                            'dice_coef_loss':dice_coef_loss,
                                                            'dice_coef':dice_coef})
    else:
        print(f"Loading standard model from {args.modelPath}")
        # Load the standard Unet model with custom objects for dice coefficient loss
        model = load_model(args.modelPath, custom_objects={'dice_coef_loss':dice_coef_loss, 
                                                            'dice_coef':dice_coef})
        
    
    # Loop through the patches and perform inference
    for legend in (map_legends):
        print(f"Processing legend: {legend}")
        
        # Create an empty array to store the full prediction
        full_prediction = np.zeros((map_width, map_height))

        # Get the legend patch
        legend_patch = h5_image.get_legend(map_name, legend)

        # Iterate through rows and columns to get map patches
        for row in range(num_rows):
            for col in range(num_cols):

                map_patch = h5_image.get_patch(row, col, map_name)

                # Get the prediction for the current patch
                print(f"Prediction for patch ({row}, {col}) completed.")
                prediction = perform_inference(legend_patch, map_patch, model)

                
                # Place the prediction in the corresponding position in the full_prediction array
                x_start = row * h5_image.patch_size
                y_start = col * h5_image.patch_size
                full_prediction[x_start:x_start+h5_image.patch_size, y_start:y_start+h5_image.patch_size] = prediction
        
        # Mask out the map background pixels from the prediction
        print("Applying mask to the full prediction.")
        masked_prediction = prediction_mask(full_prediction)

        # Save the results
        print("Saving results.")
        save_results(masked_prediction, args.outputPath, map_name, legend)
    
    # Close the HDF5 file
    print("Inference process completed. Closing HDF5 file.")
    h5_image.close()

if __name__ == "__main__":
    # Command-line interface setup
    parser = argparse.ArgumentParser(description="Perform inference on a given map.")
    parser.add_argument("--mapPath", required=True, help="Path to the hdf5 file.")
    parser.add_argument("--jsonPath", required=True, help="Path to the JSON file that contain positions of legends.")
    parser.add_argument("--featureType", choices=["Polygon", "Point", "Line", "All"], default="Polygon", help="Type of feature to detect. Three options, Polygon, Point, and, Line")
    parser.add_argument("--outputPath", required=True, help="Path to save the inference results. ")
    parser.add_argument("--modelPath", default="./inference_model/Unet-attentionUnet.h5", help="Path to the trained model. Default is './inference_model/Unet-attentionUnet.h5'.")
    
    args = parser.parse_args()
    main(args)

# Test command
# python inference.py --mapPath "/projects/bbym/shared/data/commonPatchData/256/OK_250K.hdf5" --jsonPath "" --featureType "Polygon" --outputPath "/projects/bbym/nathanj/attentionUnet/infer_results" --modelPath "/projects/bbym/nathanj/attentionUnet/inference_model/Unet-attentionUnet.h5"