Skip to content

Commit

Permalink
Add default starting solution (#7)
Browse files Browse the repository at this point in the history
* Add default starting solution

* fixes
  • Loading branch information
blegat authored Nov 16, 2023
1 parent a5afb31 commit 5b8de35
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 545 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Benoît Legat <[email protected]>"]
version = "0.1.0"

[deps]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SDPLR_jll = "3a057b76-36a0-51f0-a66f-6d580b8e8efd"

[compat]
Expand Down
57 changes: 49 additions & 8 deletions src/SDPLR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,47 @@ Base.@kwdef struct Parameters
typebd::Cptrdiff_t = 1
end

# See `macros.h`
datablockind(data, block, numblock) = ((data + 1) - 1) * numblock + block

import Random
function default_R(blktype::Vector{Cchar}, blksz, maxranks)
# See `getstorage` in `main.c`
nr = sum(eachindex(blktype)) do k
if blktype[k] == Cchar('s')
return blksz[k] * maxranks[k]
elseif blktype[k] == Cchar('d')
return blksz[k]
else
return 0
end
end
# In `main.c`, it does (rand() / RAND_MAX) - (rand() - RAND_MAX) to take the difference between
# two numbers between 0 and 1. Here, Julia's rand() is already between 0 and 1 so we don't have
# to divide.
Random.seed!(925)
return rand(nr) - rand(nr)
end

function default_maxranks(blktype, blksz, CAinfo_entptr)
numblk = length(blktype)
m = div(length(CAinfo_entptr) - 1, numblk) - 1
# See `getstorage` in `main.c`
return map(eachindex(blktype)) do k
if blktype[k] == Cchar('s')
cons = count(1:m) do i
ind = datablockind(i, k, numblk)
return CAinfo_entptr[ind+1] > CAinfo_entptr[ind]
end
return Csize_t(min(isqrt(2cons) + 1, blksz[k]))
elseif blktype[k] == Cchar('d')
return Csize_t(1)
else
return Csize_t(0)
end
end
end

function solve(
blksz::Vector{Cptrdiff_t},
blktype::Vector{Cchar},
Expand All @@ -32,13 +73,13 @@ function solve(
CArow::Vector{Csize_t},
CAcol::Vector{Csize_t},
CAinfo_entptr::Vector{Csize_t},
CAinfo_type::Vector{Cchar},
params::Parameters,
R::Vector{Cdouble},
lambda::Vector{Cdouble},
maxranks::Vector{Csize_t},
ranks::Vector{Csize_t},
pieces::Vector{Cdouble},
CAinfo_type::Vector{Cchar};
params::Parameters = Parameters(),
maxranks::Vector{Csize_t} = default_maxranks(blktype, blksz, CAinfo_entptr),
ranks::Vector{Csize_t} = copy(maxranks),
R::Vector{Cdouble} = default_R(blktype, blksz, maxranks),
lambda::Vector{Cdouble} = zeros(length(b)),
pieces::Vector{Cdouble} = Cdouble[0, 0, 0, 0, 0, 0, inv(sum(blksz)), 1],
)
numblk = length(blksz)
@assert length(blktype) == numblk
Expand Down Expand Up @@ -80,7 +121,7 @@ function solve(
ranks::Ptr{Csize_t},
pieces::Ptr{Cdouble},
)::Csize_t
return ret
return ret, R, lambda, ranks
end

end # module
11 changes: 4 additions & 7 deletions test/test_vibra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import SDPLR
end
@testset "Solve vibra with sdplrlib" begin
include("vibra.jl")
ret = SDPLR.solve(
ret, R, lambda, ranks = SDPLR.solve(
blksz,
blktype,
b,
Expand All @@ -14,12 +14,9 @@ end
CAcol,
CAinfo_entptr,
CAinfo_type,
SDPLR.Parameters(),
R,
lambda,
maxranks,
ranks,
pieces,
)
@test iszero(ret)
@test length(R) == 477
@test sum(lambda) -40.8133 rtol = 1e-3
@test ranks == Csize_t[9, 9, 1]
end
Loading

0 comments on commit 5b8de35

Please sign in to comment.