Skip to content

Commit

Permalink
#3 and #4 Working on clickthrough rates.
Browse files Browse the repository at this point in the history
Signed-off-by: jzonthemtn <[email protected]>
  • Loading branch information
jzonthemtn committed Sep 19, 2024
1 parent dab5944 commit b20f08a
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.searchevaluationframework.model.ClickthroughRate;
import org.opensearch.searchevaluationframework.model.Judgment;

import java.util.Collection;
import java.util.Map;
Expand All @@ -19,7 +20,11 @@ public static void main(String[] args) throws Exception {
final Map<Integer, Double> rankAggregatedClickThrough = openSearchEvaluationFramework.getRankAggregatedClickThrough();

// Calculate the click-through rate for query/doc pairs.
final Collection<ClickthroughRate> clickthroughRates = openSearchEvaluationFramework.getClickthroughRate();
final Map<String, Collection<ClickthroughRate>> clickthroughRates = openSearchEvaluationFramework.getClickthroughRate();

// TODO: Generate the implicit judgments.
// Format: datetime, query_id, query, document, judgment
final Collection<Judgment> judgments = openSearchEvaluationFramework.getJudgments(rankAggregatedClickThrough, clickthroughRates);

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.searchevaluationframework.model.ClickthroughRate;
import org.opensearch.searchevaluationframework.model.Judgment;
import org.opensearch.searchevaluationframework.model.UbiEvent;

import java.io.IOException;
Expand All @@ -31,14 +32,15 @@

public class OpenSearchEvaluationFramework {

// OpenSearch indexes.
public static final String INDEX_UBI_EVENTS = "ubi_events";
public static final String INDEX_UBI_QUERIES = "ubi_queries";
public static final String INDEX_RANK_AGGREGATED_CTR = "rank_aggregated_ctr";
public static final String INDEX_QUERY_DOC_CTR = "click_through_rates";
public static final String INDEX_JUDGMENT = "judgments";

public static final String EVENT_CLICK = "click";


private final RestHighLevelClient client;

public OpenSearchEvaluationFramework() {
Expand All @@ -48,11 +50,21 @@ public OpenSearchEvaluationFramework() {

}

public Collection<ClickthroughRate> getClickthroughRate() throws IOException {
public Collection<Judgment> getJudgments(final Map<Integer, Double> rankAggregatedClickThrough, Map<String, Collection<ClickthroughRate>> clickthroughRates) throws IOException {

final Collection<Judgment> judgments = new LinkedList<>();

indexJudgments(judgments);

return judgments;

}

public Map<String, Collection<ClickthroughRate>> getClickthroughRate() throws IOException {

// For each query:
// - Get each document returned in that query (in the QueryResponse object).
// - Calculate the clickthrough rate for the document. (clicks/impressions)
// - Calculate the click-through rate for the document. (clicks/impressions)

final String query = "{\"match_all\":{}}";
final BoolQueryBuilder queryBuilder = new BoolQueryBuilder().must(new WrapperQueryBuilder(query));
Expand All @@ -68,24 +80,23 @@ public Collection<ClickthroughRate> getClickthroughRate() throws IOException {
String scrollId = searchResponse.getScrollId();
SearchHit[] searchHits = searchResponse.getHits().getHits();

final Collection<ClickthroughRate> clickthroughRates = new LinkedList<>();
final Map<String, Collection<ClickthroughRate>> queriesToClickthroughRates = new HashMap<>();

while (searchHits != null && searchHits.length > 0) {

for (final SearchHit hit : searchHits) {

final UbiEvent ubiEvent = new UbiEvent(hit);
final ClickthroughRate clickthroughRate = new ClickthroughRate(ubiEvent.getQueryId());

final Collection<ClickthroughRate> clickthroughRates = queriesToClickthroughRates.getOrDefault(ubiEvent.getQueryId(), new LinkedList<>());
final ClickthroughRate clickthroughRate = clickthroughRates.stream().filter(p -> p.getObjectId().equals(ubiEvent.getObjectId())).findFirst().orElse(new ClickthroughRate(ubiEvent.getObjectId()));

if (StringUtils.equalsIgnoreCase(ubiEvent.getActionName(), EVENT_CLICK)) {
clickthroughRate.logClick();
} else {
clickthroughRate.logEvent();
}

clickthroughRates.add(clickthroughRate);
System.out.println(clickthroughRate.toString());

}

final SearchScrollRequest scrollRequest = new SearchScrollRequest(scrollId);
Expand All @@ -98,9 +109,9 @@ public Collection<ClickthroughRate> getClickthroughRate() throws IOException {

}

index(clickthroughRates);
indexClickthroughRates(queriesToClickthroughRates);

return clickthroughRates;
return queriesToClickthroughRates;

}

Expand Down Expand Up @@ -156,25 +167,25 @@ public Map<Integer, Double> getRankAggregatedClickThrough() throws IOException {
}

// Now for each position, divide its value by the total number of events.
// This is the click-through rate.
for(final Integer i : rankAggregatedClickThrough.keySet()) {
rankAggregatedClickThrough.put(i, rankAggregatedClickThrough.get(i) / totalEvents);
}

// Clear the scroll
final ClearScrollRequest clearScrollRequest = new ClearScrollRequest();
clearScrollRequest.addScrollId(scrollId);
client.clearScroll(clearScrollRequest, RequestOptions.DEFAULT);

System.out.println("Rank-aggregated click through: " + rankAggregatedClickThrough);
System.out.println("Number of total events: " + totalEvents);

index(rankAggregatedClickThrough);
indexRankAggregatedClickthrough(rankAggregatedClickThrough);

return rankAggregatedClickThrough;

}

private void index(final Map<Integer, Double> rankAggregatedClickThrough) throws IOException {
private void indexRankAggregatedClickthrough(final Map<Integer, Double> rankAggregatedClickThrough) throws IOException {

if(!rankAggregatedClickThrough.isEmpty()) {

Expand All @@ -198,21 +209,52 @@ private void index(final Map<Integer, Double> rankAggregatedClickThrough) throws

}

private void index(final Collection<ClickthroughRate> clickthroughRates) throws IOException {
private void indexClickthroughRates(final Map<String, Collection<ClickthroughRate>> clickthroughRates) throws IOException {

if(!clickthroughRates.isEmpty()) {

final BulkRequest request = new BulkRequest();

for (final ClickthroughRate clickthroughRate : clickthroughRates) {
for(final String queryId : clickthroughRates.keySet()) {

for(final ClickthroughRate clickthroughRate : clickthroughRates.get(queryId)) {

final Map<String, Object> jsonMap = new HashMap<>();
jsonMap.put("query_id", queryId);
jsonMap.put("clicks", clickthroughRate.getClicks());
jsonMap.put("events", clickthroughRate.getEvents());
jsonMap.put("ctr", clickthroughRate.getClickthroughRate());

final IndexRequest indexRequest = new IndexRequest(INDEX_QUERY_DOC_CTR).id(UUID.randomUUID().toString()).source(jsonMap);

request.add(indexRequest);

}

}

client.bulk(request, RequestOptions.DEFAULT);

}

}

private void indexJudgments(final Collection<Judgment> judgments) throws IOException {

if(!judgments.isEmpty()) {

final BulkRequest request = new BulkRequest();

for (final Judgment judgment : judgments) {

final Map<String, Object> jsonMap = new HashMap<>();
jsonMap.put("query_id", clickthroughRate.getQueryId());
jsonMap.put("clicks", clickthroughRate.getClicks());
jsonMap.put("events", clickthroughRate.getEvents());
jsonMap.put("ctr", clickthroughRate.getClickthroughRate());
jsonMap.put("timestamp", judgment.getTimestamp());
jsonMap.put("query_id", judgment.getQueryId());
jsonMap.put("query", judgment.getQuery());
jsonMap.put("document", judgment.getDocument());
jsonMap.put("judgment", judgment.getJudgment());

final IndexRequest indexRequest = new IndexRequest(INDEX_QUERY_DOC_CTR).id(UUID.randomUUID().toString()).source(jsonMap);
final IndexRequest indexRequest = new IndexRequest(INDEX_JUDGMENT).id(UUID.randomUUID().toString()).source(jsonMap);

request.add(indexRequest);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

public class ClickthroughRate {

private final String queryId;
private final String objectId;
private int clicks;
private int events;

public ClickthroughRate(String queryId) {
this.queryId = queryId;
public ClickthroughRate(final String objectId) {
this.objectId = objectId;
}

public void logClick() {
Expand All @@ -25,11 +25,7 @@ public double getClickthroughRate() {

@Override
public String toString() {
return "queryId: " + queryId + ", clicks: " + clicks + ", events: " + events + ", ctr: " + getClickthroughRate();
}

public String getQueryId() {
return queryId;
return "clicks: " + clicks + ", events: " + events + ", ctr: " + getClickthroughRate();
}

public int getClicks() {
Expand All @@ -40,4 +36,8 @@ public int getEvents() {
return events;
}

public String getObjectId() {
return objectId;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package org.opensearch.searchevaluationframework.model;

public class Judgment {

private final long timestamp;
private final String queryId;
private final String query;
private final String document;
private final double judgment;

public Judgment(final long timestamp, final String queryId, final String query, final String document, final double judgment) {
this.timestamp = timestamp;
this.queryId = queryId;
this.query = query;
this.document = document;
this.judgment = judgment;
}

public long getTimestamp() {
return timestamp;
}

public String getQueryId() {
return queryId;
}

public String getQuery() {
return query;
}

public String getDocument() {
return document;
}

public double getJudgment() {
return judgment;
}

}

0 comments on commit b20f08a

Please sign in to comment.