Skip to content

Commit

Permalink
Removed all old chain stuff from FunctionOfVector
Browse files Browse the repository at this point in the history
  • Loading branch information
Gareth Aneurin Tribello committed Sep 16, 2024
1 parent 5e562d7 commit a88ff36
Showing 1 changed file with 23 additions and 135 deletions.
158 changes: 23 additions & 135 deletions src/function/FunctionOfVector.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,8 @@ class FunctionOfVector : public ActionWithVector {
private:
/// Do the calculation at the end of the run
bool doAtEnd;
/// Is this the first time we are doing the calc
bool firststep;
/// The function that is being computed
T myfunc;
/// The number of derivatives for this action
unsigned nderivatives;
/// A vector that tells us if we have stored the input value
std::vector<bool> stored_arguments;
public:
static void registerKeywords(Keywords&);
/// This method is used to run the calculation with functions such as highest/lowest and sort.
Expand All @@ -60,14 +54,10 @@ class FunctionOfVector : public ActionWithVector {
unsigned getNumberOfDerivatives() override ;
/// Resize vectors that are the wrong size
void prepare() override ;
/// Check if all he actions are required
void areAllTasksRequired( std::vector<ActionWithVector*>& task_reducing_actions );
/// Get the label to write in the graph
std::string writeInGraph() const override { return myfunc.getGraphInfo( getName() ); }
/// This builds the task list for the action
void calculate() override;
/// This ensures that we create some bookeeping stuff during the first step
void setupStreamedComponents( const std::string& headstr, unsigned& nquants, unsigned& nmat, unsigned& maxcol ) override ;
/// Calculate the function
void performTask( const unsigned& current, MultiValue& myvals ) const override ;
};
Expand Down Expand Up @@ -99,9 +89,7 @@ template <class T>
FunctionOfVector<T>::FunctionOfVector(const ActionOptions&ao):
Action(ao),
ActionWithVector(ao),
doAtEnd(true),
firststep(true),
nderivatives(0)
doAtEnd(true)
{
// Get the shape of the output
std::vector<unsigned> shape(1); shape[0]=getNumberOfFinalTasks();
Expand All @@ -110,7 +98,7 @@ FunctionOfVector<T>::FunctionOfVector(const ActionOptions&ao):
// Create the task list
if( myfunc.doWithTasks() ) {
doAtEnd=false; if( shape[0]>0 ) done_in_chain=true;
} else { plumed_assert( getNumberOfArguments()==1 ); done_in_chain=false; getPntrToArgument(0)->buildDataStore(); }
} else { plumed_assert( getNumberOfArguments()==1 ); getPntrToArgument(0)->buildDataStore(); }
// Get the names of the components
std::vector<std::string> components( keywords.getOutputComponents() );
// Create the values to hold the output
Expand Down Expand Up @@ -143,35 +131,10 @@ FunctionOfVector<T>::FunctionOfVector(const ActionOptions&ao):
}
// Check if this is a timeseries
unsigned argstart=myfunc.getArgStart();
// for(unsigned i=argstart; i<getNumberOfArguments();++i) {
// if( getPntrToArgument(i)->isTimeSeries() ) {
// for(unsigned i=0; i<getNumberOfComponents(); ++i) getPntrToOutput(i)->makeHistoryDependent();
// break;
// }
// }
// Set the periodicities of the output components
myfunc.setPeriodicityForOutputs( this );
// Check if we can put the function in a chain
for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
// CollectFrames* ab=dynamic_cast<CollectFrames*>( getPntrToArgument(i)->getPntrToAction() );
// if( ab && ab->hasClear() ) { doNotChain=true; getPntrToArgument(i)->buildDataStore( getLabel() ); }
// No chains if we are using a sum or a mean
if( getPntrToArgument(i)->getRank()==0 ) {
FunctionOfVector<Sum>* as = dynamic_cast<FunctionOfVector<Sum>*>( getPntrToArgument(i)->getPntrToAction() );
if(as) done_in_chain=false;
} else {
ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() );
if( !av ) done_in_chain=false;
else if( av->getNumberOfMasks()>=0 && !myfunc.checkIfMaskAllowed( getArguments() ) ) error("cannot use argument masks in input as not all elements are computed");
}
}
// Don't need to do the calculation in a chain if the input is constant
bool allconstant=true;
for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
if( !getPntrToArgument(i)->isConstant() ) { allconstant=false; break; }
}
if( allconstant ) done_in_chain=false;
nderivatives = buildArgumentStore(myfunc.getArgStart());
// Setup the derivatives
unsigned nderivatives = buildArgumentStore(myfunc.getArgStart());
}

template <class T>
Expand All @@ -189,7 +152,9 @@ void FunctionOfVector<T>::turnOnDerivatives() {

template <class T>
unsigned FunctionOfVector<T>::getNumberOfDerivatives() {
return nderivatives;
unsigned nder = 0, argstart = myfunc.getArgStart();
for(unsigned i=argstart; i<getNumberOfArguments(); ++i) nder += getPntrToArgument(i)->getNumberOfStoredValues();
return nder;
}

template <class T>
Expand All @@ -207,102 +172,35 @@ void FunctionOfVector<T>::prepare() {
ActionWithVector::prepare();
}

template <class T>
void FunctionOfVector<T>::setupStreamedComponents( const std::string& headstr, unsigned& nquants, unsigned& nmat, unsigned& maxcol ) {
if( firststep ) {
stored_arguments.resize( getNumberOfArguments() );
std::string control = getFirstActionInChain()->getLabel();
for(unsigned i=0; i<stored_arguments.size(); ++i) {
if( getPntrToArgument(i)->isConstant() ) stored_arguments[i]=false;
else stored_arguments[i] = !getPntrToArgument(i)->ignoreStoredValue( control );
}
firststep=false;
}
ActionWithVector::setupStreamedComponents( headstr, nquants, nmat, maxcol );
}

template <class T>
void FunctionOfVector<T>::performTask( const unsigned& current, MultiValue& myvals ) const {
unsigned argstart=myfunc.getArgStart(); std::vector<double> args( getNumberOfArguments()-argstart);
if( actionInChain() ) {
for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
if( getPntrToArgument(i)->getRank()==0 ) args[i-argstart] = getPntrToArgument(i)->get();
else if( !getPntrToArgument(i)->valueHasBeenSet() ) args[i-argstart] = myvals.get( getPntrToArgument(i)->getPositionInStream() );
else args[i-argstart] = getPntrToArgument(i)->get( myvals.getTaskIndex() );
}
} else {
for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
if( getPntrToArgument(i)->getRank()==1 ) args[i-argstart]=getPntrToArgument(i)->get(current);
else args[i-argstart] = getPntrToArgument(i)->get();
}
for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
if( getPntrToArgument(i)->getRank()==1 ) args[i-argstart]=getPntrToArgument(i)->get(current);
else args[i-argstart] = getPntrToArgument(i)->get();
}
// Calculate the function and its derivatives
std::vector<double> vals( getNumberOfComponents() ); Matrix<double> derivatives( getNumberOfComponents(), args.size() );
myfunc.calc( this, args, vals, derivatives );
// And set the values
for(unsigned i=0; i<vals.size(); ++i) myvals.addValue( getConstPntrToComponent(i)->getPositionInStream(), vals[i] );
for(unsigned i=0; i<vals.size(); ++i) myvals.addValue( i, vals[i] );
// Return if we are not computing derivatives
if( doNotCalculateDerivatives() ) return;
// And now compute the derivatives
// Second condition here is only not true if actionInChain()==True if
// input arguments the only non-constant objects in input are scalars.
// In that case we can use the non chain version to calculate the derivatives
// with respect to the scalar.
if( actionInChain() ) {
for(unsigned j=0; j<args.size(); ++j) {
unsigned istrn = getPntrToArgument(argstart+j)->getPositionInStream();
if( stored_arguments[argstart+j] ) {
unsigned task_index = myvals.getTaskIndex(); if( getPntrToArgument(argstart+j)->getRank()==0 ) task_index=0;
myvals.addDerivative( istrn, task_index, 1.0 ); myvals.updateIndex( istrn, task_index );
}
unsigned arg_deriv_s = arg_deriv_starts[argstart+j];
for(unsigned k=0; k<myvals.getNumberActive(istrn); ++k) {
unsigned kind=myvals.getActiveIndex(istrn,k);
for(int i=0; i<getNumberOfComponents(); ++i) {
unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
myvals.addDerivative( ostrn, arg_deriv_s + kind, derivatives(i,j)*myvals.getDerivative( istrn, kind ) );
}
}
// Ensure we only store one lot of derivative indices
bool found=false; ActionWithValue* aav=getPntrToArgument(argstart+j)->getPntrToAction();
for(unsigned k=0; k<j; ++k) {
if( arg_deriv_starts[argstart+k]==arg_deriv_s ) {
if( getPntrToArgument(argstart+k)->getPntrToAction()!=aav ) {
ActionWithVector* av = dynamic_cast<ActionWithVector*>( getPntrToArgument(argstart+j)->getPntrToAction() );
if( av ) {
for(int i=0; i<getNumberOfComponents(); ++i) av->updateAdditionalIndices( getConstPntrToComponent(i)->getPositionInStream(), myvals );
}
}
found=true; break;
}
}
if( found ) continue;
for(unsigned k=0; k<myvals.getNumberActive(istrn); ++k) {
unsigned kind=myvals.getActiveIndex(istrn,k);
for(int i=0; i<getNumberOfComponents(); ++i) {
unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
myvals.updateIndex( ostrn, arg_deriv_s + kind );
}

unsigned base=0;
for(unsigned j=0; j<args.size(); ++j) {
if( getPntrToArgument(argstart+j)->getRank()==1 ) {
for(int i=0; i<getNumberOfComponents(); ++i) {
myvals.addDerivative( i, base+current, derivatives(i,j) );
myvals.updateIndex( i, base+current );
}
}
} else {
unsigned base=0;
for(unsigned j=0; j<args.size(); ++j) {
if( getPntrToArgument(argstart+j)->getRank()==1 ) {
for(int i=0; i<getNumberOfComponents(); ++i) {
unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
myvals.addDerivative( ostrn, base+current, derivatives(i,j) );
myvals.updateIndex( ostrn, base+current );
}
} else {
for(int i=0; i<getNumberOfComponents(); ++i) {
unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
myvals.addDerivative( ostrn, base, derivatives(i,j) );
myvals.updateIndex( ostrn, base );
}
} else {
for(int i=0; i<getNumberOfComponents(); ++i) {
myvals.addDerivative( i, base, derivatives(i,j) );
myvals.updateIndex( i, base );
}
base += getPntrToArgument(argstart+j)->getNumberOfValues();
}
base += getPntrToArgument(argstart+j)->getNumberOfValues();
}
}

Expand All @@ -313,8 +211,6 @@ unsigned FunctionOfVector<T>::getNumberOfFinalTasks() {
plumed_assert( getPntrToArgument(i)->getRank()<2 );
if( getPntrToArgument(i)->getRank()==1 ) {
if( nelements>0 ) {
// if( getPntrToArgument(i)->isTimeSeries() && getPntrToArgument(i)->getShape()[0]<nelements ) nelements=getPntrToArgument(i)->isTimeSeries();
// else
if(getPntrToArgument(i)->getShape()[0]!=nelements ) error("all vectors input should have the same length");
} else if( nelements==0 ) nelements=getPntrToArgument(i)->getShape()[0];
plumed_assert( !getPntrToArgument(i)->hasDerivatives() );
Expand All @@ -325,12 +221,6 @@ unsigned FunctionOfVector<T>::getNumberOfFinalTasks() {
return nelements;
}

template <class T>
void FunctionOfVector<T>::areAllTasksRequired( std::vector<ActionWithVector*>& task_reducing_actions ) {
if( task_reducing_actions.size()==0 ) return;
if( !myfunc.allComponentsRequired( getArguments(), task_reducing_actions ) ) task_reducing_actions.push_back(this);
}

template <class T>
void FunctionOfVector<T>::runSingleTaskCalculation( const Value* arg, ActionWithValue* action, T& f ) {
// This is used if we are doing sorting actions on a single vector
Expand All @@ -349,8 +239,6 @@ void FunctionOfVector<T>::runSingleTaskCalculation( const Value* arg, ActionWith

template <class T>
void FunctionOfVector<T>::calculate() {
// Everything is done elsewhere
if( actionInChain() ) return;
// This is done if we are calculating a function of multiple cvs
if( !doAtEnd ) runAllTasks();
// This is used if we are doing sorting actions on a single vector
Expand Down

0 comments on commit a88ff36

Please sign in to comment.