Skip to content
Snippets Groups Projects
inference.py 9.35 KiB
Newer Older
import argparse
Nattapon Jaroenchai's avatar
Nattapon Jaroenchai committed
import math
import cv2
Nattapon Jaroenchai's avatar
Nattapon Jaroenchai committed
import os
import numpy as np
import rasterio
import tensorflow as tf
from keras.models import load_model
Nattapon Jaroenchai's avatar
Nattapon Jaroenchai committed
from data_util import DataLoader
from h5Image import H5Image
from unet_util import (UNET_224, Residual_CNN_block,
                       attention_up_and_concatenate,
                       attention_up_and_concatenate2, dice_coef,
                       dice_coef_loss, evaluate_prediction_result, jacard_coef,
                       multiplication, multiplication2)
# Declare h5_image as a global variable to streamline data access across functions
h5_image = None
def prediction_mask(prediction_result):
    """
    Apply a mask to the prediction image to isolate the area of interest.

    Parameters:
    - pad_unpatch_predicted_threshold: The thresholded prediction array.

    Returns:
    - masked_img: The masked prediction image.
    """
    global h5_image

    # Get map array 
    map_array = h5_image.get_map()

    # Convert the map array to a grayscale image
    gray = cv2.cvtColor(map_array, cv2.COLOR_BGR2GRAY)  # greyscale image

    # Detect Background Color
    pix_hist = cv2.calcHist([gray],[0],None,[256],[0,256])
    background_pix_value = np.argmax(pix_hist, axis=None)

    # Flood fill borders
    height, width = gray.shape[:2]
    corners = [[0,0],[0,height-1],[width-1, 0],[width-1, height-1]]
    for c in corners:
        cv2.floodFill(gray, None, (c[0],c[1]), 255)

    # AdaptiveThreshold to remove noise
    thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 21, 4)

    # Edge Detection
    thresh_blur = cv2.GaussianBlur(thresh, (11, 11), 0)
    canny = cv2.Canny(thresh_blur, 0, 200)
    canny_dilate = cv2.dilate(canny, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)))

    # Finding contours for the detected edges.
    contours, hierarchy = cv2.findContours(canny_dilate, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)

    # Keeping only the largest detected contour.
    contour = sorted(contours, key=cv2.contourArea, reverse=True)[0]
    wid, hight = prediction_result.shape[0], prediction_result.shape[1]
    mask = np.zeros([wid, hight])
    mask = cv2.fillPoly(mask, pts=[contour], color=(1,1,1)).astype(int)
    masked_img = cv2.bitwise_and(prediction_result, mask)

    return masked_img

def perform_inference(legend_patch, map_patch, model):
    """
    Perform inference on a given map patch and legend patch using a trained model.

    Parameters:
    - legend_patch: The legend patch from the map.
    - map_patch: The map patch for inference.
    - model: The trained deep learning model.
    Returns:
    - prediction: The prediction result for the given map patch.
    """
    global h5_image

    # Concatenate along the third axis and normalize
    input_patch = tf.concat(axis=2, values=[map_patch, legend_patch])
    input_patch = input_patch * 2.0 - 1.0
    
    # Resize the input patch to match the model's expected input size
    input_patch_resized = tf.image.resize(input_patch, (h5_image.patch_size, h5_image.patch_size))
    
    # Expand dimensions for prediction
    input_patch_expanded = tf.expand_dims(input_patch_resized, axis=0)

    # Get the prediction from the model
    prediction = model.predict(np.expand_dims(input_patch_expanded, axis=0))

    return prediction.squeeze()

def save_results(prediction, outputPath, map_name, legend):
    """
    Save the prediction results to a specified output path.

    Parameters:
    - prediction: The prediction result.
    - outputPath: The directory where the results should be saved.
    - map_name: The name of the map.
    - legend: The legend associated with the prediction.
    """
    global h5_image
    
    output_image_path = os.path.join(outputPath, f"{map_name}_{legend}.tif")

    # Save the prediction to the output path
    with open(output_image_path, 'w') as f:
        f.write(str(prediction))

    ### Waiting for georeferencing data
    # This section will be used in future releases to save georeferenced images.

    ### Waiting for georeferencing data
    # with rasterio.open(map_img_path) as src:
    #     metadata = src.meta

    # metadata.update({
    #     'dtype': 'uint8',
    #     'count': 1,
    #     'height': reconstructed_image.shape[0],
    #     'width': reconstructed_image.shape[1],
    #     'compress': 'lzw',
    # })

    # with rasterio.open(output_image_path, 'w', **metadata) as dst:
    #     dst.write(reconstructed_image, 1)

def main(args):
    """
    Main function to orchestrate the map inference process.
    Parameters:
    - args: Command-line arguments.
    """
    global h5_image

Nattapon Jaroenchai's avatar
Nattapon Jaroenchai committed
    # Load the HDF5 file using the H5Image class
    print("Loading the HDF5 file.")
Nattapon Jaroenchai's avatar
Nattapon Jaroenchai committed
    h5_image = H5Image(args.mapPath, mode='r')
    # Get map details
    print("Getting map details.")
    map_name = h5_image.get_maps()[0]
    print(f"Map Name: {map_name}")
    all_map_legends = h5_image.get_layers(map_name)
    print(f"All Map Legends: {all_map_legends}")
    
    # Filter the legends based on the feature type
    if args.featureType == "Polygon":
        map_legends = [legend for legend in all_map_legends if "_poly" in legend]
    elif args.featureType == "Point":
        map_legends = [legend for legend in all_map_legends if "_pt" in legend]
    elif args.featureType == "Line":
        map_legends = [legend for legend in all_map_legends if "_line" in legend]
    elif args.featureType == "All":
        map_legends = all_map_legends
Nattapon Jaroenchai's avatar
Nattapon Jaroenchai committed
    
    # Get the size of the map
    map_width, map_height, _ = h5_image.get_map_size(map_name)

    print("h5_image.patch_size", h5_image.patch_size)
Nattapon Jaroenchai's avatar
Nattapon Jaroenchai committed
    
    # Calculate the number of patches based on the patch size and border
    num_rows = math.ceil(map_width / h5_image.patch_size)
    num_cols = math.ceil(map_height / h5_image.patch_size)
    
    # Conditionally load the model based on the presence of "attention" in the model path
    if "attention" in args.modelPath:
        # Load the attention Unet model with custom objects for attention mechanisms
        print(f"Loading model with attention from {args.modelPath}")
        model = load_model(args.modelPath, custom_objects={'multiplication': multiplication,
                                                            'multiplication2': multiplication2,
                                                            'dice_coef_loss':dice_coef_loss,
                                                            'dice_coef':dice_coef})
    else:
        print(f"Loading standard model from {args.modelPath}")
        # Load the standard Unet model with custom objects for dice coefficient loss
        model = load_model(args.modelPath, custom_objects={'dice_coef_loss':dice_coef_loss, 
                                                            'dice_coef':dice_coef})
        
    
Nattapon Jaroenchai's avatar
Nattapon Jaroenchai committed
    # Loop through the patches and perform inference
    for legend in (map_legends):
        print(f"Processing legend: {legend}")
        
        # Create an empty array to store the full prediction
        full_prediction = np.zeros((map_width, map_height))

        # Get the legend patch
        legend_patch = h5_image.get_layer(map_name, legend)
        # Iterate through rows and columns to get map patches
        for row in range(num_rows):
            for col in range(num_cols):
                map_patch = h5_image.get_patch(row, col, map_name)
                # Get the prediction for the current patch
                print(f"Prediction for patch ({row}, {col}) completed.")
                prediction = perform_inference(legend_patch, map_patch, model)
                
                # Place the prediction in the corresponding position in the full_prediction array
                x_start = row * h5_image.patch_size
                y_start = col * h5_image.patch_size
                full_prediction[x_start:x_start+h5_image.patch_size, y_start:y_start+h5_image.patch_size] = prediction
        
        # Mask out the map background pixels from the prediction
        print("Applying mask to the full prediction.")
        masked_prediction = prediction_mask(full_prediction)

        # Save the results
        print("Saving results.")
        save_results(masked_prediction, args.outputPath, map_name, legend)
Nattapon Jaroenchai's avatar
Nattapon Jaroenchai committed
    
    # Close the HDF5 file
    print("Inference process completed. Closing HDF5 file.")
Nattapon Jaroenchai's avatar
Nattapon Jaroenchai committed
    h5_image.close()

if __name__ == "__main__":
    # Command-line interface setup
    parser = argparse.ArgumentParser(description="Perform inference on a given map.")
Nattapon Jaroenchai's avatar
Nattapon Jaroenchai committed
    parser.add_argument("--mapPath", required=True, help="Path to the hdf5 file.")
    parser.add_argument("--jsonPath", required=True, help="Path to the JSON file that contain positions of legends.")
    parser.add_argument("--featureType", choices=["Polygon", "Point", "Line", "All"], default="Polygon", help="Type of feature to detect. Three options, Polygon, Point, and, Line")
Nattapon Jaroenchai's avatar
Nattapon Jaroenchai committed
    parser.add_argument("--outputPath", required=True, help="Path to save the inference results. ")
    parser.add_argument("--modelPath", default="./inference_model/Unet-attentionUnet.h5", help="Path to the trained model. Default is './inference_model/Unet-attentionUnet.h5'.")
    
    args = parser.parse_args()
    main(args)

# Test command
# python inference.py --mapPath "/projects/bbym/shared/data/commonPatchData/256/OK_250K.hdf5" --jsonPath "" --featureType "Polygon" --outputPath "/projects/bbym/nathanj/attentionUnet/infer_results" --modelPath "/projects/bbym/nathanj/attentionUnet/inference_model/Unet-attentionUnet.h5"