From b85ee9d64a536937912544c7bbd5b98b635b7e8d Mon Sep 17 00:00:00 2001 From: Christian C Date: Mon, 11 Nov 2024 12:29:32 -0800 Subject: Initial commit --- code/sunlab/suntorch/plotting/model_extensions.py | 34 +++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 code/sunlab/suntorch/plotting/model_extensions.py (limited to 'code/sunlab/suntorch/plotting/model_extensions.py') diff --git a/code/sunlab/suntorch/plotting/model_extensions.py b/code/sunlab/suntorch/plotting/model_extensions.py new file mode 100644 index 0000000..33f0191 --- /dev/null +++ b/code/sunlab/suntorch/plotting/model_extensions.py @@ -0,0 +1,34 @@ +from matplotlib import pyplot as plt +from sunlab.common.data.shape_dataset import ShapeDataset +from sunlab.globals import DIR_ROOT + + +def apply_boundary( + model_loc=DIR_ROOT + "models/current_model/", + border_thickness=3, + include_transition_regions=False, + threshold=0.7, + alpha=1, + _plt=None, +): + """# Apply Boundary to Plot + + Use Pregenerated Boundary by Default for Speed""" + from sunlab.common.scaler import MaxAbsScaler + import numpy as np + + if _plt is None: + _plt = plt + if (model_loc == model_loc) and (border_thickness == 3) and (threshold == 0.7): + XYM = np.load(DIR_ROOT + "extra_data/OutlineXYM.npy") + XY = XYM[:2, :, :] + if include_transition_regions: + outline = XYM[3, :, :] + else: + outline = XYM[2, :, :] + _plt.pcolor(XY[0, :, :], XY[1, :, :], outline, cmap="gray", alpha=alpha) + return + raise NotImplemented("Not Yet Implemented for PyTorch!") + + +plt.apply_boundary = apply_boundary -- cgit v1.2.1