diff --git a/ext/LuxDynamicExpressionsExt.jl b/ext/LuxDynamicExpressionsExt.jl index 4102f29f4d..980dd426e1 100644 --- a/ext/LuxDynamicExpressionsExt.jl +++ b/ext/LuxDynamicExpressionsExt.jl @@ -4,7 +4,7 @@ using ChainRulesCore: ChainRulesCore, NoTangent using DynamicExpressions: DynamicExpressions, Node, OperatorEnum, eval_grad_tree_array using FastClosures: @closure using ForwardDiff: ForwardDiff -using Lux: Lux, NAME_TYPE, Chain, Parallel, WrappedFunction, DynamicExpressionsLayer +using Lux: Lux, NAME_TYPE, Chain, Parallel, DynamicExpressionsLayer using LuxDeviceUtils: LuxCPUDevice const CRC = ChainRulesCore diff --git a/ext/LuxFluxExt.jl b/ext/LuxFluxExt.jl index dbd89e0d6c..9bbe1c0bfd 100644 --- a/ext/LuxFluxExt.jl +++ b/ext/LuxFluxExt.jl @@ -96,8 +96,7 @@ function Lux.__from_flux_adaptor(l::Flux.Conv; preserve_ps_st::Bool=false, kwarg groups = l.groups pad = l.pad isa Flux.SamePad ? SamePad() : l.pad if preserve_ps_st - _bias = l.bias isa Bool ? nothing : - reshape(copy(l.bias), ntuple(_ -> 1, length(k))..., out_chs, 1) + _bias = l.bias isa Bool ? nothing : vec(copy(l.bias)) return Lux.Conv(k, in_chs * groups => out_chs, l.σ; l.stride, pad, l.dilation, groups, init_weight=Returns(Lux._maybe_flip_conv_weight(l.weight)), init_bias=Returns(_bias), use_bias=!(l.bias isa Bool)) @@ -114,8 +113,7 @@ function Lux.__from_flux_adaptor( groups = l.groups pad = l.pad isa Flux.SamePad ? SamePad() : l.pad if preserve_ps_st - _bias = l.bias isa Bool ? nothing : - reshape(copy(l.bias), ntuple(_ -> 1, length(k))..., out_chs, 1) + _bias = l.bias isa Bool ? nothing : vec(copy(l.bias)) return Lux.ConvTranspose(k, in_chs * groups => out_chs, l.σ; l.stride, pad, l.dilation, groups, use_bias=!(l.bias isa Bool), init_weight=Returns(Lux._maybe_flip_conv_weight(l.weight)), @@ -131,8 +129,7 @@ function Lux.__from_flux_adaptor(l::Flux.CrossCor; preserve_ps_st::Bool=false, k in_chs, out_chs = size(l.weight)[(end - 1):end] pad = l.pad isa Flux.SamePad ? SamePad() : l.pad if preserve_ps_st - _bias = l.bias isa Bool ? nothing : - reshape(copy(l.bias), ntuple(_ -> 1, length(k))..., out_chs, 1) + _bias = l.bias isa Bool ? nothing : vec(copy(l.bias)) return Lux.CrossCor(k, in_chs => out_chs, l.σ; l.stride, pad, l.dilation, init_weight=Returns(copy(l.weight)), init_bias=Returns(_bias), use_bias=!(l.bias isa Bool)) diff --git a/test/layers/conv_tests.jl b/test/layers/conv_tests.jl index 60b1386b43..5462145e8e 100644 --- a/test/layers/conv_tests.jl +++ b/test/layers/conv_tests.jl @@ -187,7 +187,7 @@ end display(layer) ps, st = Lux.setup(rng, layer) @test ps.weight isa aType{Float64, 4} - @test ps.bias isa aType{Float16, 4} + @test ps.bias isa aType{Float16, 1} end @testset "Depthwise Conv" begin @@ -447,7 +447,7 @@ end display(layer) ps, st = Lux.setup(rng, layer) @test ps.weight isa aType{Float64, 4} - @test ps.bias isa aType{Float16, 4} + @test ps.bias isa aType{Float16, 1} end @testset "CrossCor SamePad kernelsize $k" for k in (