Skip to content

Commit

Permalink
Implement semantic search API for unified resources (#31881)
Browse files Browse the repository at this point in the history
* Implement semantic search API for unified resources

* use require.eventually to complete test faster

* reset mod and sum

* Update lib/auth/assist/assistv1/test/service_test.go

Co-authored-by: Jakub Nyckowski <[email protected]>

* use correct proto numbering

* reset e

* make grpc

---------

Co-authored-by: Jakub Nyckowski <[email protected]>
  • Loading branch information
xacrimon and jakule authored Oct 2, 2023
1 parent ba90737 commit fb77d4f
Show file tree
Hide file tree
Showing 10 changed files with 756 additions and 277 deletions.
356 changes: 263 additions & 93 deletions api/gen/proto/go/assist/v1/assist.pb.go

Large diffs are not rendered by default.

39 changes: 39 additions & 0 deletions api/gen/proto/go/assist/v1/assist_grpc.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions api/proto/teleport/assist/v1/assist.proto
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package teleport.assist.v1;

import "google/protobuf/empty.proto";
import "google/protobuf/timestamp.proto";
import "teleport/legacy/client/proto/authservice.proto";

option go_package = "github.com/gravitational/teleport/api/gen/proto/go/assist/v1;assist";

Expand Down Expand Up @@ -150,6 +151,22 @@ message GetAssistantEmbeddingsResponse {
repeated EmbeddedDocument embeddings = 1;
}

// SearchUnifiedResourcesRequest is a request to search for one or more resource kinds using similiarity search.
message SearchUnifiedResourcesRequest {
// query is the query used for similarity search.
string query = 1;
// limit is the number of embeddings to return (also known as k).
int32 limit = 2;
// kinds is the kind of embeddings to return (ex, node). Returns all supported kinds if empty.
repeated string kinds = 3;
}

// SearchUnifiedResourcesResponse is a response from the assistant service with a similarity-ordered list of resources.
message SearchUnifiedResourcesResponse {
// resources is the list of resources.
repeated proto.PaginatedResource resources = 1;
}

// AssistService is a service that provides an ability to communicate with the Teleport Assist.
service AssistService {
// CreateNewConversation creates a new conversation and returns the UUID of it.
Expand All @@ -172,6 +189,9 @@ service AssistService {

// IsAssistEnabled returns true if the assist is enabled or not on the auth level.
rpc IsAssistEnabled(IsAssistEnabledRequest) returns (IsAssistEnabledResponse);

// SearchUnifiedResources returns a similarity-ordered list of resources from the unified resource cache.
rpc SearchUnifiedResources(SearchUnifiedResourcesRequest) returns (SearchUnifiedResourcesResponse);
}

// AssistEmbeddingService is a service that provides an ability to communicate with the Assist Embedding service.
Expand Down
80 changes: 29 additions & 51 deletions lib/ai/embeddingprocessor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,53 +41,6 @@ import (
"github.com/gravitational/teleport/lib/utils"
)

// MockEmbedder returns embeddings based on the sha256 hash function. Those
// embeddings have no semantic meaning but ensure different embedded content
// provides different embeddings.
type MockEmbedder struct {
timesCalled map[string]int
}

func (m *MockEmbedder) ComputeEmbeddings(_ context.Context, input []string) ([]embedding.Vector64, error) {
result := make([]embedding.Vector64, len(input))
for i, text := range input {
name := strings.Split(text, "\n")[0]
m.timesCalled[name]++
hash := sha256.Sum256([]byte(text))
vector := make(embedding.Vector64, len(hash))
for j, x := range hash {
vector[j] = 1 / float64(int(x)+1)
}
result[i] = vector
}
return result, nil
}

type mockResourceGetter struct {
services.Presence
services.AccessLists
}

func (m *mockResourceGetter) GetDatabaseServers(_ context.Context, _ string, _ ...services.MarshalOption) ([]types.DatabaseServer, error) {
return nil, nil
}

func (m *mockResourceGetter) GetKubernetesServers(_ context.Context) ([]types.KubeServer, error) {
return nil, nil
}

func (m *mockResourceGetter) GetApplicationServers(_ context.Context, _ string) ([]types.AppServer, error) {
return nil, nil
}

func (m *mockResourceGetter) GetWindowsDesktops(_ context.Context, _ types.WindowsDesktopFilter) ([]types.WindowsDesktop, error) {
return nil, nil
}

func (m *mockResourceGetter) ListSAMLIdPServiceProviders(_ context.Context, _ int, _ string) ([]types.SAMLIdPServiceProvider, string, error) {
return nil, "", nil
}

func TestNodeEmbeddingGeneration(t *testing.T) {
t.Parallel()

Expand All @@ -104,8 +57,8 @@ func TestNodeEmbeddingGeneration(t *testing.T) {
})
require.NoError(t, err)

embedder := MockEmbedder{
timesCalled: make(map[string]int),
embedder := ai.MockEmbedder{
TimesCalled: make(map[string]int),
}
events := local.NewEventsService(bk)
accessLists, err := local.NewAccessListService(bk, clock)
Expand Down Expand Up @@ -163,7 +116,7 @@ func TestNodeEmbeddingGeneration(t *testing.T) {
nodesAcquired,
embeddings.GetAllEmbeddings(ctx))

for k, v := range embedder.timesCalled {
for k, v := range embedder.TimesCalled {
require.Equal(t, 1, v, "expected %v to be computed once, was %d", k, v)
}

Expand All @@ -184,7 +137,7 @@ func TestNodeEmbeddingGeneration(t *testing.T) {
return len(items) == numInitialNodes+1
}, 7*time.Second, 200*time.Millisecond)

for k, v := range embedder.timesCalled {
for k, v := range embedder.TimesCalled {
expected := 1
if strings.Contains(k, "node1") {
expected = 2
Expand Down Expand Up @@ -331,3 +284,28 @@ func Test_batchReducer_Add(t *testing.T) {
})
}
}

type mockResourceGetter struct {
services.Presence
services.AccessLists
}

func (m *mockResourceGetter) GetDatabaseServers(_ context.Context, _ string, _ ...services.MarshalOption) ([]types.DatabaseServer, error) {
return nil, nil
}

func (m *mockResourceGetter) GetKubernetesServers(_ context.Context) ([]types.KubeServer, error) {
return nil, nil
}

func (m *mockResourceGetter) GetApplicationServers(_ context.Context, _ string) ([]types.AppServer, error) {
return nil, nil
}

func (m *mockResourceGetter) GetWindowsDesktops(_ context.Context, _ types.WindowsDesktopFilter) ([]types.WindowsDesktop, error) {
return nil, nil
}

func (m *mockResourceGetter) ListSAMLIdPServiceProviders(_ context.Context, _ int, _ string) ([]types.SAMLIdPServiceProvider, string, error) {
return nil, "", nil
}
52 changes: 52 additions & 0 deletions lib/ai/mock_embedder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright 2023 Gravitational, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai

import (
"context"
"crypto/sha256"
"strings"
"sync"

"github.com/gravitational/teleport/lib/ai/embedding"
)

// MockEmbedder returns embeddings based on the sha256 hash function. Those
// embeddings have no semantic meaning but ensure different embedded content
// provides different embeddings.
type MockEmbedder struct {
mu sync.Mutex
TimesCalled map[string]int
}

func (m *MockEmbedder) ComputeEmbeddings(_ context.Context, input []string) ([]embedding.Vector64, error) {
m.mu.Lock()
defer m.mu.Unlock()

result := make([]embedding.Vector64, len(input))
for i, text := range input {
name := strings.Split(text, "\n")[0]
m.TimesCalled[name]++
hash := sha256.Sum256([]byte(text))
vector := make(embedding.Vector64, len(hash))
for j, x := range hash {
vector[j] = 1 / float64(int(x)+1)
}
result[i] = vector
}
return result, nil
}
Loading

0 comments on commit fb77d4f

Please sign in to comment.