diff --git a/h5Image.py b/h5Image.py index f1878620663c9eb53fdc3f667da5b3352656624e..354aef0ee3848cbf5d2d181266778f1258340677 100644 --- a/h5Image.py +++ b/h5Image.py @@ -1,11 +1,14 @@ import json -import logging + +import affine +import h5py +import numpy as np +import rasterio + import math import os import os.path -import cv2 -import h5py -import numpy as np +import logging class H5Image: @@ -50,19 +53,33 @@ class H5Image: :param group: parent folder of image :return: dataset of image loaded (numpy array) """ - image = cv2.imread(filename, cv2.IMREAD_UNCHANGED) - dset = group.create_dataset(name=name, data=image, shape=image.shape, compression=self.compression) - dset.attrs.create('CLASS', 'IMAGE', dtype='S6') - dset.attrs.create('IMAGE_MINMAXRANGE', [0, 255], dtype=np.uint8) - if len(image.shape) == 3 and image.shape[2] == 3: - dset.attrs.create('IMAGE_SUBCLASS', 'IMAGE_TRUECOLOR', dtype='S16') - elif len(image.shape) == 2 or image.shape[2] == 1: - dset.attrs.create('IMAGE_SUBCLASS', 'IMAGE_GRAYSCALE', dtype='S15') - else: - raise Exception("Unknown image type") - dset.attrs.create('IMAGE_VERSION', '1.2', dtype='S4') - dset.attrs.create('INTERLACE_MODE', 'INTERLACE_PIXEL', dtype='S16') - return dset + with rasterio.open(filename) as src: + profile = src.profile + image = src.read() + if len(image.shape) == 3: + if image.shape[0] == 1: + image = image[0] + elif image.shape[0] == 3: + image = image.transpose(1, 2, 0) + dset = group.create_dataset(name=name, data=image, shape=image.shape, compression=self.compression) + dset.attrs.create('CLASS', 'IMAGE', dtype='S6') + dset.attrs.create('IMAGE_VERSION', '1.2', dtype='S4') + dset.attrs.create('INTERLACE_MODE', 'INTERLACE_PIXEL', dtype='S16') + dset.attrs.create('IMAGE_MINMAXRANGE', [0, 255], dtype=np.uint8) + if len(image.shape) == 3 and image.shape[2] == 3: + dset.attrs.create('IMAGE_SUBCLASS', 'IMAGE_TRUECOLOR', dtype='S16') + elif len(image.shape) == 2 or image.shape[2] == 1: + dset.attrs.create('IMAGE_SUBCLASS', 'IMAGE_GRAYSCALE', dtype='S15') + else: + raise Exception("Unknown image type") + if 'crs' in profile: + txt = src.profile['crs'].to_string() + dset.attrs.create('CRS', txt, dtype=f'S{len(txt)}') + if 'transform' in profile: + txt = affine.dumpsw(src.profile['transform']) + dset.attrs.create('TRANSFORM', txt, dtype=f'S{len(txt)}') + return dset + return None # add an image to the file def add_image(self, filename, folder="", mapname=""): @@ -133,14 +150,46 @@ class H5Image: except ValueError as e: logging.warning(f"Error loading {label} : {e}") valid_patches = [[int(k.split('_')[0]), int(k.split('_')[1])] for k in layers_patch.keys()] - r1 = min(valid_patches, key=lambda value: int(value[0]))[0] - r2 = max(valid_patches, key=lambda value: int(value[0]))[0] - c1 = min(valid_patches, key=lambda value: int(value[1]))[1] - c2 = max(valid_patches, key=lambda value: int(value[1]))[1] + if valid_patches: + r1 = min(valid_patches, key=lambda value: int(value[0]))[0] + r2 = max(valid_patches, key=lambda value: int(value[0]))[0] + c1 = min(valid_patches, key=lambda value: int(value[1]))[1] + c2 = max(valid_patches, key=lambda value: int(value[1]))[1] + group.attrs.update({'corners': [[r1, c1], [r2, c2]]}) group.attrs.update({'patches': json.dumps(all_patches)}) group.attrs.update({'layers_patch': json.dumps(layers_patch)}) group.attrs.update({'valid_patches': json.dumps(valid_patches)}) - group.attrs.update({'corners': [[r1, c1], [r2, c2]]}) + + def save_image(self, mapname, destination, layer=None): + """ + Save the image to disk. The image is saved as a tiff file, if no layer is given + it will write all layers for the map, and the json file to the destination, otherwise + it will just write the layer. + :param mapname: the name of the map + :param destination: the destination directory + :param layer: the name of the layer, if empty all layers are written + """ + if not os.path.exists(destination): + os.makedirs(destination) + if layer is None: + self.save_image(mapname, destination, "map") + json_data = json.loads(self.h5f[mapname].attrs['json']) + json.dump(json_data, open(os.path.join(destination, f"{mapname}.json"), "w"), indent=2) + for layer in self.get_layers(mapname): + self.save_image(mapname, destination, layer) + else: + dset = self.h5f[mapname][layer] + if dset.ndim == 3: + image = dset[...].transpose(2, 0, 1) # rasterio expects bands first + else: + image = np.array(dset[...], ndmin=3) + if layer == "map": + filename = os.path.join(destination, f"{mapname}.tif") + else: + filename = os.path.join(destination, f"{mapname}_{layer}.tif") + rasterio.open(filename, 'w', driver='GTiff', compress='lzw', + height=image.shape[1], width=image.shape[2], count=image.shape[0], dtype=image.dtype, + crs=self.get_crs(mapname, layer), transform=self.get_transform(mapname, layer)).write(image) # get list of all maps def get_maps(self): @@ -168,6 +217,28 @@ class H5Image: """ return self.h5f[mapname]['map'].shape + def get_crs(self, mapname, layer='map'): + """ + Returns the crs of the layer (defaults to the map). + :param mapname: the name of the map + :param layer: the name of the layer, defaults to the map + :return: crs of the map + """ + if 'CRS' in self.h5f[mapname][layer].attrs: + return rasterio.CRS.from_string(self.h5f[mapname][layer].attrs['CRS'].decode('utf-8')) + return None + + def get_transform(self, mapname, layer='map'): + """ + Returns the transform of the layer (defaults to the map). + :param mapname: the name of the map + :param layer: the name of the layer, defaults to the map + :return: transform of the map + """ + if 'TRANSFORM' in self.h5f[mapname][layer].attrs: + return affine.loadsw(self.h5f[mapname][layer].attrs['TRANSFORM'].decode('utf-8')) + return None + def get_map_corners(self, mapname): """ Returns the bounds of the map. diff --git a/inference.py b/inference.py index f5015bf13824c0da9e75288a6a5b161af012dac9..2ec4efeaba16e4fe93f2b68bfd600365c7207ebe 100644 --- a/inference.py +++ b/inference.py @@ -2,6 +2,7 @@ import argparse import math import cv2 import os +import time import numpy as np from PIL import Image import rasterio @@ -109,7 +110,7 @@ def prediction_mask(prediction_result, map_name): # Get the map array corresponding to the given map name map_array = np.array(h5_image.get_map(map_name)) - print("map_array", map_array.shape) + # print("map_array", map_array.shape) # Convert the RGB map array to grayscale for further processing gray = cv2.cvtColor(map_array, cv2.COLOR_BGR2GRAY) @@ -198,32 +199,26 @@ def save_results(prediction, map_name, legend, outputPath): - legend: The legend associated with the prediction. - outputPath: The directory where the results should be saved. """ + + global h5_image + output_image_path = os.path.join(outputPath, f"{map_name}_{legend}.tif") # Convert the prediction to an image # Note: The prediction array may need to be scaled or converted before saving as an image - prediction_image = Image.fromarray((prediction*255).astype(np.uint8)) + # prediction_image = Image.fromarray((prediction*255).astype(np.uint8)) # Save the prediction as a tiff image - prediction_image.save(output_image_path, 'TIFF') + # prediction_image.save(output_image_path, 'TIFF') - ### Waiting for georeferencing data - # This section will be used in future releases to save georeferenced images. + prediction_image = (prediction*255).astype(np.uint8) - ## Waiting for georeferencing data - # with rasterio.open(map_img_path) as src: - # metadata = src.meta + prediction_image = np.expand_dims(prediction_image, axis=0) - # metadata.update({ - # 'dtype': 'uint8', - # 'count': 1, - # 'height': reconstructed_image.shape[0], - # 'width': reconstructed_image.shape[1], - # 'compress': 'lzw', - # }) + rasterio.open(output_image_path, 'w', driver='GTiff', compress='lzw', + height = prediction_image.shape[1], width = prediction_image.shape[2], count = prediction_image.shape[0], dtype = prediction_image.dtype, + crs = h5_image.get_crs(map_name, legend), transform = h5_image.get_transform(map_name, legend)).write(prediction_image) - # with rasterio.open(output_image_path, 'w', **metadata) as dst: - # dst.write(reconstructed_image, 1) def main(args): """ @@ -243,7 +238,7 @@ def main(args): map_name = h5_image.get_maps()[0] print(f"Map Name: {map_name}") all_map_legends = h5_image.get_layers(map_name) - print(f"All Map Legends: {all_map_legends}") + # print(f"All Map Legends: {all_map_legends}") # Filter the legends based on the feature type if args.featureType == "Polygon": @@ -280,6 +275,7 @@ def main(args): # Loop through the patches and perform inference for legend in (map_legends): print(f"Processing legend: {legend}") + start_time = time.time() # Create an empty array to store the full prediction full_prediction = np.zeros((map_width, map_height)) @@ -318,13 +314,17 @@ def main(args): # Mask out the map background pixels from the prediction print("Applying mask to the full prediction.") - masked_prediction = prediction_mask(full_prediction, map_name, legend, args.outputPath) + masked_prediction = prediction_mask(full_prediction, map_name) save_plot_as_png(masked_prediction, map_name, legend, args.outputPath) # Save the results print("Saving results.") save_results(masked_prediction, map_name, legend, args.outputPath) + + end_time = time.time() + total_time = end_time - start_time + print(f"Execution time for 1 legend: {total_time} seconds") # Close the HDF5 file print("Inference process completed. Closing HDF5 file.")