Skip to content

Commit

Permalink
Merge pull request #21 from patterns-ai-core/rerank-method
Browse files Browse the repository at this point in the history
Add rerank method and bump version
  • Loading branch information
andreibondarev authored Aug 1, 2024
2 parents 320d9c8 + 449e39a commit 8b25faa
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 34 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
## [Unreleased]

## [0.9.11] - 2024-08-01
- New `rerank()` method

## [0.9.10] - 2024-05-10
- /chat endpoint does not require `message:` parameter anymore

Expand Down
2 changes: 1 addition & 1 deletion Gemfile.lock
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
PATH
remote: .
specs:
cohere-ruby (0.9.10)
cohere-ruby (0.9.11)
faraday (>= 2.0.1, < 3.0)

GEM
Expand Down
59 changes: 37 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Cohere

<p>
<img alt='Weaviate logo' src='https://static.wikia.nocookie.net/logopedia/images/d/d4/Cohere_2023.svg/revision/latest?cb=20230419182227' height='50' />
<img alt='Cohere logo' src='https://static.wikia.nocookie.net/logopedia/images/d/d4/Cohere_2023.svg/revision/latest?cb=20230419182227' height='50' />
+&nbsp;&nbsp;
<img alt='Ruby logo' src='https://user-images.githubusercontent.com/541665/230231593-43861278-4550-421d-a543-fd3553aac4f6.png' height='40' />
</p>
Expand Down Expand Up @@ -42,15 +42,15 @@ client = Cohere::Client.new(

```ruby
client.generate(
prompt: "Once upon a time in a magical land called"
prompt: "Once upon a time in a magical land called"
)
```

### Chat

```ruby
client.chat(
message: "Hey! How are you?"
message: "Hey! How are you?"
)
```

Expand Down Expand Up @@ -90,30 +90,45 @@ client.chat(
)
```



### Embed

```ruby
client.embed(
texts: ["hello!"]
texts: ["hello!"]
)
```

### Rerank

```ruby
docs = [
"Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.",
]

client.rerank(
texts: ["hello!"]
)
```


### Classify

```ruby
examples = [
{ text: "Dermatologists don't like her!", label: "Spam" },
{ text: "Hello, open to this?", label: "Spam" },
{ text: "I need help please wire me $1000 right now", label: "Spam" },
{ text: "Nice to know you ;)", label: "Spam" },
{ text: "Please help me?", label: "Spam" },
{ text: "Your parcel will be delivered today", label: "Not spam" },
{ text: "Review changes to our Terms and Conditions", label: "Not spam" },
{ text: "Weekly sync notes", label: "Not spam" },
{ text: "Re: Follow up from today's meeting", label: "Not spam" },
{ text: "Pre-read for tomorrow", label: "Not spam" }
{ text: "Dermatologists don't like her!", label: "Spam" },
{ text: "Hello, open to this?", label: "Spam" },
{ text: "I need help please wire me $1000 right now", label: "Spam" },
{ text: "Nice to know you ;)", label: "Spam" },
{ text: "Please help me?", label: "Spam" },
{ text: "Your parcel will be delivered today", label: "Not spam" },
{ text: "Review changes to our Terms and Conditions", label: "Not spam" },
{ text: "Weekly sync notes", label: "Not spam" },
{ text: "Re: Follow up from today's meeting", label: "Not spam" },
{ text: "Pre-read for tomorrow", label: "Not spam" }
]

inputs = [
Expand All @@ -122,40 +137,40 @@ inputs = [
]

client.classify(
examples: examples,
inputs: inputs
examples: examples,
inputs: inputs
)
```

### Tokenize

```ruby
client.tokenize(
text: "hello world!"
text: "hello world!"
)
```

### Detokenize

```ruby
client.detokenize(
tokens: [33555, 1114 , 34]
tokens: [33555, 1114 , 34]
)
```

### Detect language

```ruby
client.detect_language(
texts: ["Здравствуй, Мир"]
texts: ["Здравствуй, Мир"]
)
```

### Summarize

```ruby
client.summarize(
text: "..."
text: "..."
)
```

Expand Down
23 changes: 23 additions & 0 deletions lib/cohere/client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,29 @@ def embed(
response.body
end

def rerank(
query:,
documents:,
model: nil,
top_n: nil,
rank_fields: nil,
return_documents: nil,
max_chunks_per_doc: nil
)
response = connection.post("rerank") do |req|
req.body = {
query: query,
documents: documents
}
req.body[:model] = model if model
req.body[:top_n] = top_n if top_n
req.body[:rank_fields] = rank_fields if rank_fields
req.body[:return_documents] = return_documents if return_documents
req.body[:max_chunks_per_doc] = max_chunks_per_doc if max_chunks_per_doc
end
response.body
end

def classify(
inputs:,
examples:,
Expand Down
2 changes: 1 addition & 1 deletion lib/cohere/version.rb
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# frozen_string_literal: true

module Cohere
VERSION = "0.9.10"
VERSION = "0.9.11"
end
49 changes: 39 additions & 10 deletions spec/cohere/client_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
require "spec_helper"

RSpec.describe Cohere::Client do
let(:instance) { described_class.new(api_key: "123") }
subject { described_class.new(api_key: "123") }

describe "#generate" do
let(:generate_result) { JSON.parse(File.read("spec/fixtures/generate_result.json")) }
Expand All @@ -16,7 +16,7 @@
end

it "returns a response" do
expect(instance.generate(
expect(subject.generate(
prompt: "Once upon a time in a magical land called"
).dig("generations").first.dig("text")).to eq(" The Past there was a Game called Warhammer Fantasy Battle.")
end
Expand All @@ -33,12 +33,41 @@
end

it "returns a response" do
expect(instance.embed(
expect(subject.embed(
texts: ["hello!"]
).dig("embeddings")).to eq([[1.2177734, 0.67529297, 2.0742188]])
end
end

describe "#rerank" do
let(:embed_result) { JSON.parse(File.read("spec/fixtures/rerank.json")) }
let(:response) { OpenStruct.new(body: embed_result) }
let(:docs) {
[
"Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states."
]
}

before do
allow_any_instance_of(Faraday::Connection).to receive(:post)
.with("rerank")
.and_return(response)
end

it "returns a response" do
expect(
subject
.rerank(query: "What is the capital of the United States?", documents: docs)
.dig("results")
.map { |h| h["index"] }
).to eq([3, 4, 2, 0, 1])
end
end

describe "#classify" do
let(:classify_result) { JSON.parse(File.read("spec/fixtures/classify_result.json")) }
let(:response) { OpenStruct.new(body: classify_result) }
Expand All @@ -64,7 +93,7 @@
end

it "returns a response" do
res = instance.classify(
res = subject.classify(
inputs: inputs,
examples: examples
).dig("classifications")
Expand All @@ -85,7 +114,7 @@
end

it "returns a response" do
expect(instance.tokenize(
expect(subject.tokenize(
text: "Hello, world!"
).dig("tokens")).to eq([33555, 1114, 34])
end
Expand All @@ -102,7 +131,7 @@
end

it "returns a response" do
expect(instance.tokenize(
expect(subject.tokenize(
text: "Hello, world!",
model: "base"
).dig("tokens")).to eq([33555, 1114, 34])
Expand All @@ -120,7 +149,7 @@
end

it "returns a response" do
expect(instance.detokenize(
expect(subject.detokenize(
tokens: [33555, 1114, 34]
).dig("text")).to eq("hello world!")
end
Expand All @@ -137,7 +166,7 @@
end

it "returns a response" do
expect(instance.detokenize(
expect(subject.detokenize(
tokens: [33555, 1114, 34],
model: "base"
).dig("text")).to eq("hello world!")
Expand All @@ -155,7 +184,7 @@
end

it "returns a response" do
expect(instance.detect_language(
expect(subject.detect_language(
texts: ["Здравствуй, Мир"]
).dig("results").first.dig("language_code")).to eq("ru")
end
Expand All @@ -172,7 +201,7 @@
end

it "returns a response" do
expect(instance.summarize(
expect(subject.summarize(
text: "Ice cream is a sweetened frozen food typically eaten as a snack or dessert. " \
"It may be made from milk or cream and is flavoured with a sweetener, " \
"either sugar or an alternative, and a spice, such as cocoa or vanilla, " \
Expand Down
33 changes: 33 additions & 0 deletions spec/fixtures/rerank.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"id": "fd2f37a7-78e5-4d43-9230-ca0804f8cab5",
"results": [
{
"index": 3,
"relevance_score": 0.97997653
},
{
"index": 4,
"relevance_score": 0.27963173
},
{
"index": 2,
"relevance_score": 0.10502681
},
{
"index": 0,
"relevance_score": 0.10212547
},
{
"index": 1,
"relevance_score": 0.0721122
}
],
"meta": {
"api_version": {
"version": "1"
},
"billed_units": {
"search_units": 1
}
}
}

0 comments on commit 8b25faa

Please sign in to comment.