diff options
Diffstat (limited to 'code/sunlab/suntorch/plotting/model_extensions.py')
-rw-r--r-- | code/sunlab/suntorch/plotting/model_extensions.py | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/code/sunlab/suntorch/plotting/model_extensions.py b/code/sunlab/suntorch/plotting/model_extensions.py new file mode 100644 index 0000000..33f0191 --- /dev/null +++ b/code/sunlab/suntorch/plotting/model_extensions.py @@ -0,0 +1,34 @@ +from matplotlib import pyplot as plt +from sunlab.common.data.shape_dataset import ShapeDataset +from sunlab.globals import DIR_ROOT + + +def apply_boundary( + model_loc=DIR_ROOT + "models/current_model/", + border_thickness=3, + include_transition_regions=False, + threshold=0.7, + alpha=1, + _plt=None, +): + """# Apply Boundary to Plot + + Use Pregenerated Boundary by Default for Speed""" + from sunlab.common.scaler import MaxAbsScaler + import numpy as np + + if _plt is None: + _plt = plt + if (model_loc == model_loc) and (border_thickness == 3) and (threshold == 0.7): + XYM = np.load(DIR_ROOT + "extra_data/OutlineXYM.npy") + XY = XYM[:2, :, :] + if include_transition_regions: + outline = XYM[3, :, :] + else: + outline = XYM[2, :, :] + _plt.pcolor(XY[0, :, :], XY[1, :, :], outline, cmap="gray", alpha=alpha) + return + raise NotImplemented("Not Yet Implemented for PyTorch!") + + +plt.apply_boundary = apply_boundary |