diff --git a/CHANGELOG.md b/CHANGELOG.md
index e486b3a0562c3..54e5e8dbf11d8 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -58,6 +58,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
 - Fix `doc_values` only (`index:false`) IP field searching for masks ([#16628](https://github.com/opensearch-project/OpenSearch/pull/16628))
 - Fix stale cluster state custom file deletion ([#16670](https://github.com/opensearch-project/OpenSearch/pull/16670))
 - Bound the size of cache in deprecation logger ([16702](https://github.com/opensearch-project/OpenSearch/issues/16702))
+- [Tiered Caching] Fix bug in cache stats API ([#16560](https://github.com/opensearch-project/OpenSearch/pull/16560))
 
 ### Security
 
diff --git a/modules/cache-common/src/internalClusterTest/java/org/opensearch/cache/common/tier/TieredSpilloverCacheStatsIT.java b/modules/cache-common/src/internalClusterTest/java/org/opensearch/cache/common/tier/TieredSpilloverCacheStatsIT.java
index fe6bd7050a8f3..a858e94ad1609 100644
--- a/modules/cache-common/src/internalClusterTest/java/org/opensearch/cache/common/tier/TieredSpilloverCacheStatsIT.java
+++ b/modules/cache-common/src/internalClusterTest/java/org/opensearch/cache/common/tier/TieredSpilloverCacheStatsIT.java
@@ -10,6 +10,7 @@
 
 import org.opensearch.action.admin.cluster.node.stats.NodesStatsRequest;
 import org.opensearch.action.admin.cluster.node.stats.NodesStatsResponse;
+import org.opensearch.action.admin.indices.delete.DeleteIndexRequest;
 import org.opensearch.action.admin.indices.forcemerge.ForceMergeResponse;
 import org.opensearch.action.admin.indices.stats.CommonStatsFlags;
 import org.opensearch.action.search.SearchResponse;
@@ -40,6 +41,7 @@
 import static org.opensearch.cache.common.tier.TieredSpilloverCacheStatsHolder.TIER_DIMENSION_NAME;
 import static org.opensearch.cache.common.tier.TieredSpilloverCacheStatsHolder.TIER_DIMENSION_VALUE_DISK;
 import static org.opensearch.cache.common.tier.TieredSpilloverCacheStatsHolder.TIER_DIMENSION_VALUE_ON_HEAP;
+import static org.opensearch.indices.IndicesService.INDICES_CACHE_CLEAN_INTERVAL_SETTING;
 import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked;
 import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertSearchResponse;
 
@@ -417,6 +419,55 @@ public void testStatsWithMultipleSegments() throws Exception {
         assertTrue(diskCacheStat.getEvictions() == 0);
     }
 
+    public void testClosingShard() throws Exception {
+        // Closing the shard should totally remove the stats associated with that shard.
+        internalCluster().startNodes(
+            1,
+            Settings.builder()
+                .put(defaultSettings(HEAP_CACHE_SIZE_STRING, getNumberOfSegments()))
+                .put(
+                    TieredSpilloverCacheSettings.TOOK_TIME_POLICY_CONCRETE_SETTINGS_MAP.get(CacheType.INDICES_REQUEST_CACHE).getKey(),
+                    new TimeValue(0, TimeUnit.SECONDS)
+                )
+                .put(INDICES_CACHE_CLEAN_INTERVAL_SETTING.getKey(), new TimeValue(1))
+                .build()
+        );
+        String index = "index";
+        Client client = client();
+        startIndex(client, index);
+
+        // First search one time to see how big a single value will be
+        searchIndex(client, index, 0);
+        // get total stats
+        long singleSearchSize = getTotalStats(client).getSizeInBytes();
+        // Select numbers so we get some values on both heap and disk
+        int itemsOnHeap = HEAP_CACHE_SIZE / (int) singleSearchSize;
+        int itemsOnDisk = 1 + randomInt(30); // The first one we search (to get the size) always goes to disk
+        int expectedEntries = itemsOnHeap + itemsOnDisk;
+
+        for (int i = 1; i < expectedEntries; i++) {
+            // Cause misses
+            searchIndex(client, index, i);
+        }
+        int expectedMisses = itemsOnHeap + itemsOnDisk;
+
+        // Cause some hits
+        int expectedHits = randomIntBetween(itemsOnHeap, expectedEntries); // Select it so some hits come from both tiers
+        for (int i = 0; i < expectedHits; i++) {
+            searchIndex(client, index, i);
+        }
+
+        // Check the new stats API values are as expected
+        assertEquals(
+            new ImmutableCacheStats(expectedHits, expectedMisses, 0, expectedEntries * singleSearchSize, expectedEntries),
+            getTotalStats(client)
+        );
+
+        // Closing the index should close the shard
+        assertAcked(client().admin().indices().delete(new DeleteIndexRequest("index")).get());
+        assertEquals(new ImmutableCacheStats(0, 0, 0, 0, 0), getTotalStats(client));
+    }
+
     private void startIndex(Client client, String indexName) throws InterruptedException {
         assertAcked(
             client.admin()
diff --git a/modules/cache-common/src/main/java/org/opensearch/cache/common/tier/TieredSpilloverCache.java b/modules/cache-common/src/main/java/org/opensearch/cache/common/tier/TieredSpilloverCache.java
index ab5335ca0ca66..38a6915ffd10e 100644
--- a/modules/cache-common/src/main/java/org/opensearch/cache/common/tier/TieredSpilloverCache.java
+++ b/modules/cache-common/src/main/java/org/opensearch/cache/common/tier/TieredSpilloverCache.java
@@ -373,12 +373,10 @@ private V compute(ICacheKey<K> key, LoadAwareCacheLoader<ICacheKey<K>, V> loader
 
         @Override
         public void invalidate(ICacheKey<K> key) {
-            for (Map.Entry<ICache<K, V>, TierInfo> cacheEntry : caches.entrySet()) {
-                if (key.getDropStatsForDimensions()) {
-                    List<String> dimensionValues = statsHolder.getDimensionsWithTierValue(key.dimensions, cacheEntry.getValue().tierName);
-                    statsHolder.removeDimensions(dimensionValues);
-                }
-                if (key.key != null) {
+            if (key.getDropStatsForDimensions()) {
+                statsHolder.removeDimensions(key.dimensions);
+            } else if (key.key != null) {
+                for (Map.Entry<ICache<K, V>, TierInfo> cacheEntry : caches.entrySet()) {
                     try (ReleasableLock ignore = writeLock.acquire()) {
                         cacheEntry.getKey().invalidate(key);
                     }
diff --git a/modules/cache-common/src/main/java/org/opensearch/cache/common/tier/TieredSpilloverCacheStatsHolder.java b/modules/cache-common/src/main/java/org/opensearch/cache/common/tier/TieredSpilloverCacheStatsHolder.java
index b40724430454b..7ea6d3504a52c 100644
--- a/modules/cache-common/src/main/java/org/opensearch/cache/common/tier/TieredSpilloverCacheStatsHolder.java
+++ b/modules/cache-common/src/main/java/org/opensearch/cache/common/tier/TieredSpilloverCacheStatsHolder.java
@@ -43,6 +43,8 @@ public class TieredSpilloverCacheStatsHolder extends DefaultCacheStatsHolder {
     /** Dimension value for on-disk cache, like EhcacheDiskCache. */
     public static final String TIER_DIMENSION_VALUE_DISK = "disk";
 
+    static final List<String> TIER_VALUES = List.of(TIER_DIMENSION_VALUE_ON_HEAP, TIER_DIMENSION_VALUE_DISK);
+
     /**
      * Constructor for the stats holder.
      * @param originalDimensionNames the original dimension names, not including TIER_DIMENSION_NAME
@@ -167,4 +169,17 @@ public void decrementItems(List<String> dimensionValues) {
     void setDiskCacheEnabled(boolean diskCacheEnabled) {
         this.diskCacheEnabled = diskCacheEnabled;
     }
+
+    @Override
+    public void removeDimensions(List<String> dimensionValues) {
+        assert dimensionValues.size() == dimensionNames.size() - 1
+            : "Must specify a value for every dimension except tier when removing from StatsHolder";
+        // As we are removing nodes from the tree, obtain the lock
+        lock.lock();
+        try {
+            removeDimensionsHelper(dimensionValues, statsRoot, 0);
+        } finally {
+            lock.unlock();
+        }
+    }
 }
diff --git a/modules/cache-common/src/test/java/org/opensearch/cache/common/tier/TieredSpilloverCacheStatsHolderTests.java b/modules/cache-common/src/test/java/org/opensearch/cache/common/tier/TieredSpilloverCacheStatsHolderTests.java
new file mode 100644
index 0000000000000..4524e1592b005
--- /dev/null
+++ b/modules/cache-common/src/test/java/org/opensearch/cache/common/tier/TieredSpilloverCacheStatsHolderTests.java
@@ -0,0 +1,369 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * The OpenSearch Contributors require contributions made to
+ * this file be licensed under the Apache-2.0 license or a
+ * compatible open source license.
+ */
+
+package org.opensearch.cache.common.tier;
+
+import org.opensearch.common.Randomness;
+import org.opensearch.common.cache.stats.CacheStats;
+import org.opensearch.common.cache.stats.DefaultCacheStatsHolder;
+import org.opensearch.common.cache.stats.ImmutableCacheStats;
+import org.opensearch.common.cache.stats.ImmutableCacheStatsHolder;
+import org.opensearch.test.OpenSearchTestCase;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.UUID;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CountDownLatch;
+
+import static org.opensearch.cache.common.tier.TieredSpilloverCacheStatsHolder.TIER_DIMENSION_VALUE_DISK;
+import static org.opensearch.cache.common.tier.TieredSpilloverCacheStatsHolder.TIER_DIMENSION_VALUE_ON_HEAP;
+import static org.opensearch.cache.common.tier.TieredSpilloverCacheStatsHolder.TIER_VALUES;
+
+public class TieredSpilloverCacheStatsHolderTests extends OpenSearchTestCase {
+    // These are modified from DefaultCacheStatsHolderTests.java to account for the tiers. Because we can't add a dependency on server.test,
+    // we can't reuse the same code.
+
+    public void testAddAndGet() throws Exception {
+        for (boolean diskTierEnabled : List.of(true, false)) {
+            List<String> dimensionNames = List.of("dim1", "dim2", "dim3", "dim4");
+            TieredSpilloverCacheStatsHolder cacheStatsHolder = new TieredSpilloverCacheStatsHolder(dimensionNames, diskTierEnabled);
+            Map<String, List<String>> usedDimensionValues = getUsedDimensionValues(cacheStatsHolder, 10, diskTierEnabled);
+            Map<List<String>, CacheStats> expected = populateStats(cacheStatsHolder, usedDimensionValues, 1000, 10, diskTierEnabled);
+
+            // test the value in the map is as expected for each distinct combination of values (all leaf nodes)
+            for (List<String> dimensionValues : expected.keySet()) {
+                CacheStats expectedCounter = expected.get(dimensionValues);
+                ImmutableCacheStats actualCacheStats = getNodeStats(dimensionValues, cacheStatsHolder);
+                assertEquals(expectedCounter.immutableSnapshot(), actualCacheStats);
+            }
+
+            // Check overall total matches
+            CacheStats expectedTotal = new CacheStats();
+            for (List<String> dims : expected.keySet()) {
+                CacheStats other = expected.get(dims);
+                boolean countMissesAndEvictionsTowardsTotal = dims.get(dims.size() - 1).equals(TIER_DIMENSION_VALUE_DISK)
+                    || !diskTierEnabled;
+                add(expectedTotal, other, countMissesAndEvictionsTowardsTotal);
+            }
+            assertEquals(expectedTotal.immutableSnapshot(), cacheStatsHolder.getImmutableCacheStatsHolder(null).getTotalStats());
+        }
+    }
+
+    private void add(CacheStats original, CacheStats other, boolean countMissesAndEvictionsTowardsTotal) {
+        // Add other to original, accounting for whether other is from the heap or disk tier
+        long misses = 0;
+        long evictions = 0;
+        if (countMissesAndEvictionsTowardsTotal) {
+            misses = other.getMisses();
+            evictions = other.getEvictions();
+        }
+        CacheStats modifiedOther = new CacheStats(other.getHits(), misses, evictions, other.getSizeInBytes(), other.getItems());
+        original.add(modifiedOther);
+    }
+
+    public void testReset() throws Exception {
+        List<String> dimensionNames = List.of("dim1", "dim2");
+        TieredSpilloverCacheStatsHolder cacheStatsHolder = new TieredSpilloverCacheStatsHolder(dimensionNames, true);
+        Map<String, List<String>> usedDimensionValues = getUsedDimensionValues(cacheStatsHolder, 10, true);
+        Map<List<String>, CacheStats> expected = populateStats(cacheStatsHolder, usedDimensionValues, 100, 10, true);
+
+        cacheStatsHolder.reset();
+        for (List<String> dimensionValues : expected.keySet()) {
+            CacheStats originalCounter = expected.get(dimensionValues);
+            ImmutableCacheStats expectedTotal = new ImmutableCacheStats(
+                originalCounter.getHits(),
+                originalCounter.getMisses(),
+                originalCounter.getEvictions(),
+                0,
+                0
+            );
+
+            ImmutableCacheStats actual = getNodeStats(dimensionValues, cacheStatsHolder);
+            assertEquals(expectedTotal, actual);
+        }
+    }
+
+    public void testDropStatsForDimensions() throws Exception {
+        List<String> dimensionNames = List.of("dim1", "dim2");
+        // Create stats for the following dimension sets
+        List<List<String>> statsToPopulate = List.of(List.of("A1", "B1"), List.of("A2", "B2"), List.of("A2", "B3"));
+        for (boolean diskTierEnabled : List.of(true, false)) {
+            TieredSpilloverCacheStatsHolder cacheStatsHolder = new TieredSpilloverCacheStatsHolder(dimensionNames, diskTierEnabled);
+            setupRemovalTest(cacheStatsHolder, statsToPopulate, diskTierEnabled);
+
+            // Check the resulting total is correct.
+            int numNodes = statsToPopulate.size(); // Number of distinct sets of dimensions (not including tiers)
+            // If disk tier is enabled, we expect hits to be 2 * numNodes (1 heap + 1 disk per combination of dims), otherwise 1 * numNodes.
+            // Misses and evictions should be 1 * numNodes in either case (if disk tier is present, count only the disk misses/evictions, if
+            // disk tier is absent, count the heap ones)
+            long originalHits = diskTierEnabled ? 2 * numNodes : numNodes;
+            ImmutableCacheStats expectedTotal = new ImmutableCacheStats(originalHits, numNodes, numNodes, 0, 0);
+            assertEquals(expectedTotal, cacheStatsHolder.getImmutableCacheStatsHolder(null).getTotalStats());
+
+            // When we invalidate A2, B2, we should lose the node for B2, but not B3 or A2.
+            cacheStatsHolder.removeDimensions(List.of("A2", "B2"));
+
+            // We expect hits to go down by 2 (1 heap + 1 disk) if disk is enabled, and 1 otherwise. Evictions/misses should go down by 1 in
+            // either case.
+            long removedHitsPerRemovedNode = diskTierEnabled ? 2 : 1;
+            expectedTotal = new ImmutableCacheStats(originalHits - removedHitsPerRemovedNode, numNodes - 1, numNodes - 1, 0, 0);
+            assertEquals(expectedTotal, cacheStatsHolder.getImmutableCacheStatsHolder(null).getTotalStats());
+            assertNull(getNodeStats(List.of("A2", "B2", TIER_DIMENSION_VALUE_ON_HEAP), cacheStatsHolder));
+            assertNull(getNodeStats(List.of("A2", "B2", TIER_DIMENSION_VALUE_DISK), cacheStatsHolder));
+            assertNull(getNodeStats(List.of("A2", "B2"), cacheStatsHolder));
+            assertNotNull(getNodeStats(List.of("A2"), cacheStatsHolder));
+            assertNotNull(getNodeStats(List.of("A2", "B3", TIER_DIMENSION_VALUE_ON_HEAP), cacheStatsHolder));
+
+            // When we invalidate A1, B1, we should lose the nodes for B1 and also A1, as it has no more children.
+            cacheStatsHolder.removeDimensions(List.of("A1", "B1"));
+            expectedTotal = new ImmutableCacheStats(originalHits - 2 * removedHitsPerRemovedNode, numNodes - 2, numNodes - 2, 0, 0);
+            assertEquals(expectedTotal, cacheStatsHolder.getImmutableCacheStatsHolder(null).getTotalStats());
+            assertNull(getNodeStats(List.of("A1", "B1", TIER_DIMENSION_VALUE_ON_HEAP), cacheStatsHolder));
+            assertNull(getNodeStats(List.of("A1", "B1", TIER_DIMENSION_VALUE_DISK), cacheStatsHolder));
+            assertNull(getNodeStats(List.of("A1", "B1"), cacheStatsHolder));
+            assertNull(getNodeStats(List.of("A1"), cacheStatsHolder));
+
+            // When we invalidate the last node, all nodes should be deleted except the root node
+            cacheStatsHolder.removeDimensions(List.of("A2", "B3"));
+            assertEquals(new ImmutableCacheStats(0, 0, 0, 0, 0), cacheStatsHolder.getImmutableCacheStatsHolder(null).getTotalStats());
+            // assertEquals(0, cacheStatsHolder.getStatsRoot().getChildren().size());
+        }
+    }
+
+    public void testCount() throws Exception {
+        List<String> dimensionNames = List.of("dim1", "dim2");
+        TieredSpilloverCacheStatsHolder cacheStatsHolder = new TieredSpilloverCacheStatsHolder(dimensionNames, true);
+        Map<String, List<String>> usedDimensionValues = getUsedDimensionValues(cacheStatsHolder, 10, true);
+        Map<List<String>, CacheStats> expected = populateStats(cacheStatsHolder, usedDimensionValues, 100, 10, true);
+
+        long expectedCount = 0L;
+        for (CacheStats counter : expected.values()) {
+            expectedCount += counter.getItems();
+        }
+        assertEquals(expectedCount, cacheStatsHolder.count());
+    }
+
+    public void testConcurrentRemoval() throws Exception {
+        List<String> dimensionNames = List.of("A", "B");
+        TieredSpilloverCacheStatsHolder cacheStatsHolder = new TieredSpilloverCacheStatsHolder(dimensionNames, true);
+
+        // Create stats for the following dimension sets
+        List<List<String>> statsToPopulate = new ArrayList<>();
+        int numAValues = 10;
+        int numBValues = 2;
+        for (int indexA = 0; indexA < numAValues; indexA++) {
+            for (int indexB = 0; indexB < numBValues; indexB++) {
+                statsToPopulate.add(List.of("A" + indexA, "B" + indexB));
+            }
+        }
+        setupRemovalTest(cacheStatsHolder, statsToPopulate, true);
+
+        // Remove a subset of the dimensions concurrently.
+        // Remove both (A0, B0), and (A0, B1), so we expect the intermediate node for A0 to be null afterwards.
+        // For all the others, remove only the B0 value. Then we expect the intermediate nodes for A1 through A9 to be present
+        // and reflect only the stats for their B1 child.
+
+        Thread[] threads = new Thread[numAValues + 1];
+        for (int i = 0; i < numAValues; i++) {
+            int finalI = i;
+            threads[i] = new Thread(() -> { cacheStatsHolder.removeDimensions(List.of("A" + finalI, "B0")); });
+        }
+        threads[numAValues] = new Thread(() -> { cacheStatsHolder.removeDimensions(List.of("A0", "B1")); });
+        for (Thread thread : threads) {
+            thread.start();
+        }
+        for (Thread thread : threads) {
+            thread.join();
+        }
+
+        // intermediate node for A0 should be null
+        assertNull(getNodeStats(List.of("A0"), cacheStatsHolder));
+
+        // leaf nodes for all B0 values should be null since they were removed
+        for (int indexA = 0; indexA < numAValues; indexA++) {
+            assertNull(getNodeStats(List.of("A" + indexA, "B0"), cacheStatsHolder));
+        }
+
+        // leaf nodes for all B1 values, except (A0, B1), should not be null as they weren't removed,
+        // and the intermediate nodes A1 through A9 shouldn't be null as they have remaining children
+        for (int indexA = 1; indexA < numAValues; indexA++) {
+            ImmutableCacheStats b1LeafNodeStats = getNodeStats(List.of("A" + indexA, "B1"), cacheStatsHolder);
+            assertEquals(new ImmutableCacheStats(2, 1, 1, 0, 0), b1LeafNodeStats);
+            ImmutableCacheStats intermediateLevelNodeStats = getNodeStats(List.of("A" + indexA), cacheStatsHolder);
+            assertEquals(b1LeafNodeStats, intermediateLevelNodeStats);
+        }
+    }
+
+    static void setupRemovalTest(
+        TieredSpilloverCacheStatsHolder cacheStatsHolder,
+        List<List<String>> statsToPopulate,
+        boolean diskTierEnabled
+    ) {
+        List<String> tiers = diskTierEnabled ? TIER_VALUES : List.of(TIER_DIMENSION_VALUE_ON_HEAP);
+        for (List<String> dims : statsToPopulate) {
+            // Increment hits, misses, and evictions for set of dimensions, for both heap and disk
+            for (String tier : tiers) {
+                List<String> dimsWithDimension = cacheStatsHolder.getDimensionsWithTierValue(dims, tier);
+                cacheStatsHolder.incrementHits(dimsWithDimension);
+                cacheStatsHolder.incrementMisses(dimsWithDimension);
+                boolean includeInTotal = tier.equals(TIER_DIMENSION_VALUE_DISK) || !diskTierEnabled;
+                cacheStatsHolder.incrementEvictions(dimsWithDimension, includeInTotal);
+            }
+        }
+    }
+
+    /**
+     * Returns the stats from node found by following these dimension values down from the root node.
+     * Returns null if no such node exists.
+     */
+    static ImmutableCacheStats getNodeStats(List<String> dimensionValues, DefaultCacheStatsHolder holder) {
+        String[] levels = holder.getDimensionNames().toArray(new String[0]);
+        ImmutableCacheStatsHolder immutableHolder = holder.getImmutableCacheStatsHolder(levels);
+        return immutableHolder.getStatsForDimensionValues(dimensionValues);
+    }
+
+    static Map<List<String>, CacheStats> populateStats(
+        TieredSpilloverCacheStatsHolder cacheStatsHolder,
+        Map<String, List<String>> usedDimensionValues,
+        int numDistinctValuePairs,
+        int numRepetitionsPerValue,
+        boolean diskTierEnabled
+    ) throws InterruptedException {
+        return populateStats(
+            List.of(cacheStatsHolder),
+            usedDimensionValues,
+            numDistinctValuePairs,
+            numRepetitionsPerValue,
+            diskTierEnabled
+        );
+    }
+
+    static Map<List<String>, CacheStats> populateStats(
+        List<TieredSpilloverCacheStatsHolder> cacheStatsHolders,
+        Map<String, List<String>> usedDimensionValues,
+        int numDistinctValuePairs,
+        int numRepetitionsPerValue,
+        boolean diskTierEnabled
+    ) throws InterruptedException {
+        for (TieredSpilloverCacheStatsHolder statsHolder : cacheStatsHolders) {
+            assertEquals(cacheStatsHolders.get(0).getDimensionNames(), statsHolder.getDimensionNames());
+        }
+        Map<List<String>, CacheStats> expected = new ConcurrentHashMap<>();
+        Thread[] threads = new Thread[numDistinctValuePairs];
+        CountDownLatch countDownLatch = new CountDownLatch(numDistinctValuePairs);
+        Random rand = Randomness.get();
+        List<List<String>> dimensionsForThreads = new ArrayList<>();
+        for (int i = 0; i < numDistinctValuePairs; i++) {
+            dimensionsForThreads.add(getRandomDimList(cacheStatsHolders.get(0).getDimensionNames(), usedDimensionValues, true, rand));
+            int finalI = i;
+            threads[i] = new Thread(() -> {
+                Random threadRand = Randomness.get();
+                List<String> dimensions = dimensionsForThreads.get(finalI);
+                expected.computeIfAbsent(dimensions, (key) -> new CacheStats());
+                for (TieredSpilloverCacheStatsHolder cacheStatsHolder : cacheStatsHolders) {
+                    for (int j = 0; j < numRepetitionsPerValue; j++) {
+                        CacheStats statsToInc = new CacheStats(
+                            threadRand.nextInt(10),
+                            threadRand.nextInt(10),
+                            threadRand.nextInt(10),
+                            threadRand.nextInt(5000),
+                            threadRand.nextInt(10)
+                        );
+                        for (int iter = 0; iter < statsToInc.getHits(); iter++) {
+                            expected.get(dimensions).incrementHits();
+                        }
+                        for (int iter = 0; iter < statsToInc.getMisses(); iter++) {
+                            expected.get(dimensions).incrementMisses();
+                        }
+                        for (int iter = 0; iter < statsToInc.getEvictions(); iter++) {
+                            expected.get(dimensions).incrementEvictions();
+                        }
+                        expected.get(dimensions).incrementSizeInBytes(statsToInc.getSizeInBytes());
+                        for (int iter = 0; iter < statsToInc.getItems(); iter++) {
+                            expected.get(dimensions).incrementItems();
+                        }
+                        populateStatsHolderFromStatsValueMap(cacheStatsHolder, Map.of(dimensions, statsToInc), diskTierEnabled);
+                    }
+                }
+                countDownLatch.countDown();
+            });
+        }
+        for (Thread thread : threads) {
+            thread.start();
+        }
+        countDownLatch.await();
+        return expected;
+    }
+
+    private static List<String> getRandomDimList(
+        List<String> dimensionNames,
+        Map<String, List<String>> usedDimensionValues,
+        boolean pickValueForAllDims,
+        Random rand
+    ) {
+        List<String> result = new ArrayList<>();
+        for (String dimName : dimensionNames) {
+            if (pickValueForAllDims || rand.nextBoolean()) { // if pickValueForAllDims, always pick a value for each dimension, otherwise do
+                // so 50% of the time
+                int index = between(0, usedDimensionValues.get(dimName).size() - 1);
+                result.add(usedDimensionValues.get(dimName).get(index));
+            }
+        }
+        return result;
+    }
+
+    static Map<String, List<String>> getUsedDimensionValues(
+        TieredSpilloverCacheStatsHolder cacheStatsHolder,
+        int numValuesPerDim,
+        boolean diskTierEnabled
+    ) {
+        Map<String, List<String>> usedDimensionValues = new HashMap<>();
+        for (int i = 0; i < cacheStatsHolder.getDimensionNames().size() - 1; i++) { // Have to handle final tier dimension separately
+            List<String> values = new ArrayList<>();
+            for (int j = 0; j < numValuesPerDim; j++) {
+                values.add(UUID.randomUUID().toString());
+            }
+            usedDimensionValues.put(cacheStatsHolder.getDimensionNames().get(i), values);
+        }
+        if (diskTierEnabled) {
+            usedDimensionValues.put(TieredSpilloverCacheStatsHolder.TIER_DIMENSION_NAME, TIER_VALUES);
+        } else {
+            usedDimensionValues.put(TieredSpilloverCacheStatsHolder.TIER_DIMENSION_NAME, List.of(TIER_DIMENSION_VALUE_ON_HEAP));
+        }
+        return usedDimensionValues;
+    }
+
+    public static void populateStatsHolderFromStatsValueMap(
+        TieredSpilloverCacheStatsHolder cacheStatsHolder,
+        Map<List<String>, CacheStats> statsMap,
+        boolean diskTierEnabled
+    ) {
+        for (Map.Entry<List<String>, CacheStats> entry : statsMap.entrySet()) {
+            CacheStats stats = entry.getValue();
+            List<String> dims = entry.getKey();
+            for (int i = 0; i < stats.getHits(); i++) {
+                cacheStatsHolder.incrementHits(dims);
+            }
+            for (int i = 0; i < stats.getMisses(); i++) {
+                cacheStatsHolder.incrementMisses(dims);
+            }
+            for (int i = 0; i < stats.getEvictions(); i++) {
+                boolean includeInTotal = dims.get(dims.size() - 1).equals(TIER_DIMENSION_VALUE_DISK) || !diskTierEnabled;
+                cacheStatsHolder.incrementEvictions(dims, includeInTotal);
+            }
+            cacheStatsHolder.incrementSizeInBytes(dims, stats.getSizeInBytes());
+            for (int i = 0; i < stats.getItems(); i++) {
+                cacheStatsHolder.incrementItems(dims);
+            }
+        }
+    }
+}
diff --git a/modules/cache-common/src/test/java/org/opensearch/cache/common/tier/TieredSpilloverCacheTests.java b/modules/cache-common/src/test/java/org/opensearch/cache/common/tier/TieredSpilloverCacheTests.java
index 1215a2130ac2d..3bb1321f9faf2 100644
--- a/modules/cache-common/src/test/java/org/opensearch/cache/common/tier/TieredSpilloverCacheTests.java
+++ b/modules/cache-common/src/test/java/org/opensearch/cache/common/tier/TieredSpilloverCacheTests.java
@@ -2112,6 +2112,60 @@ public void testTieredCacheDefaultSegmentCount() {
         assertTrue(VALID_SEGMENT_COUNT_VALUES.contains(tieredSpilloverCache.getNumberOfSegments()));
     }
 
+    public void testDropStatsForDimensions() throws Exception {
+        int onHeapCacheSize = randomIntBetween(300, 600);
+        int diskCacheSize = randomIntBetween(700, 1200);
+        int numberOfSegments = getNumberOfSegments();
+        int keyValueSize = 50;
+        MockCacheRemovalListener<String, String> removalListener = new MockCacheRemovalListener<>();
+        TieredSpilloverCache<String, String> tieredSpilloverCache = initializeTieredSpilloverCache(
+            keyValueSize,
+            diskCacheSize,
+            removalListener,
+            Settings.builder()
+                .put(
+                    TieredSpilloverCacheSettings.TIERED_SPILLOVER_ONHEAP_STORE_SIZE.getConcreteSettingForNamespace(
+                        CacheType.INDICES_REQUEST_CACHE.getSettingPrefix()
+                    ).getKey(),
+                    onHeapCacheSize * keyValueSize + "b"
+                )
+                .build(),
+            0,
+            numberOfSegments
+        );
+
+        List<ICacheKey<String>> usedKeys = new ArrayList<>();
+        // Fill the cache, getting some entries + evictions for both tiers
+        int minMisses = (diskCacheSize + onHeapCacheSize) / keyValueSize + 10;
+        int numMisses = onHeapCacheSize + diskCacheSize + randomIntBetween(minMisses, minMisses + 50);
+        for (int iter = 0; iter < numMisses; iter++) {
+            ICacheKey<String> key = getICacheKey(UUID.randomUUID().toString());
+            usedKeys.add(key);
+            LoadAwareCacheLoader<ICacheKey<String>, String> tieredCacheLoader = getLoadAwareCacheLoader();
+            tieredSpilloverCache.computeIfAbsent(key, tieredCacheLoader);
+        }
+        // Also do some random hits
+        Random rand = Randomness.get();
+        int approxNumHits = 30;
+        for (int i = 0; i < approxNumHits; i++) {
+            LoadAwareCacheLoader<ICacheKey<String>, String> tieredCacheLoader = getLoadAwareCacheLoader();
+            ICacheKey<String> key = usedKeys.get(rand.nextInt(usedKeys.size()));
+            tieredSpilloverCache.computeIfAbsent(key, tieredCacheLoader);
+        }
+
+        ImmutableCacheStats totalStats = tieredSpilloverCache.stats().getTotalStats();
+        assertTrue(totalStats.getHits() > 0);
+        assertTrue(totalStats.getMisses() > 0);
+        assertTrue(totalStats.getEvictions() > 0);
+
+        // Since all the keys have the same dimension values, except tiers, we only need to remove that one, and we expect all stats values
+        // should be 0 after that.
+        ICacheKey<String> dropDimensionsKey = new ICacheKey<>(null, getMockDimensions());
+        dropDimensionsKey.setDropStatsForDimensions(true);
+        tieredSpilloverCache.invalidate(dropDimensionsKey);
+        assertEquals(new ImmutableCacheStats(0, 0, 0, 0, 0), tieredSpilloverCache.stats().getTotalStats());
+    }
+
     private List<String> getMockDimensions() {
         List<String> dims = new ArrayList<>();
         for (String dimensionName : dimensionNames) {
diff --git a/server/src/main/java/org/opensearch/common/cache/stats/DefaultCacheStatsHolder.java b/server/src/main/java/org/opensearch/common/cache/stats/DefaultCacheStatsHolder.java
index ea92c8e81b8f0..f83b812af8299 100644
--- a/server/src/main/java/org/opensearch/common/cache/stats/DefaultCacheStatsHolder.java
+++ b/server/src/main/java/org/opensearch/common/cache/stats/DefaultCacheStatsHolder.java
@@ -37,10 +37,10 @@ public class DefaultCacheStatsHolder implements CacheStatsHolder {
     // Non-leaf nodes have stats matching the sum of their children.
     // We use a tree structure, rather than a map with concatenated keys, to save on memory usage. If there are many leaf
     // nodes that share a parent, that parent's dimension value will only be stored once, not many times.
-    private final Node statsRoot;
+    protected final Node statsRoot;
     // To avoid sync problems, obtain a lock before creating or removing nodes in the stats tree.
     // No lock is needed to edit stats on existing nodes.
-    private final Lock lock = new ReentrantLock();
+    protected final Lock lock = new ReentrantLock();
     // The name of the cache type using these stats
     private final String storeName;
 
@@ -188,8 +188,10 @@ public void removeDimensions(List<String> dimensionValues) {
     }
 
     // Returns a CacheStatsCounterSnapshot object for the stats to decrement if the removal happened, null otherwise.
-    private ImmutableCacheStats removeDimensionsHelper(List<String> dimensionValues, Node node, int depth) {
+    protected ImmutableCacheStats removeDimensionsHelper(List<String> dimensionValues, Node node, int depth) {
         if (depth == dimensionValues.size()) {
+            // Remove children, if present.
+            node.children.clear();
             // Pass up a snapshot of the original stats to avoid issues when the original is decremented by other fn invocations
             return node.getImmutableStats();
         }
@@ -208,7 +210,6 @@ private ImmutableCacheStats removeDimensionsHelper(List<String> dimensionValues,
         return statsToDecrement;
     }
 
-    // pkg-private for testing
     Node getStatsRoot() {
         return statsRoot;
     }
@@ -241,7 +242,7 @@ public String getDimensionValue() {
             return dimensionValue;
         }
 
-        protected Map<String, Node> getChildren() {
+        public Map<String, Node> getChildren() {
             // We can safely iterate over ConcurrentHashMap without worrying about thread issues.
             return children;
         }
@@ -280,7 +281,7 @@ long getEntries() {
             return this.stats.getItems();
         }
 
-        ImmutableCacheStats getImmutableStats() {
+        public ImmutableCacheStats getImmutableStats() {
             return this.stats.immutableSnapshot();
         }