diff --git a/inference.py b/inference.py
index 2ec4efeaba16e4fe93f2b68bfd600365c7207ebe..5d72037cc3fd143d695600dcbe2d33ae06d529cf 100644
--- a/inference.py
+++ b/inference.py
@@ -252,6 +252,7 @@ def main(args):
     
     # Get the size of the map
     map_width, map_height, _ = h5_image.get_map_size(map_name)
+    print("Map size:", h5_image.get_map_size(map_name))
     
     # Calculate the number of patches based on the patch size and border
     num_rows = math.ceil(map_width / h5_image.patch_size)
diff --git a/test_h5image.py b/test_h5image.py
index 9382243b766d6ec65c894546f5d18aa34a1e0ff8..52b6d3899c3327377a803f6cc5c28161c9634fd8 100644
--- a/test_h5image.py
+++ b/test_h5image.py
@@ -3,34 +3,25 @@ import cv2
 import tensorflow as tf
 import matplotlib.pyplot as plt
 import numpy as np
+import rasterio
 
+# Load the HDF5 file using the H5Image class
+print("Loading the HDF5 file.")
+h5_image = H5Image("/projects/bbym/shared/data/commonPatchData/256/OK_250K.hdf5", mode='r', patch_border=0)
 
-h5_image = H5Image("/projects/bbym/shared/data/commonPatchData/256/CO_Ute.hdf5", mode='r', patch_border=0)
-
+# Get map details
+print("Getting map details.")
 map_name = h5_image.get_maps()[0]
-map_array = np.array(h5_image.get_map(map_name))
-
-print("map_array.shape", map_array.shape, "type", type(map_array) )
-
-legend = h5_image.get_layers(map_name)[0]
-legend_patch = h5_image.get_legend(map_name, legend)
-legend_resized = cv2.resize(legend_patch, (256,256))
-
-# Convert to uint8 range [0, 255] if necessary
-if legend_resized.dtype == tf.float32:
-    legend_resized = (legend_resized * 255).numpy().astype(np.uint8)
-
-print(h5_image.get_layers(map_name))
-
-print("legend", legend, "legend_resized.shape", legend_resized.shape, "unique values", np.unique(legend_patch))
-
-# Create figure
-fig, ax = plt.subplots(figsize=(3, 3)) # Adjust the size as needed
-
-ax.imshow(legend_resized)
-ax.set_title('Legend')
-ax.axis('off')
-
-plt.tight_layout()
-plt.savefig('/projects/bbym/nathanj/attentionUnet/legend_output.png')
+legend = 'Cat_poly'
+
+# saving image as geotiff
+# can also use h5i.save_image(map, 'map')
+dset = h5_image.get_map(map_name)
+if dset.ndim == 3:
+    image = dset[...].transpose(2, 0, 1)  # rasterio expects bands first
+else:
+    image = np.array(dset[...], ndmin=3)
+rasterio.open(f"/projects/bbym/nathanj/attentionUnet/infer_results/{map_name}.tif", 'w', driver='GTiff', compress='lzw',
+                height=image.shape[1], width=image.shape[2], count=image.shape[0], dtype=image.dtype,
+                crs=h5_image.get_crs(map_name, legend), transform=h5_image.get_transform(map_name, legend)).write(image)