Newer
Older
import tensorflow as tf
from keras.models import load_model
from data_util import DataLoader
from h5Image import H5Image
from unet_util import (UNET_224, Residual_CNN_block,
attention_up_and_concatenate,
attention_up_and_concatenate2, dice_coef,
dice_coef_loss, evaluate_prediction_result, jacard_coef,
multiplication, multiplication2)
# Declare h5_image as a global variable to streamline data access across functions
def prediction_mask(prediction_result, map_name):
"""
Apply a mask to the prediction image to isolate the area of interest.
Parameters:
- pad_unpatch_predicted_threshold: The thresholded prediction array.
Returns:
- masked_img: The masked prediction image.
"""
global h5_image
# Get map array
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Convert the map array to a grayscale image
gray = cv2.cvtColor(map_array, cv2.COLOR_BGR2GRAY) # greyscale image
# Detect Background Color
pix_hist = cv2.calcHist([gray],[0],None,[256],[0,256])
background_pix_value = np.argmax(pix_hist, axis=None)
# Flood fill borders
height, width = gray.shape[:2]
corners = [[0,0],[0,height-1],[width-1, 0],[width-1, height-1]]
for c in corners:
cv2.floodFill(gray, None, (c[0],c[1]), 255)
# AdaptiveThreshold to remove noise
thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 21, 4)
# Edge Detection
thresh_blur = cv2.GaussianBlur(thresh, (11, 11), 0)
canny = cv2.Canny(thresh_blur, 0, 200)
canny_dilate = cv2.dilate(canny, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)))
# Finding contours for the detected edges.
contours, hierarchy = cv2.findContours(canny_dilate, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
# Keeping only the largest detected contour.
contour = sorted(contours, key=cv2.contourArea, reverse=True)[0]
wid, hight = prediction_result.shape[0], prediction_result.shape[1]
mask = np.zeros([wid, hight])
mask = cv2.fillPoly(mask, pts=[contour], color=(1,1,1)).astype(int)
masked_img = cv2.bitwise_and(prediction_result, mask)
return masked_img
def perform_inference(legend_patch, map_patch, model):
"""
Perform inference on a given map patch and legend patch using a trained model.
Parameters:
- legend_patch: The legend patch from the map.
- map_patch: The map patch for inference.
- model: The trained deep learning model.
Returns:
- prediction: The prediction result for the given map patch.
"""
legend_resized = tf.image.resize(legend_patch, (h5_image.patch_size, h5_image.patch_size))
map_patch_resize = tf.image.resize(map_patch, (h5_image.patch_size, h5_image.patch_size))
print("map_patch", map_patch.shape, "legend_patch", legend_resized.shape)
# Concatenate along the third axis and normalize
input_patch = tf.concat(axis=2, values=[map_patch_resize, legend_resized])
# Resize the input patch to match the model's expected input size
input_patch_resized = tf.image.resize(input_patch, (h5_image.patch_size, h5_image.patch_size))
# Expand dimensions for prediction
input_patch_expanded = tf.expand_dims(input_patch_resized, axis=0)
prediction = model.predict(input_patch_expanded)
return prediction.squeeze()
def save_results(prediction, outputPath, map_name, legend):
"""
Save the prediction results to a specified output path.
Parameters:
- prediction: The prediction result.
- outputPath: The directory where the results should be saved.
- map_name: The name of the map.
- legend: The legend associated with the prediction.
"""
global h5_image
output_image_path = os.path.join(outputPath, f"{map_name}_{legend}.tif")
# Save the prediction to the output path
with open(output_image_path, 'w') as f:
f.write(str(prediction))
### Waiting for georeferencing data
# This section will be used in future releases to save georeferenced images.
### Waiting for georeferencing data
# with rasterio.open(map_img_path) as src:
# metadata = src.meta
# metadata.update({
# 'dtype': 'uint8',
# 'count': 1,
# 'height': reconstructed_image.shape[0],
# 'width': reconstructed_image.shape[1],
# 'compress': 'lzw',
# })
# with rasterio.open(output_image_path, 'w', **metadata) as dst:
# dst.write(reconstructed_image, 1)
"""
Main function to orchestrate the map inference process.
Parameters:
- args: Command-line arguments.
"""
h5_image = H5Image(args.mapPath, mode='r', patch_border=0)
map_name = h5_image.get_maps()[0]
print(f"Map Name: {map_name}")
all_map_legends = h5_image.get_layers(map_name)
# Filter the legends based on the feature type
if args.featureType == "Polygon":
map_legends = [legend for legend in all_map_legends if "_poly" in legend]
elif args.featureType == "Point":
map_legends = [legend for legend in all_map_legends if "_pt" in legend]
elif args.featureType == "Line":
map_legends = [legend for legend in all_map_legends if "_line" in legend]
elif args.featureType == "All":
map_legends = all_map_legends
map_width, map_height, _ = h5_image.get_map_size(map_name)
# Calculate the number of patches based on the patch size and border
num_rows = math.ceil(map_width / h5_image.patch_size)
num_cols = math.ceil(map_height / h5_image.patch_size)
# Conditionally load the model based on the presence of "attention" in the model path
if "attention" in args.modelPath:
# Load the attention Unet model with custom objects for attention mechanisms
print(f"Loading model with attention from {args.modelPath}")
model = load_model(args.modelPath, custom_objects={'multiplication': multiplication,
'multiplication2': multiplication2,
'dice_coef_loss':dice_coef_loss,
'dice_coef':dice_coef})
else:
print(f"Loading standard model from {args.modelPath}")
# Load the standard Unet model with custom objects for dice coefficient loss
model = load_model(args.modelPath, custom_objects={'dice_coef_loss':dice_coef_loss,
'dice_coef':dice_coef})
# Create an empty array to store the full prediction
full_prediction = np.zeros((map_width, map_height))
legend_patch = h5_image.get_legend(map_name, legend)
# Iterate through rows and columns to get map patches
for row in range(num_rows):
for col in range(num_cols):
map_patch = h5_image.get_patch(row, col, map_name)
print(f"Prediction for patch ({row}, {col}) completed.")
prediction = perform_inference(legend_patch, map_patch, model)
# Calculate starting indices for rows and columns
x_start = row * h5_image.patch_size
y_start = col * h5_image.patch_size
# Calculate ending indices for rows and columns
x_end = x_start + h5_image.patch_size
y_end = y_start + h5_image.patch_size
# Adjust the ending indices if they go beyond the image size
x_end = min(x_end, map_width)
y_end = min(y_end, map_height)
# Adjust the shape of the prediction if necessary
prediction_shape_adjusted = prediction[:x_end-x_start, :y_end-y_start]
# Assign the prediction to the correct location in the full_prediction array
full_prediction[x_start:x_end, y_start:y_end] = prediction_shape_adjusted
# Mask out the map background pixels from the prediction
print("Applying mask to the full prediction.")
masked_prediction = prediction_mask(full_prediction, map_name)
save_results(masked_prediction, args.outputPath, map_name, legend)
print("Inference process completed. Closing HDF5 file.")
parser = argparse.ArgumentParser(description="Perform inference on a given map.")
parser.add_argument("--mapPath", required=True, help="Path to the hdf5 file.")
parser.add_argument("--jsonPath", required=True, help="Path to the JSON file that contain positions of legends.")
parser.add_argument("--featureType", choices=["Polygon", "Point", "Line", "All"], default="Polygon", help="Type of feature to detect. Three options, Polygon, Point, and, Line")
parser.add_argument("--outputPath", required=True, help="Path to save the inference results. ")
parser.add_argument("--modelPath", default="./inference_model/Unet-attentionUnet.h5", help="Path to the trained model. Default is './inference_model/Unet-attentionUnet.h5'.")
# Test command
# python inference.py --mapPath "/projects/bbym/shared/data/commonPatchData/256/OK_250K.hdf5" --jsonPath "" --featureType "Polygon" --outputPath "/projects/bbym/nathanj/attentionUnet/infer_results" --modelPath "/projects/bbym/nathanj/attentionUnet/inference_model/Unet-attentionUnet.h5"