From 805ae6a30e9597c6e3d7f7240f3de6983288efbb Mon Sep 17 00:00:00 2001 From: reshke Date: Wed, 22 Nov 2023 17:25:19 +0500 Subject: [PATCH] Support new query types (#9) * Support new query types * Fix --- cmd/client/main.go | 63 ++++++++++++++++++++++++++++++++++-- pkg/proc/cat_message.go | 55 +++++++++++++++++++++++++++++++ pkg/proc/command_complete.go | 25 ++++++++++++++ pkg/proc/message.go | 31 ++++++++++++++++++ pkg/proc/proto.go | 56 -------------------------------- pkg/proc/put_message.go | 39 ++++++++++++++++++++++ pkg/proc/ready_for_query.go | 25 ++++++++++++++ 7 files changed, 236 insertions(+), 58 deletions(-) create mode 100644 pkg/proc/cat_message.go create mode 100644 pkg/proc/command_complete.go create mode 100644 pkg/proc/message.go create mode 100644 pkg/proc/put_message.go create mode 100644 pkg/proc/ready_for_query.go diff --git a/cmd/client/main.go b/cmd/client/main.go index 9ba62f2..1c2f1b6 100644 --- a/cmd/client/main.go +++ b/cmd/client/main.go @@ -14,10 +14,16 @@ import ( var cfgPath string var logLevel string var decrypt bool +var encrypt bool var rootCmd = &cobra.Command{ Use: "", Short: "", +} + +var catCmd = &cobra.Command{ + Use: "cat", + Short: "cat", RunE: func(cmd *cobra.Command, args []string) error { err := config.LoadInstanceConfig(cfgPath) @@ -34,7 +40,7 @@ var rootCmd = &cobra.Command{ } defer con.Close() - msg := proc.ConstructMessage(args[0], decrypt) + msg := proc.NewCatMessage(args[0], decrypt).Encode() _, err = con.Write(msg) if err != nil { return err @@ -51,10 +57,63 @@ var rootCmd = &cobra.Command{ }, } +var putCmd = &cobra.Command{ + Use: "put", + Short: "put", + RunE: func(cmd *cobra.Command, args []string) error { + + err := config.LoadInstanceConfig(cfgPath) + if err != nil { + return err + } + + instanceCnf := config.InstanceConfig() + + con, err := net.Dial("unix", instanceCnf.SocketPath) + + if err != nil { + return err + } + + defer con.Close() + msg := proc.NewPutMessage(args[0], encrypt).Encode() + _, err = con.Write(msg) + if err != nil { + return err + } + + ylogger.Zero.Debug().Bytes("msg", msg).Msg("constructed message") + + _, err = io.Copy(os.Stdin, con) + if err != nil { + return err + } + + msg = proc.NewCommandCompleteMessage().Encode() + _, err = con.Write(msg) + if err != nil { + return err + } + + msg = proc.NewReadyForQueryMessage().Encode() + _, err = con.Write(msg) + if err != nil { + return err + } + + return nil + }, +} + func init() { rootCmd.PersistentFlags().StringVarP(&cfgPath, "config", "c", "/etc/yproxy/yproxy.yaml", "path to yproxy config file") rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "", "log level") - rootCmd.PersistentFlags().BoolVarP(&decrypt, "decrypt", "d", false, "decrypt external object or not") + + catCmd.PersistentFlags().BoolVarP(&decrypt, "decrypt", "d", false, "decrypt external object or not") + rootCmd.AddCommand(catCmd) + + putCmd.PersistentFlags().BoolVarP(&encrypt, "encrypt", "e", false, "encrypt external object before put") + rootCmd.AddCommand(putCmd) } func main() { diff --git a/pkg/proc/cat_message.go b/pkg/proc/cat_message.go new file mode 100644 index 0000000..8f8f72b --- /dev/null +++ b/pkg/proc/cat_message.go @@ -0,0 +1,55 @@ +package proc + +import ( + "bytes" + "encoding/binary" +) + +type CatMessage struct { + ProtoMessage + Decrypt bool + Name string +} + +func NewCatMessage(name string, decrypt bool) *CatMessage { + return &CatMessage{ + Name: name, + Decrypt: decrypt, + } +} + +func (c *CatMessage) Encode() []byte { + bt := []byte{ + byte(MessageTypeCat), + 0, + 0, + 0, + } + + if c.Decrypt { + bt[1] = byte(DecryptMessage) + } else { + bt[1] = byte(NoDecryptMessage) + } + + bt = append(bt, []byte(c.Name)...) + bt = append(bt, 0) + ln := len(bt) + 8 + + bs := make([]byte, 8) + binary.BigEndian.PutUint64(bs, uint64(ln)) + return append(bs, bt...) +} + +func GetCatName(b []byte) string { + buff := bytes.NewBufferString("") + + for i := 0; i < len(b); i++ { + if b[i] == 0 { + break + } + buff.WriteByte(b[i]) + } + + return buff.String() +} diff --git a/pkg/proc/command_complete.go b/pkg/proc/command_complete.go new file mode 100644 index 0000000..693aea7 --- /dev/null +++ b/pkg/proc/command_complete.go @@ -0,0 +1,25 @@ +package proc + +import "encoding/binary" + +type CommandCompleteMessage struct { +} + +func NewCommandCompleteMessage() *CommandCompleteMessage { + return &CommandCompleteMessage{} +} + +func (cc *CommandCompleteMessage) Encode() []byte { + bt := []byte{ + byte(MessageTypeCommandComplete), + 0, + 0, + 0, + } + + ln := len(bt) + 8 + + bs := make([]byte, 8) + binary.BigEndian.PutUint64(bs, uint64(ln)) + return append(bs, bt...) +} diff --git a/pkg/proc/message.go b/pkg/proc/message.go new file mode 100644 index 0000000..3231064 --- /dev/null +++ b/pkg/proc/message.go @@ -0,0 +1,31 @@ +package proc + +type ProtoMessage interface { + Decode([]byte) + Encode() []byte +} + +type MessageType byte + +type RequestEncryption byte + +const ( + MessageTypeCat = MessageType(42) + MessageTypePut = MessageType(43) + MessageTypeCommandComplete = MessageType(44) + MessageTypeReadyForQuery = MessageType(45) + + DecryptMessage = RequestEncryption(1) + NoDecryptMessage = RequestEncryption(0) + + EncryptMessage = RequestEncryption(1) + NoEncryptMessage = RequestEncryption(0) +) + +func (m MessageType) String() string { + switch m { + case MessageTypeCat: + return "CAT" + } + return "UNKNOWN" +} diff --git a/pkg/proc/proto.go b/pkg/proc/proto.go index 21f5c8a..234476c 100644 --- a/pkg/proc/proto.go +++ b/pkg/proc/proto.go @@ -1,7 +1,6 @@ package proc import ( - "bytes" "encoding/binary" "fmt" "io" @@ -21,24 +20,6 @@ func NewProtoReader(ycl *client.YClient) *ProtoReader { } } -type MessageType byte - -type RequestEncryption byte - -const ( - MessageTypeCat = MessageType(42) - DecryptMessage = RequestEncryption(1) - NoDecryptMessage = RequestEncryption(0) -) - -func (m MessageType) String() string { - switch m { - case MessageTypeCat: - return "CAT" - } - return "UNKNOWN" -} - const maxMsgLen = 1 << 20 func (r *ProtoReader) ReadPacket() (MessageType, []byte, error) { @@ -71,40 +52,3 @@ func (r *ProtoReader) ReadPacket() (MessageType, []byte, error) { msgType := MessageType(data[0]) return msgType, data, nil } - -func GetCatName(b []byte) string { - buff := bytes.NewBufferString("") - - for i := 0; i < len(b); i++ { - if b[i] == 0 { - break - } - buff.WriteByte(b[i]) - } - - return buff.String() -} - -func ConstructMessage(name string, decrypt bool) []byte { - - bt := []byte{ - byte(MessageTypeCat), - 0, - 0, - 0, - } - - if decrypt { - bt[1] = byte(DecryptMessage) - } else { - bt[1] = byte(NoDecryptMessage) - } - - bt = append(bt, []byte(name)...) - bt = append(bt, 0) - ln := len(bt) + 8 - - bs := make([]byte, 8) - binary.BigEndian.PutUint64(bs, uint64(ln)) - return append(bs, bt...) -} diff --git a/pkg/proc/put_message.go b/pkg/proc/put_message.go new file mode 100644 index 0000000..cfb7788 --- /dev/null +++ b/pkg/proc/put_message.go @@ -0,0 +1,39 @@ +package proc + +import "encoding/binary" + +type PutMessage struct { + ProtoMessage + Encrypt bool + Name string +} + +func NewPutMessage(name string, encrypt bool) *PutMessage { + return &PutMessage{ + Name: name, + Encrypt: encrypt, + } +} + +func (c *PutMessage) Encode() []byte { + bt := []byte{ + byte(MessageTypeCat), + 0, + 0, + 0, + } + + if c.Encrypt { + bt[1] = byte(EncryptMessage) + } else { + bt[1] = byte(NoEncryptMessage) + } + + bt = append(bt, []byte(c.Name)...) + bt = append(bt, 0) + ln := len(bt) + 8 + + bs := make([]byte, 8) + binary.BigEndian.PutUint64(bs, uint64(ln)) + return append(bs, bt...) +} diff --git a/pkg/proc/ready_for_query.go b/pkg/proc/ready_for_query.go new file mode 100644 index 0000000..3972770 --- /dev/null +++ b/pkg/proc/ready_for_query.go @@ -0,0 +1,25 @@ +package proc + +import "encoding/binary" + +type ReadyForQueryMessage struct { +} + +func NewReadyForQueryMessage() *ReadyForQueryMessage { + return &ReadyForQueryMessage{} +} + +func (cc *ReadyForQueryMessage) Encode() []byte { + bt := []byte{ + byte(MessageTypeReadyForQuery), + 0, + 0, + 0, + } + + ln := len(bt) + 8 + + bs := make([]byte, 8) + binary.BigEndian.PutUint64(bs, uint64(ln)) + return append(bs, bt...) +}