Skip to content

Commit

Permalink
Add tls (#85)
Browse files Browse the repository at this point in the history
* add TLS encryption

* accept TLS certificates directly, not though files

* send ReplicaID = -1

* fix after review

* add test for TLS

* fix testserver

* fix data race
  • Loading branch information
e-max authored and hashmap committed Feb 2, 2018
1 parent cec0458 commit a11af9f
Show file tree
Hide file tree
Showing 12 changed files with 540 additions and 11 deletions.
29 changes: 23 additions & 6 deletions broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,16 @@ type BrokerConf struct {
// logging frameworks. Used to notify and as replacement for stdlib `log`
// package.
Logger Logger

//Settings for TLS encryption.
//You need to set all these parameters to enable TLS

//TLS CA pem
TLSCa []byte
//TLS certificate
TLSCert []byte
//TLS key
TLSKey []byte
}

// NewBrokerConf returns the default broker configuration.
Expand Down Expand Up @@ -311,7 +321,7 @@ func (b *Broker) fetchMetadata(topics ...string) (*proto.MetadataResp, error) {
if _, ok := checkednodes[nodeID]; ok {
continue
}
conn, err := newTCPConnection(addr, b.conf.DialTimeout, b.conf.ReadTimeout)
conn, err := b.getConnection(addr)
if err != nil {
b.conf.Logger.Debug("cannot connect",
"address", addr,
Expand All @@ -336,7 +346,7 @@ func (b *Broker) fetchMetadata(topics ...string) (*proto.MetadataResp, error) {
}

for _, addr := range b.getInitialAddresses() {
conn, err := newTCPConnection(addr, b.conf.DialTimeout, b.conf.ReadTimeout)
conn, err := b.getConnection(addr)
if err != nil {
b.conf.Logger.Debug("cannot connect to seed node",
"address", addr,
Expand Down Expand Up @@ -475,7 +485,7 @@ func (b *Broker) muLeaderConnection(topic string, partition int32) (conn *connec
delete(b.metadata.endpoints, tp)
continue
}
conn, err = newTCPConnection(addr, b.conf.DialTimeout, b.conf.ReadTimeout)
conn, err = b.getConnection(addr)
if err != nil {
b.conf.Logger.Info("cannot get leader connection: cannot connect to node",
"address", addr,
Expand All @@ -490,6 +500,13 @@ func (b *Broker) muLeaderConnection(topic string, partition int32) (conn *connec
return nil, err
}

func (b *Broker) getConnection(addr string) (*connection, error) {
if b.conf.TLSCa != nil && b.conf.TLSKey != nil && b.conf.TLSCert != nil {
return newTLSConnection(addr, b.conf.TLSCa, b.conf.TLSCert, b.conf.TLSKey, b.conf.DialTimeout, b.conf.ReadTimeout)
}
return newTCPConnection(addr, b.conf.DialTimeout, b.conf.ReadTimeout)
}

// coordinatorConnection returns connection to offset coordinator for given group.
//
// Failed connection retry is controlled by broker configuration.
Expand Down Expand Up @@ -526,7 +543,7 @@ func (b *Broker) muCoordinatorConnection(consumerGroup string) (conn *connection
}

addr := fmt.Sprintf("%s:%d", resp.CoordinatorHost, resp.CoordinatorPort)
conn, err := newTCPConnection(addr, b.conf.DialTimeout, b.conf.ReadTimeout)
conn, err := b.getConnection(addr)
if err != nil {
b.conf.Logger.Debug("cannot connect to node",
"coordinatorID", resp.CoordinatorID,
Expand All @@ -552,7 +569,7 @@ func (b *Broker) muCoordinatorConnection(consumerGroup string) (conn *connection
// connection to node is cached so it was already checked
continue
}
conn, err := newTCPConnection(addr, b.conf.DialTimeout, b.conf.ReadTimeout)
conn, err := b.getConnection(addr)
if err != nil {
b.conf.Logger.Debug("cannot connect to node",
"nodeID", nodeID,
Expand Down Expand Up @@ -583,7 +600,7 @@ func (b *Broker) muCoordinatorConnection(consumerGroup string) (conn *connection
}

addr := fmt.Sprintf("%s:%d", resp.CoordinatorHost, resp.CoordinatorPort)
conn, err = newTCPConnection(addr, b.conf.DialTimeout, b.conf.ReadTimeout)
conn, err = b.getConnection(addr)
if err != nil {
b.conf.Logger.Debug("cannot connect to node",
"coordinatorID", resp.CoordinatorID,
Expand Down
40 changes: 40 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package kafka
import (
"bufio"
"bytes"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"math"
Expand All @@ -29,6 +31,44 @@ type connection struct {
readTimeout time.Duration
}

func newTLSConnection(address string, ca, cert, key []byte, timeout, readTimeout time.Duration) (*connection, error) {
roots := x509.NewCertPool()
ok := roots.AppendCertsFromPEM(ca)
if !ok {
return nil, fmt.Errorf("Cannot parse root certificate")
}

certificate, err := tls.X509KeyPair(cert, key)
if err != nil {
return nil, fmt.Errorf("Failed to parse key/cert for TLS: %s", err)
}

conf := &tls.Config{
Certificates: []tls.Certificate{certificate},
RootCAs: roots,
}

dialer := net.Dialer{
Timeout: timeout,
KeepAlive: 30 * time.Second,
}
conn, err := tls.DialWithDialer(&dialer, "tcp", address, conf)
if err != nil {
return nil, err
}
c := &connection{
stop: make(chan struct{}),
nextID: make(chan int32),
rw: conn,
respc: make(map[int32]chan []byte),
logger: &nullLogger{},
readTimeout: readTimeout,
}
go c.nextIDLoop()
go c.readRespLoop()
return c, nil
}

// newConnection returns new, initialized connection or error
func newTCPConnection(address string, timeout, readTimeout time.Duration) (*connection, error) {
dialer := net.Dialer{
Expand Down
172 changes: 172 additions & 0 deletions connection_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
package kafka

import (
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"reflect"
"strings"
Expand All @@ -10,10 +16,120 @@ import (
"github.com/optiopay/kafka/proto"
)

const TLSCaFile = "./testkeys/ca.crt"
const TLSCertFile = "./testkeys/oats.crt"
const TLSKeyFile = "./testkeys/oats.key"

type serializableMessage interface {
Bytes() ([]byte, error)
}

type TLSConf struct {
ca []byte
cert []byte
key []byte
}

func getTLSConf() (*TLSConf, error) {
ca, err := ioutil.ReadFile(TLSCaFile)
if err != nil {
return nil, fmt.Errorf("Cannot read %s", TLSCaFile)
}
cert, err := ioutil.ReadFile(TLSCertFile)
if err != nil {
return nil, fmt.Errorf("Cannot read %s", TLSCertFile)
}

key, err := ioutil.ReadFile(TLSKeyFile)
if err != nil {
return nil, fmt.Errorf("Cannot read %s", TLSKeyFile)
}

return &TLSConf{ca: ca, cert: cert, key: key}, nil

}

//just read request before start to response
func readRequest(r io.Reader) error {
dec := proto.NewDecoder(r)
size := dec.DecodeInt32()
var read int32 = 0
buf := make([]byte, size)

for read < size {
n, err := r.Read(buf)
if err != nil {
return err
}
read += int32(n)
}
return nil
}

func testTLSServer(messages ...serializableMessage) (net.Listener, error) {
tlsConf, err := getTLSConf()
if err != nil {
return nil, err
}

roots := x509.NewCertPool()
ok := roots.AppendCertsFromPEM(tlsConf.ca)
if !ok {
return nil, fmt.Errorf("Cannot parse root certificate")
}

certificate, err := tls.X509KeyPair(tlsConf.cert, tlsConf.key)
if err != nil {
return nil, fmt.Errorf("Failed to parse key/cert for TLS: %s", err)
}

conf := &tls.Config{
Certificates: []tls.Certificate{certificate},
RootCAs: roots,
}

_ = conf

ln, err := tls.Listen("tcp4", "localhost:22222", conf)
if err != nil {
return nil, err
}

responses := make([][]byte, len(messages))
for i, m := range messages {
b, err := m.Bytes()
if err != nil {
_ = ln.Close()
return nil, err
}
responses[i] = b
}

go func() {
for {
cli, err := ln.Accept()

if err != nil {
return
}

go func(conn net.Conn) {
err := readRequest(conn)
if err != nil {
log.Panic(err)
}

time.Sleep(time.Millisecond * 50)
for _, resp := range responses {
_, _ = cli.Write(resp)
}
err = cli.Close()
}(cli)
}
}()
return ln, nil
}

func testServer(messages ...serializableMessage) (net.Listener, error) {
ln, err := net.Listen("tcp4", "")
if err != nil {
Expand Down Expand Up @@ -620,3 +736,59 @@ func TestNoServerResponse(t *testing.T) {
t.Fatalf("could not close test server: %s", err)
}
}

func TestTLSConnection(t *testing.T) {
resp1 := &proto.MetadataResp{
CorrelationID: 1,
Brokers: []proto.MetadataRespBroker{
{
NodeID: 666,
Host: "example.com",
Port: 999,
},
},
Topics: []proto.MetadataRespTopic{
{
Name: "foo",
Partitions: []proto.MetadataRespPartition{
{
ID: 7,
Leader: 7,
Replicas: []int32{7},
Isrs: []int32{7},
},
},
},
},
}
ln, err := testTLSServer(resp1)
if err != nil {
t.Fatalf("test server error: %s", err)
}
tlsConf, err := getTLSConf()
if err != nil {
t.Fatalf("cannot get tls parametes: %s", err)
}
_ = tlsConf
conn, err := newTLSConnection(ln.Addr().String(), tlsConf.ca, tlsConf.cert, tlsConf.key, time.Second, time.Second)

if err != nil {
t.Fatalf("could not conect to test server: %s", err)
}
resp, err := conn.Metadata(&proto.MetadataReq{
ClientID: "tester",
Topics: []string{"first", "second"},
})
if err != nil {
t.Fatalf("could not fetch response: %s", err)
}
if !reflect.DeepEqual(resp, resp1) {
t.Fatalf("expected different response %#v", resp)
}
if err := conn.Close(); err != nil {
t.Fatalf("could not close kafka connection: %s", err)
}
if err := ln.Close(); err != nil {
t.Fatalf("could not close test server: %s", err)
}
}
5 changes: 2 additions & 3 deletions kafkatest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func NewServer(middlewares ...Middleware) *Server {
topics: make(map[string]map[int32][]*proto.Message),
offsets: make(map[string]map[int32]map[string]*topicOffset),
middlewares: middlewares,
events: make(chan struct{}),
events: make(chan struct{}, 1000),
}
return s
}
Expand Down Expand Up @@ -370,8 +370,7 @@ func (s *Server) handleProduceRequest(nodeID int32, conn net.Conn, req *proto.Pr
respParts[pi].Offset = int64(len(t[part.ID])) - 1
}
}
close(s.events)
s.events = make(chan struct{})
s.events <- struct{}{}
return resp
}

Expand Down
7 changes: 5 additions & 2 deletions proto/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,9 @@ func (r *FetchReq) Bytes() ([]byte, error) {
enc.Encode(r.CorrelationID)
enc.Encode(r.ClientID)

enc.Encode(r.ReplicaID)
//enc.Encode(r.ReplicaID)
enc.Encode(int32(-1))

enc.Encode(r.MaxWaitTime)
enc.Encode(r.MinBytes)

Expand Down Expand Up @@ -1808,7 +1810,8 @@ func (r *OffsetReq) Bytes() ([]byte, error) {
enc.Encode(r.CorrelationID)
enc.Encode(r.ClientID)

enc.Encode(r.ReplicaID)
//enc.Encode(r.ReplicaID)
enc.Encode(int32(-1))

if r.Version >= KafkaV2 {
enc.Encode(r.IsolationLevel)
Expand Down
Loading

0 comments on commit a11af9f

Please sign in to comment.