-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ADD helper methods to get/create RandomState - work on #2
- Loading branch information
1 parent
e36fe4f
commit 9004e1b
Showing
1 changed file
with
36 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import numpy as np | ||
|
||
|
||
def get_rng(rng=None, self_rng=None): | ||
""" helper function to obtain RandomState. | ||
returns RandomState created from rng | ||
if rng is None returns self_rng if RandomState | ||
if self_rng is None initializes RandomState at random | ||
:param rng: int or RandomState | ||
:param self_rng: RandomState | ||
:return: RandomState | ||
""" | ||
|
||
if rng is not None: | ||
return create_rng(rng) | ||
elif rng is None and self_rng is not None: | ||
return create_rng(self_rng) | ||
else: | ||
return np.random.RandomState() | ||
|
||
|
||
def create_rng(rng): | ||
""" helper to create rng from RandomState or int | ||
:param rng: int or RandomState | ||
:return: RandomState | ||
""" | ||
if rng is None: | ||
return np.random.RandomState() | ||
elif type(rng) == np.random.RandomState: | ||
return rng | ||
elif int(rng) == rng: | ||
return np.random.RandomState(rng) | ||
else: | ||
raise ValueError("%s is neither a number nor a RandomState. " | ||
"Initializing RandomState failed") |