Skip to content
Snippets Groups Projects
Commit 5fa5080e authored by Nattapon Jaroenchai's avatar Nattapon Jaroenchai
Browse files

Update inference.py

parent 156a711d
No related branches found
No related tags found
No related merge requests found
......@@ -2,6 +2,7 @@ import argparse
import math
import cv2
import os
import time
import numpy as np
from PIL import Image
import rasterio
......@@ -215,8 +216,8 @@ def save_results(prediction, map_name, legend, outputPath):
prediction_image = np.expand_dims(prediction_image, axis=0)
rasterio.open(output_image_path, 'w', driver='GTiff', compress='lzw',
height=prediction_image.shape[0], width=prediction_image.shape[1], count=1, dtype=prediction_image.dtype,
crs=h5_image.get_crs(map_name, legend), transform=h5_image.get_transform(map_name, legend)).write(prediction_image)
height = prediction_image.shape[1], width = prediction_image.shape[2], count = prediction_image.shape[0], dtype = prediction_image.dtype,
crs = h5_image.get_crs(map_name, legend), transform = h5_image.get_transform(map_name, legend)).write(prediction_image)
def main(args):
......@@ -274,6 +275,7 @@ def main(args):
# Loop through the patches and perform inference
for legend in (map_legends):
print(f"Processing legend: {legend}")
start_time = time.time()
# Create an empty array to store the full prediction
full_prediction = np.zeros((map_width, map_height))
......@@ -319,6 +321,10 @@ def main(args):
# Save the results
print("Saving results.")
save_results(masked_prediction, map_name, legend, args.outputPath)
end_time = time.time()
total_time = end_time - start_time
print(f"Execution time for 1 legend: {total_time} seconds")
# Close the HDF5 file
print("Inference process completed. Closing HDF5 file.")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment