Skip to content

Commit

Permalink
fix: add kms encryption
Browse files Browse the repository at this point in the history
  • Loading branch information
lutaoact committed May 31, 2022
1 parent 81b52bc commit 8509cf6
Show file tree
Hide file tree
Showing 10 changed files with 1,013 additions and 314 deletions.
3 changes: 3 additions & 0 deletions sample.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,7 @@ func main() {

// SQL sample
sample.SQLQuerySample(client)

// Server side encryption sample
sample.ServerSideEncryptionSample(client)
}
171 changes: 171 additions & 0 deletions sample/ServerSideEncryptionSample.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
package sample

import (
"fmt"
"github.com/aliyun/aliyun-tablestore-go-sdk/tablestore"
"time"
)

const (
TABLE_NAME_DISABLE = "disableSseSampleTable"
TABLE_NAME_KMS_SERVICE = "kmsServiceSampleTable"
TABLE_NAME_BYOK = "byokSampleTable"
PRIMARY_KEY_NAME = "pk"

BYOK_KEY_ID = ""
BYOK_ROLE_ARN = "acs:ram::<aliuid>:role/kms-ots-test"
)

func ServerSideEncryptionSample(client *tablestore.TableStoreClient) {
// 创建关闭服务器端加密功能的表
deleteTableIfExist(client, TABLE_NAME_DISABLE)
createTableDisableSse(client, TABLE_NAME_DISABLE)

// 创建开启服务器端加密功能(服务主秘钥)的表
deleteTableIfExist(client, TABLE_NAME_KMS_SERVICE)
createTableKmsService(client, TABLE_NAME_KMS_SERVICE)

// 创建开启服务器端加密功能(用户主秘钥)的表
deleteTableIfExist(client, TABLE_NAME_BYOK)
createTableByok(client, TABLE_NAME_BYOK, BYOK_KEY_ID, BYOK_ROLE_ARN)

// 查看表的属性
describeTable(client, TABLE_NAME_DISABLE)
describeTable(client, TABLE_NAME_KMS_SERVICE)
describeTable(client, TABLE_NAME_BYOK)

// 等待表load完毕.
time.Sleep(10 * time.Second)

// 各写入一行数据
putRow(client, TABLE_NAME_DISABLE, "pkValue")
putRow(client, TABLE_NAME_KMS_SERVICE, "pkValue")
putRow(client, TABLE_NAME_BYOK, "pkValue")

// 各读取该行数据
getRow(client, TABLE_NAME_DISABLE, "pkValue")
getRow(client, TABLE_NAME_KMS_SERVICE, "pkValue")
getRow(client, TABLE_NAME_BYOK, "pkValue")
}

func deleteTableIfExist(client *tablestore.TableStoreClient, tableName string) {
_, err := client.DeleteTable(&tablestore.DeleteTableRequest{
TableName: tableName,
})
if err != nil {
fmt.Println("DeleteTable failed", tableName, err.Error())
}
}

func createTable(client *tablestore.TableStoreClient, tableName string, sseSpec *tablestore.SSESpecification) {
createtableRequest := new(tablestore.CreateTableRequest)
tableMeta := new(tablestore.TableMeta)
tableMeta.TableName = tableName
tableMeta.AddPrimaryKeyColumn(PRIMARY_KEY_NAME, tablestore.PrimaryKeyType_STRING)
tableOption := new(tablestore.TableOption)
tableOption.TimeToAlive = -1
tableOption.MaxVersion = 3
reservedThroughput := new(tablestore.ReservedThroughput)
reservedThroughput.Readcap = 0
reservedThroughput.Writecap = 0
createtableRequest.TableMeta = tableMeta
createtableRequest.TableOption = tableOption
createtableRequest.ReservedThroughput = reservedThroughput
createtableRequest.SSESpecification = sseSpec

_, err := client.CreateTable(createtableRequest)
if err != nil {
fmt.Println("CreateTable failed", tableName, err.Error())
}
}

func createTableDisableSse(client *tablestore.TableStoreClient, tableName string) {
// 关闭服务器端加密功能
sseSpec := new(tablestore.SSESpecification)
sseSpec.SetEnable(false)

createTable(client, tableName, sseSpec)
}

func createTableKmsService(client *tablestore.TableStoreClient, tableName string) {
// 打开服务器端加密功能,使用KMS的服务主密钥
// 需要确保已经在所在区域开通了KMS服务
sseSpec := new(tablestore.SSESpecification)
sseSpec.SetEnable(true)
sseSpec.SetKeyType(tablestore.SSE_KMS_SERVICE)

createTable(client, tableName, sseSpec)
}

func createTableByok(client *tablestore.TableStoreClient, tableName string, keyId string, roleArn string) {
// 打开服务器端加密功能,使用KMS的用户主密钥
// 需要确保keyId合法有效且未被禁用,同时roleArn被授予了临时访问该keyId的权限
sseSpec := new(tablestore.SSESpecification)
sseSpec.SetEnable(true)
sseSpec.SetKeyType(tablestore.SSE_BYOK)
sseSpec.SetKeyId(keyId)
sseSpec.SetRoleArn(roleArn)

createTable(client, tableName, sseSpec)
}

func describeTable(client *tablestore.TableStoreClient, tableName string) {
resp, err := client.DescribeTable(&tablestore.DescribeTableRequest{
TableName: tableName,
})
if err != nil {
fmt.Println("describe table failed", tableName, err.Error())
return
}
fmt.Println("表的名称:" + resp.TableMeta.TableName)
sseDetails := resp.SSEDetails
if sseDetails.Enable {
fmt.Println("表是否开启服务器端加密功能:是")
fmt.Println("表的加密秘钥类型:", sseDetails.KeyType.String())
fmt.Println("表的加密主密钥id:", sseDetails.KeyId)
if sseDetails.KeyType == tablestore.SSE_BYOK {
fmt.Println("表的全局资源描述符:" + sseDetails.RoleArn)
}
} else {
fmt.Println("表是否开启服务器端加密功能:否")
}

}

func putRow(client *tablestore.TableStoreClient, tableName string, pkValue string) {
putRowRequest := new(tablestore.PutRowRequest)
putRowChange := new(tablestore.PutRowChange)
putRowChange.TableName = tableName
putPk := new(tablestore.PrimaryKey)
putPk.AddPrimaryKeyColumn(PRIMARY_KEY_NAME, pkValue)

putRowChange.PrimaryKey = putPk
putRowChange.AddColumn("price", int64(5120))
putRowChange.SetCondition(tablestore.RowExistenceExpectation_IGNORE)
putRowRequest.PutRowChange = putRowChange
_, err := client.PutRow(putRowRequest)
if err != nil {
fmt.Println("PutRow failed", tableName, err.Error())
}
}

func getRow(client *tablestore.TableStoreClient, tableName string, pkValue string) {
getRowRequest := new(tablestore.GetRowRequest)
criteria := new(tablestore.SingleRowQueryCriteria)
putPk := new(tablestore.PrimaryKey)
putPk.AddPrimaryKeyColumn(PRIMARY_KEY_NAME, pkValue)

criteria.PrimaryKey = putPk
getRowRequest.SingleRowQueryCriteria = criteria
getRowRequest.SingleRowQueryCriteria.TableName = tableName
getRowRequest.SingleRowQueryCriteria.MaxVersion = 1
getResp, err := client.GetRow(getRowRequest)

if err != nil {
fmt.Println("GetRow failed", tableName, err)
} else {
colmap := getResp.GetColumnMap()
fmt.Println(tableName, "length is ", len(colmap.Columns))
fmt.Println("get row col0 result is ", getResp.Columns[0].ColumnName, getResp.Columns[0].Value)
}
}
85 changes: 73 additions & 12 deletions tablestore/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,9 @@ func (tableStoreClient *TableStoreClient) CreateTable(request *CreateTableReques
if request.StreamSpec.EnableStream {
ss = otsprotocol.StreamSpecification{
EnableStream: &request.StreamSpec.EnableStream,
ExpirationTime: &request.StreamSpec.ExpirationTime}
ExpirationTime: &request.StreamSpec.ExpirationTime,
ColumnsToGet: request.StreamSpec.OriginColumnsToGet,
}
} else {
ss = otsprotocol.StreamSpecification{
EnableStream: &request.StreamSpec.EnableStream}
Expand All @@ -499,6 +501,37 @@ func (tableStoreClient *TableStoreClient) CreateTable(request *CreateTableReques
req.StreamSpec = &ss
}

if request.SSESpecification != nil {
if err := request.SSESpecification.CheckArguments(); err != nil {
return nil, err
}
sse := new(otsprotocol.SSESpecification)
sse.Enable = proto.Bool(request.SSESpecification.Enable)
if request.SSESpecification.KeyType != nil {
sseType := *request.SSESpecification.KeyType
switch sseType {
case SSE_KMS_SERVICE:
keyType := otsprotocol.SSEKeyType_SSE_KMS_SERVICE
sse.KeyType = &keyType
case SSE_BYOK:
keyType := otsprotocol.SSEKeyType_SSE_BYOK
sse.KeyType = &keyType
default:
return nil, errInvalidSSEKeyType(sseType.String())
}
}

if request.SSESpecification.KeyId != nil {
sse.KeyId = []byte(*request.SSESpecification.KeyId)
}

if request.SSESpecification.RoleArn != nil {
sse.RoleArn = []byte(*request.SSESpecification.RoleArn)
}

req.SseSpec = sse
}

resp := new(otsprotocol.CreateTableResponse)
response := &CreateTableResponse{}
if err := tableStoreClient.doRequestWithRetry(createTableUri, req, resp, &response.ResponseInfo); err != nil {
Expand Down Expand Up @@ -1030,11 +1063,11 @@ func (tableStoreClient *TableStoreClient) DescribeTable(request *DescribeTableRe

if resp.StreamDetails != nil && *resp.StreamDetails.EnableStream {
response.StreamDetails = &StreamDetails{
EnableStream: *resp.StreamDetails.EnableStream,
StreamId: (*StreamId)(resp.StreamDetails.StreamId),
ExpirationTime: *resp.StreamDetails.ExpirationTime,
LastEnableTime: *resp.StreamDetails.LastEnableTime,
ColumnsToGet: resp.GetStreamDetails().GetColumnsToGet(),
EnableStream: *resp.StreamDetails.EnableStream,
StreamId: (*StreamId)(resp.StreamDetails.StreamId),
ExpirationTime: *resp.StreamDetails.ExpirationTime,
LastEnableTime: *resp.StreamDetails.LastEnableTime,
OriginColumnsToGet: resp.GetStreamDetails().GetColumnsToGet(),
}
} else {
response.StreamDetails = &StreamDetails{
Expand All @@ -1045,6 +1078,34 @@ func (tableStoreClient *TableStoreClient) DescribeTable(request *DescribeTableRe
response.IndexMetas = append(response.IndexMetas, ConvertPbIndexMetaToIndexMeta(meta))
}

if resp.GetSseDetails() == nil {
response.SSEDetails = &SSEDetails{
Enable: false,
}
} else {
respSse := resp.GetSseDetails()
sseDetail := new(SSEDetails)
sseDetail.Enable = resp.GetSseDetails().GetEnable()
switch respSse.GetKeyType() {
case otsprotocol.SSEKeyType_SSE_KMS_SERVICE:
sseDetail.KeyType = SSE_KMS_SERVICE
case otsprotocol.SSEKeyType_SSE_BYOK:
sseDetail.KeyType = SSE_BYOK
default:
return nil, errInvalidSSEKeyType(respSse.GetKeyType().String())
}

if respSse.GetKeyId() != nil {
sseDetail.KeyId = string(respSse.GetKeyId())
}

if respSse.GetRoleArn() != nil {
sseDetail.RoleArn = string(respSse.GetRoleArn())
}

response.SSEDetails = sseDetail
}

return response, nil
}

Expand Down Expand Up @@ -1080,7 +1141,7 @@ func (tableStoreClient *TableStoreClient) UpdateTable(request *UpdateTableReques
req.StreamSpec = &otsprotocol.StreamSpecification{
EnableStream: &request.StreamSpec.EnableStream,
ExpirationTime: &request.StreamSpec.ExpirationTime,
ColumnsToGet: request.StreamSpec.ColumnsToGet,
ColumnsToGet: request.StreamSpec.OriginColumnsToGet,
}
} else {
req.StreamSpec = &otsprotocol.StreamSpecification{EnableStream: &request.StreamSpec.EnableStream}
Expand All @@ -1104,11 +1165,11 @@ func (tableStoreClient *TableStoreClient) UpdateTable(request *UpdateTableReques

if *resp.StreamDetails.EnableStream {
response.StreamDetails = &StreamDetails{
EnableStream: *resp.StreamDetails.EnableStream,
StreamId: (*StreamId)(resp.StreamDetails.StreamId),
ExpirationTime: *resp.StreamDetails.ExpirationTime,
LastEnableTime: *resp.StreamDetails.LastEnableTime,
ColumnsToGet: resp.GetStreamDetails().GetColumnsToGet(),
EnableStream: *resp.StreamDetails.EnableStream,
StreamId: (*StreamId)(resp.StreamDetails.StreamId),
ExpirationTime: *resp.StreamDetails.ExpirationTime,
LastEnableTime: *resp.StreamDetails.LastEnableTime,
OriginColumnsToGet: resp.GetStreamDetails().GetColumnsToGet(),
}
} else {
response.StreamDetails = &StreamDetails{
Expand Down
56 changes: 49 additions & 7 deletions tablestore/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,6 @@ func PrepareSQLSearchIndex(c *C, tableName string, indexName string) {
}

func (s *TableStoreSuite) TestCreateTable(c *C) {
fmt.Println("TestCreateTable finished")

tableName := tableNamePrefix + "testcreatetable1"

deleteReq := new(DeleteTableRequest)
Expand Down Expand Up @@ -243,10 +241,54 @@ func (s *TableStoreSuite) TestCreateTable(c *C) {

_, error := client.CreateTable(createtableRequest)
c.Check(error, Equals, nil)

fmt.Println("TestCreateTable finished")
}

func (s *TableStoreSuite) TestCreateTableWithOriginColumn(c *C) {
tableName := tableNamePrefix + "originColumn"
defer func() {
deleteReq := new(DeleteTableRequest)
deleteReq.TableName = tableName
client.DeleteTable(deleteReq)
}()

ctReq := new(CreateTableRequest)
tableMeta := new(TableMeta)
tableMeta.TableName = tableName
tableMeta.AddPrimaryKeyColumn("pk1", PrimaryKeyType_STRING)

tableOption := new(TableOption)

tableOption.TimeToAlive = -1
tableOption.MaxVersion = 3

reservedThroughput := new(ReservedThroughput)
reservedThroughput.Readcap = 0
reservedThroughput.Writecap = 0

ctReq.TableMeta = tableMeta
ctReq.TableOption = tableOption
ctReq.ReservedThroughput = reservedThroughput

ctReq.StreamSpec = &StreamSpecification{
EnableStream: true,
ExpirationTime: 168,
OriginColumnsToGet: []string{"col1", "col2"},
}

_, err := client.CreateTable(ctReq)
c.Check(err, Equals, nil)

descTableRequest := &DescribeTableRequest{
TableName: tableName,
}
descResp, err := client.DescribeTable(descTableRequest)
c.Check(err, Equals, nil)
c.Check(2, Equals, len(descResp.StreamDetails.OriginColumnsToGet))

fmt.Println("TestCreateTableWithOriginColumn finished")
}

func (s *TableStoreSuite) TestReCreateTableAndPutRow(c *C) {
fmt.Println("TestReCreateTableAndPutRow started")

Expand Down Expand Up @@ -330,7 +372,7 @@ func (s *TableStoreSuite) TestUpdateAndDescribeTable(c *C) {
updateTableReq.StreamSpec = new(StreamSpecification)
updateTableReq.StreamSpec.EnableStream = true
updateTableReq.StreamSpec.ExpirationTime = 168
updateTableReq.StreamSpec.ColumnsToGet = []string{"col1", "col2"}
updateTableReq.StreamSpec.OriginColumnsToGet = []string{"col1", "col2"}

updateTableResp, error := client.UpdateTable(updateTableReq)
c.Assert(error, Equals, nil)
Expand All @@ -350,9 +392,9 @@ func (s *TableStoreSuite) TestUpdateAndDescribeTable(c *C) {
c.Assert(describ.TableOption.MaxVersion, Equals, updateTableReq.TableOption.MaxVersion)
c.Assert(describ.StreamDetails.EnableStream, Equals, updateTableReq.StreamSpec.EnableStream)
c.Assert(describ.StreamDetails.ExpirationTime, Equals, updateTableReq.StreamSpec.ExpirationTime)
c.Assert(len(describ.StreamDetails.ColumnsToGet), Equals, len(updateTableReq.StreamSpec.ColumnsToGet))
for i, s := range describ.StreamDetails.ColumnsToGet {
c.Assert(s, Equals, updateTableReq.StreamSpec.ColumnsToGet[i])
c.Assert(len(describ.StreamDetails.OriginColumnsToGet), Equals, len(updateTableReq.StreamSpec.OriginColumnsToGet))
for i, s := range describ.StreamDetails.OriginColumnsToGet {
c.Assert(s, Equals, updateTableReq.StreamSpec.OriginColumnsToGet[i])
}
fmt.Println("TestUpdateAndDescribeTable finished")
}
Expand Down
Loading

0 comments on commit 8509cf6

Please sign in to comment.