aboutsummaryrefslogtreecommitdiff
path: root/code/sunlab/common/distribution/gaussian_distribution.py
blob: e478ab676bb6605334dcdb8321566bac31fcbd48 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from .adversarial_distribution import *


class GaussianDistribution(AdversarialDistribution):
    """# Gaussian Distribution"""

    def __init__(self, N):
        """# Gaussian Distribution Initialization

        Initializes the name and dimensions"""
        super().__init__(N)
        self.full_name = f"{N}-Dimensional Gaussian Distribution"
        self.name = "G"

    def __call__(self, *args):
        """# Magic method when calling the distribution

        This method is going to be called when you use gauss(N1,...,Nm)"""
        import numpy as np

        return np.random.multivariate_normal(
            mean=np.zeros(self.dims), cov=np.eye(self.dims), size=[*args]
        )