diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..63f5006ee1878895f85bba109bb014d7de0fad73 --- /dev/null +++ b/inference.py @@ -0,0 +1,52 @@ + +import os +import argparse +import numpy as np +import tensorflow as tf +from keras.models import load_model +from unet_util import dice_coef, dice_coef_loss, UNET_224 + +def load_and_preprocess_img(mapPath, jsonPath, featureType): + # TODO: Implement the function to load and preprocess the map based on the provided paths and feature type. + # This is a placeholder and may need to be adjusted based on the actual data and preprocessing steps. + map_img = tf.io.read_file(mapPath) + map_img = tf.cast(tf.io.decode_png(map_img), dtype=tf.float32) / 255.0 + map_img = tf.image.resize(map_img, [256, 256]) + + # Additional preprocessing based on jsonPath and featureType can be added here. + + return map_img + +def perform_inference(map_img, model): + prediction = model.predict(map_img) + return prediction + +def save_results(prediction, outputPath): + # TODO: Implement the function to save the prediction results to the specified output path. + # This is a placeholder and may need to be adjusted based on the desired output format. + with open(outputPath, 'w') as f: + f.write(str(prediction)) + +def main(args): + # Load and preprocess the map + map_img = load_and_preprocess_img(args.mapPath, args.jsonPath, args.featureType) + + # Load the trained model + model = load_model(args.modelPath, custom_objects={'dice_coef_loss': dice_coef_loss, 'dice_coef': dice_coef}) + + # Perform inference + prediction = perform_inference(map_img, model) + + # Save the results + save_results(prediction, args.outputPath) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Perform inference on a given map.") + parser.add_argument("--mapPath", required=True, help="Path to the map image.") + parser.add_argument("--jsonPath", required=True, help="Path to the JSON file.") + parser.add_argument("--featureType", choices=["Polygon"], default="Polygon", help="Type of feature to detect. Currently supports only 'Polygon'.") + parser.add_argument("--outputPath", required=True, help="Path to save the inference results.") + parser.add_argument("--modelPath", default="./models_Unet-attentionUnet/Unet-attentionUnet.h5", help="Path to the trained model. Default is './models_Unet-attentionUnet/Unet-attentionUnet.h5'.") + + args = parser.parse_args() + main(args)