Skip to content
Snippets Groups Projects
Commit 0d6f888a authored by Nattapon Jaroenchai's avatar Nattapon Jaroenchai
Browse files

Create inference.py

parent d066cb21
Branches additional_input
No related tags found
1 merge request!1updated main for release
Pipeline #2487 failed
import os
import argparse
import numpy as np
import tensorflow as tf
from keras.models import load_model
from unet_util import dice_coef, dice_coef_loss, UNET_224
def load_and_preprocess_img(mapPath, jsonPath, featureType):
# TODO: Implement the function to load and preprocess the map based on the provided paths and feature type.
# This is a placeholder and may need to be adjusted based on the actual data and preprocessing steps.
map_img = tf.io.read_file(mapPath)
map_img = tf.cast(tf.io.decode_png(map_img), dtype=tf.float32) / 255.0
map_img = tf.image.resize(map_img, [256, 256])
# Additional preprocessing based on jsonPath and featureType can be added here.
return map_img
def perform_inference(map_img, model):
prediction = model.predict(map_img)
return prediction
def save_results(prediction, outputPath):
# TODO: Implement the function to save the prediction results to the specified output path.
# This is a placeholder and may need to be adjusted based on the desired output format.
with open(outputPath, 'w') as f:
f.write(str(prediction))
def main(args):
# Load and preprocess the map
map_img = load_and_preprocess_img(args.mapPath, args.jsonPath, args.featureType)
# Load the trained model
model = load_model(args.modelPath, custom_objects={'dice_coef_loss': dice_coef_loss, 'dice_coef': dice_coef})
# Perform inference
prediction = perform_inference(map_img, model)
# Save the results
save_results(prediction, args.outputPath)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Perform inference on a given map.")
parser.add_argument("--mapPath", required=True, help="Path to the map image.")
parser.add_argument("--jsonPath", required=True, help="Path to the JSON file.")
parser.add_argument("--featureType", choices=["Polygon"], default="Polygon", help="Type of feature to detect. Currently supports only 'Polygon'.")
parser.add_argument("--outputPath", required=True, help="Path to save the inference results.")
parser.add_argument("--modelPath", default="./models_Unet-attentionUnet/Unet-attentionUnet.h5", help="Path to the trained model. Default is './models_Unet-attentionUnet/Unet-attentionUnet.h5'.")
args = parser.parse_args()
main(args)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment