Skip to content
Snippets Groups Projects
Commit 211a5fd8 authored by Nattapon Jaroenchai's avatar Nattapon Jaroenchai
Browse files

Update inference.py

parent 7f72fdca
No related branches found
No related tags found
1 merge request!1updated main for release
......@@ -6,18 +6,58 @@ import numpy as np
from PIL import Image
import rasterio
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from keras.models import load_model
from data_util import DataLoader
from h5Image import H5Image
from unet_util import (UNET_224, Residual_CNN_block,
attention_up_and_concatenate,
attention_up_and_concatenate2, dice_coef,
dice_coef_loss, evaluate_prediction_result, jacard_coef,
multiplication, multiplication2)
attention_up_and_concatenate,
attention_up_and_concatenate2, dice_coef,
dice_coef_loss, evaluate_prediction_result, jacard_coef,
multiplication, multiplication2)
# Declare h5_image as a global variable to streamline data access across functions
h5_image = None
def save_plot_as_png(prediction_result, legend_img, map_name, legend, outputPath):
global h5_image
true_seg = h5_image.get_layer(map_name, legend)
full_map = h5_image.get_map(map_name)
legend_resized = tf.image.resize(legend_img, (h5_image.patch_size, h5_image.patch_size))
output_image_path = os.path.join(outputPath, f"{map_name}_{legend}_visual.png")
fig, axarr = plt.subplots(1, 4, figsize=(20,5))
# Load images
img1 = true_seg
img2 = prediction_result
img3 = full_map
img4 = legend_resized
# Display images
axarr[0].imshow(img1)
axarr[0].set_title('True segmentation')
axarr[0].axis('off')
axarr[1].imshow(img2)
axarr[1].set_title('Predicted segmentation')
axarr[1].axis('off')
axarr[2].imshow(img3)
axarr[2].set_title('Map')
axarr[2].axis('off')
axarr[3].imshow(img4)
axarr[3].set_title('Legend')
axarr[3].axis('off')
plt.tight_layout()
plt.savefig(output_image_path)
def prediction_mask(prediction_result, map_name, legend, outputPath):
"""
Apply a mask to the prediction image to isolate the area of interest.
......@@ -69,11 +109,6 @@ def prediction_mask(prediction_result, map_name, legend, outputPath):
# Perform the bitwise operation with the mask also converted to uint8
masked_img = cv2.bitwise_and(prediction_result_int, mask)
# Save the intermediate images using PIL
Image.fromarray((prediction_result * 255).astype(np.uint8)).save(os.path.join(outputPath, f"{map_name}__{legend}_prediction_result_x255.tif"))
Image.fromarray(mask).save(os.path.join(outputPath, f"{map_name}__{legend}_mask_uint8.tif"))
return masked_img
def perform_inference(legend_patch, map_patch, model):
......@@ -91,7 +126,11 @@ def perform_inference(legend_patch, map_patch, model):
global h5_image
legend_resized = tf.image.resize(legend_patch, (h5_image.patch_size, h5_image.patch_size))
legend_resized = tf.cast(tf.io.decode_png(legend_resized), dtype=tf.float32) / 255.0
map_patch_resize = tf.image.resize(map_patch, (h5_image.patch_size, h5_image.patch_size))
map_patch_resize = tf.cast(tf.io.decode_png(map_patch_resize), dtype=tf.float32) / 255.0
print("map_patch", map_patch.shape, "legend_patch", legend_resized.shape)
# Concatenate along the third axis and normalize
......@@ -109,7 +148,7 @@ def perform_inference(legend_patch, map_patch, model):
return prediction.squeeze()
def save_results(prediction, outputPath, map_name, legend):
def save_results(prediction, map_name, legend, outputPath):
"""
Save the prediction results to a specified output path.
......@@ -241,9 +280,11 @@ def main(args):
print("Applying mask to the full prediction.")
masked_prediction = prediction_mask(full_prediction, map_name, legend, args.outputPath)
save_plot_as_png(masked_prediction, legend_patch, map_name, legend, args.outputPath)
# Save the results
print("Saving results.")
save_results(masked_prediction, args.outputPath, map_name, legend)
save_results(masked_prediction, map_name, legend, args.outputPath)
# Close the HDF5 file
print("Inference process completed. Closing HDF5 file.")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment