import argparse
import math
import cv2
import os
import numpy as np
from PIL import Image
import rasterio
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from keras.models import load_model
import matplotlib.gridspec as gridspec
from data_util import DataLoader
from h5Image import H5Image
from unet_util import (UNET_224, Residual_CNN_block,
                        attention_up_and_concatenate2, dice_coef,
                        dice_coef_loss, evaluate_prediction_result, jacard_coef,
                        multiplication, multiplication2)
# Declare h5_image as a global variable to streamline data access across functions
h5_image = None
def save_plot_as_png(prediction_result, map_name, legend, outputPath):

    global h5_image

    true_seg = h5_image.get_layer(map_name, legend)
    full_map = h5_image.get_map(map_name)
    legend_patch = h5_image.get_legend(map_name, legend)
    legend_resized = tf.image.resize(legend_patch, (true_seg.shape[0], true_seg.shape[1]))

    output_image_path = os.path.join(outputPath, f"{map_name}_{legend}_visual.png")

    fig, axarr = plt.subplots(1, 4, figsize=(20,5))

    # Load images
    img1 = true_seg
    img2 = prediction_result
    img3 = full_map
    img4 = legend_resized
    print("legend_resized.shape", img4.shape, "unique values", np.unique(legend_patch), np.unique(legend_resized))

    # Create figure with a specific size
    fig = plt.figure(figsize=(20, 5))
    # Using GridSpec for custom sizing
    gs = gridspec.GridSpec(1, 4, width_ratios=[1,1,1,1])
    ax0 = plt.subplot(gs[0])
    ax0.set_title('True segmentation')
    ax1 = plt.subplot(gs[1])
    ax1.set_title('Predicted segmentation')

    ax2 = plt.subplot(gs[2])

    ax3 = plt.subplot(gs[3])


def prediction_mask(prediction_result, map_name, legend, outputPath):
    Apply a mask to the prediction image to isolate the area of interest.

    - pad_unpatch_predicted_threshold: The thresholded prediction array.

    - masked_img: The masked prediction image.
    global h5_image

    # Get map array 
    map_array = np.array(h5_image.get_map(map_name))
    print("map_array", map_array.shape)

    # Convert the map array to a grayscale image
    gray = cv2.cvtColor(map_array, cv2.COLOR_BGR2GRAY)  # greyscale image

    # Detect Background Color
    pix_hist = cv2.calcHist([gray],[0],None,[256],[0,256])
    background_pix_value = np.argmax(pix_hist, axis=None)

    # Flood fill borders
    height, width = gray.shape[:2]
    corners = [[0,0],[0,height-1],[width-1, 0],[width-1, height-1]]
    for c in corners:
        cv2.floodFill(gray, None, (c[0],c[1]), 255)

    # AdaptiveThreshold to remove noise
    thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 21, 4)

    # Edge Detection
    thresh_blur = cv2.GaussianBlur(thresh, (11, 11), 0)
    canny = cv2.Canny(thresh_blur, 0, 200)
    canny_dilate = cv2.dilate(canny, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)))

    # Finding contours for the detected edges.
    contours, hierarchy = cv2.findContours(canny_dilate, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)

    # Keeping only the largest detected contour.
    contour = sorted(contours, key=cv2.contourArea, reverse=True)[0]
    wid, hight = prediction_result.shape[0], prediction_result.shape[1]
    mask = np.zeros([wid, hight])
    mask = cv2.fillPoly(mask, pts=[contour], color=(1)).astype(np.uint8)
    # Threshold prediction results and convert to int
    prediction_result_int = (prediction_result > 0.5).astype(np.uint8)

    # Perform the bitwise operation with the mask also converted to uint8
    masked_img = cv2.bitwise_and(prediction_result_int, mask)
    return masked_img

def perform_inference(legend_patch, map_patch, model):
    Perform inference on a given map patch and legend patch using a trained model.

    - legend_patch: The legend patch from the map.
    - map_patch: The map patch for inference.
    - model: The trained deep learning model.
    - prediction: The prediction result for the given map patch.
    global h5_image
    legend_resized = tf.image.resize(legend_patch, (h5_image.patch_size, h5_image.patch_size))
    legend_resized = tf.cast(legend_resized, dtype=tf.float32) / 255.0
    map_patch_resize = tf.image.resize(map_patch, (h5_image.patch_size, h5_image.patch_size))
    map_patch_resize = tf.cast(map_patch_resize, dtype=tf.float32) / 255.0
    # Concatenate along the third axis and normalize
    input_patch = tf.concat(axis=2, values=[map_patch_resize, legend_resized])
    input_patch = input_patch * 2.0 - 1.0
    # Resize the input patch to match the model's expected input size
    input_patch_resized = tf.image.resize(input_patch, (h5_image.patch_size, h5_image.patch_size))
    # Expand dimensions for prediction
    input_patch_expanded = tf.expand_dims(input_patch_resized, axis=0)

    # Get the prediction from the model
    prediction = model.predict(input_patch_expanded)

    return prediction.squeeze()

def save_results(prediction, map_name, legend, outputPath):
    Save the prediction results to a specified output path.

    - prediction: The prediction result (should be a 2D or 3D numpy array).
    - outputPath: The directory where the results should be saved.
    - map_name: The name of the map.
    - legend: The legend associated with the prediction.
    output_image_path = os.path.join(outputPath, f"{map_name}_{legend}.tif")

    # Convert the prediction to an image
    # Note: The prediction array may need to be scaled or converted before saving as an image
    prediction_image = Image.fromarray((prediction*255).astype(np.uint8))

    # Save the prediction as a tiff image, 'TIFF')

    ### Waiting for georeferencing data
    # This section will be used in future releases to save georeferenced images.

    ### Waiting for georeferencing data
    # with as src:
    #     metadata = src.meta

    # metadata.update({
    #     'dtype': 'uint8',
    #     'count': 1,
    #     'height': reconstructed_image.shape[0],
    #     'width': reconstructed_image.shape[1],
    #     'compress': 'lzw',
    # })

    # with, 'w', **metadata) as dst:
    #     dst.write(reconstructed_image, 1)

def main(args):
    Main function to orchestrate the map inference process.
    - args: Command-line arguments.
    global h5_image

    # Load the HDF5 file using the H5Image class
    print("Loading the HDF5 file.")
    h5_image = H5Image(args.mapPath, mode='r', patch_border=0)
    # Get map details
    print("Getting map details.")
    map_name = h5_image.get_maps()[0]
    print(f"Map Name: {map_name}")
    all_map_legends = h5_image.get_layers(map_name)
    print(f"All Map Legends: {all_map_legends}")
    # Filter the legends based on the feature type
    if args.featureType == "Polygon":
        map_legends = [legend for legend in all_map_legends if "_poly" in legend]
    elif args.featureType == "Point":
        map_legends = [legend for legend in all_map_legends if "_pt" in legend]
    elif args.featureType == "Line":
        map_legends = [legend for legend in all_map_legends if "_line" in legend]
    elif args.featureType == "All":
        map_legends = all_map_legends
    # Get the size of the map
    map_width, map_height, _ = h5_image.get_map_size(map_name)
    # Calculate the number of patches based on the patch size and border
    num_rows = math.ceil(map_width / h5_image.patch_size)
    num_cols = math.ceil(map_height / h5_image.patch_size)
    # Conditionally load the model based on the presence of "attention" in the model path
    if "attention" in args.modelPath:
        # Load the attention Unet model with custom objects for attention mechanisms
        print(f"Loading model with attention from {args.modelPath}")
        model = load_model(args.modelPath, custom_objects={'multiplication': multiplication,
                                                            'multiplication2': multiplication2,
        print(f"Loading standard model from {args.modelPath}")
        # Load the standard Unet model with custom objects for dice coefficient loss
        model = load_model(args.modelPath, custom_objects={'dice_coef_loss':dice_coef_loss, 
    # Loop through the patches and perform inference
    for legend in (map_legends):
        print(f"Processing legend: {legend}")
        # Create an empty array to store the full prediction
        full_prediction = np.zeros((map_width, map_height))

        # Get the legend patch
        legend_patch = h5_image.get_legend(map_name, legend)
        # Iterate through rows and columns to get map patches
        for row in range(num_rows):
            for col in range(num_cols):
                map_patch = h5_image.get_patch(row, col, map_name)
                # Get the prediction for the current patch
                print(f"Prediction for patch ({row}, {col}) completed.")
                prediction = perform_inference(legend_patch, map_patch, model)
                # Calculate starting indices for rows and columns
                x_start = row * h5_image.patch_size
                y_start = col * h5_image.patch_size

                # Calculate ending indices for rows and columns
                x_end = x_start + h5_image.patch_size
                y_end = y_start + h5_image.patch_size

                # Adjust the ending indices if they go beyond the image size
                x_end = min(x_end, map_width)
                y_end = min(y_end, map_height)

                # Adjust the shape of the prediction if necessary
                prediction_shape_adjusted = prediction[:x_end-x_start, :y_end-y_start]

                # Assign the prediction to the correct location in the full_prediction array
                full_prediction[x_start:x_end, y_start:y_end] = prediction_shape_adjusted
        # Mask out the map background pixels from the prediction
        print("Applying mask to the full prediction.")
        masked_prediction = prediction_mask(full_prediction, map_name, legend, args.outputPath)
        save_plot_as_png(masked_prediction, map_name, legend, args.outputPath)
        # Save the results
        print("Saving results.")
        save_results(masked_prediction, map_name, legend, args.outputPath)
    # Close the HDF5 file
    print("Inference process completed. Closing HDF5 file.")
if __name__ == "__main__":
    # Command-line interface setup
    parser = argparse.ArgumentParser(description="Perform inference on a given map.")
    parser.add_argument("--mapPath", required=True, help="Path to the hdf5 file.")
    parser.add_argument("--jsonPath", required=True, help="Path to the JSON file that contain positions of legends.")
    parser.add_argument("--featureType", choices=["Polygon", "Point", "Line", "All"], default="Polygon", help="Type of feature to detect. Three options, Polygon, Point, and, Line")
    parser.add_argument("--outputPath", required=True, help="Path to save the inference results. ")
    parser.add_argument("--modelPath", default="./inference_model/Unet-attentionUnet.h5", help="Path to the trained model. Default is './inference_model/Unet-attentionUnet.h5'.")
    args = parser.parse_args()

# Test command
# python --mapPath "/projects/bbym/shared/data/commonPatchData/256/OK_250K.hdf5" --jsonPath "" --featureType "Polygon" --outputPath "/projects/bbym/nathanj/attentionUnet/infer_results" --modelPath "/projects/bbym/nathanj/attentionUnet/inference_model/Unet-attentionUnet.h5"