Skip to content

Commit

Permalink
master: valvector: ad_join: add jac_sparsity.
Browse files Browse the repository at this point in the history
  • Loading branch information
bradbell committed Mar 1, 2024
1 parent 76b7723 commit d617328
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 1 deletion.
26 changes: 26 additions & 0 deletions example/valvector/ad_join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ bool ad_join(void)
{ // ok
bool ok = true;
//
// sparsity_type
typedef CppAD::sparse_rc< CPPAD_TESTVECTOR(size_t) > sparsity_type;
//
// scalar_type
typedef valvector::scalar_type scalar_type;
//
Expand Down Expand Up @@ -74,6 +77,29 @@ bool ad_join(void)
// ok
for(size_t j = 0; j < n; ++j)
ok &= dw[0][j] == scalar_type(2) * x[0][j];
//
// jac_pattern
sparsity_type identity_pattern(n, n, n);
for(size_t k = 0; k < n; ++k)
identity_pattern.set(k, k, k);
bool transpose = false;
bool dependency = false;
bool internal_bool = false;
sparsity_type jac_pattern;
f.for_jac_sparsity(
identity_pattern, transpose, dependency, internal_bool, jac_pattern
);
//
// ok
ok &= jac_pattern.nnz() == n;
ok &= jac_pattern.nr() == m;
ok &= jac_pattern.nc() == n;
CPPAD_TESTVECTOR(size_t) col_major = jac_pattern.col_major();
for(size_t k = 0; k < n; ++k)
{ ok &= jac_pattern.row()[k] == 0;
ok &= jac_pattern.col()[k] == k;
}
//
return ok;
}
// END C++
40 changes: 40 additions & 0 deletions include/cppad/example/valvector/ad_join.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,46 @@ class valvector_ad_join_atom : public CppAD::atomic_four<valvector> {
}
return ok;
}
// ------------------------------------------------------------------------
// jac_sparsity
bool jac_sparsity(
size_t call_id ,
bool dependency ,
const CppAD::vector<bool>& ident_zero_x ,
const CppAD::vector<bool>& select_x ,
const CppAD::vector<bool>& select_y ,
CppAD::sparse_rc< CppAD::vector<size_t> >& pattern_out ) override
{ //
// ok
bool ok = true;
//
// m, n
size_t m = select_y.size();
size_t n = select_x.size();
//
assert( call_id == 0 );
assert( m == 1 );
//
// nnz
size_t nnz = 0;
if( select_y[0] )
{ for(size_t j = 0; j < n; ++j)
{ if( select_x[j] )
++nnz;
}
}
//
// pattern_out
pattern_out.resize(m, n, nnz);
size_t k = 0;
if( select_y[0] )
{ for(size_t j = 0; j < n; ++j)
{ if( select_x[j] )
pattern_out.set(k++, 0, j);
}
}
return ok;
}
};

class valvector_ad_join {
Expand Down
3 changes: 2 additions & 1 deletion include/cppad/example/valvector/ad_split.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,13 @@ class valvector_ad_split_atom : public CppAD::atomic_four<valvector> {
// ok
bool ok = true;
//
// m
// m, n
size_t m = select_y.size();
size_t n = select_x.size();
//
assert( call_id == 0 );
assert( n == 1 );
//
// nnz
size_t nnz = 0;
if( select_x[0] )
Expand Down

0 comments on commit d617328

Please sign in to comment.