Skip to content

Commit

Permalink
Add new msg type handlers (#11)
Browse files Browse the repository at this point in the history
* Add new msg type handlers

* Fixes

* more fixes
  • Loading branch information
reshke authored Dec 15, 2023
1 parent 8172f0f commit dcd3c2d
Show file tree
Hide file tree
Showing 12 changed files with 256 additions and 11 deletions.
9 changes: 8 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@

build:
mkdir -p devbin
go build -o devbin/yproxy ./cmd/yproxy
go build -o devbin/client ./cmd/client
go build -o devbin/client ./cmd/client

####################### TESTS #######################

unittest:
go test -race ./pkg/proc/...

23 changes: 20 additions & 3 deletions cmd/client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,26 @@ var putCmd = &cobra.Command{

ylogger.Zero.Debug().Bytes("msg", msg).Msg("constructed message")

_, err = io.Copy(os.Stdin, con)
if err != nil {
return err
const SZ = 65536
chunk := make([]byte, SZ)
for {
n, err := os.Stdin.Read(chunk)
if n > 0 {
msg := proc.NewCopyDataMessage()
msg.Sz = uint64(n)
copy(msg.Data, chunk[:n])

_, err = con.Write(msg.Encode())
if err != nil {
return err
}
}

if err == io.EOF {
break
} else {
return err
}
}

msg = proc.NewCommandCompleteMessage().Encode()
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/yezzey-gp/yproxy

go 1.21.1
go 1.21

require (
github.com/BurntSushi/toml v1.3.2
Expand Down
17 changes: 17 additions & 0 deletions pkg/crypt/crypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

type Crypter interface {
Decrypt(reader io.Reader) (io.Reader, error)
Encrypt(writer io.WriteCloser) (io.WriteCloser, error)
}

type GPGCrypter struct {
Expand Down Expand Up @@ -82,3 +83,19 @@ func (g *GPGCrypter) Decrypt(reader io.Reader) (io.Reader, error) {

return md.UnverifiedBody, nil
}

func (g *GPGCrypter) Encrypt(writer io.WriteCloser) (io.WriteCloser, error) {
err := g.loadSecret()
if err != nil {
return nil, err
}
ylogger.Zero.Debug().Str("gpg path", g.cnf.GPGKeyPath).Msg("loaded gpg key")

encryptedWriter, err := openpgp.Encrypt(writer, g.PubKey, nil, nil, nil)

if err != nil {
return nil, errors.WithStack(err)
}

return encryptedWriter, nil
}
5 changes: 5 additions & 0 deletions pkg/proc/command_complete.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package proc
import "encoding/binary"

type CommandCompleteMessage struct {
ProtoMessage
}

func NewCommandCompleteMessage() *CommandCompleteMessage {
Expand All @@ -23,3 +24,7 @@ func (cc *CommandCompleteMessage) Encode() []byte {
binary.BigEndian.PutUint64(bs, uint64(ln))
return append(bs, bt...)
}

func (c *CommandCompleteMessage) Decode(body []byte) error {
return nil
}
41 changes: 41 additions & 0 deletions pkg/proc/copy_data.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package proc

import (
"encoding/binary"
)

type CopyDataMessage struct {
ProtoMessage
Sz uint64
Data []byte
}

func NewCopyDataMessage() *CopyDataMessage {
return &CopyDataMessage{}
}

func (cc *CopyDataMessage) Encode() []byte {
bt := make([]byte, 4+cc.Sz)

bt[0] = byte(MessageTypeCopyData)

// sizeof(sz) + data
ln := len(bt) + 8 + 8 + int(cc.Sz)

bs := make([]byte, 8)
binary.BigEndian.PutUint64(bs, uint64(ln))

binary.BigEndian.PutUint64(bt[4:], uint64(cc.Sz))

// check data len more than cc.sz?
copy(bt[4+8:], cc.Data[:cc.Sz])

return append(bs, bt...)
}

func (cc *CopyDataMessage) Decode(data []byte) {
msgLenBuf := data[4:12]
cc.Sz = binary.BigEndian.Uint64(msgLenBuf)
cc.Data = make([]byte, cc.Sz)
copy(cc.Data, data[12:12+cc.Sz])
}
60 changes: 59 additions & 1 deletion pkg/proc/interaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ package proc

import (
"io"
"sync"

"github.com/yezzey-gp/yproxy/pkg/client"
"github.com/yezzey-gp/yproxy/pkg/crypt"
"github.com/yezzey-gp/yproxy/pkg/storage"
"github.com/yezzey-gp/yproxy/pkg/ylogger"
)

func ProcConn(s storage.StorageReader, cr crypt.Crypter, ycl *client.YClient) error {
func ProcConn(s storage.StorageInteractor, cr crypt.Crypter, ycl *client.YClient) error {
pr := NewProtoReader(ycl)
tp, body, err := pr.ReadPacket()
if err != nil {
Expand Down Expand Up @@ -46,6 +47,63 @@ func ProcConn(s storage.StorageReader, cr crypt.Crypter, ycl *client.YClient) er

case MessageTypePut:

msg := PutMessage{}
msg.Decode(body)

var w io.WriteCloser

r, w := io.Pipe()

if msg.Encrypt {
var err error
w, err = cr.Encrypt(w)
if err != nil {
_ = ycl.ReplyError(err, "failed to encrypt")

return ycl.Conn.Close()
}
}

wg := sync.WaitGroup{}
wg.Add(1)

go func() {
defer wg.Done()
for {
tp, body, err := pr.ReadPacket()
if err != nil {
_ = ycl.ReplyError(err, "failed to compelete request")

_ = ycl.Conn.Close()
return
}

switch tp {
case MessageTypeCopyData:
msg := CopyDataMessage{}
msg.Decode(body)
w.Write(msg.Data)
case MessageTypeCommandComplete:
msg := CommandCompleteMessage{}
msg.Decode(body)
case MessageTypeReadyForQuery:
msg := ReadyForQueryMessage{}
msg.Decode(body)
return
}
}
}()

err := s.PutFileToDest(msg.Name, r)

wg.Wait()

if err != nil {
_ = ycl.ReplyError(err, "failed to upload")

return ycl.Conn.Close()
}

default:

_ = ycl.ReplyError(nil, "wrong request type")
Expand Down
1 change: 1 addition & 0 deletions pkg/proc/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ const (
MessageTypePut = MessageType(43)
MessageTypeCommandComplete = MessageType(44)
MessageTypeReadyForQuery = MessageType(45)
MessageTypeCopyData = MessageType(46)

DecryptMessage = RequestEncryption(1)
NoDecryptMessage = RequestEncryption(0)
Expand Down
30 changes: 30 additions & 0 deletions pkg/proc/message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,33 @@ func TestCatMsg(t *testing.T) {
assert.Equal(msg.Decrypt, msg2.Decrypt)
}
}

func TestPutMsg(t *testing.T) {
assert := assert.New(t)

type tcase struct {
name string
encrypt bool
err error
}

for _, tt := range []tcase{
{
"nam1",
true,
nil,
},
} {

msg := proc.NewPutMessage(tt.name, tt.encrypt)
body := msg.Encode()

msg2 := proc.CatMessage{}

err := msg2.Decode(body[8:])

assert.NoError(err)
assert.Equal(msg.Name, msg2.Name)
assert.Equal(msg.Encrypt, msg2.Decrypt)
}
}
26 changes: 25 additions & 1 deletion pkg/proc/put_message.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package proc

import "encoding/binary"
import (
"bytes"
"encoding/binary"
)

type PutMessage struct {
ProtoMessage
Expand Down Expand Up @@ -37,3 +40,24 @@ func (c *PutMessage) Encode() []byte {
binary.BigEndian.PutUint64(bs, uint64(ln))
return append(bs, bt...)
}

func (c *PutMessage) GetPutName(b []byte) string {
buff := bytes.NewBufferString("")

for i := 0; i < len(b); i++ {
if b[i] == 0 {
break
}
buff.WriteByte(b[i])
}

return buff.String()
}

func (c *PutMessage) Decode(body []byte) error {
if body[1] == byte(EncryptMessage) {
c.Encrypt = true
}
c.Name = c.GetPutName(body[4:])
return nil
}
5 changes: 5 additions & 0 deletions pkg/proc/ready_for_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package proc
import "encoding/binary"

type ReadyForQueryMessage struct {
ProtoMessage
}

func NewReadyForQueryMessage() *ReadyForQueryMessage {
Expand All @@ -23,3 +24,7 @@ func (cc *ReadyForQueryMessage) Encode() []byte {
binary.BigEndian.PutUint64(bs, uint64(ln))
return append(bs, bt...)
}

func (c *ReadyForQueryMessage) Decode(body []byte) error {
return nil
}
48 changes: 44 additions & 4 deletions pkg/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/yezzey-gp/yproxy/config"
"github.com/yezzey-gp/yproxy/pkg/ylogger"
)
Expand All @@ -15,19 +16,31 @@ type StorageReader interface {
CatFileFromStorage(name string) (io.Reader, error)
}

type S3StorageReader struct {
type StorageWriter interface {
PutFileToDest(name string, r io.Reader) error
}

type StorageInteractor interface {
StorageReader
StorageWriter
}

type S3StorageInteractor struct {
StorageReader
StorageWriter

pool SessionPool
cnf *config.Storage
}

func NewStorage(cnf *config.Storage) StorageReader {
return &S3StorageReader{
func NewStorage(cnf *config.Storage) StorageInteractor {
return &S3StorageInteractor{
pool: NewSessionPool(cnf),
cnf: cnf,
}
}

func (s *S3StorageReader) CatFileFromStorage(name string) (io.Reader, error) {
func (s *S3StorageInteractor) CatFileFromStorage(name string) (io.Reader, error) {
// XXX: fix this
sess, err := s.pool.GetSession(context.TODO())
if err != nil {
Expand All @@ -47,3 +60,30 @@ func (s *S3StorageReader) CatFileFromStorage(name string) (io.Reader, error) {
object, err := sess.GetObject(input)
return object.Body, err
}

func (s *S3StorageInteractor) PutFileToDest(name string, r io.Reader) error {
sess, err := s.pool.GetSession(context.TODO())
if err != nil {
ylogger.Zero.Err(err).Msg("failed to acquire s3 session")
return nil
}

objectPath := path.Join(s.cnf.StoragePrefix, name)

up := s3manager.NewUploaderWithClient(sess, func(uploader *s3manager.Uploader) {
uploader.PartSize = int64(1 << 20)
uploader.Concurrency = 1
})

_, err = up.Upload(
&s3manager.UploadInput{

Bucket: aws.String(s.cnf.StorageBucket),
Key: aws.String(objectPath),
Body: r,
StorageClass: aws.String("STANDARD"),
},
)

return err
}

0 comments on commit dcd3c2d

Please sign in to comment.