diff --git a/CHANGELOG.md b/CHANGELOG.md
index b3c9e73159de7..5fd86c76fe6ab 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
 - Add a counter to node stat (and _cat/shards) api to track shard going from idle to non-idle ([#12768](https://github.com/opensearch-project/OpenSearch/pull/12768))
 - [Concurrent Segment Search] Disable concurrent segment search for system indices and throttled requests ([#12954](https://github.com/opensearch-project/OpenSearch/pull/12954))
 - Detect breaking changes on pull requests ([#9044](https://github.com/opensearch-project/OpenSearch/pull/9044))
+- Add cluster primary balance contraint for rebalancing with buffer ([#12656](https://github.com/opensearch-project/OpenSearch/pull/12656))
 
 ### Dependencies
 - Bump `org.apache.commons:commons-configuration2` from 2.10.0 to 2.10.1 ([#12896](https://github.com/opensearch-project/OpenSearch/pull/12896))
diff --git a/server/src/internalClusterTest/java/org/opensearch/indices/replication/SegmentReplicationAllocationIT.java b/server/src/internalClusterTest/java/org/opensearch/indices/replication/SegmentReplicationAllocationIT.java
index 30edea6551067..669e24f9fb555 100644
--- a/server/src/internalClusterTest/java/org/opensearch/indices/replication/SegmentReplicationAllocationIT.java
+++ b/server/src/internalClusterTest/java/org/opensearch/indices/replication/SegmentReplicationAllocationIT.java
@@ -31,6 +31,9 @@
 import java.util.stream.Collectors;
 
 import static org.opensearch.cluster.routing.ShardRoutingState.STARTED;
+import static org.opensearch.cluster.routing.allocation.allocator.BalancedShardsAllocator.PREFER_PRIMARY_SHARD_BALANCE;
+import static org.opensearch.cluster.routing.allocation.allocator.BalancedShardsAllocator.PREFER_PRIMARY_SHARD_REBALANCE;
+import static org.opensearch.cluster.routing.allocation.allocator.BalancedShardsAllocator.PRIMARY_SHARD_REBALANCE_BUFFER;
 import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked;
 
 @OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, numDataNodes = 0)
@@ -58,6 +61,20 @@ public void enablePreferPrimaryBalance() {
         );
     }
 
+    public void setAllocationRelocationStrategy(boolean preferPrimaryBalance, boolean preferPrimaryRebalance, float buffer) {
+        assertAcked(
+            client().admin()
+                .cluster()
+                .prepareUpdateSettings()
+                .setPersistentSettings(
+                    Settings.builder()
+                        .put(PREFER_PRIMARY_SHARD_BALANCE.getKey(), preferPrimaryBalance)
+                        .put(PREFER_PRIMARY_SHARD_REBALANCE.getKey(), preferPrimaryRebalance)
+                        .put(PRIMARY_SHARD_REBALANCE_BUFFER.getKey(), buffer)
+                )
+        );
+    }
+
     /**
      * This test verifies that the overall primary balance is attained during allocation. This test verifies primary
      * balance per index and across all indices is maintained.
@@ -87,7 +104,7 @@ public void testGlobalPrimaryAllocation() throws Exception {
         state = client().admin().cluster().prepareState().execute().actionGet().getState();
         logger.info(ShardAllocations.printShardDistribution(state));
         verifyPerIndexPrimaryBalance();
-        verifyPrimaryBalance();
+        verifyPrimaryBalance(0.0f);
     }
 
     /**
@@ -224,6 +241,70 @@ public void testAllocationWithDisruption() throws Exception {
         verifyPerIndexPrimaryBalance();
     }
 
+    /**
+     * Similar to testSingleIndexShardAllocation test but creates multiple indices, multiple nodes adding in and getting
+     * removed. The test asserts post each such event that primary shard distribution is balanced for each index as well as across the nodes
+     * when the PREFER_PRIMARY_SHARD_REBALANCE is set to true
+     */
+    public void testAllocationAndRebalanceWithDisruption() throws Exception {
+        internalCluster().startClusterManagerOnlyNode();
+        final int maxReplicaCount = 2;
+        final int maxShardCount = 2;
+        // Create higher number of nodes than number of shards to reduce chances of SameShardAllocationDecider kicking-in
+        // and preventing primary relocations
+        final int nodeCount = randomIntBetween(5, 10);
+        final int numberOfIndices = randomIntBetween(1, 10);
+        final float buffer = randomIntBetween(1, 4) * 0.10f;
+
+        logger.info("--> Creating {} nodes", nodeCount);
+        final List<String> nodeNames = new ArrayList<>();
+        for (int i = 0; i < nodeCount; i++) {
+            nodeNames.add(internalCluster().startNode());
+        }
+        setAllocationRelocationStrategy(true, true, buffer);
+
+        int shardCount, replicaCount;
+        ClusterState state;
+        for (int i = 0; i < numberOfIndices; i++) {
+            shardCount = randomIntBetween(1, maxShardCount);
+            replicaCount = randomIntBetween(1, maxReplicaCount);
+            logger.info("--> Creating index test{} with primary {} and replica {}", i, shardCount, replicaCount);
+            createIndex("test" + i, shardCount, replicaCount, i % 2 == 0);
+            ensureGreen(TimeValue.timeValueSeconds(60));
+            if (logger.isTraceEnabled()) {
+                state = client().admin().cluster().prepareState().execute().actionGet().getState();
+                logger.info(ShardAllocations.printShardDistribution(state));
+            }
+        }
+        state = client().admin().cluster().prepareState().execute().actionGet().getState();
+        logger.info(ShardAllocations.printShardDistribution(state));
+        verifyPerIndexPrimaryBalance();
+        verifyPrimaryBalance(buffer);
+
+        final int additionalNodeCount = randomIntBetween(1, 5);
+        logger.info("--> Adding {} nodes", additionalNodeCount);
+
+        internalCluster().startNodes(additionalNodeCount);
+        ensureGreen(TimeValue.timeValueSeconds(60));
+        state = client().admin().cluster().prepareState().execute().actionGet().getState();
+        logger.info(ShardAllocations.printShardDistribution(state));
+        verifyPerIndexPrimaryBalance();
+        verifyPrimaryBalance(buffer);
+
+        int nodeCountToStop = additionalNodeCount;
+        while (nodeCountToStop > 0) {
+            internalCluster().stopRandomDataNode();
+            // give replica a chance to promote as primary before terminating node containing the replica
+            ensureGreen(TimeValue.timeValueSeconds(60));
+            nodeCountToStop--;
+        }
+        state = client().admin().cluster().prepareState().execute().actionGet().getState();
+        logger.info("--> Cluster state post nodes stop {}", state);
+        logger.info(ShardAllocations.printShardDistribution(state));
+        verifyPerIndexPrimaryBalance();
+        verifyPrimaryBalance(buffer);
+    }
+
     /**
      * Utility method which ensures cluster has balanced primary shard distribution across a single index.
      * @throws Exception exception
@@ -263,7 +344,7 @@ private void verifyPerIndexPrimaryBalance() throws Exception {
         }, 60, TimeUnit.SECONDS);
     }
 
-    private void verifyPrimaryBalance() throws Exception {
+    private void verifyPrimaryBalance(float buffer) throws Exception {
         assertBusy(() -> {
             final ClusterState currentState = client().admin().cluster().prepareState().execute().actionGet().getState();
             RoutingNodes nodes = currentState.getRoutingNodes();
@@ -278,7 +359,7 @@ private void verifyPrimaryBalance() throws Exception {
                     .filter(ShardRouting::primary)
                     .collect(Collectors.toList())
                     .size();
-                assertTrue(primaryCount <= avgPrimaryShardsPerNode);
+                assertTrue(primaryCount <= (avgPrimaryShardsPerNode * (1 + buffer)));
             }
         }, 60, TimeUnit.SECONDS);
     }
diff --git a/server/src/main/java/org/opensearch/cluster/routing/allocation/AllocationConstraints.java b/server/src/main/java/org/opensearch/cluster/routing/allocation/AllocationConstraints.java
index 5375910c57579..6702db4b43e91 100644
--- a/server/src/main/java/org/opensearch/cluster/routing/allocation/AllocationConstraints.java
+++ b/server/src/main/java/org/opensearch/cluster/routing/allocation/AllocationConstraints.java
@@ -30,9 +30,9 @@ public class AllocationConstraints {
 
     public AllocationConstraints() {
         this.constraints = new HashMap<>();
-        this.constraints.putIfAbsent(INDEX_SHARD_PER_NODE_BREACH_CONSTRAINT_ID, new Constraint(isIndexShardsPerNodeBreached()));
-        this.constraints.putIfAbsent(INDEX_PRIMARY_SHARD_BALANCE_CONSTRAINT_ID, new Constraint(isPerIndexPrimaryShardsPerNodeBreached()));
-        this.constraints.putIfAbsent(CLUSTER_PRIMARY_SHARD_BALANCE_CONSTRAINT_ID, new Constraint(isPrimaryShardsPerNodeBreached()));
+        this.constraints.put(INDEX_SHARD_PER_NODE_BREACH_CONSTRAINT_ID, new Constraint(isIndexShardsPerNodeBreached()));
+        this.constraints.put(INDEX_PRIMARY_SHARD_BALANCE_CONSTRAINT_ID, new Constraint(isPerIndexPrimaryShardsPerNodeBreached()));
+        this.constraints.put(CLUSTER_PRIMARY_SHARD_BALANCE_CONSTRAINT_ID, new Constraint(isPrimaryShardsPerNodeBreached(0.0f)));
     }
 
     public void updateAllocationConstraint(String constraint, boolean enable) {
diff --git a/server/src/main/java/org/opensearch/cluster/routing/allocation/ConstraintTypes.java b/server/src/main/java/org/opensearch/cluster/routing/allocation/ConstraintTypes.java
index ae2d4a0926194..08fe8f92d1f80 100644
--- a/server/src/main/java/org/opensearch/cluster/routing/allocation/ConstraintTypes.java
+++ b/server/src/main/java/org/opensearch/cluster/routing/allocation/ConstraintTypes.java
@@ -28,6 +28,11 @@ public class ConstraintTypes {
      */
     public final static String CLUSTER_PRIMARY_SHARD_BALANCE_CONSTRAINT_ID = "cluster.primary.shard.balance.constraint";
 
+    /**
+     * Defines a cluster constraint which is breached when a node contains more than avg primary shards across all indices
+     */
+    public final static String CLUSTER_PRIMARY_SHARD_REBALANCE_CONSTRAINT_ID = "cluster.primary.shard.rebalance.constraint";
+
     /**
      * Defines an index constraint which is breached when a node contains more than avg number of shards for an index
      */
@@ -70,14 +75,14 @@ public static Predicate<Constraint.ConstraintParams> isPerIndexPrimaryShardsPerN
     }
 
     /**
-     * Defines a predicate which returns true when a node contains more than average number of primary shards. This
-     * constraint is used in weight calculation during allocation only. When breached a high weight {@link ConstraintTypes#CONSTRAINT_WEIGHT}
-     * is assigned to node resulting in lesser chances of node being selected as allocation target
+     * Defines a predicate which returns true when a node contains more than average number of primary shards with added buffer. This
+     * constraint is used in weight calculation during allocation/rebalance both. When breached a high weight {@link ConstraintTypes#CONSTRAINT_WEIGHT}
+     * is assigned to node resulting in lesser chances of node being selected as allocation/rebalance target
      */
-    public static Predicate<Constraint.ConstraintParams> isPrimaryShardsPerNodeBreached() {
+    public static Predicate<Constraint.ConstraintParams> isPrimaryShardsPerNodeBreached(float buffer) {
         return (params) -> {
             int primaryShardCount = params.getNode().numPrimaryShards();
-            int allowedPrimaryShardCount = (int) Math.ceil(params.getBalancer().avgPrimaryShardsPerNode());
+            int allowedPrimaryShardCount = (int) Math.ceil(params.getBalancer().avgPrimaryShardsPerNode() * (1 + buffer));
             return primaryShardCount >= allowedPrimaryShardCount;
         };
     }
diff --git a/server/src/main/java/org/opensearch/cluster/routing/allocation/RebalanceConstraints.java b/server/src/main/java/org/opensearch/cluster/routing/allocation/RebalanceConstraints.java
index a4036ec47ec0e..2c2138af18abc 100644
--- a/server/src/main/java/org/opensearch/cluster/routing/allocation/RebalanceConstraints.java
+++ b/server/src/main/java/org/opensearch/cluster/routing/allocation/RebalanceConstraints.java
@@ -14,8 +14,10 @@
 import java.util.HashMap;
 import java.util.Map;
 
+import static org.opensearch.cluster.routing.allocation.ConstraintTypes.CLUSTER_PRIMARY_SHARD_REBALANCE_CONSTRAINT_ID;
 import static org.opensearch.cluster.routing.allocation.ConstraintTypes.INDEX_PRIMARY_SHARD_BALANCE_CONSTRAINT_ID;
 import static org.opensearch.cluster.routing.allocation.ConstraintTypes.isPerIndexPrimaryShardsPerNodeBreached;
+import static org.opensearch.cluster.routing.allocation.ConstraintTypes.isPrimaryShardsPerNodeBreached;
 
 /**
  * Constraints applied during rebalancing round; specify conditions which, if breached, reduce the
@@ -27,9 +29,13 @@ public class RebalanceConstraints {
 
     private Map<String, Constraint> constraints;
 
-    public RebalanceConstraints() {
+    public RebalanceConstraints(RebalanceParameter rebalanceParameter) {
         this.constraints = new HashMap<>();
-        this.constraints.putIfAbsent(INDEX_PRIMARY_SHARD_BALANCE_CONSTRAINT_ID, new Constraint(isPerIndexPrimaryShardsPerNodeBreached()));
+        this.constraints.put(INDEX_PRIMARY_SHARD_BALANCE_CONSTRAINT_ID, new Constraint(isPerIndexPrimaryShardsPerNodeBreached()));
+        this.constraints.put(
+            CLUSTER_PRIMARY_SHARD_REBALANCE_CONSTRAINT_ID,
+            new Constraint(isPrimaryShardsPerNodeBreached(rebalanceParameter.getPreferPrimaryBalanceBuffer()))
+        );
     }
 
     public void updateRebalanceConstraint(String constraint, boolean enable) {
diff --git a/server/src/main/java/org/opensearch/cluster/routing/allocation/RebalanceParameter.java b/server/src/main/java/org/opensearch/cluster/routing/allocation/RebalanceParameter.java
new file mode 100644
index 0000000000000..35fbaede93ba3
--- /dev/null
+++ b/server/src/main/java/org/opensearch/cluster/routing/allocation/RebalanceParameter.java
@@ -0,0 +1,24 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * The OpenSearch Contributors require contributions made to
+ * this file be licensed under the Apache-2.0 license or a
+ * compatible open source license.
+ */
+
+package org.opensearch.cluster.routing.allocation;
+
+/**
+ * RebalanceConstraint Params
+ */
+public class RebalanceParameter {
+    private float preferPrimaryBalanceBuffer;
+
+    public RebalanceParameter(float preferPrimaryBalanceBuffer) {
+        this.preferPrimaryBalanceBuffer = preferPrimaryBalanceBuffer;
+    }
+
+    public float getPreferPrimaryBalanceBuffer() {
+        return preferPrimaryBalanceBuffer;
+    }
+}
diff --git a/server/src/main/java/org/opensearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java b/server/src/main/java/org/opensearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java
index 41ace0e7661fe..b2443490dd973 100644
--- a/server/src/main/java/org/opensearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java
+++ b/server/src/main/java/org/opensearch/cluster/routing/allocation/allocator/BalancedShardsAllocator.java
@@ -46,6 +46,7 @@
 import org.opensearch.cluster.routing.allocation.ConstraintTypes;
 import org.opensearch.cluster.routing.allocation.MoveDecision;
 import org.opensearch.cluster.routing.allocation.RebalanceConstraints;
+import org.opensearch.cluster.routing.allocation.RebalanceParameter;
 import org.opensearch.cluster.routing.allocation.RoutingAllocation;
 import org.opensearch.cluster.routing.allocation.ShardAllocationDecision;
 import org.opensearch.common.inject.Inject;
@@ -61,6 +62,7 @@
 import java.util.Set;
 
 import static org.opensearch.cluster.routing.allocation.ConstraintTypes.CLUSTER_PRIMARY_SHARD_BALANCE_CONSTRAINT_ID;
+import static org.opensearch.cluster.routing.allocation.ConstraintTypes.CLUSTER_PRIMARY_SHARD_REBALANCE_CONSTRAINT_ID;
 import static org.opensearch.cluster.routing.allocation.ConstraintTypes.INDEX_PRIMARY_SHARD_BALANCE_CONSTRAINT_ID;
 import static org.opensearch.cluster.routing.allocation.ConstraintTypes.INDEX_SHARD_PER_NODE_BREACH_CONSTRAINT_ID;
 
@@ -145,10 +147,29 @@ public class BalancedShardsAllocator implements ShardsAllocator {
         Property.NodeScope
     );
 
+    public static final Setting<Boolean> PREFER_PRIMARY_SHARD_REBALANCE = Setting.boolSetting(
+        "cluster.routing.allocation.rebalance.primary.enable",
+        false,
+        Property.Dynamic,
+        Property.NodeScope
+    );
+
+    public static final Setting<Float> PRIMARY_SHARD_REBALANCE_BUFFER = Setting.floatSetting(
+        "cluster.routing.allocation.rebalance.primary.buffer",
+        0.10f,
+        0.0f,
+        Property.Dynamic,
+        Property.NodeScope
+    );
+
     private volatile boolean movePrimaryFirst;
     private volatile ShardMovementStrategy shardMovementStrategy;
 
     private volatile boolean preferPrimaryShardBalance;
+    private volatile boolean preferPrimaryShardRebalance;
+    private volatile float preferPrimaryShardRebalanceBuffer;
+    private volatile float indexBalanceFactor;
+    private volatile float shardBalanceFactor;
     private volatile WeightFunction weightFunction;
     private volatile float threshold;
 
@@ -158,14 +179,21 @@ public BalancedShardsAllocator(Settings settings) {
 
     @Inject
     public BalancedShardsAllocator(Settings settings, ClusterSettings clusterSettings) {
-        setWeightFunction(INDEX_BALANCE_FACTOR_SETTING.get(settings), SHARD_BALANCE_FACTOR_SETTING.get(settings));
+        setShardBalanceFactor(SHARD_BALANCE_FACTOR_SETTING.get(settings));
+        setIndexBalanceFactor(INDEX_BALANCE_FACTOR_SETTING.get(settings));
+        setPreferPrimaryShardRebalanceBuffer(PRIMARY_SHARD_REBALANCE_BUFFER.get(settings));
+        updateWeightFunction();
         setThreshold(THRESHOLD_SETTING.get(settings));
         setPreferPrimaryShardBalance(PREFER_PRIMARY_SHARD_BALANCE.get(settings));
+        setPreferPrimaryShardRebalance(PREFER_PRIMARY_SHARD_REBALANCE.get(settings));
         setShardMovementStrategy(SHARD_MOVEMENT_STRATEGY_SETTING.get(settings));
         clusterSettings.addSettingsUpdateConsumer(PREFER_PRIMARY_SHARD_BALANCE, this::setPreferPrimaryShardBalance);
         clusterSettings.addSettingsUpdateConsumer(SHARD_MOVE_PRIMARY_FIRST_SETTING, this::setMovePrimaryFirst);
         clusterSettings.addSettingsUpdateConsumer(SHARD_MOVEMENT_STRATEGY_SETTING, this::setShardMovementStrategy);
-        clusterSettings.addSettingsUpdateConsumer(INDEX_BALANCE_FACTOR_SETTING, SHARD_BALANCE_FACTOR_SETTING, this::setWeightFunction);
+        clusterSettings.addSettingsUpdateConsumer(INDEX_BALANCE_FACTOR_SETTING, this::updateIndexBalanceFactor);
+        clusterSettings.addSettingsUpdateConsumer(SHARD_BALANCE_FACTOR_SETTING, this::updateShardBalanceFactor);
+        clusterSettings.addSettingsUpdateConsumer(PRIMARY_SHARD_REBALANCE_BUFFER, this::updatePreferPrimaryShardBalanceBuffer);
+        clusterSettings.addSettingsUpdateConsumer(PREFER_PRIMARY_SHARD_REBALANCE, this::setPreferPrimaryShardRebalance);
         clusterSettings.addSettingsUpdateConsumer(THRESHOLD_SETTING, this::setThreshold);
     }
 
@@ -190,8 +218,35 @@ private void setShardMovementStrategy(ShardMovementStrategy shardMovementStrateg
         }
     }
 
-    private void setWeightFunction(float indexBalance, float shardBalanceFactor) {
-        weightFunction = new WeightFunction(indexBalance, shardBalanceFactor);
+    private void setIndexBalanceFactor(float indexBalanceFactor) {
+        this.indexBalanceFactor = indexBalanceFactor;
+    }
+
+    private void setShardBalanceFactor(float shardBalanceFactor) {
+        this.shardBalanceFactor = shardBalanceFactor;
+    }
+
+    private void setPreferPrimaryShardRebalanceBuffer(float preferPrimaryShardRebalanceBuffer) {
+        this.preferPrimaryShardRebalanceBuffer = preferPrimaryShardRebalanceBuffer;
+    }
+
+    private void updateIndexBalanceFactor(float indexBalanceFactor) {
+        this.indexBalanceFactor = indexBalanceFactor;
+        updateWeightFunction();
+    }
+
+    private void updateShardBalanceFactor(float shardBalanceFactor) {
+        this.shardBalanceFactor = shardBalanceFactor;
+        updateWeightFunction();
+    }
+
+    private void updatePreferPrimaryShardBalanceBuffer(float preferPrimaryShardBalanceBuffer) {
+        this.preferPrimaryShardRebalanceBuffer = preferPrimaryShardBalanceBuffer;
+        updateWeightFunction();
+    }
+
+    private void updateWeightFunction() {
+        weightFunction = new WeightFunction(this.indexBalanceFactor, this.shardBalanceFactor, this.preferPrimaryShardRebalanceBuffer);
     }
 
     /**
@@ -205,6 +260,11 @@ private void setPreferPrimaryShardBalance(boolean preferPrimaryShardBalance) {
         this.weightFunction.updateRebalanceConstraint(INDEX_PRIMARY_SHARD_BALANCE_CONSTRAINT_ID, preferPrimaryShardBalance);
     }
 
+    private void setPreferPrimaryShardRebalance(boolean preferPrimaryShardRebalance) {
+        this.preferPrimaryShardRebalance = preferPrimaryShardRebalance;
+        this.weightFunction.updateRebalanceConstraint(CLUSTER_PRIMARY_SHARD_REBALANCE_CONSTRAINT_ID, preferPrimaryShardRebalance);
+    }
+
     private void setThreshold(float threshold) {
         this.threshold = threshold;
     }
@@ -221,7 +281,8 @@ public void allocate(RoutingAllocation allocation) {
             shardMovementStrategy,
             weightFunction,
             threshold,
-            preferPrimaryShardBalance
+            preferPrimaryShardBalance,
+            preferPrimaryShardRebalance
         );
         localShardsBalancer.allocateUnassigned();
         localShardsBalancer.moveShards();
@@ -242,7 +303,8 @@ public ShardAllocationDecision decideShardAllocation(final ShardRouting shard, f
             shardMovementStrategy,
             weightFunction,
             threshold,
-            preferPrimaryShardBalance
+            preferPrimaryShardBalance,
+            preferPrimaryShardRebalance
         );
         AllocateUnassignedDecision allocateUnassignedDecision = AllocateUnassignedDecision.NOT_TAKEN;
         MoveDecision moveDecision = MoveDecision.NOT_TAKEN;
@@ -348,7 +410,7 @@ static class WeightFunction {
         private AllocationConstraints constraints;
         private RebalanceConstraints rebalanceConstraints;
 
-        WeightFunction(float indexBalance, float shardBalance) {
+        WeightFunction(float indexBalance, float shardBalance, float preferPrimaryBalanceBuffer) {
             float sum = indexBalance + shardBalance;
             if (sum <= 0.0f) {
                 throw new IllegalArgumentException("Balance factors must sum to a value > 0 but was: " + sum);
@@ -357,8 +419,9 @@ static class WeightFunction {
             theta1 = indexBalance / sum;
             this.indexBalance = indexBalance;
             this.shardBalance = shardBalance;
+            RebalanceParameter rebalanceParameter = new RebalanceParameter(preferPrimaryBalanceBuffer);
             this.constraints = new AllocationConstraints();
-            this.rebalanceConstraints = new RebalanceConstraints();
+            this.rebalanceConstraints = new RebalanceConstraints(rebalanceParameter);
             // Enable index shard per node breach constraint
             updateAllocationConstraint(INDEX_SHARD_PER_NODE_BREACH_CONSTRAINT_ID, true);
         }
@@ -495,7 +558,7 @@ public Balancer(
             float threshold,
             boolean preferPrimaryBalance
         ) {
-            super(logger, allocation, shardMovementStrategy, weight, threshold, preferPrimaryBalance);
+            super(logger, allocation, shardMovementStrategy, weight, threshold, preferPrimaryBalance, false);
         }
     }
 
diff --git a/server/src/main/java/org/opensearch/cluster/routing/allocation/allocator/LocalShardsBalancer.java b/server/src/main/java/org/opensearch/cluster/routing/allocation/allocator/LocalShardsBalancer.java
index 46edd86043ab2..696a83dd624a8 100644
--- a/server/src/main/java/org/opensearch/cluster/routing/allocation/allocator/LocalShardsBalancer.java
+++ b/server/src/main/java/org/opensearch/cluster/routing/allocation/allocator/LocalShardsBalancer.java
@@ -61,6 +61,7 @@ public class LocalShardsBalancer extends ShardsBalancer {
     private final ShardMovementStrategy shardMovementStrategy;
 
     private final boolean preferPrimaryBalance;
+    private final boolean preferPrimaryRebalance;
     private final BalancedShardsAllocator.WeightFunction weight;
 
     private final float threshold;
@@ -76,7 +77,8 @@ public LocalShardsBalancer(
         ShardMovementStrategy shardMovementStrategy,
         BalancedShardsAllocator.WeightFunction weight,
         float threshold,
-        boolean preferPrimaryBalance
+        boolean preferPrimaryBalance,
+        boolean preferPrimaryRebalance
     ) {
         this.logger = logger;
         this.allocation = allocation;
@@ -91,6 +93,7 @@ public LocalShardsBalancer(
         sorter = newNodeSorter();
         inEligibleTargetNode = new HashSet<>();
         this.preferPrimaryBalance = preferPrimaryBalance;
+        this.preferPrimaryRebalance = preferPrimaryRebalance;
         this.shardMovementStrategy = shardMovementStrategy;
     }
 
@@ -995,13 +998,18 @@ private boolean tryRelocateShard(BalancedShardsAllocator.ModelNode minNode, Bala
                     continue;
                 }
                 // This is a safety net which prevents un-necessary primary shard relocations from maxNode to minNode when
-                // doing such relocation wouldn't help in primary balance.
+                // doing such relocation wouldn't help in primary balance. The condition won't be applicable when we enable node level
+                // primary rebalance
                 if (preferPrimaryBalance == true
+                    && preferPrimaryRebalance == false
                     && shard.primary()
                     && maxNode.numPrimaryShards(shard.getIndexName()) - minNode.numPrimaryShards(shard.getIndexName()) < 2) {
                     continue;
                 }
-
+                // Relax the above condition to per node to allow rebalancing to attain global balance
+                if (preferPrimaryRebalance == true && shard.primary() && maxNode.numPrimaryShards() - minNode.numPrimaryShards() < 2) {
+                    continue;
+                }
                 final Decision decision = new Decision.Multi().add(allocationDecision).add(rebalanceDecision);
                 maxNode.removeShard(shard);
                 long shardSize = allocation.clusterInfo().getShardSize(shard, ShardRouting.UNAVAILABLE_EXPECTED_SHARD_SIZE);
diff --git a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java
index 7c69a0487b618..92bcbf9e9f3dd 100644
--- a/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java
+++ b/server/src/main/java/org/opensearch/common/settings/ClusterSettings.java
@@ -253,7 +253,9 @@ public void apply(Settings value, Settings current, Settings previous) {
                 AwarenessReplicaBalance.CLUSTER_ROUTING_ALLOCATION_AWARENESS_BALANCE_SETTING,
                 BalancedShardsAllocator.INDEX_BALANCE_FACTOR_SETTING,
                 BalancedShardsAllocator.SHARD_BALANCE_FACTOR_SETTING,
+                BalancedShardsAllocator.PRIMARY_SHARD_REBALANCE_BUFFER,
                 BalancedShardsAllocator.PREFER_PRIMARY_SHARD_BALANCE,
+                BalancedShardsAllocator.PREFER_PRIMARY_SHARD_REBALANCE,
                 BalancedShardsAllocator.SHARD_MOVE_PRIMARY_FIRST_SETTING,
                 BalancedShardsAllocator.SHARD_MOVEMENT_STRATEGY_SETTING,
                 BalancedShardsAllocator.THRESHOLD_SETTING,
diff --git a/server/src/test/java/org/opensearch/cluster/routing/allocation/BalanceConfigurationTests.java b/server/src/test/java/org/opensearch/cluster/routing/allocation/BalanceConfigurationTests.java
index 62dd14e69c402..dc2247f913a0f 100644
--- a/server/src/test/java/org/opensearch/cluster/routing/allocation/BalanceConfigurationTests.java
+++ b/server/src/test/java/org/opensearch/cluster/routing/allocation/BalanceConfigurationTests.java
@@ -72,6 +72,7 @@
 import java.util.concurrent.TimeUnit;
 import java.util.stream.Collectors;
 
+import static org.opensearch.cluster.ClusterName.CLUSTER_NAME_SETTING;
 import static org.opensearch.cluster.metadata.IndexMetadata.SETTING_CREATION_DATE;
 import static org.opensearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_REPLICAS;
 import static org.opensearch.cluster.metadata.IndexMetadata.SETTING_NUMBER_OF_SHARDS;
@@ -140,10 +141,14 @@ public void testIndexBalance() {
     }
 
     private Settings.Builder getSettingsBuilderForPrimaryBalance() {
-        return getSettingsBuilderForPrimaryBalance(true);
+        return getSettingsBuilderForPrimaryBalance(true, false);
     }
 
-    private Settings.Builder getSettingsBuilderForPrimaryBalance(boolean preferPrimaryBalance) {
+    private Settings.Builder getSettingsBuilderForPrimaryReBalance() {
+        return getSettingsBuilderForPrimaryBalance(true, true);
+    }
+
+    private Settings.Builder getSettingsBuilderForPrimaryBalance(boolean preferPrimaryBalance, boolean preferPrimaryRebalance) {
         final float indexBalance = 0.55f;
         final float shardBalance = 0.45f;
         final float balanceThreshold = 1.0f;
@@ -155,6 +160,7 @@ private Settings.Builder getSettingsBuilderForPrimaryBalance(boolean preferPrima
         );
         settings.put(BalancedShardsAllocator.INDEX_BALANCE_FACTOR_SETTING.getKey(), indexBalance);
         settings.put(BalancedShardsAllocator.PREFER_PRIMARY_SHARD_BALANCE.getKey(), preferPrimaryBalance);
+        settings.put(BalancedShardsAllocator.PREFER_PRIMARY_SHARD_REBALANCE.getKey(), preferPrimaryRebalance);
         settings.put(BalancedShardsAllocator.SHARD_BALANCE_FACTOR_SETTING.getKey(), shardBalance);
         settings.put(BalancedShardsAllocator.THRESHOLD_SETTING.getKey(), balanceThreshold);
         return settings;
@@ -201,7 +207,7 @@ public void testPrimaryBalanceWithoutPreferPrimaryBalanceSetting() {
         int balanceFailed = 0;
 
         AllocationService strategy = createAllocationService(
-            getSettingsBuilderForPrimaryBalance(false).build(),
+            getSettingsBuilderForPrimaryBalance(false, false).build(),
             new TestGatewayAllocator()
         );
         for (int i = 0; i < numberOfRuns; i++) {
@@ -244,6 +250,60 @@ public void testPrimaryBalanceWithPreferPrimaryBalanceSetting() {
         assertTrue(balanceFailed <= 1);
     }
 
+    /**
+     * This test verifies primary shard balance is attained  setting.
+     */
+    public void testPrimaryBalanceNotSolvedForNodeDropWithPreferPrimaryBalanceSetting() {
+        final int numberOfNodes = 4;
+        final int numberOfIndices = 4;
+        final int numberOfShards = 4;
+        final int numberOfReplicas = 1;
+        final int numberOfRuns = 5;
+        final float buffer = 0.10f;
+        int balanceFailed = 0;
+
+        AllocationService strategy = createAllocationService(getSettingsBuilderForPrimaryBalance().build(), new TestGatewayAllocator());
+        for (int i = 0; i < numberOfRuns; i++) {
+            ClusterState clusterState = initCluster(strategy, numberOfIndices, numberOfNodes, numberOfShards, numberOfReplicas);
+            clusterState = removeOneNode(clusterState, strategy);
+            logger.info(ShardAllocations.printShardDistribution(clusterState));
+            try {
+                verifyPrimaryBalance(clusterState, buffer);
+            } catch (AssertionError | Exception e) {
+                balanceFailed++;
+                logger.info("Unexpected assertion failure");
+            }
+        }
+        assertTrue(balanceFailed >= 4);
+    }
+
+    /**
+     * This test verifies primary shard balance is attained with PREFER_PRIMARY_SHARD_BALANCE setting.
+     */
+    public void testPrimaryBalanceSolvedWithPreferPrimaryRebalanceSetting() {
+        final int numberOfNodes = 4;
+        final int numberOfIndices = 4;
+        final int numberOfShards = 4;
+        final int numberOfReplicas = 1;
+        final int numberOfRuns = 5;
+        final float buffer = 0.10f;
+        int balanceFailed = 0;
+
+        AllocationService strategy = createAllocationService(getSettingsBuilderForPrimaryReBalance().build(), new TestGatewayAllocator());
+        for (int i = 0; i < numberOfRuns; i++) {
+            ClusterState clusterState = initCluster(strategy, numberOfIndices, numberOfNodes, numberOfShards, numberOfReplicas);
+            clusterState = removeOneNode(clusterState, strategy);
+            logger.info(ShardAllocations.printShardDistribution(clusterState));
+            try {
+                verifyPrimaryBalance(clusterState, buffer);
+            } catch (Exception e) {
+                balanceFailed++;
+                logger.info("Unexpected assertion failure");
+            }
+        }
+        assertTrue(balanceFailed <= 1);
+    }
+
     /**
      * This test verifies the allocation logic when nodes breach multiple constraints and ensure node breaching min
      * constraints chosen for allocation.
@@ -368,8 +428,7 @@ public void testPrimaryBalanceWithContrainstBreaching() {
      */
     public void testGlobalPrimaryBalance() throws Exception {
         AllocationService strategy = createAllocationService(getSettingsBuilderForPrimaryBalance().build(), new TestGatewayAllocator());
-        ClusterState clusterState = ClusterState.builder(org.opensearch.cluster.ClusterName.CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY))
-            .build();
+        ClusterState clusterState = ClusterState.builder(CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY)).build();
         clusterState = addNode(clusterState, strategy);
         clusterState = addNode(clusterState, strategy);
 
@@ -378,7 +437,30 @@ public void testGlobalPrimaryBalance() throws Exception {
         clusterState = addIndex(clusterState, strategy, "test-index3", 1, 1);
 
         logger.info(ShardAllocations.printShardDistribution(clusterState));
-        verifyPrimaryBalance(clusterState);
+        verifyPrimaryBalance(clusterState, 0.0f);
+    }
+
+    /**
+     * This test verifies global balance by creating indices iteratively and verify primary shards do not pile up on one
+     * @throws Exception generic exception
+     */
+    public void testGlobalPrimaryBalanceWithNodeDrops() throws Exception {
+        final float buffer = 0.10f;
+        AllocationService strategy = createAllocationService(getSettingsBuilderForPrimaryReBalance().build(), new TestGatewayAllocator());
+        ClusterState clusterState = ClusterState.builder(CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY)).build();
+        clusterState = addNodes(clusterState, strategy, 5);
+
+        clusterState = addIndices(clusterState, strategy, 5, 1, 8);
+
+        logger.info(ShardAllocations.printShardDistribution(clusterState));
+        verifyPrimaryBalance(clusterState, buffer);
+
+        clusterState = removeOneNode(clusterState, strategy);
+
+        clusterState = applyAllocationUntilNoChange(clusterState, strategy);
+
+        logger.info(ShardAllocations.printShardDistribution(clusterState));
+        verifyPrimaryBalance(clusterState, buffer);
     }
 
     /**
@@ -538,7 +620,7 @@ private void verifyPerIndexPrimaryBalance(ClusterState currentState) {
         }
     }
 
-    private void verifyPrimaryBalance(ClusterState clusterState) throws Exception {
+    private void verifySkewedPrimaryBalance(ClusterState clusterState, int delta) throws Exception {
         assertBusy(() -> {
             RoutingNodes nodes = clusterState.getRoutingNodes();
             int totalPrimaryShards = 0;
@@ -546,13 +628,36 @@ private void verifyPrimaryBalance(ClusterState clusterState) throws Exception {
                 totalPrimaryShards += index.primaryShardsActive();
             }
             final int avgPrimaryShardsPerNode = (int) Math.ceil(totalPrimaryShards * 1f / clusterState.getRoutingNodes().size());
+            int maxPrimaryShardOnNode = Integer.MIN_VALUE;
+            int minPrimaryShardOnNode = Integer.MAX_VALUE;
             for (RoutingNode node : nodes) {
                 final int primaryCount = node.shardsWithState(STARTED)
                     .stream()
                     .filter(ShardRouting::primary)
                     .collect(Collectors.toList())
                     .size();
-                assertTrue(primaryCount <= avgPrimaryShardsPerNode);
+                maxPrimaryShardOnNode = Math.max(maxPrimaryShardOnNode, primaryCount);
+                minPrimaryShardOnNode = Math.min(minPrimaryShardOnNode, primaryCount);
+            }
+            assertTrue(maxPrimaryShardOnNode - minPrimaryShardOnNode < delta);
+        }, 60, TimeUnit.SECONDS);
+    }
+
+    private void verifyPrimaryBalance(ClusterState clusterState, float buffer) throws Exception {
+        assertBusy(() -> {
+            RoutingNodes nodes = clusterState.getRoutingNodes();
+            int totalPrimaryShards = 0;
+            for (final IndexRoutingTable index : clusterState.getRoutingTable().indicesRouting().values()) {
+                totalPrimaryShards += index.primaryShardsActive();
+            }
+            final int avgPrimaryShardsPerNode = (int) Math.ceil(totalPrimaryShards * 1f / clusterState.getRoutingNodes().size());
+            for (RoutingNode node : nodes) {
+                final int primaryCount = node.shardsWithState(STARTED)
+                    .stream()
+                    .filter(ShardRouting::primary)
+                    .collect(Collectors.toList())
+                    .size();
+                assertTrue(primaryCount <= (avgPrimaryShardsPerNode * (1 + buffer)));
             }
         }, 60, TimeUnit.SECONDS);
     }
@@ -568,8 +673,8 @@ public void testShardBalance() {
             ClusterRebalanceAllocationDecider.CLUSTER_ROUTING_ALLOCATION_ALLOW_REBALANCE_SETTING.getKey(),
             ClusterRebalanceAllocationDecider.ClusterRebalanceType.ALWAYS.toString()
         );
-        settings.put(BalancedShardsAllocator.INDEX_BALANCE_FACTOR_SETTING.getKey(), indexBalance);
         settings.put(BalancedShardsAllocator.SHARD_BALANCE_FACTOR_SETTING.getKey(), shardBalance);
+        settings.put(BalancedShardsAllocator.INDEX_BALANCE_FACTOR_SETTING.getKey(), indexBalance);
         settings.put(BalancedShardsAllocator.THRESHOLD_SETTING.getKey(), balanceThreshold);
 
         AllocationService strategy = createAllocationService(settings.build(), new TestGatewayAllocator());
@@ -635,6 +740,34 @@ private ClusterState addIndex(
         return applyAllocationUntilNoChange(clusterState, strategy);
     }
 
+    private ClusterState addIndices(
+        ClusterState clusterState,
+        AllocationService strategy,
+        int numberOfShards,
+        int numberOfReplicas,
+        int numberOfIndices
+    ) {
+        Metadata.Builder metadataBuilder = Metadata.builder(clusterState.getMetadata());
+        RoutingTable.Builder routingTableBuilder = RoutingTable.builder(clusterState.routingTable());
+
+        for (int i = 0; i < numberOfIndices; i++) {
+            IndexMetadata.Builder index = IndexMetadata.builder("test" + i)
+                .settings(settings(Version.CURRENT))
+                .numberOfShards(numberOfShards)
+                .numberOfReplicas(numberOfReplicas);
+
+            metadataBuilder = metadataBuilder.put(index);
+            routingTableBuilder.addAsNew(index.build());
+        }
+
+        clusterState = ClusterState.builder(clusterState)
+            .metadata(metadataBuilder.build())
+            .routingTable(routingTableBuilder.build())
+            .build();
+        clusterState = strategy.reroute(clusterState, "indices-created");
+        return applyAllocationUntilNoChange(clusterState, strategy);
+    }
+
     private ClusterState initCluster(
         AllocationService strategy,
         int numberOfIndices,
@@ -665,7 +798,7 @@ private ClusterState initCluster(
         for (int i = 0; i < numberOfNodes; i++) {
             nodes.add(newNode("node" + i));
         }
-        ClusterState clusterState = ClusterState.builder(org.opensearch.cluster.ClusterName.CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY))
+        ClusterState clusterState = ClusterState.builder(CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY))
             .nodes(nodes)
             .metadata(metadata)
             .routingTable(initialRoutingTable)
@@ -674,6 +807,17 @@ private ClusterState initCluster(
         return applyAllocationUntilNoChange(clusterState, strategy);
     }
 
+    private ClusterState addNodes(ClusterState clusterState, AllocationService strategy, int numberOfNodes) {
+        logger.info("now, start [{}] more node, check that rebalancing will happen because we set it to always", numberOfNodes);
+        DiscoveryNodes.Builder nodes = DiscoveryNodes.builder(clusterState.nodes());
+        for (int i = 0; i < numberOfNodes; i++) {
+            nodes.add(newNode("node" + (clusterState.nodes().getSize() + i)));
+        }
+        clusterState = ClusterState.builder(clusterState).nodes(nodes.build()).build();
+        clusterState = strategy.reroute(clusterState, "reroute");
+        return applyStartedShardsUntilNoChange(clusterState, strategy);
+    }
+
     private ClusterState addNode(ClusterState clusterState, AllocationService strategy) {
         logger.info("now, start 1 more node, check that rebalancing will happen because we set it to always");
         clusterState = ClusterState.builder(clusterState)
@@ -919,7 +1063,7 @@ public ShardAllocationDecision decideShardAllocation(ShardRouting shard, Routing
             nodes.add(node);
         }
 
-        ClusterState clusterState = ClusterState.builder(org.opensearch.cluster.ClusterName.CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY))
+        ClusterState clusterState = ClusterState.builder(CLUSTER_NAME_SETTING.getDefault(Settings.EMPTY))
             .nodes(nodes)
             .metadata(metadata)
             .routingTable(routingTable)