Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for prompts, notifications and context.Context #1

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 60 additions & 3 deletions example/main.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package main

import (
"context"
"crypto/sha256"
"errors"
"fmt"
"time"

"golang.org/x/time/rate"

Expand Down Expand Up @@ -37,21 +39,51 @@ func main() {
},
}

s := mcp.NewServer(serverInfo, tools)
prompts := []mcp.PromptDefinition{
{
Metadata: mcp.Prompt{
Name: "example",
Description: ptr("An example prompt template"),
Arguments: []mcp.PromptArgument{
{
Name: "text",
Description: ptr("Text to process"),
Required: ptr(true),
},
},
},
Process: processPrompt,
RateLimit: rate.NewLimiter(rate.Every(time.Second), 5),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The spec doesn't require rate limiting of prompt gets in the same way that it does for tool calls. However having a rate limit seems reasonable.

},
}

s := mcp.NewServer(serverInfo, tools, prompts)
s.Serve()
}

func computeSHA256(params mcp.CallToolRequestParams) (mcp.CallToolResult, error) {
func ptr[T any](t T) *T {
return &t
}

// Update the computeSHA256 function to send a single notification
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment looks like a TODO item. It could be removed or changed to document the function.

func computeSHA256(ctx context.Context, n mcp.Notifier, params mcp.CallToolRequestParams) (mcp.CallToolResult, error) {
txt := params.Arguments["text"].(string)

if len(txt) == 0 {
return mcp.CallToolResult{}, errors.New("failed to compute checksum: text cannot be empty")
}

err := n.Notify(ctx, "test/notification", map[string]any{
"message": "Processing text",
})
if err != nil {
fmt.Printf("Failed to send notification: %v\n", err)
}

h := sha256.New()
h.Write([]byte(txt))

checksum := fmt.Sprintf("%x", h.Sum(nil))

var noError bool
return mcp.CallToolResult{
Content: []any{
Expand All @@ -63,3 +95,28 @@ func computeSHA256(params mcp.CallToolRequestParams) (mcp.CallToolResult, error)
IsError: &noError,
}, nil
}

func processPrompt(ctx context.Context, n mcp.Notifier, params mcp.GetPromptRequestParams) (mcp.GetPromptResult, error) {
if params.Arguments["text"] == "" {
return mcp.GetPromptResult{}, errors.New("input text cannot be empty")
}

err := n.Notify(ctx, "test/notification", map[string]any{
"message": "Processing text",
})
if err != nil {
fmt.Printf("Failed to send notification: %v\n", err)
}

return mcp.GetPromptResult{
Messages: []mcp.PromptMessage{
{
Role: mcp.RoleAssistant,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would a role of user be more natural for a prompt with a single message?

Content: mcp.TextContent{
Type: "text",
Text: "Processed: " + params.Arguments["text"],
},
},
},
}, nil
}
129 changes: 122 additions & 7 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,64 @@ import (

const SupportedProtocolVersion = "2024-11-05"

// Notifier provides a method for sending MCP notifications
type Notifier interface {
Notify(ctx context.Context, method string, params any) error
}

// connNotifier implements Notifier using a jsonrpc2.Conn
type connNotifier struct{ *jsonrpc2.Conn }

func (n *connNotifier) Notify(ctx context.Context, method string, params any) error {
return n.Conn.Notify(ctx, method, params)
}

type ToolDefinition struct {
Metadata Tool
Execute func(CallToolRequestParams) (CallToolResult, error)
Execute func(context.Context, Notifier, CallToolRequestParams) (CallToolResult, error)
RateLimit *rate.Limiter
}

type PromptDefinition struct {
Metadata Prompt
Process func(context.Context, Notifier, GetPromptRequestParams) (GetPromptResult, error)
RateLimit *rate.Limiter
}

type handler struct {
serverInfo Implementation
toolMetadata []Tool
tools map[string]ToolDefinition
serverInfo Implementation
toolMetadata []Tool
tools map[string]ToolDefinition
promptMetadata []Prompt
prompts map[string]PromptDefinition
}

type Server struct {
handler *handler
}

func NewServer(serverInfo Implementation, tools []ToolDefinition) *Server {
func NewServer(serverInfo Implementation, tools []ToolDefinition, prompts []PromptDefinition) *Server {
toolMetadata := make([]Tool, 0, len(tools))
toolFuncs := make(map[string]ToolDefinition, len(tools))
for _, t := range tools {
toolMetadata = append(toolMetadata, t.Metadata)
toolFuncs[t.Metadata.Name] = t
}
return &Server{handler: &handler{serverInfo: serverInfo, toolMetadata: toolMetadata, tools: toolFuncs}}

promptMetadata := make([]Prompt, 0, len(prompts))
promptFuncs := make(map[string]PromptDefinition, len(prompts))
for _, p := range prompts {
promptMetadata = append(promptMetadata, p.Metadata)
promptFuncs[p.Metadata.Name] = p
}

return &Server{handler: &handler{
serverInfo: serverInfo,
toolMetadata: toolMetadata,
tools: toolFuncs,
promptMetadata: promptMetadata,
prompts: promptFuncs,
}}
}

func (s *Server) Serve() {
Expand All @@ -56,6 +90,10 @@ func (h *handler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2
h.handleListTools(ctx, conn, req)
case "tools/call":
h.handleToolCall(ctx, conn, req)
case "prompts/list":
h.handleListPrompts(ctx, conn, req)
case "prompts/get":
h.handleGetPrompt(ctx, conn, req)
default:
h.replyWithJSONRPCError(ctx, conn, req, &jsonrpc2.Error{
Code: jsonrpc2.CodeMethodNotFound,
Expand All @@ -74,6 +112,9 @@ func (h *handler) handleInitialize(ctx context.Context, conn *jsonrpc2.Conn, req
Tools: &ServerCapabilitiesTools{
ListChanged: &unsupported,
},
Prompts: &ServerCapabilitiesPrompts{
ListChanged: &unsupported,
},
},
}
h.replyWithResult(ctx, conn, req, response)
Expand Down Expand Up @@ -128,7 +169,8 @@ func (h *handler) handleToolCall(ctx context.Context, conn *jsonrpc2.Conn, req *
}
}

response, err := t.Execute(params)
notifier := &connNotifier{Conn: conn}
response, err := t.Execute(ctx, notifier, params)
if err != nil {
h.replyWithToolError(ctx, conn, req, err.Error())
return
Expand All @@ -137,6 +179,79 @@ func (h *handler) handleToolCall(ctx context.Context, conn *jsonrpc2.Conn, req *
h.replyWithResult(ctx, conn, req, response)
}

func (h *handler) handleListPrompts(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) {
var params ListPromptsRequestParams
if req.Params != nil {
if err := json.Unmarshal(*req.Params, &params); err != nil || params.Cursor != nil {
h.replyWithJSONRPCError(ctx, conn, req, &jsonrpc2.Error{
Code: jsonrpc2.CodeInvalidParams,
Message: "Invalid params",
})
return
}
}
h.replyWithResult(ctx, conn, req, ListPromptsResult{Prompts: h.promptMetadata})
}

func (h *handler) handleGetPrompt(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) {
var params GetPromptRequestParams
if err := json.Unmarshal(*req.Params, &params); err != nil {
h.replyWithJSONRPCError(ctx, conn, req, &jsonrpc2.Error{
Code: jsonrpc2.CodeInvalidParams,
Message: "Invalid params",
})
return
}

p, ok := h.prompts[params.Name]
if !ok {
h.replyWithJSONRPCError(ctx, conn, req, &jsonrpc2.Error{
Code: jsonrpc2.CodeInvalidParams,
Message: fmt.Sprintf("Unknown prompt: %s", params.Name),
})
return
}

if !p.RateLimit.Allow() {
h.replyWithPromptError(ctx, conn, req, "rate limit exceeded")
return
}

for _, arg := range p.Metadata.Arguments {
if arg.Required != nil && *arg.Required {
if _, ok := params.Arguments[arg.Name]; !ok {
h.replyWithJSONRPCError(ctx, conn, req, &jsonrpc2.Error{
Code: jsonrpc2.CodeInvalidParams,
Message: fmt.Sprintf("Missing required argument: %s", arg.Name),
})
return
}
}
}

notifier := &connNotifier{Conn: conn}
result, err := p.Process(ctx, notifier, params)
if err != nil {
h.replyWithPromptError(ctx, conn, req, err.Error())
return
}

h.replyWithResult(ctx, conn, req, result)
}

func (h *handler) replyWithPromptError(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request, errMsg string) {
result := GetPromptResult{
Messages: []PromptMessage{{
Role: RoleAssistant,
Content: TextContent{
Type: "text",
Text: errMsg,
},
}},
}
h.replyWithResult(ctx, conn, req, result)
}

func (h *handler) replyWithJSONRPCError(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request, rpcErr *jsonrpc2.Error) {
if err := conn.ReplyWithError(ctx, req.ID, rpcErr); err != nil {
slog.Error("problem replying with error", "method", req.Method, "error", err)
Expand Down
Loading