diff --git a/h5Image.py b/h5Image.py
index f1878620663c9eb53fdc3f667da5b3352656624e..354aef0ee3848cbf5d2d181266778f1258340677 100644
--- a/h5Image.py
+++ b/h5Image.py
@@ -1,11 +1,14 @@
 import json
-import logging
+
+import affine
+import h5py
+import numpy as np
+import rasterio
+
 import math
 import os
 import os.path
-import cv2
-import h5py
-import numpy as np
+import logging
 
 
 class H5Image:
@@ -50,19 +53,33 @@ class H5Image:
         :param group: parent folder of image
         :return: dataset of image loaded (numpy array)
         """
-        image = cv2.imread(filename, cv2.IMREAD_UNCHANGED)
-        dset = group.create_dataset(name=name, data=image, shape=image.shape, compression=self.compression)
-        dset.attrs.create('CLASS', 'IMAGE', dtype='S6')
-        dset.attrs.create('IMAGE_MINMAXRANGE', [0, 255], dtype=np.uint8)
-        if len(image.shape) == 3 and image.shape[2] == 3:
-            dset.attrs.create('IMAGE_SUBCLASS', 'IMAGE_TRUECOLOR', dtype='S16')
-        elif len(image.shape) == 2 or image.shape[2] == 1:
-            dset.attrs.create('IMAGE_SUBCLASS', 'IMAGE_GRAYSCALE', dtype='S15')
-        else:
-            raise Exception("Unknown image type")
-        dset.attrs.create('IMAGE_VERSION', '1.2', dtype='S4')
-        dset.attrs.create('INTERLACE_MODE', 'INTERLACE_PIXEL', dtype='S16')
-        return dset
+        with rasterio.open(filename) as src:
+            profile = src.profile
+            image = src.read()
+            if len(image.shape) == 3:
+                if image.shape[0] == 1:
+                    image = image[0]
+                elif image.shape[0] == 3:
+                    image = image.transpose(1, 2, 0)
+            dset = group.create_dataset(name=name, data=image, shape=image.shape, compression=self.compression)
+            dset.attrs.create('CLASS', 'IMAGE', dtype='S6')
+            dset.attrs.create('IMAGE_VERSION', '1.2', dtype='S4')
+            dset.attrs.create('INTERLACE_MODE', 'INTERLACE_PIXEL', dtype='S16')
+            dset.attrs.create('IMAGE_MINMAXRANGE', [0, 255], dtype=np.uint8)
+            if len(image.shape) == 3 and image.shape[2] == 3:
+                dset.attrs.create('IMAGE_SUBCLASS', 'IMAGE_TRUECOLOR', dtype='S16')
+            elif len(image.shape) == 2 or image.shape[2] == 1:
+                dset.attrs.create('IMAGE_SUBCLASS', 'IMAGE_GRAYSCALE', dtype='S15')
+            else:
+                raise Exception("Unknown image type")
+            if 'crs' in profile:
+                txt = src.profile['crs'].to_string()
+                dset.attrs.create('CRS', txt, dtype=f'S{len(txt)}')
+            if 'transform' in profile:
+                txt = affine.dumpsw(src.profile['transform'])
+                dset.attrs.create('TRANSFORM', txt, dtype=f'S{len(txt)}')
+            return dset
+        return None
 
     # add an image to the file
     def add_image(self, filename, folder="", mapname=""):
@@ -133,14 +150,46 @@ class H5Image:
             except ValueError as e:
                 logging.warning(f"Error loading {label} : {e}")
         valid_patches = [[int(k.split('_')[0]), int(k.split('_')[1])] for k in layers_patch.keys()]
-        r1 = min(valid_patches, key=lambda value: int(value[0]))[0]
-        r2 = max(valid_patches, key=lambda value: int(value[0]))[0]
-        c1 = min(valid_patches, key=lambda value: int(value[1]))[1]
-        c2 = max(valid_patches, key=lambda value: int(value[1]))[1]
+        if valid_patches:
+            r1 = min(valid_patches, key=lambda value: int(value[0]))[0]
+            r2 = max(valid_patches, key=lambda value: int(value[0]))[0]
+            c1 = min(valid_patches, key=lambda value: int(value[1]))[1]
+            c2 = max(valid_patches, key=lambda value: int(value[1]))[1]
+            group.attrs.update({'corners': [[r1, c1], [r2, c2]]})
         group.attrs.update({'patches': json.dumps(all_patches)})
         group.attrs.update({'layers_patch': json.dumps(layers_patch)})
         group.attrs.update({'valid_patches': json.dumps(valid_patches)})
-        group.attrs.update({'corners': [[r1, c1], [r2, c2]]})
+
+    def save_image(self, mapname, destination, layer=None):
+        """
+        Save the image to disk. The image is saved as a tiff file, if no layer is given
+        it will write all layers for the map, and the json file to the destination, otherwise
+        it will just write the layer.
+        :param mapname: the name of the map
+        :param destination: the destination directory
+        :param layer: the name of the layer, if empty all layers are written
+        """
+        if not os.path.exists(destination):
+            os.makedirs(destination)
+        if layer is None:
+            self.save_image(mapname, destination, "map")
+            json_data = json.loads(self.h5f[mapname].attrs['json'])
+            json.dump(json_data, open(os.path.join(destination, f"{mapname}.json"), "w"), indent=2)
+            for layer in self.get_layers(mapname):
+                self.save_image(mapname, destination, layer)
+        else:
+            dset = self.h5f[mapname][layer]
+            if dset.ndim == 3:
+                image = dset[...].transpose(2, 0, 1)  # rasterio expects bands first
+            else:
+                image = np.array(dset[...], ndmin=3)
+            if layer == "map":
+                filename = os.path.join(destination, f"{mapname}.tif")
+            else:
+                filename = os.path.join(destination, f"{mapname}_{layer}.tif")
+            rasterio.open(filename, 'w', driver='GTiff', compress='lzw',
+                            height=image.shape[1], width=image.shape[2], count=image.shape[0], dtype=image.dtype,
+                            crs=self.get_crs(mapname, layer), transform=self.get_transform(mapname, layer)).write(image)
 
     # get list of all maps
     def get_maps(self):
@@ -168,6 +217,28 @@ class H5Image:
         """
         return self.h5f[mapname]['map'].shape
 
+    def get_crs(self, mapname, layer='map'):
+        """
+        Returns the crs of the layer (defaults to the map).
+        :param mapname: the name of the map
+        :param layer: the name of the layer, defaults to the map
+        :return: crs of the map
+        """
+        if 'CRS' in self.h5f[mapname][layer].attrs:
+            return rasterio.CRS.from_string(self.h5f[mapname][layer].attrs['CRS'].decode('utf-8'))
+        return None
+
+    def get_transform(self, mapname, layer='map'):
+        """
+        Returns the transform of the layer (defaults to the map).
+        :param mapname: the name of the map
+        :param layer: the name of the layer, defaults to the map
+        :return: transform of the map
+        """
+        if 'TRANSFORM' in self.h5f[mapname][layer].attrs:
+            return affine.loadsw(self.h5f[mapname][layer].attrs['TRANSFORM'].decode('utf-8'))
+        return None
+
     def get_map_corners(self, mapname):
         """
         Returns the bounds of the map.
diff --git a/inference.py b/inference.py
index f5015bf13824c0da9e75288a6a5b161af012dac9..2ec4efeaba16e4fe93f2b68bfd600365c7207ebe 100644
--- a/inference.py
+++ b/inference.py
@@ -2,6 +2,7 @@ import argparse
 import math
 import cv2
 import os
+import time
 import numpy as np
 from PIL import Image
 import rasterio
@@ -109,7 +110,7 @@ def prediction_mask(prediction_result, map_name):
 
     # Get the map array corresponding to the given map name
     map_array = np.array(h5_image.get_map(map_name))
-    print("map_array", map_array.shape)
+    # print("map_array", map_array.shape)
 
     # Convert the RGB map array to grayscale for further processing
     gray = cv2.cvtColor(map_array, cv2.COLOR_BGR2GRAY)
@@ -198,32 +199,26 @@ def save_results(prediction, map_name, legend, outputPath):
     - legend: The legend associated with the prediction.
     - outputPath: The directory where the results should be saved.
     """
+
+    global h5_image
+
     output_image_path = os.path.join(outputPath, f"{map_name}_{legend}.tif")
 
     # Convert the prediction to an image
     # Note: The prediction array may need to be scaled or converted before saving as an image
-    prediction_image = Image.fromarray((prediction*255).astype(np.uint8))
+    # prediction_image = Image.fromarray((prediction*255).astype(np.uint8))
 
     # Save the prediction as a tiff image
-    prediction_image.save(output_image_path, 'TIFF')
+    # prediction_image.save(output_image_path, 'TIFF')
 
-    ### Waiting for georeferencing data
-    # This section will be used in future releases to save georeferenced images.
+    prediction_image = (prediction*255).astype(np.uint8)
 
-    ## Waiting for georeferencing data
-    # with rasterio.open(map_img_path) as src:
-    #     metadata = src.meta
+    prediction_image = np.expand_dims(prediction_image, axis=0)
 
-    # metadata.update({
-    #     'dtype': 'uint8',
-    #     'count': 1,
-    #     'height': reconstructed_image.shape[0],
-    #     'width': reconstructed_image.shape[1],
-    #     'compress': 'lzw',
-    # })
+    rasterio.open(output_image_path, 'w', driver='GTiff', compress='lzw',
+                height = prediction_image.shape[1], width = prediction_image.shape[2], count = prediction_image.shape[0], dtype = prediction_image.dtype,
+                crs = h5_image.get_crs(map_name, legend), transform = h5_image.get_transform(map_name, legend)).write(prediction_image)
 
-    # with rasterio.open(output_image_path, 'w', **metadata) as dst:
-    #     dst.write(reconstructed_image, 1)
 
 def main(args):
     """
@@ -243,7 +238,7 @@ def main(args):
     map_name = h5_image.get_maps()[0]
     print(f"Map Name: {map_name}")
     all_map_legends = h5_image.get_layers(map_name)
-    print(f"All Map Legends: {all_map_legends}")
+    # print(f"All Map Legends: {all_map_legends}")
     
     # Filter the legends based on the feature type
     if args.featureType == "Polygon":
@@ -280,6 +275,7 @@ def main(args):
     # Loop through the patches and perform inference
     for legend in (map_legends):
         print(f"Processing legend: {legend}")
+        start_time = time.time()
         
         # Create an empty array to store the full prediction
         full_prediction = np.zeros((map_width, map_height))
@@ -318,13 +314,17 @@ def main(args):
         
         # Mask out the map background pixels from the prediction
         print("Applying mask to the full prediction.")
-        masked_prediction = prediction_mask(full_prediction, map_name, legend, args.outputPath)
+        masked_prediction = prediction_mask(full_prediction, map_name)
 
         save_plot_as_png(masked_prediction, map_name, legend, args.outputPath)
 
         # Save the results
         print("Saving results.")
         save_results(masked_prediction, map_name, legend, args.outputPath)
+        
+        end_time = time.time()
+        total_time = end_time - start_time
+        print(f"Execution time for 1 legend: {total_time} seconds")
     
     # Close the HDF5 file
     print("Inference process completed. Closing HDF5 file.")