Skip to content
Snippets Groups Projects
Commit f7247d3b authored by Nattapon Jaroenchai's avatar Nattapon Jaroenchai
Browse files

Update inference.py

parent b2c1cdea
No related branches found
No related tags found
1 merge request!1updated main for release
......@@ -18,7 +18,7 @@ from unet_util import (UNET_224, Residual_CNN_block,
# Declare h5_image as a global variable to streamline data access across functions
h5_image = None
def prediction_mask(prediction_result, map_name):
def prediction_mask(prediction_result, map_name, outputPath):
"""
Apply a mask to the prediction image to isolate the area of interest.
......@@ -72,6 +72,11 @@ def prediction_mask(prediction_result, map_name):
mask_uint8 = mask.astype(np.uint8)
masked_img = cv2.bitwise_and(prediction_result_uint8, mask_uint8)
# Save the intermediate images using PIL
Image.fromarray((prediction_result * 255).astype(np.uint8)).save(os.path.join(outputPath, f"{map_name}_prediction_result_x255.tif"))
Image.fromarray(prediction_result_normalized).save(os.path.join(outputPath, f"{map_name}_prediction_result_normalized.tif"))
Image.fromarray(mask_uint8).save(os.path.join(outputPath, f"{map_name}_mask_uint8.tif"))
return masked_img
def perform_inference(legend_patch, map_patch, model):
......@@ -237,7 +242,7 @@ def main(args):
# Mask out the map background pixels from the prediction
print("Applying mask to the full prediction.")
masked_prediction = prediction_mask(full_prediction, map_name)
masked_prediction = prediction_mask(full_prediction, map_name, args.outputPath)
# Save the results
print("Saving results.")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment