Skip to content

Commit

Permalink
Added delete option and output formatting.
Browse files Browse the repository at this point in the history
  • Loading branch information
wolfeidau committed Dec 13, 2015
1 parent f0ad453 commit 0100de4
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 27 deletions.
77 changes: 77 additions & 0 deletions cli.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package unicreds

import (
"encoding/csv"
"io"

"github.com/olekukonko/tablewriter"
)

const (
// TableFormatTerm format the table for a terminal session
TableFormatTerm = iota // 0
// TableFormatCSV format the table as CSV
TableFormatCSV // 1
)

// TableWriter enables writing of tables in a variety of formats
type TableWriter struct {
tableFormat int
headers []string
rows [][]string
wr io.Writer
}

// NewTable create a new table writer
func NewTable(wr io.Writer) *TableWriter {
return &TableWriter{wr: wr}
}

// SetHeaders set the column headers
func (tw *TableWriter) SetHeaders(headers []string) {
tw.headers = headers
}

// SetFormat set the format
func (tw *TableWriter) SetFormat(tableFormat int) {
tw.tableFormat = tableFormat
}

func (tw *TableWriter) Write(row []string) {
tw.rows = append(tw.rows, row)
}

// BulkWrite append an array of rows to the buffer
func (tw *TableWriter) BulkWrite(rows [][]string) {
tw.rows = append(tw.rows, rows...)
}

// Render render the table out to the supplied writer
func (tw *TableWriter) Render() error {
switch tw.tableFormat {
case TableFormatTerm:
table := tablewriter.NewWriter(tw.wr)
table.SetHeader(tw.headers)
table.AppendBulk(tw.rows)
table.Render()
case TableFormatCSV:
w := csv.NewWriter(tw.wr)

if err := w.Write(tw.headers); err != nil {
return err
}

for _, r := range tw.rows {
if err := w.Write(r); err != nil {
return err
}
}
w.Flush()

if err := w.Error(); err != nil {
return err
}
}

return nil
}
50 changes: 50 additions & 0 deletions cli_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package unicreds

import (
"bytes"
"testing"

"github.com/stretchr/testify/assert"
)

func TestRender(t *testing.T) {

tt := []struct {
tableFormat int
output string
headers []string
rows [][]string
}{
{
tableFormat: TableFormatTerm,
output: `+------------+-----------+
| NAME | VERSION |
+------------+-----------+
| testlogin1 | testpass1 |
| testlogin2 | testpass2 |
+------------+-----------+
`,
headers: []string{"Name", "Version"},
rows: [][]string{{"testlogin1", "testpass1"}, {"testlogin2", "testpass2"}},
},
{
tableFormat: TableFormatCSV,
output: "Name,Version\ntestlogin1,testpass1\ntestlogin2,testpass2\n",
headers: []string{"Name", "Version"},
rows: [][]string{{"testlogin1", "testpass1"}, {"testlogin2", "testpass2"}},
},
}

for _, tv := range tt {
var b bytes.Buffer

table := NewTable(&b)
table.SetHeaders(tv.headers)
table.SetFormat(tv.tableFormat)
table.BulkWrite(tv.rows)
table.Render()

assert.Equal(t, tv.output, b.String())
}

}
48 changes: 38 additions & 10 deletions cmd/unicred/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
var (
app = kingpin.New("unicreds", "A credential/secret storage command line tool.")
debug = app.Flag("debug", "Enable debug mode.").Bool()
csv = app.Flag("csv", "Enable csv output for table data.").Bool()

alias = app.Flag("alias", "KMS key alias.").Default("alias/credstash").String()

Expand All @@ -22,11 +23,13 @@ var (

cmdList = app.Command("list", "List all credentials names and version.")

cmdPut = app.Command("put", "Put a credential in the store.")
cmdPutName = cmdPut.Arg("credential", "The name of the credential to get.").Required().String()
cmdPutSecret = cmdPut.Arg("value", "The value of the credential to store.").Required().String()
cmdPut = app.Command("put", "Put a credential in the store.")
cmdPutName = cmdPut.Arg("credential", "The name of the credential to get.").Required().String()
cmdPutSecret = cmdPut.Arg("value", "The value of the credential to store.").Required().String()
cmdPutVersion = cmdPut.Arg("version", "The version to store with the credential.").Int()

cmdDelete = app.Command("delete", "Delete a credential from the store.")
cmdDelete = app.Command("delete", "Delete a credential from the store.")
cmdDeleteName = cmdDelete.Arg("credential", "The name of the credential to get.").Required().String()

// Version app version
Version = "1.0.0"
Expand All @@ -43,33 +46,58 @@ func main() {
}
fmt.Printf("%+v\n", cred.Secret)
case cmdPut.FullCommand():
err := unicreds.PutSecret(*cmdPutName, *cmdPutSecret, "")
var version string
if *cmdPutVersion != 0 {
version = fmt.Sprintf("%d", *cmdPutVersion)
}
err := unicreds.PutSecret(*cmdPutName, *cmdPutSecret, version)
if err != nil {
printFatalError(err)
}
fmt.Printf("%s has been stored\n", *cmdPutName)
case cmdList.FullCommand():
creds, err := unicreds.ListSecrets()
if err != nil {
printFatalError(err)
}

table := unicreds.NewTable(os.Stdout)
table.SetHeaders([]string{"Name", "Version"})

if *csv {
table.SetFormat(unicreds.TableFormatCSV)
}

for _, cred := range creds {
fmt.Printf("%s\t%s\n", cred.Name, cred.Version)
table.Write([]string{cred.Name, cred.Version})
}
table.Render()
case cmdGetAll.FullCommand():
creds, err := unicreds.ListSecrets()
if err != nil {
printFatalError(err)
}

table := unicreds.NewTable(os.Stdout)
table.SetHeaders([]string{"Name", "Secret"})

if *csv {
table.SetFormat(unicreds.TableFormatCSV)
}

for _, cred := range creds {
fmt.Printf("%s\t%s\n", cred.Name, cred.Secret)
table.Write([]string{cred.Name, cred.Secret})
}
table.Render()
case cmdDelete.FullCommand():
printFatalError(fmt.Errorf("Command %s not implemented", cmdDelete.FullCommand()))
err := unicreds.DeleteSecret(*cmdDeleteName)
if err != nil {
printFatalError(err)
}
}
}

func printFatalError(err error) {
fmt.Fprintf(os.Stderr, "error occured: %v\n", err)
os.Exit(1)
}

//func printFatal(msg, arg string)
73 changes: 64 additions & 9 deletions ds.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,34 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface"
)

const (
// Table the name of the dynamodb table
Table = "credential-store"

// Region the AWS region dynamodb table
Region = "ap-southeast-2"
Region = "us-west-2"

// KmsKey default KMS key alias name
KmsKey = "alias/credstash"
)

var (
dynamoSvc dynamodbiface.DynamoDBAPI

// ErrSecretNotFound returned when unable to find the specified secret in dynamodb
ErrSecretNotFound = errors.New("Secret Not Found")

// ErrHmacValidationFailed returned when the hmac signature validation fails
ErrHmacValidationFailed = errors.New("Secret HMAC validation failed")
)

func init() {
dynamoSvc = dynamodb.New(session.New(), aws.NewConfig())
}

// Credential managed credential information
type Credential struct {
Name string `ds:"name"`
Expand All @@ -46,9 +53,8 @@ type DecryptedCredential struct {

// CreateDBTable create the table which stores credentials
func CreateDBTable() (err error) {
svc := dynamodb.New(session.New(), &aws.Config{Region: aws.String(Region)})

res, err := svc.CreateTable(&dynamodb.CreateTableInput{
res, err := dynamoSvc.CreateTable(&dynamodb.CreateTableInput{
AttributeDefinitions: []*dynamodb.AttributeDefinition{
{
AttributeName: aws.String("name"),
Expand Down Expand Up @@ -88,9 +94,8 @@ func CreateDBTable() (err error) {

// GetSecret retrieve the secret from dynamodb using the name
func GetSecret(name string) (*DecryptedCredential, error) {
svc := dynamodb.New(session.New(), &aws.Config{Region: aws.String(Region)})

res, err := svc.Query(&dynamodb.QueryInput{
res, err := dynamoSvc.Query(&dynamodb.QueryInput{
TableName: aws.String(Table),
ExpressionAttributeNames: map[string]*string{
"#N": aws.String("name"),
Expand Down Expand Up @@ -127,9 +132,8 @@ func GetSecret(name string) (*DecryptedCredential, error) {

// ListSecrets return a list of secrets
func ListSecrets() ([]*DecryptedCredential, error) {
svc := dynamodb.New(session.New(), &aws.Config{Region: aws.String(Region)})

res, err := svc.Scan(&dynamodb.ScanInput{
res, err := dynamoSvc.Scan(&dynamodb.ScanInput{
TableName: aws.String(Table),
AttributesToGet: []*string{
aws.String("name"),
Expand Down Expand Up @@ -168,7 +172,6 @@ func ListSecrets() ([]*DecryptedCredential, error) {

// PutSecret retrieve the secret from dynamodb
func PutSecret(name, secret, version string) error {
svc := dynamodb.New(session.New(), &aws.Config{Region: aws.String(Region)})

if version == "" {
version = "1"
Expand Down Expand Up @@ -200,14 +203,66 @@ func PutSecret(name, secret, version string) error {
return err
}

_, err = svc.PutItem(&dynamodb.PutItemInput{
_, err = dynamoSvc.PutItem(&dynamodb.PutItemInput{
TableName: aws.String(Table),
Item: data,
})

return err
}

// DeleteSecret delete a secret
func DeleteSecret(name string) error {

res, err := dynamoSvc.Query(&dynamodb.QueryInput{
TableName: aws.String(Table),
ExpressionAttributeNames: map[string]*string{
"#N": aws.String("name"),
},
ExpressionAttributeValues: map[string]*dynamodb.AttributeValue{
":name": &dynamodb.AttributeValue{
S: aws.String(name),
},
},
KeyConditionExpression: aws.String("#N = :name"),
ConsistentRead: aws.Bool(true),
ScanIndexForward: aws.Bool(false), // descending order
})

if err != nil {
return err
}

for _, item := range res.Items {
cred := new(Credential)

err = Decode(item, cred)
if err != nil {
return err
}

fmt.Printf("deleting name=%s version=%s\n", cred.Name, cred.Version)

_, err = dynamoSvc.DeleteItem(&dynamodb.DeleteItemInput{
TableName: aws.String(Table),
Key: map[string]*dynamodb.AttributeValue{
"name": &dynamodb.AttributeValue{
S: aws.String(cred.Name),
},
"version": &dynamodb.AttributeValue{
S: aws.String(cred.Version),
},
},
})

if err != nil {
return err
}
}

return nil
}

func decryptCredential(cred *Credential) (*DecryptedCredential, error) {

wrappedKey, err := base64.StdEncoding.DecodeString(cred.Key)
Expand Down
11 changes: 3 additions & 8 deletions kms.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,13 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/aws/aws-sdk-go/service/kms/kmsiface"
)

// KeyManagement is a sub-set of the capabilities of the KMS client.
type KeyManagement interface {
GenerateDataKey(*kms.GenerateDataKeyInput) (*kms.GenerateDataKeyOutput, error)
Decrypt(*kms.DecryptInput) (*kms.DecryptOutput, error)
}

var kmsSvc KeyManagement
var kmsSvc kmsiface.KMSAPI

func init() {
kmsSvc = kms.New(session.New(), &aws.Config{Region: aws.String(Region)})
kmsSvc = kms.New(session.New(), aws.NewConfig())
}

// DataKey which contains the details of the KMS key
Expand Down

0 comments on commit 0100de4

Please sign in to comment.