Skip to content
Snippets Groups Projects
Commit 141db921 authored by Nattapon Jaroenchai's avatar Nattapon Jaroenchai
Browse files

create data loader class for input processing

parent e23c2f24
Branches additional_input
No related tags found
No related merge requests found
import os
import rasterio
import numpy as np
from PIL import Image
import tensorflow as tf
from data_util import DataLoader
# import segmentation_models as sm
# from keras.models import load_model
from tensorflow.keras.models import load_model
from unet_util import dice_coef_loss, dice_coef, jacard_coef, dice_coef_loss, Residual_CNN_block, multiplication, attention_up_and_concatenate, multiplication2, attention_up_and_concatenate2, UNET_224, evaluate_prediction_result
# Set the limit to a larger value than the default
Image.MAX_IMAGE_PIXELS = 200000000 # For example, allow up to 200 million pixels
def load_image_and_predict(map_file_name, prediction_path, model):
"""
Load map image, find corresponding legend images, create inputs, predict, and reconstruct images.
Parameters:
map_file_name (str): Name of the map image file (e.g., 'AR_Maumee.tif').
prediction_path (str): Path to save the predicted image.
"""
# Set the paths
map_dir = '/projects/bbym/shared/data/cma/validation/'
map_img_path = os.path.join(map_dir, map_file_name)
json_file_name = os.path.splitext(map_file_name)[0] + '.json'
json_file_path = os.path.join(map_dir, json_file_name)
patch_size=(256, 256, 3)
overlap=30
# Instantiate DataLoader and get processed data
data_loader = DataLoader(map_img_path, json_file_path, patch_size, overlap)
processed_data = data_loader.get_processed_data()
# 'poly_legends', 'pt_legends', 'line_legends'
for legend in ['poly_legends']:
for legend_img, legend_label in processed_data[legend]:
# Convert legend_img back to uint8 and scale values to 0-255
legend_img_uint8 = tf.cast(legend_img * 255, dtype=tf.uint8).numpy()
legend_img_pil = Image.fromarray(legend_img_uint8)
output_legend_img_path = os.path.join(prediction_path, f"{os.path.splitext(map_file_name)[0]}_{legend_label}.png")
legend_img_pil.save(output_legend_img_path, 'PNG')
map_patches = processed_data['map_patches']
total_row, total_col, _, _, _, _ = map_patches.shape
predicted_patches = np.zeros((total_row, total_col, patch_size[0], patch_size[1]))
for i in range(total_row):
for j in range(total_col):
single_map_patch = map_patches[i, j, :, :][0]
# Concatenate along the third axis and normalize
input_patch = tf.concat(axis=2, values=[single_map_patch, legend_img])
input_patch = input_patch * 2.0 - 1.0
# Resize the input patch
input_patch_resized = tf.image.resize(input_patch, patch_size[:2])
# Expand dimensions for prediction
input_patch_expanded = tf.expand_dims(input_patch_resized, axis=0)
# Make prediction and store it
predicted_patch = model.predict(input_patch_expanded, verbose = 0)
predicted_patches[i, j, :, :] = predicted_patch.squeeze()
reconstructed_image = data_loader.reconstruct_data(predicted_patches)
reconstructed_image = (reconstructed_image * 255).astype(np.uint8)
output_image_path = os.path.join(prediction_path, f"{os.path.splitext(map_file_name)[0]}_{legend_label}.tif")
with rasterio.open(map_img_path) as src:
metadata = src.meta
metadata.update({
'dtype': 'uint8',
'count': 1,
'height': reconstructed_image.shape[0],
'width': reconstructed_image.shape[1],
'compress': 'lzw',
})
with rasterio.open(output_image_path, 'w', **metadata) as dst:
dst.write(reconstructed_image, 1)
print(f"Predicted image saved at: {output_image_path}")
################################################################
##### Prepare the model configurations #########################
################################################################
name_id = 'unproecessed_legends' #You can change the id for each run so that all models and stats are saved separately.
prediction_path = './predicts_'+name_id+'/'
model_path = './models_'+name_id+'/'
# Avaiable backbones for Unet architechture
# 'vgg16' 'vgg19' 'resnet18' 'resnet34' 'resnet50' 'resnet101' 'resnet152' 'inceptionv3'
# 'inceptionresnetv2' 'densenet121' 'densenet169' 'densenet201' 'seresnet18' 'seresnet34'
# 'seresnet50' 'seresnet101' 'seresnet152', and 'attentionUnet'
backend = 'attentionUnet'
name = 'Unet-'+ backend
finetune = False
if (finetune): name += "_ft"
model = load_model(model_path+name+'.h5',
custom_objects={'multiplication': multiplication,
'multiplication2': multiplication2,
'dice_coef_loss':dice_coef_loss,
'dice_coef':dice_coef,})
# Example of how to use the function
load_image_and_predict('AR_Maumee.tif', prediction_path, model)
import os
import numpy as np
import json
from PIL import Image
from patchify import patchify
import tensorflow as tf
# # Example of how to use the class
# data_loader = DataLoader('/path/to/AR_Maumee.tif', '/path/to/AR_Maumee.json', patch_size=(256, 256, 3), overlap=30)
# processed_data = data_loader.get_processed_data()
# reconstructed_image = data_loader.reconstruct_data(processed_data['map_patches'])
class DataLoader:
"""
DataLoader class to load and process TIFF images and corresponding JSON files.
Attributes:
tiff_path (str): Path to the input TIFF image file.
json_path (str): Path to the input JSON file.
patch_size (tuple of int): Size of the patches to extract from the TIFF image.
overlap (int): Overlapping pixels between patches.
map_img (numpy.ndarray): Loaded and normalized map image.
orig_size (tuple of int): Original size of the map image.
json_data (dict): Loaded JSON data.
processed_data (dict): Processed data including map patches and legends.
Methods:
load_and_process(): Loads and processes the TIFF image and JSON data.
load_tiff(): Loads and normalizes the TIFF image.
load_json(): Loads JSON data.
process_legends(label_suffix, resize_to=(256, 256)): Processes legends based on label suffix.
process_data(): Extracts and processes data from the loaded TIFF and JSON.
get_processed_data(): Returns the processed data.
"""
def __init__(self, tiff_path, json_path, patch_size=(256, 256, 3), overlap=30):
"""
Initializes DataLoader with specified file paths, patch size, and overlap.
Parameters:
tiff_path (str): Path to the input TIFF image file.
json_path (str): Path to the input JSON file.
patch_size (tuple of int, optional): Size of patches to extract from image. Default is (256, 256, 3).
overlap (int, optional): Number of overlapping pixels between patches. Default is 30.
"""
self.tiff_path = tiff_path
self.json_path = json_path
self.patch_size = patch_size
self.overlap = overlap
self.map_img = None
self.orig_size = None
self.json_data = None
self.processed_data = None
self.load_and_process()
def load_and_process(self):
"""Loads and processes the TIFF image and JSON data."""
self.load_tiff()
self.load_json()
self.process_data()
def load_tiff(self):
"""Loads and normalizes the TIFF image."""
print(f"Loading and normalizing map image: {self.tiff_path}")
self.map_img = Image.open(self.tiff_path)
self.orig_size = self.map_img.size
self.map_img = np.array(self.map_img) / 255.0
def load_json(self):
"""Loads JSON data."""
with open(self.json_path, 'r') as json_file:
self.json_data = json.load(json_file)
def process_legends(self, label_suffix, resize_to=(256, 256)):
"""
Processes legends based on the label suffix.
Parameters:
label_suffix (str): Suffix in the label to identify the type of legend.
resize_to (tuple of int, optional): Size to resize the legends to. Default is (256, 256).
Returns:
list of tuples: List of processed legends and their corresponding labels.
"""
legends = [shape for shape in self.json_data['shapes'] if label_suffix in shape['label']]
processed_legends = []
for legend in legends:
points = np.array(legend['points'])
top_left = points.min(axis=0)
bottom_right = points.max(axis=0)
legend_img = self.map_img[int(top_left[1]):int(bottom_right[1]), int(top_left[0]):int(bottom_right[0]), :]
legend_img = tf.image.resize(legend_img, resize_to)
processed_legends.append((legend_img.numpy(), legend['label']))
return processed_legends
def process_data(self):
"""Extracts and processes data from the loaded TIFF and JSON."""
step_size = self.patch_size[0] - self.overlap
pad_x = (step_size - (self.map_img.shape[1] % step_size)) % step_size
pad_y = (step_size - (self.map_img.shape[0] % step_size)) % step_size
self.map_img = np.pad(self.map_img, ((0, pad_y), (0, pad_x), (0, 0)), mode='constant')
print(f"Patchifying map image with overlap...")
map_patches = patchify(self.map_img, self.patch_size, step=step_size)
poly_legends = self.process_legends('_poly')
pt_legends = self.process_legends('_pt')
line_legends = self.process_legends('_line')
self.processed_data = {
"map_patches": map_patches,
"poly_legends": poly_legends,
"pt_legends": pt_legends,
"line_legends": line_legends,
"original_size": self.orig_size,
}
def get_processed_data(self):
"""
Returns the processed data.
Raises:
ValueError: If data has not been loaded and processed.
Returns:
dict: Processed data including map patches and legends.
"""
if not self.processed_data:
raise ValueError("Data should be loaded and processed first")
return self.processed_data
def reconstruct_data(self, patches):
"""
Reconstructs an image from overlapping patches, keeping the maximum value
for overlapping pixels.
Parameters:
patches (numpy array): Image patches.
Returns:
numpy array: Reconstructed image.
"""
assert self.overlap >= 0, "Overlap should be non-negative"
step = patches.shape[2] - self.overlap
img_shape = self.orig_size[::-1] # Reverse the original size tuple to match the array shape
img = np.zeros(img_shape) # Initialize the image with zeros
for i in range(patches.shape[0]):
for j in range(patches.shape[1]):
x_start = i * step
y_start = j * step
x_end = min(x_start + patches.shape[2], img_shape[0])
y_end = min(y_start + patches.shape[3], img_shape[1])
img[x_start:x_end, y_start:y_end] = np.maximum(
img[x_start:x_end, y_start:y_end],
patches[i, j, :x_end-x_start, :y_end-y_start]
)
return img
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment