Skip to content
Snippets Groups Projects
Commit 0bcd7160 authored by Nattapon Jaroenchai's avatar Nattapon Jaroenchai
Browse files

Update inference_pipeline.py

parent 62592431
No related branches found
No related tags found
No related merge requests found
......@@ -156,7 +156,7 @@ def stitch_patches(predictions, patch_locations, original_image_shape, padded_im
logger.info("Stitching of patches completed.")
return full_prediction
def save_prediction_to_tiff(prediction, file_path, transform=None, crs=None):
def save_prediction_to_tiff(prediction, file_path, image_file_name):
"""
Save the prediction result to a TIFF file.
......@@ -166,20 +166,39 @@ def save_prediction_to_tiff(prediction, file_path, transform=None, crs=None):
- transform: rasterio Affine Transform, optional, Geospatial transform of the prediction (default: None).
- crs: str, optional, Coordinate Reference System of the prediction (default: None).
"""
with rasterio.open(
file_path,
'w',
driver='GTiff',
height=prediction.shape[0],
width=prediction.shape[1],
count=1,
dtype=prediction.dtype,
crs=crs,
transform=transform
) as dst:
dst.write(prediction, 1)
prediction_image = (prediction*255).astype(np.uint8)
prediction_image = np.expand_dims(prediction_image, axis=0)
with rasterio.open(image_file_name) as src_dataset:
# Get a copy of the source dataset's profile. Thus our
# destination dataset will have the same dimensions,
# number of bands, data type, and georeferencing as the
# source dataset.
kwds = src_dataset.profile
# Change the format driver for the destination dataset to
# 'GTiff', short for GeoTIFF.
kwds['driver'] = 'GTiff'
# Add GeoTIFF-specific keyword arguments.
kwds['dtype'] = prediction_image.dtype
kwds['height'] = prediction_image.shape[1]
kwds['width'] = prediction_image.shape[2]
kwds['compress'] = 'lzw'
with rasterio.open(file_path, 'w', **kwds) as dst:
dst.write(prediction, 1)
logger.info(f"Saved prediction to {file_path}")
def inference_image(image, image_file_name, legends, model, feature_type, patch_size=256, save_dir='predictions'):
"""
Perform inference on an image using a trained model and extract predictions for each legend.
......@@ -212,7 +231,7 @@ def inference_image(image, image_file_name, legends, model, feature_type, patch_
file_path = os.path.join(save_dir, f"{base_name}_{legend}.tiff")
# Save the prediction as a TIFF file
save_prediction_to_tiff(masked_prediction, file_path)
save_prediction_to_tiff(masked_prediction, file_path, image_file_name)
predictions[legend] = masked_prediction
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment