Skip to content
Snippets Groups Projects
Commit 6bcd6fc6 authored by Nattapon Jaroenchai's avatar Nattapon Jaroenchai
Browse files

Update inference.py

parent f7247d3b
No related branches found
No related tags found
No related merge requests found
......@@ -18,7 +18,7 @@ from unet_util import (UNET_224, Residual_CNN_block,
# Declare h5_image as a global variable to streamline data access across functions
h5_image = None
def prediction_mask(prediction_result, map_name, outputPath):
def prediction_mask(prediction_result, map_name, legend, outputPath):
"""
Apply a mask to the prediction image to isolate the area of interest.
......@@ -62,20 +62,17 @@ def prediction_mask(prediction_result, map_name, outputPath):
contour = sorted(contours, key=cv2.contourArea, reverse=True)[0]
wid, hight = prediction_result.shape[0], prediction_result.shape[1]
mask = np.zeros([wid, hight])
mask = cv2.fillPoly(mask, pts=[contour], color=(1,1,1)).astype(int)
mask = cv2.fillPoly(mask, pts=[contour], color=(1)).astype(int)
# Normalize the float image to be in the range [0, 255] if it's not already
prediction_result_normalized = cv2.normalize(prediction_result, None, 0, 255, cv2.NORM_MINMAX)
prediction_result_uint8 = prediction_result_normalized.astype(np.uint8)
# Threshold prediction results and convert to int
prediction_result_int = (prediction_result > 0.5).astype(int)
# Perform the bitwise operation with the mask also converted to uint8
mask_uint8 = mask.astype(np.uint8)
masked_img = cv2.bitwise_and(prediction_result_uint8, mask_uint8)
masked_img = cv2.bitwise_and(prediction_result_int, mask)
# Save the intermediate images using PIL
Image.fromarray((prediction_result * 255).astype(np.uint8)).save(os.path.join(outputPath, f"{map_name}_prediction_result_x255.tif"))
Image.fromarray(prediction_result_normalized).save(os.path.join(outputPath, f"{map_name}_prediction_result_normalized.tif"))
Image.fromarray(mask_uint8).save(os.path.join(outputPath, f"{map_name}_mask_uint8.tif"))
Image.fromarray((prediction_result * 255).astype(np.uint8)).save(os.path.join(outputPath, f"{map_name}__{legend}_prediction_result_x255.tif"))
Image.fromarray(mask).save(os.path.join(outputPath, f"{map_name}__{legend}_mask_uint8.tif"))
return masked_img
......@@ -126,7 +123,7 @@ def save_results(prediction, outputPath, map_name, legend):
# Convert the prediction to an image
# Note: The prediction array may need to be scaled or converted before saving as an image
prediction_image = Image.fromarray(prediction.astype(np.uint8))
prediction_image = Image.fromarray((prediction*255).astype(np.uint8))
# Save the prediction as a tiff image
prediction_image.save(output_image_path, 'TIFF')
......@@ -242,7 +239,7 @@ def main(args):
# Mask out the map background pixels from the prediction
print("Applying mask to the full prediction.")
masked_prediction = prediction_mask(full_prediction, map_name, args.outputPath)
masked_prediction = prediction_mask(full_prediction, map_name, legend, args.outputPath)
# Save the results
print("Saving results.")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment