Skip to content

Commit

Permalink
pass instances
Browse files Browse the repository at this point in the history
  • Loading branch information
DE0CH committed Jan 9, 2023
1 parent fb16f87 commit 5d18b8a
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 5 deletions.
18 changes: 14 additions & 4 deletions src/irace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from rpy2.robjects.vectors import DataFrame, BoolVector, FloatVector, IntVector, StrVector, ListVector, IntArray, Matrix, ListSexpVector,FloatSexpVector,IntSexpVector,StrSexpVector,BoolSexpVector
from rpy2.robjects.functions import SignatureTranslatedFunction
from rpy2.rinterface import RRuntimeWarning
import json


rpy2conversion = ro.conversion.get_conversion()
irace_converter = ro.default_converter + numpy2ri.converter + pandas2ri.converter
Expand Down Expand Up @@ -83,6 +85,7 @@ def tmp_r_target_runner(experiment, scenario):
py_experiment['configuration'] = OrderedDict(
(k,v) for k,v in py_experiment['configuration'].items() if not pd.isna(v)
)
py_experiment['instance'] = context['py_instances'][int(py_experiment['id.instance']) - 1]
try:
ret = context['py_target_runner'](py_experiment, py_scenario)
except:
Expand All @@ -94,7 +97,6 @@ def tmp_r_target_runner(experiment, scenario):
def check_windows(scenario):
if scenario.get('parallel', 1) != 1 and os.name == 'nt':
raise NotImplementedError('Parallel running on windows is not supported yet. Follow https://github.com/auto-optimization/iracepy/issues/16 for updates. Alternatively, use Linux or MacOS or the irace R package directly.')

class irace:
# Import irace R package
try:
Expand All @@ -107,12 +109,20 @@ class irace:

def __init__(self, scenario, parameters_table, target_runner):
self.scenario = scenario
self.instances = scenario.get('instances', None)
self.context = {}
if 'instances' in scenario:
self.scenario['instances'] = np.asarray(scenario['instances'])
self.context.update({
'py_instances': self.scenario['instances'],
})
self.scenario['instances'] = StrVector(list(map(lambda x: json.dumps(x, skipkeys=True, default=self.scenario.get('instanceObjectSerializer', lambda x: '<not serializable>')), self.scenario['instances'])))
self.scenario.pop('instanceObjectSerializer', None)
with localconverter(irace_converter_hack):
self.parameters = self._pkg.readParameters(text = parameters_table, digits = self.scenario.get('digits', 4))
self.context = {'py_target_runner' : target_runner,
'py_scenario': self.scenario }
self.context.update({
'py_target_runner' : target_runner,
'py_scenario': self.scenario,
})
check_windows(scenario)

def read_configurations(self, filename=None, text=None):
Expand Down
69 changes: 68 additions & 1 deletion tests/test_data_passable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from irace import irace
import pandas as pd
from multiprocessing import Queue
import pytest
import os

q = Queue()

Expand All @@ -12,6 +14,10 @@ def target_runner(experiment, scenario):
else:
return dict(cost=1)

def target_runner2(experiment, scenario):
if experiment['id.instance'] == 1:
experiment['instance'].put(1335)
return dict(cost=1)

params = '''
one "" c ('0', '1')
Expand All @@ -34,4 +40,65 @@ def test():
tuner = irace(scenario, params, target_runner)
best_conf = tuner.run()
assert q.get() == 124


def test_instances():
q = Queue()
scenario = dict(
instances = [q],
maxExperiments = 180,
debugLevel = 0,
parallel = 1,
logFile = "",
seed = 123
)
tuner = irace(scenario, params, target_runner2)
best_conf = tuner.run()
assert q.get() == 1335

@pytest.mark.skipif(os.name == 'nt',
reason="Parallel on Windows not supported")
def test_instances2():
q = Queue()
scenario = dict(
instances = [q],
maxExperiments = 180,
debugLevel = 0,
parallel = 2,
logFile = "",
seed = 123
)
tuner = irace(scenario, params, target_runner2)
best_conf = tuner.run()
assert q.get() == 1335

def test_default_serializer():
q = Queue()
scenario = dict(
instances = [q],
maxExperiments = 180,
debugLevel = 0,
parallel = 1,
logFile = "",
seed = 123,
instanceObjectSerializer = lambda x: 'hello world'
)
tuner = irace(scenario, params, target_runner2)
best_conf = tuner.run()
assert q.get() == 1335

@pytest.mark.skipif(os.name == 'nt',
reason="Parallel on Windows not supported")
def test_default_serializer():
q = Queue()
scenario = dict(
instances = [q],
maxExperiments = 180,
debugLevel = 0,
parallel = 2,
logFile = "",
seed = 123,
instanceObjectSerializer = lambda x: 'hello world'
)
tuner = irace(scenario, params, target_runner2)
best_conf = tuner.run()
assert q.get() == 1335

0 comments on commit 5d18b8a

Please sign in to comment.