Skip to content

Commit

Permalink
Simplifying custom MPI datatypes and ops
Browse files Browse the repository at this point in the history
  • Loading branch information
poulson committed Nov 23, 2016
1 parent 5956981 commit 4eebd49
Show file tree
Hide file tree
Showing 5 changed files with 432 additions and 409 deletions.
45 changes: 26 additions & 19 deletions examples/number_theory/SVPChallenge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
Copyright (c) 2009-2016, Jack Poulson
All rights reserved.
This file is part of Elemental and is under the BSD 2-Clause License,
which can be found in the LICENSE file in the root directory, or at
This file is part of Elemental and is under the BSD 2-Clause License,
which can be found in the LICENSE file in the root directory, or at
http://opensource.org/licenses/BSD-2-Clause
*/
#include <El.hpp>
Expand All @@ -28,9 +28,9 @@ int main( int argc, char* argv[] )
Input("--inputBasisFile","input basis file",
string("../data/number_theory/SVPChallenge40.txt"));
const bool trans = Input("--transpose","transpose input?",true);
const string outputBasisFile =
const string outputBasisFile =
Input("--outputBasisFile","output basis file",string("BKZ"));
const string shortestVecFile =
const string shortestVecFile =
Input
("--shortestVecFile","shortest vector file",string("shortest"));
const Real delta = Input("--delta","delta for LLL",Real(0.9999));
Expand Down Expand Up @@ -73,13 +73,13 @@ int main( int argc, char* argv[] )
const bool timeLLL = Input("--timeLLL","time LLL?",false);
const bool timeBKZ = Input("--timeBKZ","time BKZ?",true);
const bool progressLLL =
Input("--progressLLL","print LLL progress?",false);
Input("--progressLLL","print LLL progress?",false);
const bool progressBKZ =
Input("--progressBKZ","print BKZ progress?",true);
Input("--progressBKZ","print BKZ progress?",true);
const bool print = Input("--print","output all matrices?",true);
const bool logFailedEnums =
Input("--logFailedEnums","log failed enumerations in BKZ?",false);
const bool logStreakSizes =
const bool logStreakSizes =
Input("--logStreakSizes","log enum streak sizes in BKZ?",false);
const bool logNontrivialCoords =
Input("--logNontrivialCoords","log nontrivial enum coords?",false);
Expand Down Expand Up @@ -107,18 +107,18 @@ int main( int argc, char* argv[] )
mpfr::SetPrecision( prec );
#endif

Matrix<Real> B;
Matrix<Real> B;
if( trans )
{
Matrix<Real> BTrans;
Read( BTrans, inputBasisFile );
Transpose( BTrans, B );
Transpose( BTrans, B );
}
else
Read( B, inputBasisFile );
const Int m = B.Height();
const Int n = B.Width();
const Real BOrigOne = OneNorm( B );
const Real BOrigOne = OneNorm( B );
Output("|| B_orig ||_1 = ",BOrigOne);
if( print )
Print( B, "BOrig" );
Expand Down Expand Up @@ -155,7 +155,7 @@ int main( int argc, char* argv[] )
return 45;
*/
};
auto enumTypeLambda =
auto enumTypeLambda =
[&]( Int j )
{
if( j <= 3 )
Expand All @@ -178,7 +178,7 @@ int main( int argc, char* argv[] )
ctrl.startCol = startColBKZ;
ctrl.enumCtrl.enumType = FULL_ENUM;
ctrl.enumCtrl.time = timeEnum;
ctrl.enumCtrl.innerProgress = innerEnumProgress;
ctrl.enumCtrl.innerProgress = innerEnumProgress;
ctrl.enumCtrl.phaseLength = phaseLength;
ctrl.enumCtrl.enqueueProb = enqueueProb;
ctrl.enumCtrl.progressLevel = progressLevel;
Expand Down Expand Up @@ -210,6 +210,7 @@ int main( int argc, char* argv[] )
ctrl.enumCtrl.customMaxOneNorms = true;
const Int startIndex = Max(n/2-1,0);
const Int numPhases = ((n-startIndex)+phaseLength-1) / phaseLength;
Output("numPhases=",numPhases);
ctrl.enumCtrl.minInfNorms.resize( numPhases, 0 );
ctrl.enumCtrl.maxInfNorms.resize( numPhases, 1 );
ctrl.enumCtrl.minOneNorms.resize( numPhases, 0 );
Expand Down Expand Up @@ -248,7 +249,7 @@ int main( int argc, char* argv[] )
auto info = BKZ( B, R, ctrl );
const double runTime = mpi::Time() - startTime;
Output
(" BKZ(",blocksize,",",delta,",",eta,") took ",runTime," seconds");
(" BKZ(",blocksize,",",delta,",",eta,") took ",runTime," seconds");
Output(" achieved delta: ",info.delta);
Output(" achieved eta: ",info.eta);
Output(" num swaps: ",info.numSwaps);
Expand All @@ -261,7 +262,7 @@ int main( int argc, char* argv[] )
Output(" targetRatio*GH(L): ",challenge);
if( print )
{
Print( B, "B" );
Print( B, "B" );
Print( R, "R" );
}
Write( B, outputBasisFile, ASCII, "BKZ" );
Expand Down Expand Up @@ -289,7 +290,7 @@ int main( int argc, char* argv[] )

if( !succeeded || fullEnum )
{
const Int start = 0;
const Int start = 0;
const Int numCols = n;
const Range<Int> subInd( start, start+numCols );
auto BSub = B( ALL, subInd );
Expand All @@ -304,11 +305,17 @@ int main( int argc, char* argv[] )
timer.Start();
Real result;
if( fullEnum )
result =
ShortestVectorEnumeration( BSub, RSub, target, v, enumCtrl );
{
result =
ShortestVectorEnumeration( BSub, RSub, target, v, enumCtrl );
Output("shortest vector result = ",result);
}
else
result =
ShortVectorEnumeration( BSub, RSub, target, v, enumCtrl );
{
result =
ShortVectorEnumeration( BSub, RSub, target, v, enumCtrl );
Output("short vector result = ",result);
}
Output("Enumeration: ",timer.Stop()," seconds");
if( result < target )
{
Expand Down
38 changes: 32 additions & 6 deletions include/El/core/imports/mpi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,40 @@ const Op BINARY_XOR = MPI_BXOR;
template<typename T>
struct Types
{
static bool createdTypeBeforeResize;
static El::mpi::Datatype typeBeforeResize;

static bool createdType;
static El::mpi::Datatype type;
// CAUTION: These are not defined for all types
static El::mpi::Op sumOp, prodOp,
minOp, maxOp,
userOp, userCommOp;

static bool haveSumOp;
static bool createdSumOp;
static El::mpi::Op sumOp;

static bool haveProdOp;
static bool createdProdOp;
static El::mpi::Op prodOp;

static bool haveMinOp;
static bool createdMinOp;
static El::mpi::Op minOp;

static bool haveMaxOp;
static bool createdMaxOp;
static El::mpi::Op maxOp;

static bool haveUserOp;
static bool createdUserOp;
static El::mpi::Op userOp;

static bool haveUserCommOp;
static bool createdUserCommOp;
static El::mpi::Op userCommOp;

static function<T(const T&,const T&)> userFunc, userCommFunc;

// Internally called once per type between MPI_Init and MPI_Finalize
static void Destroy();
};

template<typename T>
Expand Down Expand Up @@ -1075,8 +1103,6 @@ void VerifySendsAndRecvs
void CreateCustom() EL_NO_RELEASE_EXCEPT;
void DestroyCustom() EL_NO_RELEASE_EXCEPT;

template<typename T> Datatype& TypeMap() EL_NO_EXCEPT;

#ifdef EL_HAVE_MPC
void CreateBigIntFamily();
void DestroyBigIntFamily();
Expand Down
28 changes: 14 additions & 14 deletions src/core/environment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
2013, Jeff Hammond
All rights reserved.
This file is part of Elemental and is under the BSD 2-Clause License,
which can be found in the LICENSE file in the root directory, or at
This file is part of Elemental and is under the BSD 2-Clause License,
which can be found in the LICENSE file in the root directory, or at
http://opensource.org/licenses/BSD-2-Clause
*/
#include <El-lite.hpp>
Expand Down Expand Up @@ -35,7 +35,7 @@ void PrintVersion( ostream& os )

void PrintConfig( ostream& os )
{
os <<
os <<
"Elemental configuration:\n" <<
" Math libraries: " << EL_MATH_LIBS << "\n"
#ifdef EL_HAVE_FLA_BSVD
Expand Down Expand Up @@ -148,7 +148,7 @@ void Initialize( int& argc, char**& argv )
("Cannot initialize elemental after finalizing MPI");
}
#ifdef EL_HYBRID
const Int provided =
const Int provided =
mpi::InitializeThread
( argc, argv, mpi::THREAD_MULTIPLE );
const int commRank = mpi::Rank( mpi::COMM_WORLD );
Expand Down Expand Up @@ -191,16 +191,16 @@ void Initialize( int& argc, char**& argv )

InitializeRandom();

// Create the types and ops
// NOTE: mpfr::SetPrecision created the BigFloat types
// Create the types and ops.
// mpfr::SetPrecision within InitializeRandom created the BigFloat types
mpi::CreateCustom();
}

void Finalize()
{
DEBUG_CSE
if( ::numElemInits <= 0 )
{
{
cerr << "Finalized Elemental more times than initialized" << endl;
return;
}
Expand All @@ -214,7 +214,7 @@ void Finalize()
::args = 0;

Grid::FinalizeDefault();

// Destroy the types and ops
mpi::DestroyCustom();

Expand Down Expand Up @@ -242,10 +242,10 @@ void Finalize()
}

Args& GetArgs()
{
{
if( args == 0 )
throw std::runtime_error("No available instance of Args");
return *::args;
return *::args;
}

void Args::HandleVersion( ostream& os ) const
Expand Down Expand Up @@ -292,15 +292,15 @@ void ReportException( const exception& e, ostream& os )
{
if( string(e.what()) != "" )
{
os << "Process " << mpi::Rank()
os << "Process " << mpi::Rank()
<< " caught an unrecoverable exception with message:\n"
<< e.what() << endl;
}
DEBUG_ONLY(DumpCallStack(os))
mpi::Abort( mpi::COMM_WORLD, 1 );
}
catch( exception& castExcept )
{
catch( exception& castExcept )
{
if( string(e.what()) != "" )
{
os << "Process " << mpi::Rank() << " caught error message:\n"
Expand Down Expand Up @@ -352,7 +352,7 @@ void Union
{
both.resize( first.size()+second.size() );
auto it = std::set_union
( first.cbegin(), first.cend(),
( first.cbegin(), first.cend(),
second.cbegin(), second.cend(),
both.begin() );
both.resize( Int(it-both.begin()) );
Expand Down
Loading

0 comments on commit 4eebd49

Please sign in to comment.