Skip to content

Commit

Permalink
Work on Highs callbacks, WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
andreaslundell committed Apr 17, 2024
1 parent 226c884 commit 2ae86c3
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 74 deletions.
2 changes: 1 addition & 1 deletion ThirdParty/HiGHS
Submodule HiGHS updated 162 files
175 changes: 104 additions & 71 deletions src/MIPSolver/MIPSolverHighs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,53 +26,96 @@ namespace SHOT
{

// Callback that correctly indents Highs log messages and prints them using SHOT's logging functionality
static void highsLogCallback(HighsLogType type, const char* message, void* envPtr)
{
auto env = static_cast<Environment*>(envPtr);

if(!env->settings->getSetting<bool>("Console.DualSolver.Show", "Output"))
return;

auto lines = Utilities::splitStringByCharacter(std::string(message), '\n');

for(auto const& line : lines)
env->output->outputInfo(fmt::format(" | {} ", line));
}

static void userInterruptCallback(const int callback_type, const char* message, const HighsCallbackDataOut* data_out,
HighsCallbackDataIn* data_in, void* user_callback_data)
{
auto env = *reinterpret_cast<std::shared_ptr<Environment>*>(user_callback_data);
auto MIPSolver = std::dynamic_pointer_cast<MIPSolverHighs>(env->dualSolver->MIPSolver);

if(callback_type == kCallBackMipFeasibleSolution)
{
std::vector<double> solution(
data_out->mip_solution, data_out->mip_solution + MIPSolver->getNumberOfVariables());

MIPSolver->currentSolutions.push_back(std::make_pair(data_out->objective_function_value, solution));

std::cout << "Sol limit " << MIPSolver->getSolutionLimit() << " number sols "
<< MIPSolver->currentSolutions.size() << std::endl;

data_in->user_interrupt = false;
}
else if(callback_type == kCallbackMipInterrupt)
{
// std::cout << "Sol limit " << MIPSolver->getSolutionLimit() << " number sols "
// << MIPSolver->currentSolutions.size() << std::endl;

if(MIPSolver->currentSolutions.size() >= MIPSolver->getSolutionLimit())
{
std::cout << "sol limit reached\n";
data_in->user_interrupt = false;
}
else
{
data_in->user_interrupt = false;
}
}
}
HighsCallbackFunctionType highsCallback
= [](int callback_type, const std::string& message, const HighsCallbackDataOut* data_out,
HighsCallbackDataIn* data_in, void* user_callback_data) {
HighsMipData callback_data = *(static_cast<HighsMipData*>(user_callback_data));
auto env = callback_data.env;

if(callback_type == kCallbackLogging)
{
if(!env->settings->getSetting<bool>("Console.DualSolver.Show", "Output"))
return;

auto lines = Utilities::splitStringByCharacter(std::string(message), '\n');

for(auto const& line : lines)
env->output->outputInfo(fmt::format(" | {} ", line));

// data_in->user_interrupt = false;
return;
}

auto MIPSolver = std::dynamic_pointer_cast<MIPSolverHighs>(env->dualSolver->MIPSolver);

if(callback_type == kCallbackMipInterrupt)
{
if(MIPSolver->currentSolutions.size() == MIPSolver->getSolutionLimit())
{
env->output->outputDebug(fmt::format(" | solution limit reached "));
data_in->user_interrupt = true;
}
else
{
data_in->user_interrupt = false;
}

return;
}

if(callback_type == kCallbackMipSolution)
{
std::vector<double> solution(
data_out->mip_solution, data_out->mip_solution + MIPSolver->getNumberOfVariables());

double hashValue = Utilities::calculateHash(solution);

for(int i = 0; i < MIPSolver->currentSolutions.size(); i++)
{
if(MIPSolver->currentSolutions[i].hashValue == hashValue)
{
return;
}
}

SolutionPoint currentSolution;
currentSolution.objectiveValue = env->reformulatedProblem->objectiveFunction->calculateValue(solution);
currentSolution.point = solution;
currentSolution.hashValue = hashValue;
MIPSolver->currentSolutions.push_back(currentSolution);

env->output->outputInfo(fmt::format(" | #sols: {} \t obj.val: {:.4f} \t gap: {:.4f} ",
MIPSolver->currentSolutions.size(), data_out->objective_function_value, data_out->mip_gap));

// Sorts the solutions so that the best one is at the first position
if(env->reformulatedProblem->objectiveFunction->properties.isMinimize)
{
std::sort(MIPSolver->currentSolutions.begin(), MIPSolver->currentSolutions.end(),
[](const SolutionPoint& firstSolution, const SolutionPoint& secondSolution) {
return (firstSolution.objectiveValue < secondSolution.objectiveValue);
});
}
else
{
std::sort(MIPSolver->currentSolutions.begin(), MIPSolver->currentSolutions.end(),
[](const SolutionPoint& firstSolution, const SolutionPoint& secondSolution) {
return (firstSolution.objectiveValue > secondSolution.objectiveValue);
});
}

/*for(int i = 0; i < MIPSolver->currentSolutions.size(); i++)
{
std::cout << fmt::format("{:.8f} \t {:.8f} ", MIPSolver->currentSolutions[i].objectiveValue,
MIPSolver->currentSolutions[i].hashValue)
<< std::endl;
}*/

// Strange that we need to set this manually
data_in->user_interrupt = false;

return;
}
};

MIPSolverHighs::MIPSolverHighs(EnvironmentPtr envPtr) { env = envPtr; }

Expand Down Expand Up @@ -288,17 +331,13 @@ void MIPSolverHighs::initializeSolverSettings()
// highsInstance.setOptionValue("highs_debug_level", 3);
// highsInstance.setOptionValue("mip_report_level", 2);

highsInstance.setOptionValue("output_flag", false);
highsInstance.setOptionValue("threads", env->settings->getSetting<int>("MIP.NumberOfThreads", "Dual"));

highsInstance.setLogCallback(highsLogCallback, (void*)env.get());

// void* p_user_callback_data = (void*)(env.get());

auto ptr1 = reinterpret_cast<void*>(&env); /* shared_ptr > void ptr */

highsInstance.setCallback(userInterruptCallback, ptr1);
highsInstance.setCallback(highsCallback, reinterpret_cast<void*>(&highsCallbackData));
highsInstance.startCallback(kCallbackMipSolution);
highsInstance.startCallback(kCallbackMipInterrupt);
highsInstance.startCallback(kCallBackMipFeasibleSolution);
highsInstance.startCallback(kCallbackLogging);
}

int MIPSolverHighs::addLinearConstraint(
Expand Down Expand Up @@ -452,6 +491,9 @@ E_ProblemSolutionStatus MIPSolverHighs::solveProblem()
cachedSolutionHasChanged = true;
currentSolutions.clear();

HighsLp lp = highsInstance.getLp();
highsCallbackData.env = env;

highsReturnStatus = highsInstance.run();
MIPSolutionStatus = getSolutionStatus();

Expand All @@ -469,32 +511,24 @@ bool MIPSolverHighs::repairInfeasibility()

int MIPSolverHighs::increaseSolutionLimit(int increment)
{
this->solLimit += increment;
this->solutionLimit += increment;

this->setSolutionLimit(this->solLimit);

return (this->solLimit);
return (this->solutionLimit);
}

void MIPSolverHighs::setSolutionLimit(long int limit)
{
if(limit > kHighsIInf)
{
highsInstance.setOptionValue("mip_max_improving_sols", kHighsIInf);
this->solLimit = kHighsIInf;
this->solutionLimit = kHighsIInf;
}
else
{
highsInstance.setOptionValue("mip_max_improving_sols", (int)limit);
this->solutionLimit = limit;
}
}

int MIPSolverHighs::getSolutionLimit()
{
HighsInt solutionLimit;
highsInstance.getOptionValue("mip_max_improving_sols", solutionLimit);
return ((int)solutionLimit);
}
int MIPSolverHighs::getSolutionLimit() { return (this->solutionLimit); }

void MIPSolverHighs::setTimeLimit(double seconds) { highsInstance.setOptionValue("time_limit", seconds); }

Expand Down Expand Up @@ -573,7 +607,6 @@ void MIPSolverHighs::addMIPStart(VectorDouble point)
HighsSolution solution;
solution.col_value = point;

auto sol = highsInstance.getSolution();
auto return_status = highsInstance.setSolution(solution);

if(return_status != HighsStatus::kOk)
Expand All @@ -599,7 +632,7 @@ double MIPSolverHighs::getObjectiveValue(int solIdx)

if(isProblemDiscrete && isMIP)
{
objectiveValue = currentSolutions.at(currentSolutions.size() - solIdx - 1).first;
objectiveValue = currentSolutions.at(solIdx).objectiveValue;
}
else
{
Expand Down Expand Up @@ -630,7 +663,7 @@ VectorDouble MIPSolverHighs::getVariableSolution(int solIdx)

if(isProblemDiscrete && isMIP)
{
solution = currentSolutions.at(currentSolutions.size() - solIdx - 1).second;
solution = currentSolutions.at(solIdx).point;
}
else
{
Expand Down
10 changes: 8 additions & 2 deletions src/MIPSolver/MIPSolverHighs.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

namespace SHOT
{
struct HighsMipData
{
EnvironmentPtr env;
};

class MIPSolverHighs : public IMIPSolver, MIPSolverBase
{
Expand Down Expand Up @@ -157,14 +161,16 @@ class MIPSolverHighs : public IMIPSolver, MIPSolverBase

std::string getSolverVersion() override;

std::vector<std::pair<double, VectorDouble>> currentSolutions;
// Objective value, solution point has, solution point
std::vector<SolutionPoint> currentSolutions;

private:
HighsModel highsModel;
Highs highsInstance;
HighsStatus highsReturnStatus;
HighsMipData highsCallbackData;

long int solLimit;
long int solutionLimit;
double timeLimit = 1e100;
double cutOff;
int numberOfThreads = 1;
Expand Down

0 comments on commit 2ae86c3

Please sign in to comment.