From b85ee9d64a536937912544c7bbd5b98b635b7e8d Mon Sep 17 00:00:00 2001 From: Christian C Date: Mon, 11 Nov 2024 12:29:32 -0800 Subject: Initial commit --- code/sunlab/common/plotting/base.py | 270 ++++++++++++++++++++++++++++++++++++ 1 file changed, 270 insertions(+) create mode 100644 code/sunlab/common/plotting/base.py (limited to 'code/sunlab/common/plotting/base.py') diff --git a/code/sunlab/common/plotting/base.py b/code/sunlab/common/plotting/base.py new file mode 100644 index 0000000..aaf4a94 --- /dev/null +++ b/code/sunlab/common/plotting/base.py @@ -0,0 +1,270 @@ +from matplotlib import pyplot as plt + + +def blank_plot(_plt=None, _xticks=False, _yticks=False): + """# Remove Plot Labels""" + if _plt is None: + _plt = plt + _plt.xlabel("") + _plt.ylabel("") + _plt.title("") + tick_params = { + "which": "both", + "bottom": _xticks, + "left": _yticks, + "right": False, + "labelleft": False, + "labelbottom": False, + } + _plt.tick_params(**tick_params) + for child in plt.gcf().get_children(): + if child._label == "": + child.set_yticks([]) + axs = _plt.gcf().get_axes() + try: + axs = axs.flatten() + except: + ... + for ax in axs: + ax.set_xlabel("") + ax.set_ylabel("") + ax.set_title("") + ax.tick_params(**tick_params) + + +def save_plot(name, _plt=None, _xticks=False, _yticks=False, tighten=True): + """# Save Plot in Multiple Formats""" + assert type(name) == str, "Name must be string" + from os.path import dirname + from os import makedirs + + makedirs(dirname(name), exist_ok=True) + if _plt is None: + from matplotlib import pyplot as plt + _plt = plt + _plt.savefig(name + ".png", dpi=1000) + blank_plot(_plt, _xticks=_xticks, _yticks=_yticks) + if tighten: + from matplotlib import pyplot as plt + plt.tight_layout() + _plt.savefig(name + ".pdf") + _plt.savefig(name + ".svg") + + +def scatter_2d(data_2d, labels=None, _plt=None, **matplotlib_kwargs): + """# Scatter 2d Data + + - data_2d: 2d-dataset to plot + - labels: labels for each case + - _plt: Optional specific plot to transform""" + from .colors import Pcolor + + if _plt is None: + _plt = plt + + def _filter(data, labels=None, _filter_on=None): + if labels is None: + return data, False + else: + return data[labels == _filter_on], True + + for _class in [2, 3, 1, 0]: + local_data, has_color = _filter(data_2d, labels, _class) + if has_color: + _plt.scatter( + local_data[:, 0], + local_data[:, 1], + color=Pcolor[_class], + **matplotlib_kwargs + ) + else: + _plt.scatter(local_data[:, 0], local_data[:, 1], **matplotlib_kwargs) + break + return _plt + + +def plot_contour(two_d_mask, color="black", color_map=None, raise_error=False): + """# Plot Contour of Mask""" + from matplotlib.pyplot import contour + from numpy import mgrid + + z = two_d_mask + x, y = mgrid[: z.shape[1], : z.shape[0]] + if color_map is not None: + try: + color = color_map(color) + except Exception as e: + if raise_error: + raise e + try: + contour(x, y, z.T, colors=color, linewidth=0.5) + except Exception as e: + if raise_error: + raise e + + +def gaussian_smooth_plot( + xy, + sigma=0.1, + extent=[-1, 1, -1, 1], + grid_n=100, + grid=None, + do_normalize=True, +): + """# Plot Data with Gaussian Smoothening around point""" + from numpy import array, ndarray, linspace, meshgrid, zeros_like + from numpy import pi, sqrt, exp + from numpy.linalg import norm + + if not type(xy) == ndarray: + xy = array(xy) + if grid is not None: + XY = grid + else: + X = linspace(extent[0], extent[1], grid_n) + Y = linspace(extent[2], extent[3], grid_n) + XY = array(meshgrid(X, Y)).T + smoothed = zeros_like(XY[:, :, 0]) + factor = 1 + if do_normalize: + factor = (sqrt(2 * pi) * sigma) ** 2. + if len(xy.shape) == 1: + smoothed = exp(-((norm(xy - XY, axis=-1) / (sqrt(2) * sigma)) ** 2.0)) / factor + else: + try: + from tqdm.notebook import tqdm + except Exception: + + def tqdm(x): + return x + + for i in tqdm(range(xy.shape[0])): + if xy.shape[-1] == 2: + smoothed += ( + exp(-((norm((xy[i, :] - XY), axis=-1) / (sqrt(2) * sigma)) ** 2.0)) + / factor + ) + elif xy.shape[-1] == 3: + smoothed += ( + exp(-((norm((xy[i, :2] - XY), axis=-1) / (sqrt(2) * sigma)) ** 2.0)) + / factor + * xy[i, 2] + ) + return smoothed, XY + + +def plot_grid_data(xy_grid, data_grid, cbar=False, _plt=None, _cmap="RdBu", grid_min=1): + """# Plot Gridded Data""" + from numpy import nanmin, nanmax + from matplotlib.colors import TwoSlopeNorm + + if _plt is None: + _plt = plt + norm = TwoSlopeNorm( + vmin=nanmin([-grid_min, nanmin(data_grid)]), + vcenter=0, + vmax=nanmax([grid_min, nanmax(data_grid)]), + ) + _plt.pcolor(xy_grid[:, :, 0], xy_grid[:, :, 1], data_grid, cmap="RdBu", norm=norm) + if cbar: + _plt.colorbar() + + +def plot_winding(xy_grid, winding_grid, cbar=False, _plt=None): + plot_grid_data(xy_grid, winding_grid, cbar=cbar, _plt=_plt, grid_min=1e-5) + + +def plot_vorticity(xy_grid, vorticity_grid, cbar=False, save=False, _plt=None): + plot_grid_data(xy_grid, vorticity_grid, cbar=cbar, _plt=_plt, grid_min=1e-1) + + +plt.blank = lambda: blank_plot(plt) +plt.scatter2d = lambda data, labels=None, **matplotlib_kwargs: scatter_2d( + data, labels, plt, **matplotlib_kwargs +) +plt.save = save_plot + + +def interpolate_points(df, columns=["Latent-0", "Latent-1"], kind="quadratic", N=500): + """# Interpolate points""" + from scipy.interpolate import interp1d + import numpy as np + + points = df[columns].to_numpy() + distance = np.cumsum(np.sqrt(np.sum(np.diff(points, axis=0) ** 2, axis=1))) + distance = np.insert(distance, 0, 0) / distance[-1] + interpolator = interp1d(distance, points, kind=kind, axis=0) + alpha = np.linspace(0, 1, N) + interpolated_points = interpolator(alpha) + return interpolated_points.reshape(-1, 1, 2) + + +def plot_trajectory( + df, + Fm=24, + FM=96, + interpolate=False, + interpolation_kind="quadratic", + interpolation_N=500, + columns=["Latent-0", "Latent-1"], + frame_column="Frames", + alpha=0.8, + lw=4, + _plt=None, + _z=None, +): + """# Plot Trajectories + + Interpolation possible""" + import numpy as np + from matplotlib.collections import LineCollection + import matplotlib as mpl + + if _plt is None: + _plt = plt + if type(_plt) == mpl.axes._axes.Axes: + _ca = _plt + else: + try: + _ca = _plt.gca() + except: + _ca = _plt + X = df[columns[0]] + Y = df[columns[1]] + fm, fM = np.min(df[frame_column]), np.max(df[frame_column]) + + if interpolate: + if interpolation_kind == "cubic": + if len(df) <= 3: + return + elif interpolation_kind == "quadratic": + if len(df) <= 2: + return + else: + if len(df) <= 1: + return + points = interpolate_points( + df, kind=interpolation_kind, columns=columns, N=interpolation_N + ) + else: + points = np.array([X, Y]).T.reshape(-1, 1, 2) + + segments = np.concatenate([points[:-1], points[1:]], axis=1) + lc = LineCollection( + segments, + cmap=plt.get_cmap("plasma"), + norm=mpl.colors.Normalize(Fm, FM), + ) + if _z is not None: + from mpl_toolkits.mplot3d.art3d import line_collection_2d_to_3d + + if interpolate: + _z = _z # TODO: Interpolate + line_collection_2d_to_3d(lc, _z) + lc.set_array(np.linspace(fm, fM, points.shape[0])) + lc.set_linewidth(lw) + lc.set_alpha(alpha) + _ca.add_collection(lc) + _ca.autoscale() + _ca.margins(0.04) + return lc -- cgit v1.2.1