Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/some_improvements'
Browse files Browse the repository at this point in the history
  • Loading branch information
sepehr78 committed Jul 2, 2024
2 parents e7009b0 + 772bd08 commit e00f828
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 16 deletions.
41 changes: 41 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: CI
on:
push:
branches:
- main
tags: ['*']
pull_request:
concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
version:
- '1'
os:
- ubuntu-latest
arch:
- x64
steps:
- uses: actions/checkout@v3
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: julia-actions/cache@v1
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v4
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
with:
file: lcov.info

8 changes: 8 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,11 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"



[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
37 changes: 21 additions & 16 deletions src/RecursiveCausalDiscovery.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Computes the Markov boundary matrix for all variables in-place.
- `data::DataFrame`: DataFrame where each column is a variable.
- `ci_test`: Conditional independence test to use.
"""
function find_markov_boundary_matrix!(markov_boundary_matrix::BitMatrix, data::Matrix{Float64}, ci_test::Function)
function find_markov_boundary_matrix!(markov_boundary_matrix::BitMatrix, data::AbstractMatrix, ci_test::Function)
num_vars = size(data, 2)

@threads for i in 1:(num_vars - 1)
Expand All @@ -38,18 +38,23 @@ RSL class for learning graph structure.
- `markov_boundary_matrix`: Matrix indicating whether variable i is in the Markov boundary of j.
- `skip_rem_check_vec`: Used to keep track of which variables to skip when checking for removability. Speeds up the algorithm.
"""
struct RSL
data::Matrix{Float64}
ci_test::Function
struct RSL{T, F}
data::Matrix{T}
ci_test::F
markov_boundary_matrix::BitMatrix
skip_rem_check_vec::BitVector

function RSL(data::Matrix{Float64}, ci_test::Function)
function RSL(data::Matrix{T}, ci_test::Function) where T
num_vars = size(data, 2)
new(Float64.(data), ci_test, falses(num_vars, num_vars), falses(num_vars))
new{T, typeof(ci_test)}(data, ci_test, falses(num_vars, num_vars), falses(num_vars))
end
end


_to_matrix(x) = Tables.matrix(x)
_to_matrix(x::Matrix) = x


"""
learn_and_get_skeleton(data, ci_test)::Graph
Expand All @@ -64,7 +69,7 @@ Runs the algorithm on the data to learn and return the learned skeleton graph.
"""
function learn_and_get_skeleton(data, ci_test::Function; mkbd_ci_test::Function=ci_test)::SimpleGraph
Tables.istable(data) || throw(ArgumentError("Argument does not support Tables.jl"))
data_mat = Tables.matrix(data)
data_mat = _to_matrix(data)
num_vars = size(data_mat, 2)

rsl = RSL(data_mat, ci_test)
Expand All @@ -80,20 +85,20 @@ function learn_and_get_skeleton(data, ci_test::Function; mkbd_ci_test::Function=

for i in 1:(num_vars - 1)
# only consider variables that are left and have skip check set to False
var_to_check_arr = var_arr[var_left_bool_arr .& .!rsl.skip_rem_check_vec]
var_to_check_arr = @views var_arr[var_left_bool_arr .& .!rsl.skip_rem_check_vec]

# sort the variables by the size of their markov boundary
mb_size = sum(rsl.markov_boundary_matrix[:, var_to_check_arr], dims = 1)[1, :]
mb_size = @views sum(rsl.markov_boundary_matrix[:, var_to_check_arr], dims = 1)[1, :]
sort_indices = sortperm(mb_size)
sorted_var_arr = var_to_check_arr[sort_indices]
sorted_var_arr = @views var_to_check_arr[sort_indices]

# find a removable variable
removable_var = find_removable!(rsl, sorted_var_arr)

if removable_var == REMOVABLE_NOT_FOUND
# if no removable found, then pick the variable with the smallest markov boundary from var_left_bool_arr
var_left_arr = findall(var_left_bool_arr)
mb_size_all = sum(rsl.markov_boundary_matrix[var_left_arr, :], dims = 2)
mb_size_all = @views sum(rsl.markov_boundary_matrix[var_left_arr, :], dims = 2)
removable_var = var_left_arr[argmin(mb_size_all)]

rsl.skip_rem_check_vec .= false
Expand Down Expand Up @@ -191,7 +196,7 @@ Find a removable variable in the given list of variables.
# Returns
- `Int`: Index of the removable variable.
"""
function find_removable!(rsl::RSL, var_idx_list::Vector{Int})::Int
function find_removable!(rsl::RSL, var_idx_list::AbstractVector{Int})::Int
for var_idx in var_idx_list
if is_removable(rsl, var_idx)
return var_idx
Expand All @@ -208,7 +213,7 @@ Update the Markov boundary matrix after removing a variable.
- `var_idx::Int`: Index of the variable to remove.
- `var_neighbors::Vector{Int}`: Array containing the indices of the neighbors of var_idx.
"""
function update_markov_boundary_matrix!(rsl::RSL, var_idx::Int, var_neighbors::Vector{Int})
function update_markov_boundary_matrix!(rsl::RSL, var_idx::Int, var_neighbors::AbstractVector{Int})
var_markov_boundary = findall(@view rsl.markov_boundary_matrix[:, var_idx])

# For every variable in the Markov boundary of var_idx, remove it from the Markov boundary and update flag
Expand All @@ -225,10 +230,10 @@ function update_markov_boundary_matrix!(rsl::RSL, var_idx::Int, var_neighbors::V
var_y_idx = var_neighbors[ne_idx_y]
var_z_idx = var_neighbors[ne_idx_z]

var_y_markov_boundary = findall(rsl.markov_boundary_matrix[:, var_y_idx])
var_z_markov_boundary = findall(rsl.markov_boundary_matrix[:, var_z_idx])
var_y_markov_boundary = findall(@view rsl.markov_boundary_matrix[:, var_y_idx])
var_z_markov_boundary = findall(@view rsl.markov_boundary_matrix[:, var_z_idx])

if sum(rsl.markov_boundary_matrix[:, var_y_idx]) < sum(rsl.markov_boundary_matrix[:, var_z_idx])
if @views sum(rsl.markov_boundary_matrix[:, var_y_idx]) < sum(rsl.markov_boundary_matrix[:, var_z_idx])
cond_set = setdiff(var_y_markov_boundary, [var_z_idx])
else
cond_set = setdiff(var_z_markov_boundary, [var_y_idx])
Expand Down
17 changes: 17 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using RecursiveCausalDiscovery

using Test


# navigate with your shell in the main directory of the package
# First run: using Pkg; Pkg.activate("."); using TestEnv; TestEnv.activate(); include("test/runtests.jl")
# Second run: include("test/runtests.jl")


@testset "Some unit tests" begin
@test 1 == SimpleGraph{Int64}(91, [Int64[], [50, 77], [63, 76, 97], [98], [21, 79, 82, 99], [29, 64], [16, 17, 45, 76, 96], [35, 45, 84], Int64[], [61], [60, 71], [30, 41, 73], [36], [24, 95], [54], [7, 43, 90], [7, 29, 97], [32], [33, 85], Int64[], [5], Int64[], [73], [14, 76], [100], Int64[], [70], [32, 62], [6, 17, 55], [12], Int64[], [18, 28], [19], [58], [8], [13, 48], [61, 86], Int64[], [99], [68, 88], [12, 46, 65], Int64[], [16, 53, 76, 84, 90], [55, 78, 97], [7, 8], [41, 71, 76], [73, 74, 77], [36], [59], [2], [68], Int64[], [43], [15], [29, 44, 76, 85, 98], [80], [70, 88], [34], [49, 87], [11, 88, 97], [10, 37], [28], [3, 90], [6], [41], Int64[], [69, 74], [40, 51, 85], [67], [27, 57], [11, 46], [100], [12, 23, 47], [47, 67, 82], Int64[], [3, 7, 24, 43, 46, 55, 89], [2, 47, 87], [44, 93], [5], [56], [82], [5, 74, 81], [87], [8, 43, 85], [19, 55, 68, 84, 89], [37], [59, 77, 83], [40, 57, 60], [76, 85], [16, 43, 63, 99], [95], Int64[], [78], Int64[], [14, 91], [7], [3, 17, 44, 60], [4, 55], [5, 39, 90], [25, 72]])
end

@testset "Some smaller finer tests" begin
@test true
end

0 comments on commit e00f828

Please sign in to comment.