Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new msg type handlers #11

Merged
merged 3 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@

build:
mkdir -p devbin
go build -o devbin/yproxy ./cmd/yproxy
go build -o devbin/client ./cmd/client
go build -o devbin/client ./cmd/client

####################### TESTS #######################

unittest:
go test -race ./pkg/proc/...

23 changes: 20 additions & 3 deletions cmd/client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,26 @@ var putCmd = &cobra.Command{

ylogger.Zero.Debug().Bytes("msg", msg).Msg("constructed message")

_, err = io.Copy(os.Stdin, con)
if err != nil {
return err
const SZ = 65536
chunk := make([]byte, SZ)
for {
n, err := os.Stdin.Read(chunk)
if n > 0 {
msg := proc.NewCopyDataMessage()
msg.Sz = uint64(n)
copy(msg.Data, chunk[:n])

_, err = con.Write(msg.Encode())
if err != nil {
return err
}
}

if err == io.EOF {
break
} else {
return err
}
}

msg = proc.NewCommandCompleteMessage().Encode()
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/yezzey-gp/yproxy

go 1.21.1
go 1.21

require (
github.com/BurntSushi/toml v1.3.2
Expand Down
17 changes: 17 additions & 0 deletions pkg/crypt/crypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

type Crypter interface {
Decrypt(reader io.Reader) (io.Reader, error)
Encrypt(writer io.WriteCloser) (io.WriteCloser, error)
}

type GPGCrypter struct {
Expand Down Expand Up @@ -82,3 +83,19 @@ func (g *GPGCrypter) Decrypt(reader io.Reader) (io.Reader, error) {

return md.UnverifiedBody, nil
}

func (g *GPGCrypter) Encrypt(writer io.WriteCloser) (io.WriteCloser, error) {
err := g.loadSecret()
if err != nil {
return nil, err
}
ylogger.Zero.Debug().Str("gpg path", g.cnf.GPGKeyPath).Msg("loaded gpg key")

encryptedWriter, err := openpgp.Encrypt(writer, g.PubKey, nil, nil, nil)

if err != nil {
return nil, errors.WithStack(err)
}

return encryptedWriter, nil
}
5 changes: 5 additions & 0 deletions pkg/proc/command_complete.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package proc
import "encoding/binary"

type CommandCompleteMessage struct {
ProtoMessage
}

func NewCommandCompleteMessage() *CommandCompleteMessage {
Expand All @@ -23,3 +24,7 @@ func (cc *CommandCompleteMessage) Encode() []byte {
binary.BigEndian.PutUint64(bs, uint64(ln))
return append(bs, bt...)
}

func (c *CommandCompleteMessage) Decode(body []byte) error {
return nil
}
41 changes: 41 additions & 0 deletions pkg/proc/copy_data.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package proc

import (
"encoding/binary"
)

type CopyDataMessage struct {
ProtoMessage
Sz uint64
Data []byte
}

func NewCopyDataMessage() *CopyDataMessage {
return &CopyDataMessage{}
}

func (cc *CopyDataMessage) Encode() []byte {
bt := make([]byte, 4+cc.Sz)

bt[0] = byte(MessageTypeCopyData)

// sizeof(sz) + data
ln := len(bt) + 8 + 8 + int(cc.Sz)

bs := make([]byte, 8)
binary.BigEndian.PutUint64(bs, uint64(ln))

binary.BigEndian.PutUint64(bt[4:], uint64(cc.Sz))

// check data len more than cc.sz?
copy(bt[4+8:], cc.Data[:cc.Sz])

return append(bs, bt...)
}

func (cc *CopyDataMessage) Decode(data []byte) {
msgLenBuf := data[4:12]
cc.Sz = binary.BigEndian.Uint64(msgLenBuf)
cc.Data = make([]byte, cc.Sz)
copy(cc.Data, data[12:12+cc.Sz])
}
60 changes: 59 additions & 1 deletion pkg/proc/interaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ package proc

import (
"io"
"sync"

"github.com/yezzey-gp/yproxy/pkg/client"
"github.com/yezzey-gp/yproxy/pkg/crypt"
"github.com/yezzey-gp/yproxy/pkg/storage"
"github.com/yezzey-gp/yproxy/pkg/ylogger"
)

func ProcConn(s storage.StorageReader, cr crypt.Crypter, ycl *client.YClient) error {
func ProcConn(s storage.StorageInteractor, cr crypt.Crypter, ycl *client.YClient) error {
pr := NewProtoReader(ycl)
tp, body, err := pr.ReadPacket()
if err != nil {
Expand Down Expand Up @@ -46,6 +47,63 @@ func ProcConn(s storage.StorageReader, cr crypt.Crypter, ycl *client.YClient) er

case MessageTypePut:

msg := PutMessage{}
msg.Decode(body)

var w io.WriteCloser

r, w := io.Pipe()

if msg.Encrypt {
var err error
w, err = cr.Encrypt(w)
if err != nil {
_ = ycl.ReplyError(err, "failed to encrypt")

return ycl.Conn.Close()
}
}

wg := sync.WaitGroup{}
wg.Add(1)

go func() {
defer wg.Done()
for {
tp, body, err := pr.ReadPacket()
if err != nil {
_ = ycl.ReplyError(err, "failed to compelete request")

_ = ycl.Conn.Close()
return
}

switch tp {
case MessageTypeCopyData:
msg := CopyDataMessage{}
msg.Decode(body)
w.Write(msg.Data)
case MessageTypeCommandComplete:
msg := CommandCompleteMessage{}
msg.Decode(body)
case MessageTypeReadyForQuery:
msg := ReadyForQueryMessage{}
msg.Decode(body)
return
}
}
}()

err := s.PutFileToDest(msg.Name, r)

wg.Wait()

if err != nil {
_ = ycl.ReplyError(err, "failed to upload")

return ycl.Conn.Close()
}

default:

_ = ycl.ReplyError(nil, "wrong request type")
Expand Down
1 change: 1 addition & 0 deletions pkg/proc/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ const (
MessageTypePut = MessageType(43)
MessageTypeCommandComplete = MessageType(44)
MessageTypeReadyForQuery = MessageType(45)
MessageTypeCopyData = MessageType(46)

DecryptMessage = RequestEncryption(1)
NoDecryptMessage = RequestEncryption(0)
Expand Down
30 changes: 30 additions & 0 deletions pkg/proc/message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,33 @@ func TestCatMsg(t *testing.T) {
assert.Equal(msg.Decrypt, msg2.Decrypt)
}
}

func TestPutMsg(t *testing.T) {
assert := assert.New(t)

type tcase struct {
name string
encrypt bool
err error
}

for _, tt := range []tcase{
{
"nam1",
true,
nil,
},
} {

msg := proc.NewPutMessage(tt.name, tt.encrypt)
body := msg.Encode()

msg2 := proc.CatMessage{}

err := msg2.Decode(body[8:])

assert.NoError(err)
assert.Equal(msg.Name, msg2.Name)
assert.Equal(msg.Encrypt, msg2.Decrypt)
}
}
26 changes: 25 additions & 1 deletion pkg/proc/put_message.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package proc

import "encoding/binary"
import (
"bytes"
"encoding/binary"
)

type PutMessage struct {
ProtoMessage
Expand Down Expand Up @@ -37,3 +40,24 @@ func (c *PutMessage) Encode() []byte {
binary.BigEndian.PutUint64(bs, uint64(ln))
return append(bs, bt...)
}

func (c *PutMessage) GetPutName(b []byte) string {
buff := bytes.NewBufferString("")

for i := 0; i < len(b); i++ {
if b[i] == 0 {
break
}
buff.WriteByte(b[i])
}

return buff.String()
}

func (c *PutMessage) Decode(body []byte) error {
if body[1] == byte(EncryptMessage) {
c.Encrypt = true
}
c.Name = c.GetPutName(body[4:])
return nil
}
5 changes: 5 additions & 0 deletions pkg/proc/ready_for_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package proc
import "encoding/binary"

type ReadyForQueryMessage struct {
ProtoMessage
}

func NewReadyForQueryMessage() *ReadyForQueryMessage {
Expand All @@ -23,3 +24,7 @@ func (cc *ReadyForQueryMessage) Encode() []byte {
binary.BigEndian.PutUint64(bs, uint64(ln))
return append(bs, bt...)
}

func (c *ReadyForQueryMessage) Decode(body []byte) error {
return nil
}
48 changes: 44 additions & 4 deletions pkg/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/yezzey-gp/yproxy/config"
"github.com/yezzey-gp/yproxy/pkg/ylogger"
)
Expand All @@ -15,19 +16,31 @@ type StorageReader interface {
CatFileFromStorage(name string) (io.Reader, error)
}

type S3StorageReader struct {
type StorageWriter interface {
PutFileToDest(name string, r io.Reader) error
}

type StorageInteractor interface {
StorageReader
StorageWriter
}

type S3StorageInteractor struct {
StorageReader
StorageWriter

pool SessionPool
cnf *config.Storage
}

func NewStorage(cnf *config.Storage) StorageReader {
return &S3StorageReader{
func NewStorage(cnf *config.Storage) StorageInteractor {
return &S3StorageInteractor{
pool: NewSessionPool(cnf),
cnf: cnf,
}
}

func (s *S3StorageReader) CatFileFromStorage(name string) (io.Reader, error) {
func (s *S3StorageInteractor) CatFileFromStorage(name string) (io.Reader, error) {
// XXX: fix this
sess, err := s.pool.GetSession(context.TODO())
if err != nil {
Expand All @@ -47,3 +60,30 @@ func (s *S3StorageReader) CatFileFromStorage(name string) (io.Reader, error) {
object, err := sess.GetObject(input)
return object.Body, err
}

func (s *S3StorageInteractor) PutFileToDest(name string, r io.Reader) error {
sess, err := s.pool.GetSession(context.TODO())
if err != nil {
ylogger.Zero.Err(err).Msg("failed to acquire s3 session")
return nil
}

objectPath := path.Join(s.cnf.StoragePrefix, name)

up := s3manager.NewUploaderWithClient(sess, func(uploader *s3manager.Uploader) {
uploader.PartSize = int64(1 << 20)
uploader.Concurrency = 1
})

_, err = up.Upload(
&s3manager.UploadInput{

Bucket: aws.String(s.cnf.StorageBucket),
Key: aws.String(objectPath),
Body: r,
StorageClass: aws.String("STANDARD"),
},
)

return err
}