diff --git a/Assets/Scripts/Algorithms.meta b/Assets/Scripts/Algorithms.meta new file mode 100644 index 00000000..125d5297 --- /dev/null +++ b/Assets/Scripts/Algorithms.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 5ce134601efc54aa28f0284924ffbd01 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Assets/Scripts/Algorithms/Clustering.meta b/Assets/Scripts/Algorithms/Clustering.meta new file mode 100644 index 00000000..821ce142 --- /dev/null +++ b/Assets/Scripts/Algorithms/Clustering.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 8e4bdae9b86b243de9987ac58f316992 +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Assets/Scripts/Algorithms/Clustering/AgglomerativeClusterer.cs b/Assets/Scripts/Algorithms/Clustering/AgglomerativeClusterer.cs new file mode 100644 index 00000000..52967aa0 --- /dev/null +++ b/Assets/Scripts/Algorithms/Clustering/AgglomerativeClusterer.cs @@ -0,0 +1,99 @@ +using System.Collections; +using System.Collections.Generic; +using UnityEngine; + +// The agglomerative clusterer class performs agglomerative clustering with the stopping condition +// given by the size and radius constraints. +public class AgglomerativeClusterer : ISizeAndRadiusConstrainedClusterer { + public AgglomerativeClusterer(List objects, int maxSize, float maxRadius) + : base(objects, maxSize, maxRadius) {} + + // Cluster the game objects. + public override void Cluster() { + // Add a cluster for every game object. + foreach (var obj in _objects) { + Cluster cluster = new Cluster(obj); + cluster.AddObject(obj); + _clusters.Add(cluster); + } + + // Create a set containing all valid cluster indices. + HashSet validClusterIndices = new HashSet(); + for (int i = 0; i < _clusters.Count; ++i) { + validClusterIndices.Add(i); + } + + // Find the pairwise distances between all clusters. + // The upper triangular half of the distances matrix is unused. + float[,] distances = new float[_clusters.Count, _clusters.Count]; + for (int i = 0; i < _clusters.Count; ++i) { + for (int j = 0; j < i; ++j) { + distances[i, j] = Vector3.Distance(_clusters[i].Coordinates, _clusters[j].Coordinates); + } + } + + while (true) { + // Find the minimum distance between two clusters. + float minDistance = Mathf.Infinity; + int clusterIndex1 = -1; + int clusterIndex2 = -1; + for (int i = 0; i < _clusters.Count; ++i) { + for (int j = 0; j < i; ++j) { + if (distances[i, j] < minDistance) { + minDistance = distances[i, j]; + clusterIndex1 = i; + clusterIndex2 = j; + } + } + } + + // Check whether the minimum distance exceeds the maximum cluster radius, in which case the + // algorithm has converged. This produces a conservative solution because the radius of a + // merged cluster is less than or equal to the sum of the original cluster radii. + if (minDistance >= _maxRadius) { + break; + } + + // Check whether merging the two clusters would violate the size constraint. + if (_clusters[clusterIndex1].Size() + _clusters[clusterIndex2].Size() > _maxSize) { + distances[clusterIndex1, clusterIndex2] = Mathf.Infinity; + continue; + } + + // Merge the two clusters together. + int minClusterIndex = Mathf.Min(clusterIndex1, clusterIndex2); + int maxClusterIndex = Mathf.Max(clusterIndex1, clusterIndex2); + _clusters[minClusterIndex].Merge(_clusters[maxClusterIndex]); + _clusters[minClusterIndex].Recenter(); + validClusterIndices.Remove(maxClusterIndex); + + // Update the distances matrix using the distance between the cluster centroids. + // TODO(titan): Change the distance metric to use average or maximum linkage. + for (int i = 0; i < minClusterIndex; ++i) { + if (distances[minClusterIndex, i] < Mathf.Infinity) { + distances[minClusterIndex, i] = + Vector3.Distance(_clusters[minClusterIndex].Coordinates, _clusters[i].Coordinates); + } + } + for (int i = minClusterIndex + 1; i < _clusters.Count; ++i) { + if (distances[i, minClusterIndex] < Mathf.Infinity) { + distances[i, minClusterIndex] = + Vector3.Distance(_clusters[minClusterIndex].Coordinates, _clusters[i].Coordinates); + } + } + for (int i = 0; i < maxClusterIndex; ++i) { + distances[maxClusterIndex, i] = Mathf.Infinity; + } + for (int i = maxClusterIndex + 1; i < _clusters.Count; ++i) { + distances[i, maxClusterIndex] = Mathf.Infinity; + } + } + + // Select only the valid clusters. + for (int i = _clusters.Count - 1; i >= 0; --i) { + if (!validClusterIndices.Contains(i)) { + _clusters.RemoveAt(i); + } + } + } +} diff --git a/Assets/Scripts/Algorithms/Clustering/AgglomerativeClusterer.cs.meta b/Assets/Scripts/Algorithms/Clustering/AgglomerativeClusterer.cs.meta new file mode 100644 index 00000000..47613e98 --- /dev/null +++ b/Assets/Scripts/Algorithms/Clustering/AgglomerativeClusterer.cs.meta @@ -0,0 +1,2 @@ +fileFormatVersion: 2 +guid: e0c0460dd55b54f39bf7ca7878411679 \ No newline at end of file diff --git a/Assets/Scripts/Algorithms/Clustering/Cluster.cs b/Assets/Scripts/Algorithms/Clustering/Cluster.cs new file mode 100644 index 00000000..547d9296 --- /dev/null +++ b/Assets/Scripts/Algorithms/Clustering/Cluster.cs @@ -0,0 +1,88 @@ +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using UnityEngine; + +// The cluster class represents a collection of game objects. +public class Cluster { + // Coordinates of the cluster. + private Vector3 _coordinates = Vector3.zero; + + // List of game objects in the cluster. + private List _objects = new List(); + + public Cluster() {} + public Cluster(in Vector3 coordinates) { + _coordinates = coordinates; + } + public Cluster(in GameObject obj) { + _coordinates = obj.transform.position; + } + + // Get the cluster coordinates. + public Vector3 Coordinates { + get { return _coordinates; } + } + + // Get the list of game objects. + public IReadOnlyList Objects { + get { return _objects; } + } + + // Return the size of the cluster. + public int Size() { + return _objects.Count; + } + + // Check whether the cluster is empty. + public bool IsEmpty() { + return Size() == 0; + } + + // Calculate the radius of the cluster. + public float Radius() { + if (IsEmpty()) { + return 0; + } + + Vector3 centroid = Centroid(); + return _objects.Max(obj => Vector3.Distance(centroid, obj.transform.position)); + } + + // Calculate the centroid of the cluster. + public Vector3 Centroid() { + if (IsEmpty()) { + return Vector3.zero; + } + + Vector3 centroid = Vector3.zero; + foreach (var obj in _objects) { + centroid += obj.transform.position; + } + centroid /= _objects.Count; + return centroid; + } + + // Recenter the cluster's centroid to be the mean of all game objects' positions in the cluster. + public void Recenter() { + _coordinates = Centroid(); + } + + // Add a game object to the cluster. + // This function does not update the centroid of the cluster. + public void AddObject(in GameObject obj) { + _objects.Add(obj); + } + + // Add multiple game objects to the cluster. + // This function does not update the centroid of the cluster. + public void AddObjects(in IReadOnlyList objects) { + _objects.AddRange(objects); + } + + // Merge another cluster into this one. + // This function does not update the centroid of the cluster. + public void Merge(in Cluster cluster) { + AddObjects(cluster.Objects); + } +} diff --git a/Assets/Scripts/Algorithms/Clustering/Cluster.cs.meta b/Assets/Scripts/Algorithms/Clustering/Cluster.cs.meta new file mode 100644 index 00000000..641b1e4a --- /dev/null +++ b/Assets/Scripts/Algorithms/Clustering/Cluster.cs.meta @@ -0,0 +1,2 @@ +fileFormatVersion: 2 +guid: 99d3dec10e6c64a86a7cd57009cf1c31 \ No newline at end of file diff --git a/Assets/Scripts/Algorithms/Clustering/Clusterer.cs b/Assets/Scripts/Algorithms/Clustering/Clusterer.cs new file mode 100644 index 00000000..5be4aa28 --- /dev/null +++ b/Assets/Scripts/Algorithms/Clustering/Clusterer.cs @@ -0,0 +1,47 @@ +using System.Collections; +using System.Collections.Generic; +using UnityEngine; + +// The clusterer class is an interface for clustering algorithms. +public abstract class IClusterer { + // List of game objects to cluster. + protected List _objects = new List(); + + // List of clusters. + protected List _clusters = new List(); + + public IClusterer(List objects) { + _objects = objects; + } + + // Get the list of game objects. + public IReadOnlyList Objects { + get { return _objects; } + } + + // Get the list of clusters. + public IReadOnlyList Clusters { + get { return _clusters; } + } + + // Cluster the game objects. + public abstract void Cluster(); +} + +// The size and radius-constrained clusterer class is an interface for clustering algorithms with +// size and radius constraints. The size is defined as the maximum number of game objects within a +// cluster, and the radius denotes the maximum distance from the cluster's centroid to any of its +// assigned game objects. +public abstract class ISizeAndRadiusConstrainedClusterer : IClusterer { + // Maximum cluster size. + protected readonly int _maxSize = 0; + + // Maximum cluster radius. + protected readonly float _maxRadius = 0; + + public ISizeAndRadiusConstrainedClusterer(List objects, int maxSize, float maxRadius) + : base(objects) { + _maxSize = maxSize; + _maxRadius = maxRadius; + } +} diff --git a/Assets/Scripts/Algorithms/Clustering/Clusterer.cs.meta b/Assets/Scripts/Algorithms/Clustering/Clusterer.cs.meta new file mode 100644 index 00000000..b80bb31d --- /dev/null +++ b/Assets/Scripts/Algorithms/Clustering/Clusterer.cs.meta @@ -0,0 +1,2 @@ +fileFormatVersion: 2 +guid: b622dd33f3bbe4622bd82fbc79f64cf7 \ No newline at end of file diff --git a/Assets/Scripts/Algorithms/Clustering/KMeansClusterer.cs b/Assets/Scripts/Algorithms/Clustering/KMeansClusterer.cs new file mode 100644 index 00000000..e086e526 --- /dev/null +++ b/Assets/Scripts/Algorithms/Clustering/KMeansClusterer.cs @@ -0,0 +1,119 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using UnityEngine; + +// The k-means clusterer class performs k-means clustering. +public class KMeansClusterer : IClusterer { + public const float Epsilon = 1e-3f; + + // Number of clusters. + private int _k = 0; + + // Maximum number of iterations. + private int _maxIterations = 20; + + public KMeansClusterer(List objects, int k, int maxIterations = 20) : base(objects) { + _k = k; + _maxIterations = maxIterations; + } + + // Cluster the game objects. + public override void Cluster() { + // Initialize the clusters with centroids located at random game objects. + // Perform Fisher-Yates shuffling to find k random game objects. + System.Random random = new System.Random(); + for (int i = _objects.Count - 1; i >= _objects.Count - _k; --i) { + int j = random.Next(i + 1); + (_objects[i], _objects[j]) = (_objects[j], _objects[i]); + } + for (int i = _objects.Count - 1; i >= _objects.Count - _k; --i) { + _clusters.Add(new Cluster(_objects[i])); + } + + bool converged = false; + int iteration = 0; + while (!converged && iteration < _maxIterations) { + AssignObjectsToCluster(); + + // Calculate the new clusters as the mean of all assigned game objects. + converged = true; + for (int clusterIndex = 0; clusterIndex < _clusters.Count; ++clusterIndex) { + Cluster newCluster; + if (_clusters[clusterIndex].IsEmpty()) { + int objectIndex = random.Next(_objects.Count); + newCluster = new Cluster(_objects[objectIndex]); + } else { + newCluster = new Cluster(_clusters[clusterIndex].Centroid()); + } + + // Check whether the algorithm has converged by checking whether the cluster has moved. + if (Vector3.Distance(newCluster.Coordinates, _clusters[clusterIndex].Coordinates) > + Epsilon) { + converged = false; + } + + _clusters[clusterIndex] = newCluster; + } + + ++iteration; + } + + AssignObjectsToCluster(); + } + + private void AssignObjectsToCluster() { + // Determine the closest centroid to each game object. + foreach (var obj in _objects) { + float minDistance = Mathf.Infinity; + int minIndex = -1; + for (int clusterIndex = 0; clusterIndex < _clusters.Count; ++clusterIndex) { + float distance = + Vector3.Distance(_clusters[clusterIndex].Coordinates, obj.transform.position); + if (distance < minDistance) { + minDistance = distance; + minIndex = clusterIndex; + } + } + _clusters[minIndex].AddObject(obj); + } + } +} + +// The constrained k-means clusterer class performs k-means clustering under size and radius +// constraints. +public class ConstrainedKMeansClusterer : ISizeAndRadiusConstrainedClusterer { + public ConstrainedKMeansClusterer(List objects, int maxSize, float maxRadius) + : base(objects, maxSize, maxRadius) {} + + // Cluster the game objects. + public override void Cluster() { + int numClusters = (int)Mathf.Ceil(_objects.Count / _maxSize); + KMeansClusterer clusterer; + while (true) { + clusterer = new KMeansClusterer(_objects, numClusters); + clusterer.Cluster(); + + // Count the number of over-populated and over-sized clusters. + int numOverPopulatedClusters = 0; + int numOverSizedClusters = 0; + foreach (var cluster in clusterer.Clusters) { + if (cluster.Size() > _maxSize) { + ++numOverPopulatedClusters; + } + if (cluster.Radius() > _maxRadius) { + ++numOverSizedClusters; + } + } + + // If all clusters satisfy the size and radius constraints, the algorithm has converged. + if (numOverPopulatedClusters == 0 && numOverSizedClusters == 0) { + break; + } + + numClusters += + (int)Mathf.Ceil(Mathf.Max(numOverPopulatedClusters, numOverSizedClusters) / 2.0f); + } + _clusters = new List(clusterer.Clusters); + } +} diff --git a/Assets/Scripts/Algorithms/Clustering/KMeansClusterer.cs.meta b/Assets/Scripts/Algorithms/Clustering/KMeansClusterer.cs.meta new file mode 100644 index 00000000..4823728e --- /dev/null +++ b/Assets/Scripts/Algorithms/Clustering/KMeansClusterer.cs.meta @@ -0,0 +1,2 @@ +fileFormatVersion: 2 +guid: 239318f01460f416792b3fd0e5165fee \ No newline at end of file diff --git a/Assets/Tests/EditMode/AgglomerativeClustererTest.cs b/Assets/Tests/EditMode/AgglomerativeClustererTest.cs new file mode 100644 index 00000000..22e6c659 --- /dev/null +++ b/Assets/Tests/EditMode/AgglomerativeClustererTest.cs @@ -0,0 +1,69 @@ +using NUnit.Framework; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using UnityEngine; +using UnityEngine.TestTools; + +public class AgglomerativeClustererTest { + public static GameObject GenerateObject(in Vector3 position) { + GameObject obj = new GameObject(); + obj.transform.position = position; + return obj; + } + + public static readonly List Objects = new List { + GenerateObject(new Vector3(0, 0, 0)), + GenerateObject(new Vector3(0, 1, 0)), + GenerateObject(new Vector3(0, 1.5f, 0)), + GenerateObject(new Vector3(0, 2.5f, 0)), + }; + + [Test] + public void TestSingleCluster() { + AgglomerativeClusterer clusterer = + new AgglomerativeClusterer(Objects, maxSize: Objects.Count, maxRadius: Mathf.Infinity); + clusterer.Cluster(); + Assert.AreEqual(1, clusterer.Clusters.Count); + Cluster cluster = clusterer.Clusters[0]; + Assert.AreEqual(Objects.Count, cluster.Size()); + Assert.AreEqual(new Vector3(0, 1.25f, 0), cluster.Centroid()); + } + + [Test] + public void TestMaxSizeOne() { + AgglomerativeClusterer clusterer = + new AgglomerativeClusterer(Objects, maxSize: 1, maxRadius: Mathf.Infinity); + clusterer.Cluster(); + Assert.AreEqual(Objects.Count, clusterer.Clusters.Count); + foreach (var cluster in clusterer.Clusters) { + Assert.AreEqual(1, cluster.Size()); + } + } + + [Test] + public void TestZeroRadius() { + AgglomerativeClusterer clusterer = + new AgglomerativeClusterer(Objects, maxSize: Objects.Count, maxRadius: 0); + clusterer.Cluster(); + Assert.AreEqual(Objects.Count, clusterer.Clusters.Count); + foreach (var cluster in clusterer.Clusters) { + Assert.AreEqual(1, cluster.Size()); + } + } + + [Test] + public void TestSmallRadius() { + AgglomerativeClusterer clusterer = + new AgglomerativeClusterer(Objects, maxSize: Objects.Count, maxRadius: 1); + clusterer.Cluster(); + Assert.AreEqual(3, clusterer.Clusters.Count); + List clusters = clusterer.Clusters.OrderBy(cluster => cluster.Coordinates[1]).ToList(); + Assert.AreEqual(1, clusters[0].Size()); + Assert.AreEqual(new Vector3(0, 0, 0), clusters[0].Coordinates); + Assert.AreEqual(2, clusters[1].Size()); + Assert.AreEqual(new Vector3(0, 1.25f, 0), clusters[1].Coordinates); + Assert.AreEqual(1, clusters[2].Size()); + Assert.AreEqual(new Vector3(0, 2.5f, 0), clusters[2].Coordinates); + } +} diff --git a/Assets/Tests/EditMode/AgglomerativeClustererTest.cs.meta b/Assets/Tests/EditMode/AgglomerativeClustererTest.cs.meta new file mode 100644 index 00000000..a1254489 --- /dev/null +++ b/Assets/Tests/EditMode/AgglomerativeClustererTest.cs.meta @@ -0,0 +1,2 @@ +fileFormatVersion: 2 +guid: 04f19bc2ee71841a8bf1d11791707c9b \ No newline at end of file diff --git a/Assets/Tests/EditMode/ClusterTest.cs b/Assets/Tests/EditMode/ClusterTest.cs new file mode 100644 index 00000000..d594ff24 --- /dev/null +++ b/Assets/Tests/EditMode/ClusterTest.cs @@ -0,0 +1,99 @@ +using NUnit.Framework; +using System.Collections; +using System.Collections.Generic; +using UnityEngine; +using UnityEngine.TestTools; + +public class ClusterTest { + public static GameObject GenerateObject(in Vector3 position) { + GameObject obj = new GameObject(); + obj.transform.position = position; + return obj; + } + + public static Cluster GenerateCluster(in IReadOnlyList objects) { + Cluster cluster = new Cluster(); + cluster.AddObjects(objects); + cluster.Recenter(); + return cluster; + } + + [Test] + public void TestSize() { + const int size = 10; + List objects = new List(); + for (int i = 0; i < size; ++i) { + objects.Add(GenerateObject(new Vector3(0, i, 0))); + } + Cluster cluster = GenerateCluster(objects); + Assert.AreEqual(size, cluster.Size()); + } + + [Test] + public void TestIsEmpty() { + Cluster emptyCluster = new Cluster(); + Assert.IsTrue(emptyCluster.IsEmpty()); + + Cluster cluster = new Cluster(); + cluster.AddObject(new GameObject()); + Assert.IsFalse(cluster.IsEmpty()); + } + + [Test] + public void TestRadius() { + const float radius = 5; + List objects = new List(); + objects.Add(GenerateObject(new Vector3(0, radius, 0))); + objects.Add(GenerateObject(new Vector3(0, -radius, 0))); + Cluster cluster = GenerateCluster(objects); + Assert.AreEqual(radius, cluster.Radius()); + } + + [Test] + public void TestCentroid() { + List objects = new List(); + for (int i = -1; i <= 1; ++i) { + for (int j = -1; j <= 1; ++j) { + objects.Add(GenerateObject(new Vector3(i, j, 0))); + } + } + Cluster cluster = GenerateCluster(objects); + Assert.AreEqual(Vector3.zero, cluster.Centroid()); + } + + [Test] + public void TestRecenter() { + List objects = new List(); + for (int i = -1; i <= 1; ++i) { + for (int j = -1; j <= 1; ++j) { + objects.Add(GenerateObject(new Vector3(i, j, 0))); + } + } + Cluster cluster = GenerateCluster(objects); + cluster.AddObject(GenerateObject(new Vector3(10, -10, 0))); + Assert.AreNotEqual(new Vector3(1, -1, 0), cluster.Coordinates); + cluster.Recenter(); + Assert.AreEqual(new Vector3(1, -1, 0), cluster.Coordinates); + } + + [Test] + public void TestMerge() { + const int size = 10; + List objects1 = new List(); + List objects2 = new List(); + for (int i = 0; i < size; ++i) { + objects1.Add(GenerateObject(new Vector3(0, i, 0))); + objects2.Add(GenerateObject(new Vector3(i, 0, 0))); + } + Cluster cluster1 = GenerateCluster(objects1); + Cluster cluster2 = GenerateCluster(objects2); + int size1 = cluster1.Size(); + int size2 = cluster2.Size(); + Vector3 centroid1 = cluster1.Centroid(); + Vector3 centroid2 = cluster2.Centroid(); + cluster1.Merge(cluster2); + cluster1.Recenter(); + Assert.AreEqual(size1 + size2, cluster1.Size()); + Assert.AreEqual((centroid1 + centroid2) / 2, cluster1.Coordinates); + } +} diff --git a/Assets/Tests/EditMode/ClusterTest.cs.meta b/Assets/Tests/EditMode/ClusterTest.cs.meta new file mode 100644 index 00000000..bb4c2b32 --- /dev/null +++ b/Assets/Tests/EditMode/ClusterTest.cs.meta @@ -0,0 +1,2 @@ +fileFormatVersion: 2 +guid: 74be2838017f7416283597f9f3275791 \ No newline at end of file diff --git a/Assets/Tests/EditMode/KMeansClustererTest.cs b/Assets/Tests/EditMode/KMeansClustererTest.cs new file mode 100644 index 00000000..fe51ac91 --- /dev/null +++ b/Assets/Tests/EditMode/KMeansClustererTest.cs @@ -0,0 +1,150 @@ +using NUnit.Framework; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using UnityEngine; +using UnityEngine.TestTools; + +public class KMeansClustererTest { + public static readonly List Objects = new List { + GenerateObject(new Vector3(0, 0, 0)), + GenerateObject(new Vector3(0, 1, 0)), + GenerateObject(new Vector3(0, 1.5f, 0)), + GenerateObject(new Vector3(0, 2.5f, 0)), + }; + + public static GameObject GenerateObject(in Vector3 position) { + GameObject obj = new GameObject(); + obj.transform.position = position; + return obj; + } + + [Test] + public void TestSingleCluster() { + KMeansClusterer clusterer = new KMeansClusterer(Objects, k: 1); + clusterer.Cluster(); + Cluster cluster = clusterer.Clusters[0]; + Assert.AreEqual(Objects.Count, cluster.Size()); + Assert.AreEqual(new Vector3(0, 1.25f, 0), cluster.Coordinates); + Assert.AreEqual(new Vector3(0, 1.25f, 0), cluster.Centroid()); + } + + // Test to reveal improper clearing of cluster memberships. + [Test] + public void TestTwoDistinctClustersWithResetNeeded() { + // Group A: points near (0, 0, 0). + var groupA = new List { + GenerateObject(new Vector3(0, 0, 0)), + GenerateObject(new Vector3(1, 0, 0)), + GenerateObject(new Vector3(0, 1, 0)), + GenerateObject(new Vector3(1, 1, 0)), + }; + + // Group B: points near (10, 10, 10). + var groupB = new List { + GenerateObject(new Vector3(10, 10, 10)), + GenerateObject(new Vector3(11, 10, 10)), + GenerateObject(new Vector3(10, 11, 10)), + GenerateObject(new Vector3(11, 11, 10)), + }; + + // Combine them. + var objects = new List(); + objects.AddRange(groupA); + objects.AddRange(groupB); + + // Create clusterer with k = 2. + KMeansClusterer clusterer = new KMeansClusterer(objects, k: 2); + clusterer.Cluster(); + + // We expect exactly 2 clusters. + Assert.AreEqual(2, clusterer.Clusters.Count); + + // Retrieve the clusters. + Cluster c0 = clusterer.Clusters[0]; + Cluster c1 = clusterer.Clusters[1]; + + // Because the clusters are well-separated, each cluster should contain all points from one + // group or the other, not a mixture. Check via centroids. + var centroid0 = c0.Centroid(); + var centroid1 = c1.Centroid(); + + // One centroid should be near (0.5, 0.5, 0), the other near (10.5, 10.5, 10). + var expectedCentroid0 = new Vector3(0.5f, 0.5f, 0); + var expectedCentroid1 = new Vector3(10.5f, 10.5f, 10); + bool correctPlacement = (centroid0 == expectedCentroid0 && centroid1 == expectedCentroid1) || + (centroid0 == expectedCentroid1 && centroid1 == expectedCentroid0); + Assert.IsTrue( + correctPlacement, + "Centroids not close to the expected group centers. Possible leftover membership from a previous iteration if clusters not cleared."); + + // Additionally, we can count membership to confirm that each cluster got exactly four points + // for a more direct check. + int cluster0Count = c0.Size(); + int cluster1Count = c1.Size(); + Assert.AreEqual(8, cluster0Count + cluster1Count, + "Total membership across clusters does not match the total number of objects."); + + // Even if the clusters swapped roles, each cluster should have 4 points if membership was + // properly reset and re-assigned. + bool clusterCountsValid = cluster0Count == 4 && cluster1Count == 4; + Assert.IsTrue(clusterCountsValid, + $"Cluster sizes not as expected. c0={cluster0Count}, c1={cluster1Count}."); + } +} + +public class ConstrainedKMeansClustererTest { + public static readonly List Objects = new List { + GenerateObject(new Vector3(0, 0, 0)), + GenerateObject(new Vector3(0, 1, 0)), + GenerateObject(new Vector3(0, 1.5f, 0)), + GenerateObject(new Vector3(0, 2.5f, 0)), + }; + + public static GameObject GenerateObject(in Vector3 position) { + GameObject obj = new GameObject(); + obj.transform.position = position; + return obj; + } + + [Test] + public void TestSingleCluster() { + ConstrainedKMeansClusterer clusterer = + new ConstrainedKMeansClusterer(Objects, maxSize: Objects.Count, maxRadius: Mathf.Infinity); + clusterer.Cluster(); + Assert.AreEqual(1, clusterer.Clusters.Count); + Cluster cluster = clusterer.Clusters[0]; + Assert.AreEqual(Objects.Count, cluster.Size()); + Assert.AreEqual(new Vector3(0, 1.25f, 0), cluster.Centroid()); + } + + [Test] + public void TestMaxSizeOne() { + ConstrainedKMeansClusterer clusterer = + new ConstrainedKMeansClusterer(Objects, maxSize: 1, maxRadius: Mathf.Infinity); + clusterer.Cluster(); + Assert.AreEqual(Objects.Count, clusterer.Clusters.Count); + foreach (var cluster in clusterer.Clusters) { + Assert.AreEqual(1, cluster.Size()); + } + } + + [Test] + public void TestZeroRadius() { + ConstrainedKMeansClusterer clusterer = + new ConstrainedKMeansClusterer(Objects, maxSize: Objects.Count, maxRadius: 0); + clusterer.Cluster(); + Assert.AreEqual(Objects.Count, clusterer.Clusters.Count); + foreach (var cluster in clusterer.Clusters) { + Assert.AreEqual(1, cluster.Size()); + } + } + + [Test] + public void TestSmallRadius() { + ConstrainedKMeansClusterer clusterer = + new ConstrainedKMeansClusterer(Objects, maxSize: Objects.Count, maxRadius: 1); + clusterer.Cluster(); + Assert.AreEqual(2, clusterer.Clusters.Count); + } +} diff --git a/Assets/Tests/EditMode/KMeansClustererTest.cs.meta b/Assets/Tests/EditMode/KMeansClustererTest.cs.meta new file mode 100644 index 00000000..77103751 --- /dev/null +++ b/Assets/Tests/EditMode/KMeansClustererTest.cs.meta @@ -0,0 +1,2 @@ +fileFormatVersion: 2 +guid: 0380bb64b0b224eb2a89bd5bfec704d3 \ No newline at end of file