Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for graceful termination #20

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,19 @@ func main() {

l.Println("Started processing events")

process_events := true

sigc := make(chan os.Signal, 1)
signal.Notify(sigc, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
go func() {
for _ = range sigc {
nlClient.Close()
process_events = false
}
}()

//Main loop. Get data from netlink and send it to the json lib for processing
for {
for process_events {
msg, err := nlClient.Receive()
if err != nil {
el.Printf("Error during message receive: %+v\n", err)
Expand All @@ -356,4 +367,7 @@ func main() {

marshaller.Consume(msg)
}

l.Println("Exiting go-audit.")
os.Exit(0)
}
24 changes: 21 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ func NewNetlinkClient(recvSize int) *NetlinkClient {

go func() {
for {
// Bail if our socket is closed, which typically means we're exiting
if n.fd < 0 {
return
}

n.KeepConnection()
time.Sleep(time.Second * 5)
}
Expand Down Expand Up @@ -127,11 +132,11 @@ func (n *NetlinkClient) Receive() (*syscall.NetlinkMessage, error) {
return msg, nil
}

func (n *NetlinkClient) KeepConnection() {
func (n *NetlinkClient) SetPid(pid uint32) {
payload := &AuditStatusPayload{
Mask: 4,
Enabled: 1,
Pid: uint32(syscall.Getpid()),
Pid: pid,
//TODO: Failure: http://lxr.free-electrons.com/source/include/uapi/linux/audit.h#L338
}

Expand All @@ -143,6 +148,19 @@ func (n *NetlinkClient) KeepConnection() {

err := n.Send(packet, payload)
if err != nil {
el.Println("Error occurred while trying to keep the connection:", err)
el.Println("Error occurred while trying to set the audit PID:", err)
}
}

func (n *NetlinkClient) KeepConnection() {
n.SetPid(uint32(syscall.Getpid()))
}

func (n *NetlinkClient) Close() {
// Tell kernel that we're closing
if n.fd >= 0 {
n.SetPid(0)
syscall.Close(n.fd)
n.fd = -1
}
}
38 changes: 37 additions & 1 deletion client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"os"
"syscall"
"testing"
"time"
)

func TestNetlinkClient_KeepConnection(t *testing.T) {
Expand Down Expand Up @@ -34,7 +35,7 @@ func TestNetlinkClient_KeepConnection(t *testing.T) {
syscall.Close(n.fd)
n.KeepConnection()
assert.Equal(t, "", lb.String(), "Got some log lines we did not expect")
assert.Equal(t, "Error occurred while trying to keep the connection: bad file descriptor\n", elb.String(), "Figured we would have an error")
assert.Equal(t, "Error occurred while trying to set the audit PID: bad file descriptor\n", elb.String(), "Figured we would have an error")
}

func TestNetlinkClient_SendReceive(t *testing.T) {
Expand Down Expand Up @@ -102,6 +103,41 @@ func TestNewNetlinkClient(t *testing.T) {
assert.Equal(t, "", elb.String(), "Did not expect any error messages")
}

func TestNetlinkClient_Close(t *testing.T) {
n := makeNelinkClient(t)

done := make(chan bool)

go func() {
msg, err := n.Receive()
if err != nil {
t.Log("Did not expect an error", err)
done <- false
return
}

expectedData := []byte{4, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
binary.LittleEndian.PutUint32(expectedData[12:16], 0)

assert.Equal(t, uint16(1001), msg.Header.Type, "Header.Type mismatch")
assert.Equal(t, uint16(5), msg.Header.Flags, "Header.Flags mismatch")
assert.Equal(t, uint32(1), msg.Header.Seq, "Header.Seq mismatch")
assert.Equal(t, uint32(56), msg.Header.Len, "Packet size is wrong - this test is brittle though")
assert.EqualValues(t, msg.Data[:40], expectedData, "data was wrong")

done <- true
}()

time.Sleep(100 * time.Millisecond)
n.Close()
if !(<-done) {
t.FailNow()
}

// Make sure fd was closed
assert.Equal(t, -1, n.fd, "Netlink fd was not closed")
}

// Helper to make a client listening on a unix socket
func makeNelinkClient(t *testing.T) *NetlinkClient {
os.Remove("go-audit.test.sock")
Expand Down