Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convolution を実装 #53

Merged
merged 10 commits into from
Sep 16, 2020
Merged

Convolution を実装 #53

merged 10 commits into from
Sep 16, 2020

Conversation

terry-u16
Copy link
Contributor

@terry-u16 terry-u16 commented Sep 15, 2020

Convolution を実装しました。 #11

処理の内容はほぼC++版に沿っています。
Convolution.cs の実装に加え、以下の4ファイルをInternal/Mathフォルダ下に追加しました。

  • Butterfly.cs (NTTのメイン処理)
  • PrimitiveRoot.cs (素数 m の最小の原始根を求める)
  • PowMod.cs (x^n mod m を求める)
  • BSF.cs (整数 n に対し、2進数表記における末尾の0の個数を求める)

mod 998244353 のケースについてのみではありますが、以下で動作確認を行いました。

また、簡単にではありますが、以下のようなコードでランダムな入力・数パターンのmodに対しナイーブな解法との整合性を確認しています。

テストコード
using System;
using System.Linq;
using System.Runtime.CompilerServices;
using AtCoder;
using Xunit;

public class ConvolutionTest
{
    [Theory]
    [InlineData(1, 1, 42)]
    [InlineData(4, 6, 42)]
    [InlineData(64, 64, 123)]
    [InlineData(100, 100, 42)]
    [InlineData(1, 10000, 12345)]
    [InlineData(4876, 12878, 26861194601)]
    [InlineData(7314, 3890, 5890635110)]
    public void ConvolutionMod998244353Test(int lengthA, int lengthB, ulong seed) => ConvolutionMod<Mod998244353>(lengthA, lengthB, seed);

    [Theory]
    [InlineData(1, 1, 42)]
    [InlineData(4, 6, 42)]
    [InlineData(64, 64, 123)]
    [InlineData(100, 100, 42)]
    [InlineData(1, 10000, 12345)]
    [InlineData(4876, 12878, 26861194601)]
    [InlineData(7314, 3890, 5890635110)]
    public void ConvolutionMod163577857Test(int lengthA, int lengthB, ulong seed) => ConvolutionMod<Mod163577857>(lengthA, lengthB, seed);

    [Theory]
    [InlineData(1, 1, 42)]
    [InlineData(4, 6, 42)]
    [InlineData(64, 64, 123)]
    [InlineData(100, 100, 42)]
    [InlineData(1, 10000, 12345)]
    [InlineData(4876, 12878, 26861194601)]
    [InlineData(7314, 3890, 5890635110)]
    public void ConvolutionMod469762049Test(int lengthA, int lengthB, ulong seed) => ConvolutionMod<Mod469762049>(lengthA, lengthB, seed);

    private void ConvolutionMod<T>(int lengthA, int lengthB, ulong seed) where T : struct, IStaticMod
    {
        var rand = new XorShift(seed);
        var a = new StaticModInt<T>[lengthA];
        var b = new StaticModInt<T>[lengthB];
        var aRaw = new ulong[lengthA];
        var bRaw = new ulong[lengthB];

        for (int i = 0; i < a.Length; i++)
        {
            aRaw[i] = rand.Next();
            a[i] = StaticModInt<T>.Raw((int)(aRaw[i] % default(T).Mod));
        }

        for (int i = 0; i < b.Length; i++)
        {
            bRaw[i] = rand.Next();
            b[i] = StaticModInt<T>.Raw((int)(bRaw[i] % default(T).Mod));
        }

        var expected = new StaticModInt<T>[a.Length + b.Length - 1];
        for (int i = 0; i < a.Length; i++)
        {
            for (int j = 0; j < b.Length; j++)
            {
                expected[i + j] += a[i] * b[j];
            }
        }

        // 各種オーバーロードについてテスト
        var actualModInt = AtCoder.Math.Convolution(a, b);
        var actualModIntSpan = AtCoder.Math.Convolution((ReadOnlySpan<StaticModInt<T>>)a, b);
        var actualInt = AtCoder.Math.Convolution<T>(a.Select(ai => ai.Value).ToArray(), b.Select(bi => bi.Value).ToArray());
        var actualUInt = AtCoder.Math.Convolution<T>(a.Select(ai => (uint)ai.Value).ToArray(), b.Select(bi => (uint)bi.Value).ToArray());
        var actualLong = AtCoder.Math.Convolution<T>(a.Select(ai => (long)ai.Value).ToArray(), b.Select(bi => (long)bi.Value).ToArray());
        var actualULong = AtCoder.Math.Convolution<T>(aRaw, bRaw);

        Assert.Equal(expected, actualModInt);
        Assert.Equal(expected, actualModIntSpan.ToArray());
        Assert.Equal(expected.Select(ei => ei.Value), actualInt);
        Assert.Equal(expected.Select(ei => (uint)ei.Value), actualUInt);
        Assert.Equal(expected.Select(ei => (long)ei.Value), actualLong);
        Assert.Equal(expected.Select(ei => (ulong)ei.Value), actualULong);
    }

    [Theory]
    [InlineData(0, 0)]
    [InlineData(0, 1)]
    [InlineData(4, 0)]
    [InlineData(0, 123456)]
    public void ConvolutionEmptyTest(int lengthA, int lengthB)
    {
        var aInt = new int[lengthA];
        var bInt = new int[lengthB];
        var aUInt = new uint[lengthA];
        var bUInt = new uint[lengthB];
        var aLong = new long[lengthA];
        var bLong = new long[lengthB];
        var aULong = new ulong[lengthA];
        var bULong = new ulong[lengthB];
        var aMod = new StaticModInt<Mod998244353>[lengthA];
        var bMod = new StaticModInt<Mod998244353>[lengthB];

        var actualInt = AtCoder.Math.Convolution(aInt, bInt);
        var actualUInt = AtCoder.Math.Convolution(aUInt, bUInt);
        var actualLong = AtCoder.Math.Convolution(aLong, bLong);
        var actualULong = AtCoder.Math.Convolution(aULong, bULong);
        var actualModInt = AtCoder.Math.Convolution(aMod, bMod);
        var actualModIntSpan = AtCoder.Math.Convolution((ReadOnlySpan<StaticModInt<Mod998244353>>)aMod, bMod);

        Assert.Equal(Array.Empty<int>(), actualInt);
        Assert.Equal(Array.Empty<uint>(), actualUInt);
        Assert.Equal(Array.Empty<long>(), actualLong);
        Assert.Equal(Array.Empty<ulong>(), actualULong);
        Assert.Equal(Array.Empty<StaticModInt<Mod998244353>>(), actualModInt);
        Assert.Equal(Array.Empty<StaticModInt<Mod998244353>>(), actualModIntSpan.ToArray());
    }

    [Theory]
    [InlineData(1, 1, 42)]
    [InlineData(4, 6, 42)]
    [InlineData(64, 64, 123)]
    [InlineData(100, 100, 42)]
    [InlineData(1, 10000, 12345)]
    [InlineData(4876, 12878, 26861194601)]
    [InlineData(7314, 3890, 5890635110)]
    public void ConvolutionLLTest(int lengthA, int lengthB, ulong seed)
    {
        var rand = new XorShift(seed);
        var a = new long[lengthA];
        var b = new long[lengthB];

        for (int i = 0; i < a.Length; i++)
        {
            a[i] = rand.Next(1_000_000) - 500_000;
        }

        for (int i = 0; i < b.Length; i++)
        {
            b[i] = rand.Next(1_000_000) - 500_000;
        }

        var expected = new long[a.Length + b.Length - 1];
        for (int i = 0; i < a.Length; i++)
        {
            for (int j = 0; j < b.Length; j++)
            {
                expected[i + j] += a[i] * b[j];
            }
        }

        var actual = AtCoder.Math.ConvolutionLong(a, b);

        Assert.Equal(expected, actual);
    }

    /// <summary>
    /// 163577857 = 39×2^22 + 1
    /// </summary>
    readonly struct Mod163577857 : IStaticMod
    {
        public uint Mod => 163577857;
        public bool IsPrime => true;
    }

    /// <summary>
    /// 469762049 = 7×2^26 + 1
    /// </summary>
    readonly struct Mod469762049 : IStaticMod
    {
        public uint Mod => 469762049;
        public bool IsPrime => true;
    }
}

public class XorShift
{
    ulong _x;

    public XorShift() : this((ulong)DateTime.Now.Ticks) { }

    public XorShift(ulong seed)
    {
        _x = seed;
    }

    /// <summary>
    /// [0, (2^64)-1)の乱数を生成します。
    /// </summary>
    /// <returns></returns>
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public ulong Next()
    {
        _x = _x ^ (_x << 13);
        _x = _x ^ (_x >> 7);
        _x = _x ^ (_x << 17);
        return _x;
    }

    /// <summary>
    /// [0, <c>exclusiveMax</c>)の乱数を生成します。
    /// </summary>
    /// <param name="exclusiveMax"></param>
    /// <returns></returns>
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public int Next(int exclusiveMax) => (int)(Next() % (uint)exclusiveMax);

    /// <summary>
    /// [0.0, 1.0)の乱数を生成します。
    /// </summary>
    /// <returns></returns>
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public double NextDouble()
    {
        const ulong max = 1UL << 50;
        const ulong mask = max - 1;
        return (double)(Next() & mask) / max;
    }
}

よろしくお願いします。

@key-moon
Copy link
Contributor

key-moon commented Sep 15, 2020

ありがとうございます!

  • BSF.cs に関しては、InternalBit 内の方が的確かと思いましたがいかがでしょうか?
  • CalcurateSumIE / CalcurateSumE に関して、static フィールド初期化子によるキャッシュをするとより簡潔かと思いました。が、全展開で Expand された時にオーバーヘッドになりうることを思えばこの実装でも良いかもしれません。
static class Test
{
    public static int StaticField = GenerateInt();
    static int GenerateInt() { Console.WriteLine("Init"); return 1; };
    static Test() { }
}

Comment on lines 18 to 26
return m switch
{
2 => 1,
167772161 => 3,
469762049 => 3,
754974721 => 11,
998244353 => 3,
_ => Calculate(m)
};
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

private static Dictionary<int, int> s_cache = new Dictionary<int, int>
{
    { 2, 1 },
    {167772161, 3},
    {469762049, 3},
    {754974721, 11},
    {998244353, 3},
};

を用意して、

if(s_cache.TryGetValue(m, out int r)) return r;
s_cache[m] = Calculate(m);

とキャッシュしても良いかなと思います。

@terry-u16
Copy link
Contributor Author

長いコードにも関わらず、丁寧にご確認頂きありがとうございます!

  • BSF.cs に関しては、InternalBit 内の方が的確かと思いましたがいかがでしょうか?

InternalBit の存在を見落としていました……。確かにそちらの方が良さそうです。

  • CalcurateSumIE / CalcurateSumE に関して、static フィールド初期化子によるキャッシュをするとより簡潔かと思いました。が、全展開で Expand された時にオーバーヘッドになりうることを思えばこの実装でも良いかもしれません。

私も静的フィールドへの直書きと迷って、一応C++版に寄せた書き方にしておくか……程度の動機だったのですが、やっぱりそちらの方が簡潔そうですね。

if(s_cache.TryGetValue(m, out int r)) return r;
s_cache[m] = Calculate(m);

とキャッシュしても良いかなと思います。

constexprが付けられない以上2回呼ばれてしまうので、確かにそちらの方が良さそうです。

修正まで今しばらくお待ち頂ければと思います。よろしくお願いします。

@terry-u16
Copy link
Contributor Author

上記3件を反映しました。改めて愚直解と一致することを確認済みです。
よろしくお願いします。

@key-moon
Copy link
Contributor

確認しました!ありがとうございます。

@key-moon
Copy link
Contributor

一つだけ気になったのですが、 PowMod は Math で実装されているものを使用しないのですか…?(昨日指摘できずに申し訳ありません。)

@terry-u16
Copy link
Contributor Author

C++版のpow_mod_constexprinternal_math.hpp内にあるのを見てこれはInternalだなと思い込んでいましたが、C++だからconstexpr版を追加で用意していただけで、通常版は通常版で普通にありましたね……お恥ずかしい……。
重複分は削除しましたので、よろしくお願いします。

@key-moon
Copy link
Contributor

なるほど、Burret 演算を使ってないのは constexpr を適用できるようにするための名残だったんですね。(昨日見たときは型でも違うのかと思ってスルーしてしまってました。)ありがとうございます🙇

@key-moon key-moon merged commit 595149f into kzrnm:master Sep 16, 2020
@terry-u16 terry-u16 deleted the feature/convolution branch September 16, 2020 10:16
kzrnm added a commit that referenced this pull request Feb 27, 2022
Take consistent with Queue<T>.Enqueue

Fix #53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants