Skip to content

Commit

Permalink
Merge pull request #21 from invenia/wct/relax-tests
Browse files Browse the repository at this point in the history
Relax test assumptions
  • Loading branch information
willtebbutt authored Aug 27, 2020
2 parents 1c87cf7 + 3eab6af commit 99cfe11
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Models"
uuid = "e6388cff-ecff-480c-9b53-83211bf7812a"
authors = ["Invenia Technical Computing Corporation"]
version = "0.2.2"
version = "0.2.3"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down
10 changes: 5 additions & 5 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ testing downstream dependencies, and [`test_interface`](@ref) for testing the Mo
been correctly implemented.
"""
module TestUtils
using Distributions: Normal, MultivariateNormal
using Distributions
using Models
using NamedDims
using StatsBase
Expand Down Expand Up @@ -75,7 +75,7 @@ function FakeTemplate{DistributionEstimate, SingleOutput}()
FakeTemplate{DistributionEstimate, SingleOutput}() do num_variates, inputs
@assert(num_variates == 1, "$num_variates != 1")
inputs = NamedDimsArray{(:features, :observations)}(inputs)
return Normal.(zeros(size(inputs, :observations)))
return NoncentralT.(3.0, zeros(size(inputs, :observations)))
end
end

Expand All @@ -88,7 +88,7 @@ distribution (with zero-vector mean and identity covariance matrix) for each obs
function FakeTemplate{DistributionEstimate, MultiOutput}()
FakeTemplate{DistributionEstimate, MultiOutput}() do num_variates, inputs
std_dev = ones(num_variates)
return [MultivariateNormal(std_dev) for _ in 1:size(inputs, 2)]
return [Product(Normal.(0, std_dev)) for _ in 1:size(inputs, 2)]
end
end

Expand Down Expand Up @@ -158,7 +158,7 @@ function test_interface(
inputs=rand(5, 5), outputs=rand(1, 5),
)
predictions = test_common(template, inputs, outputs)
@test predictions isa Vector{<:Normal{<:Real}}
@test predictions isa AbstractVector{<:ContinuousUnivariateDistribution}
@test length(predictions) == size(outputs, 2)
@test all(length.(predictions) .== size(outputs, 1))
end
Expand All @@ -168,7 +168,7 @@ function test_interface(
inputs=rand(5, 5), outputs=rand(3, 5)
)
predictions = test_common(template, inputs, outputs)
@test predictions isa Vector{<:MultivariateNormal{<:Real}}
@test predictions isa AbstractVector{<:ContinuousMultivariateDistribution}
@test length(predictions) == size(outputs, 2)
@test all(length.(predictions) .== size(outputs, 1))
end
Expand Down

2 comments on commit 99cfe11

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/20385

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.3 -m "<description of version>" 99cfe110c430c41c88461fb7a80fbd9e9041bc04
git push origin v0.2.3

Please sign in to comment.