Newer
Older
import tensorflow as tf
import matplotlib.pyplot as plt
# Load the HDF5 file using the H5Image class
print("Loading the HDF5 file.")
h5_image = H5Image("/projects/bbym/shared/data/commonPatchData/256/OK_250K.hdf5", mode='r', patch_border=0)
# Get map details
print("Getting map details.")
legend = 'Cat_poly'
# saving image as geotiff
# can also use h5i.save_image(map, 'map')
dset = h5_image.get_map(map_name)
if dset.ndim == 3:
image = dset[...].transpose(2, 0, 1) # rasterio expects bands first
else:
image = np.array(dset[...], ndmin=3)
rasterio.open(f"/projects/bbym/nathanj/attentionUnet/infer_results/{map_name}.tif", 'w', driver='GTiff', compress='lzw',
height=image.shape[1], width=image.shape[2], count=image.shape[0], dtype=image.dtype,
crs=h5_image.get_crs(map_name, legend), transform=h5_image.get_transform(map_name, legend)).write(image)