Skip to content
Snippets Groups Projects
Commit f6917864 authored by Nattapon Jaroenchai's avatar Nattapon Jaroenchai
Browse files

Update inference.py

parent 32eeda76
No related branches found
No related tags found
1 merge request!1updated main for release
......@@ -94,56 +94,59 @@ def save_plot_as_png(prediction_result, map_name, legend, outputPath):
# Save the combined visualization to the specified path
plt.savefig(output_image_path)
def prediction_mask(prediction_result, map_name, legend, outputPath):
def prediction_mask(prediction_result, map_name):
"""
Apply a mask to the prediction image to isolate the area of interest.
Parameters:
- pad_unpatch_predicted_threshold: The thresholded prediction array.
- prediction_result: numpy array, The output of the model after prediction.
- map_name: str, The name of the map used for prediction.
Returns:
- masked_img: The masked prediction image.
- masked_img: numpy array, The masked prediction image.
"""
global h5_image
# Get map array
# Get the map array corresponding to the given map name
map_array = np.array(h5_image.get_map(map_name))
print("map_array", map_array.shape)
# Convert the map array to a grayscale image
gray = cv2.cvtColor(map_array, cv2.COLOR_BGR2GRAY) # greyscale image
# Convert the RGB map array to grayscale for further processing
gray = cv2.cvtColor(map_array, cv2.COLOR_BGR2GRAY)
# Detect Background Color
# Identify the most frequent pixel value, which will be used as the background pixel value
pix_hist = cv2.calcHist([gray],[0],None,[256],[0,256])
background_pix_value = np.argmax(pix_hist, axis=None)
# Flood fill borders
# Flood fill from the corners to identify and modify the background regions
height, width = gray.shape[:2]
corners = [[0,0],[0,height-1],[width-1, 0],[width-1, height-1]]
for c in corners:
cv2.floodFill(gray, None, (c[0],c[1]), 255)
# AdaptiveThreshold to remove noise
# Adaptive thresholding to remove small noise and artifacts
thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 21, 4)
# Edge Detection
# Detect edges using the Canny edge detection method
thresh_blur = cv2.GaussianBlur(thresh, (11, 11), 0)
canny = cv2.Canny(thresh_blur, 0, 200)
canny_dilate = cv2.dilate(canny, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)))
# Finding contours for the detected edges.
# Detect contours in the edge-detected image
contours, hierarchy = cv2.findContours(canny_dilate, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
# Keeping only the largest detected contour.
# Retain only the largest contour
contour = sorted(contours, key=cv2.contourArea, reverse=True)[0]
# Create an empty mask of the same size as the prediction_result
wid, hight = prediction_result.shape[0], prediction_result.shape[1]
mask = np.zeros([wid, hight])
mask = cv2.fillPoly(mask, pts=[contour], color=(1)).astype(np.uint8)
# Threshold prediction results and convert to int
# Convert prediction result to a binary format using a threshold
prediction_result_int = (prediction_result > 0.5).astype(np.uint8)
# Perform the bitwise operation with the mask also converted to uint8
# Apply the mask to the thresholded prediction result
masked_img = cv2.bitwise_and(prediction_result_int, mask)
return masked_img
......@@ -153,33 +156,35 @@ def perform_inference(legend_patch, map_patch, model):
Perform inference on a given map patch and legend patch using a trained model.
Parameters:
- legend_patch: The legend patch from the map.
- map_patch: The map patch for inference.
- model: The trained deep learning model.
- legend_patch: numpy array, The legend patch from the map.
- map_patch: numpy array, The map patch for inference.
- model: tensorflow.keras Model, The trained deep learning model.
Returns:
- prediction: The prediction result for the given map patch.
- prediction: numpy array, The prediction result for the given map patch.
"""
global h5_image
# Resize the legend patch to match the h5 image patch size and normalize to [0,1]
legend_resized = cv2.resize(legend_patch, (h5_image.patch_size, h5_image.patch_size))
legend_resized = tf.cast(legend_resized, dtype=tf.float32) / 255.0
# Resize the map patch to match the h5 image patch size and normalize to [0,1]
map_patch_resize = cv2.resize(map_patch, (h5_image.patch_size, h5_image.patch_size))
map_patch_resize = tf.cast(map_patch_resize, dtype=tf.float32) / 255.0
# Concatenate along the third axis and normalize
# Concatenate the map and legend patches along the third axis (channels) and normalize to [-1,1]
input_patch = tf.concat(axis=2, values=[map_patch_resize, legend_resized])
input_patch = input_patch * 2.0 - 1.0
# Resize the input patch to match the model's expected input size
# Resize the concatenated input patch to match the model's expected input size
input_patch_resized = tf.image.resize(input_patch, (h5_image.patch_size, h5_image.patch_size))
# Expand dimensions for prediction
# Expand the dimensions of the input patch for the prediction (models expect a batch dimension)
input_patch_expanded = tf.expand_dims(input_patch_resized, axis=0)
# Get the prediction from the model
prediction = model.predict(input_patch_expanded)
# Obtain the prediction from the trained model
prediction = model.predict(input_patch_expanded, verbose=0)
return prediction.squeeze()
......@@ -189,9 +194,9 @@ def save_results(prediction, map_name, legend, outputPath):
Parameters:
- prediction: The prediction result (should be a 2D or 3D numpy array).
- outputPath: The directory where the results should be saved.
- map_name: The name of the map.
- legend: The legend associated with the prediction.
- outputPath: The directory where the results should be saved.
"""
output_image_path = os.path.join(outputPath, f"{map_name}_{legend}.tif")
......@@ -289,8 +294,8 @@ def main(args):
map_patch = h5_image.get_patch(row, col, map_name)
# Get the prediction for the current patch
print(f"Prediction for patch ({row}, {col}) completed.")
prediction = perform_inference(legend_patch, map_patch, model)
# print(f"Prediction for patch ({row}, {col}) completed.")
# Calculate starting indices for rows and columns
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment