Skip to content

Commit

Permalink
Merge pull request #45 from cirKITers/init-strategy
Browse files Browse the repository at this point in the history
Init strategy
  • Loading branch information
majafranz authored Sep 30, 2024
2 parents 42278cc + 34a7b01 commit e246996
Showing 1 changed file with 56 additions and 44 deletions.
100 changes: 56 additions & 44 deletions qml_essentials/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,54 +88,17 @@ def __init__(
self.degree = 1

log.info(f"Number of implicit layers set to {impl_n_layers}.")

params_shape: Tuple[int, int] = (
# calculate the shape of the parameter vector here, we will re-use this in init.
self._params_shape: Tuple[int, int] = (
impl_n_layers,
self.pqc.n_params_per_layer(self.n_qubits),
)
# this will also be re-used in the init method,
# however, only if nothing is provided
self._inialization_strategy = initialization

def set_control_params(params: np.ndarray, value: float) -> np.ndarray:
indices = self.pqc.get_control_indices(self.n_qubits)
if indices is None:
warnings.warn(
f"Specified {initialization} but circuit\
does not contain controlled rotation gates.\
Parameters are intialized randomly.",
UserWarning,
)
else:
params[:, indices[0] : indices[1] : indices[2]] = (
np.ones_like(params[:, indices[0] : indices[1] : indices[2]])
* value
)
return params

rng = np.random.default_rng(random_seed)
if initialization == "random":
self.params: np.ndarray = rng.uniform(
0, 2 * np.pi, params_shape, requires_grad=True
)
elif initialization == "zeros":
self.params: np.ndarray = np.zeros(params_shape, requires_grad=True)
elif initialization == "pi":
self.params: np.ndarray = np.ones(params_shape, requires_grad=True) * np.pi
elif initialization == "zero-controlled":
self.params: np.ndarray = rng.uniform(
0, 2 * np.pi, params_shape, requires_grad=True
)
self.params = set_control_params(self.params, 0)
elif initialization == "pi-controlled":
self.params: np.ndarray = rng.uniform(
0, 2 * np.pi, params_shape, requires_grad=True
)
self.params = set_control_params(self.params, np.pi)
else:
raise Exception("Invalid initialization method")

log.info(
f"Initialized parameters with shape {self.params.shape}\
using strategy {initialization}."
)
# ..here! where we only require a seed
self.initialize_params(random_seed)

# Initialize two circuits, one with the default device and
# one with the mixed device
Expand Down Expand Up @@ -238,6 +201,55 @@ def shots(self, value: Optional[int]) -> None:
value = None
self._shots = value

def initialize_params(self, random_seed, initialization: str = None) -> None:
# use existing strategy if not specified
initialization = initialization or self._inialization_strategy

def set_control_params(params: np.ndarray, value: float) -> np.ndarray:
indices = self.pqc.get_control_indices(self.n_qubits)
if indices is None:
warnings.warn(
f"Specified {initialization} but circuit\
does not contain controlled rotation gates.\
Parameters are intialized randomly.",
UserWarning,
)
else:
params[:, indices[0] : indices[1] : indices[2]] = (
np.ones_like(params[:, indices[0] : indices[1] : indices[2]])
* value
)
return params

rng = np.random.default_rng(random_seed)
if initialization == "random":
self.params: np.ndarray = rng.uniform(
0, 2 * np.pi, self._params_shape, requires_grad=True
)
elif initialization == "zeros":
self.params: np.ndarray = np.zeros(self._params_shape, requires_grad=True)
elif initialization == "pi":
self.params: np.ndarray = (
np.ones(self._params_shape, requires_grad=True) * np.pi
)
elif initialization == "zero-controlled":
self.params: np.ndarray = rng.uniform(
0, 2 * np.pi, self._params_shape, requires_grad=True
)
self.params = set_control_params(self.params, 0)
elif initialization == "pi-controlled":
self.params: np.ndarray = rng.uniform(
0, 2 * np.pi, self._params_shape, requires_grad=True
)
self.params = set_control_params(self.params, np.pi)
else:
raise Exception("Invalid initialization method")

log.info(
f"Initialized parameters with shape {self.params.shape}\
using strategy {initialization}."
)

def _iec(
self,
inputs: np.ndarray,
Expand Down

0 comments on commit e246996

Please sign in to comment.