diff --git a/src/irace/__init__.py b/src/irace/__init__.py index a4858eb..da5ef81 100644 --- a/src/irace/__init__.py +++ b/src/irace/__init__.py @@ -15,6 +15,8 @@ from rpy2.robjects.vectors import DataFrame, BoolVector, FloatVector, IntVector, StrVector, ListVector, IntArray, Matrix, ListSexpVector,FloatSexpVector,IntSexpVector,StrSexpVector,BoolSexpVector from rpy2.robjects.functions import SignatureTranslatedFunction from rpy2.rinterface import RRuntimeWarning +import json + rpy2conversion = ro.conversion.get_conversion() irace_converter = ro.default_converter + numpy2ri.converter + pandas2ri.converter @@ -83,6 +85,7 @@ def tmp_r_target_runner(experiment, scenario): py_experiment['configuration'] = OrderedDict( (k,v) for k,v in py_experiment['configuration'].items() if not pd.isna(v) ) + py_experiment['instance'] = context['py_instances'][int(py_experiment['id.instance']) - 1] try: ret = context['py_target_runner'](py_experiment, py_scenario) except: @@ -94,7 +97,6 @@ def tmp_r_target_runner(experiment, scenario): def check_windows(scenario): if scenario.get('parallel', 1) != 1 and os.name == 'nt': raise NotImplementedError('Parallel running on windows is not supported yet. Follow https://github.com/auto-optimization/iracepy/issues/16 for updates. Alternatively, use Linux or MacOS or the irace R package directly.') - class irace: # Import irace R package try: @@ -107,12 +109,20 @@ class irace: def __init__(self, scenario, parameters_table, target_runner): self.scenario = scenario + self.instances = scenario.get('instances', None) + self.context = {} if 'instances' in scenario: - self.scenario['instances'] = np.asarray(scenario['instances']) + self.context.update({ + 'py_instances': self.scenario['instances'], + }) + self.scenario['instances'] = StrVector(list(map(lambda x: json.dumps(x, skipkeys=True, default=self.scenario.get('instanceObjectSerializer', lambda x: '')), self.scenario['instances']))) + self.scenario.pop('instanceObjectSerializer', None) with localconverter(irace_converter_hack): self.parameters = self._pkg.readParameters(text = parameters_table, digits = self.scenario.get('digits', 4)) - self.context = {'py_target_runner' : target_runner, - 'py_scenario': self.scenario } + self.context.update({ + 'py_target_runner' : target_runner, + 'py_scenario': self.scenario, + }) check_windows(scenario) def read_configurations(self, filename=None, text=None): diff --git a/tests/test_data_passable.py b/tests/test_data_passable.py index 390246d..6cdedbb 100644 --- a/tests/test_data_passable.py +++ b/tests/test_data_passable.py @@ -2,6 +2,8 @@ from irace import irace import pandas as pd from multiprocessing import Queue +import pytest +import os q = Queue() @@ -12,6 +14,10 @@ def target_runner(experiment, scenario): else: return dict(cost=1) +def target_runner2(experiment, scenario): + if experiment['id.instance'] == 1: + experiment['instance'].put(1335) + return dict(cost=1) params = ''' one "" c ('0', '1') @@ -34,4 +40,65 @@ def test(): tuner = irace(scenario, params, target_runner) best_conf = tuner.run() assert q.get() == 124 - \ No newline at end of file + +def test_instances(): + q = Queue() + scenario = dict( + instances = [q], + maxExperiments = 180, + debugLevel = 0, + parallel = 1, + logFile = "", + seed = 123 + ) + tuner = irace(scenario, params, target_runner2) + best_conf = tuner.run() + assert q.get() == 1335 + +@pytest.mark.skipif(os.name == 'nt', + reason="Parallel on Windows not supported") +def test_instances2(): + q = Queue() + scenario = dict( + instances = [q], + maxExperiments = 180, + debugLevel = 0, + parallel = 2, + logFile = "", + seed = 123 + ) + tuner = irace(scenario, params, target_runner2) + best_conf = tuner.run() + assert q.get() == 1335 + +def test_default_serializer(): + q = Queue() + scenario = dict( + instances = [q], + maxExperiments = 180, + debugLevel = 0, + parallel = 1, + logFile = "", + seed = 123, + instanceObjectSerializer = lambda x: 'hello world' + ) + tuner = irace(scenario, params, target_runner2) + best_conf = tuner.run() + assert q.get() == 1335 + +@pytest.mark.skipif(os.name == 'nt', + reason="Parallel on Windows not supported") +def test_default_serializer(): + q = Queue() + scenario = dict( + instances = [q], + maxExperiments = 180, + debugLevel = 0, + parallel = 2, + logFile = "", + seed = 123, + instanceObjectSerializer = lambda x: 'hello world' + ) + tuner = irace(scenario, params, target_runner2) + best_conf = tuner.run() + assert q.get() == 1335