From b85ee9d64a536937912544c7bbd5b98b635b7e8d Mon Sep 17 00:00:00 2001 From: Christian C Date: Mon, 11 Nov 2024 12:29:32 -0800 Subject: Initial commit --- code/sunlab/common/data/__init__.py | 6 + code/sunlab/common/data/basic.py | 6 + code/sunlab/common/data/dataset.py | 255 ++++++++++++++++++++++++++++ code/sunlab/common/data/dataset_iterator.py | 34 ++++ code/sunlab/common/data/image_dataset.py | 75 ++++++++ code/sunlab/common/data/shape_dataset.py | 57 +++++++ code/sunlab/common/data/utilities.py | 119 +++++++++++++ 7 files changed, 552 insertions(+) create mode 100644 code/sunlab/common/data/__init__.py create mode 100644 code/sunlab/common/data/basic.py create mode 100644 code/sunlab/common/data/dataset.py create mode 100644 code/sunlab/common/data/dataset_iterator.py create mode 100644 code/sunlab/common/data/image_dataset.py create mode 100644 code/sunlab/common/data/shape_dataset.py create mode 100644 code/sunlab/common/data/utilities.py (limited to 'code/sunlab/common/data') diff --git a/code/sunlab/common/data/__init__.py b/code/sunlab/common/data/__init__.py new file mode 100644 index 0000000..3e26874 --- /dev/null +++ b/code/sunlab/common/data/__init__.py @@ -0,0 +1,6 @@ +from .basic import * +from .dataset import * +from .dataset_iterator import * +from .shape_dataset import * +from .image_dataset import * +from .utilities import * diff --git a/code/sunlab/common/data/basic.py b/code/sunlab/common/data/basic.py new file mode 100644 index 0000000..bb2e912 --- /dev/null +++ b/code/sunlab/common/data/basic.py @@ -0,0 +1,6 @@ +import numpy + + +numpy.load_dat = lambda *args, **kwargs: numpy.load( + *args, **kwargs, allow_pickle=True +).item() diff --git a/code/sunlab/common/data/dataset.py b/code/sunlab/common/data/dataset.py new file mode 100644 index 0000000..8589abf --- /dev/null +++ b/code/sunlab/common/data/dataset.py @@ -0,0 +1,255 @@ +from .dataset_iterator import DatasetIterator + + +class Dataset: + """# Dataset Superclass""" + + base_scale = 10.0 + + def __init__( + self, + dataset_filename, + data_columns=[], + label_columns=[], + batch_size=None, + shuffle=False, + val_split=0.0, + scaler=None, + sort_columns=None, + random_seed=4332, + pre_scale=10.0, + **kwargs + ): + """# Initialize Dataset + self.dataset = dataset (N, ...) + self.labels = labels (N, ...) + + Optional Arguments: + - prescale_function: The function that takes the ratio and transforms + the dataset by multiplying the prescale_function output + - sort_columns: The columns to sort the data by initially + - equal_split: If the classifications should be equally split in + training""" + from pandas import read_csv + from numpy import array, all + from numpy.random import seed + + if seed is not None: + seed(random_seed) + + # Basic Dataset Information + self.data_columns = data_columns + self.label_columns = label_columns + self.source = dataset_filename + self.dataframe = read_csv(self.source) + + # Pre-scaling Transformation + prescale_ratio = self.base_scale / pre_scale + ratio = prescale_ratio + prescale_powers = array([2, 1, 1, 0, 2, 1, 0, 0, 1, 1, 1, 1, 1]) + if "prescale_function" in kwargs.keys(): + prescale_function = kwargs["prescale_function"] + else: + + def prescale_function(x): + return x**prescale_powers + + self.prescale_function = prescale_function + self.prescale_factor = self.prescale_function(ratio) + assert ( + len(data_columns) == self.prescale_factor.shape[0] + ), "Column Mismatch on Prescale" + self.original_scale = pre_scale + + # Scaling Transformation + self.scaled = scaler is not None + self.scaler = scaler + + # Training Dataset Information + self.do_split = False if val_split == 0.0 else True + self.validation_split = val_split + self.batch_size = batch_size + self.do_shuffle = shuffle + self.equal_split = False + if "equal_split" in kwargs.keys(): + self.equal_split = kwargs["equal_split"] + + # Classification Labels if they exist + self.dataset = self.dataframe[self.data_columns].to_numpy() + if len(self.label_columns) == 0: + self.labels = None + elif not all([column in self.dataframe.columns for column in label_columns]): + import warnings + + warnings.warn( + "No classification labels found for the dataset", RuntimeWarning + ) + self.labels = None + else: + self.labels = self.dataframe[self.label_columns].squeeze() + + # Initialize the dataset + if "sort_columns" in kwargs.keys(): + self.sort(kwargs["sort_columns"]) + if self.do_shuffle: + self.shuffle() + if self.do_split: + self.split() + self.refresh_dataset() + + def __len__(self): + """# Get how many cases are in the dataset""" + return self.dataset.shape[0] + + def __getitem__(self, idx): + """# Make Dataset Sliceable""" + idx_slice = None + slice_stride = 1 if self.batch_size is None else self.batch_size + # If we pass a slice, return the slice + if type(idx) == slice: + idx_slice = idx + # If we pass an int, return a batch-size slice + else: + idx_slice = slice( + idx * slice_stride, min([len(self), (idx + 1) * slice_stride]) + ) + if self.labels is None: + return self.dataset[idx_slice, ...] + return self.dataset[idx_slice, ...], self.labels[idx_slice, ...] + + def scale_data(self, data): + """# Scale dataset from scaling function""" + data = data * self.prescale_factor + if not (self.scaler is None): + data = self.scaler(data) + return data + + def scale(self): + """# Scale Dataset""" + self.dataset = self.scale_data(self.dataset) + + def refresh_dataset(self, dataframe=None): + """# Refresh Dataset + + Regenerate the dataset from a dataframe. + Primarily used after a sort or filter.""" + if dataframe is None: + dataframe = self.dataframe + self.dataset = dataframe[self.data_columns].to_numpy() + if self.labels is not None: + self.labels = dataframe[self.label_columns].to_numpy().squeeze() + self.scale() + + def sort_on(self, columns): + """# Sort Dataset on Column(s)""" + from numpy import all + + if type(columns) == str: + columns = [columns] + if columns is not None: + assert all( + [column in self.dataframe.columns for column in columns] + ), "Dataframe does not contain some provided columns!" + self.dataframe = self.dataframe.sort_values(by=columns) + self.refresh_dataset() + + def filter_on(self, column, value): + """# Filter Dataset on Column Value(s)""" + assert column in self.dataframe.columns, "Column DNE" + self.working_dataset = self.dataframe[self.dataframe[column].isin(value)] + self.refresh_dataset(self.working_dataset) + + def filter_off(self): + """# Remove any filter on the dataset""" + self.refresh_dataset() + + def unique(self, column): + """# Get unique values in a column(s)""" + assert column in self.dataframe.columns, "Column DNE" + from numpy import unique + + return unique(self.dataframe[column]) + + def shuffle_data(self, data, labels=None): + """# Shuffle a dataset""" + from numpy.random import permutation + + shuffled = permutation(data.shape[0]) + if labels is not None: + assert ( + self.labels.shape[0] == self.dataset.shape[0] + ), "Dataset and Label Shape Mismatch" + shuf_data = data[shuffled, ...] + shuf_labels = labels[shuffled] + if len(labels.shape) > 1: + shuf_labels = labels[shuffled,...] + return shuf_data, shuf_labels + return data[shuffled, ...] + + def shuffle(self): + """# Shuffle the dataset""" + if self.do_shuffle: + if self.labels is None: + self.dataset = self.shuffle_data(self.dataset) + self.dataset, self.labels = self.shuffle_data(self.dataset, self.labels) + + def split(self): + """# Training/ Validation Splitting""" + from numpy import floor, unique, where, hstack, delete + from numpy.random import permutation + + equal_classes = self.equal_split + if not self.do_split: + return + assert self.validation_split <= 1.0, "Too High" + assert self.validation_split > 0.0, "Too Low" + train_count = int(floor(self.dataset.shape[0] * (1 - self.validation_split))) + training_data = self.dataset[:train_count, ...] + training_labels = None + validation_data = self.dataset[train_count:, ...] + validation_labels = None + if self.labels is not None: + if equal_classes: + # Ensure the split balances the prevalence of each class + assert len(self.labels.shape) == 1, "1D Classification Only Currently" + classification_breakdown = unique(self.labels, return_counts=True) + train_count = min( + [ + train_count, + classification_breakdown.shape[0] + * min(classification_breakdown[1]), + ] + ) + class_size = train_count / classification_breakdown.shape[0] + class_indicies = [ + permutation(where(self.labels == _class)[0]) + for _class in classification_breakdown[0] + ] + class_indicies = [indexes[:class_size] for indexes in class_indicies] + train_class_indicies = hstack(class_indicies).squeeze() + train_class_indicies = permutation(train_class_indicies) + training_data = self.dataset[train_class_indicies, ...] + training_labels = self.labels[train_class_indicies] + if len(self.labels.shape) > 1: + training_labels = self.labels[train_class_indicies,...] + validation_data = delete(self.dataset, train_class_indicies, axis=0) + validation_labels = delete( + self.labels, train_class_indicies, axis=0 + ).squeeze() + else: + training_labels = self.labels[:train_count] + if len(training_labels.shape) > 1: + training_labels = self.labels[:train_count, ...] + validation_labels = self.labels[train_count:] + if len(validation_labels.shape) > 1: + validation_labels = self.labels[train_count:, ...] + self.training_data = training_data + self.validation_data = validation_data + self.training = DatasetIterator(training_data, training_labels, self.batch_size) + self.validation = DatasetIterator( + validation_data, validation_labels, self.batch_size + ) + + def reset_iterators(self): + """# Reset Train/ Validation Iterators""" + self.split() diff --git a/code/sunlab/common/data/dataset_iterator.py b/code/sunlab/common/data/dataset_iterator.py new file mode 100644 index 0000000..7c91caa --- /dev/null +++ b/code/sunlab/common/data/dataset_iterator.py @@ -0,0 +1,34 @@ +class DatasetIterator: + """# Dataset Iterator + + Creates an iterator object on a dataset and labels""" + + def __init__(self, dataset, labels=None, batch_size=None): + """# Initialize the iterator with the dataset and labels + + - batch_size: How many to include in the iteration""" + self.dataset = dataset + self.labels = labels + self.current = 0 + self.batch_size = ( + batch_size if batch_size is not None else self.dataset.shape[0] + ) + + def __iter__(self): + """# Iterator Function""" + return self + + def __next__(self): + """# Next Iteration + + Slice the dataset and labels to provide""" + self.cur = self.current + self.current += 1 + if self.cur * self.batch_size < self.dataset.shape[0]: + iterator_slice = slice( + self.cur * self.batch_size, (self.cur + 1) * self.batch_size + ) + if self.labels is None: + return self.dataset[iterator_slice, ...] + return self.dataset[iterator_slice, ...], self.labels[iterator_slice, ...] + raise StopIteration diff --git a/code/sunlab/common/data/image_dataset.py b/code/sunlab/common/data/image_dataset.py new file mode 100644 index 0000000..46e77b6 --- /dev/null +++ b/code/sunlab/common/data/image_dataset.py @@ -0,0 +1,75 @@ +class ImageDataset: + def __init__( + self, + base_directory, + ext="png", + channels=[0], + batch_size=None, + shuffle=False, + rotate=False, + rotate_p=1., + ): + """# Image Dataset + + Load a directory of images""" + from glob import glob + from matplotlib.pyplot import imread + from numpy import newaxis, vstack + from numpy.random import permutation, rand + + self.base_directory = base_directory + files = glob(self.base_directory + "*." + ext) + self.dataset = [] + for file in files: + im = imread(file)[newaxis, :, :, channels].transpose(0, 3, 1, 2) + self.dataset.append(im) + # Also add rotations of the image to the dataset + if rotate: + if rand() < rotate_p: + self.dataset.append(im[:, :, ::-1, :]) + if rand() < rotate_p: + self.dataset.append(im[:, :, :, ::-1]) + if rand() < rotate_p: + self.dataset.append(im[:, :, ::-1, ::-1]) + if rand() < rotate_p: + self.dataset.append(im.transpose(0, 1, 3, 2)) + if rand() < rotate_p: + self.dataset.append(im.transpose(0, 1, 3, 2)[:, :, ::-1, :]) + if rand() < rotate_p: + self.dataset.append(im.transpose(0, 1, 3, 2)[:, :, :, ::-1]) + if rand() < rotate_p: + self.dataset.append(im.transpose(0, 1, 3, 2)[:, :, ::-1, ::-1]) + self.dataset = vstack(self.dataset) + if shuffle: + self.dataset = self.dataset[permutation(self.dataset.shape[0]), ...] + self.batch_size = ( + batch_size if batch_size is not None else self.dataset.shape[0] + ) + + def torch(self, device=None): + """# Cast to Torch Tensor""" + import torch + + if device is None: + device = torch.device("cpu") + return torch.tensor(self.dataset).to(device) + + def numpy(self): + """# Cast to Numpy Array""" + return self.dataset + + def __len__(self): + """# Return Number of Cases + + (or Number in each Batch)""" + return self.dataset.shape[0] // self.batch_size + + def __getitem__(self, index): + """# Slice By Batch""" + if type(index) == tuple: + return self.dataset[index] + elif type(index) == int: + return self.dataset[ + index * self.batch_size : (index + 1) * self.batch_size, ... + ] + return diff --git a/code/sunlab/common/data/shape_dataset.py b/code/sunlab/common/data/shape_dataset.py new file mode 100644 index 0000000..5a68736 --- /dev/null +++ b/code/sunlab/common/data/shape_dataset.py @@ -0,0 +1,57 @@ +from .dataset import Dataset + + +class ShapeDataset(Dataset): + """# Shape Dataset""" + + def __init__( + self, + dataset_filename, + data_columns=[ + "Area", + "MjrAxisLength", + "MnrAxisLength", + "Eccentricity", + "ConvexArea", + "EquivDiameter", + "Solidity", + "Extent", + "Perimeter", + "ConvexPerim", + "FibLen", + "InscribeR", + "BlebLen", + ], + label_columns=["Class"], + batch_size=None, + shuffle=False, + val_split=0.0, + scaler=None, + sort_columns=None, + random_seed=4332, + pre_scale=10, + **kwargs + ): + """# Initialize Dataset + self.dataset = dataset (N, ...) + self.labels = labels (N, ...) + + Optional Arguments: + - prescale_function: The function that takes the ratio and transforms + the dataset by multiplying the prescale_function output + - sort_columns: The columns to sort the data by initially + - equal_split: If the classifications should be equally split in + training""" + super().__init__( + dataset_filename, + data_columns=data_columns, + label_columns=label_columns, + batch_size=batch_size, + shuffle=shuffle, + val_split=val_split, + scaler=scaler, + sort_columns=sort_columns, + random_seed=random_seed, + pre_scale=pre_scale, + **kwargs + ) diff --git a/code/sunlab/common/data/utilities.py b/code/sunlab/common/data/utilities.py new file mode 100644 index 0000000..6b4e6f3 --- /dev/null +++ b/code/sunlab/common/data/utilities.py @@ -0,0 +1,119 @@ +from .shape_dataset import ShapeDataset +from ..scaler.max_abs_scaler import MaxAbsScaler + + +def import_10x( + filename, + magnification=10, + batch_size=None, + shuffle=False, + val_split=0.0, + scaler=None, + sort_columns=None, +): + """# Import a 10x Image Dataset + + Pixel-to-micron: ???""" + magnification = 10 + dataset = ShapeDataset( + filename, + batch_size=batch_size, + shuffle=shuffle, + pre_scale=magnification, + val_split=val_split, + scaler=scaler, + sort_columns=sort_columns, + ) + return dataset + + +def import_20x( + filename, + magnification=10, + batch_size=None, + shuffle=False, + val_split=0.0, + scaler=None, + sort_columns=None, +): + """# Import a 20x Image Dataset + + Pixel-to-micron: ???""" + magnification = 20 + dataset = ShapeDataset( + filename, + batch_size=batch_size, + shuffle=shuffle, + pre_scale=magnification, + val_split=val_split, + scaler=scaler, + sort_columns=sort_columns, + ) + return dataset + + +def import_dataset( + filename, + magnification, + batch_size=None, + shuffle=False, + val_split=0.0, + scaler=None, + sort_columns=None, +): + """# Import a dataset + + Requires a magnificaiton to be specified""" + dataset = ShapeDataset( + filename, + pre_scale=magnification, + batch_size=batch_size, + shuffle=shuffle, + val_split=val_split, + scaler=scaler, + sort_columns=sort_columns, + ) + return dataset + + +def import_full_dataset(fname, magnification=20, scaler=None): + """# Import a Full Dataset + + If a classification file exists(.txt with a 'Class' header and 'frame','cellnumber' headers), also import it""" + from os.path import isfile + import pandas as pd + import numpy as np + + cfname = fname + tfname = cfname[:-3] + "txt" + columns = [ + "frame", + "cellnumber", + "x-cent", + "y-cent", + "actinedge", + "filopodia", + "bleb", + "lamellipodia", + ] + if isfile(tfname): + dataset = import_dataset(cfname, magnification=magnification, scaler=scaler) + class_df = np.loadtxt(tfname, skiprows=1) + class_df = pd.DataFrame(class_df, columns=columns) + full_df = pd.merge( + dataset.dataframe, + class_df, + left_on=["Frames", "CellNum"], + right_on=["frame", "cellnumber"], + ) + full_df["Class"] = np.argmax( + class_df[["actinedge", "filopodia", "bleb", "lamellipodia"]].to_numpy(), + axis=-1, + ) + dataset.labels = full_df["Class"].to_numpy() + else: + dataset = import_dataset(cfname, magnification=magnification, scaler=scaler) + full_df = dataset.dataframe + dataset.dataframe = full_df + dataset.filter_off() + return dataset -- cgit v1.2.1