Skip to content
Snippets Groups Projects
Commit 6bcd6fc6 authored by Nattapon Jaroenchai's avatar Nattapon Jaroenchai
Browse files

Update inference.py

parent f7247d3b
No related branches found
No related tags found
1 merge request!1updated main for release
......@@ -18,7 +18,7 @@ from unet_util import (UNET_224, Residual_CNN_block,
# Declare h5_image as a global variable to streamline data access across functions
h5_image = None
def prediction_mask(prediction_result, map_name, outputPath):
def prediction_mask(prediction_result, map_name, legend, outputPath):
"""
Apply a mask to the prediction image to isolate the area of interest.
......@@ -62,20 +62,17 @@ def prediction_mask(prediction_result, map_name, outputPath):
contour = sorted(contours, key=cv2.contourArea, reverse=True)[0]
wid, hight = prediction_result.shape[0], prediction_result.shape[1]
mask = np.zeros([wid, hight])
mask = cv2.fillPoly(mask, pts=[contour], color=(1,1,1)).astype(int)
mask = cv2.fillPoly(mask, pts=[contour], color=(1)).astype(int)
# Normalize the float image to be in the range [0, 255] if it's not already
prediction_result_normalized = cv2.normalize(prediction_result, None, 0, 255, cv2.NORM_MINMAX)
prediction_result_uint8 = prediction_result_normalized.astype(np.uint8)
# Threshold prediction results and convert to int
prediction_result_int = (prediction_result > 0.5).astype(int)
# Perform the bitwise operation with the mask also converted to uint8
mask_uint8 = mask.astype(np.uint8)
masked_img = cv2.bitwise_and(prediction_result_uint8, mask_uint8)
masked_img = cv2.bitwise_and(prediction_result_int, mask)
# Save the intermediate images using PIL
Image.fromarray((prediction_result * 255).astype(np.uint8)).save(os.path.join(outputPath, f"{map_name}_prediction_result_x255.tif"))
Image.fromarray(prediction_result_normalized).save(os.path.join(outputPath, f"{map_name}_prediction_result_normalized.tif"))
Image.fromarray(mask_uint8).save(os.path.join(outputPath, f"{map_name}_mask_uint8.tif"))
Image.fromarray((prediction_result * 255).astype(np.uint8)).save(os.path.join(outputPath, f"{map_name}__{legend}_prediction_result_x255.tif"))
Image.fromarray(mask).save(os.path.join(outputPath, f"{map_name}__{legend}_mask_uint8.tif"))
return masked_img
......@@ -126,7 +123,7 @@ def save_results(prediction, outputPath, map_name, legend):
# Convert the prediction to an image
# Note: The prediction array may need to be scaled or converted before saving as an image
prediction_image = Image.fromarray(prediction.astype(np.uint8))
prediction_image = Image.fromarray((prediction*255).astype(np.uint8))
# Save the prediction as a tiff image
prediction_image.save(output_image_path, 'TIFF')
......@@ -242,7 +239,7 @@ def main(args):
# Mask out the map background pixels from the prediction
print("Applying mask to the full prediction.")
masked_prediction = prediction_mask(full_prediction, map_name, args.outputPath)
masked_prediction = prediction_mask(full_prediction, map_name, legend, args.outputPath)
# Save the results
print("Saving results.")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment