Skip to content

Commit

Permalink
fix: store the bias as a vector
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 23, 2024
1 parent 2364a46 commit c544375
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 14 deletions.
4 changes: 2 additions & 2 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ end
function initialparameters(rng::AbstractRNG, b::Bilinear{use_bias}) where {use_bias}
if use_bias
return (weight=b.init_weight(rng, b.out_dims, b.in1_dims, b.in2_dims),
bias=b.init_bias(rng, b.out_dims, 1)) # TODO: In v1.0 make it a vector
bias=b.init_bias(rng, b.out_dims))
else
return (weight=b.init_weight(rng, b.out_dims, b.in1_dims, b.in2_dims),)
end
Expand All @@ -524,7 +524,7 @@ function (b::Bilinear{use_bias})((x, y)::Tuple{<:AbstractVecOrMat, <:AbstractVec
Wy = reshape(reshape(ps.weight, (:, d_y)) * y, (d_z, d_x, :))
Wyx = reshape(batched_mul(Wy, reshape(x, (d_x, 1, :))), (d_z, :))

return bias_activation!!(b.activation, Wyx, _vec(_getproperty(ps, Val(:bias)))), st
return bias_activation!!(b.activation, Wyx, _getproperty(ps, Val(:bias))), st
end

function (b::Bilinear)((x, y)::Tuple{<:AbstractArray, <:AbstractArray}, ps, st::NamedTuple)
Expand Down
12 changes: 6 additions & 6 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ function initialparameters(rng::AbstractRNG, c::Conv{N, use_bias}) where {N, use
weight = _convfilter(
rng, c.kernel_size, c.in_chs => c.out_chs; init=c.init_weight, groups=c.groups)
!use_bias && return (; weight)
return (; weight, bias=c.init_bias(rng, ntuple(_ -> 1, N)..., c.out_chs, 1)) # TODO: flatten in v1
return (; weight, bias=c.init_bias(rng, c.out_chs))
end

function parameterlength(c::Conv{N, use_bias}) where {N, use_bias}
Expand All @@ -116,7 +116,7 @@ end
cdims = DenseConvDims(y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups)
return (
fused_conv_bias_activation(
c.activation, ps.weight, y, _vec(_getproperty(ps, Val(:bias))), cdims),
c.activation, ps.weight, y, _getproperty(ps, Val(:bias)), cdims),
st)
end

Expand Down Expand Up @@ -594,7 +594,7 @@ end
function initialparameters(rng::AbstractRNG, c::CrossCor{N, use_bias}) where {N, use_bias}
weight = _convfilter(rng, c.kernel_size, c.in_chs => c.out_chs; init=c.init_weight)
!use_bias && return (; weight)
return (; weight, bias=c.init_bias(rng, ntuple(_ -> 1, N)..., c.out_chs, 1)) # TODO: flatten in v1
return (; weight, bias=c.init_bias(rng, c.out_chs))
end

function parameterlength(c::CrossCor{N, use_bias}) where {N, use_bias}
Expand All @@ -607,7 +607,7 @@ end
DenseConvDims(y, ps.weight; c.stride, padding=c.pad, c.dilation); F=true)
return (
fused_conv_bias_activation(
c.activation, ps.weight, y, _vec(_getproperty(ps, Val(:bias))), cdims),
c.activation, ps.weight, y, _getproperty(ps, Val(:bias)), cdims),
st)
end

Expand Down Expand Up @@ -720,7 +720,7 @@ function initialparameters(
weight = _convfilter(
rng, c.kernel_size, c.out_chs => c.in_chs; init=c.init_weight, c.groups)
!use_bias && return (; weight)
return (; weight, bias=c.init_bias(rng, ntuple(_ -> 1, N)..., c.out_chs, 1)) # TODO: flatten in v1
return (; weight, bias=c.init_bias(rng, c.out_chs))
end

function parameterlength(c::ConvTranspose{N, use_bias}) where {N, use_bias}
Expand All @@ -734,7 +734,7 @@ end
y, ps.weight; c.stride, padding=c.pad, c.dilation, c.groups)
return (
bias_activation!!(c.activation, _conv_transpose(y, ps.weight, cdims),
_vec(_getproperty(ps, Val(:bias)))),
_getproperty(ps, Val(:bias))),
st)
end

Expand Down
9 changes: 3 additions & 6 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,7 @@ function initialparameters(
end

function initialstates(rng::AbstractRNG, ::RNNCell)
# FIXME(@avik-pal): Take PRNGs seriously
randn(rng, 1)
randn(rng, 1) # FIXME(@avik-pal): Take PRNGs seriously
return (rng=replicate(rng),)
end

Expand Down Expand Up @@ -423,8 +422,7 @@ function initialparameters(rng::AbstractRNG,
end

function initialstates(rng::AbstractRNG, ::LSTMCell)
# FIXME(@avik-pal): Take PRNGs seriously
randn(rng, 1)
randn(rng, 1) # FIXME(@avik-pal): Take PRNGs seriously
return (rng=replicate(rng),)
end

Expand Down Expand Up @@ -592,8 +590,7 @@ function initialparameters(
end

function initialstates(rng::AbstractRNG, ::GRUCell)
# FIXME(@avik-pal): Take PRNGs seriously
randn(rng, 1)
randn(rng, 1) # FIXME(@avik-pal): Take PRNGs seriously
return (rng=replicate(rng),)
end

Expand Down

0 comments on commit c544375

Please sign in to comment.