From cd71ccb2e0a623aa61afef20b64c18969078fe80 Mon Sep 17 00:00:00 2001 From: Jack Dermody Date: Fri, 26 Jul 2024 07:50:42 +1000 Subject: [PATCH] added exponential distribution and vector graph - also refactored numerical analysis --- BrightData.Cuda/CudaProvider.cs | 6 +- BrightData.UnitTests/AnalysisTests.cs | 8 +- BrightData.UnitTests/SpanTests.cs | 12 + BrightData.UnitTests/VectorSetTests.cs | 37 ++- .../Analysis/CastToDoubleNumericAnalysis.cs | 22 +- .../Analysis/LinearBinnedFrequencyAnalysis.cs | 21 +- BrightData/Analysis/NumericAnalyser.cs | 138 ++++++------ .../OnlineStandardDeviationAnalysis.cs | 48 ++++ BrightData/Analysis/StaticAnalysers.cs | 7 +- BrightData/BrightData.xml | 194 +++++++++++++++- .../Distribution/ExponentialDistribution.cs | 20 ++ BrightData/ExtensionMethods.Analysis.cs | 22 +- BrightData/ExtensionMethods.DataTable.cs | 10 +- BrightData/ExtensionMethods.Distributions.cs | 8 + BrightData/ExtensionMethods.Span.cs | 35 +++ BrightData/ExtensionMethods.cs | 7 + BrightData/Interfaces.VectorIndexing.cs | 22 +- BrightData/Interfaces.cs | 24 +- .../Helper/IndexedFixedSizeGraphNode.cs | 212 ++++++++++++++++++ .../VectorIndexing/Helper/VectorGraph.cs | 123 ++++++++++ .../IndexStrategy/FlatVectorIndex.cs | 26 ++- .../IndexStrategy/RandomProjectionIndex.cs | 12 +- .../Storage/InMemoryVectorStorage.cs | 15 ++ .../LinearAlgebra/VectorIndexing/VectorSet.cs | 16 +- 24 files changed, 915 insertions(+), 130 deletions(-) create mode 100644 BrightData/Analysis/OnlineStandardDeviationAnalysis.cs create mode 100644 BrightData/Distribution/ExponentialDistribution.cs create mode 100644 BrightData/LinearAlgebra/VectorIndexing/Helper/IndexedFixedSizeGraphNode.cs create mode 100644 BrightData/LinearAlgebra/VectorIndexing/Helper/VectorGraph.cs diff --git a/BrightData.Cuda/CudaProvider.cs b/BrightData.Cuda/CudaProvider.cs index 5976c273..3022146a 100644 --- a/BrightData.Cuda/CudaProvider.cs +++ b/BrightData.Cuda/CudaProvider.cs @@ -662,7 +662,7 @@ internal float FindStdDev(IDeviceMemoryPtr a, uint size, float mean, uint ai = 1 if (ptr != a) ptr.Release(); - return Convert.ToSingle(System.Math.Sqrt(total.Sum() / inputSize)); + return MathF.Sqrt(total.Sum() / inputSize); } return 0f; } @@ -701,7 +701,7 @@ internal float EuclideanDistance(IDeviceMemoryPtr a, IDeviceMemoryPtr b, uint si { var ret = Allocate(size, stream); Invoke(_euclideanDistance, stream, size, a.DevicePointer, b.DevicePointer, ret.DevicePointer, size, ai, bi, ci); - return Convert.ToSingle(System.Math.Sqrt(SumValues(ret, size))); + return MathF.Sqrt(SumValues(ret, size)); } internal float ManhattanDistance(IDeviceMemoryPtr a, IDeviceMemoryPtr b, uint size, uint ai = 1, uint bi = 1, uint ci = 1, CuStream* stream = null) @@ -730,7 +730,7 @@ internal float CosineDistance(IDeviceMemoryPtr a, IDeviceMemoryPtr b, uint size, else if (bb.Equals(0f)) return 0.0f; else - return 1f - (ab / (float)System.Math.Sqrt(aa) / (float)System.Math.Sqrt(bb)); + return 1f - (ab / MathF.Sqrt(aa) / MathF.Sqrt(bb)); } finally { buffer.Release(); diff --git a/BrightData.UnitTests/AnalysisTests.cs b/BrightData.UnitTests/AnalysisTests.cs index e7c423de..f861607e 100644 --- a/BrightData.UnitTests/AnalysisTests.cs +++ b/BrightData.UnitTests/AnalysisTests.cs @@ -39,25 +39,25 @@ public void DateAnalysisNoMostFrequent() [Fact] public void IntegerAnalysis() { - var analysis = new[] { 1, 2, 3 }.Analyze(); + var analysis = new[] { 1, 2, 3 }.AnalyzeAsDoubles(); analysis.Min.Should().Be(1); analysis.Max.Should().Be(3); analysis.Median.Should().Be(2); analysis.NumDistinct.Should().Be(3); - analysis.Total.Should().Be(3); + analysis.Count.Should().Be(3); analysis.SampleStdDev.Should().Be(1); } [Fact] public void IntegerAnalysis2() { - var analysis = new[] { 1, 2, 2, 3 }.Analyze(); + var analysis = new[] { 1, 2, 2, 3 }.AnalyzeAsDoubles(); analysis.Min.Should().Be(1); analysis.Max.Should().Be(3); analysis.Median.Should().Be(2); analysis.NumDistinct.Should().Be(3); analysis.Mode.Should().Be(2); - analysis.Total.Should().Be(4); + analysis.Count.Should().Be(4); } [Fact] diff --git a/BrightData.UnitTests/SpanTests.cs b/BrightData.UnitTests/SpanTests.cs index 8dad21db..381ce691 100644 --- a/BrightData.UnitTests/SpanTests.cs +++ b/BrightData.UnitTests/SpanTests.cs @@ -1,4 +1,5 @@ using System; +using System.Linq; using BrightData.UnitTests.Helper; using FluentAssertions; using Xunit; @@ -54,5 +55,16 @@ public void SearchSpan() } resultCount.Should().Be(1); } + + [Fact] + public void GetRankedIndices() + { + Span span = stackalloc float[32]; + for (var i = 0; i < 32; i++) + span[i] = 16 - i; + var indices = span.GetRankedIndices(); + indices.Length.Should().Be(32); + indices.Should().ContainInConsecutiveOrder(32.AsRange().Select(i => 31 - i)); + } } } diff --git a/BrightData.UnitTests/VectorSetTests.cs b/BrightData.UnitTests/VectorSetTests.cs index b69ebd69..80dc3ca8 100644 --- a/BrightData.UnitTests/VectorSetTests.cs +++ b/BrightData.UnitTests/VectorSetTests.cs @@ -1,14 +1,24 @@ -using BrightData.UnitTests.Helper; +using System; +using BrightData.UnitTests.Helper; using System.Linq; using BrightData.LinearAlgebra.VectorIndexing; +using BrightData.LinearAlgebra.VectorIndexing.Helper; using BrightData.Types; using FluentAssertions; using Xunit; +using Xunit.Abstractions; namespace BrightData.UnitTests { public class VectorSetTests : UnitTestBase { + readonly ITestOutputHelper _testOutputHelper; + + public VectorSetTests(ITestOutputHelper testOutputHelper) + { + _testOutputHelper = testOutputHelper; + } + [Fact] public void Average() { @@ -53,5 +63,30 @@ public void Closest() score[0].Should().Be(2); score[1].Should().Be(1); } + + [Fact] + public void TestVectorGraphNode() + { + var node = new IndexedFixedSizeGraphNode(1); + node.Index.Should().Be(1); + node.NeighbourIndices.Length.Should().Be(0); + + node.TryAddNeighbour(2, 0.9f); + node.NeighbourIndices[0].Should().Be(2); + node.NeighbourWeights[0].Should().Be(0.9f); + + node.TryAddNeighbour(3, 0.8f); + node.NeighbourIndices[0].Should().Be(3); + node.NeighbourWeights[0].Should().Be(0.8f); + + for(var i = 4U; i <= 10; i++) + node.TryAddNeighbour(i, 1f - 0.1f * i); + node.NeighbourIndices.Length.Should().Be(8); + node.NeighbourIndices[0].Should().Be(10); + node.NeighbourIndices[1].Should().Be(9); + + node.TryAddNeighbour(20, 0.5f).Should().BeTrue(); + node.TryAddNeighbour(20, 0.5f).Should().BeFalse(); + } } } diff --git a/BrightData/Analysis/CastToDoubleNumericAnalysis.cs b/BrightData/Analysis/CastToDoubleNumericAnalysis.cs index 3431898a..236b219a 100644 --- a/BrightData/Analysis/CastToDoubleNumericAnalysis.cs +++ b/BrightData/Analysis/CastToDoubleNumericAnalysis.cs @@ -1,4 +1,5 @@ using System; +using System.Numerics; using BrightData.Converter; using BrightData.Types; @@ -8,12 +9,13 @@ namespace BrightData.Analysis /// Used to cast other numeric types to doubles for numeric analysis /// /// - internal class CastToDoubleNumericAnalysis(uint writeCount = Consts.MaxWriteCount) : IDataAnalyser - where T : struct + internal class CastToDoubleNumericAnalysis(uint writeCount = Consts.MaxWriteCount) : IDataAnalyser, INumericAnalysis + where T : unmanaged, INumber { readonly ConvertToDouble _converter = new(); + ulong _count; - public NumericAnalyser Analysis { get; } = new(writeCount); + public NumericAnalyser Analysis { get; } = new(writeCount); public void Add(T val) { @@ -35,5 +37,19 @@ public void WriteTo(MetaData metadata) { Analysis.WriteTo(metadata); } + + public T L1Norm => T.CreateSaturating(Analysis.L1Norm); + public T L2Norm => T.CreateSaturating(Analysis.L2Norm); + public T Min => T.CreateSaturating(Analysis.Min); + public T Max => T.CreateSaturating(Analysis.Max); + public T Mean => T.CreateSaturating(Analysis.Mean); + public T? SampleVariance => Analysis.SampleVariance.HasValue ? T.CreateSaturating(Analysis.SampleVariance.Value) : null; + public T? PopulationVariance => Analysis.PopulationVariance.HasValue ? T.CreateSaturating(Analysis.PopulationVariance.Value) : null; + public uint NumDistinct => Analysis.NumDistinct; + public T? SampleStdDev => Analysis.SampleStdDev.HasValue ? T.CreateSaturating(Analysis.SampleStdDev.Value) : null; + public T? PopulationStdDev => Analysis.PopulationStdDev.HasValue ? T.CreateSaturating(Analysis.PopulationStdDev.Value) : null; + public ulong Count => Analysis.Count; + public T? Median => Analysis.Median.HasValue ? T.CreateSaturating(Analysis.Median.Value) : null; + public T? Mode => Analysis.Mode.HasValue ? T.CreateSaturating(Analysis.Mode.Value) : null; } } diff --git a/BrightData/Analysis/LinearBinnedFrequencyAnalysis.cs b/BrightData/Analysis/LinearBinnedFrequencyAnalysis.cs index 2befe0d3..2edbc867 100644 --- a/BrightData/Analysis/LinearBinnedFrequencyAnalysis.cs +++ b/BrightData/Analysis/LinearBinnedFrequencyAnalysis.cs @@ -1,20 +1,22 @@ using System; using System.Collections.Generic; +using System.Numerics; namespace BrightData.Analysis { /// /// Binned frequency analysis /// - internal class LinearBinnedFrequencyAnalysis(double min, double max, uint numBins) + internal class LinearBinnedFrequencyAnalysis(T min, T max, uint numBins) + where T : unmanaged, INumber, IBinaryFloatingPointIeee754 { - readonly double _step = (max - min) / numBins; + readonly T _step = (max - min) / T.CreateTruncating(numBins); readonly ulong[] _bins = new ulong[numBins]; ulong _belowRange = 0, _aboveRange = 0; - public void Add(double val) + public void Add(T val) { - if (double.IsNaN(val)) + if (T.IsNaN(val)) return; if (val < min) @@ -29,23 +31,24 @@ public void Add(double val) } } - public IEnumerable<(double Start, double End, ulong Count)> ContinuousFrequency + public IEnumerable<(T Start, T End, ulong Count)> ContinuousFrequency { get { if (_belowRange > 0) - yield return (double.NegativeInfinity, min, _belowRange); + yield return (T.NegativeInfinity, min, _belowRange); var index = 0; foreach (var c in _bins) { + var val = T.CreateTruncating(index); yield return ( - min + (index * _step), - min + (index + 1) * _step, + min + (val * _step), + min + (val + T.One) * _step, c ); ++index; } if(_aboveRange > 0) - yield return (max, double.PositiveInfinity, _aboveRange); + yield return (max, T.PositiveInfinity, _aboveRange); } } } diff --git a/BrightData/Analysis/NumericAnalyser.cs b/BrightData/Analysis/NumericAnalyser.cs index be2922c4..0f6011e5 100644 --- a/BrightData/Analysis/NumericAnalyser.cs +++ b/BrightData/Analysis/NumericAnalyser.cs @@ -1,6 +1,8 @@ using System; using System.Collections.Generic; +using System.Globalization; using System.Linq; +using System.Numerics; using BrightData.Types; namespace BrightData.Analysis @@ -8,27 +10,22 @@ namespace BrightData.Analysis /// /// Numeric analysis /// - internal class NumericAnalyser(uint writeCount = Consts.MaxWriteCount) : IDataAnalyser + internal class NumericAnalyser(uint writeCount = Consts.MaxWriteCount) : OnlineStandardDeviationAnalysis, IDataAnalyser, INumericAnalysis + where T: unmanaged, INumber, IMinMaxValue, IBinaryFloatingPointIeee754, IConvertible { - readonly SortedDictionary _distinct = []; - double _mean, _m2, _min = double.MaxValue, _max = double.MinValue, _mode, _l1, _l2; - ulong _total, _highestCount; + readonly SortedDictionary _distinct = []; + T _mode, _l2; + ulong _highestCount; - public virtual void Add(double val) + public override void Add(T val) { - ++_total; - - // online std deviation and mean - // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm - var delta = val - _mean; - _mean += (delta / _total); - _m2 += delta * (val - _mean); + base.Add(val); // find the min and the max - if (val < _min) - _min = val; - if (val > _max) - _max = val; + if (val < Min) + Min = val; + if (val > Max) + Max = val; // add to distinct values if (_distinct.TryGetValue(val, out var count)) @@ -42,61 +39,48 @@ public virtual void Add(double val) } // calculate norms - _l1 += Math.Abs(val); + L1Norm += T.Abs(val); _l2 += val * val; } - public void Append(ReadOnlySpan span) + public void Append(ReadOnlySpan span) { foreach(var item in span) Add(item); } - public double L1Norm => _l1; - public double L2Norm => Math.Sqrt(_l2); - public double Min => _min; - public double Max => _max; - public double Mean => _mean; - public double? SampleVariance => _total > 1 ? _m2 / (_total - 1) : null; - public double? PopulationVariance => _total > 0 ? _m2 / _total : null; + public T L1Norm { get; private set; } + public T L2Norm => T.Sqrt(_l2); + public T Min { get; private set; } = T.MaxValue; + public T Max { get; private set; } = T.MinValue; public uint NumDistinct => (uint)_distinct.Count; - public double? SampleStdDev { - get - { - var variance = SampleVariance; - if (variance.HasValue) - return Math.Sqrt(variance.Value); - return null; - } - } - - public double? PopulationStdDev + public T? Median { get { - var variance = PopulationVariance; - if (variance.HasValue) - return Math.Sqrt(variance.Value); - return null; + T? ret = null; + if (_distinct.Count > 0) { + if (Count % 2 == 1) + return SortedValues.Skip((int) (Count / 2)).First(); + return CalculateAverage(SortedValues.Skip((int)(Count / 2 - 1)).Take(2)); + } + return ret; } } - public double? Median + static T CalculateAverage(IEnumerable values) { - get - { - double? ret = null; - if (_distinct.Count > 0) { - if (_total % 2 == 1) - return SortedValues.Skip((int) (_total / 2)).First(); - return SortedValues.Skip((int) (_total / 2 - 1)).Take(2).Average(); - } - return ret; + var ret = T.Zero; + var count = 0; + foreach (var value in values) { + ret += value; + ++count; } + return ret / T.CreateTruncating(count); } - IEnumerable SortedValues + IEnumerable SortedValues { get { @@ -107,7 +91,7 @@ IEnumerable SortedValues } } - public double? Mode + public T? Mode { get { @@ -121,47 +105,55 @@ public void AddObject(object obj) { var str = obj.ToString(); if (str != null) { - var val = double.Parse(str); + var val = T.Parse(str, null); Add(val); } } + static double? CreateNullable(T? value) + { + if (value.HasValue) + return double.CreateChecked(value.Value); + return null; + } + public void WriteTo(MetaData metadata) { metadata.Set(Consts.HasBeenAnalysed, true); metadata.Set(Consts.IsNumeric, true); - metadata.Set(Consts.L1Norm, L1Norm); - metadata.Set(Consts.L2Norm, L2Norm); - metadata.Set(Consts.Min, Min); - metadata.Set(Consts.Max, Max); - metadata.Set(Consts.Mean, Mean); - metadata.Set(Consts.Total, _total); - metadata.SetIfNotNull(Consts.SampleVariance, SampleVariance); - metadata.SetIfNotNull(Consts.SampleStdDev, SampleStdDev); - metadata.SetIfNotNull(Consts.PopulationVariance, PopulationVariance); - metadata.SetIfNotNull(Consts.PopulationStdDev, PopulationStdDev); - metadata.SetIfNotNull(Consts.Median, Median); - metadata.SetIfNotNull(Consts.Mode, Mode); + metadata.Set(Consts.L1Norm, double.CreateChecked(L1Norm)); + metadata.Set(Consts.L2Norm, double.CreateChecked(L2Norm)); + metadata.Set(Consts.Min, double.CreateChecked(Min)); + metadata.Set(Consts.Max, double.CreateChecked(Max)); + metadata.Set(Consts.Mean, double.CreateChecked(Mean)); + metadata.Set(Consts.Total, Count); + metadata.SetIfNotNull(Consts.SampleVariance, CreateNullable(SampleVariance)); + metadata.SetIfNotNull(Consts.SampleStdDev, CreateNullable(SampleStdDev)); + metadata.SetIfNotNull(Consts.PopulationVariance, CreateNullable(PopulationVariance)); + metadata.SetIfNotNull(Consts.PopulationStdDev, CreateNullable(PopulationStdDev)); + metadata.SetIfNotNull(Consts.Median, CreateNullable(Median)); + metadata.SetIfNotNull(Consts.Mode, CreateNullable(Mode)); metadata.Set(Consts.NumDistinct, NumDistinct); - var total = (double) _total; + var total = T.CreateTruncating(Count); var range = Max - Min; - if (range > 0) { - var bin = new LinearBinnedFrequencyAnalysis(Min, Max, 10); + if (range > T.Zero) { + var bin = new LinearBinnedFrequencyAnalysis(Min, Max, 10); var index = 0U; foreach (var item in _distinct.OrderByDescending(kv => kv.Value)) { if (index++ < writeCount) - metadata.Set($"{Consts.FrequencyPrefix}{item.Key}", item.Value / total); + metadata.Set($"{Consts.FrequencyPrefix}{item.Key}", double.CreateChecked(T.CreateTruncating(item.Value) / total)); for (ulong i = 0; i < item.Value; i++) bin.Add(item.Key); } + var formatProvider = CultureInfo.InvariantCulture.NumberFormat; foreach (var (s, e, c) in bin.ContinuousFrequency) { - if (c == 0 && (double.IsNegativeInfinity(s) || double.IsPositiveInfinity(e))) + if (c == 0 && (T.IsNegativeInfinity(s) || T.IsPositiveInfinity(e))) continue; - var start = double.IsNegativeInfinity(s) ? "-∞" : s.ToString("G17"); - var end = double.IsPositiveInfinity(e) ? "∞" : e.ToString("G17"); - metadata.Set($"{Consts.FrequencyRangePrefix}{start}/{end}", c / total); + var start = T.IsNegativeInfinity(s) ? "-∞" : s.ToString("G17", formatProvider); + var end = T.IsPositiveInfinity(e) ? "∞" : e.ToString("G17", formatProvider); + metadata.Set($"{Consts.FrequencyRangePrefix}{start}/{end}", double.CreateChecked(T.CreateTruncating(c) / total)); } } } diff --git a/BrightData/Analysis/OnlineStandardDeviationAnalysis.cs b/BrightData/Analysis/OnlineStandardDeviationAnalysis.cs new file mode 100644 index 00000000..24db120a --- /dev/null +++ b/BrightData/Analysis/OnlineStandardDeviationAnalysis.cs @@ -0,0 +1,48 @@ +using System; +using System.Numerics; + +namespace BrightData.Analysis +{ + internal class OnlineStandardDeviationAnalysis : IStandardDeviationAnalysis + where T: unmanaged, INumber, IMinMaxValue, IBinaryFloatingPointIeee754, IConvertible + { + T _m2; + + public virtual void Add(T value) + { + ++Count; + + // online std deviation and mean + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm + var delta = value - Mean; + Mean += (delta / T.CreateTruncating(Count)); + _m2 += delta * (value - Mean); + } + + public ulong Count { get; private set; } = 0; + public T Mean { get; private set; } + public T? SampleVariance => Count > 1 ? _m2 / T.CreateTruncating(Count - 1) : null; + public T? PopulationVariance => Count > 0 ? _m2 / T.CreateTruncating(Count) : null; + + public T? SampleStdDev { + get + { + var variance = SampleVariance; + if (variance.HasValue) + return T.Sqrt(variance.Value); + return null; + } + } + + public T? PopulationStdDev + { + get + { + var variance = PopulationVariance; + if (variance.HasValue) + return T.Sqrt(variance.Value); + return null; + } + } + } +} diff --git a/BrightData/Analysis/StaticAnalysers.cs b/BrightData/Analysis/StaticAnalysers.cs index 8c41c2a9..5a9047bb 100644 --- a/BrightData/Analysis/StaticAnalysers.cs +++ b/BrightData/Analysis/StaticAnalysers.cs @@ -1,4 +1,5 @@ using System; +using System.Numerics; using BrightData.Helper; namespace BrightData.Analysis @@ -20,7 +21,9 @@ public static class StaticAnalysers /// /// Number of items to write in histogram /// - public static IDataAnalyser CreateNumericAnalyser(uint writeCount = Consts.MaxWriteCount) where T:struct => new CastToDoubleNumericAnalysis(writeCount); + public static IDataAnalyser CreateNumericAnalyser(uint writeCount = Consts.MaxWriteCount) where T: unmanaged, IMinMaxValue, IBinaryFloatingPointIeee754, IConvertible => new NumericAnalyser(writeCount); + + public static IDataAnalyser CreateNumericAnalyserCastToDouble(uint writeCount = Consts.MaxWriteCount) where T: unmanaged, INumber => new CastToDoubleNumericAnalysis(writeCount); /// /// Creates an analyzer that will convert each item to a string @@ -56,7 +59,7 @@ public static class StaticAnalysers /// /// /// - public static IDataAnalyser CreateNumericAnalyser(uint writeCount = Consts.MaxWriteCount) => new NumericAnalyser(writeCount); + public static IDataAnalyser CreateNumericAnalyser(uint writeCount = Consts.MaxWriteCount) => new NumericAnalyser(writeCount); /// /// Creates a string analyzer diff --git a/BrightData/BrightData.xml b/BrightData/BrightData.xml index 6be107df..226c3e14 100644 --- a/BrightData/BrightData.xml +++ b/BrightData/BrightData.xml @@ -60,12 +60,12 @@ Index based type analysis - + Binned frequency analysis - + Binned frequency analysis @@ -117,12 +117,12 @@ - + Numeric analysis - + Numeric analysis @@ -2205,6 +2205,14 @@ + + + Analyzes numbers in a sequence + + + + + Analyzes dates in a sequence @@ -2801,6 +2809,13 @@ Upper bound + + + Generates a range of positive integers + + + + Generates a range of positive integers @@ -4043,6 +4058,14 @@ Standard deviation + + + Create an exponential distribution + + + + + Hardware dependent size of a numeric vector of floats @@ -5270,6 +5293,22 @@ + + + Returns the index of each element of span, ordered from lowest to highest + + + + + + + + Returns the index of each element of span, ordered from lowest to highest + + + + + Creates a read only vector from the span @@ -9387,7 +9426,12 @@ - Vectors are projected into a random lower dimensional space + Vectors are randomly projected into a random lower dimensional space + + + + + A nearest neighbour graph is created to improve index performance @@ -9447,6 +9491,19 @@ + + + Passes each vector to the callback, possible in parallel + + + + + + + Returns all vectors + + + A vector set index @@ -11975,6 +12032,133 @@ + + + Fixed size indexed graph node that maintains weighted list of neighbours as a max heap + + + + + + + Fixed size indexed graph node that maintains weighted list of neighbours as a max heap + + + + + + + + + + Max number of neighbours + + + + + Current number of neighbours + + + + + The smallest neighbour weight + + + + + The largest neighbour weight + + + + + The index of the neighbour with the smallest weight + + + + + The index of the neighbour with the largest weight + + + + + Tries to add a new neighbour - will succeed if there aren't already max neighbours with a smaller weight + + + + + + + + Sorted list of neighbour indices + + + + + Sorted list of neighbour weights + + + + + Returns a neighbour weight + + + + + + Enumerates the weighted neighbours + + + + + Creates a graph of vectors with a fixed size set of neighbours + + + + + + Gets the neighbours for a node, sorted by distance + + + + + + + Gets the weights for the node's neighbours + + + + + + + Creates + + + + + + + + + + Writes the graph to disk + + + + + + Loads a vector graph from disk + + + + + + + Find distance between each vector in the set and each input vector + + + + + Represents a set of vectors diff --git a/BrightData/Distribution/ExponentialDistribution.cs b/BrightData/Distribution/ExponentialDistribution.cs new file mode 100644 index 00000000..a6ee0b6f --- /dev/null +++ b/BrightData/Distribution/ExponentialDistribution.cs @@ -0,0 +1,20 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace BrightData.Distribution +{ + internal class ExponentialDistribution(BrightDataContext context, float lambda) : IContinuousDistribution + { + public float Sample() + { + float r; + do { + r = context.NextRandomFloat(); + } while (r == 0f); + return -MathF.Log(r) / lambda; + } + } +} diff --git a/BrightData/ExtensionMethods.Analysis.cs b/BrightData/ExtensionMethods.Analysis.cs index 739648b4..115be4f1 100644 --- a/BrightData/ExtensionMethods.Analysis.cs +++ b/BrightData/ExtensionMethods.Analysis.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Numerics; using BrightData.Analysis; using BrightData.Analysis.Readers; using BrightData.Types; @@ -81,13 +82,28 @@ public static partial class ExtensionMethods /// /// /// - public static NumericAnalysis Analyze(this IEnumerable data) - where T : struct + public static INumericAnalysis Analyze(this IEnumerable data) + where T : unmanaged, INumber, IMinMaxValue, IBinaryFloatingPointIeee754, IConvertible + { + var analysis = new NumericAnalyser(); + foreach (var item in data) + analysis.Add(item); + return analysis; + } + + /// + /// Analyzes numbers in a sequence + /// + /// + /// + /// + public static INumericAnalysis AnalyzeAsDoubles(this IEnumerable data) + where T : unmanaged, INumber { var analysis = new CastToDoubleNumericAnalysis(); foreach (var item in data) analysis.Add(item); - return analysis.GetMetaData().GetNumericAnalysis(); + return analysis; } /// diff --git a/BrightData/ExtensionMethods.DataTable.cs b/BrightData/ExtensionMethods.DataTable.cs index 4717b442..320956e7 100644 --- a/BrightData/ExtensionMethods.DataTable.cs +++ b/BrightData/ExtensionMethods.DataTable.cs @@ -240,11 +240,11 @@ public static IDataAnalyser GetAnalyser(this BrightDataType type, MetaData metaD { BrightDataType.Double => StaticAnalysers.CreateNumericAnalyser(writeCount), BrightDataType.Float => StaticAnalysers.CreateNumericAnalyser(writeCount), - BrightDataType.Decimal => StaticAnalysers.CreateNumericAnalyser(writeCount), - BrightDataType.SByte => StaticAnalysers.CreateNumericAnalyser(writeCount), - BrightDataType.Int => StaticAnalysers.CreateNumericAnalyser(writeCount), - BrightDataType.Long => StaticAnalysers.CreateNumericAnalyser(writeCount), - BrightDataType.Short => StaticAnalysers.CreateNumericAnalyser(writeCount), + BrightDataType.Decimal => StaticAnalysers.CreateNumericAnalyserCastToDouble(writeCount), + BrightDataType.SByte => StaticAnalysers.CreateNumericAnalyserCastToDouble(writeCount), + BrightDataType.Int => StaticAnalysers.CreateNumericAnalyserCastToDouble(writeCount), + BrightDataType.Long => StaticAnalysers.CreateNumericAnalyserCastToDouble(writeCount), + BrightDataType.Short => StaticAnalysers.CreateNumericAnalyserCastToDouble(writeCount), BrightDataType.Date => StaticAnalysers.CreateDateAnalyser(), BrightDataType.BinaryData => StaticAnalysers.CreateFrequencyAnalyser(writeCount), BrightDataType.DateOnly => StaticAnalysers.CreateFrequencyAnalyser(writeCount), diff --git a/BrightData/ExtensionMethods.Distributions.cs b/BrightData/ExtensionMethods.Distributions.cs index 93871625..5fbbbb91 100644 --- a/BrightData/ExtensionMethods.Distributions.cs +++ b/BrightData/ExtensionMethods.Distributions.cs @@ -79,5 +79,13 @@ public partial class ExtensionMethods /// Standard deviation /// public static IContinuousDistribution CreateNormalDistribution(this BrightDataContext context, float mean = 0f, float stdDev = 1f) => new NormalDistribution(context, mean, stdDev); + + /// + /// Create an exponential distribution + /// + /// + /// + /// + public static IContinuousDistribution CreateExponentialDistribution(this BrightDataContext context, float lambda = 1f) => new ExponentialDistribution(context, lambda); } } diff --git a/BrightData/ExtensionMethods.Span.cs b/BrightData/ExtensionMethods.Span.cs index 4bdd39af..b2f6e530 100644 --- a/BrightData/ExtensionMethods.Span.cs +++ b/BrightData/ExtensionMethods.Span.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Numerics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; @@ -1619,6 +1620,40 @@ static unsafe void CacheTranspose(T* from, uint rows, uint columns, uint rb, } } + /// + /// Returns the index of each element of span, ordered from lowest to highest + /// + /// + /// + /// + public static uint[] GetRankedIndices(this ReadOnlySpan span) where T : unmanaged, INumber + { + var len = span.Length; + using var temp = SpanOwner.Allocate(len); + var copy = temp.Span; + span.CopyTo(copy); + + using var indices = SpanOwner.Allocate(len); + var indicesSpan = indices.Span; + for (var i = 0; i < len; i++) + indicesSpan[i] = (uint)i; + + copy.Sort(indicesSpan); + + var ret = len.AsRange().ToArray(); + for(var i = 0; i < len; i++) + ret[indicesSpan[i]] = (uint)i; + return ret; + } + + /// + /// Returns the index of each element of span, ordered from lowest to highest + /// + /// + /// + /// + public static uint[] GetRankedIndices(this Span span) where T : unmanaged, INumber => GetRankedIndices((ReadOnlySpan)span); + /// /// Creates a read only vector from the span /// diff --git a/BrightData/ExtensionMethods.cs b/BrightData/ExtensionMethods.cs index bddc3ed4..f9f7ee8b 100644 --- a/BrightData/ExtensionMethods.cs +++ b/BrightData/ExtensionMethods.cs @@ -236,6 +236,13 @@ public static ICanConvert GetFloatConverter(this BrightDataContext /// public static IEnumerable AsRange(this int count) => Enumerable.Range(0, count).Select(i => (uint)i); + /// + /// Generates a range of positive integers + /// + /// + /// + public static IEnumerable AsRange(this Range range) => Enumerable.Range(range.Start.Value, (range.End.Value - range.Start.Value)).Select(i => (uint)i); + /// /// Generates a range of positive integers /// diff --git a/BrightData/Interfaces.VectorIndexing.cs b/BrightData/Interfaces.VectorIndexing.cs index 964ab783..6c90a8fa 100644 --- a/BrightData/Interfaces.VectorIndexing.cs +++ b/BrightData/Interfaces.VectorIndexing.cs @@ -18,9 +18,14 @@ public enum VectorIndexStrategy Flat, /// - /// Vectors are projected into a random lower dimensional space + /// Vectors are randomly projected into a random lower dimensional space /// - RandomProjection + RandomProjection, + + /// + /// A nearest neighbour graph is created to improve index performance + /// + NearestNeighbours } /// @@ -82,6 +87,19 @@ public interface IStoreVectors : IStoreVectors, IDisposable /// /// void ForEach(IndexedSpanCallback callback); + + /// + /// Passes each vector to the callback, possible in parallel + /// + /// + /// + void ForEach(IEnumerable indices, IndexedSpanCallback callback); + + /// + /// Returns all vectors + /// + /// + ReadOnlyMemory[] GetAll(); } /// diff --git a/BrightData/Interfaces.cs b/BrightData/Interfaces.cs index e4a09183..c5bf4f6c 100644 --- a/BrightData/Interfaces.cs +++ b/BrightData/Interfaces.cs @@ -163,7 +163,29 @@ public interface IDataAnalyser : IAppendBlocks, IDataAnalyser where T : no void Add(T obj); } -/// + public interface IStandardDeviationAnalysis where T : unmanaged + { + T Mean { get; } + T? SampleVariance { get; } + T? PopulationVariance { get; } + T? SampleStdDev { get; } + T? PopulationStdDev { get; } + ulong Count { get; } + } + + public interface INumericAnalysis : IStandardDeviationAnalysis + where T: unmanaged + { + T L1Norm { get; } + T L2Norm { get; } + T Min { get; } + T Max { get; } + uint NumDistinct { get; } + T? Median { get; } + T? Mode { get; } + } + + /// /// Types of data normalization /// public enum NormalizationType : byte diff --git a/BrightData/LinearAlgebra/VectorIndexing/Helper/IndexedFixedSizeGraphNode.cs b/BrightData/LinearAlgebra/VectorIndexing/Helper/IndexedFixedSizeGraphNode.cs new file mode 100644 index 00000000..f8e9822f --- /dev/null +++ b/BrightData/LinearAlgebra/VectorIndexing/Helper/IndexedFixedSizeGraphNode.cs @@ -0,0 +1,212 @@ +using System; +using System.Collections.Generic; +using System.Numerics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace BrightData.LinearAlgebra.VectorIndexing.Helper +{ + /// + /// Fixed size indexed graph node that maintains weighted list of neighbours as a max heap + /// + /// + /// + public record struct IndexedFixedSizeGraphNode(uint Index) + where T : unmanaged, INumber, IMinMaxValue + { + /// + /// Max number of neighbours + /// + public const int MaxNeighbours = 8; + [InlineArray(MaxNeighbours)] + internal struct IndexFixedSize + { + public uint _element0; + } + [InlineArray(MaxNeighbours)] + internal struct DistanceFixedSize + { + public T _element0; + } + readonly IndexFixedSize _neighbourIndices = new(); + readonly DistanceFixedSize _neighbourWeights = new(); + + /// + /// Current number of neighbours + /// + public byte NeighbourCount { get; private set; } = 0; + + /// + /// The smallest neighbour weight + /// + public readonly T MinDistance => NeighbourCount > 0 ? NeighbourWeights[0] : T.MaxValue; + + /// + /// The largest neighbour weight + /// + public readonly T MaxDistance => NeighbourCount > 0 ? NeighbourWeights[NeighbourCount - 1] : T.MinValue; + + /// + /// The index of the neighbour with the smallest weight + /// + public readonly uint MinNeighbourIndex => NeighbourCount > 0 ? NeighbourIndices[0] : uint.MaxValue; + + /// + /// The index of the neighbour with the largest weight + /// + public readonly uint MaxNeighbourIndex => NeighbourCount > 0 ? NeighbourIndices[NeighbourCount - 1] : uint.MaxValue; + + /// + /// Tries to add a new neighbour - will succeed if there aren't already max neighbours with a smaller weight + /// + /// + /// + /// + public unsafe bool TryAddNeighbour2(uint neighbourIndex, T neighbourWeight) + { + var isFull = NeighbourCount == MaxNeighbours; + fixed (uint* indices = &_neighbourIndices._element0) + fixed (T* weights = &_neighbourWeights._element0) { + // check to see if it should be inserted + if (isFull && weights[NeighbourCount - 1] <= neighbourWeight) + return false; + + byte insertPosition = 0; + var foundInsertPosition = false; + for (byte i = 0; i < NeighbourCount; i++) { + // check that the neighbour has not already been added + if (indices[i] == neighbourIndex) + return false; + + // see if we should insert here + if (weights[i] > neighbourWeight) { + insertPosition = i; + foundInsertPosition = true; + break; + } + } + + if (!foundInsertPosition) { + // there is no room left + if (isFull) + return false; + + // insert at end + insertPosition = NeighbourCount; + } + else { + // shuffle to make room + for (var i = NeighbourCount - (isFull ? 2 : 1); i >= insertPosition; i--) { + indices[i + 1] = indices[i]; + weights[i + 1] = weights[i]; + } + } + + // insert the item + indices[insertPosition] = neighbourIndex; + weights[insertPosition] = neighbourWeight; + if (!isFull) + ++NeighbourCount; + } + return true; + } + + public bool TryAddNeighbour(uint neighbourIndex, T neighbourWeight) + { + var isFull = NeighbourCount == MaxNeighbours; + var indices = MemoryMarshal.CreateSpan(ref Unsafe.As(ref Unsafe.AsRef(in _neighbourIndices)), MaxNeighbours); + var weights = MemoryMarshal.CreateSpan(ref Unsafe.As(ref Unsafe.AsRef(in _neighbourWeights)), MaxNeighbours); + + // check to see if it should be inserted + if (isFull && weights[NeighbourCount - 1] <= neighbourWeight) + return false; + + // Use binary search to find the insertion position + int left = 0, + right = NeighbourCount - 1, + insertPosition = NeighbourCount + ; + while (left <= right) { + var mid = left + (right - left) / 2; + if (weights[mid] > neighbourWeight) { + insertPosition = mid; + right = mid - 1; + } + else if (weights[mid] < neighbourWeight) { + left = mid + 1; + } + else { + // Check if the neighbour already exists + if (indices[mid] == neighbourIndex) + return false; + + left = mid + 1; + } + } + + // Check if the neighbour already exists in the left partition + for (var i = insertPosition-1; i >= 0; i--) { + if (weights[i] < neighbourWeight) + break; + if (indices[i] == neighbourIndex) + return false; + } + + if (insertPosition == NeighbourCount) { + // there is no room left + if (isFull) + return false; + } + else { + // shuffle to make room + for (var i = NeighbourCount - (isFull ? 2 : 1); i >= insertPosition; i--) { + indices[i + 1] = indices[i]; + weights[i + 1] = weights[i]; + } + } + + // insert the item + indices[insertPosition] = neighbourIndex; + weights[insertPosition] = neighbourWeight; + if (!isFull) + ++NeighbourCount; + return true; + } + + /// + /// Sorted list of neighbour indices + /// + public readonly ReadOnlySpan NeighbourIndices => MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As(ref Unsafe.AsRef(in _neighbourIndices)), NeighbourCount); + + /// + /// Sorted list of neighbour weights + /// + public readonly ReadOnlySpan NeighbourWeights => MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As(ref Unsafe.AsRef(in _neighbourWeights)), NeighbourCount); + + /// + /// Returns a neighbour weight + /// + /// + public readonly (uint NeighbourIndex, T NeighbourWeight) this[byte index] + { + get + { + if (index < NeighbourCount) + return (NeighbourIndices[index], NeighbourWeights[index]); + return (uint.MaxValue, T.MinValue); + } + } + + /// + /// Enumerates the weighted neighbours + /// + public readonly IEnumerable<(uint NeighbourIndex, T NeighbourWeight)> WeightedNeighbours + { + get + { + for (byte i = 0; i < NeighbourCount; i++) + yield return this[i]; + } + } + } +} diff --git a/BrightData/LinearAlgebra/VectorIndexing/Helper/VectorGraph.cs b/BrightData/LinearAlgebra/VectorIndexing/Helper/VectorGraph.cs new file mode 100644 index 00000000..d7150987 --- /dev/null +++ b/BrightData/LinearAlgebra/VectorIndexing/Helper/VectorGraph.cs @@ -0,0 +1,123 @@ +using System; +using System.IO; +using System.Numerics; +using System.Runtime.CompilerServices; +using System.Threading.Tasks; +using CommunityToolkit.HighPerformance; + +namespace BrightData.LinearAlgebra.VectorIndexing.Helper +{ + /// + /// Creates a graph of vectors with a fixed size set of neighbours + /// + /// + public class VectorGraph + where T : unmanaged, IBinaryFloatingPointIeee754, IMinMaxValue + { + readonly IndexedFixedSizeGraphNode[] _nodes; + + VectorGraph(IndexedFixedSizeGraphNode[] nodes) + { + _nodes = nodes; + } + + /// + /// Gets the neighbours for a node, sorted by distance + /// + /// + /// + public ReadOnlySpan GetNeighbours(uint vectorIndex) => _nodes[vectorIndex].NeighbourIndices; + + /// + /// Gets the weights for the node's neighbours + /// + /// + /// + public ReadOnlySpan GetNeighbourWeights(uint vectorIndex) => _nodes[vectorIndex].NeighbourWeights; + + /// + /// Creates + /// + /// + /// + /// + /// + /// + [SkipLocalsInit] + public static unsafe VectorGraph Build( + IStoreVectors vectors, + DistanceMetric distanceMetric, + bool shortCircuitIfNodeNeighboursAreFull = true, + Action? onNode = null) + { + var size = vectors.Size; + var distance = size <= 1024 + ? stackalloc T[(int)size] + : new T[size]; + + var ret = GC.AllocateUninitializedArray>((int)size); + for (var i = 0U; i < size; i++) + ret[i] = new(i); + + for (var i = 0U; i < size; i++) + { + if (shortCircuitIfNodeNeighboursAreFull && ret[i].NeighbourCount == IndexedFixedSizeGraphNode.MaxNeighbours) + continue; + + // find the distance between this node and each of its neighbours + fixed (T* dest = distance) + { + var destPtr = dest; + var currentIndex = i; + vectors.ForEach((x, j) => + { + if(currentIndex != j) + destPtr[j] = T.Abs(x.FindDistance(vectors[currentIndex], distanceMetric)); + }); + } + + // find top N closest neighbours + var maxHeap = new IndexedFixedSizeGraphNode(); + for (var j = 0; j < size; j++) { + if (i == j) + continue; + var d = distance[j]; + maxHeap.TryAddNeighbour((uint)j, d); + } + + // connect the closest nodes + foreach (var (index, d) in maxHeap.WeightedNeighbours) + { + ret[index].TryAddNeighbour(i, d); + ret[i].TryAddNeighbour(index, d); + } + onNode?.Invoke(i); + } + + return new(ret); + } + + /// + /// Writes the graph to disk + /// + /// + public async Task WriteToDisk(string filePath) + { + using var fileHandle = File.OpenHandle(filePath, FileMode.Create, FileAccess.Write, FileShare.None, FileOptions.Asynchronous); + await RandomAccess.WriteAsync(fileHandle, _nodes.AsMemory().AsBytes(), 0); + } + + /// + /// Loads a vector graph from disk + /// + /// + /// + public static async Task> LoadFromDisk(string filePath) + { + using var fileHandle = File.OpenHandle(filePath); + var ret = GC.AllocateUninitializedArray>((int)(RandomAccess.GetLength(fileHandle) / Unsafe.SizeOf>())); + await RandomAccess.ReadAsync(fileHandle, ret.AsMemory().AsBytes(), 0); + return new(ret); + } + } +} diff --git a/BrightData/LinearAlgebra/VectorIndexing/IndexStrategy/FlatVectorIndex.cs b/BrightData/LinearAlgebra/VectorIndexing/IndexStrategy/FlatVectorIndex.cs index 2056fb47..3b25da62 100644 --- a/BrightData/LinearAlgebra/VectorIndexing/IndexStrategy/FlatVectorIndex.cs +++ b/BrightData/LinearAlgebra/VectorIndexing/IndexStrategy/FlatVectorIndex.cs @@ -36,19 +36,31 @@ public unsafe IEnumerable Rank(ReadOnlySpan vector, DistanceMetric dist public uint[] Closest(ReadOnlyMemory[] vector, DistanceMetric distanceMetric) { - // find distance between each vector in the set and each input vector var size = Storage.Size; - var distance = new T[size, vector.Length]; + var distance = GetDistance(vector, distanceMetric); + + // find the closest input vector index for each vector in the set + var ret = new uint[size]; + Parallel.For(0, size, i => ret[i] = ((ReadOnlySpan)distance.AsSpan2D().GetRowSpan((int)i)).MinimumIndex()); + return ret; + } + + /// + /// Find distance between each vector in the set and each input vector + /// + /// + /// + /// + T[,] GetDistance(ReadOnlyMemory[] vector, DistanceMetric distanceMetric) + { + var size = Storage.Size; + var ret = new T[size, vector.Length]; Parallel.For(0, size * vector.Length, i => { var dataIndex = (uint)i % size; var vectorIndex = (uint)i / size; - distance[dataIndex, vectorIndex] = Storage[dataIndex].FindDistance(vector[vectorIndex].Span, distanceMetric); + ret[dataIndex, vectorIndex] = Storage[dataIndex].FindDistance(vector[vectorIndex].Span, distanceMetric); }); - - // find the closest input vector index for each vector in the set - var ret = new uint[size]; - Parallel.For(0, size, i => ret[i] = ((ReadOnlySpan)distance.AsSpan2D().GetRowSpan((int)i)).MinimumIndex()); return ret; } } diff --git a/BrightData/LinearAlgebra/VectorIndexing/IndexStrategy/RandomProjectionIndex.cs b/BrightData/LinearAlgebra/VectorIndexing/IndexStrategy/RandomProjectionIndex.cs index f7cf53ed..839bce87 100644 --- a/BrightData/LinearAlgebra/VectorIndexing/IndexStrategy/RandomProjectionIndex.cs +++ b/BrightData/LinearAlgebra/VectorIndexing/IndexStrategy/RandomProjectionIndex.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using System.Numerics; +using System.Runtime.Intrinsics; using System.Text; using System.Threading.Tasks; using BrightData.LinearAlgebra.VectorIndexing.Storage; @@ -50,12 +51,15 @@ public IEnumerable Rank(ReadOnlySpan vector, DistanceMetric distanceMet public uint[] Closest(ReadOnlyMemory[] vector, DistanceMetric distanceMetric) { - var vectors = vector.Select(x => { + return _projectionIndex.Closest(Project(vector), distanceMetric); + } + + ReadOnlyMemory[] Project(ReadOnlyMemory[] vectors) => + vectors.Select(x => { using var vector2 = _lap.CreateVector(x); using var projection = _randomProjection.Multiply(vector2); return new ReadOnlyMemory(projection.ReadOnlySegment.ToNewArray()); - }).ToArray(); - return _projectionIndex.Closest(vectors, distanceMetric); - } + }).ToArray() + ; } } diff --git a/BrightData/LinearAlgebra/VectorIndexing/Storage/InMemoryVectorStorage.cs b/BrightData/LinearAlgebra/VectorIndexing/Storage/InMemoryVectorStorage.cs index a5c88aa2..10a2904f 100644 --- a/BrightData/LinearAlgebra/VectorIndexing/Storage/InMemoryVectorStorage.cs +++ b/BrightData/LinearAlgebra/VectorIndexing/Storage/InMemoryVectorStorage.cs @@ -1,5 +1,6 @@ using System; using System.Buffers; +using System.Collections.Generic; using System.Numerics; using System.Threading.Tasks; @@ -44,5 +45,19 @@ public void ForEach(IndexedSpanCallback callback) { Parallel.For(0, Size, i => callback(this[(uint)i], (uint)i)); } + + public void ForEach(IEnumerable indices, IndexedSpanCallback callback) + { + Parallel.ForEach(indices, i => callback(this[i], i)); + } + + public ReadOnlyMemory[] GetAll() + { + var size = Size; + var ret = new ReadOnlyMemory[size]; + for(var i = 0U; i < size; i++) + ret[i] = this[i].ToArray(); + return ret; + } } } diff --git a/BrightData/LinearAlgebra/VectorIndexing/VectorSet.cs b/BrightData/LinearAlgebra/VectorIndexing/VectorSet.cs index e5c5c72e..b1e84827 100644 --- a/BrightData/LinearAlgebra/VectorIndexing/VectorSet.cs +++ b/BrightData/LinearAlgebra/VectorIndexing/VectorSet.cs @@ -3,6 +3,7 @@ using System.Linq; using System.Numerics; using System.Runtime.Intrinsics; +using System.Threading.Tasks; using BrightData.Helper; using BrightData.LinearAlgebra.VectorIndexing.IndexStrategy; using BrightData.LinearAlgebra.VectorIndexing.Storage; @@ -28,10 +29,7 @@ public class VectorSet : IHaveSize, IDisposable /// public VectorSet(uint vectorSize, VectorIndexStrategy indexType = VectorIndexStrategy.Flat, VectorStorageType storageType = VectorStorageType.InMemory, uint? capacity = null) { - IStoreVectors storage = storageType switch { - VectorStorageType.InMemory => new InMemoryVectorStorage(vectorSize, capacity), - _ => throw new NotSupportedException() - }; + var storage = GetStorage(storageType, vectorSize, capacity); if (indexType == VectorIndexStrategy.Flat) _index = new FlatVectorIndex(storage); else @@ -40,16 +38,18 @@ public VectorSet(uint vectorSize, VectorIndexStrategy indexType = VectorIndexStr public VectorSet(LinearAlgebraProvider lap, uint vectorSize, uint projectionSize, VectorIndexStrategy indexType = VectorIndexStrategy.RandomProjection, VectorStorageType storageType = VectorStorageType.InMemory, uint? capacity = null, int s = 3) { - IStoreVectors storage = storageType switch { - VectorStorageType.InMemory => new InMemoryVectorStorage(vectorSize, capacity), - _ => throw new NotSupportedException() - }; + var storage = GetStorage(storageType, vectorSize, capacity); if (indexType == VectorIndexStrategy.RandomProjection) _index = new RandomProjectionIndex(lap, storage, projectionSize, capacity, s); else throw new NotSupportedException(); } + public static IStoreVectors GetStorage(VectorStorageType storageType, uint vectorSize, uint? capacity) => storageType switch { + VectorStorageType.InMemory => new InMemoryVectorStorage(vectorSize, capacity), + _ => throw new NotSupportedException() + }; + /// public void Dispose() {