diff --git a/go.mod b/go.mod index 5e0fb99..634d221 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ -module github.com/bwmarrin/snowflake +module github.com/hash-rabbit/snowflake go 1.16 diff --git a/snowflake.go b/snowflake.go index 194d50f..e326bbf 100644 --- a/snowflake.go +++ b/snowflake.go @@ -98,27 +98,13 @@ type ID int64 // NewNode returns a new snowflake node that can be used to generate snowflake // IDs func NewNode(node int64) (*Node, error) { - - if NodeBits+StepBits > 22 { - return nil, errors.New("Remember, you have a total 22 bits to share between Node/Step") - } - // re-calc in case custom NodeBits or StepBits were set - // DEPRECATED: the below block will be removed in a future release. - mu.Lock() - nodeMax = -1 ^ (-1 << NodeBits) - nodeMask = nodeMax << StepBits - stepMask = -1 ^ (-1 << StepBits) - timeShift = NodeBits + StepBits - nodeShift = StepBits - mu.Unlock() - n := Node{} n.node = node - n.nodeMax = -1 ^ (-1 << NodeBits) - n.nodeMask = n.nodeMax << StepBits - n.stepMask = -1 ^ (-1 << StepBits) - n.timeShift = NodeBits + StepBits - n.nodeShift = StepBits + + err := n.SetNodeAndStepBits(NodeBits, StepBits) + if err != nil { + return nil, err + } if n.node < 0 || n.node > n.nodeMax { return nil, errors.New("Node number must be between 0 and " + strconv.FormatInt(n.nodeMax, 10)) @@ -131,6 +117,35 @@ func NewNode(node int64) (*Node, error) { return &n, nil } +// SetNodeAndStepBits can set Node node and step bits by custom +func (n *Node) SetNodeAndStepBits(nodeBits, stepBits uint8) error { + if nodeBits+stepBits > 22 { + return errors.New("remember, you have a total 22 bits to share between Node/Step") + } + + // re-calc in case custom NodeBits or StepBits were set + // DEPRECATED: the below block will be removed in a future release. + mu.Lock() + NodeBits = nodeBits + StepBits = stepBits + nodeMax = -1 ^ (-1 << nodeBits) + nodeMask = nodeMax << stepBits + stepMask = -1 ^ (-1 << stepBits) + timeShift = nodeBits + stepBits + nodeShift = stepBits + mu.Unlock() + + n.mu.Lock() + n.nodeMax = -1 ^ (-1 << nodeBits) + n.nodeMask = n.nodeMax << stepBits + n.stepMask = -1 ^ (-1 << stepBits) + n.timeShift = nodeBits + stepBits + n.nodeShift = stepBits + n.mu.Unlock() + + return nil +} + // Generate creates and returns a unique snowflake ID // To help guarantee uniqueness // - Make sure your system is keeping accurate system time diff --git a/snowflake_test.go b/snowflake_test.go index ff750c4..e9ccfe4 100644 --- a/snowflake_test.go +++ b/snowflake_test.go @@ -58,6 +58,13 @@ func TestRace(t *testing.T) { } +func TestCustomSet(t *testing.T) { + node, _ := NewNode(1) + node.SetNodeAndStepBits(4, 4) + id := node.Generate() + t.Logf("Int64:%#v", id.Int64()) +} + //****************************************************************************** // Converters/Parsers Test funcs // We should have funcs here to test conversion both ways for everything