diff --git a/create_prediction_map.py b/create_prediction_map.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ba386f9d68d31da57add3371761952c23523f6b
--- /dev/null
+++ b/create_prediction_map.py
@@ -0,0 +1,115 @@
+
+
+import os
+import rasterio
+import numpy as np
+from PIL import Image
+import tensorflow as tf
+from data_util import DataLoader
+# import segmentation_models as sm
+# from keras.models import load_model
+from tensorflow.keras.models import load_model
+from unet_util import dice_coef_loss, dice_coef, jacard_coef, dice_coef_loss, Residual_CNN_block, multiplication, attention_up_and_concatenate, multiplication2, attention_up_and_concatenate2, UNET_224, evaluate_prediction_result
+
+# Set the limit to a larger value than the default
+Image.MAX_IMAGE_PIXELS = 200000000  # For example, allow up to 200 million pixels
+
+def load_image_and_predict(map_file_name, prediction_path, model):
+    """
+    Load map image, find corresponding legend images, create inputs, predict, and reconstruct images.
+
+    Parameters:
+    map_file_name (str): Name of the map image file (e.g., 'AR_Maumee.tif').
+    prediction_path (str): Path to save the predicted image.
+    """
+    # Set the paths
+    map_dir = '/projects/bbym/shared/data/cma/validation/'
+    map_img_path = os.path.join(map_dir, map_file_name)
+    json_file_name = os.path.splitext(map_file_name)[0] + '.json'
+    json_file_path = os.path.join(map_dir, json_file_name)
+
+    patch_size=(256, 256, 3)
+    overlap=30
+
+    # Instantiate DataLoader and get processed data
+    data_loader = DataLoader(map_img_path, json_file_path, patch_size, overlap)
+    processed_data = data_loader.get_processed_data()
+
+    # 'poly_legends', 'pt_legends', 'line_legends'
+    for legend in ['poly_legends']:
+        for legend_img, legend_label in processed_data[legend]:
+            # Convert legend_img back to uint8 and scale values to 0-255
+            legend_img_uint8 = tf.cast(legend_img * 255, dtype=tf.uint8).numpy()
+            legend_img_pil = Image.fromarray(legend_img_uint8)
+
+            output_legend_img_path = os.path.join(prediction_path, f"{os.path.splitext(map_file_name)[0]}_{legend_label}.png")
+            legend_img_pil.save(output_legend_img_path, 'PNG')
+
+            map_patches = processed_data['map_patches']
+            total_row, total_col, _, _, _, _ = map_patches.shape
+            predicted_patches = np.zeros((total_row, total_col, patch_size[0], patch_size[1]))
+
+            for i in range(total_row):
+                for j in range(total_col):
+                    single_map_patch = map_patches[i, j, :, :][0]
+
+                    # Concatenate along the third axis and normalize
+                    input_patch = tf.concat(axis=2, values=[single_map_patch, legend_img])
+                    input_patch = input_patch * 2.0 - 1.0
+                    
+                    # Resize the input patch
+                    input_patch_resized = tf.image.resize(input_patch, patch_size[:2])
+                    
+                    # Expand dimensions for prediction
+                    input_patch_expanded = tf.expand_dims(input_patch_resized, axis=0)
+                    
+                    # Make prediction and store it
+                    predicted_patch = model.predict(input_patch_expanded, verbose = 0)
+                    predicted_patches[i, j, :, :] = predicted_patch.squeeze()
+
+            reconstructed_image = data_loader.reconstruct_data(predicted_patches)
+            reconstructed_image = (reconstructed_image * 255).astype(np.uint8)
+
+            output_image_path = os.path.join(prediction_path, f"{os.path.splitext(map_file_name)[0]}_{legend_label}.tif")
+
+            with rasterio.open(map_img_path) as src:
+                metadata = src.meta
+
+            metadata.update({
+                'dtype': 'uint8',
+                'count': 1,
+                'height': reconstructed_image.shape[0],
+                'width': reconstructed_image.shape[1],
+                'compress': 'lzw',
+            })
+
+            with rasterio.open(output_image_path, 'w', **metadata) as dst:
+                dst.write(reconstructed_image, 1)
+
+            print(f"Predicted image saved at: {output_image_path}")
+
+################################################################
+##### Prepare the model configurations #########################
+################################################################
+name_id = 'unproecessed_legends' #You can change the id for each run so that all models and stats are saved separately.
+prediction_path = './predicts_'+name_id+'/'
+model_path = './models_'+name_id+'/'
+
+# Avaiable backbones for Unet architechture
+# 'vgg16' 'vgg19' 'resnet18' 'resnet34' 'resnet50' 'resnet101' 'resnet152' 'inceptionv3'
+# 'inceptionresnetv2' 'densenet121' 'densenet169' 'densenet201' 'seresnet18' 'seresnet34'
+# 'seresnet50' 'seresnet101' 'seresnet152', and 'attentionUnet'
+backend = 'attentionUnet'
+name = 'Unet-'+ backend
+
+finetune = False
+if (finetune): name += "_ft"
+
+model = load_model(model_path+name+'.h5',
+                    custom_objects={'multiplication': multiplication,
+                                'multiplication2': multiplication2,
+                                'dice_coef_loss':dice_coef_loss,
+                                'dice_coef':dice_coef,})
+
+# Example of how to use the function
+load_image_and_predict('AR_Maumee.tif', prediction_path, model)
diff --git a/data_util.py b/data_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcc7243721ab780772ff807556c763e90b175c54
--- /dev/null
+++ b/data_util.py
@@ -0,0 +1,163 @@
+import os
+import numpy as np
+import json
+from PIL import Image
+from patchify import patchify
+import tensorflow as tf
+
+# # Example of how to use the class
+# data_loader = DataLoader('/path/to/AR_Maumee.tif', '/path/to/AR_Maumee.json', patch_size=(256, 256, 3), overlap=30)
+# processed_data = data_loader.get_processed_data()
+# reconstructed_image = data_loader.reconstruct_data(processed_data['map_patches'])
+
+class DataLoader:
+    """
+    DataLoader class to load and process TIFF images and corresponding JSON files.
+
+    Attributes:
+    tiff_path (str): Path to the input TIFF image file.
+    json_path (str): Path to the input JSON file.
+    patch_size (tuple of int): Size of the patches to extract from the TIFF image.
+    overlap (int): Overlapping pixels between patches.
+    map_img (numpy.ndarray): Loaded and normalized map image.
+    orig_size (tuple of int): Original size of the map image.
+    json_data (dict): Loaded JSON data.
+    processed_data (dict): Processed data including map patches and legends.
+
+    Methods:
+    load_and_process(): Loads and processes the TIFF image and JSON data.
+    load_tiff(): Loads and normalizes the TIFF image.
+    load_json(): Loads JSON data.
+    process_legends(label_suffix, resize_to=(256, 256)): Processes legends based on label suffix.
+    process_data(): Extracts and processes data from the loaded TIFF and JSON.
+    get_processed_data(): Returns the processed data.
+    """
+    def __init__(self, tiff_path, json_path, patch_size=(256, 256, 3), overlap=30):
+        """
+        Initializes DataLoader with specified file paths, patch size, and overlap.
+
+        Parameters:
+        tiff_path (str): Path to the input TIFF image file.
+        json_path (str): Path to the input JSON file.
+        patch_size (tuple of int, optional): Size of patches to extract from image. Default is (256, 256, 3).
+        overlap (int, optional): Number of overlapping pixels between patches. Default is 30.
+        """
+        self.tiff_path = tiff_path
+        self.json_path = json_path
+        self.patch_size = patch_size
+        self.overlap = overlap
+        self.map_img = None
+        self.orig_size = None
+        self.json_data = None
+        self.processed_data = None
+        self.load_and_process()
+
+    def load_and_process(self):
+        """Loads and processes the TIFF image and JSON data."""
+        self.load_tiff()
+        self.load_json()
+        self.process_data()
+
+    def load_tiff(self):
+        """Loads and normalizes the TIFF image."""
+        print(f"Loading and normalizing map image: {self.tiff_path}")
+        self.map_img = Image.open(self.tiff_path)
+        self.orig_size = self.map_img.size  
+        self.map_img = np.array(self.map_img) / 255.0
+
+    def load_json(self):
+        """Loads JSON data."""
+        with open(self.json_path, 'r') as json_file:
+            self.json_data = json.load(json_file)
+
+    def process_legends(self, label_suffix, resize_to=(256, 256)):
+        """
+        Processes legends based on the label suffix.
+
+        Parameters:
+        label_suffix (str): Suffix in the label to identify the type of legend.
+        resize_to (tuple of int, optional): Size to resize the legends to. Default is (256, 256).
+
+        Returns:
+        list of tuples: List of processed legends and their corresponding labels.
+        """
+        legends = [shape for shape in self.json_data['shapes'] if label_suffix in shape['label']]
+        processed_legends = []
+        for legend in legends:
+            points = np.array(legend['points'])
+            top_left = points.min(axis=0)
+            bottom_right = points.max(axis=0)
+            legend_img = self.map_img[int(top_left[1]):int(bottom_right[1]), int(top_left[0]):int(bottom_right[0]), :]
+            legend_img = tf.image.resize(legend_img, resize_to)
+            processed_legends.append((legend_img.numpy(), legend['label']))
+        return processed_legends
+
+    def process_data(self):
+        """Extracts and processes data from the loaded TIFF and JSON."""
+        step_size = self.patch_size[0] - self.overlap
+        
+        pad_x = (step_size - (self.map_img.shape[1] % step_size)) % step_size
+        pad_y = (step_size - (self.map_img.shape[0] % step_size)) % step_size
+        self.map_img = np.pad(self.map_img, ((0, pad_y), (0, pad_x), (0, 0)), mode='constant')
+
+        print(f"Patchifying map image with overlap...")
+        map_patches = patchify(self.map_img, self.patch_size, step=step_size)
+
+        poly_legends = self.process_legends('_poly')
+        pt_legends = self.process_legends('_pt')
+        line_legends = self.process_legends('_line')
+
+        self.processed_data = {
+            "map_patches": map_patches,
+            "poly_legends": poly_legends,
+            "pt_legends": pt_legends,
+            "line_legends": line_legends,
+            "original_size": self.orig_size,
+        }
+
+    def get_processed_data(self):
+        """
+        Returns the processed data.
+
+        Raises:
+        ValueError: If data has not been loaded and processed.
+
+        Returns:
+        dict: Processed data including map patches and legends.
+        """
+        if not self.processed_data:
+            raise ValueError("Data should be loaded and processed first")
+        return self.processed_data
+
+
+    def reconstruct_data(self, patches):
+        """
+        Reconstructs an image from overlapping patches, keeping the maximum value
+        for overlapping pixels.
+        
+        Parameters:
+        patches (numpy array): Image patches.
+        
+        Returns:
+        numpy array: Reconstructed image.
+        """
+        assert self.overlap >= 0, "Overlap should be non-negative"
+        step = patches.shape[2] - self.overlap
+        img_shape = self.orig_size[::-1]  # Reverse the original size tuple to match the array shape
+        img = np.zeros(img_shape)  # Initialize the image with zeros
+        
+        for i in range(patches.shape[0]):
+            for j in range(patches.shape[1]):
+                x_start = i * step
+                y_start = j * step
+                x_end = min(x_start + patches.shape[2], img_shape[0])
+                y_end = min(y_start + patches.shape[3], img_shape[1])
+                
+                img[x_start:x_end, y_start:y_end] = np.maximum(
+                    img[x_start:x_end, y_start:y_end], 
+                    patches[i, j, :x_end-x_start, :y_end-y_start]
+                )
+        
+        return img
+
+