diff --git a/internal/db/db.go b/internal/db/db.go new file mode 100644 index 00000000..cbc20c73 --- /dev/null +++ b/internal/db/db.go @@ -0,0 +1,68 @@ +package db + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/canonical/chisel/internal/jsonwall" + "github.com/klauspost/compress/zstd" +) + +const schema = "0.1" + +// New creates a new Chisel DB writer with the proper schema. +func New() *jsonwall.DBWriter { + options := jsonwall.DBWriterOptions{Schema: schema} + return jsonwall.NewDBWriter(&options) +} + +func dbPath(root string) string { + return filepath.Join(root, ".chisel.db") +} + +// Save uses the provided writer dbw to write the Chisel DB into the standard +// path under the provided root directory. +func Save(dbw *jsonwall.DBWriter, root string) error { + f, err := os.OpenFile(dbPath(root), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + return err + } + defer f.Close() + // chmod the existing file + if err := f.Chmod(0644); err != nil { + return err + } + zw, err := zstd.NewWriter(f) + if err != nil { + return err + } + if _, err := dbw.WriteTo(zw); err != nil { + return err + } + zw.Close() + return nil +} + +// Load reads a Chisel DB from the standard path under the provided root +// directory. If the Chisel DB doesn't exist, the returned error satisfies +// os.IsNotExist(err). +func Load(root string) (*jsonwall.DB, error) { + f, err := os.Open(dbPath(root)) + if err != nil { + return nil, err + } + defer f.Close() + zr, err := zstd.NewReader(f) + if err != nil { + return nil, err + } + db, err := jsonwall.ReadDB(zr) + if err != nil { + return nil, err + } + if s := db.Schema(); s != schema { + return nil, fmt.Errorf("invalid schema %#v", s) + } + return db, nil +} diff --git a/internal/db/db_test.go b/internal/db/db_test.go new file mode 100644 index 00000000..32ec2f7d --- /dev/null +++ b/internal/db/db_test.go @@ -0,0 +1,74 @@ +package db_test + +import ( + "sort" + + "github.com/canonical/chisel/internal/db" + . "gopkg.in/check.v1" +) + +type testEntry struct { + S string `json:"s,omitempty"` + I int64 `json:"i,omitempty"` + L []string `json:"l,omitempty"` + M map[string]bool `json:"m,omitempty"` +} + +var saveLoadTestCase = []testEntry{ + {"", 0, nil, nil}, + {"hello", -1, nil, nil}, + {"", 0, nil, nil}, + {"", 100, []string{"a", "b"}, nil}, + {"", 0, nil, map[string]bool{"a": true, "b": false}}, + {"abc", 123, []string{"foo", "bar"}, nil}, +} + +func (s *S) TestSaveLoadRoundTrip(c *C) { + // To compare expected and obtained entries we first wrap the original + // entries in wrappers with increasing K. When we read the wrappers back + // they may be in different order because jsonwall sorts them serialized + // as JSON. So we sort them by K to compare them in the original order. + + type wrapper struct { + // test values + testEntry + // sort key for comparison + K int `json:"key"` + } + + // wrap the entries with increasing K + expected := make([]wrapper, len(saveLoadTestCase)) + for i, entry := range saveLoadTestCase { + expected[i] = wrapper{entry, i} + } + + workDir := c.MkDir() + dbw := db.New() + for _, entry := range expected { + err := dbw.Add(entry) + c.Assert(err, IsNil) + } + err := db.Save(dbw, workDir) + c.Assert(err, IsNil) + + dbr, err := db.Load(workDir) + c.Assert(err, IsNil) + c.Assert(dbr.Schema(), Equals, db.Schema) + + iter, err := dbr.Iterate(nil) + c.Assert(err, IsNil) + + obtained := make([]wrapper, 0, len(expected)) + for iter.Next() { + var wrapped wrapper + err := iter.Get(&wrapped) + c.Assert(err, IsNil) + obtained = append(obtained, wrapped) + } + + // sort the entries by K to get the original order + sort.Slice(obtained, func(i, j int) bool { + return obtained[i].K < obtained[j].K + }) + c.Assert(obtained, DeepEquals, expected) +} diff --git a/internal/db/export_test.go b/internal/db/export_test.go new file mode 100644 index 00000000..f04ccf3b --- /dev/null +++ b/internal/db/export_test.go @@ -0,0 +1,3 @@ +package db + +var Schema = schema diff --git a/internal/db/suite_test.go b/internal/db/suite_test.go new file mode 100644 index 00000000..4334c717 --- /dev/null +++ b/internal/db/suite_test.go @@ -0,0 +1,15 @@ +package db_test + +import ( + "testing" + + . "gopkg.in/check.v1" +) + +func Test(t *testing.T) { + TestingT(t) +} + +type S struct{} + +var _ = Suite(&S{})