Skip to content

Commit

Permalink
Add support for a ping command to pipe clients. (#316)
Browse files Browse the repository at this point in the history
- Add a simple ping protocol to pipe clients conformance service
- Default pipe to use base64 encoding instead of JSON
  • Loading branch information
jnthntatum authored Oct 4, 2023
1 parent 0d5e40d commit 2362c66
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 43 deletions.
6 changes: 4 additions & 2 deletions tests/simple/simple_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ var (
flagSkipCheck bool
flagSkipTests stringArray
flagPipe bool
flagPipePings bool
flagPipeBase64 bool
rc *runConfig
)
Expand All @@ -62,7 +63,8 @@ func init() {
flag.BoolVar(&flagSkipCheck, "skip_check", false, "force skipping the check phase")
flag.Var(&flagSkipTests, "skip_test", "name(s) of tests to skip. can be set multiple times. to skip the following tests: f1/s1/t1, f1/s1/t2, f1/s2/*, f2/s3/t3, you give the arguments --skip_test=f1/s1/t1,t2;s2 --skip_test=f2/s3/t3")
flag.BoolVar(&flagPipe, "pipe", false, "Use pipes instead of gRPC")
flag.BoolVar(&flagPipeBase64, "pipe_base64", false, "Use base64 encoded wire format proto in pipes (default JSON).")
flag.BoolVar(&flagPipeBase64, "pipe_base64", true, "Use base64 encoded wire format proto in pipes (if disabled, use JSON).")
flag.BoolVar(&flagPipePings, "pipe_pings", false, "Enable pinging pipe client to subprocess status.")

flag.Parse()
}
Expand Down Expand Up @@ -103,7 +105,7 @@ func initRunConfig() (*runConfig, error) {
var cli celrpc.ConfClient
var err error
if flagPipe {
cli, err = celrpc.NewPipeClient(cmd, flagPipeBase64)
cli, err = celrpc.NewPipeClient(cmd, flagPipeBase64, flagPipePings)
} else {
cli, err = celrpc.NewGrpcClient(cmd)
}
Expand Down
1 change: 1 addition & 0 deletions tools/celrpc/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ go_library(
"@org_golang_google_grpc//reflection:go_default_library",
"@org_golang_google_protobuf//encoding/protojson:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//types/known/emptypb:go_default_library",
],
)

Expand Down
88 changes: 50 additions & 38 deletions tools/celrpc/celrpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"bufio"
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"log"
Expand All @@ -17,6 +18,7 @@ import (
"google.golang.org/grpc/reflection"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/emptypb"

confpb "google.golang.org/genproto/googleapis/api/expr/conformance/v1alpha1"
)
Expand All @@ -38,15 +40,18 @@ type grpcConfClient struct {
}

// pipe conformance client uses the following protocol:
// * two lines are sent over input
// * first input line is "parse", "check", or "eval"
// * second input line is JSON of the corresponding request
// * one output line is expected, repeat again.
// - two lines are sent over input
// - first input line is "parse", "check", "ping", or "eval"
// - second input line is encoded request message
// - one output line is expected, repeat again.
type pipeConfClient struct {
cmd *exec.Cmd
stdOut *bufio.Reader
stdIn io.Writer
useBase64 bool
binary string
cmd_args []string
cmd *exec.Cmd
stdOut *bufio.Reader
stdIn io.Writer
useBase64 bool
pingsEnabled bool
}

// NewGrpcClient creates a new gRPC ConformanceService client. A server binary
Expand Down Expand Up @@ -122,60 +127,62 @@ func ExampleNewGrpcClient() {
// method returns a non-nil error.
//
// base64Encode enables base64Encoded messages (b64encode(Any.serializeToString))
func NewPipeClient(serverCmd string, base64Encode bool) (ConfClient, error) {
// pingsEnabled enables pinging between reqests to test subprocess health
func NewPipeClient(serverCmd string, base64Encode bool, pingsEnabled bool) (ConfClient, error) {
c := pipeConfClient{
useBase64: base64Encode,
useBase64: base64Encode,
pingsEnabled: pingsEnabled,
}

fields := strings.Fields(serverCmd)
if len(fields) < 1 {
return &c, fmt.Errorf("server cmd '%s' invalid", serverCmd)
}
cmd := exec.Command(fields[0], fields[1:]...)
c.binary = fields[0]
c.cmd_args = fields[1:]

return &c, c.reset()
}

// reset restarts the conformance server piped implementation.
func (c *pipeConfClient) reset() error {
if c.binary == "" {
return errors.New("reset on invalid pipe service configuration")
}
cmd := exec.Command(c.binary, c.cmd_args...)
out, err := cmd.StdoutPipe()
if err != nil {
return &c, err
return err
}
c.stdIn, err = cmd.StdinPipe()
if err != nil {
return &c, err
return err
}
cmd.Stderr = os.Stderr // share our error stream

err = cmd.Start()
if err != nil {
return &c, err
return err
}
// Only assign cmd for stopping if it has successfully started.
c.cmd = cmd
c.stdOut = bufio.NewReader(out)
return &c, nil
return nil
}

// ExampleNewPipeClient creates a new CEL pipe client using a path to a server binary.
// TODO Run from celrpc_test.go.
func ExampleNewPipeClient() {
c, err := NewPipeClient("/path/to/server/binary", false)
defer c.Shutdown()
if err != nil {
log.Fatal("Couldn't create client")
}
parseRequest := confpb.ParseRequest{
CelSource: "1 + 1",
}
parseResponse, err := c.Parse(context.Background(), &parseRequest)
if err != nil {
log.Fatal("Couldn't parse")
}
parsedExpr := parseResponse.ParsedExpr
evalRequest := confpb.EvalRequest{
ExprKind: &confpb.EvalRequest_ParsedExpr{ParsedExpr: parsedExpr},
}
evalResponse, err := c.Eval(context.Background(), &evalRequest)
if err != nil {
log.Fatal("Couldn't eval")
func (c *pipeConfClient) isAlive() bool {
m := emptypb.Empty{}
err := c.pipeCommand("ping", &m, &m)
return err == nil
}

// checkAlive tests the client process health and restarts it on failure.
func (c *pipeConfClient) checkAlive() error {
if c.isAlive() {
return nil
}
fmt.Printf("1 + 1 is %v\n", evalResponse.Result.GetValue().GetInt64Value())
c.Shutdown()
return c.reset()
}

func (c *pipeConfClient) marshal(in proto.Message) (string, error) {
Expand Down Expand Up @@ -203,6 +210,11 @@ func (c *pipeConfClient) unmarshal(encoded string, out proto.Message) error {
}

func (c *pipeConfClient) pipeCommand(cmd string, in proto.Message, out proto.Message) error {
if c.pingsEnabled && cmd != "ping" {
if err := c.checkAlive(); err != nil {
return err
}
}
if _, err := c.stdIn.Write([]byte(cmd + "\n")); err != nil {
return err
}
Expand Down
76 changes: 73 additions & 3 deletions tools/celrpc/celrpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func TestPipeParse(t *testing.T) {
if err != nil {
t.Fatalf("error loading bazel runfile path, %v", err)
}
conf, err := NewPipeClient(serverCmd, useBase64)
conf, err := NewPipeClient(serverCmd, useBase64, false /*usePings*/)
defer conf.Shutdown()
if err != nil {
t.Fatalf("error initializing client got %v wanted nil", err)
Expand Down Expand Up @@ -69,6 +69,76 @@ func TestPipeParse(t *testing.T) {
t.Errorf("Issues: got %v expected none", resp.Issues)
}

if resp.GetParsedExpr().GetExpr().GetCallExpr().GetFunction() != "_+_" {
t.Errorf("unexpected ast got: %s wanted _+_(1, 1)", resp.GetParsedExpr())
}
})
}
}

func TestPipeCrashRecover(t *testing.T) {
for _, useBase64 := range []bool{false, true} {
t := t
t.Run(fmt.Sprintf("useBase64=%v", useBase64), func(t *testing.T) {
serverCmd, err := bazel.Runfile(serverCmd)

if useBase64 {
serverCmd = fmt.Sprintf("%s %s", serverCmd, serverBase64Flag)
}

if err != nil {
t.Fatalf("error loading bazel runfile path, %v", err)
}
conf, err := NewPipeClient(serverCmd, useBase64, true /*usePings*/)
defer conf.Shutdown()
if err != nil {
t.Fatalf("error initializing client got %v wanted nil", err)
}
var resp *confpb.ParseResponse
r := make(chan *confpb.ParseResponse)
e := make(chan error)
go func() {
resp, err := conf.Parse(context.Background(), &confpb.ParseRequest{
CelSource: "test_crash",
})
e <- err
r <- resp
}()

select {
case <-time.After(2 * time.Second):
err = errors.New("timeout")
case err = <-e:
resp = <-r
}

if err == nil {
t.Fatalf("Expected error from pipe, got nil")
}

go func() {
resp, err := conf.Parse(context.Background(), &confpb.ParseRequest{
CelSource: "1 + 1",
})
e <- err
r <- resp
}()

select {
case <-time.After(2 * time.Second):
err = errors.New("timeout")
case err = <-e:
resp = <-r
}

if err != nil {
t.Fatalf("error from pipe: %v", err)
}

if len(resp.Issues) > 0 {
t.Errorf("Issues: got %v expected none", resp.Issues)
}

if resp.GetParsedExpr().GetExpr().GetCallExpr().GetFunction() != "_+_" {
t.Errorf("unexpected ast got: %s wanted _+_(1, 1)", resp.GetParsedExpr())
}
Expand All @@ -90,7 +160,7 @@ func TestPipeEval(t *testing.T) {
if err != nil {
t.Fatalf("error loading bazel runfile path, %v", err)
}
conf, err := NewPipeClient(serverCmd, useBase64)
conf, err := NewPipeClient(serverCmd, useBase64, false /*usePings*/)
defer conf.Shutdown()
if err != nil {
t.Fatalf("error initializing client got %v wanted nil", err)
Expand Down Expand Up @@ -140,7 +210,7 @@ func TestPipeCheck(t *testing.T) {
if err != nil {
t.Fatalf("error loading bazel runfile path, %v", err)
}
conf, err := NewPipeClient(serverCmd, useBase64)
conf, err := NewPipeClient(serverCmd, useBase64, false /*usePings*/)
defer conf.Shutdown()
if err != nil {
t.Fatalf("error initializing client got %v wanted nil", err)
Expand Down
1 change: 1 addition & 0 deletions tools/celrpc/testpipeimpl/main/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ go_binary(
deps = [
"@org_golang_google_protobuf//encoding/protojson:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//types/known/emptypb:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_api//expr/conformance/v1alpha1:go_default_library",
"@org_golang_google_genproto_googleapis_rpc//status:go_default_library",
Expand Down
16 changes: 16 additions & 0 deletions tools/celrpc/testpipeimpl/main/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/emptypb"

confpb "google.golang.org/genproto/googleapis/api/expr/conformance/v1alpha1"
exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
Expand Down Expand Up @@ -138,6 +139,10 @@ func processLoop() int {
}
}

if req.CelSource == "test_crash" {
os.Exit(2)
}

if err = c.serialize(writer, &resp); err != nil {
fmt.Fprintf(os.Stderr, "error serializing parse resp %v\n", err)
return 1
Expand Down Expand Up @@ -182,6 +187,17 @@ func processLoop() int {
fmt.Fprintf(os.Stderr, "error serializing check resp %v\n", err)
return 1
}
case "ping":
req := emptypb.Empty{}
if err := c.unmarshal(msg, &req); err != nil {
fmt.Fprintf(os.Stderr, "bad ping req: %v\n", err)
return 1
}
resp := emptypb.Empty{}
if err = c.serialize(writer, &resp); err != nil {
fmt.Fprintf(os.Stderr, "error serializing ping resp %v\n", err)
return 1
}
default:
fmt.Fprintf(os.Stderr, "unsupported cmd: %s\n", cmd)
return 1
Expand Down

0 comments on commit 2362c66

Please sign in to comment.