diff --git a/Makefile b/Makefile index a0bd848..562f978 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,11 @@ + build: mkdir -p devbin go build -o devbin/yproxy ./cmd/yproxy - go build -o devbin/client ./cmd/client \ No newline at end of file + go build -o devbin/client ./cmd/client + +####################### TESTS ####################### + +unittest: + go test -race ./pkg/proc/... + diff --git a/cmd/client/main.go b/cmd/client/main.go index 1c2f1b6..c4dffe7 100644 --- a/cmd/client/main.go +++ b/cmd/client/main.go @@ -84,9 +84,26 @@ var putCmd = &cobra.Command{ ylogger.Zero.Debug().Bytes("msg", msg).Msg("constructed message") - _, err = io.Copy(os.Stdin, con) - if err != nil { - return err + const SZ = 65536 + chunk := make([]byte, SZ) + for { + n, err := os.Stdin.Read(chunk) + if n > 0 { + msg := proc.NewCopyDataMessage() + msg.Sz = uint64(n) + copy(msg.Data, chunk[:n]) + + _, err = con.Write(msg.Encode()) + if err != nil { + return err + } + } + + if err == io.EOF { + break + } else { + return err + } } msg = proc.NewCommandCompleteMessage().Encode() diff --git a/go.mod b/go.mod index 7708e36..b48d556 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/yezzey-gp/yproxy -go 1.21.1 +go 1.21 require ( github.com/BurntSushi/toml v1.3.2 diff --git a/pkg/crypt/crypt.go b/pkg/crypt/crypt.go index b57b03f..07708df 100644 --- a/pkg/crypt/crypt.go +++ b/pkg/crypt/crypt.go @@ -14,6 +14,7 @@ import ( type Crypter interface { Decrypt(reader io.Reader) (io.Reader, error) + Encrypt(writer io.WriteCloser) (io.WriteCloser, error) } type GPGCrypter struct { @@ -82,3 +83,19 @@ func (g *GPGCrypter) Decrypt(reader io.Reader) (io.Reader, error) { return md.UnverifiedBody, nil } + +func (g *GPGCrypter) Encrypt(writer io.WriteCloser) (io.WriteCloser, error) { + err := g.loadSecret() + if err != nil { + return nil, err + } + ylogger.Zero.Debug().Str("gpg path", g.cnf.GPGKeyPath).Msg("loaded gpg key") + + encryptedWriter, err := openpgp.Encrypt(writer, g.PubKey, nil, nil, nil) + + if err != nil { + return nil, errors.WithStack(err) + } + + return encryptedWriter, nil +} diff --git a/pkg/proc/command_complete.go b/pkg/proc/command_complete.go index 693aea7..c4768e2 100644 --- a/pkg/proc/command_complete.go +++ b/pkg/proc/command_complete.go @@ -3,6 +3,7 @@ package proc import "encoding/binary" type CommandCompleteMessage struct { + ProtoMessage } func NewCommandCompleteMessage() *CommandCompleteMessage { @@ -23,3 +24,7 @@ func (cc *CommandCompleteMessage) Encode() []byte { binary.BigEndian.PutUint64(bs, uint64(ln)) return append(bs, bt...) } + +func (c *CommandCompleteMessage) Decode(body []byte) error { + return nil +} diff --git a/pkg/proc/copy_data.go b/pkg/proc/copy_data.go new file mode 100644 index 0000000..0fc1b18 --- /dev/null +++ b/pkg/proc/copy_data.go @@ -0,0 +1,41 @@ +package proc + +import ( + "encoding/binary" +) + +type CopyDataMessage struct { + ProtoMessage + Sz uint64 + Data []byte +} + +func NewCopyDataMessage() *CopyDataMessage { + return &CopyDataMessage{} +} + +func (cc *CopyDataMessage) Encode() []byte { + bt := make([]byte, 4+cc.Sz) + + bt[0] = byte(MessageTypeCopyData) + + // sizeof(sz) + data + ln := len(bt) + 8 + 8 + int(cc.Sz) + + bs := make([]byte, 8) + binary.BigEndian.PutUint64(bs, uint64(ln)) + + binary.BigEndian.PutUint64(bt[4:], uint64(cc.Sz)) + + // check data len more than cc.sz? + copy(bt[4+8:], cc.Data[:cc.Sz]) + + return append(bs, bt...) +} + +func (cc *CopyDataMessage) Decode(data []byte) { + msgLenBuf := data[4:12] + cc.Sz = binary.BigEndian.Uint64(msgLenBuf) + cc.Data = make([]byte, cc.Sz) + copy(cc.Data, data[12:12+cc.Sz]) +} diff --git a/pkg/proc/interaction.go b/pkg/proc/interaction.go index ccaae2e..87ebaa7 100644 --- a/pkg/proc/interaction.go +++ b/pkg/proc/interaction.go @@ -2,6 +2,7 @@ package proc import ( "io" + "sync" "github.com/yezzey-gp/yproxy/pkg/client" "github.com/yezzey-gp/yproxy/pkg/crypt" @@ -9,7 +10,7 @@ import ( "github.com/yezzey-gp/yproxy/pkg/ylogger" ) -func ProcConn(s storage.StorageReader, cr crypt.Crypter, ycl *client.YClient) error { +func ProcConn(s storage.StorageInteractor, cr crypt.Crypter, ycl *client.YClient) error { pr := NewProtoReader(ycl) tp, body, err := pr.ReadPacket() if err != nil { @@ -46,6 +47,63 @@ func ProcConn(s storage.StorageReader, cr crypt.Crypter, ycl *client.YClient) er case MessageTypePut: + msg := PutMessage{} + msg.Decode(body) + + var w io.WriteCloser + + r, w := io.Pipe() + + if msg.Encrypt { + var err error + w, err = cr.Encrypt(w) + if err != nil { + _ = ycl.ReplyError(err, "failed to encrypt") + + return ycl.Conn.Close() + } + } + + wg := sync.WaitGroup{} + wg.Add(1) + + go func() { + defer wg.Done() + for { + tp, body, err := pr.ReadPacket() + if err != nil { + _ = ycl.ReplyError(err, "failed to compelete request") + + _ = ycl.Conn.Close() + return + } + + switch tp { + case MessageTypeCopyData: + msg := CopyDataMessage{} + msg.Decode(body) + w.Write(msg.Data) + case MessageTypeCommandComplete: + msg := CommandCompleteMessage{} + msg.Decode(body) + case MessageTypeReadyForQuery: + msg := ReadyForQueryMessage{} + msg.Decode(body) + return + } + } + }() + + err := s.PutFileToDest(msg.Name, r) + + wg.Wait() + + if err != nil { + _ = ycl.ReplyError(err, "failed to upload") + + return ycl.Conn.Close() + } + default: _ = ycl.ReplyError(nil, "wrong request type") diff --git a/pkg/proc/message.go b/pkg/proc/message.go index 3231064..1cf6466 100644 --- a/pkg/proc/message.go +++ b/pkg/proc/message.go @@ -14,6 +14,7 @@ const ( MessageTypePut = MessageType(43) MessageTypeCommandComplete = MessageType(44) MessageTypeReadyForQuery = MessageType(45) + MessageTypeCopyData = MessageType(46) DecryptMessage = RequestEncryption(1) NoDecryptMessage = RequestEncryption(0) diff --git a/pkg/proc/message_test.go b/pkg/proc/message_test.go index f85da76..f9b1a92 100644 --- a/pkg/proc/message_test.go +++ b/pkg/proc/message_test.go @@ -36,3 +36,33 @@ func TestCatMsg(t *testing.T) { assert.Equal(msg.Decrypt, msg2.Decrypt) } } + +func TestPutMsg(t *testing.T) { + assert := assert.New(t) + + type tcase struct { + name string + encrypt bool + err error + } + + for _, tt := range []tcase{ + { + "nam1", + true, + nil, + }, + } { + + msg := proc.NewPutMessage(tt.name, tt.encrypt) + body := msg.Encode() + + msg2 := proc.CatMessage{} + + err := msg2.Decode(body[8:]) + + assert.NoError(err) + assert.Equal(msg.Name, msg2.Name) + assert.Equal(msg.Encrypt, msg2.Decrypt) + } +} diff --git a/pkg/proc/put_message.go b/pkg/proc/put_message.go index cfb7788..03972a8 100644 --- a/pkg/proc/put_message.go +++ b/pkg/proc/put_message.go @@ -1,6 +1,9 @@ package proc -import "encoding/binary" +import ( + "bytes" + "encoding/binary" +) type PutMessage struct { ProtoMessage @@ -37,3 +40,24 @@ func (c *PutMessage) Encode() []byte { binary.BigEndian.PutUint64(bs, uint64(ln)) return append(bs, bt...) } + +func (c *PutMessage) GetPutName(b []byte) string { + buff := bytes.NewBufferString("") + + for i := 0; i < len(b); i++ { + if b[i] == 0 { + break + } + buff.WriteByte(b[i]) + } + + return buff.String() +} + +func (c *PutMessage) Decode(body []byte) error { + if body[1] == byte(EncryptMessage) { + c.Encrypt = true + } + c.Name = c.GetPutName(body[4:]) + return nil +} diff --git a/pkg/proc/ready_for_query.go b/pkg/proc/ready_for_query.go index 3972770..99e6a2f 100644 --- a/pkg/proc/ready_for_query.go +++ b/pkg/proc/ready_for_query.go @@ -3,6 +3,7 @@ package proc import "encoding/binary" type ReadyForQueryMessage struct { + ProtoMessage } func NewReadyForQueryMessage() *ReadyForQueryMessage { @@ -23,3 +24,7 @@ func (cc *ReadyForQueryMessage) Encode() []byte { binary.BigEndian.PutUint64(bs, uint64(ln)) return append(bs, bt...) } + +func (c *ReadyForQueryMessage) Decode(body []byte) error { + return nil +} diff --git a/pkg/storage/storage.go b/pkg/storage/storage.go index a496b27..3cdbac3 100644 --- a/pkg/storage/storage.go +++ b/pkg/storage/storage.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/yezzey-gp/yproxy/config" "github.com/yezzey-gp/yproxy/pkg/ylogger" ) @@ -15,19 +16,31 @@ type StorageReader interface { CatFileFromStorage(name string) (io.Reader, error) } -type S3StorageReader struct { +type StorageWriter interface { + PutFileToDest(name string, r io.Reader) error +} + +type StorageInteractor interface { + StorageReader + StorageWriter +} + +type S3StorageInteractor struct { + StorageReader + StorageWriter + pool SessionPool cnf *config.Storage } -func NewStorage(cnf *config.Storage) StorageReader { - return &S3StorageReader{ +func NewStorage(cnf *config.Storage) StorageInteractor { + return &S3StorageInteractor{ pool: NewSessionPool(cnf), cnf: cnf, } } -func (s *S3StorageReader) CatFileFromStorage(name string) (io.Reader, error) { +func (s *S3StorageInteractor) CatFileFromStorage(name string) (io.Reader, error) { // XXX: fix this sess, err := s.pool.GetSession(context.TODO()) if err != nil { @@ -47,3 +60,30 @@ func (s *S3StorageReader) CatFileFromStorage(name string) (io.Reader, error) { object, err := sess.GetObject(input) return object.Body, err } + +func (s *S3StorageInteractor) PutFileToDest(name string, r io.Reader) error { + sess, err := s.pool.GetSession(context.TODO()) + if err != nil { + ylogger.Zero.Err(err).Msg("failed to acquire s3 session") + return nil + } + + objectPath := path.Join(s.cnf.StoragePrefix, name) + + up := s3manager.NewUploaderWithClient(sess, func(uploader *s3manager.Uploader) { + uploader.PartSize = int64(1 << 20) + uploader.Concurrency = 1 + }) + + _, err = up.Upload( + &s3manager.UploadInput{ + + Bucket: aws.String(s.cnf.StorageBucket), + Key: aws.String(objectPath), + Body: r, + StorageClass: aws.String("STANDARD"), + }, + ) + + return err +}