Skip to content
Snippets Groups Projects
Commit 211a5fd8 authored by Nattapon Jaroenchai's avatar Nattapon Jaroenchai
Browse files

Update inference.py

parent 7f72fdca
No related branches found
No related tags found
No related merge requests found
......@@ -6,18 +6,58 @@ import numpy as np
from PIL import Image
import rasterio
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from keras.models import load_model
from data_util import DataLoader
from h5Image import H5Image
from unet_util import (UNET_224, Residual_CNN_block,
attention_up_and_concatenate,
attention_up_and_concatenate2, dice_coef,
dice_coef_loss, evaluate_prediction_result, jacard_coef,
multiplication, multiplication2)
attention_up_and_concatenate,
attention_up_and_concatenate2, dice_coef,
dice_coef_loss, evaluate_prediction_result, jacard_coef,
multiplication, multiplication2)
# Declare h5_image as a global variable to streamline data access across functions
h5_image = None
def save_plot_as_png(prediction_result, legend_img, map_name, legend, outputPath):
global h5_image
true_seg = h5_image.get_layer(map_name, legend)
full_map = h5_image.get_map(map_name)
legend_resized = tf.image.resize(legend_img, (h5_image.patch_size, h5_image.patch_size))
output_image_path = os.path.join(outputPath, f"{map_name}_{legend}_visual.png")
fig, axarr = plt.subplots(1, 4, figsize=(20,5))
# Load images
img1 = true_seg
img2 = prediction_result
img3 = full_map
img4 = legend_resized
# Display images
axarr[0].imshow(img1)
axarr[0].set_title('True segmentation')
axarr[0].axis('off')
axarr[1].imshow(img2)
axarr[1].set_title('Predicted segmentation')
axarr[1].axis('off')
axarr[2].imshow(img3)
axarr[2].set_title('Map')
axarr[2].axis('off')
axarr[3].imshow(img4)
axarr[3].set_title('Legend')
axarr[3].axis('off')
plt.tight_layout()
plt.savefig(output_image_path)
def prediction_mask(prediction_result, map_name, legend, outputPath):
"""
Apply a mask to the prediction image to isolate the area of interest.
......@@ -69,11 +109,6 @@ def prediction_mask(prediction_result, map_name, legend, outputPath):
# Perform the bitwise operation with the mask also converted to uint8
masked_img = cv2.bitwise_and(prediction_result_int, mask)
# Save the intermediate images using PIL
Image.fromarray((prediction_result * 255).astype(np.uint8)).save(os.path.join(outputPath, f"{map_name}__{legend}_prediction_result_x255.tif"))
Image.fromarray(mask).save(os.path.join(outputPath, f"{map_name}__{legend}_mask_uint8.tif"))
return masked_img
def perform_inference(legend_patch, map_patch, model):
......@@ -91,7 +126,11 @@ def perform_inference(legend_patch, map_patch, model):
global h5_image
legend_resized = tf.image.resize(legend_patch, (h5_image.patch_size, h5_image.patch_size))
legend_resized = tf.cast(tf.io.decode_png(legend_resized), dtype=tf.float32) / 255.0
map_patch_resize = tf.image.resize(map_patch, (h5_image.patch_size, h5_image.patch_size))
map_patch_resize = tf.cast(tf.io.decode_png(map_patch_resize), dtype=tf.float32) / 255.0
print("map_patch", map_patch.shape, "legend_patch", legend_resized.shape)
# Concatenate along the third axis and normalize
......@@ -109,7 +148,7 @@ def perform_inference(legend_patch, map_patch, model):
return prediction.squeeze()
def save_results(prediction, outputPath, map_name, legend):
def save_results(prediction, map_name, legend, outputPath):
"""
Save the prediction results to a specified output path.
......@@ -241,9 +280,11 @@ def main(args):
print("Applying mask to the full prediction.")
masked_prediction = prediction_mask(full_prediction, map_name, legend, args.outputPath)
save_plot_as_png(masked_prediction, legend_patch, map_name, legend, args.outputPath)
# Save the results
print("Saving results.")
save_results(masked_prediction, args.outputPath, map_name, legend)
save_results(masked_prediction, map_name, legend, args.outputPath)
# Close the HDF5 file
print("Inference process completed. Closing HDF5 file.")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment