diff --git a/src/implementations/BigFloat.jl b/src/implementations/BigFloat.jl index f72ad5b0..0e9034fc 100644 --- a/src/implementations/BigFloat.jl +++ b/src/implementations/BigFloat.jl @@ -19,6 +19,24 @@ end const _MPFRRoundingMode = Base.MPFR.MPFRRoundingMode +# copy + +promote_operation(::typeof(copy), ::Type{BigFloat}) = BigFloat + +function operate_to!(out::BigFloat, ::typeof(copy), in::BigFloat) + ccall( + (:mpfr_set, :libmpfr), + Int32, + (Ref{BigFloat}, Ref{BigFloat}, _MPFRRoundingMode), + out, + in, + Base.MPFR.ROUNDING_MODE[], + ) + return out +end + +operate!(::typeof(copy), x::BigFloat) = x + # zero promote_operation(::typeof(zero), ::Type{BigFloat}) = BigFloat @@ -371,17 +389,7 @@ function buffered_operate_to!( x::AbstractVector{F}, y::AbstractVector{F}, ) where {F<:BigFloat} - local set! = function (out::F, in::F) - ccall( - (:mpfr_set, :libmpfr), - Int32, - (Ref{BigFloat}, Ref{BigFloat}, Base.MPFR.MPFRRoundingMode), - out, - in, - Base.MPFR.ROUNDING_MODE[], - ) - return nothing - end + set! = (o, i) -> operate_to!(o, copy, i) local swap! = function (x::BigFloat, y::BigFloat) ccall((:mpfr_swap, :libmpfr), Cvoid, (Ref{BigFloat}, Ref{BigFloat}), x, y) diff --git a/src/implementations/BigInt.jl b/src/implementations/BigInt.jl index 36aa8c45..9302174c 100644 --- a/src/implementations/BigInt.jl +++ b/src/implementations/BigInt.jl @@ -13,6 +13,17 @@ mutability(::Type{BigInt}) = IsMutable() # https://github.com/JuliaLang/julia/blob/7d41d1eb610cad490cbaece8887f9bbd2a775021/base/gmp.jl#L772 mutable_copy(x::BigInt) = Base.GMP.MPZ.set(x) +# copy + +promote_operation(::typeof(copy), ::Type{BigInt}) = BigInt + +function operate_to!(out::BigInt, ::typeof(copy), in::BigInt) + Base.GMP.MPZ.set!(out, in) + return out +end + +operate!(::typeof(copy), x::BigInt) = x + # zero promote_operation(::typeof(zero), ::Type{BigInt}) = BigInt diff --git a/src/implementations/Rational.jl b/src/implementations/Rational.jl index 2c293138..91c61477 100644 --- a/src/implementations/Rational.jl +++ b/src/implementations/Rational.jl @@ -13,6 +13,18 @@ mutability(::Type{Rational{T}}) where {T} = mutability(T) mutable_copy(x::Rational) = Rational(mutable_copy(x.num), mutable_copy(x.den)) +# copy + +promote_operation(::typeof(copy), ::Type{Q}) where {Q<:Rational} = Q + +function operate_to!(out::Q, ::typeof(copy), in::Q) where {Q<:Rational} + operate_to!(out.num, copy, in.num) + operate_to!(out.den, copy, in.den) + return out +end + +operate!(::typeof(copy), x::Rational) = x + # zero promote_operation(::typeof(zero), ::Type{Rational{T}}) where {T} = Rational{T} diff --git a/src/interface.jl b/src/interface.jl index 5b3452ee..c7c19dfa 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -218,7 +218,10 @@ operate(::typeof(convert), ::Type{T}, x) where {T} = convert(T, x) operate(::typeof(convert), ::Type{T}, x::T) where {T} = copy_if_mutable(x) -function operate(::Union{typeof(+),typeof(*),typeof(gcd),typeof(lcm)}, x) +function operate( + ::Union{typeof(copy),typeof(+),typeof(*),typeof(gcd),typeof(lcm)}, + x, +) return copy_if_mutable(x) end diff --git a/test/copy.jl b/test/copy.jl new file mode 100644 index 00000000..531b5c73 --- /dev/null +++ b/test/copy.jl @@ -0,0 +1,37 @@ +# Copyright (c) 2023 MutableArithmetics.jl contributors +# +# This Source Code Form is subject to the terms of the Mozilla Public License, +# v.2.0. If a copy of the MPL was not distributed with this file, You can obtain +# one at http://mozilla.org/MPL/2.0/. + +@testset "copy: $T" for T in ( + Float64, + BigFloat, + Int, + BigInt, + Rational{Int}, + Rational{BigInt}, +) + @test MA.operate!!(copy, T(2)) == 2 + @test MA.operate_to!!(T(3), copy, T(2)) == 2 + if MA.mutability(T, copy, T) == MA.IsMutable() + @testset "mutable" begin + @testset "correctness" begin + x = T(2) + y = T(3) + @test MA.operate!(copy, x) === x == 2 + @test MA.operate_to!(y, copy, x) === y == 2 + end + @testset "alloc" begin + f = let x = T(2) + () -> MA.operate!(copy, x) + end + g = let x = T(2), y = T(3) + () -> MA.operate_to!(y, copy, x) + end + alloc_test(f, 0) + alloc_test(g, 0) + end + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 2ffd7205..9a9bb53c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,8 @@ include("interface.jl") include("range.jl") +include("copy.jl") + @testset "Int" begin include("int.jl") end