diff --git a/inference.py b/inference.py index 2ec4efeaba16e4fe93f2b68bfd600365c7207ebe..5d72037cc3fd143d695600dcbe2d33ae06d529cf 100644 --- a/inference.py +++ b/inference.py @@ -252,6 +252,7 @@ def main(args): # Get the size of the map map_width, map_height, _ = h5_image.get_map_size(map_name) + print("Map size:", h5_image.get_map_size(map_name)) # Calculate the number of patches based on the patch size and border num_rows = math.ceil(map_width / h5_image.patch_size) diff --git a/test_h5image.py b/test_h5image.py index 9382243b766d6ec65c894546f5d18aa34a1e0ff8..52b6d3899c3327377a803f6cc5c28161c9634fd8 100644 --- a/test_h5image.py +++ b/test_h5image.py @@ -3,34 +3,25 @@ import cv2 import tensorflow as tf import matplotlib.pyplot as plt import numpy as np +import rasterio +# Load the HDF5 file using the H5Image class +print("Loading the HDF5 file.") +h5_image = H5Image("/projects/bbym/shared/data/commonPatchData/256/OK_250K.hdf5", mode='r', patch_border=0) -h5_image = H5Image("/projects/bbym/shared/data/commonPatchData/256/CO_Ute.hdf5", mode='r', patch_border=0) - +# Get map details +print("Getting map details.") map_name = h5_image.get_maps()[0] -map_array = np.array(h5_image.get_map(map_name)) - -print("map_array.shape", map_array.shape, "type", type(map_array) ) - -legend = h5_image.get_layers(map_name)[0] -legend_patch = h5_image.get_legend(map_name, legend) -legend_resized = cv2.resize(legend_patch, (256,256)) - -# Convert to uint8 range [0, 255] if necessary -if legend_resized.dtype == tf.float32: - legend_resized = (legend_resized * 255).numpy().astype(np.uint8) - -print(h5_image.get_layers(map_name)) - -print("legend", legend, "legend_resized.shape", legend_resized.shape, "unique values", np.unique(legend_patch)) - -# Create figure -fig, ax = plt.subplots(figsize=(3, 3)) # Adjust the size as needed - -ax.imshow(legend_resized) -ax.set_title('Legend') -ax.axis('off') - -plt.tight_layout() -plt.savefig('/projects/bbym/nathanj/attentionUnet/legend_output.png') +legend = 'Cat_poly' + +# saving image as geotiff +# can also use h5i.save_image(map, 'map') +dset = h5_image.get_map(map_name) +if dset.ndim == 3: + image = dset[...].transpose(2, 0, 1) # rasterio expects bands first +else: + image = np.array(dset[...], ndmin=3) +rasterio.open(f"/projects/bbym/nathanj/attentionUnet/infer_results/{map_name}.tif", 'w', driver='GTiff', compress='lzw', + height=image.shape[1], width=image.shape[2], count=image.shape[0], dtype=image.dtype, + crs=h5_image.get_crs(map_name, legend), transform=h5_image.get_transform(map_name, legend)).write(image)