aboutsummaryrefslogtreecommitdiff
path: root/code/sunlab/common
diff options
context:
space:
mode:
Diffstat (limited to 'code/sunlab/common')
-rw-r--r--code/sunlab/common/__init__.py5
-rw-r--r--code/sunlab/common/data/__init__.py6
-rw-r--r--code/sunlab/common/data/basic.py6
-rw-r--r--code/sunlab/common/data/dataset.py255
-rw-r--r--code/sunlab/common/data/dataset_iterator.py34
-rw-r--r--code/sunlab/common/data/image_dataset.py75
-rw-r--r--code/sunlab/common/data/shape_dataset.py57
-rw-r--r--code/sunlab/common/data/utilities.py119
-rw-r--r--code/sunlab/common/distribution/__init__.py7
-rw-r--r--code/sunlab/common/distribution/adversarial_distribution.py35
-rw-r--r--code/sunlab/common/distribution/gaussian_distribution.py23
-rw-r--r--code/sunlab/common/distribution/o_gaussian_distribution.py38
-rw-r--r--code/sunlab/common/distribution/s_gaussian_distribution.py40
-rw-r--r--code/sunlab/common/distribution/swiss_roll_distribution.py42
-rw-r--r--code/sunlab/common/distribution/symmetric_uniform_distribution.py21
-rw-r--r--code/sunlab/common/distribution/uniform_distribution.py21
-rw-r--r--code/sunlab/common/distribution/x_gaussian_distribution.py38
-rw-r--r--code/sunlab/common/mathlib/__init__.py1
-rw-r--r--code/sunlab/common/mathlib/base.py57
-rw-r--r--code/sunlab/common/mathlib/lyapunov.py54
-rw-r--r--code/sunlab/common/mathlib/random_walks.py83
-rw-r--r--code/sunlab/common/plotting/__init__.py2
-rw-r--r--code/sunlab/common/plotting/base.py270
-rw-r--r--code/sunlab/common/plotting/colors.py38
-rw-r--r--code/sunlab/common/scaler/__init__.py2
-rw-r--r--code/sunlab/common/scaler/adversarial_scaler.py44
-rw-r--r--code/sunlab/common/scaler/max_abs_scaler.py48
-rw-r--r--code/sunlab/common/scaler/quantile_scaler.py52
28 files changed, 1473 insertions, 0 deletions
diff --git a/code/sunlab/common/__init__.py b/code/sunlab/common/__init__.py
new file mode 100644
index 0000000..cb6716c
--- /dev/null
+++ b/code/sunlab/common/__init__.py
@@ -0,0 +1,5 @@
+from .data import *
+from .distribution import *
+from .scaler import *
+from .mathlib import *
+from .plotting import *
diff --git a/code/sunlab/common/data/__init__.py b/code/sunlab/common/data/__init__.py
new file mode 100644
index 0000000..3e26874
--- /dev/null
+++ b/code/sunlab/common/data/__init__.py
@@ -0,0 +1,6 @@
+from .basic import *
+from .dataset import *
+from .dataset_iterator import *
+from .shape_dataset import *
+from .image_dataset import *
+from .utilities import *
diff --git a/code/sunlab/common/data/basic.py b/code/sunlab/common/data/basic.py
new file mode 100644
index 0000000..bb2e912
--- /dev/null
+++ b/code/sunlab/common/data/basic.py
@@ -0,0 +1,6 @@
+import numpy
+
+
+numpy.load_dat = lambda *args, **kwargs: numpy.load(
+ *args, **kwargs, allow_pickle=True
+).item()
diff --git a/code/sunlab/common/data/dataset.py b/code/sunlab/common/data/dataset.py
new file mode 100644
index 0000000..8589abf
--- /dev/null
+++ b/code/sunlab/common/data/dataset.py
@@ -0,0 +1,255 @@
+from .dataset_iterator import DatasetIterator
+
+
+class Dataset:
+ """# Dataset Superclass"""
+
+ base_scale = 10.0
+
+ def __init__(
+ self,
+ dataset_filename,
+ data_columns=[],
+ label_columns=[],
+ batch_size=None,
+ shuffle=False,
+ val_split=0.0,
+ scaler=None,
+ sort_columns=None,
+ random_seed=4332,
+ pre_scale=10.0,
+ **kwargs
+ ):
+ """# Initialize Dataset
+ self.dataset = dataset (N, ...)
+ self.labels = labels (N, ...)
+
+ Optional Arguments:
+ - prescale_function: The function that takes the ratio and transforms
+ the dataset by multiplying the prescale_function output
+ - sort_columns: The columns to sort the data by initially
+ - equal_split: If the classifications should be equally split in
+ training"""
+ from pandas import read_csv
+ from numpy import array, all
+ from numpy.random import seed
+
+ if seed is not None:
+ seed(random_seed)
+
+ # Basic Dataset Information
+ self.data_columns = data_columns
+ self.label_columns = label_columns
+ self.source = dataset_filename
+ self.dataframe = read_csv(self.source)
+
+ # Pre-scaling Transformation
+ prescale_ratio = self.base_scale / pre_scale
+ ratio = prescale_ratio
+ prescale_powers = array([2, 1, 1, 0, 2, 1, 0, 0, 1, 1, 1, 1, 1])
+ if "prescale_function" in kwargs.keys():
+ prescale_function = kwargs["prescale_function"]
+ else:
+
+ def prescale_function(x):
+ return x**prescale_powers
+
+ self.prescale_function = prescale_function
+ self.prescale_factor = self.prescale_function(ratio)
+ assert (
+ len(data_columns) == self.prescale_factor.shape[0]
+ ), "Column Mismatch on Prescale"
+ self.original_scale = pre_scale
+
+ # Scaling Transformation
+ self.scaled = scaler is not None
+ self.scaler = scaler
+
+ # Training Dataset Information
+ self.do_split = False if val_split == 0.0 else True
+ self.validation_split = val_split
+ self.batch_size = batch_size
+ self.do_shuffle = shuffle
+ self.equal_split = False
+ if "equal_split" in kwargs.keys():
+ self.equal_split = kwargs["equal_split"]
+
+ # Classification Labels if they exist
+ self.dataset = self.dataframe[self.data_columns].to_numpy()
+ if len(self.label_columns) == 0:
+ self.labels = None
+ elif not all([column in self.dataframe.columns for column in label_columns]):
+ import warnings
+
+ warnings.warn(
+ "No classification labels found for the dataset", RuntimeWarning
+ )
+ self.labels = None
+ else:
+ self.labels = self.dataframe[self.label_columns].squeeze()
+
+ # Initialize the dataset
+ if "sort_columns" in kwargs.keys():
+ self.sort(kwargs["sort_columns"])
+ if self.do_shuffle:
+ self.shuffle()
+ if self.do_split:
+ self.split()
+ self.refresh_dataset()
+
+ def __len__(self):
+ """# Get how many cases are in the dataset"""
+ return self.dataset.shape[0]
+
+ def __getitem__(self, idx):
+ """# Make Dataset Sliceable"""
+ idx_slice = None
+ slice_stride = 1 if self.batch_size is None else self.batch_size
+ # If we pass a slice, return the slice
+ if type(idx) == slice:
+ idx_slice = idx
+ # If we pass an int, return a batch-size slice
+ else:
+ idx_slice = slice(
+ idx * slice_stride, min([len(self), (idx + 1) * slice_stride])
+ )
+ if self.labels is None:
+ return self.dataset[idx_slice, ...]
+ return self.dataset[idx_slice, ...], self.labels[idx_slice, ...]
+
+ def scale_data(self, data):
+ """# Scale dataset from scaling function"""
+ data = data * self.prescale_factor
+ if not (self.scaler is None):
+ data = self.scaler(data)
+ return data
+
+ def scale(self):
+ """# Scale Dataset"""
+ self.dataset = self.scale_data(self.dataset)
+
+ def refresh_dataset(self, dataframe=None):
+ """# Refresh Dataset
+
+ Regenerate the dataset from a dataframe.
+ Primarily used after a sort or filter."""
+ if dataframe is None:
+ dataframe = self.dataframe
+ self.dataset = dataframe[self.data_columns].to_numpy()
+ if self.labels is not None:
+ self.labels = dataframe[self.label_columns].to_numpy().squeeze()
+ self.scale()
+
+ def sort_on(self, columns):
+ """# Sort Dataset on Column(s)"""
+ from numpy import all
+
+ if type(columns) == str:
+ columns = [columns]
+ if columns is not None:
+ assert all(
+ [column in self.dataframe.columns for column in columns]
+ ), "Dataframe does not contain some provided columns!"
+ self.dataframe = self.dataframe.sort_values(by=columns)
+ self.refresh_dataset()
+
+ def filter_on(self, column, value):
+ """# Filter Dataset on Column Value(s)"""
+ assert column in self.dataframe.columns, "Column DNE"
+ self.working_dataset = self.dataframe[self.dataframe[column].isin(value)]
+ self.refresh_dataset(self.working_dataset)
+
+ def filter_off(self):
+ """# Remove any filter on the dataset"""
+ self.refresh_dataset()
+
+ def unique(self, column):
+ """# Get unique values in a column(s)"""
+ assert column in self.dataframe.columns, "Column DNE"
+ from numpy import unique
+
+ return unique(self.dataframe[column])
+
+ def shuffle_data(self, data, labels=None):
+ """# Shuffle a dataset"""
+ from numpy.random import permutation
+
+ shuffled = permutation(data.shape[0])
+ if labels is not None:
+ assert (
+ self.labels.shape[0] == self.dataset.shape[0]
+ ), "Dataset and Label Shape Mismatch"
+ shuf_data = data[shuffled, ...]
+ shuf_labels = labels[shuffled]
+ if len(labels.shape) > 1:
+ shuf_labels = labels[shuffled,...]
+ return shuf_data, shuf_labels
+ return data[shuffled, ...]
+
+ def shuffle(self):
+ """# Shuffle the dataset"""
+ if self.do_shuffle:
+ if self.labels is None:
+ self.dataset = self.shuffle_data(self.dataset)
+ self.dataset, self.labels = self.shuffle_data(self.dataset, self.labels)
+
+ def split(self):
+ """# Training/ Validation Splitting"""
+ from numpy import floor, unique, where, hstack, delete
+ from numpy.random import permutation
+
+ equal_classes = self.equal_split
+ if not self.do_split:
+ return
+ assert self.validation_split <= 1.0, "Too High"
+ assert self.validation_split > 0.0, "Too Low"
+ train_count = int(floor(self.dataset.shape[0] * (1 - self.validation_split)))
+ training_data = self.dataset[:train_count, ...]
+ training_labels = None
+ validation_data = self.dataset[train_count:, ...]
+ validation_labels = None
+ if self.labels is not None:
+ if equal_classes:
+ # Ensure the split balances the prevalence of each class
+ assert len(self.labels.shape) == 1, "1D Classification Only Currently"
+ classification_breakdown = unique(self.labels, return_counts=True)
+ train_count = min(
+ [
+ train_count,
+ classification_breakdown.shape[0]
+ * min(classification_breakdown[1]),
+ ]
+ )
+ class_size = train_count / classification_breakdown.shape[0]
+ class_indicies = [
+ permutation(where(self.labels == _class)[0])
+ for _class in classification_breakdown[0]
+ ]
+ class_indicies = [indexes[:class_size] for indexes in class_indicies]
+ train_class_indicies = hstack(class_indicies).squeeze()
+ train_class_indicies = permutation(train_class_indicies)
+ training_data = self.dataset[train_class_indicies, ...]
+ training_labels = self.labels[train_class_indicies]
+ if len(self.labels.shape) > 1:
+ training_labels = self.labels[train_class_indicies,...]
+ validation_data = delete(self.dataset, train_class_indicies, axis=0)
+ validation_labels = delete(
+ self.labels, train_class_indicies, axis=0
+ ).squeeze()
+ else:
+ training_labels = self.labels[:train_count]
+ if len(training_labels.shape) > 1:
+ training_labels = self.labels[:train_count, ...]
+ validation_labels = self.labels[train_count:]
+ if len(validation_labels.shape) > 1:
+ validation_labels = self.labels[train_count:, ...]
+ self.training_data = training_data
+ self.validation_data = validation_data
+ self.training = DatasetIterator(training_data, training_labels, self.batch_size)
+ self.validation = DatasetIterator(
+ validation_data, validation_labels, self.batch_size
+ )
+
+ def reset_iterators(self):
+ """# Reset Train/ Validation Iterators"""
+ self.split()
diff --git a/code/sunlab/common/data/dataset_iterator.py b/code/sunlab/common/data/dataset_iterator.py
new file mode 100644
index 0000000..7c91caa
--- /dev/null
+++ b/code/sunlab/common/data/dataset_iterator.py
@@ -0,0 +1,34 @@
+class DatasetIterator:
+ """# Dataset Iterator
+
+ Creates an iterator object on a dataset and labels"""
+
+ def __init__(self, dataset, labels=None, batch_size=None):
+ """# Initialize the iterator with the dataset and labels
+
+ - batch_size: How many to include in the iteration"""
+ self.dataset = dataset
+ self.labels = labels
+ self.current = 0
+ self.batch_size = (
+ batch_size if batch_size is not None else self.dataset.shape[0]
+ )
+
+ def __iter__(self):
+ """# Iterator Function"""
+ return self
+
+ def __next__(self):
+ """# Next Iteration
+
+ Slice the dataset and labels to provide"""
+ self.cur = self.current
+ self.current += 1
+ if self.cur * self.batch_size < self.dataset.shape[0]:
+ iterator_slice = slice(
+ self.cur * self.batch_size, (self.cur + 1) * self.batch_size
+ )
+ if self.labels is None:
+ return self.dataset[iterator_slice, ...]
+ return self.dataset[iterator_slice, ...], self.labels[iterator_slice, ...]
+ raise StopIteration
diff --git a/code/sunlab/common/data/image_dataset.py b/code/sunlab/common/data/image_dataset.py
new file mode 100644
index 0000000..46e77b6
--- /dev/null
+++ b/code/sunlab/common/data/image_dataset.py
@@ -0,0 +1,75 @@
+class ImageDataset:
+ def __init__(
+ self,
+ base_directory,
+ ext="png",
+ channels=[0],
+ batch_size=None,
+ shuffle=False,
+ rotate=False,
+ rotate_p=1.,
+ ):
+ """# Image Dataset
+
+ Load a directory of images"""
+ from glob import glob
+ from matplotlib.pyplot import imread
+ from numpy import newaxis, vstack
+ from numpy.random import permutation, rand
+
+ self.base_directory = base_directory
+ files = glob(self.base_directory + "*." + ext)
+ self.dataset = []
+ for file in files:
+ im = imread(file)[newaxis, :, :, channels].transpose(0, 3, 1, 2)
+ self.dataset.append(im)
+ # Also add rotations of the image to the dataset
+ if rotate:
+ if rand() < rotate_p:
+ self.dataset.append(im[:, :, ::-1, :])
+ if rand() < rotate_p:
+ self.dataset.append(im[:, :, :, ::-1])
+ if rand() < rotate_p:
+ self.dataset.append(im[:, :, ::-1, ::-1])
+ if rand() < rotate_p:
+ self.dataset.append(im.transpose(0, 1, 3, 2))
+ if rand() < rotate_p:
+ self.dataset.append(im.transpose(0, 1, 3, 2)[:, :, ::-1, :])
+ if rand() < rotate_p:
+ self.dataset.append(im.transpose(0, 1, 3, 2)[:, :, :, ::-1])
+ if rand() < rotate_p:
+ self.dataset.append(im.transpose(0, 1, 3, 2)[:, :, ::-1, ::-1])
+ self.dataset = vstack(self.dataset)
+ if shuffle:
+ self.dataset = self.dataset[permutation(self.dataset.shape[0]), ...]
+ self.batch_size = (
+ batch_size if batch_size is not None else self.dataset.shape[0]
+ )
+
+ def torch(self, device=None):
+ """# Cast to Torch Tensor"""
+ import torch
+
+ if device is None:
+ device = torch.device("cpu")
+ return torch.tensor(self.dataset).to(device)
+
+ def numpy(self):
+ """# Cast to Numpy Array"""
+ return self.dataset
+
+ def __len__(self):
+ """# Return Number of Cases
+
+ (or Number in each Batch)"""
+ return self.dataset.shape[0] // self.batch_size
+
+ def __getitem__(self, index):
+ """# Slice By Batch"""
+ if type(index) == tuple:
+ return self.dataset[index]
+ elif type(index) == int:
+ return self.dataset[
+ index * self.batch_size : (index + 1) * self.batch_size, ...
+ ]
+ return
diff --git a/code/sunlab/common/data/shape_dataset.py b/code/sunlab/common/data/shape_dataset.py
new file mode 100644
index 0000000..5a68736
--- /dev/null
+++ b/code/sunlab/common/data/shape_dataset.py
@@ -0,0 +1,57 @@
+from .dataset import Dataset
+
+
+class ShapeDataset(Dataset):
+ """# Shape Dataset"""
+
+ def __init__(
+ self,
+ dataset_filename,
+ data_columns=[
+ "Area",
+ "MjrAxisLength",
+ "MnrAxisLength",
+ "Eccentricity",
+ "ConvexArea",
+ "EquivDiameter",
+ "Solidity",
+ "Extent",
+ "Perimeter",
+ "ConvexPerim",
+ "FibLen",
+ "InscribeR",
+ "BlebLen",
+ ],
+ label_columns=["Class"],
+ batch_size=None,
+ shuffle=False,
+ val_split=0.0,
+ scaler=None,
+ sort_columns=None,
+ random_seed=4332,
+ pre_scale=10,
+ **kwargs
+ ):
+ """# Initialize Dataset
+ self.dataset = dataset (N, ...)
+ self.labels = labels (N, ...)
+
+ Optional Arguments:
+ - prescale_function: The function that takes the ratio and transforms
+ the dataset by multiplying the prescale_function output
+ - sort_columns: The columns to sort the data by initially
+ - equal_split: If the classifications should be equally split in
+ training"""
+ super().__init__(
+ dataset_filename,
+ data_columns=data_columns,
+ label_columns=label_columns,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ val_split=val_split,
+ scaler=scaler,
+ sort_columns=sort_columns,
+ random_seed=random_seed,
+ pre_scale=pre_scale,
+ **kwargs
+ )
diff --git a/code/sunlab/common/data/utilities.py b/code/sunlab/common/data/utilities.py
new file mode 100644
index 0000000..6b4e6f3
--- /dev/null
+++ b/code/sunlab/common/data/utilities.py
@@ -0,0 +1,119 @@
+from .shape_dataset import ShapeDataset
+from ..scaler.max_abs_scaler import MaxAbsScaler
+
+
+def import_10x(
+ filename,
+ magnification=10,
+ batch_size=None,
+ shuffle=False,
+ val_split=0.0,
+ scaler=None,
+ sort_columns=None,
+):
+ """# Import a 10x Image Dataset
+
+ Pixel-to-micron: ???"""
+ magnification = 10
+ dataset = ShapeDataset(
+ filename,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ pre_scale=magnification,
+ val_split=val_split,
+ scaler=scaler,
+ sort_columns=sort_columns,
+ )
+ return dataset
+
+
+def import_20x(
+ filename,
+ magnification=10,
+ batch_size=None,
+ shuffle=False,
+ val_split=0.0,
+ scaler=None,
+ sort_columns=None,
+):
+ """# Import a 20x Image Dataset
+
+ Pixel-to-micron: ???"""
+ magnification = 20
+ dataset = ShapeDataset(
+ filename,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ pre_scale=magnification,
+ val_split=val_split,
+ scaler=scaler,
+ sort_columns=sort_columns,
+ )
+ return dataset
+
+
+def import_dataset(
+ filename,
+ magnification,
+ batch_size=None,
+ shuffle=False,
+ val_split=0.0,
+ scaler=None,
+ sort_columns=None,
+):
+ """# Import a dataset
+
+ Requires a magnificaiton to be specified"""
+ dataset = ShapeDataset(
+ filename,
+ pre_scale=magnification,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ val_split=val_split,
+ scaler=scaler,
+ sort_columns=sort_columns,
+ )
+ return dataset
+
+
+def import_full_dataset(fname, magnification=20, scaler=None):
+ """# Import a Full Dataset
+
+ If a classification file exists(.txt with a 'Class' header and 'frame','cellnumber' headers), also import it"""
+ from os.path import isfile
+ import pandas as pd
+ import numpy as np
+
+ cfname = fname
+ tfname = cfname[:-3] + "txt"
+ columns = [
+ "frame",
+ "cellnumber",
+ "x-cent",
+ "y-cent",
+ "actinedge",
+ "filopodia",
+ "bleb",
+ "lamellipodia",
+ ]
+ if isfile(tfname):
+ dataset = import_dataset(cfname, magnification=magnification, scaler=scaler)
+ class_df = np.loadtxt(tfname, skiprows=1)
+ class_df = pd.DataFrame(class_df, columns=columns)
+ full_df = pd.merge(
+ dataset.dataframe,
+ class_df,
+ left_on=["Frames", "CellNum"],
+ right_on=["frame", "cellnumber"],
+ )
+ full_df["Class"] = np.argmax(
+ class_df[["actinedge", "filopodia", "bleb", "lamellipodia"]].to_numpy(),
+ axis=-1,
+ )
+ dataset.labels = full_df["Class"].to_numpy()
+ else:
+ dataset = import_dataset(cfname, magnification=magnification, scaler=scaler)
+ full_df = dataset.dataframe
+ dataset.dataframe = full_df
+ dataset.filter_off()
+ return dataset
diff --git a/code/sunlab/common/distribution/__init__.py b/code/sunlab/common/distribution/__init__.py
new file mode 100644
index 0000000..a23cb0c
--- /dev/null
+++ b/code/sunlab/common/distribution/__init__.py
@@ -0,0 +1,7 @@
+from .gaussian_distribution import *
+from .x_gaussian_distribution import *
+from .o_gaussian_distribution import *
+from .s_gaussian_distribution import *
+from .uniform_distribution import *
+from .symmetric_uniform_distribution import *
+from .swiss_roll_distribution import *
diff --git a/code/sunlab/common/distribution/adversarial_distribution.py b/code/sunlab/common/distribution/adversarial_distribution.py
new file mode 100644
index 0000000..675c00e
--- /dev/null
+++ b/code/sunlab/common/distribution/adversarial_distribution.py
@@ -0,0 +1,35 @@
+class AdversarialDistribution:
+ """# Distribution Class to use in Adversarial Autoencoder
+
+ For any distribution to be implemented, make sure to ensure each of the
+ methods are implemented"""
+
+ def __init__(self, N):
+ """# Initialize the distribution for N-dimensions"""
+ self.dims = N
+ return
+
+ def get_full_name(self):
+ """# Return a human-readable name of the distribution"""
+ return self.full_name
+
+ def get_name(self):
+ """# Return a shortened name of the distribution
+
+ Preferrably, the name should include characters that can be included in
+ a file name"""
+ return self.name
+
+ def __str__(self):
+ """# Returns the short name"""
+ return self.get_name()
+
+ def __repr__(self):
+ """# Returns the short name"""
+ return self.get_name()
+
+ def __call__(self, *args):
+ """# Magic method when calling the distribution
+
+ This method is going to be called when you use `dist(...)`"""
+ raise NotImplementedError("This distribution has not been implemented yet")
diff --git a/code/sunlab/common/distribution/gaussian_distribution.py b/code/sunlab/common/distribution/gaussian_distribution.py
new file mode 100644
index 0000000..e478ab6
--- /dev/null
+++ b/code/sunlab/common/distribution/gaussian_distribution.py
@@ -0,0 +1,23 @@
+from .adversarial_distribution import *
+
+
+class GaussianDistribution(AdversarialDistribution):
+ """# Gaussian Distribution"""
+
+ def __init__(self, N):
+ """# Gaussian Distribution Initialization
+
+ Initializes the name and dimensions"""
+ super().__init__(N)
+ self.full_name = f"{N}-Dimensional Gaussian Distribution"
+ self.name = "G"
+
+ def __call__(self, *args):
+ """# Magic method when calling the distribution
+
+ This method is going to be called when you use gauss(N1,...,Nm)"""
+ import numpy as np
+
+ return np.random.multivariate_normal(
+ mean=np.zeros(self.dims), cov=np.eye(self.dims), size=[*args]
+ )
diff --git a/code/sunlab/common/distribution/o_gaussian_distribution.py b/code/sunlab/common/distribution/o_gaussian_distribution.py
new file mode 100644
index 0000000..1222ca1
--- /dev/null
+++ b/code/sunlab/common/distribution/o_gaussian_distribution.py
@@ -0,0 +1,38 @@
+from .adversarial_distribution import *
+
+
+class OGaussianDistribution(AdversarialDistribution):
+ """# O Gaussian Distribution"""
+
+ def __init__(self, N):
+ """# O Gaussian Distribution Initialization
+
+ Initializes the name and dimensions"""
+ super().__init__(N)
+ assert self.dims == 2, "This Distribution only Supports 2-Dimensions"
+ self.full_name = "2-Dimensional O-Gaussian Distribution"
+ self.name = "OG"
+
+ def __call__(self, *args):
+ """# Magic method when calling the distribution
+
+ This method is going to be called when you use xgauss(case_count)"""
+ import numpy as np
+
+ assert len(args) == 1, "Only 1 argument supported"
+ N = args[0]
+ sample_base = np.zeros((4 * N, 2))
+ sample_base[0 * N : (0 + 1) * N, :] = np.random.multivariate_normal(
+ mean=[1, 1], cov=[[1, 0], [0, 1]], size=[N]
+ )
+ sample_base[1 * N : (1 + 1) * N, :] = np.random.multivariate_normal(
+ mean=[-1, -1], cov=[[1, 0], [0, 1]], size=[N]
+ )
+ sample_base[2 * N : (2 + 1) * N, :] = np.random.multivariate_normal(
+ mean=[-1, 1], cov=[[1, 0], [0, 1]], size=[N]
+ )
+ sample_base[3 * N : (3 + 1) * N, :] = np.random.multivariate_normal(
+ mean=[1, -1], cov=[[1, 0], [0, 1]], size=[N]
+ )
+ np.random.shuffle(sample_base)
+ return sample_base[:N, :]
diff --git a/code/sunlab/common/distribution/s_gaussian_distribution.py b/code/sunlab/common/distribution/s_gaussian_distribution.py
new file mode 100644
index 0000000..cace57f
--- /dev/null
+++ b/code/sunlab/common/distribution/s_gaussian_distribution.py
@@ -0,0 +1,40 @@
+from .adversarial_distribution import *
+
+
+class SGaussianDistribution(AdversarialDistribution):
+ """# S Gaussian Distribution"""
+
+ def __init__(self, N, scale=0):
+ """# S Gaussian Distribution Initialization
+
+ Initializes the name and dimensions"""
+ super().__init__(N)
+ assert self.dims == 2, "This Distribution only Supports 2-Dimensions"
+ self.full_name = "2-Dimensional S-Gaussian Distribution"
+ self.name = "SG"
+ self.scale = scale
+
+ def __call__(self, *args):
+ """# Magic method when calling the distribution
+
+ This method is going to be called when you use xgauss(case_count)"""
+ import numpy as np
+
+ assert len(args) == 1, "Only 1 argument supported"
+ N = args[0]
+ sample_base = np.zeros((4 * N, 2))
+ scale = self.scale
+ sample_base[0 * N : (0 + 1) * N, :] = np.random.multivariate_normal(
+ mean=[1, 1], cov=[[1, scale], [scale, 1]], size=[N]
+ )
+ sample_base[1 * N : (1 + 1) * N, :] = np.random.multivariate_normal(
+ mean=[-1, -1], cov=[[1, scale], [scale, 1]], size=[N]
+ )
+ sample_base[2 * N : (2 + 1) * N, :] = np.random.multivariate_normal(
+ mean=[-1, 1], cov=[[1, -scale], [-scale, 1]], size=[N]
+ )
+ sample_base[3 * N : (3 + 1) * N, :] = np.random.multivariate_normal(
+ mean=[1, -1], cov=[[1, -scale], [-scale, 1]], size=[N]
+ )
+ np.random.shuffle(sample_base)
+ return sample_base[:N, :]
diff --git a/code/sunlab/common/distribution/swiss_roll_distribution.py b/code/sunlab/common/distribution/swiss_roll_distribution.py
new file mode 100644
index 0000000..613bfc5
--- /dev/null
+++ b/code/sunlab/common/distribution/swiss_roll_distribution.py
@@ -0,0 +1,42 @@
+from .adversarial_distribution import *
+
+
+class SwissRollDistribution(AdversarialDistribution):
+ """# Swiss Roll Distribution"""
+
+ def __init__(self, N, scaling_factor=0.25, noise_level=0.15):
+ """# Swiss Roll Distribution Initialization
+
+ Initializes the name and dimensions"""
+ super().__init__(N)
+ assert (self.dims == 2) or (
+ self.dims == 3
+ ), "This Distribution only Supports 2,3-Dimensions"
+ self.full_name = f"{self.dims}-Dimensional Swiss Roll Distribution Distribution"
+ self.name = f"SR{self.dims}"
+ self.noise_level = noise_level
+ self.scale = scaling_factor
+
+ def __call__(self, *args):
+ """# Magic method when calling the distribution
+
+ This method is going to be called when you use xgauss(case_count)"""
+ import numpy as np
+
+ assert len(args) == 1, "Only 1 argument supported"
+ N = args[0]
+ noise = self.noise_level
+ scaling_factor = self.scale
+
+ t = 3 * np.pi / 2 * (1 + 2 * np.random.rand(1, N))
+ h = 21 * np.random.rand(1, N)
+ RANDOM = np.random.randn(3, N) * noise
+ data = (
+ np.concatenate(
+ (scaling_factor * t * np.cos(t), h, scaling_factor * t * np.sin(t))
+ )
+ + RANDOM
+ )
+ if self.dims == 2:
+ return data.T[:, [0, 2]]
+ return data.T[:, [0, 2, 1]]
diff --git a/code/sunlab/common/distribution/symmetric_uniform_distribution.py b/code/sunlab/common/distribution/symmetric_uniform_distribution.py
new file mode 100644
index 0000000..c3a4db0
--- /dev/null
+++ b/code/sunlab/common/distribution/symmetric_uniform_distribution.py
@@ -0,0 +1,21 @@
+from .adversarial_distribution import *
+
+
+class SymmetricUniformDistribution(AdversarialDistribution):
+ """# Symmetric Uniform Distribution on [-1, 1)"""
+
+ def __init__(self, N):
+ """# Symmetric Uniform Distribution Initialization
+
+ Initializes the name and dimensions"""
+ super().__init__(N)
+ self.full_name = f"{N}-Dimensional Symmetric Uniform Distribution"
+ self.name = "SU"
+
+ def __call__(self, *args):
+ """# Magic method when calling the distribution
+
+ This method is going to be called when you use suniform(N1,...,Nm)"""
+ import numpy as np
+
+ return np.random.rand(*args, self.dims) * 2.0 - 1.0
diff --git a/code/sunlab/common/distribution/uniform_distribution.py b/code/sunlab/common/distribution/uniform_distribution.py
new file mode 100644
index 0000000..3e23e67
--- /dev/null
+++ b/code/sunlab/common/distribution/uniform_distribution.py
@@ -0,0 +1,21 @@
+from .adversarial_distribution import *
+
+
+class UniformDistribution(AdversarialDistribution):
+ """# Uniform Distribution on [0, 1)"""
+
+ def __init__(self, N):
+ """# Uniform Distribution Initialization
+
+ Initializes the name and dimensions"""
+ super().__init__(N)
+ self.full_name = f"{N}-Dimensional Uniform Distribution"
+ self.name = "U"
+
+ def __call__(self, *args):
+ """# Magic method when calling the distribution
+
+ This method is going to be called when you use uniform(N1,...,Nm)"""
+ import numpy as np
+
+ return np.random.rand(*args, self.dims)
diff --git a/code/sunlab/common/distribution/x_gaussian_distribution.py b/code/sunlab/common/distribution/x_gaussian_distribution.py
new file mode 100644
index 0000000..b4330aa
--- /dev/null
+++ b/code/sunlab/common/distribution/x_gaussian_distribution.py
@@ -0,0 +1,38 @@
+from .adversarial_distribution import *
+
+
+class XGaussianDistribution(AdversarialDistribution):
+ """# X Gaussian Distribution"""
+
+ def __init__(self, N):
+ """# X Gaussian Distribution Initialization
+
+ Initializes the name and dimensions"""
+ super().__init__(N)
+ assert self.dims == 2, "This Distribution only Supports 2-Dimensions"
+ self.full_name = "2-Dimensional X-Gaussian Distribution"
+ self.name = "XG"
+
+ def __call__(self, *args):
+ """# Magic method when calling the distribution
+
+ This method is going to be called when you use xgauss(case_count)"""
+ import numpy as np
+
+ assert len(args) == 1, "Only 1 argument supported"
+ N = args[0]
+ sample_base = np.zeros((4 * N, 2))
+ sample_base[0 * N : (0 + 1) * N, :] = np.random.multivariate_normal(
+ mean=[1, 1], cov=[[1, 0.7], [0.7, 1]], size=[N]
+ )
+ sample_base[1 * N : (1 + 1) * N, :] = np.random.multivariate_normal(
+ mean=[-1, -1], cov=[[1, 0.7], [0.7, 1]], size=[N]
+ )
+ sample_base[2 * N : (2 + 1) * N, :] = np.random.multivariate_normal(
+ mean=[-1, 1], cov=[[1, -0.7], [-0.7, 1]], size=[N]
+ )
+ sample_base[3 * N : (3 + 1) * N, :] = np.random.multivariate_normal(
+ mean=[1, -1], cov=[[1, -0.7], [-0.7, 1]], size=[N]
+ )
+ np.random.shuffle(sample_base)
+ return sample_base[:N, :]
diff --git a/code/sunlab/common/mathlib/__init__.py b/code/sunlab/common/mathlib/__init__.py
new file mode 100644
index 0000000..9b5ed21
--- /dev/null
+++ b/code/sunlab/common/mathlib/__init__.py
@@ -0,0 +1 @@
+from .base import *
diff --git a/code/sunlab/common/mathlib/base.py b/code/sunlab/common/mathlib/base.py
new file mode 100644
index 0000000..38ab14c
--- /dev/null
+++ b/code/sunlab/common/mathlib/base.py
@@ -0,0 +1,57 @@
+import numpy as np
+
+
+def angle(a, b):
+ """# Get Angle Between Row Vectors"""
+ from numpy import arctan2, pi
+
+ theta_a = arctan2(a[:, 1], a[:, 0])
+ theta_b = arctan2(b[:, 1], b[:, 0])
+ d_theta = theta_b - theta_a
+ assert (-pi <= d_theta) and (d_theta <= pi), "Theta difference outside of [-π,π]"
+ return d_theta
+
+
+def normalize(column):
+ """# Normalize Column Vector"""
+ from numpy.linalg import norm
+
+ return column / norm(column, axis=0)
+
+
+def winding(xy_grid, trajectory_start, trajectory_end):
+ """# Get Winding Number on Grid"""
+ from numpy import zeros, cross, clip, arcsin
+
+ trajectories = trajectory_end - trajectory_start
+ winding = zeros((xy_grid.shape[0]))
+ for idx, trajectory in enumerate(trajectories):
+ r = xy_grid - trajectory_start[idx]
+ cross = cross(normalize(trajectory), normalize(r))
+ cross = clip(cross, -1, 1)
+ theta = arcsin(cross)
+ winding += theta
+ return winding
+
+
+def vorticity(xy_grid, trajectory_start, trajectory_end):
+ """# Get Vorticity Number on Grid"""
+ from numpy import zeros, cross
+
+ trajectories = trajectory_end - trajectory_start
+ vorticity = zeros((xy_grid.shape[0]))
+ for idx, trajectory in enumerate(trajectories):
+ r = xy_grid - trajectory_start[idx]
+ vorticity += cross(normalize(trajectory), normalize(r))
+ return vorticity
+
+
+def data_range(data):
+ """# Get the range of values for each row"""
+ from numpy import min, max
+
+ return min(data, axis=0), max(data, axis=0)
+
+
+np.normalize = normalize
+np.range = data_range
diff --git a/code/sunlab/common/mathlib/lyapunov.py b/code/sunlab/common/mathlib/lyapunov.py
new file mode 100644
index 0000000..3c747f1
--- /dev/null
+++ b/code/sunlab/common/mathlib/lyapunov.py
@@ -0,0 +1,54 @@
+def trajectory_to_distances(x):
+ """X: [N,N_t,N_d]
+ ret [N,N_t]"""
+ from numpy import zeros
+ from numpy.linalg import norm
+ from itertools import product, combinations
+
+ x = [x[idx, ...] for idx in range(x.shape[0])]
+ pairwise_trajectories = combinations(x, 2)
+ _N_COMB = len(list(pairwise_trajectories))
+ N_max = x[0].shape[0]
+ distances = zeros((_N_COMB, N_max))
+ pairwise_trajectories = combinations(x, 2)
+ for idx, (a_t, b_t) in enumerate(pairwise_trajectories):
+ distances[idx, :] = norm(a_t[:N_max, :] - b_t[:N_max, :], axis=-1)
+ return distances
+
+
+def Lyapunov_d(X):
+ """X: [N,N_t]
+ λ_n = ln(|dX_n|/|dX_0|)/n; n = [1,2,...]"""
+ from numpy import zeros, log, repeat
+
+ Y = zeros((X.shape[0], X.shape[1] - 1))
+ Y = log(X[:, 1:] / repeat([X[:, 0]], Y.shape[1], axis=0).T) / (
+ repeat([range(Y.shape[1])], Y.shape[0], axis=0) + 1
+ )
+ return Y
+
+
+def Lyapunov_t(X):
+ """X: [N,N_t,N_d]"""
+ return Lyapunov_d(trajectory_to_distances(X))
+
+
+Lyapunov = Lyapunov_d
+
+
+def RelativeDistance_d(X):
+ """X: [N,N_t]
+ λ_n = ln(|dX_n|/|dX_0|)/n; n = [1,2,...]"""
+ from numpy import zeros, log, repeat
+
+ Y = zeros((X.shape[0], X.shape[1] - 1))
+ Y = log(X[:, 1:] / repeat([X[:, 0]], Y.shape[1], axis=0).T)
+ return Y
+
+
+def RelativeDistance_t(X):
+ """X: [N,N_t,N_d]"""
+ return RelativeDistance_d(trajectory_to_distances(X))
+
+
+RelativeDistance = RelativeDistance_d
diff --git a/code/sunlab/common/mathlib/random_walks.py b/code/sunlab/common/mathlib/random_walks.py
new file mode 100644
index 0000000..3aa3bcb
--- /dev/null
+++ b/code/sunlab/common/mathlib/random_walks.py
@@ -0,0 +1,83 @@
+def get_levy_flight(T=50, D=2, t0=0.1, alpha=3, periodic=False):
+ from numpy import vstack
+ from mistree import get_levy_flight as get_flight
+
+ if D == 2:
+ x, y = get_flight(T, mode="2D", periodic=periodic, t_0=t0, alpha=alpha)
+ xy = vstack([x, y]).T
+ elif D == 3:
+ x, y, z = get_flight(T, mode="3D", periodic=periodic, t_0=t0, alpha=alpha)
+ xy = vstack([x, y, z]).T
+ else:
+ raise ValueError(f"Dimension {D} not supported!")
+ return xy
+
+
+def get_levy_flights(N=10, T=50, D=2, t0=0.1, alpha=3, periodic=False):
+ from numpy import moveaxis, array
+
+ trajectories = []
+ for _ in range(N):
+ xy = get_levy_flight(T=T, D=D, t0=t0, alpha=alpha, periodic=periodic)
+ trajectories.append(xy)
+ return moveaxis(array(trajectories), 0, 1)
+
+
+def get_jitter_levy_flights(
+ N=10, T=50, D=2, t0=0.1, alpha=3, periodic=False, noise=5e-2
+):
+ from numpy.random import randn
+
+ trajectories = get_levy_flights(
+ N=N, T=T, D=D, t0=t0, alpha=alpha, periodic=periodic
+ )
+ return trajectories + randn(*trajectories.shape) * noise
+
+
+def get_gaussian_random_walk(T=50, D=2, R=5, step_size=0.5, soft=None):
+ from numpy import array, sin, cos, exp, zeros, pi
+ from numpy.random import randn, uniform, rand
+ from numpy.linalg import norm
+
+ def is_in(x, R=1):
+ from numpy.linalg import norm
+
+ return norm(x) < R
+
+ X = zeros((T, D))
+ for t in range(1, T):
+ while True:
+ if D == 2:
+ angle = uniform(0, pi * 2)
+ step = randn(1) * step_size
+ X[t, :] = X[t - 1, :] + array([cos(angle), sin(angle)]) * step
+ else:
+ X[t, :] = X[t - 1, :] + randn(D) / D * step_size
+ if soft is None:
+ if is_in(X[t, :], R):
+ break
+ elif rand() < exp(-(norm(X[t, :]) - R) * soft):
+ break
+ return X
+
+
+def get_gaussian_random_walks(N=10, T=50, D=2, R=5, step_size=0.5, soft=None):
+ from numpy import moveaxis, array
+
+ trajectories = []
+ for _ in range(N):
+ xy = get_gaussian_random_walk(T=T, D=D, R=R, step_size=step_size, soft=soft)
+ trajectories.append(xy)
+ return moveaxis(array(trajectories), 0, 1)
+
+
+def get_gaussian_sample(T=50, D=2):
+ from numpy.random import randn
+
+ return randn(T, D)
+
+
+def get_gaussian_samples(N=10, T=50, D=2, R=5, step_size=0.5):
+ from numpy.random import randn
+
+ return randn(T, N, D)
diff --git a/code/sunlab/common/plotting/__init__.py b/code/sunlab/common/plotting/__init__.py
new file mode 100644
index 0000000..d6873aa
--- /dev/null
+++ b/code/sunlab/common/plotting/__init__.py
@@ -0,0 +1,2 @@
+from .colors import *
+from .base import *
diff --git a/code/sunlab/common/plotting/base.py b/code/sunlab/common/plotting/base.py
new file mode 100644
index 0000000..aaf4a94
--- /dev/null
+++ b/code/sunlab/common/plotting/base.py
@@ -0,0 +1,270 @@
+from matplotlib import pyplot as plt
+
+
+def blank_plot(_plt=None, _xticks=False, _yticks=False):
+ """# Remove Plot Labels"""
+ if _plt is None:
+ _plt = plt
+ _plt.xlabel("")
+ _plt.ylabel("")
+ _plt.title("")
+ tick_params = {
+ "which": "both",
+ "bottom": _xticks,
+ "left": _yticks,
+ "right": False,
+ "labelleft": False,
+ "labelbottom": False,
+ }
+ _plt.tick_params(**tick_params)
+ for child in plt.gcf().get_children():
+ if child._label == "<colorbar>":
+ child.set_yticks([])
+ axs = _plt.gcf().get_axes()
+ try:
+ axs = axs.flatten()
+ except:
+ ...
+ for ax in axs:
+ ax.set_xlabel("")
+ ax.set_ylabel("")
+ ax.set_title("")
+ ax.tick_params(**tick_params)
+
+
+def save_plot(name, _plt=None, _xticks=False, _yticks=False, tighten=True):
+ """# Save Plot in Multiple Formats"""
+ assert type(name) == str, "Name must be string"
+ from os.path import dirname
+ from os import makedirs
+
+ makedirs(dirname(name), exist_ok=True)
+ if _plt is None:
+ from matplotlib import pyplot as plt
+ _plt = plt
+ _plt.savefig(name + ".png", dpi=1000)
+ blank_plot(_plt, _xticks=_xticks, _yticks=_yticks)
+ if tighten:
+ from matplotlib import pyplot as plt
+ plt.tight_layout()
+ _plt.savefig(name + ".pdf")
+ _plt.savefig(name + ".svg")
+
+
+def scatter_2d(data_2d, labels=None, _plt=None, **matplotlib_kwargs):
+ """# Scatter 2d Data
+
+ - data_2d: 2d-dataset to plot
+ - labels: labels for each case
+ - _plt: Optional specific plot to transform"""
+ from .colors import Pcolor
+
+ if _plt is None:
+ _plt = plt
+
+ def _filter(data, labels=None, _filter_on=None):
+ if labels is None:
+ return data, False
+ else:
+ return data[labels == _filter_on], True
+
+ for _class in [2, 3, 1, 0]:
+ local_data, has_color = _filter(data_2d, labels, _class)
+ if has_color:
+ _plt.scatter(
+ local_data[:, 0],
+ local_data[:, 1],
+ color=Pcolor[_class],
+ **matplotlib_kwargs
+ )
+ else:
+ _plt.scatter(local_data[:, 0], local_data[:, 1], **matplotlib_kwargs)
+ break
+ return _plt
+
+
+def plot_contour(two_d_mask, color="black", color_map=None, raise_error=False):
+ """# Plot Contour of Mask"""
+ from matplotlib.pyplot import contour
+ from numpy import mgrid
+
+ z = two_d_mask
+ x, y = mgrid[: z.shape[1], : z.shape[0]]
+ if color_map is not None:
+ try:
+ color = color_map(color)
+ except Exception as e:
+ if raise_error:
+ raise e
+ try:
+ contour(x, y, z.T, colors=color, linewidth=0.5)
+ except Exception as e:
+ if raise_error:
+ raise e
+
+
+def gaussian_smooth_plot(
+ xy,
+ sigma=0.1,
+ extent=[-1, 1, -1, 1],
+ grid_n=100,
+ grid=None,
+ do_normalize=True,
+):
+ """# Plot Data with Gaussian Smoothening around point"""
+ from numpy import array, ndarray, linspace, meshgrid, zeros_like
+ from numpy import pi, sqrt, exp
+ from numpy.linalg import norm
+
+ if not type(xy) == ndarray:
+ xy = array(xy)
+ if grid is not None:
+ XY = grid
+ else:
+ X = linspace(extent[0], extent[1], grid_n)
+ Y = linspace(extent[2], extent[3], grid_n)
+ XY = array(meshgrid(X, Y)).T
+ smoothed = zeros_like(XY[:, :, 0])
+ factor = 1
+ if do_normalize:
+ factor = (sqrt(2 * pi) * sigma) ** 2.
+ if len(xy.shape) == 1:
+ smoothed = exp(-((norm(xy - XY, axis=-1) / (sqrt(2) * sigma)) ** 2.0)) / factor
+ else:
+ try:
+ from tqdm.notebook import tqdm
+ except Exception:
+
+ def tqdm(x):
+ return x
+
+ for i in tqdm(range(xy.shape[0])):
+ if xy.shape[-1] == 2:
+ smoothed += (
+ exp(-((norm((xy[i, :] - XY), axis=-1) / (sqrt(2) * sigma)) ** 2.0))
+ / factor
+ )
+ elif xy.shape[-1] == 3:
+ smoothed += (
+ exp(-((norm((xy[i, :2] - XY), axis=-1) / (sqrt(2) * sigma)) ** 2.0))
+ / factor
+ * xy[i, 2]
+ )
+ return smoothed, XY
+
+
+def plot_grid_data(xy_grid, data_grid, cbar=False, _plt=None, _cmap="RdBu", grid_min=1):
+ """# Plot Gridded Data"""
+ from numpy import nanmin, nanmax
+ from matplotlib.colors import TwoSlopeNorm
+
+ if _plt is None:
+ _plt = plt
+ norm = TwoSlopeNorm(
+ vmin=nanmin([-grid_min, nanmin(data_grid)]),
+ vcenter=0,
+ vmax=nanmax([grid_min, nanmax(data_grid)]),
+ )
+ _plt.pcolor(xy_grid[:, :, 0], xy_grid[:, :, 1], data_grid, cmap="RdBu", norm=norm)
+ if cbar:
+ _plt.colorbar()
+
+
+def plot_winding(xy_grid, winding_grid, cbar=False, _plt=None):
+ plot_grid_data(xy_grid, winding_grid, cbar=cbar, _plt=_plt, grid_min=1e-5)
+
+
+def plot_vorticity(xy_grid, vorticity_grid, cbar=False, save=False, _plt=None):
+ plot_grid_data(xy_grid, vorticity_grid, cbar=cbar, _plt=_plt, grid_min=1e-1)
+
+
+plt.blank = lambda: blank_plot(plt)
+plt.scatter2d = lambda data, labels=None, **matplotlib_kwargs: scatter_2d(
+ data, labels, plt, **matplotlib_kwargs
+)
+plt.save = save_plot
+
+
+def interpolate_points(df, columns=["Latent-0", "Latent-1"], kind="quadratic", N=500):
+ """# Interpolate points"""
+ from scipy.interpolate import interp1d
+ import numpy as np
+
+ points = df[columns].to_numpy()
+ distance = np.cumsum(np.sqrt(np.sum(np.diff(points, axis=0) ** 2, axis=1)))
+ distance = np.insert(distance, 0, 0) / distance[-1]
+ interpolator = interp1d(distance, points, kind=kind, axis=0)
+ alpha = np.linspace(0, 1, N)
+ interpolated_points = interpolator(alpha)
+ return interpolated_points.reshape(-1, 1, 2)
+
+
+def plot_trajectory(
+ df,
+ Fm=24,
+ FM=96,
+ interpolate=False,
+ interpolation_kind="quadratic",
+ interpolation_N=500,
+ columns=["Latent-0", "Latent-1"],
+ frame_column="Frames",
+ alpha=0.8,
+ lw=4,
+ _plt=None,
+ _z=None,
+):
+ """# Plot Trajectories
+
+ Interpolation possible"""
+ import numpy as np
+ from matplotlib.collections import LineCollection
+ import matplotlib as mpl
+
+ if _plt is None:
+ _plt = plt
+ if type(_plt) == mpl.axes._axes.Axes:
+ _ca = _plt
+ else:
+ try:
+ _ca = _plt.gca()
+ except:
+ _ca = _plt
+ X = df[columns[0]]
+ Y = df[columns[1]]
+ fm, fM = np.min(df[frame_column]), np.max(df[frame_column])
+
+ if interpolate:
+ if interpolation_kind == "cubic":
+ if len(df) <= 3:
+ return
+ elif interpolation_kind == "quadratic":
+ if len(df) <= 2:
+ return
+ else:
+ if len(df) <= 1:
+ return
+ points = interpolate_points(
+ df, kind=interpolation_kind, columns=columns, N=interpolation_N
+ )
+ else:
+ points = np.array([X, Y]).T.reshape(-1, 1, 2)
+
+ segments = np.concatenate([points[:-1], points[1:]], axis=1)
+ lc = LineCollection(
+ segments,
+ cmap=plt.get_cmap("plasma"),
+ norm=mpl.colors.Normalize(Fm, FM),
+ )
+ if _z is not None:
+ from mpl_toolkits.mplot3d.art3d import line_collection_2d_to_3d
+
+ if interpolate:
+ _z = _z # TODO: Interpolate
+ line_collection_2d_to_3d(lc, _z)
+ lc.set_array(np.linspace(fm, fM, points.shape[0]))
+ lc.set_linewidth(lw)
+ lc.set_alpha(alpha)
+ _ca.add_collection(lc)
+ _ca.autoscale()
+ _ca.margins(0.04)
+ return lc
diff --git a/code/sunlab/common/plotting/colors.py b/code/sunlab/common/plotting/colors.py
new file mode 100644
index 0000000..c4fc727
--- /dev/null
+++ b/code/sunlab/common/plotting/colors.py
@@ -0,0 +1,38 @@
+class PhenotypeColors:
+ """# Phenotype Colorings
+
+ Standardization for the different phenotype colors"""
+
+ def __init__(self):
+ """# Empty Construtor"""
+ pass
+
+ def get_basic_colors(self, transition=False):
+ """# Return the Color Names
+
+ - transition: Returns the color for the transition class too"""
+ if transition:
+ return ["yellow", "purple", "green", "blue", "cyan"]
+ return ["yellow", "purple", "green", "blue"]
+
+ def get_colors(self, transition=False):
+ """# Return the Color Names
+
+ - transition: Returns the color for the transition class too"""
+ if transition:
+ return ["#ffff00", "#ff3cfa", "#11f309", "#213ff0", "cyan"]
+ return ["#ffff00", "#ff3cfa", "#11fe09", "#213ff0"]
+
+ def get_colormap(self, transition=False):
+ """# Return the Matplotlib Colormap
+
+ - transition: Returns the color for the transition class too"""
+ from matplotlib.colors import ListedColormap as LC
+
+ return LC(self.get_colors(transition))
+
+
+# Basic Exports
+Pcolor = PhenotypeColors().get_colors()
+Pmap = PhenotypeColors().get_colormap()
+Pmapx = PhenotypeColors().get_colormap(True)
diff --git a/code/sunlab/common/scaler/__init__.py b/code/sunlab/common/scaler/__init__.py
new file mode 100644
index 0000000..2a2281a
--- /dev/null
+++ b/code/sunlab/common/scaler/__init__.py
@@ -0,0 +1,2 @@
+from .max_abs_scaler import *
+from .quantile_scaler import *
diff --git a/code/sunlab/common/scaler/adversarial_scaler.py b/code/sunlab/common/scaler/adversarial_scaler.py
new file mode 100644
index 0000000..7f61725
--- /dev/null
+++ b/code/sunlab/common/scaler/adversarial_scaler.py
@@ -0,0 +1,44 @@
+class AdversarialScaler:
+ """# Scaler Class to use in Adversarial Autoencoder
+
+ For any scaler to be implemented, make sure to ensure each of the methods
+ are implemented:
+ - __init__ (optional)
+ - init
+ - load
+ - save
+ - __call__"""
+
+ def __init__(self, base_directory):
+ """# Scaler initialization
+
+ - Initialize the base directory of the model where it will live"""
+ self.base_directory = base_directory
+
+ def init(self, data):
+ """# Scaler initialization
+
+ Initialize the scaler transformation with the data
+ Should always return self in subclasses"""
+ raise NotImplementedError("Scaler initialization has not been implemented yet")
+
+ def load(self):
+ """# Scaler loading
+
+ Load the data scaler model from a file
+ Should always return self in subclasses"""
+ raise NotImplementedError("Scaler loading has not been implemented yet")
+
+ def save(self):
+ """# Scaler saving
+
+ Save the data scaler model"""
+ raise NotImplementedError("Scaler saving has not been implemented yet")
+
+ def transform(self, *args, **kwargs):
+ """# Scale the given data"""
+ return self.__call__(*args, **kwargs)
+
+ def __call__(self, *args, **kwargs):
+ """# Scale the given data"""
+ raise NotImplementedError("Scaler has not been implemented yet")
diff --git a/code/sunlab/common/scaler/max_abs_scaler.py b/code/sunlab/common/scaler/max_abs_scaler.py
new file mode 100644
index 0000000..56ea589
--- /dev/null
+++ b/code/sunlab/common/scaler/max_abs_scaler.py
@@ -0,0 +1,48 @@
+from .adversarial_scaler import AdversarialScaler
+
+
+class MaxAbsScaler(AdversarialScaler):
+ """# MaxAbsScaler
+
+ Scale the data based on the maximum-absolute value in each column"""
+
+ def __init__(self, base_directory):
+ """# MaxAbsScaler initialization
+
+ - Initialize the base directory of the model where it will live
+ - Initialize the scaler model"""
+ super().__init__(base_directory)
+ from sklearn.preprocessing import MaxAbsScaler as MAS
+
+ self.scaler_base = MAS()
+ self.scaler = None
+
+ def init(self, data):
+ """# Scaler initialization
+
+ Initialize the scaler transformation with the data"""
+ self.scaler = self.scaler_base.fit(data)
+ return self
+
+ def load(self):
+ """# Scaler loading
+
+ Load the data scaler model from a file"""
+ from pickle import load
+
+ with open(f"{self.base_directory}/portable/maxabs_scaler.pkl", "rb") as fhandle:
+ self.scaler = load(fhandle)
+ return self
+
+ def save(self):
+ """# Scaler saving
+
+ Save the data scaler model"""
+ from pickle import dump
+
+ with open(f"{self.base_directory}/portable/maxabs_scaler.pkl", "wb") as fhandle:
+ dump(self.scaler, fhandle)
+
+ def __call__(self, *args, **kwargs):
+ """# Scale the given data"""
+ return self.scaler.transform(*args, **kwargs)
diff --git a/code/sunlab/common/scaler/quantile_scaler.py b/code/sunlab/common/scaler/quantile_scaler.py
new file mode 100644
index 0000000..a0f53fd
--- /dev/null
+++ b/code/sunlab/common/scaler/quantile_scaler.py
@@ -0,0 +1,52 @@
+from .adversarial_scaler import AdversarialScaler
+
+
+class QuantileScaler(AdversarialScaler):
+ """# QuantileScaler
+
+ Scale the data based on the quantile distributions of each column"""
+
+ def __init__(self, base_directory):
+ """# QuantileScaler initialization
+
+ - Initialize the base directory of the model where it will live
+ - Initialize the scaler model"""
+ super().__init__(base_directory)
+ from sklearn.preprocessing import QuantileTransformer as QS
+
+ self.scaler_base = QS()
+ self.scaler = None
+
+ def init(self, data):
+ """# Scaler initialization
+
+ Initialize the scaler transformation with the data"""
+ self.scaler = self.scaler_base.fit(data)
+ return self
+
+ def load(self):
+ """# Scaler loading
+
+ Load the data scaler model from a file"""
+ from pickle import load
+
+ with open(
+ f"{self.base_directory}/portable/quantile_scaler.pkl", "rb"
+ ) as fhandle:
+ self.scaler = load(fhandle)
+ return self
+
+ def save(self):
+ """# Scaler saving
+
+ Save the data scaler model"""
+ from pickle import dump
+
+ with open(
+ f"{self.base_directory}/portable/quantile_scaler.pkl", "wb"
+ ) as fhandle:
+ dump(self.scaler, fhandle)
+
+ def __call__(self, *args, **kwargs):
+ """# Scale the given data"""
+ return self.scaler.transform(*args, **kwargs)