diff --git a/tests/simple/simple_test.go b/tests/simple/simple_test.go index e715ef84..68d79e69 100644 --- a/tests/simple/simple_test.go +++ b/tests/simple/simple_test.go @@ -49,6 +49,7 @@ var ( flagSkipCheck bool flagSkipTests stringArray flagPipe bool + flagPipePings bool flagPipeBase64 bool rc *runConfig ) @@ -62,7 +63,8 @@ func init() { flag.BoolVar(&flagSkipCheck, "skip_check", false, "force skipping the check phase") flag.Var(&flagSkipTests, "skip_test", "name(s) of tests to skip. can be set multiple times. to skip the following tests: f1/s1/t1, f1/s1/t2, f1/s2/*, f2/s3/t3, you give the arguments --skip_test=f1/s1/t1,t2;s2 --skip_test=f2/s3/t3") flag.BoolVar(&flagPipe, "pipe", false, "Use pipes instead of gRPC") - flag.BoolVar(&flagPipeBase64, "pipe_base64", false, "Use base64 encoded wire format proto in pipes (default JSON).") + flag.BoolVar(&flagPipeBase64, "pipe_base64", true, "Use base64 encoded wire format proto in pipes (if disabled, use JSON).") + flag.BoolVar(&flagPipePings, "pipe_pings", false, "Enable pinging pipe client to subprocess status.") flag.Parse() } @@ -103,7 +105,7 @@ func initRunConfig() (*runConfig, error) { var cli celrpc.ConfClient var err error if flagPipe { - cli, err = celrpc.NewPipeClient(cmd, flagPipeBase64) + cli, err = celrpc.NewPipeClient(cmd, flagPipeBase64, flagPipePings) } else { cli, err = celrpc.NewGrpcClient(cmd) } diff --git a/tools/celrpc/BUILD.bazel b/tools/celrpc/BUILD.bazel index 503410d3..17182842 100644 --- a/tools/celrpc/BUILD.bazel +++ b/tools/celrpc/BUILD.bazel @@ -16,6 +16,7 @@ go_library( "@org_golang_google_grpc//reflection:go_default_library", "@org_golang_google_protobuf//encoding/protojson:go_default_library", "@org_golang_google_protobuf//proto:go_default_library", + "@org_golang_google_protobuf//types/known/emptypb:go_default_library", ], ) diff --git a/tools/celrpc/celrpc.go b/tools/celrpc/celrpc.go index e43ff1d8..65492f10 100644 --- a/tools/celrpc/celrpc.go +++ b/tools/celrpc/celrpc.go @@ -5,6 +5,7 @@ import ( "bufio" "context" "encoding/base64" + "errors" "fmt" "io" "log" @@ -17,6 +18,7 @@ import ( "google.golang.org/grpc/reflection" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/emptypb" confpb "google.golang.org/genproto/googleapis/api/expr/conformance/v1alpha1" ) @@ -38,15 +40,18 @@ type grpcConfClient struct { } // pipe conformance client uses the following protocol: -// * two lines are sent over input -// * first input line is "parse", "check", or "eval" -// * second input line is JSON of the corresponding request -// * one output line is expected, repeat again. +// - two lines are sent over input +// - first input line is "parse", "check", "ping", or "eval" +// - second input line is encoded request message +// - one output line is expected, repeat again. type pipeConfClient struct { - cmd *exec.Cmd - stdOut *bufio.Reader - stdIn io.Writer - useBase64 bool + binary string + cmd_args []string + cmd *exec.Cmd + stdOut *bufio.Reader + stdIn io.Writer + useBase64 bool + pingsEnabled bool } // NewGrpcClient creates a new gRPC ConformanceService client. A server binary @@ -122,60 +127,62 @@ func ExampleNewGrpcClient() { // method returns a non-nil error. // // base64Encode enables base64Encoded messages (b64encode(Any.serializeToString)) -func NewPipeClient(serverCmd string, base64Encode bool) (ConfClient, error) { +// pingsEnabled enables pinging between reqests to test subprocess health +func NewPipeClient(serverCmd string, base64Encode bool, pingsEnabled bool) (ConfClient, error) { c := pipeConfClient{ - useBase64: base64Encode, + useBase64: base64Encode, + pingsEnabled: pingsEnabled, } fields := strings.Fields(serverCmd) if len(fields) < 1 { return &c, fmt.Errorf("server cmd '%s' invalid", serverCmd) } - cmd := exec.Command(fields[0], fields[1:]...) + c.binary = fields[0] + c.cmd_args = fields[1:] + + return &c, c.reset() +} + +// reset restarts the conformance server piped implementation. +func (c *pipeConfClient) reset() error { + if c.binary == "" { + return errors.New("reset on invalid pipe service configuration") + } + cmd := exec.Command(c.binary, c.cmd_args...) out, err := cmd.StdoutPipe() if err != nil { - return &c, err + return err } c.stdIn, err = cmd.StdinPipe() if err != nil { - return &c, err + return err } cmd.Stderr = os.Stderr // share our error stream err = cmd.Start() if err != nil { - return &c, err + return err } // Only assign cmd for stopping if it has successfully started. c.cmd = cmd c.stdOut = bufio.NewReader(out) - return &c, nil + return nil } -// ExampleNewPipeClient creates a new CEL pipe client using a path to a server binary. -// TODO Run from celrpc_test.go. -func ExampleNewPipeClient() { - c, err := NewPipeClient("/path/to/server/binary", false) - defer c.Shutdown() - if err != nil { - log.Fatal("Couldn't create client") - } - parseRequest := confpb.ParseRequest{ - CelSource: "1 + 1", - } - parseResponse, err := c.Parse(context.Background(), &parseRequest) - if err != nil { - log.Fatal("Couldn't parse") - } - parsedExpr := parseResponse.ParsedExpr - evalRequest := confpb.EvalRequest{ - ExprKind: &confpb.EvalRequest_ParsedExpr{ParsedExpr: parsedExpr}, - } - evalResponse, err := c.Eval(context.Background(), &evalRequest) - if err != nil { - log.Fatal("Couldn't eval") +func (c *pipeConfClient) isAlive() bool { + m := emptypb.Empty{} + err := c.pipeCommand("ping", &m, &m) + return err == nil +} + +// checkAlive tests the client process health and restarts it on failure. +func (c *pipeConfClient) checkAlive() error { + if c.isAlive() { + return nil } - fmt.Printf("1 + 1 is %v\n", evalResponse.Result.GetValue().GetInt64Value()) + c.Shutdown() + return c.reset() } func (c *pipeConfClient) marshal(in proto.Message) (string, error) { @@ -203,6 +210,11 @@ func (c *pipeConfClient) unmarshal(encoded string, out proto.Message) error { } func (c *pipeConfClient) pipeCommand(cmd string, in proto.Message, out proto.Message) error { + if c.pingsEnabled && cmd != "ping" { + if err := c.checkAlive(); err != nil { + return err + } + } if _, err := c.stdIn.Write([]byte(cmd + "\n")); err != nil { return err } diff --git a/tools/celrpc/celrpc_test.go b/tools/celrpc/celrpc_test.go index 37e7d122..d4710d8e 100644 --- a/tools/celrpc/celrpc_test.go +++ b/tools/celrpc/celrpc_test.go @@ -37,7 +37,7 @@ func TestPipeParse(t *testing.T) { if err != nil { t.Fatalf("error loading bazel runfile path, %v", err) } - conf, err := NewPipeClient(serverCmd, useBase64) + conf, err := NewPipeClient(serverCmd, useBase64, false /*usePings*/) defer conf.Shutdown() if err != nil { t.Fatalf("error initializing client got %v wanted nil", err) @@ -69,6 +69,76 @@ func TestPipeParse(t *testing.T) { t.Errorf("Issues: got %v expected none", resp.Issues) } + if resp.GetParsedExpr().GetExpr().GetCallExpr().GetFunction() != "_+_" { + t.Errorf("unexpected ast got: %s wanted _+_(1, 1)", resp.GetParsedExpr()) + } + }) + } +} + +func TestPipeCrashRecover(t *testing.T) { + for _, useBase64 := range []bool{false, true} { + t := t + t.Run(fmt.Sprintf("useBase64=%v", useBase64), func(t *testing.T) { + serverCmd, err := bazel.Runfile(serverCmd) + + if useBase64 { + serverCmd = fmt.Sprintf("%s %s", serverCmd, serverBase64Flag) + } + + if err != nil { + t.Fatalf("error loading bazel runfile path, %v", err) + } + conf, err := NewPipeClient(serverCmd, useBase64, true /*usePings*/) + defer conf.Shutdown() + if err != nil { + t.Fatalf("error initializing client got %v wanted nil", err) + } + var resp *confpb.ParseResponse + r := make(chan *confpb.ParseResponse) + e := make(chan error) + go func() { + resp, err := conf.Parse(context.Background(), &confpb.ParseRequest{ + CelSource: "test_crash", + }) + e <- err + r <- resp + }() + + select { + case <-time.After(2 * time.Second): + err = errors.New("timeout") + case err = <-e: + resp = <-r + } + + if err == nil { + t.Fatalf("Expected error from pipe, got nil") + } + + go func() { + resp, err := conf.Parse(context.Background(), &confpb.ParseRequest{ + CelSource: "1 + 1", + }) + e <- err + r <- resp + }() + + select { + case <-time.After(2 * time.Second): + err = errors.New("timeout") + case err = <-e: + resp = <-r + } + + if err != nil { + t.Fatalf("error from pipe: %v", err) + } + + if len(resp.Issues) > 0 { + t.Errorf("Issues: got %v expected none", resp.Issues) + } + if resp.GetParsedExpr().GetExpr().GetCallExpr().GetFunction() != "_+_" { t.Errorf("unexpected ast got: %s wanted _+_(1, 1)", resp.GetParsedExpr()) } @@ -90,7 +160,7 @@ func TestPipeEval(t *testing.T) { if err != nil { t.Fatalf("error loading bazel runfile path, %v", err) } - conf, err := NewPipeClient(serverCmd, useBase64) + conf, err := NewPipeClient(serverCmd, useBase64, false /*usePings*/) defer conf.Shutdown() if err != nil { t.Fatalf("error initializing client got %v wanted nil", err) @@ -140,7 +210,7 @@ func TestPipeCheck(t *testing.T) { if err != nil { t.Fatalf("error loading bazel runfile path, %v", err) } - conf, err := NewPipeClient(serverCmd, useBase64) + conf, err := NewPipeClient(serverCmd, useBase64, false /*usePings*/) defer conf.Shutdown() if err != nil { t.Fatalf("error initializing client got %v wanted nil", err) diff --git a/tools/celrpc/testpipeimpl/main/BUILD.bazel b/tools/celrpc/testpipeimpl/main/BUILD.bazel index 1fc6e5d3..9e1d630f 100644 --- a/tools/celrpc/testpipeimpl/main/BUILD.bazel +++ b/tools/celrpc/testpipeimpl/main/BUILD.bazel @@ -12,6 +12,7 @@ go_binary( deps = [ "@org_golang_google_protobuf//encoding/protojson:go_default_library", "@org_golang_google_protobuf//proto:go_default_library", + "@org_golang_google_protobuf//types/known/emptypb:go_default_library", "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library", "@org_golang_google_genproto_googleapis_api//expr/conformance/v1alpha1:go_default_library", "@org_golang_google_genproto_googleapis_rpc//status:go_default_library", diff --git a/tools/celrpc/testpipeimpl/main/main.go b/tools/celrpc/testpipeimpl/main/main.go index fcfa8b14..e05567f4 100644 --- a/tools/celrpc/testpipeimpl/main/main.go +++ b/tools/celrpc/testpipeimpl/main/main.go @@ -15,6 +15,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/emptypb" confpb "google.golang.org/genproto/googleapis/api/expr/conformance/v1alpha1" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" @@ -138,6 +139,10 @@ func processLoop() int { } } + if req.CelSource == "test_crash" { + os.Exit(2) + } + if err = c.serialize(writer, &resp); err != nil { fmt.Fprintf(os.Stderr, "error serializing parse resp %v\n", err) return 1 @@ -182,6 +187,17 @@ func processLoop() int { fmt.Fprintf(os.Stderr, "error serializing check resp %v\n", err) return 1 } + case "ping": + req := emptypb.Empty{} + if err := c.unmarshal(msg, &req); err != nil { + fmt.Fprintf(os.Stderr, "bad ping req: %v\n", err) + return 1 + } + resp := emptypb.Empty{} + if err = c.serialize(writer, &resp); err != nil { + fmt.Fprintf(os.Stderr, "error serializing ping resp %v\n", err) + return 1 + } default: fmt.Fprintf(os.Stderr, "unsupported cmd: %s\n", cmd) return 1