Skip to content

Commit

Permalink
both methods for geom_smooth have tests and work
Browse files Browse the repository at this point in the history
  • Loading branch information
rdboyes committed Mar 23, 2024
1 parent 9df22de commit 7f22e32
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 19 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
Loess = "4345ca2d-374a-55d4-8d30-97f9976e7612"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
TidierData = "fe2206b3-d496-4ee9-a338-6a095c4ece80"
Expand Down
3 changes: 2 additions & 1 deletion src/TidierPlots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ using CategoricalArrays
# Data manipulation, expression parsing
using TidierData

# for ... GLMS
# for geom_smooth fits
using GLM
using Loess

include("structs.jl")

Expand Down
40 changes: 25 additions & 15 deletions src/geoms/geom_smooth.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,35 +33,45 @@ end
function geom_smooth(args...; kwargs...)
aes_dict, args_dict = extract_aes(args, kwargs)
args_dict["geom_name"] = "geom_smooth"

analysis = stat_loess

if haskey(args_dict, "method")
if args_dict["method"] == "lm"
analysis = stat_linear
return [build_geom(aes_dict,
args_dict,
["x", "y"],
:Lines,
stat_linear),
build_geom(aes_dict,
args_dict,
["x", "lower", "upper"],
:Band,
stat_linear)]
end
end

# geom_smooth returns TWO makie plots:
# :Lines - the center line
# :Band - the uncertainty interval

return [build_geom(aes_dict,
return build_geom(aes_dict,
args_dict,
["x", "y"],
:Lines,
analysis),
build_geom(aes_dict,
args_dict,
["x", "lower", "upper"],
:Band,
analysis)]
stat_loess)
end

function stat_loess(aes_dict::Dict{String, Symbol},
args_dict::Dict{Any, Any}, required_aes::Vector{String}, plot_data::DataFrame)

return (aes_dict, args_dict, required_aes, plot_data)
x = plot_data[!, aes_dict["x"]]
y = plot_data[!, aes_dict["y"]]

model = Loess.loess(x, y; span = .75, degree = 2)
= range(extrema(x)..., length=200)
= Loess.predict(model, x̂)

return_data = DataFrame(
String(aes_dict["x"]) => x̂,
String(aes_dict["y"]) =>
)

return (aes_dict, args_dict, required_aes, return_data)
end

function stat_linear(aes_dict::Dict{String, Symbol},
Expand Down
20 changes: 19 additions & 1 deletion test/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.10.2"
manifest_format = "2.0"
project_hash = "d3db1469a6d665450c389024fc6c41ad54db31ae"
project_hash = "5ca544e0ddaec04c2c0551102da3cb1f1f227a35"

[[deps.AbstractFFTs]]
deps = ["LinearAlgebra"]
Expand Down Expand Up @@ -593,6 +593,12 @@ version = "1.0.10+0"
deps = ["Random"]
uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"

[[deps.GLM]]
deps = ["Distributions", "LinearAlgebra", "Printf", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns", "StatsModels"]
git-tree-sha1 = "273bd1cd30768a2fddfa3fd63bbc746ed7249e5f"
uuid = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
version = "1.9.0"

[[deps.GeoInterface]]
deps = ["Extents"]
git-tree-sha1 = "d4f85701f569584f2cff7ba67a137d03f0cfb7d0"
Expand Down Expand Up @@ -1076,6 +1082,12 @@ git-tree-sha1 = "110897e7db2d6836be22c18bffd9422218ee6284"
uuid = "d3a379c0-f9a3-5b72-a4c0-6bf4d2e8af0f"
version = "2.12.0+0"

[[deps.Loess]]
deps = ["Distances", "LinearAlgebra", "Statistics", "StatsAPI"]
git-tree-sha1 = "a113a8be4c6d0c64e217b472fb6e61c760eb4022"
uuid = "4345ca2d-374a-55d4-8d30-97f9976e7612"
version = "0.6.3"

[[deps.LogExpFunctions]]
deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"]
git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37"
Expand Down Expand Up @@ -1776,6 +1788,12 @@ version = "1.3.1"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"

[[deps.StatsModels]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Printf", "REPL", "ShiftedArrays", "SparseArrays", "StatsAPI", "StatsBase", "StatsFuns", "Tables"]
git-tree-sha1 = "5cf6c4583533ee38639f73b880f35fc85f2941e0"
uuid = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
version = "0.7.3"

[[deps.StringManipulation]]
deps = ["PrecompileTools"]
git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5"
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
JDF = "babc3d20-cd49-4f60-a736-a8f9c08892d3"
Loess = "4345ca2d-374a-55d4-8d30-97f9976e7612"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TidierData = "fe2206b3-d496-4ee9-a338-6a095c4ece80"
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ using TidierData
using Images
using JDF
using CategoricalArrays
using GLM
using Loess

# functions to compare two images using a difference hash
# essentially copied from ImageHashes.jl, but package is out of date
Expand Down
73 changes: 71 additions & 2 deletions test/test_geoms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,75 @@
@test plot_images_equal(t, m)

end


@testset "geom_smooth" begin
t = ggplot(penguins, aes(x = "bill_length_mm", y = "bill_depth_mm")) +
geom_smooth() + geom_point()

model = Loess.loess(penguins.bill_length_mm, penguins.bill_depth_mm; span = .75, degree = 2)
= range(extrema(penguins.bill_length_mm)..., length=200)
= Loess.predict(model, x̂)

m = Makie.plot(
Makie.SpecApi.GridLayout(
Makie.SpecApi.Axis(
plots = [
Makie.PlotSpec(
:Lines,
x̂,
ŷ),
Makie.PlotSpec(
:Scatter,
penguins.bill_length_mm,
penguins.bill_depth_mm
)
]
)
)
)

@test plot_images_equal(t, m)

t = ggplot(penguins, aes(x = "bill_length_mm", y = "bill_depth_mm")) +
geom_smooth(method = "lm") + geom_point()


function add_intercept_column(x::AbstractVector{T}) where {T}
mat = similar(x, float(T), (length(x), 2))
fill!(view(mat, :, 1), 1)
copyto!(view(mat, :, 2), x)
return mat
end

lin_model = GLM.lm(add_intercept_column(penguins.bill_length_mm), penguins.bill_depth_mm)
= range(extrema(penguins.bill_length_mm)..., length=100)
pred = DataFrame(GLM.predict(lin_model, add_intercept_column(x̂); interval = :confidence))

m = Makie.plot(
Makie.SpecApi.GridLayout(
Makie.SpecApi.Axis(
plots = [
Makie.PlotSpec(
:Lines,
x̂,
pred.prediction),
Makie.PlotSpec(
:Scatter,
penguins.bill_length_mm,
penguins.bill_depth_mm
),
Makie.PlotSpec(
:Band,
x̂,
pred.lower,
pred.upper
)
]
)
)
)

@test plot_images_equal(t, m)

end
end

0 comments on commit 7f22e32

Please sign in to comment.