import argparse import math import cv2 import os import numpy as np from PIL import Image import rasterio import tensorflow as tf import matplotlib.pyplot as plt import matplotlib.image as mpimg from keras.models import load_model import matplotlib.gridspec as gridspec from data_util import DataLoader from h5Image import H5Image from unet_util import (UNET_224, Residual_CNN_block, attention_up_and_concatenate, attention_up_and_concatenate2, dice_coef, dice_coef_loss, evaluate_prediction_result, jacard_coef, multiplication, multiplication2) # Declare h5_image as a global variable to streamline data access across functions h5_image = None def save_plot_as_png(prediction_result, map_name, legend, outputPath): """ This function visualizes and saves the True Segmentation, Predicted Segmentation, Full Map, and the Legend in a single image. Parameters: - prediction_result: 2D numpy array representing the predicted segmentation. - map_name: string, the name of the map. - legend: string, the name of the legend. - outputPath: string, the directory where the output image will be saved. Returns: - None. The output image is saved in the specified directory. """ global h5_image # Using a global variable to access the h5 image object # Fetching the true segmentation layer true_seg = h5_image.get_layer(map_name, legend) # Fetching the full map full_map = h5_image.get_map(map_name) # Fetching the legend patch from h5 image legend_patch = h5_image.get_legend(map_name, legend) # Resize the legend to the specified dimensions legend_resized = cv2.resize(legend_patch, (256,256)) # Convert the legend to uint8 range [0, 255] if its dtype is float32 if legend_resized.dtype == tf.float32: legend_resized = (legend_resized * 255).numpy().astype(np.uint8) # Construct the output image path output_image_path = os.path.join(outputPath, f"{map_name}_{legend}_visual.png") # Create a figure with 4 subplots: true segmentation, predicted segmentation, full map, and legend fig, axarr = plt.subplots(1, 4, figsize=(20,5)) # Using GridSpec for custom sizing of the subplots gs = gridspec.GridSpec(1, 4, width_ratios=[1,1,1,1]) # Display the true segmentation ax0 = plt.subplot(gs[0]) ax0.imshow(true_seg) ax0.set_title('True segmentation') ax0.axis('off') # Display the predicted segmentation ax1 = plt.subplot(gs[1]) ax1.imshow(prediction_result) ax1.set_title('Predicted segmentation') ax1.axis('off') # Display the full map ax2 = plt.subplot(gs[2]) ax2.imshow(full_map) ax2.set_title('Map') ax2.axis('off') # Display the resized legend ax3 = plt.subplot(gs[3]) ax3.imshow(legend_resized) ax3.set_title('Legend') ax3.axis('off') # Adjust layout to ensure there's no overlap plt.tight_layout() # Save the combined visualization to the specified path plt.savefig(output_image_path) def prediction_mask(prediction_result, map_name): """ Apply a mask to the prediction image to isolate the area of interest. Parameters: - prediction_result: numpy array, The output of the model after prediction. - map_name: str, The name of the map used for prediction. Returns: - masked_img: numpy array, The masked prediction image. """ global h5_image # Get the map array corresponding to the given map name map_array = np.array(h5_image.get_map(map_name)) # print("map_array", map_array.shape) # Convert the RGB map array to grayscale for further processing gray = cv2.cvtColor(map_array, cv2.COLOR_BGR2GRAY) # Identify the most frequent pixel value, which will be used as the background pixel value pix_hist = cv2.calcHist([gray],[0],None,[256],[0,256]) background_pix_value = np.argmax(pix_hist, axis=None) # Flood fill from the corners to identify and modify the background regions height, width = gray.shape[:2] corners = [[0,0],[0,height-1],[width-1, 0],[width-1, height-1]] for c in corners: cv2.floodFill(gray, None, (c[0],c[1]), 255) # Adaptive thresholding to remove small noise and artifacts thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 21, 4) # Detect edges using the Canny edge detection method thresh_blur = cv2.GaussianBlur(thresh, (11, 11), 0) canny = cv2.Canny(thresh_blur, 0, 200) canny_dilate = cv2.dilate(canny, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))) # Detect contours in the edge-detected image contours, hierarchy = cv2.findContours(canny_dilate, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE) # Retain only the largest contour contour = sorted(contours, key=cv2.contourArea, reverse=True)[0] # Create an empty mask of the same size as the prediction_result wid, hight = prediction_result.shape[0], prediction_result.shape[1] mask = np.zeros([wid, hight]) mask = cv2.fillPoly(mask, pts=[contour], color=(1)).astype(np.uint8) # Convert prediction result to a binary format using a threshold prediction_result_int = (prediction_result > 0.5).astype(np.uint8) # Apply the mask to the thresholded prediction result masked_img = cv2.bitwise_and(prediction_result_int, mask) return masked_img def perform_inference(legend_patch, map_patch, model): """ Perform inference on a given map patch and legend patch using a trained model. Parameters: - legend_patch: numpy array, The legend patch from the map. - map_patch: numpy array, The map patch for inference. - model: tensorflow.keras Model, The trained deep learning model. Returns: - prediction: numpy array, The prediction result for the given map patch. """ global h5_image # Resize the legend patch to match the h5 image patch size and normalize to [0,1] legend_resized = cv2.resize(legend_patch, (h5_image.patch_size, h5_image.patch_size)) legend_resized = tf.cast(legend_resized, dtype=tf.float32) / 255.0 # Resize the map patch to match the h5 image patch size and normalize to [0,1] map_patch_resize = cv2.resize(map_patch, (h5_image.patch_size, h5_image.patch_size)) map_patch_resize = tf.cast(map_patch_resize, dtype=tf.float32) / 255.0 # Concatenate the map and legend patches along the third axis (channels) and normalize to [-1,1] input_patch = tf.concat(axis=2, values=[map_patch_resize, legend_resized]) input_patch = input_patch * 2.0 - 1.0 # Resize the concatenated input patch to match the model's expected input size input_patch_resized = tf.image.resize(input_patch, (h5_image.patch_size, h5_image.patch_size)) # Expand the dimensions of the input patch for the prediction (models expect a batch dimension) input_patch_expanded = tf.expand_dims(input_patch_resized, axis=0) # Obtain the prediction from the trained model prediction = model.predict(input_patch_expanded, verbose=0) return prediction.squeeze() def save_results(prediction, map_name, legend, outputPath): """ Save the prediction results to a specified output path. Parameters: - prediction: The prediction result (should be a 2D or 3D numpy array). - map_name: The name of the map. - legend: The legend associated with the prediction. - outputPath: The directory where the results should be saved. """ global h5_image output_image_path = os.path.join(outputPath, f"{map_name}_{legend}.tif") # Convert the prediction to an image # Note: The prediction array may need to be scaled or converted before saving as an image prediction_image = Image.fromarray((prediction*255).astype(np.uint8)) # Save the prediction as a tiff image # prediction_image.save(output_image_path, 'TIFF') image = np.moveaxis(prediction_image, -1, 0) rasterio.open(output_image_path, 'w', driver='GTiff', compress='lzw', height=image.shape[1], width=image.shape[2], count=image.shape[0], dtype=image.dtype, crs=h5_image.get_crs(map_name, legend), transform=h5_image.get_transform(map_name, legend)).write(image) def main(args): """ Main function to orchestrate the map inference process. Parameters: - args: Command-line arguments. """ global h5_image # Load the HDF5 file using the H5Image class print("Loading the HDF5 file.") h5_image = H5Image(args.mapPath, mode='r', patch_border=0) # Get map details print("Getting map details.") map_name = h5_image.get_maps()[0] print(f"Map Name: {map_name}") all_map_legends = h5_image.get_layers(map_name) # print(f"All Map Legends: {all_map_legends}") # Filter the legends based on the feature type if args.featureType == "Polygon": map_legends = [legend for legend in all_map_legends if "_poly" in legend] elif args.featureType == "Point": map_legends = [legend for legend in all_map_legends if "_pt" in legend] elif args.featureType == "Line": map_legends = [legend for legend in all_map_legends if "_line" in legend] elif args.featureType == "All": map_legends = all_map_legends # Get the size of the map map_width, map_height, _ = h5_image.get_map_size(map_name) # Calculate the number of patches based on the patch size and border num_rows = math.ceil(map_width / h5_image.patch_size) num_cols = math.ceil(map_height / h5_image.patch_size) # Conditionally load the model based on the presence of "attention" in the model path if "attention" in args.modelPath: # Load the attention Unet model with custom objects for attention mechanisms print(f"Loading model with attention from {args.modelPath}") model = load_model(args.modelPath, custom_objects={'multiplication': multiplication, 'multiplication2': multiplication2, 'dice_coef_loss':dice_coef_loss, 'dice_coef':dice_coef}) else: print(f"Loading standard model from {args.modelPath}") # Load the standard Unet model with custom objects for dice coefficient loss model = load_model(args.modelPath, custom_objects={'dice_coef_loss':dice_coef_loss, 'dice_coef':dice_coef}) # Loop through the patches and perform inference for legend in (map_legends): print(f"Processing legend: {legend}") # Create an empty array to store the full prediction full_prediction = np.zeros((map_width, map_height)) # Get the legend patch legend_patch = h5_image.get_legend(map_name, legend) # Iterate through rows and columns to get map patches for row in range(num_rows): for col in range(num_cols): map_patch = h5_image.get_patch(row, col, map_name) # Get the prediction for the current patch prediction = perform_inference(legend_patch, map_patch, model) # print(f"Prediction for patch ({row}, {col}) completed.") # Calculate starting indices for rows and columns x_start = row * h5_image.patch_size y_start = col * h5_image.patch_size # Calculate ending indices for rows and columns x_end = x_start + h5_image.patch_size y_end = y_start + h5_image.patch_size # Adjust the ending indices if they go beyond the image size x_end = min(x_end, map_width) y_end = min(y_end, map_height) # Adjust the shape of the prediction if necessary prediction_shape_adjusted = prediction[:x_end-x_start, :y_end-y_start] # Assign the prediction to the correct location in the full_prediction array full_prediction[x_start:x_end, y_start:y_end] = prediction_shape_adjusted # Mask out the map background pixels from the prediction print("Applying mask to the full prediction.") masked_prediction = prediction_mask(full_prediction, map_name) save_plot_as_png(masked_prediction, map_name, legend, args.outputPath) # Save the results print("Saving results.") save_results(masked_prediction, map_name, legend, args.outputPath) # Close the HDF5 file print("Inference process completed. Closing HDF5 file.") h5_image.close() if __name__ == "__main__": # Command-line interface setup parser = argparse.ArgumentParser(description="Perform inference on a given map.") parser.add_argument("--mapPath", required=True, help="Path to the hdf5 file.") parser.add_argument("--featureType", choices=["Polygon", "Point", "Line", "All"], default="Polygon", help="Type of feature to detect. Three options, Polygon, Point, and, Line") parser.add_argument("--outputPath", required=True, help="Path to save the inference results. ") parser.add_argument("--modelPath", default="./inference_model/Unet-attentionUnet.h5", help="Path to the trained model. Default is './inference_model/Unet-attentionUnet.h5'.") args = parser.parse_args() main(args) # Test command # python inference.py --mapPath "/projects/bbym/shared/data/commonPatchData/256/OK_250K.hdf5" --featureType "Polygon" --outputPath "/projects/bbym/nathanj/attentionUnet/infer_results" --modelPath "/projects/bbym/nathanj/attentionUnet/inference_model/Unet-attentionUnet.h5"