Skip to content

Commit

Permalink
Forward request from custom handler to localserver
Browse files Browse the repository at this point in the history
  • Loading branch information
RebeccaMahany committed May 6, 2024
1 parent bafd23f commit e500e70
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 5 deletions.
85 changes: 82 additions & 3 deletions ee/customprotocol/custom_protocol_handler_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@ package customprotocol
import "C"
import (
"context"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"strings"
"sync"

"github.com/kolide/launcher/ee/localserver"
)

var urlInput chan string
Expand Down Expand Up @@ -70,15 +78,86 @@ func (c *customProtocolHandler) Interrupt(_ error) {
c.interrupt <- struct{}{}
}

// handleCustomProtocolRequest receives requests and logs them. In the future,
// it will validate them and forward them to launcher root.
// handleCustomProtocolRequest receives requests, performs a small amount of validation,
// and then forwards them to launcher root's localserver.
func (c *customProtocolHandler) handleCustomProtocolRequest(requestUrl string) error {
c.slogger.Log(context.TODO(), slog.LevelInfo,
"received custom protocol request",
"request_url", requestUrl,
)

// TODO: validate the request and forward it to launcher root
requestPath, err := extractRequestPath(requestUrl)
if err != nil {
return fmt.Errorf("extracting request path from URL: %w", err)
}

// Collect errors to return IFF we are unable to successfully forward to any port
var forwardingResultsLock sync.Mutex
forwardingErrorMsgs := make([]string, 0)
successfullyForwarded := false

// Attempt to forward the request to every port launcher potentially listens on
var wg sync.WaitGroup
for _, p := range localserver.PortList {
wg.Add(1)
p := p
go func() {
defer wg.Done()

err := forwardRequest(p, requestPath)

forwardingResultsLock.Lock()
defer forwardingResultsLock.Unlock()
if err != nil {
forwardingErrorMsgs = append(forwardingErrorMsgs, err.Error())
} else {
successfullyForwarded = true
}
}()
}

wg.Wait()

if !successfullyForwarded {
return fmt.Errorf("unable to successfully forward request to any launcher port: %s", strings.Join(forwardingErrorMsgs, ";"))
}

return nil
}

// extractRequestPath pulls out the path and query from the custom protocol request, discarding the
// scheme/host.
func extractRequestPath(requestUrl string) (string, error) {
// Validate that we received a legitimate-looking URL
parsedUrl, err := url.Parse(requestUrl)
if err != nil {
return "", fmt.Errorf("unparseable url: %w", err)
}

return parsedUrl.RequestURI(), nil
}

// forwardRequest makes the request with the given `reqPath` to localserver at the given `port`.
func forwardRequest(port int, reqPath string) error {
reqUrl := fmt.Sprintf("http://localhost:%d/%s", port, strings.TrimPrefix(reqPath, "/"))
req, err := http.NewRequest(http.MethodGet, reqUrl, nil)
if err != nil {
return fmt.Errorf("creating forward request: %w", err)
}

resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("making request: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("received non-200 status code %d from localhost and could not read response body: %w", resp.StatusCode, err)
}
return fmt.Errorf("received non-200 status code %d from localhost: %s", resp.StatusCode, string(respBytes))
}

return nil
}
Expand Down
53 changes: 53 additions & 0 deletions ee/customprotocol/custom_protocol_handler_darwin_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
//go:build darwin
// +build darwin

package customprotocol

import (
"testing"

"github.com/stretchr/testify/require"
)

func Test_extractRequestPath(t *testing.T) {
t.Parallel()

for _, tt := range []struct {
testCaseName string
requestUrl string
expectedPath string
expectedError bool
}{
{
testCaseName: "valid request",
requestUrl: "kolide://local/v3/cmd?box=abcd",
expectedPath: "/v3/cmd?box=abcd",
expectedError: false,
},
{
testCaseName: "valid request, no query",
requestUrl: "kolide://local/v4/cmd",
expectedPath: "/v4/cmd",
expectedError: false,
},
{
testCaseName: "invalid request",
requestUrl: string(rune(0x7f)), // invalid control character in URL
expectedPath: "",
expectedError: true,
},
} {
tt := tt
t.Run(tt.testCaseName, func(t *testing.T) {
t.Parallel()

reqPath, err := extractRequestPath(tt.requestUrl)
if tt.expectedError {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, tt.expectedPath, reqPath)
}
})
}
}
4 changes: 2 additions & 2 deletions ee/localserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
)

// Special Kolide Ports
var portList = []int{
var PortList = []int{
12519,
40978,
52115,
Expand Down Expand Up @@ -324,7 +324,7 @@ func (ls *localServer) Interrupt(_ error) {
func (ls *localServer) startListener() (net.Listener, error) {
ctx := context.TODO()

for _, p := range portList {
for _, p := range PortList {
ls.slogger.Log(ctx, slog.LevelDebug,
"trying port",
"port", p,
Expand Down

0 comments on commit e500e70

Please sign in to comment.