Skip to content

Commit

Permalink
Added ClusterSolution class with crossOver method
Browse files Browse the repository at this point in the history
  • Loading branch information
leifeld committed Jan 6, 2025
1 parent deff8cb commit 8b2cc50
Showing 1 changed file with 257 additions and 10 deletions.
267 changes: 257 additions & 10 deletions dna/src/main/java/export/Exporter.java
Original file line number Diff line number Diff line change
Expand Up @@ -3790,7 +3790,7 @@ public double[] evaluateBackboneSolution(String[] backboneEntities, int p) {
* @param K The number of communities.
* @return The modularity score.
*/
double modularity(int[] mem, double[][] mat, int K) {
private double modularity(int[] mem, double[][] mat, int K) {
if (mat == null || mat.length == 0 || mat.length != mat[0].length) {
throw new IllegalArgumentException("Matrix must be square and non-empty.");
}
Expand Down Expand Up @@ -3855,7 +3855,7 @@ public double[] evaluateBackboneSolution(String[] backboneEntities, int p) {
* @return The E-I index for the combined contributions of positive and negative
* ties.
*/
double eiIndex(int[] memberships, double[][] mat) {
private double eiIndex(int[] memberships, double[][] mat) {
double external = 0.0, internal = 0.0;
int n = mat.length;

Expand Down Expand Up @@ -3890,7 +3890,7 @@ public double[] evaluateBackboneSolution(String[] backboneEntities, int p) {
* @param arr An int array.
* @return An array of ranks, starting with 0.
*/
int[] calculateRanks(int... arr) {
private int[] calculateRanks(int... arr) {
class Pair {
final int value;
final int index;
Expand Down Expand Up @@ -3923,7 +3923,7 @@ class Pair {
* @param arr A double array.
* @return An array of ranks, starting with 0.
*/
int[] calculateRanks(double... arr) {
private int[] calculateRanks(double... arr) {
class Pair {
final double value;
final int index;
Expand Down Expand Up @@ -3962,7 +3962,7 @@ class Pair {
* @param K Number of groups.
* @return A shuffled array of group memberships.
*/
int[] createRandomMemberships(int N, int K) {
private int[] createRandomMemberships(int N, int K) {
// Preallocate ArrayList with an exact capacity of N to avoid resizing
ArrayList<Integer> membership = new ArrayList<>(N);

Expand All @@ -3985,25 +3985,272 @@ int[] createRandomMemberships(int N, int K) {
}

/**
* For the genetic algorithm: define a class that represents pairs of two
* For the genetic algorithm: Define a class that represents pairs of two
* indices of membership bits (i.e., index of the first node and index of
* the second node in a membership solution, with a maximum of N nodes).
*/
class MembershipPair {
int firstIndex;
int secondIndex;

private class MembershipPair {
int firstIndex; // Index of the first member
int secondIndex; // Index of the second member

/**
* Constructs a MembershipPair with the specified indices.
*
* @param firstIndex the index of the first member
* @param secondIndex the index of the second member
*/
public MembershipPair(int firstIndex, int secondIndex) {
this.firstIndex = firstIndex;
this.secondIndex = secondIndex;
}

/**
* Returns the first index.
*
* @return the first index
*/
public int getFirstIndex() {
return this.firstIndex;
}

/**
* Retrieves the value of the second index.
*
* @return the value of the second index.
*/
public int getSecondIndex() {
return this.secondIndex;
}
}

/**
* This class represents a cluster solution in the genetic algorithm,
* including the membership vector, which contains information on cluster
* membership for each node in the network. It also contains the number of
* nodes N and the number of clusters K.
*/
private class ClusterSolution implements Cloneable {

private final int[] memberships; // cluster memberships of all nodes, starting with 0
private final int N; // number of nodes
private final int K; // number of clusters

/**
* Constructs a ClusterSolution with the specified parameters.
*
* @param n The number of nodes (must be positive).
* @param k The number of clusters (must be positive).
* @param memberships The membership vector (length must equal n, values must be in the range [0, k - 1]).
* @throws IllegalArgumentException If any parameter is invalid.
*/
public ClusterSolution(int n, int k, int[] memberships) {
if (n <= 0) {
throw new IllegalArgumentException("N must be positive.");
}
if (k <= 1) {
throw new IllegalArgumentException("K must be larger than 1.");
}
if (n <= k) {
throw new IllegalArgumentException("N must be larger than K.");
}
if (memberships == null || memberships.length != n) {
throw new IllegalArgumentException("Memberships must have length equal to N.");
}
validateMemberships(memberships, k);
this.N = n;
this.K = k;
this.memberships = memberships.clone(); // defensive copy to avoid external modification
}

/**
* Constructs a ClusterSolution with random memberships.
*
* @param n The number of nodes (must be positive).
* @param k The number of clusters (must be positive).
* @throws IllegalArgumentException If any parameter is invalid.
*/
public ClusterSolution(int n, int k) {
if (n <= 0) {
throw new IllegalArgumentException("N must be positive.");
}
if (k <= 1) {
throw new IllegalArgumentException("K must be larger than 1.");
}
if (n <= k) {
throw new IllegalArgumentException("N must be larger than K.");
}
this.N = n;
this.K = k;
this.memberships = createRandomMemberships(n, k);
}

/**
* Returns the number of clusters.
*
* @return The number of clusters.
*/
public int getK() {
return K;
}

/**
* Returns a copy of the membership vector.
*
* @return A copy of the membership vector.
*/
public int[] getMemberships() {
return memberships.clone(); // defensive copy to avoid external modification
}

/**
* Creates a deep clone of this ClusterSolution.
*
* @return A deep clone of this object.
* @throws CloneNotSupportedException If the object cannot be cloned.
*/
protected ClusterSolution clone() throws CloneNotSupportedException {
return new ClusterSolution(this.N, this.K, this.memberships.clone());
}

/**
* Validates that all memberships are within the range [0, K - 1].
*/
private void validateMemberships(int[] memberships, int k) {
for (int membership : memberships) {
if (membership < 0 || membership >= k) {
throw new IllegalArgumentException("Membership values must be in the range [0, K - 1].");
}
}
}

/**
* Cross-over breeding. Combines the membership vectors of the current solution
* and a foreign solution to produce an offspring with balanced cluster distribution.
*
* @param foreignMemberships A membership vector of a foreign cluster solution.
* @throws IllegalArgumentException If the input vector is invalid or incompatible.
*/
public int[] crossOver(int[] foreignMemberships) {
// Validate input
if (foreignMemberships == null || foreignMemberships.length != this.memberships.length) {
throw new IllegalArgumentException("Incompatible membership vector lengths.");
}
validateMemberships(foreignMemberships, K);

// Step 1: Relabel clusters to align with maximum overlap
int[][] overlapMatrix = calculateOverlapMatrix(this.memberships, foreignMemberships, K);
int[] newMemberships = performRelabeling(this.memberships, foreignMemberships, overlapMatrix);

// Step 2: Perform random crossover between relabeled membership vectors
newMemberships = performCrossover(newMemberships, foreignMemberships);

// Step 3: Adjust cluster distribution to achieve balance
newMemberships = balanceClusterDistribution(newMemberships, K);
return newMemberships;
}

/**
* Calculates the overlap matrix between two membership vectors.
*/
private int[][] calculateOverlapMatrix(int[] memberships1, int[] memberships2, int k) {
int[][] matrix = new int[k][k];
for (int i = 0; i < memberships1.length; i++) {
matrix[memberships1[i]][memberships2[i]]++;
}
return matrix;
}

/**
* Relabels clusters to maximize overlap between two membership vectors.
*/
private int[] performRelabeling(int[] memberships1, int[] memberships2, int[][] overlapMatrix) {
int k = overlapMatrix.length; // Number of clusters
int[] relabelMap = new int[k]; // Map from original cluster to new cluster

// For each row, find the column with the maximum overlap
boolean[] assigned = new boolean[k]; // Track assigned columns
Arrays.fill(relabelMap, -1);

for (int row = 0; row < k; row++) {
double[] rowValues = new double[k];
for (int col = 0; col < k; col++) {
rowValues[col] = overlapMatrix[row][col];
}

// Use `calculateRanks` to rank columns for this row
int[] ranks = calculateRanks(rowValues);

// Assign the best column for this row
for (int rank = 0; rank < ranks.length; rank++) {
int col = ranks[rank];
if (!assigned[col]) {
relabelMap[row] = col;
assigned[col] = true;
break;
}
}
}

// Apply the relabeling map to the original memberships
int[] relabeledMemberships = new int[memberships1.length];
for (int i = 0; i < memberships1.length; i++) {
relabeledMemberships[i] = relabelMap[memberships1[i]];
}

return relabeledMemberships;
}

/**
* Performs crossover by randomly combining bits from two membership vectors.
*/
private int[] performCrossover(int[] memberships1, int[] memberships2) {
Random rand = new Random(); // Optionally pass a seed here for reproducibility
int[] result = new int[memberships1.length];
for (int i = 0; i < memberships1.length; i++) {
result[i] = (rand.nextBoolean()) ? memberships1[i] : memberships2[i];
}
return result;
}

/**
* Balances cluster distribution by adjusting over- and under-represented clusters.
*/
private int[] balanceClusterDistribution(int[] memberships, int k) {
int[] counts = new int[k];
List<List<Integer>> clusterIndices = new ArrayList<>();
for (int i = 0; i < k; i++) clusterIndices.add(new ArrayList<>());

// Populate counts and cluster indices
for (int i = 0; i < memberships.length; i++) {
counts[memberships[i]]++;
clusterIndices.get(memberships[i]).add(i);
}

// Compute target sizes
int base = memberships.length / k; // Base size for each cluster
int extra = memberships.length % k; // Number of clusters with one extra element
int[] maxAllowed = new int[k];
for (int i = 0; i < k; i++) {
maxAllowed[i] = base + ((i < extra) ? 1 : 0); // Add 1 to the first 'extra' clusters (due to modulo)
}

// Balance clusters
for (int i = 0; i < k; i++) {
while (counts[i] > maxAllowed[i]) {
for (int j = 0; j < k; j++) {
if (counts[j] < maxAllowed[j]) {
// Move an element from cluster i to cluster j
int idx = clusterIndices.get(i).remove(0);
memberships[idx] = j; // Swap cluster membership from i to j
counts[i]--;
counts[j]++;
clusterIndices.get(j).add(idx); // Mark index as belonging to cluster j
break;
}
}
}
}
return memberships;
}
}
}

0 comments on commit 8b2cc50

Please sign in to comment.