Skip to content

Commit

Permalink
[Tiered Caching] Enabling serialization for IndicesRequestCache key o…
Browse files Browse the repository at this point in the history
…bject

Signed-off-by: Sagar Upadhyaya <[email protected]>
  • Loading branch information
sgup432 committed Sep 29, 2023
1 parent 7dc6683 commit df8b26e
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,45 @@ public void testProfileDisableCache() throws Exception {
}
}

public void testCacheWithInvalidation() throws Exception {
Client client = client();
assertAcked(
client.admin()
.indices()
.prepareCreate("index")
.setMapping("k", "type=keyword")
.setSettings(
Settings.builder()
.put(IndicesRequestCache.INDEX_CACHE_REQUEST_ENABLED_SETTING.getKey(), true)
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
)
.get()
);
indexRandom(true, client.prepareIndex("index").setSource("k", "hello"));
ensureSearchable("index");
SearchResponse resp = client.prepareSearch("index").setRequestCache(true).setQuery(QueryBuilders.termQuery("k", "hello")).get();
assertSearchResponse(resp);
OpenSearchAssertions.assertAllSuccessful(resp);
assertThat(resp.getHits().getTotalHits().value, equalTo(1L));

assertCacheState(client, "index", 0, 1);
// Index but don't refresh
indexRandom(false, client.prepareIndex("index").setSource("k", "hello2"));
resp = client.prepareSearch("index").setRequestCache(true).setQuery(QueryBuilders.termQuery("k", "hello")).get();
assertSearchResponse(resp);
// Should expect hit as here as refresh didn't happen
assertCacheState(client, "index", 1, 1);

// Explicit refresh would invalidate cache
refresh();
// Hit same query again
resp = client.prepareSearch("index").setRequestCache(true).setQuery(QueryBuilders.termQuery("k", "hello")).get();
assertSearchResponse(resp);
// Should expect miss as key has changed due to change in IndexReader.CacheKey (due to refresh)
assertCacheState(client, "index", 1, 2);
}

private static void assertCacheState(Client client, String index, long expectedHits, long expectedMisses) {
RequestCacheStats requestCacheStats = client.admin()
.indices()
Expand All @@ -648,6 +687,7 @@ private static void assertCacheState(Client client, String index, long expectedH
Arrays.asList(expectedHits, expectedMisses, 0L),
Arrays.asList(requestCacheStats.getHitCount(), requestCacheStats.getMissCount(), requestCacheStats.getEvictions())
);

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.opensearch.core.index.shard.ShardId;

import java.io.IOException;
import java.util.UUID;

/**
* A {@link org.apache.lucene.index.FilterDirectoryReader} that exposes
Expand All @@ -51,11 +52,14 @@ public final class OpenSearchDirectoryReader extends FilterDirectoryReader {
private final ShardId shardId;
private final FilterDirectoryReader.SubReaderWrapper wrapper;

private DelegatingCacheHelper delegatingCacheHelper;

private OpenSearchDirectoryReader(DirectoryReader in, FilterDirectoryReader.SubReaderWrapper wrapper, ShardId shardId)
throws IOException {
super(in, wrapper);
this.wrapper = wrapper;
this.shardId = shardId;
this.delegatingCacheHelper = new DelegatingCacheHelper(in.getReaderCacheHelper());
}

/**
Expand All @@ -68,7 +72,53 @@ public ShardId shardId() {
@Override
public CacheHelper getReaderCacheHelper() {
// safe to delegate since this reader does not alter the index
return in.getReaderCacheHelper();
return this.delegatingCacheHelper;
}

public DelegatingCacheHelper getDelegatingCacheHelper() {
return this.delegatingCacheHelper;
}

public class DelegatingCacheHelper implements CacheHelper {
CacheHelper cacheHelper;
DelegatingCacheKey serializableCacheKey;

DelegatingCacheHelper(CacheHelper cacheHelper) {
this.cacheHelper = cacheHelper;
this.serializableCacheKey = new DelegatingCacheKey(cacheHelper.getKey());
}

@Override
public CacheKey getKey() {
return this.cacheHelper.getKey();
}

public DelegatingCacheKey getDelegatingCacheKey() {
return this.serializableCacheKey;
}

@Override
public void addClosedListener(ClosedListener listener) {
this.cacheHelper.addClosedListener(listener);
}
}

public class DelegatingCacheKey {
CacheKey cacheKey;
private final UUID uniqueId;

DelegatingCacheKey(CacheKey cacheKey) {
this.cacheKey = cacheKey;
this.uniqueId = UUID.randomUUID();
}

public CacheKey getCacheKey() {
return this.cacheKey;
}

public UUID getId() {
return uniqueId;
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.ConcurrentCollections;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.common.unit.ByteSizeValue;

import java.io.Closeable;
Expand Down Expand Up @@ -108,8 +111,9 @@ public final class IndicesRequestCache implements RemovalListener<IndicesRequest
private final ByteSizeValue size;
private final TimeValue expire;
private final Cache<Key, BytesReference> cache;
private final IndicesService indicesService;

IndicesRequestCache(Settings settings) {
IndicesRequestCache(Settings settings, IndicesService indicesService) {
this.size = INDICES_CACHE_QUERY_SIZE.get(settings);
this.expire = INDICES_CACHE_QUERY_EXPIRE.exists(settings) ? INDICES_CACHE_QUERY_EXPIRE.get(settings) : null;
long sizeInBytes = size.getBytes();
Expand All @@ -121,6 +125,7 @@ public final class IndicesRequestCache implements RemovalListener<IndicesRequest
cacheBuilder.setExpireAfterAccess(expire);
}
cache = cacheBuilder.build();
this.indicesService = indicesService;
}

@Override
Expand All @@ -145,13 +150,19 @@ BytesReference getOrCompute(
BytesReference cacheKey
) throws Exception {
assert reader.getReaderCacheHelper() != null;
final Key key = new Key(cacheEntity, reader.getReaderCacheHelper().getKey(), cacheKey);
assert reader.getReaderCacheHelper() instanceof OpenSearchDirectoryReader.DelegatingCacheHelper;

OpenSearchDirectoryReader.DelegatingCacheHelper delegatingCacheHelper = (OpenSearchDirectoryReader.DelegatingCacheHelper) reader
.getReaderCacheHelper();
String readerCacheKeyUniqueId = delegatingCacheHelper.getDelegatingCacheKey().getId().toString();
assert readerCacheKeyUniqueId != null;
final Key key = new Key(cacheEntity, cacheKey, readerCacheKeyUniqueId);
Loader cacheLoader = new Loader(cacheEntity, loader);
BytesReference value = cache.computeIfAbsent(key, cacheLoader);
if (cacheLoader.isLoaded()) {
key.entity.onMiss();
// see if its the first time we see this reader, and make sure to register a cleanup key
CleanupKey cleanupKey = new CleanupKey(cacheEntity, reader.getReaderCacheHelper().getKey());
CleanupKey cleanupKey = new CleanupKey(cacheEntity, readerCacheKeyUniqueId);
if (!registeredClosedListeners.containsKey(cleanupKey)) {
Boolean previous = registeredClosedListeners.putIfAbsent(cleanupKey, Boolean.TRUE);
if (previous == null) {
Expand All @@ -172,15 +183,22 @@ BytesReference getOrCompute(
*/
void invalidate(CacheEntity cacheEntity, DirectoryReader reader, BytesReference cacheKey) {
assert reader.getReaderCacheHelper() != null;
cache.invalidate(new Key(cacheEntity, reader.getReaderCacheHelper().getKey(), cacheKey));
String readerCacheKeyUniqueId = null;
if (reader instanceof OpenSearchDirectoryReader) {
IndexReader.CacheHelper cacheHelper = ((OpenSearchDirectoryReader) reader).getDelegatingCacheHelper();
readerCacheKeyUniqueId = ((OpenSearchDirectoryReader.DelegatingCacheHelper) cacheHelper).getDelegatingCacheKey()
.getId()
.toString();
}
cache.invalidate(new Key(cacheEntity, cacheKey, readerCacheKeyUniqueId));
}

/**
* Loader for the request cache
*
* @opensearch.internal
*/
private static class Loader implements CacheLoader<Key, BytesReference> {
protected static class Loader implements CacheLoader<Key, BytesReference> {

private final CacheEntity entity;
private final CheckedSupplier<BytesReference, IOException> loader;
Expand All @@ -207,7 +225,7 @@ public BytesReference load(Key key) throws Exception {
/**
* Basic interface to make this cache testable.
*/
interface CacheEntity extends Accountable {
interface CacheEntity extends Accountable, Writeable {

/**
* Called after the value was loaded.
Expand Down Expand Up @@ -240,24 +258,31 @@ interface CacheEntity extends Accountable {
* Called when this entity instance is removed
*/
void onRemoval(RemovalNotification<Key, BytesReference> notification);

}

/**
* Unique key for the cache
*
* @opensearch.internal
*/
static class Key implements Accountable {
private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(Key.class);
class Key implements Accountable, Writeable {
private final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(Key.class);

public final CacheEntity entity; // use as identity equality
public final IndexReader.CacheKey readerCacheKey;
public final String readerCacheKeyUniqueId;
public final BytesReference value;

Key(CacheEntity entity, IndexReader.CacheKey readerCacheKey, BytesReference value) {
Key(CacheEntity entity, BytesReference value, String readerCacheKeyUniqueId) {
this.entity = entity;
this.readerCacheKey = Objects.requireNonNull(readerCacheKey);
this.value = value;
this.readerCacheKeyUniqueId = Objects.requireNonNull(readerCacheKeyUniqueId);
}

Key(StreamInput in) throws IOException {
this.entity = in.readOptionalWriteable(in1 -> indicesService.new IndexShardCacheEntity(in1));
this.readerCacheKeyUniqueId = in.readOptionalString();
this.value = in.readBytesReference();
}

@Override
Expand All @@ -276,7 +301,7 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Key key = (Key) o;
if (Objects.equals(readerCacheKey, key.readerCacheKey) == false) return false;
if (Objects.equals(readerCacheKeyUniqueId, key.readerCacheKeyUniqueId) == false) return false;
if (!entity.getCacheIdentity().equals(key.entity.getCacheIdentity())) return false;
if (!value.equals(key.value)) return false;
return true;
Expand All @@ -285,19 +310,26 @@ public boolean equals(Object o) {
@Override
public int hashCode() {
int result = entity.getCacheIdentity().hashCode();
result = 31 * result + readerCacheKey.hashCode();
result = 31 * result + readerCacheKeyUniqueId.hashCode();
result = 31 * result + value.hashCode();
return result;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalWriteable(entity);
out.writeOptionalString(readerCacheKeyUniqueId);
out.writeBytesReference(value);
}
}

private class CleanupKey implements IndexReader.ClosedListener {
final CacheEntity entity;
final IndexReader.CacheKey readerCacheKey;
final String readerCacheKeyUniqueId;

private CleanupKey(CacheEntity entity, IndexReader.CacheKey readerCacheKey) {
private CleanupKey(CacheEntity entity, String readerCacheKeyUniqueId) {
this.entity = entity;
this.readerCacheKey = readerCacheKey;
this.readerCacheKeyUniqueId = readerCacheKeyUniqueId;
}

@Override
Expand All @@ -315,15 +347,15 @@ public boolean equals(Object o) {
return false;
}
CleanupKey that = (CleanupKey) o;
if (Objects.equals(readerCacheKey, that.readerCacheKey) == false) return false;
if (Objects.equals(readerCacheKeyUniqueId, that.readerCacheKeyUniqueId) == false) return false;
if (!entity.getCacheIdentity().equals(that.entity.getCacheIdentity())) return false;
return true;
}

@Override
public int hashCode() {
int result = entity.getCacheIdentity().hashCode();
result = 31 * result + Objects.hashCode(readerCacheKey);
result = 31 * result + Objects.hashCode(readerCacheKeyUniqueId);
return result;
}
}
Expand All @@ -336,7 +368,7 @@ synchronized void cleanCache() {
for (Iterator<CleanupKey> iterator = keysToClean.iterator(); iterator.hasNext();) {
CleanupKey cleanupKey = iterator.next();
iterator.remove();
if (cleanupKey.readerCacheKey == null || cleanupKey.entity.isOpen() == false) {
if (cleanupKey.readerCacheKeyUniqueId == null || cleanupKey.entity.isOpen() == false) {
// null indicates full cleanup, as does a closed shard
currentFullClean.add(cleanupKey.entity.getCacheIdentity());
} else {
Expand All @@ -349,7 +381,7 @@ synchronized void cleanCache() {
if (currentFullClean.contains(key.entity.getCacheIdentity())) {
iterator.remove();
} else {
if (currentKeysToClean.contains(new CleanupKey(key.entity, key.readerCacheKey))) {
if (currentKeysToClean.contains(new CleanupKey(key.entity, key.readerCacheKeyUniqueId))) {
iterator.remove();
}
}
Expand Down
21 changes: 17 additions & 4 deletions server/src/main/java/org/opensearch/indices/IndicesService.java
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ public IndicesService(
this.shardsClosedTimeout = settings.getAsTime(INDICES_SHARDS_CLOSED_TIMEOUT, new TimeValue(1, TimeUnit.DAYS));
this.analysisRegistry = analysisRegistry;
this.indexNameExpressionResolver = indexNameExpressionResolver;
this.indicesRequestCache = new IndicesRequestCache(settings);
this.indicesRequestCache = new IndicesRequestCache(settings, this);
this.indicesQueryCache = new IndicesQueryCache(settings);
this.mapperRegistry = mapperRegistry;
this.namedWriteableRegistry = namedWriteableRegistry;
Expand Down Expand Up @@ -1746,14 +1746,21 @@ private BytesReference cacheShardLevelResult(
*
* @opensearch.internal
*/
static final class IndexShardCacheEntity extends AbstractIndexShardCacheEntity {
private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(IndexShardCacheEntity.class);
public final class IndexShardCacheEntity extends AbstractIndexShardCacheEntity {
private final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(IndexShardCacheEntity.class);
private final IndexShard indexShard;

protected IndexShardCacheEntity(IndexShard indexShard) {
public IndexShardCacheEntity(IndexShard indexShard) {
this.indexShard = indexShard;
}

public IndexShardCacheEntity(StreamInput in) throws IOException {
Index index = in.readOptionalWriteable(Index::new);
int shardId = in.readVInt();
IndexService indexService = indices.get(index.getUUID());
this.indexShard = Optional.ofNullable(indexService).map(indexService1 -> indexService1.getShard(shardId)).orElse(null);
}

@Override
protected ShardRequestCache stats() {
return indexShard.requestCache();
Expand All @@ -1775,6 +1782,12 @@ public long ramBytesUsed() {
// across many entities
return BASE_RAM_BYTES_USED;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalWriteable(indexShard.shardId().getIndex());
out.writeVInt(indexShard.shardId().id());
}
}

@FunctionalInterface
Expand Down
Loading

0 comments on commit df8b26e

Please sign in to comment.