diff --git a/cli.go b/cli.go new file mode 100644 index 0000000..95d8e6b --- /dev/null +++ b/cli.go @@ -0,0 +1,77 @@ +package unicreds + +import ( + "encoding/csv" + "io" + + "github.com/olekukonko/tablewriter" +) + +const ( + // TableFormatTerm format the table for a terminal session + TableFormatTerm = iota // 0 + // TableFormatCSV format the table as CSV + TableFormatCSV // 1 +) + +// TableWriter enables writing of tables in a variety of formats +type TableWriter struct { + tableFormat int + headers []string + rows [][]string + wr io.Writer +} + +// NewTable create a new table writer +func NewTable(wr io.Writer) *TableWriter { + return &TableWriter{wr: wr} +} + +// SetHeaders set the column headers +func (tw *TableWriter) SetHeaders(headers []string) { + tw.headers = headers +} + +// SetFormat set the format +func (tw *TableWriter) SetFormat(tableFormat int) { + tw.tableFormat = tableFormat +} + +func (tw *TableWriter) Write(row []string) { + tw.rows = append(tw.rows, row) +} + +// BulkWrite append an array of rows to the buffer +func (tw *TableWriter) BulkWrite(rows [][]string) { + tw.rows = append(tw.rows, rows...) +} + +// Render render the table out to the supplied writer +func (tw *TableWriter) Render() error { + switch tw.tableFormat { + case TableFormatTerm: + table := tablewriter.NewWriter(tw.wr) + table.SetHeader(tw.headers) + table.AppendBulk(tw.rows) + table.Render() + case TableFormatCSV: + w := csv.NewWriter(tw.wr) + + if err := w.Write(tw.headers); err != nil { + return err + } + + for _, r := range tw.rows { + if err := w.Write(r); err != nil { + return err + } + } + w.Flush() + + if err := w.Error(); err != nil { + return err + } + } + + return nil +} diff --git a/cli_test.go b/cli_test.go new file mode 100644 index 0000000..451a8cc --- /dev/null +++ b/cli_test.go @@ -0,0 +1,50 @@ +package unicreds + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRender(t *testing.T) { + + tt := []struct { + tableFormat int + output string + headers []string + rows [][]string + }{ + { + tableFormat: TableFormatTerm, + output: `+------------+-----------+ +| NAME | VERSION | ++------------+-----------+ +| testlogin1 | testpass1 | +| testlogin2 | testpass2 | ++------------+-----------+ +`, + headers: []string{"Name", "Version"}, + rows: [][]string{{"testlogin1", "testpass1"}, {"testlogin2", "testpass2"}}, + }, + { + tableFormat: TableFormatCSV, + output: "Name,Version\ntestlogin1,testpass1\ntestlogin2,testpass2\n", + headers: []string{"Name", "Version"}, + rows: [][]string{{"testlogin1", "testpass1"}, {"testlogin2", "testpass2"}}, + }, + } + + for _, tv := range tt { + var b bytes.Buffer + + table := NewTable(&b) + table.SetHeaders(tv.headers) + table.SetFormat(tv.tableFormat) + table.BulkWrite(tv.rows) + table.Render() + + assert.Equal(t, tv.output, b.String()) + } + +} diff --git a/cmd/unicred/main.go b/cmd/unicred/main.go index c05e478..a8d48dc 100644 --- a/cmd/unicred/main.go +++ b/cmd/unicred/main.go @@ -11,6 +11,7 @@ import ( var ( app = kingpin.New("unicreds", "A credential/secret storage command line tool.") debug = app.Flag("debug", "Enable debug mode.").Bool() + csv = app.Flag("csv", "Enable csv output for table data.").Bool() alias = app.Flag("alias", "KMS key alias.").Default("alias/credstash").String() @@ -22,11 +23,13 @@ var ( cmdList = app.Command("list", "List all credentials names and version.") - cmdPut = app.Command("put", "Put a credential in the store.") - cmdPutName = cmdPut.Arg("credential", "The name of the credential to get.").Required().String() - cmdPutSecret = cmdPut.Arg("value", "The value of the credential to store.").Required().String() + cmdPut = app.Command("put", "Put a credential in the store.") + cmdPutName = cmdPut.Arg("credential", "The name of the credential to get.").Required().String() + cmdPutSecret = cmdPut.Arg("value", "The value of the credential to store.").Required().String() + cmdPutVersion = cmdPut.Arg("version", "The version to store with the credential.").Int() - cmdDelete = app.Command("delete", "Delete a credential from the store.") + cmdDelete = app.Command("delete", "Delete a credential from the store.") + cmdDeleteName = cmdDelete.Arg("credential", "The name of the credential to get.").Required().String() // Version app version Version = "1.0.0" @@ -43,33 +46,58 @@ func main() { } fmt.Printf("%+v\n", cred.Secret) case cmdPut.FullCommand(): - err := unicreds.PutSecret(*cmdPutName, *cmdPutSecret, "") + var version string + if *cmdPutVersion != 0 { + version = fmt.Sprintf("%d", *cmdPutVersion) + } + err := unicreds.PutSecret(*cmdPutName, *cmdPutSecret, version) if err != nil { printFatalError(err) } + fmt.Printf("%s has been stored\n", *cmdPutName) case cmdList.FullCommand(): creds, err := unicreds.ListSecrets() if err != nil { printFatalError(err) } + + table := unicreds.NewTable(os.Stdout) + table.SetHeaders([]string{"Name", "Version"}) + + if *csv { + table.SetFormat(unicreds.TableFormatCSV) + } + for _, cred := range creds { - fmt.Printf("%s\t%s\n", cred.Name, cred.Version) + table.Write([]string{cred.Name, cred.Version}) } + table.Render() case cmdGetAll.FullCommand(): creds, err := unicreds.ListSecrets() if err != nil { printFatalError(err) } + + table := unicreds.NewTable(os.Stdout) + table.SetHeaders([]string{"Name", "Secret"}) + + if *csv { + table.SetFormat(unicreds.TableFormatCSV) + } + for _, cred := range creds { - fmt.Printf("%s\t%s\n", cred.Name, cred.Secret) + table.Write([]string{cred.Name, cred.Secret}) } + table.Render() case cmdDelete.FullCommand(): - printFatalError(fmt.Errorf("Command %s not implemented", cmdDelete.FullCommand())) + err := unicreds.DeleteSecret(*cmdDeleteName) + if err != nil { + printFatalError(err) + } } } + func printFatalError(err error) { fmt.Fprintf(os.Stderr, "error occured: %v\n", err) os.Exit(1) } - -//func printFatal(msg, arg string) diff --git a/ds.go b/ds.go index 6e14d73..279a0f3 100644 --- a/ds.go +++ b/ds.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" ) const ( @@ -15,13 +16,15 @@ const ( Table = "credential-store" // Region the AWS region dynamodb table - Region = "ap-southeast-2" + Region = "us-west-2" // KmsKey default KMS key alias name KmsKey = "alias/credstash" ) var ( + dynamoSvc dynamodbiface.DynamoDBAPI + // ErrSecretNotFound returned when unable to find the specified secret in dynamodb ErrSecretNotFound = errors.New("Secret Not Found") @@ -29,6 +32,10 @@ var ( ErrHmacValidationFailed = errors.New("Secret HMAC validation failed") ) +func init() { + dynamoSvc = dynamodb.New(session.New(), aws.NewConfig()) +} + // Credential managed credential information type Credential struct { Name string `ds:"name"` @@ -46,9 +53,8 @@ type DecryptedCredential struct { // CreateDBTable create the table which stores credentials func CreateDBTable() (err error) { - svc := dynamodb.New(session.New(), &aws.Config{Region: aws.String(Region)}) - res, err := svc.CreateTable(&dynamodb.CreateTableInput{ + res, err := dynamoSvc.CreateTable(&dynamodb.CreateTableInput{ AttributeDefinitions: []*dynamodb.AttributeDefinition{ { AttributeName: aws.String("name"), @@ -88,9 +94,8 @@ func CreateDBTable() (err error) { // GetSecret retrieve the secret from dynamodb using the name func GetSecret(name string) (*DecryptedCredential, error) { - svc := dynamodb.New(session.New(), &aws.Config{Region: aws.String(Region)}) - res, err := svc.Query(&dynamodb.QueryInput{ + res, err := dynamoSvc.Query(&dynamodb.QueryInput{ TableName: aws.String(Table), ExpressionAttributeNames: map[string]*string{ "#N": aws.String("name"), @@ -127,9 +132,8 @@ func GetSecret(name string) (*DecryptedCredential, error) { // ListSecrets return a list of secrets func ListSecrets() ([]*DecryptedCredential, error) { - svc := dynamodb.New(session.New(), &aws.Config{Region: aws.String(Region)}) - res, err := svc.Scan(&dynamodb.ScanInput{ + res, err := dynamoSvc.Scan(&dynamodb.ScanInput{ TableName: aws.String(Table), AttributesToGet: []*string{ aws.String("name"), @@ -168,7 +172,6 @@ func ListSecrets() ([]*DecryptedCredential, error) { // PutSecret retrieve the secret from dynamodb func PutSecret(name, secret, version string) error { - svc := dynamodb.New(session.New(), &aws.Config{Region: aws.String(Region)}) if version == "" { version = "1" @@ -200,7 +203,7 @@ func PutSecret(name, secret, version string) error { return err } - _, err = svc.PutItem(&dynamodb.PutItemInput{ + _, err = dynamoSvc.PutItem(&dynamodb.PutItemInput{ TableName: aws.String(Table), Item: data, }) @@ -208,6 +211,58 @@ func PutSecret(name, secret, version string) error { return err } +// DeleteSecret delete a secret +func DeleteSecret(name string) error { + + res, err := dynamoSvc.Query(&dynamodb.QueryInput{ + TableName: aws.String(Table), + ExpressionAttributeNames: map[string]*string{ + "#N": aws.String("name"), + }, + ExpressionAttributeValues: map[string]*dynamodb.AttributeValue{ + ":name": &dynamodb.AttributeValue{ + S: aws.String(name), + }, + }, + KeyConditionExpression: aws.String("#N = :name"), + ConsistentRead: aws.Bool(true), + ScanIndexForward: aws.Bool(false), // descending order + }) + + if err != nil { + return err + } + + for _, item := range res.Items { + cred := new(Credential) + + err = Decode(item, cred) + if err != nil { + return err + } + + fmt.Printf("deleting name=%s version=%s\n", cred.Name, cred.Version) + + _, err = dynamoSvc.DeleteItem(&dynamodb.DeleteItemInput{ + TableName: aws.String(Table), + Key: map[string]*dynamodb.AttributeValue{ + "name": &dynamodb.AttributeValue{ + S: aws.String(cred.Name), + }, + "version": &dynamodb.AttributeValue{ + S: aws.String(cred.Version), + }, + }, + }) + + if err != nil { + return err + } + } + + return nil +} + func decryptCredential(cred *Credential) (*DecryptedCredential, error) { wrappedKey, err := base64.StdEncoding.DecodeString(cred.Key) diff --git a/kms.go b/kms.go index 96e6cba..815c414 100644 --- a/kms.go +++ b/kms.go @@ -4,18 +4,13 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/kms" + "github.com/aws/aws-sdk-go/service/kms/kmsiface" ) -// KeyManagement is a sub-set of the capabilities of the KMS client. -type KeyManagement interface { - GenerateDataKey(*kms.GenerateDataKeyInput) (*kms.GenerateDataKeyOutput, error) - Decrypt(*kms.DecryptInput) (*kms.DecryptOutput, error) -} - -var kmsSvc KeyManagement +var kmsSvc kmsiface.KMSAPI func init() { - kmsSvc = kms.New(session.New(), &aws.Config{Region: aws.String(Region)}) + kmsSvc = kms.New(session.New(), aws.NewConfig()) } // DataKey which contains the details of the KMS key