diff --git a/dna/src/main/java/dna/export/Polarisation.java b/dna/src/main/java/dna/export/Polarisation.java index 6a6d8e63..15dd6be3 100644 --- a/dna/src/main/java/dna/export/Polarisation.java +++ b/dna/src/main/java/dna/export/Polarisation.java @@ -184,9 +184,10 @@ public PolarisationResultTimeSeries getResults() { * @param congruenceNetwork A 2D array representing the congruence network. * @param conflictNetwork A 2D array representing the conflict network. * @param normaliseScores Should the result be divided by its theoretical maximum (the sum of the two matrix norms)? + * @param numClusters The number of clusters. * @return The quality of polarization as a double value. */ - private double qualityAbsdiff(int[] memberships, double[][] congruenceNetwork, double[][] conflictNetwork, boolean normalise) { + private double qualityAbsdiff(int[] memberships, double[][] congruenceNetwork, double[][] conflictNetwork, boolean normalise, int numClusters) { double congruenceNorm = calculateMatrixNorm(congruenceNetwork); double conflictNorm = calculateMatrixNorm(conflictNetwork); @@ -586,10 +587,11 @@ private class GeneticIteration { * @param congruenceNetwork The congruence matrix. * @param conflictNetwork The conflict matrix. * @param normalise Should the quality/fitness scores be normalised? + * @param numClusters The number of clusters. * @param rng The random number generator to use. * @return A list of children cluster solutions. */ - GeneticIteration(ArrayList clusterSolutions, double[][] congruenceNetwork, double[][] conflictNetwork, boolean normalise, Random rng) { + GeneticIteration(ArrayList clusterSolutions, double[][] congruenceNetwork, double[][] conflictNetwork, boolean normalise, int numClusters, Random rng) { this.clusterSolutions = new ArrayList<>(clusterSolutions); this.normalise = normalise; this.congruenceNetwork = congruenceNetwork.clone(); @@ -608,7 +610,7 @@ private class GeneticIteration { "Number of mutations based on the mutation percentage."); Dna.logger.log(log); - this.q = evaluateQuality(this.congruenceNetwork, this.conflictNetwork, this.normalise); + this.q = evaluateQuality(this.congruenceNetwork, this.conflictNetwork, normalise, numClusters); this.children = eliteRetentionStep(this.clusterSolutions, this.q, this.numElites); this.children = crossoverStep(this.clusterSolutions, this.q, this.children, rng); this.children = mutationStep(this.children, this.numMutations, this.n, rng); @@ -622,13 +624,14 @@ private class GeneticIteration { * @param congruenceNetwork The congruence network matrix. * @param conflictNetwork The conflict network matrix. * @param normalise Normalise the results? + * @param numClusters The number of clusters. * @return An array of quality scores for each cluster solution. */ - private double[] evaluateQuality(double[][] congruenceNetwork, double[][] conflictNetwork, boolean normalise) { + private double[] evaluateQuality(double[][] congruenceNetwork, double[][] conflictNetwork, boolean normalise, int numClusters) { double[] q = new double[clusterSolutions.size()]; for (int i = 0; i < clusterSolutions.size(); i++) { int[] mem = clusterSolutions.get(i).getMemberships(); - q[i] = qualityAbsdiff(mem, congruenceNetwork, conflictNetwork, normalise); + q[i] = qualityAbsdiff(mem, congruenceNetwork, conflictNetwork, normalise, numClusters); } return q; } @@ -846,7 +849,7 @@ public PolarisationResultTimeSeries geneticAlgorithm () { // Run through iterations and do the breeding, then collect results and stats lastIndex = numIterations - 1; // choose last possible value here as a default if early convergence does not happen for (int i = 0; i < numIterations; i++) { - GeneticIteration geneticIteration = new GeneticIteration(cs, this.congruence.get(t).getMatrix(), this.conflict.get(t).getMatrix(), this.normaliseScores, rng); + GeneticIteration geneticIteration = new GeneticIteration(cs, this.congruence.get(t).getMatrix(), this.conflict.get(t).getMatrix(), this.normaliseScores, this.numClusters, rng); cs = geneticIteration.getChildren(); // compute summary statistics based on iteration step and retain them @@ -1279,105 +1282,116 @@ private ArrayList[][][] create3dArray(String[] var1Values, Stri /** * Prepare the greedy membership swapping algorithm and run all the iterations. * Take out the maximum quality measure at the last step and create an object - * that stores the polarisation results. + * that stores the polarisation results. Run the algorithm in parallel for all + * time windows. */ private PolarisationResultTimeSeries greedyAlgorithm () { Random rng = (this.randomSeed == 0) ? new Random() : new Random(this.randomSeed); // Initialize random number generator - ArrayList polarisationResults = new ArrayList(); + + ArrayList polarisationResults = ProgressBar + .wrap(IntStream.range(0, Polarisation.this.congruence.size()).parallel(), "Greedy algorithm") + .map(t -> greedyTimeStep(Polarisation.this.congruence.get(t), + Polarisation.this.conflict.get(t), + Polarisation.this.normaliseScores, + Polarisation.this.numClusters, + rng.nextLong())) + .collect(Collectors.toCollection(ArrayList::new)); + + PolarisationResultTimeSeries polarisationResultTimeSeries = new PolarisationResultTimeSeries(polarisationResults); + return polarisationResultTimeSeries; + } + /** + * A single run of the greedy algorithm, for one pair of congruence and conflict + * network, i.e., for one time slice. + * + * @param congruence A Matrix object containing the 2D congruence array. + * @param conflict A Matrix object containing the 2D conflict array. + * @param normaliseScores Normalise the absdiff quality/fitness scores to 1.0? + * @param numClusters The number of clusters. + * @param seed A random seed, which is used to create a new random number generator for this algorithm run. The seed should have been itself generated by a random number generator to ensure variability across time steps and reproducibility. + * @return a PolarisationResult object + */ + private PolarisationResult greedyTimeStep(Matrix congruence, Matrix conflict, boolean normaliseScores, int numClusters, long seed) { // for each time step, run the algorithm over the cluster solutions; retain quality and memberships - double[][] congruenceMatrix, conflictMatrix; - int t, oldI, oldJ; + double[][] congruenceMatrix = congruence.getMatrix(); + double[][] conflictMatrix = conflict.getMatrix(); ArrayList maxQArray = new ArrayList(); - int[] bestMemberships, mem, mem2; - double maxQ, q1, q2; - boolean noChanges; - - try (ProgressBar pb = new ProgressBar("Greedy algorithm", this.congruence.size())) { - for (t = 0; t < congruence.size(); t++) { // go through all time steps of the time window networks - maxQArray.clear(); - congruenceMatrix = congruence.get(t).getMatrix(); - conflictMatrix = conflict.get(t).getMatrix(); - double combinedNorm = calculateMatrixNorm(congruenceMatrix) + calculateMatrixNorm(congruenceMatrix); - - if (congruenceMatrix.length > 0 || combinedNorm == 0.0) { // if the network has no nodes or edges, skip this step and return 0 directly - - // Create initially random cluster solution to update - ClusterSolution cs = new ClusterSolution(congruence.get(t).getMatrix().length, numClusters, rng); - mem = cs.getMemberships(); - - // evaluate quality of initial solution - maxQArray.add(qualityAbsdiff(mem, congruenceMatrix, conflictMatrix, this.normaliseScores)); - bestMemberships = mem.clone(); - maxQ = maxQArray.get(0); - - boolean convergence = false; - while (!convergence) { // run the two nested for-loops repeatedly until there are no more swaps - noChanges = true; - for (int i = 0; i < mem.length; i++) { - for (int j = 1; j < mem.length; j++) { // swap positions i and j in the membership vector and see if leads to higher fitness - if (i < j && mem[i] != mem[j]) { - mem2 = mem.clone(); - oldI = mem2[i]; - oldJ = mem2[j]; - mem2[i] = oldJ; - mem2[j] = oldI; - q1 = qualityAbsdiff(mem, congruenceMatrix, conflictMatrix, this.normaliseScores); - q2 = qualityAbsdiff(mem2, congruenceMatrix, conflictMatrix, this.normaliseScores); - if (q2 > q1) { // candidate solution has higher fitness -> keep it - mem = mem2.clone(); // accept the new solution if it was better than the previous - maxQArray.add(q2); - maxQ = q2; - bestMemberships = mem.clone(); - noChanges = false; - } - } + double combinedNorm = calculateMatrixNorm(congruenceMatrix) + calculateMatrixNorm(congruenceMatrix); + + if (congruenceMatrix.length > 0 || combinedNorm == 0.0) { // if the network has no nodes or edges, skip this step and return 0 directly + + // Create initially random cluster solution to update + Random random = new Random(seed); + ClusterSolution cs = new ClusterSolution(congruenceMatrix.length, numClusters, random); + int[] mem = cs.getMemberships(); + + // evaluate quality of initial solution + maxQArray.add(qualityAbsdiff(mem, congruenceMatrix, conflictMatrix, normaliseScores, numClusters)); + int[] bestMemberships = mem.clone(); + double maxQ = maxQArray.get(0); + + boolean convergence = false; + while (!convergence) { // run the two nested for-loops repeatedly until there are no more swaps + boolean noChanges = true; + for (int i = 0; i < mem.length; i++) { + for (int j = 1; j < mem.length; j++) { // swap positions i and j in the membership vector and see if leads to higher fitness + if (i < j && mem[i] != mem[j]) { + int[] mem2 = mem.clone(); + int oldI = mem2[i]; + int oldJ = mem2[j]; + mem2[i] = oldJ; + mem2[j] = oldI; + double q1 = qualityAbsdiff(mem, congruenceMatrix, conflictMatrix, normaliseScores, numClusters); + double q2 = qualityAbsdiff(mem2, congruenceMatrix, conflictMatrix, normaliseScores, numClusters); + if (q2 > q1) { // candidate solution has higher fitness -> keep it + mem = mem2.clone(); // accept the new solution if it was better than the previous + maxQArray.add(q2); + maxQ = q2; + bestMemberships = mem.clone(); + noChanges = false; } } - if (noChanges) { - convergence = true; - } - } - - double[] maxQArray2 = new double[maxQArray.size()]; - for (int i = 0; i < maxQArray.size(); i++) { - maxQArray2[i] = maxQArray.get(i); } - - // save results in array as a complex object - double[] avgQArray = maxQArray2; - double[] sdQArray = new double[maxQArray.size()]; - PolarisationResult pr = new PolarisationResult( - maxQArray2, - avgQArray, - sdQArray, - maxQ, - bestMemberships, - congruence.get(t).getRowNames(), - true, - congruence.get(t).getStart(), - congruence.get(t).getStop(), - congruence.get(t).getDateTime()); - polarisationResults.add(pr); - } else { // zero result because network is empty - PolarisationResult pr = new PolarisationResult( - new double[] { 0 }, - new double[] { 0 }, - new double[] { 0 }, - 0.0, - new int[0], - new String[0], - true, - congruence.get(t).getStart(), - congruence.get(t).getStop(), - congruence.get(t).getDateTime()); - polarisationResults.add(pr); } - pb.step(); + if (noChanges) { + convergence = true; + } } - } - PolarisationResultTimeSeries polarisationResultTimeSeries = new PolarisationResultTimeSeries(polarisationResults); - return polarisationResultTimeSeries; + double[] maxQArray2 = new double[maxQArray.size()]; + for (int i = 0; i < maxQArray.size(); i++) { + maxQArray2[i] = maxQArray.get(i); + } + + // save results in array as a complex object + double[] avgQArray = maxQArray2; + double[] sdQArray = new double[maxQArray.size()]; + PolarisationResult pr = new PolarisationResult( + maxQArray2, + avgQArray, + sdQArray, + maxQ, + bestMemberships, + congruence.getRowNames(), + true, + congruence.getStart(), + congruence.getStop(), + congruence.getDateTime()); + return pr; + } else { // zero result because network is empty + PolarisationResult pr = new PolarisationResult( + new double[] { 0 }, + new double[] { 0 }, + new double[] { 0 }, + 0.0, + new int[0], + new String[0], + true, + congruence.getStart(), + congruence.getStop(), + congruence.getDateTime()); + return pr; + } } } \ No newline at end of file