diff --git a/Project.toml b/Project.toml index 840bb96..60bf894 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Benoît Legat "] version = "0.1.0" [deps] +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SDPLR_jll = "3a057b76-36a0-51f0-a66f-6d580b8e8efd" [compat] diff --git a/src/SDPLR.jl b/src/SDPLR.jl index b9b0629..c78325f 100644 --- a/src/SDPLR.jl +++ b/src/SDPLR.jl @@ -24,6 +24,47 @@ Base.@kwdef struct Parameters typebd::Cptrdiff_t = 1 end +# See `macros.h` +datablockind(data, block, numblock) = ((data + 1) - 1) * numblock + block + +import Random +function default_R(blktype::Vector{Cchar}, blksz, maxranks) + # See `getstorage` in `main.c` + nr = sum(eachindex(blktype)) do k + if blktype[k] == Cchar('s') + return blksz[k] * maxranks[k] + elseif blktype[k] == Cchar('d') + return blksz[k] + else + return 0 + end + end + # In `main.c`, it does (rand() / RAND_MAX) - (rand() - RAND_MAX) to take the difference between + # two numbers between 0 and 1. Here, Julia's rand() is already between 0 and 1 so we don't have + # to divide. + Random.seed!(925) + return rand(nr) - rand(nr) +end + +function default_maxranks(blktype, blksz, CAinfo_entptr) + numblk = length(blktype) + m = div(length(CAinfo_entptr) - 1, numblk) - 1 + # See `getstorage` in `main.c` + return map(eachindex(blktype)) do k + if blktype[k] == Cchar('s') + cons = count(1:m) do i + ind = datablockind(i, k, numblk) + return CAinfo_entptr[ind+1] > CAinfo_entptr[ind] + end + return Csize_t(min(isqrt(2cons) + 1, blksz[k])) + elseif blktype[k] == Cchar('d') + return Csize_t(1) + else + return Csize_t(0) + end + end +end + function solve( blksz::Vector{Cptrdiff_t}, blktype::Vector{Cchar}, @@ -32,13 +73,13 @@ function solve( CArow::Vector{Csize_t}, CAcol::Vector{Csize_t}, CAinfo_entptr::Vector{Csize_t}, - CAinfo_type::Vector{Cchar}, - params::Parameters, - R::Vector{Cdouble}, - lambda::Vector{Cdouble}, - maxranks::Vector{Csize_t}, - ranks::Vector{Csize_t}, - pieces::Vector{Cdouble}, + CAinfo_type::Vector{Cchar}; + params::Parameters = Parameters(), + maxranks::Vector{Csize_t} = default_maxranks(blktype, blksz, CAinfo_entptr), + ranks::Vector{Csize_t} = copy(maxranks), + R::Vector{Cdouble} = default_R(blktype, blksz, maxranks), + lambda::Vector{Cdouble} = zeros(length(b)), + pieces::Vector{Cdouble} = Cdouble[0, 0, 0, 0, 0, 0, inv(sum(blksz)), 1], ) numblk = length(blksz) @assert length(blktype) == numblk @@ -80,7 +121,7 @@ function solve( ranks::Ptr{Csize_t}, pieces::Ptr{Cdouble}, )::Csize_t - return ret + return ret, R, lambda, ranks end end # module diff --git a/test/test_vibra.jl b/test/test_vibra.jl index 5e7062a..48d3cea 100644 --- a/test/test_vibra.jl +++ b/test/test_vibra.jl @@ -5,7 +5,7 @@ import SDPLR end @testset "Solve vibra with sdplrlib" begin include("vibra.jl") - ret = SDPLR.solve( + ret, R, lambda, ranks = SDPLR.solve( blksz, blktype, b, @@ -14,12 +14,9 @@ end CAcol, CAinfo_entptr, CAinfo_type, - SDPLR.Parameters(), - R, - lambda, - maxranks, - ranks, - pieces, ) @test iszero(ret) + @test length(R) == 477 + @test sum(lambda) ≈ -40.8133 rtol = 1e-3 + @test ranks == Csize_t[9, 9, 1] end diff --git a/test/vibra.jl b/test/vibra.jl index e93fcfe..7753eaa 100644 --- a/test/vibra.jl +++ b/test/vibra.jl @@ -2,44 +2,7 @@ m = 36 numblk = 3 blksz = Cptrdiff_t[24, 25, 36] blktype = Cchar['s', 's', 'd'] -b = [ - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, - 1.000000, -] +b = ones(m) CAinfo_entptr = Csize_t[ 0, 2, @@ -1508,495 +1471,3 @@ CAcol = Csize_t[ 36, ] CAinfo_type = repeat(blktype, m + 1) -R = Cdouble[ - 0.409043, - 0.849300, - 0.716049, - 0.050392, - -0.239808, - 0.497833, - -0.175185, - 0.074342, - 0.040529, - 0.051350, - 0.450247, - -0.149381, - 0.079064, - 0.064391, - -0.136040, - 0.830741, - -0.025521, - -0.600394, - -0.501549, - 0.006896, - 0.018504, - -0.194408, - 0.155992, - 0.519439, - 0.564224, - -0.020216, - 0.426399, - 0.320978, - 0.210292, - -0.172953, - -0.787483, - -0.652744, - 0.337829, - -0.122477, - 0.196859, - 0.094923, - 0.389147, - 0.058682, - -0.219551, - -0.258340, - 0.148910, - 0.359144, - -0.448153, - 0.080772, - 0.065490, - -0.163608, - -0.434394, - 0.465003, - -0.069491, - 0.739714, - -0.773055, - 0.342731, - -0.264848, - 0.161232, - -0.020466, - -0.139725, - -0.375664, - -0.296248, - -0.100460, - -0.061678, - -0.307658, - 0.243038, - -0.041639, - 0.161382, - 0.050567, - 0.016237, - -0.264266, - 0.893604, - -0.239618, - -0.013512, - -0.042273, - -0.131641, - 0.066182, - -0.241754, - 0.356635, - -0.069248, - -0.528718, - 0.306925, - -0.118170, - 0.318512, - 0.369403, - 0.209909, - -0.306282, - 0.576023, - 0.158349, - 0.194015, - 0.409274, - -0.301861, - 0.634485, - 0.441177, - 0.152953, - 0.040097, - -0.454280, - 0.053878, - 0.815328, - 0.359947, - -0.292861, - 0.289869, - -0.326631, - -0.145043, - 0.428403, - -0.052207, - 0.111347, - 0.313132, - 0.102318, - 0.736952, - 0.126238, - -0.120494, - 0.124071, - 0.099508, - 0.305773, - 0.124131, - 0.029310, - -0.586230, - 0.279892, - -0.465604, - -0.682300, - -0.454028, - 0.128122, - 0.174878, - -0.386587, - 0.485711, - -0.311121, - -0.588391, - 0.298600, - -0.247650, - -0.029429, - 0.617285, - 0.290481, - -0.297440, - 0.299782, - -0.721331, - -0.280855, - 0.046528, - 0.683744, - 0.077708, - 0.133236, - 0.036808, - -0.071365, - 0.635065, - -0.040467, - -0.168393, - 0.263895, - 0.929985, - 0.479947, - 0.124750, - -0.130499, - -0.602789, - 0.228059, - 0.435914, - -0.334856, - -0.791040, - -0.922960, - 0.620494, - -0.305815, - -0.767091, - -0.357026, - -0.416179, - -0.486591, - -0.538754, - 0.146351, - 0.773085, - -0.010586, - 0.604497, - -0.385966, - 0.129276, - 0.554032, - 0.077557, - -0.415995, - -0.363254, - -0.046389, - -0.068093, - -0.065827, - -0.496545, - 0.029534, - 0.440938, - 0.550696, - -0.808264, - -0.615681, - -0.242418, - -0.163356, - 0.277530, - -0.774905, - 0.021747, - 0.053138, - 0.087289, - 0.528150, - -0.396666, - -0.465077, - -0.090096, - -0.928511, - -0.259659, - -0.435136, - 0.162295, - -0.271691, - -0.052464, - -0.193264, - -0.044371, - -0.419847, - 0.840930, - -0.364130, - 0.640041, - -0.052586, - 0.626342, - 0.350163, - 0.046793, - -0.112913, - 0.043837, - -0.280789, - -0.598866, - 0.126009, - 0.012473, - -0.426264, - -0.035424, - -0.054523, - -0.212819, - -0.266276, - -0.254633, - 0.670597, - 0.163475, - -0.092047, - -0.557588, - -0.002598, - 0.246991, - -0.307619, - 0.164133, - -0.003205, - -0.248864, - 0.144708, - -0.275580, - -0.478783, - -0.328468, - -0.061177, - -0.063920, - -0.211393, - -0.369820, - -0.690177, - 0.491600, - 0.558842, - 0.023839, - -0.613284, - 0.069613, - -0.635422, - -0.045958, - 0.536461, - -0.400529, - 0.190063, - 0.312620, - 0.687902, - 0.738307, - 0.348255, - -0.826931, - -0.558920, - -0.034697, - -0.596740, - 0.251144, - 0.177114, - 0.173488, - 0.141507, - 0.050978, - -0.194008, - -0.264354, - -0.078549, - 0.124942, - 0.296966, - 0.085862, - 0.231736, - -0.696701, - -0.076401, - -0.327262, - -0.742256, - 0.116836, - 0.244640, - -0.023450, - -0.582146, - 0.064115, - 0.677976, - -0.780099, - -0.151973, - -0.791220, - 0.208092, - -0.209787, - -0.446428, - 0.112644, - 0.362114, - -0.195969, - -0.290736, - 0.054716, - 0.211191, - 0.340052, - 0.058793, - 0.852891, - 0.203927, - -0.663804, - 0.417634, - 0.196947, - 0.839618, - 0.065553, - 0.080673, - -0.129639, - -0.036420, - 0.062645, - 0.595259, - -0.373251, - 0.148627, - 0.435496, - 0.370967, - -0.055811, - 0.361177, - 0.626849, - -0.174638, - 0.404536, - 0.810956, - 0.058170, - -0.351713, - -0.665585, - -0.397077, - -0.715289, - 0.514421, - 0.004633, - -0.121739, - -0.129532, - -0.271486, - -0.175878, - 0.095043, - -0.131497, - 0.274767, - 0.585385, - -0.256871, - -0.022648, - -0.427915, - 0.003283, - -0.291172, - -0.626810, - 0.199455, - -0.113703, - -0.182050, - 0.326943, - 0.571423, - 0.636156, - 0.602036, - -0.413129, - -0.222301, - -0.037474, - -0.184190, - -0.087756, - 0.251646, - -0.437098, - 0.203522, - 0.183722, - -0.686086, - 0.480692, - -0.173216, - -0.321671, - -0.041300, - 0.064673, - -0.180855, - -0.562461, - -0.176201, - 0.117671, - -0.660057, - 0.027503, - -0.054557, - 0.552062, - 0.526497, - 0.551942, - -0.194197, - -0.027725, - -0.159731, - 0.030795, - -0.034911, - 0.359705, - -0.252989, - 0.381077, - 0.197309, - -0.789580, - -0.471336, - 0.128841, - 0.828328, - -0.376927, - 0.416687, - 0.195237, - 0.207650, - -0.144922, - -0.181884, - -0.153111, - -0.018660, - 0.567440, - 0.003838, - 0.299162, - 0.447265, - 0.738400, - -0.252495, - 0.450037, - 0.548144, - 0.930659, - -0.751875, - 0.187322, - -0.706593, - 0.503694, - -0.071728, - -0.105271, - -0.442824, - 0.637964, - 0.410686, - -0.237838, - 0.446708, - 0.833881, - 0.367803, - -0.229954, - 0.357027, - -0.119221, - 0.826600, - 0.060927, - 0.110501, - -0.768468, - 0.915203, - -0.051614, - -0.411570, - 0.033395, - -0.476523, - 0.857068, - -0.703881, - -0.835602, - -0.025133, - -0.420130, - -0.364850, - -0.062954, - 0.528751, - -0.418242, - -0.437407, - 0.678026, - -0.060449, - -0.049823, - 0.212762, - -0.472902, - -0.220872, - -0.293576, - 0.075115, - -0.260831, - -0.194733, - 0.381875, - -0.220311, - -0.097834, - -0.683107, - -0.052308, - 0.297214, - 0.357283, - -0.747871, - -0.235254, - 0.291031, - 0.201130, - -0.323991, - 0.145406, - -0.221624, - -0.198895, - 0.203614, - 0.024059, - 0.501901, - -0.708257, - -0.589850, - -0.301719, - 0.865152, - -0.055871, - 0.363024, - 0.582171, - 0.000718, - 0.724214, - -0.309145, - 0.427851, - -0.245805, - 0.052873, - -0.588064, - -0.058888, - 0.003158, - -0.267370, - -0.527711, -] -lambda = zeros(m) -maxranks = Csize_t[9, 9, 1] -ranks = Csize_t[9, 9, 1] -pieces = Cdouble[ - 0.000000, - 0.000000, - 0.000000, - 0.000000, - 0.000000, - 0.000000, - 0.011765, - 1.000000, -]