Source code for densetorch.data.datasets

import os

import numpy as np
from PIL import Image
from torch.utils.data import Dataset

from .utils import KEYS_TO_DTYPES


def transpose_normals(normals):
    # Transpose normals: HxWxC -> CxHxW
    normals = normals.permute(2, 0, 1)
    return normals


def convert_normals_to_xzy(normals):
    normals = normals / 255.0
    normals[:, :, 0] = normals[:, :, 0] * 2.0 - 1.0  # x
    normals[:, :, 1] = normals[:, :, 1] * 2.0 - 1.0  # z
    normals[:, :, 2] = 1.0 - normals[:, :, 2] * 2.0  # y
    return normals


def set_normals_ignore_indices(
    normals, depth, normals_ignore_index, depth_ignore_index
):
    depth_ignore_region = depth == depth_ignore_index
    normals[depth_ignore_region] = normals_ignore_index
    return normals


[docs]class MMDataset(Dataset): """Multi-Modality dataset. Works with any datasets that contain image and any number of 2D-annotations. Args: data_file (string): Path to the data file with annotations. data_dir (string): Directory with all the images. data_list_sep (string): Separator between the columns in data_file. data_list_columns (list of strings): Column names in data_file. masks_names (list of strings): Keys for each annotation mask (e.g., 'segm', 'depth'). ignore_indices (list of floats or ints): Ignore values for each annotation mask in the same order as `masks_names`. depth_scale (float): Scaling factor for depth. transform (callable, optional): Optional transform to be applied on a sample. Raises: ValueError: if the number of columns in data_file is not equal to `len(data_list_columns)`. """ def __init__( self, data_file, data_dir, data_list_sep, data_list_columns, masks_names, ignore_indices, depth_scale=1.0, transform=None, ): with open(data_file, "rb") as f: datalist = f.readlines() self.datalist = [] for line in datalist: columns = line.decode("utf-8").strip("\n").split(data_list_sep) assert len(columns) == len( data_list_columns ), f"Inconsistent number of columns, got {len(columns):d} for the following names: {data_list_columns}" self.datalist.append( {name: path for name, path in zip(data_list_columns, columns)} ) self.root_dir = data_dir self.transform = transform self.masks_names = masks_names self.ignore_indices = ignore_indices self.depth_scale = depth_scale def __len__(self): return len(self.datalist) def __getitem__(self, idx): sample_paths = { key: os.path.join(self.root_dir, path) for key, path in self.datalist[idx].items() } sample = {} sample["image"] = self.read_image(sample_paths["image"]) for key in self.masks_names: mask = np.array(Image.open(sample_paths[key])) if key != "normals": assert ( len(mask.shape) == 2 ), "Segm and depth masks must be encoded without colourmap" else: mask = convert_normals_to_xzy(mask) sample[key] = mask if self.transform: sample["names"] = self.masks_names sample = self.transform(sample) # the names key can be removed by the transformation if "names" in sample: del sample["names"] for key in self.masks_names: if key in KEYS_TO_DTYPES: scale = 1 if key == "depth": scale = self.depth_scale sample[key] = sample[key].to(KEYS_TO_DTYPES[key]) / scale # The ignore regions of surface normals are dependent on the ignore regions of depth if "depth" in self.masks_names and "normals" in self.masks_names: depth_ignore_index = self.ignore_indices[self.masks_names.index("depth")] normals_ignore_index = self.ignore_indices[ self.masks_names.index("normals") ] sample["normals"] = transpose_normals( set_normals_ignore_indices( sample["normals"], sample["depth"], normals_ignore_index, depth_ignore_index, ) ) elif "normals" in self.masks_names: sample["normals"] = transpose_normals(sample["normals"]) return sample
[docs] @staticmethod def read_image(x): """Simple image reader Args: x (str): path to image. Returns image as `np.array`. """ img_arr = np.array(Image.open(x), dtype=np.float32) if len(img_arr.shape) == 2: # grayscale img_arr = np.tile(img_arr, [3, 1, 1]).transpose(1, 2, 0) return img_arr