Skip to content

Commit

Permalink
do svm and ranger
Browse files Browse the repository at this point in the history
  • Loading branch information
thierrymoudiki committed Jan 18, 2024
1 parent b57eff7 commit a21d107
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 8 deletions.
47 changes: 45 additions & 2 deletions src/Ahead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,51 @@ module Ahead
return rcopy(R"ahead::eatf(y=$y, h=$h, level=$level)")
end

function dynrmf(y; h=5, level=95)
return rcopy(R"ahead::dynrmf(y=$y, h=$h, level=$level)")
function dynrmf(y; h=5, level=95, method="autoridge", kwargs...)

if method == "autoridge"

return rcopy(R"ahead::dynrmf(y=$y, h=$h, level=$level)")

elseif method == "svm"

kwargs = Dict(kwargs)
if haskey(kwargs, "kernel")
kernel = kwargs["kernel"]
else
kernel = "radial"
end
R"z <- dynrmf(fdeaths, h=20, level=95, fit_func = e1071::svm,
fit_params = list(kernel = $kernel), predict_func = predict)"
return rcopy(R"z")

elseif method == "randomforest"

kwargs = Dict(kwargs)
if haskey(kwargs, "num_trees")
num_trees = kwargs["num_trees"]
else
num_trees = 500
end
R"""
fit_func <- function(x, y, ...)
{
df <- data.frame(y=y, x) # naming of columns is mandatory for `predict`
ranger::ranger(y ~ ., data=df, ...)
};
predict_func <- function(obj, newx)
{
colnames(newx) <- paste0("X", 1:ncol(newx)) # mandatory, linked to df in fit_func
predict(object=obj, data=newx)$predictions # only accepts a named newx
};
z <- ahead::dynrmf(y=$y, h=$h, level=$level,
fit_func = fit_func,
fit_params = list(num.trees = $num_trees),
predict_func = predict_func);
"""
return rcopy(R"z")

end
end

function loessf(y; h=5, level=95)
Expand Down
26 changes: 20 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,33 @@ using Test
@test val[:method] == "mean"
end

@testset "----- Testing dynrmf -----" begin
@testset "----- Testing dynrmf (autoridge) -----" begin
val = Ahead.dynrmf(y, h=6)
println(val)
@test isapprox(round(val[:residuals][1]), 0)
@test val[:method] == "DynRM 1"
end
end

@testset "----- Testing eatf -----" begin
val = Ahead.eatf(y, h=6)
@testset "----- Testing dynrmf (svm) -----" begin
val = Ahead.dynrmf(y, h=6, kernel="linear")
println(val)
@test isapprox(round(val[:residuals][1]), 0)
@test val[:x] == y
@test val[:method] == "EAT"
end
end

@testset "----- Testing dynrmf (ranger) -----" begin
val = Ahead.dynrmf(y, h=6, num_trees=50)
println(val)
@test isapprox(round(val[:residuals][1]), 0)
@test val[:x] == y
end

#@testset "----- Testing eatf -----" begin
# val = Ahead.eatf(y, h=6)
# println(val)
# @test val[:x] == y
# @test val[:method] == "EAT"
#end

@testset "----- Testing loessf -----" begin
val = Ahead.loessf(y, h=6)
Expand Down

0 comments on commit a21d107

Please sign in to comment.