diff --git a/create_prediction_map.py b/create_prediction_map.py new file mode 100644 index 0000000000000000000000000000000000000000..0ba386f9d68d31da57add3371761952c23523f6b --- /dev/null +++ b/create_prediction_map.py @@ -0,0 +1,115 @@ + + +import os +import rasterio +import numpy as np +from PIL import Image +import tensorflow as tf +from data_util import DataLoader +# import segmentation_models as sm +# from keras.models import load_model +from tensorflow.keras.models import load_model +from unet_util import dice_coef_loss, dice_coef, jacard_coef, dice_coef_loss, Residual_CNN_block, multiplication, attention_up_and_concatenate, multiplication2, attention_up_and_concatenate2, UNET_224, evaluate_prediction_result + +# Set the limit to a larger value than the default +Image.MAX_IMAGE_PIXELS = 200000000 # For example, allow up to 200 million pixels + +def load_image_and_predict(map_file_name, prediction_path, model): + """ + Load map image, find corresponding legend images, create inputs, predict, and reconstruct images. + + Parameters: + map_file_name (str): Name of the map image file (e.g., 'AR_Maumee.tif'). + prediction_path (str): Path to save the predicted image. + """ + # Set the paths + map_dir = '/projects/bbym/shared/data/cma/validation/' + map_img_path = os.path.join(map_dir, map_file_name) + json_file_name = os.path.splitext(map_file_name)[0] + '.json' + json_file_path = os.path.join(map_dir, json_file_name) + + patch_size=(256, 256, 3) + overlap=30 + + # Instantiate DataLoader and get processed data + data_loader = DataLoader(map_img_path, json_file_path, patch_size, overlap) + processed_data = data_loader.get_processed_data() + + # 'poly_legends', 'pt_legends', 'line_legends' + for legend in ['poly_legends']: + for legend_img, legend_label in processed_data[legend]: + # Convert legend_img back to uint8 and scale values to 0-255 + legend_img_uint8 = tf.cast(legend_img * 255, dtype=tf.uint8).numpy() + legend_img_pil = Image.fromarray(legend_img_uint8) + + output_legend_img_path = os.path.join(prediction_path, f"{os.path.splitext(map_file_name)[0]}_{legend_label}.png") + legend_img_pil.save(output_legend_img_path, 'PNG') + + map_patches = processed_data['map_patches'] + total_row, total_col, _, _, _, _ = map_patches.shape + predicted_patches = np.zeros((total_row, total_col, patch_size[0], patch_size[1])) + + for i in range(total_row): + for j in range(total_col): + single_map_patch = map_patches[i, j, :, :][0] + + # Concatenate along the third axis and normalize + input_patch = tf.concat(axis=2, values=[single_map_patch, legend_img]) + input_patch = input_patch * 2.0 - 1.0 + + # Resize the input patch + input_patch_resized = tf.image.resize(input_patch, patch_size[:2]) + + # Expand dimensions for prediction + input_patch_expanded = tf.expand_dims(input_patch_resized, axis=0) + + # Make prediction and store it + predicted_patch = model.predict(input_patch_expanded, verbose = 0) + predicted_patches[i, j, :, :] = predicted_patch.squeeze() + + reconstructed_image = data_loader.reconstruct_data(predicted_patches) + reconstructed_image = (reconstructed_image * 255).astype(np.uint8) + + output_image_path = os.path.join(prediction_path, f"{os.path.splitext(map_file_name)[0]}_{legend_label}.tif") + + with rasterio.open(map_img_path) as src: + metadata = src.meta + + metadata.update({ + 'dtype': 'uint8', + 'count': 1, + 'height': reconstructed_image.shape[0], + 'width': reconstructed_image.shape[1], + 'compress': 'lzw', + }) + + with rasterio.open(output_image_path, 'w', **metadata) as dst: + dst.write(reconstructed_image, 1) + + print(f"Predicted image saved at: {output_image_path}") + +################################################################ +##### Prepare the model configurations ######################### +################################################################ +name_id = 'unproecessed_legends' #You can change the id for each run so that all models and stats are saved separately. +prediction_path = './predicts_'+name_id+'/' +model_path = './models_'+name_id+'/' + +# Avaiable backbones for Unet architechture +# 'vgg16' 'vgg19' 'resnet18' 'resnet34' 'resnet50' 'resnet101' 'resnet152' 'inceptionv3' +# 'inceptionresnetv2' 'densenet121' 'densenet169' 'densenet201' 'seresnet18' 'seresnet34' +# 'seresnet50' 'seresnet101' 'seresnet152', and 'attentionUnet' +backend = 'attentionUnet' +name = 'Unet-'+ backend + +finetune = False +if (finetune): name += "_ft" + +model = load_model(model_path+name+'.h5', + custom_objects={'multiplication': multiplication, + 'multiplication2': multiplication2, + 'dice_coef_loss':dice_coef_loss, + 'dice_coef':dice_coef,}) + +# Example of how to use the function +load_image_and_predict('AR_Maumee.tif', prediction_path, model) diff --git a/data_util.py b/data_util.py new file mode 100644 index 0000000000000000000000000000000000000000..bcc7243721ab780772ff807556c763e90b175c54 --- /dev/null +++ b/data_util.py @@ -0,0 +1,163 @@ +import os +import numpy as np +import json +from PIL import Image +from patchify import patchify +import tensorflow as tf + +# # Example of how to use the class +# data_loader = DataLoader('/path/to/AR_Maumee.tif', '/path/to/AR_Maumee.json', patch_size=(256, 256, 3), overlap=30) +# processed_data = data_loader.get_processed_data() +# reconstructed_image = data_loader.reconstruct_data(processed_data['map_patches']) + +class DataLoader: + """ + DataLoader class to load and process TIFF images and corresponding JSON files. + + Attributes: + tiff_path (str): Path to the input TIFF image file. + json_path (str): Path to the input JSON file. + patch_size (tuple of int): Size of the patches to extract from the TIFF image. + overlap (int): Overlapping pixels between patches. + map_img (numpy.ndarray): Loaded and normalized map image. + orig_size (tuple of int): Original size of the map image. + json_data (dict): Loaded JSON data. + processed_data (dict): Processed data including map patches and legends. + + Methods: + load_and_process(): Loads and processes the TIFF image and JSON data. + load_tiff(): Loads and normalizes the TIFF image. + load_json(): Loads JSON data. + process_legends(label_suffix, resize_to=(256, 256)): Processes legends based on label suffix. + process_data(): Extracts and processes data from the loaded TIFF and JSON. + get_processed_data(): Returns the processed data. + """ + def __init__(self, tiff_path, json_path, patch_size=(256, 256, 3), overlap=30): + """ + Initializes DataLoader with specified file paths, patch size, and overlap. + + Parameters: + tiff_path (str): Path to the input TIFF image file. + json_path (str): Path to the input JSON file. + patch_size (tuple of int, optional): Size of patches to extract from image. Default is (256, 256, 3). + overlap (int, optional): Number of overlapping pixels between patches. Default is 30. + """ + self.tiff_path = tiff_path + self.json_path = json_path + self.patch_size = patch_size + self.overlap = overlap + self.map_img = None + self.orig_size = None + self.json_data = None + self.processed_data = None + self.load_and_process() + + def load_and_process(self): + """Loads and processes the TIFF image and JSON data.""" + self.load_tiff() + self.load_json() + self.process_data() + + def load_tiff(self): + """Loads and normalizes the TIFF image.""" + print(f"Loading and normalizing map image: {self.tiff_path}") + self.map_img = Image.open(self.tiff_path) + self.orig_size = self.map_img.size + self.map_img = np.array(self.map_img) / 255.0 + + def load_json(self): + """Loads JSON data.""" + with open(self.json_path, 'r') as json_file: + self.json_data = json.load(json_file) + + def process_legends(self, label_suffix, resize_to=(256, 256)): + """ + Processes legends based on the label suffix. + + Parameters: + label_suffix (str): Suffix in the label to identify the type of legend. + resize_to (tuple of int, optional): Size to resize the legends to. Default is (256, 256). + + Returns: + list of tuples: List of processed legends and their corresponding labels. + """ + legends = [shape for shape in self.json_data['shapes'] if label_suffix in shape['label']] + processed_legends = [] + for legend in legends: + points = np.array(legend['points']) + top_left = points.min(axis=0) + bottom_right = points.max(axis=0) + legend_img = self.map_img[int(top_left[1]):int(bottom_right[1]), int(top_left[0]):int(bottom_right[0]), :] + legend_img = tf.image.resize(legend_img, resize_to) + processed_legends.append((legend_img.numpy(), legend['label'])) + return processed_legends + + def process_data(self): + """Extracts and processes data from the loaded TIFF and JSON.""" + step_size = self.patch_size[0] - self.overlap + + pad_x = (step_size - (self.map_img.shape[1] % step_size)) % step_size + pad_y = (step_size - (self.map_img.shape[0] % step_size)) % step_size + self.map_img = np.pad(self.map_img, ((0, pad_y), (0, pad_x), (0, 0)), mode='constant') + + print(f"Patchifying map image with overlap...") + map_patches = patchify(self.map_img, self.patch_size, step=step_size) + + poly_legends = self.process_legends('_poly') + pt_legends = self.process_legends('_pt') + line_legends = self.process_legends('_line') + + self.processed_data = { + "map_patches": map_patches, + "poly_legends": poly_legends, + "pt_legends": pt_legends, + "line_legends": line_legends, + "original_size": self.orig_size, + } + + def get_processed_data(self): + """ + Returns the processed data. + + Raises: + ValueError: If data has not been loaded and processed. + + Returns: + dict: Processed data including map patches and legends. + """ + if not self.processed_data: + raise ValueError("Data should be loaded and processed first") + return self.processed_data + + + def reconstruct_data(self, patches): + """ + Reconstructs an image from overlapping patches, keeping the maximum value + for overlapping pixels. + + Parameters: + patches (numpy array): Image patches. + + Returns: + numpy array: Reconstructed image. + """ + assert self.overlap >= 0, "Overlap should be non-negative" + step = patches.shape[2] - self.overlap + img_shape = self.orig_size[::-1] # Reverse the original size tuple to match the array shape + img = np.zeros(img_shape) # Initialize the image with zeros + + for i in range(patches.shape[0]): + for j in range(patches.shape[1]): + x_start = i * step + y_start = j * step + x_end = min(x_start + patches.shape[2], img_shape[0]) + y_end = min(y_start + patches.shape[3], img_shape[1]) + + img[x_start:x_end, y_start:y_end] = np.maximum( + img[x_start:x_end, y_start:y_end], + patches[i, j, :x_end-x_start, :y_end-y_start] + ) + + return img + +