Skip to content

Commit

Permalink
Implement adaptive localization
Browse files Browse the repository at this point in the history
  • Loading branch information
dafeda committed Nov 11, 2022
1 parent 5bb7b6c commit fe5a485
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 49 deletions.
73 changes: 51 additions & 22 deletions src/ert/_c_wrappers/analysis/analysis_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class VariableInfo(TypedDict):
DEFAULT_IES_DEC_STEPLENGTH = 2.50
DEFAULT_ENKF_TRUNCATION = 0.98
DEFAULT_IES_INVERSION = 0
DEFAULT_LOCALIZATION = False
DEFAULT_LOCALIZATION_CORRELATION_THRESHOLD = 0.4


class AnalysisMode(str, Enum):
Expand All @@ -46,6 +48,18 @@ def get_mode_variables(mode: AnalysisMode) -> Dict[str, "VariableInfo"]:
"step": 0.01,
"labelname": "Singular value truncation",
},
"LOCALIZATION": {
"type": bool,
"value": DEFAULT_LOCALIZATION,
"labelname": "Switch for adaptive localization",
},
"LOCALIZATION_CORRELATION_THRESHOLD": {
"type": float,
"min": 0.0,
"value": DEFAULT_LOCALIZATION_CORRELATION_THRESHOLD,
"max": 1.0,
"labelname": "Threshold defining high correlation",
},
}
ies_variables: Dict[str, "VariableInfo"] = {
"IES_MAX_STEPLENGTH": {
Expand Down Expand Up @@ -152,30 +166,39 @@ def set_var(self, var_name: str, value: Union[float, int, bool, str]):
self.handle_special_key_set(var_name, value)
elif var_name in self._variables:
var = self._variables[var_name]
try:
new_value = var["type"](value)
if new_value > var["max"]:
var["value"] = var["max"]
logger.warning(
f"New value {new_value} for key"
f" {var_name} is out of [{var['min']}, {var['max']}] "
f"using max value {var['max']}"

if var["type"] is not bool:
try:
new_value = var["type"](value)
if new_value > var["max"]:
var["value"] = var["max"]
logger.warning(
f"New value {new_value} for key"
f" {var_name} is out of [{var['min']}, {var['max']}] "
f"using max value {var['max']}"
)
elif new_value < var["min"]:
var["value"] = var["min"]
logger.warning(
f"New value {new_value} for key"
f" {var_name} is out of [{var['min']}, {var['max']}] "
f"using min value {var['min']}"
)
else:
var["value"] = new_value

except ValueError:
raise ValueError(
f"Variable {var_name} expected type {var['type']}"
f" received value `{value}` of type `{type(value)}`"
)
elif new_value < var["min"]:
var["value"] = var["min"]
logger.warning(
f"New value {new_value} for key"
f" {var_name} is out of [{var['min']}, {var['max']}] "
f"using min value {var['min']}"
else:
if not isinstance(var["value"], bool):
raise ValueError(
f"Variable {var_name} expected type {var['type']}"
f" received value `{value}` of type `{type(value)}`"
)
else:
var["value"] = new_value

except ValueError:
raise ValueError(
f"Variable {var_name} expected type {var['type']}"
f" received value `{value}` of type `{type(value)}`"
)
var["value"] = var["type"](value)
else:
raise KeyError(f"Variable {var_name} not found in module")

Expand All @@ -190,6 +213,12 @@ def inversion(self, value):
def get_truncation(self) -> float:
return self.get_variable_value("ENKF_TRUNCATION")

def localization(self) -> bool:
return self.get_variable_value("LOCALIZATION")

def localization_correlation_threshold(self) -> float:
return self.get_variable_value("LOCALIZATION_CORRELATION_THRESHOLD")

def get_steplength(self, iteration_nr: int) -> float:
"""
This is an implementation of Eq. (49), which calculates a suitable
Expand Down
131 changes: 104 additions & 27 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import iterative_ensemble_smoother as ies
import numpy as np
import numpy.typing as npt
from iterative_ensemble_smoother.experimental import (
ensemble_smoother_update_step_row_scaling,
)
Expand Down Expand Up @@ -131,6 +132,29 @@ def _create_temporary_parameter_storage(
return temporary_storage


def correlated_parameter_response_pairs(
A: npt.NDArray[np.float_], Y: npt.NDArray[np.float_], correlation_threshold: float
) -> npt.NDArray[np.int_]:
N = A.shape[1]
Y_prime = Y - Y.mean(axis=1, keepdims=True)
C_YY = Y_prime @ Y_prime.T / (N - 1)
Sigma_Y = np.diag(np.sqrt(np.diag(C_YY)))

A_prime = A - A.mean(axis=1, keepdims=True)
C_AA = A_prime @ A_prime.T / (N - 1)

# State-measurement covariance matrix
C_AY = A_prime @ Y_prime.T / (N - 1)
Sigma_A = np.diag(np.sqrt(np.diag(C_AA)))

# State-measurement correlation matrix
c_AY = np.linalg.inv(Sigma_A) @ C_AY @ np.linalg.inv(Sigma_Y)

# _, corr_idx_Y = np.where(np.triu(np.abs(c_AY)) > correlation_threshold)
_, corr_idx_Y = np.where(np.abs(c_AY) > correlation_threshold)
return corr_idx_Y


def analysis_ES(
updatestep: "UpdateConfiguration",
obs: "EnkfObs",
Expand All @@ -154,7 +178,7 @@ def analysis_ES(
# Looping over local analysis update_step
for update_step in updatestep:

S, observation_handle = update.load_observations_and_responses(
Y, observation_handle = update.load_observations_and_responses(
source_fs,
obs,
alpha,
Expand All @@ -178,21 +202,52 @@ def analysis_ES(
A_with_rowscaling = _get_row_scaling_A_matrices(
temp_storage, update_step.row_scaling_parameters
)
noise = rng.standard_normal(size=(len(observation_values), S.shape[1]))

if A is not None:
A = ies.ensemble_smoother_update_step(
S,
A,
observation_errors,
observation_values,
noise,
module.get_truncation(),
ies.InversionType(module.inversion),
)
_save_to_temporary_storage(temp_storage, update_step.parameters, A)
if module.localization():
A_ES_loc = []
for i in range(A.shape[0]):
N = A.shape[1]
A_chunk = A[i, :].reshape(1, N)
corr_idx_Y = correlated_parameter_response_pairs(
A_chunk, Y, module.localization_correlation_threshold()
)
Y_loc = Y[corr_idx_Y, :]
observation_errors_loc = observation_errors[corr_idx_Y]
observation_values_loc = observation_values[corr_idx_Y]
noise = rng.standard_normal(
size=(len(observation_values_loc), Y.shape[1])
)

A_loc = ies.ensemble_smoother_update_step(
Y_loc,
A,
observation_errors_loc,
observation_values_loc,
noise,
module.get_truncation(),
ies.InversionType(module.inversion),
)
A_ES_loc.append(A_loc)
_save_to_temporary_storage(
temp_storage, update_step.parameters, np.vstack(A_loc)
)
else:
noise = rng.standard_normal(size=(len(observation_values), Y.shape[1]))
A = ies.ensemble_smoother_update_step(
Y,
A,
observation_errors,
observation_values,
noise,
module.get_truncation(),
ies.InversionType(module.inversion),
)
_save_to_temporary_storage(temp_storage, update_step.parameters, A)

if A_with_rowscaling:
A_with_rowscaling = ensemble_smoother_update_step_row_scaling(
S,
Y,
A_with_rowscaling,
observation_errors,
observation_values,
Expand Down Expand Up @@ -235,7 +290,7 @@ def analysis_IES(
# Looping over local analysis update_step
for update_step in updatestep:

S, observation_handle = update.load_observations_and_responses(
Y, observation_handle = update.load_observations_and_responses(
source_fs,
obs,
alpha,
Expand All @@ -258,19 +313,41 @@ def analysis_IES(

A = _get_A_matrix(temp_storage, update_step.parameters)

noise = rng.standard_normal(size=(len(observation_values), S.shape[1]))
A = iterative_ensemble_smoother.update_step(
S,
A,
observation_errors,
observation_values,
noise,
ensemble_mask=np.array(ens_mask),
observation_mask=observation_mask,
inversion=ies.InversionType(module.inversion),
truncation=module.get_truncation(),
)
_save_to_temporary_storage(temp_storage, update_step.parameters, A)
if module.localization():
corr_idx_Y = correlated_parameter_response_pairs(
A, Y, module.localization_correlation_threshold()
)
Y_loc = Y[corr_idx_Y, :]
observation_errors_loc = observation_errors[corr_idx_Y]
observation_values_loc = observation_values[corr_idx_Y]
noise = rng.standard_normal(size=(len(observation_values_loc), Y.shape[1]))

A_loc = iterative_ensemble_smoother.update_step(
Y_loc,
A,
observation_errors_loc,
observation_values_loc,
noise,
ensemble_mask=np.array(ens_mask),
observation_mask=observation_mask,
inversion=ies.InversionType(module.inversion),
truncation=module.get_truncation(),
)
_save_to_temporary_storage(temp_storage, update_step.parameters, A_loc)
else:
noise = rng.standard_normal(size=(len(observation_values), Y.shape[1]))
A = iterative_ensemble_smoother.update_step(
Y,
A,
observation_errors,
observation_values,
noise,
ensemble_mask=np.array(ens_mask),
observation_mask=observation_mask,
inversion=ies.InversionType(module.inversion),
truncation=module.get_truncation(),
)
_save_to_temporary_storage(temp_storage, update_step.parameters, A)

_save_temporary_storage_to_disk(
target_fs, ensemble_config, temp_storage, iens_active_index
Expand Down
2 changes: 2 additions & 0 deletions test-data/poly_example/poly.ert
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
JOBNAME poly_%d

ANALYSIS_SET_VAR STD_ENKF LOCALIZATION True

QUEUE_SYSTEM LOCAL
QUEUE_OPTION LOCAL MAX_RUNNING 50

Expand Down
3 changes: 3 additions & 0 deletions test-data/snake_oil/snake_oil.ert
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ NUM_REALIZATIONS 25
ANALYSIS_SET_VAR IES_ENKF IES_INVERSION 1
ANALYSIS_SET_VAR STD_ENKF IES_INVERSION 1

ANALYSIS_SET_VAR STD_ENKF LOCALIZATION True
ANALYSIS_SET_VAR STD_ENKF LOCALIZATION_CORRELATION_THRESHOLD 0.0

DEFINE <STORAGE> storage/<CONFIG_FILE_BASE>
RANDOM_SEED 3593114179000630026631423308983283277868

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
DEFAULT_IES_INVERSION,
DEFAULT_IES_MAX_STEPLENGTH,
DEFAULT_IES_MIN_STEPLENGTH,
DEFAULT_LOCALIZATION,
DEFAULT_LOCALIZATION_CORRELATION_THRESHOLD,
get_mode_variables,
)

Expand All @@ -21,12 +23,16 @@ def test_analysis_module_default_values():
"IES_DEC_STEPLENGTH": DEFAULT_IES_DEC_STEPLENGTH,
"IES_INVERSION": DEFAULT_IES_INVERSION,
"ENKF_TRUNCATION": DEFAULT_ENKF_TRUNCATION,
"LOCALIZATION": DEFAULT_LOCALIZATION,
"LOCALIZATION_CORRELATION_THRESHOLD": DEFAULT_LOCALIZATION_CORRELATION_THRESHOLD, # noqa
}

es_am = AnalysisModule.ens_smoother_module()
assert es_am.variable_value_dict() == {
"IES_INVERSION": DEFAULT_IES_INVERSION,
"ENKF_TRUNCATION": DEFAULT_ENKF_TRUNCATION,
"LOCALIZATION": DEFAULT_LOCALIZATION,
"LOCALIZATION_CORRELATION_THRESHOLD": DEFAULT_LOCALIZATION_CORRELATION_THRESHOLD, # noqa
}


Expand Down

0 comments on commit fe5a485

Please sign in to comment.