Skip to content
Snippets Groups Projects
Commit f6917864 authored by Nattapon Jaroenchai's avatar Nattapon Jaroenchai
Browse files

Update inference.py

parent 32eeda76
No related branches found
No related tags found
No related merge requests found
......@@ -94,56 +94,59 @@ def save_plot_as_png(prediction_result, map_name, legend, outputPath):
# Save the combined visualization to the specified path
plt.savefig(output_image_path)
def prediction_mask(prediction_result, map_name, legend, outputPath):
def prediction_mask(prediction_result, map_name):
"""
Apply a mask to the prediction image to isolate the area of interest.
Parameters:
- pad_unpatch_predicted_threshold: The thresholded prediction array.
- prediction_result: numpy array, The output of the model after prediction.
- map_name: str, The name of the map used for prediction.
Returns:
- masked_img: The masked prediction image.
- masked_img: numpy array, The masked prediction image.
"""
global h5_image
# Get map array
# Get the map array corresponding to the given map name
map_array = np.array(h5_image.get_map(map_name))
print("map_array", map_array.shape)
# Convert the map array to a grayscale image
gray = cv2.cvtColor(map_array, cv2.COLOR_BGR2GRAY) # greyscale image
# Convert the RGB map array to grayscale for further processing
gray = cv2.cvtColor(map_array, cv2.COLOR_BGR2GRAY)
# Detect Background Color
# Identify the most frequent pixel value, which will be used as the background pixel value
pix_hist = cv2.calcHist([gray],[0],None,[256],[0,256])
background_pix_value = np.argmax(pix_hist, axis=None)
# Flood fill borders
# Flood fill from the corners to identify and modify the background regions
height, width = gray.shape[:2]
corners = [[0,0],[0,height-1],[width-1, 0],[width-1, height-1]]
for c in corners:
cv2.floodFill(gray, None, (c[0],c[1]), 255)
# AdaptiveThreshold to remove noise
# Adaptive thresholding to remove small noise and artifacts
thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 21, 4)
# Edge Detection
# Detect edges using the Canny edge detection method
thresh_blur = cv2.GaussianBlur(thresh, (11, 11), 0)
canny = cv2.Canny(thresh_blur, 0, 200)
canny_dilate = cv2.dilate(canny, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)))
# Finding contours for the detected edges.
# Detect contours in the edge-detected image
contours, hierarchy = cv2.findContours(canny_dilate, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
# Keeping only the largest detected contour.
# Retain only the largest contour
contour = sorted(contours, key=cv2.contourArea, reverse=True)[0]
# Create an empty mask of the same size as the prediction_result
wid, hight = prediction_result.shape[0], prediction_result.shape[1]
mask = np.zeros([wid, hight])
mask = cv2.fillPoly(mask, pts=[contour], color=(1)).astype(np.uint8)
# Threshold prediction results and convert to int
# Convert prediction result to a binary format using a threshold
prediction_result_int = (prediction_result > 0.5).astype(np.uint8)
# Perform the bitwise operation with the mask also converted to uint8
# Apply the mask to the thresholded prediction result
masked_img = cv2.bitwise_and(prediction_result_int, mask)
return masked_img
......@@ -153,33 +156,35 @@ def perform_inference(legend_patch, map_patch, model):
Perform inference on a given map patch and legend patch using a trained model.
Parameters:
- legend_patch: The legend patch from the map.
- map_patch: The map patch for inference.
- model: The trained deep learning model.
- legend_patch: numpy array, The legend patch from the map.
- map_patch: numpy array, The map patch for inference.
- model: tensorflow.keras Model, The trained deep learning model.
Returns:
- prediction: The prediction result for the given map patch.
- prediction: numpy array, The prediction result for the given map patch.
"""
global h5_image
# Resize the legend patch to match the h5 image patch size and normalize to [0,1]
legend_resized = cv2.resize(legend_patch, (h5_image.patch_size, h5_image.patch_size))
legend_resized = tf.cast(legend_resized, dtype=tf.float32) / 255.0
# Resize the map patch to match the h5 image patch size and normalize to [0,1]
map_patch_resize = cv2.resize(map_patch, (h5_image.patch_size, h5_image.patch_size))
map_patch_resize = tf.cast(map_patch_resize, dtype=tf.float32) / 255.0
# Concatenate along the third axis and normalize
# Concatenate the map and legend patches along the third axis (channels) and normalize to [-1,1]
input_patch = tf.concat(axis=2, values=[map_patch_resize, legend_resized])
input_patch = input_patch * 2.0 - 1.0
# Resize the input patch to match the model's expected input size
# Resize the concatenated input patch to match the model's expected input size
input_patch_resized = tf.image.resize(input_patch, (h5_image.patch_size, h5_image.patch_size))
# Expand dimensions for prediction
# Expand the dimensions of the input patch for the prediction (models expect a batch dimension)
input_patch_expanded = tf.expand_dims(input_patch_resized, axis=0)
# Get the prediction from the model
prediction = model.predict(input_patch_expanded)
# Obtain the prediction from the trained model
prediction = model.predict(input_patch_expanded, verbose=0)
return prediction.squeeze()
......@@ -189,9 +194,9 @@ def save_results(prediction, map_name, legend, outputPath):
Parameters:
- prediction: The prediction result (should be a 2D or 3D numpy array).
- outputPath: The directory where the results should be saved.
- map_name: The name of the map.
- legend: The legend associated with the prediction.
- outputPath: The directory where the results should be saved.
"""
output_image_path = os.path.join(outputPath, f"{map_name}_{legend}.tif")
......@@ -289,8 +294,8 @@ def main(args):
map_patch = h5_image.get_patch(row, col, map_name)
# Get the prediction for the current patch
print(f"Prediction for patch ({row}, {col}) completed.")
prediction = perform_inference(legend_patch, map_patch, model)
# print(f"Prediction for patch ({row}, {col}) completed.")
# Calculate starting indices for rows and columns
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment