Skip to content

Commit

Permalink
Use new Rand
Browse files Browse the repository at this point in the history
  • Loading branch information
Swopper050 committed Aug 21, 2024
1 parent 50a8213 commit b0189f6
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 39 deletions.
4 changes: 2 additions & 2 deletions ops/fixtures.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ func Float32TensorFixture(shp ...int) tensor.Tensor {
)
}

func RandomFloat32TensorFixture(shp ...int) tensor.Tensor {
func RandomFloat32TensorFixture(r *rand.Rand, shp ...int) tensor.Tensor {
rands := make([]float32, NElements(shp...))
for i := 0; i < NElements(shp...); i++ {
rands[i] = rand.Float32()
rands[i] = r.Float32()
}

return tensor.New(
Expand Down
42 changes: 21 additions & 21 deletions ops/opset13/lstm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,17 +280,17 @@ func TestInputValidationLSTM(t *testing.T) {
}

func lstmInput0() []tensor.Tensor {
rand.Seed(10)
r := rand.New(rand.NewSource(10))

return []tensor.Tensor{
// Input X: (sequence_length, batch_size, input_size).
ops.RandomFloat32TensorFixture(2, 1, 3),
ops.RandomFloat32TensorFixture(r, 2, 1, 3),
// Input W: (num_directions, 4 * hidden_size, input_size).
ops.RandomFloat32TensorFixture(1, 16, 3),
ops.RandomFloat32TensorFixture(r, 1, 16, 3),
// Input R: (num_directions, 4 * hidden_size, hidden_size).
ops.RandomFloat32TensorFixture(1, 16, 4),
ops.RandomFloat32TensorFixture(r, 1, 16, 4),
// Input B: (num_directions, 8 * hidden_size).
ops.RandomFloat32TensorFixture(1, 32),
ops.RandomFloat32TensorFixture(r, 1, 32),
// Input sequence_lens: not supported.
nil,
// Input initial_h: (num_directions, batch_size, hidden_size).
Expand All @@ -303,38 +303,38 @@ func lstmInput0() []tensor.Tensor {
}

func lstmInput1() []tensor.Tensor {
rand.Seed(11)
r := rand.New(rand.NewSource(11))

return []tensor.Tensor{
// Input X: (sequence_length, batch_size, input_size).
ops.RandomFloat32TensorFixture(10, 1, 3),
ops.RandomFloat32TensorFixture(r, 10, 1, 3),
// Input W: (num_directions, 4 * hidden_size, input_size).
ops.RandomFloat32TensorFixture(1, 16, 3),
ops.RandomFloat32TensorFixture(r, 1, 16, 3),
// Input R: (num_directions, 4 * hidden_size, hidden_size).
ops.RandomFloat32TensorFixture(1, 16, 4),
ops.RandomFloat32TensorFixture(r, 1, 16, 4),
// Input B: (num_directions, 8 * hidden_size).
ops.RandomFloat32TensorFixture(1, 32),
ops.RandomFloat32TensorFixture(r, 1, 32),
// Input sequence_lens: not supported.
nil,
// Input initial_h: (num_directions, batch_size, hidden_size).
ops.RandomFloat32TensorFixture(1, 1, 4),
ops.RandomFloat32TensorFixture(r, 1, 1, 4),
// Input initial_c: (num_directions, batch_size, hidden_size).
ops.RandomFloat32TensorFixture(1, 1, 4),
ops.RandomFloat32TensorFixture(r, 1, 1, 4),
// Input P: peephole weights.
nil,
}
}

func lstmInputNoBNoH() []tensor.Tensor {
rand.Seed(12)
r := rand.New(rand.NewSource(12))

return []tensor.Tensor{
// Input X: (sequence_length, batch_size, input_size).
ops.RandomFloat32TensorFixture(10, 1, 3),
ops.RandomFloat32TensorFixture(r, 10, 1, 3),
// Input W: (num_directions, 4 * hidden_size, input_size).
ops.RandomFloat32TensorFixture(1, 16, 3),
ops.RandomFloat32TensorFixture(r, 1, 16, 3),
// Input R: (num_directions, 4 * hidden_size, hidden_size).
ops.RandomFloat32TensorFixture(1, 16, 4),
ops.RandomFloat32TensorFixture(r, 1, 16, 4),
// Input B.
nil,
// Input sequence_lens: not supported.
Expand All @@ -349,15 +349,15 @@ func lstmInputNoBNoH() []tensor.Tensor {
}

func lstmInputPeepholes() []tensor.Tensor {
rand.Seed(13)
r := rand.New(rand.NewSource(13))

return []tensor.Tensor{
// Input X: (sequence_length, batch_size, input_size).
ops.RandomFloat32TensorFixture(10, 1, 3),
ops.RandomFloat32TensorFixture(r, 10, 1, 3),
// Input W: (num_directions, 4 * hidden_size, input_size).
ops.RandomFloat32TensorFixture(1, 16, 3),
ops.RandomFloat32TensorFixture(r, 1, 16, 3),
// Input R: (num_directions, 4 * hidden_size, hidden_size).
ops.RandomFloat32TensorFixture(1, 16, 4),
ops.RandomFloat32TensorFixture(r, 1, 16, 4),
// Input B.
nil,
// Input sequence_lens: not supported.
Expand All @@ -367,7 +367,7 @@ func lstmInputPeepholes() []tensor.Tensor {
// Input initial_c.
nil,
// Input P: (num_directions, 3 * hidden_size).
ops.RandomFloat32TensorFixture(1, 12),
ops.RandomFloat32TensorFixture(r, 1, 12),
}
}

Expand Down
32 changes: 16 additions & 16 deletions ops/opset13/rnn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,15 +246,15 @@ func TestInputValidationRNN(t *testing.T) {
}

func rnnInput0() []tensor.Tensor {
rand.Seed(13)
r := rand.New(rand.NewSource(13))

return []tensor.Tensor{
// Input X: (sequence_length, batch_size, input_size).
ops.RandomFloat32TensorFixture(2, 1, 3),
ops.RandomFloat32TensorFixture(r, 2, 1, 3),
// Input W: (num_directions, hidden_size, input_size).
ops.RandomFloat32TensorFixture(1, 4, 3),
ops.RandomFloat32TensorFixture(r, 1, 4, 3),
// Input R: (num_directions, hidden_size, hidden_size).
ops.RandomFloat32TensorFixture(1, 4, 4),
ops.RandomFloat32TensorFixture(r, 1, 4, 4),
// Input B: (num_directions, 2 * hidden_size)
ops.TensorWithBackingFixture(ops.Zeros(ops.NElements(1, 8)), 1, 8),
// Input sequence_lens: not supported
Expand All @@ -265,15 +265,15 @@ func rnnInput0() []tensor.Tensor {
}

func rnnInput1() []tensor.Tensor {
rand.Seed(13)
r := rand.New(rand.NewSource(13))

return []tensor.Tensor{
// Input X: (sequence_length, batch_size, input_size).
ops.RandomFloat32TensorFixture(10, 3, 4),
ops.RandomFloat32TensorFixture(r, 10, 3, 4),
// Input W: (num_directions, hidden_size, input_size).
ops.RandomFloat32TensorFixture(1, 10, 4),
ops.RandomFloat32TensorFixture(r, 1, 10, 4),
// Input R: (num_directions, hidden_size, hidden_size).
ops.RandomFloat32TensorFixture(1, 10, 10),
ops.RandomFloat32TensorFixture(r, 1, 10, 10),
// Input B: (num_directions, 2 * hidden_size)
ops.TensorWithBackingFixture(ops.Zeros(ops.NElements(1, 20)), 1, 20),
// Input sequence_lens: not supported
Expand All @@ -284,15 +284,15 @@ func rnnInput1() []tensor.Tensor {
}

func rnnInputNoB() []tensor.Tensor {
rand.Seed(13)
r := rand.New(rand.NewSource(13))

return []tensor.Tensor{
// Input X: (sequence_length, batch_size, input_size).
ops.RandomFloat32TensorFixture(2, 1, 3),
ops.RandomFloat32TensorFixture(r, 2, 1, 3),
// Input W: (num_directions, hidden_size, input_size).
ops.RandomFloat32TensorFixture(1, 4, 3),
ops.RandomFloat32TensorFixture(r, 1, 4, 3),
// Input R: (num_directions, hidden_size, hidden_size).
ops.RandomFloat32TensorFixture(1, 4, 4),
ops.RandomFloat32TensorFixture(r, 1, 4, 4),
// Input B: not provided.
nil,
// Input sequence_lens: not supported
Expand All @@ -303,15 +303,15 @@ func rnnInputNoB() []tensor.Tensor {
}

func rnnInputNoBNoH() []tensor.Tensor {
rand.Seed(13)
r := rand.New(rand.NewSource(13))

return []tensor.Tensor{
// Input X: (sequence_length, batch_size, input_size).
ops.RandomFloat32TensorFixture(2, 1, 3),
ops.RandomFloat32TensorFixture(r, 2, 1, 3),
// Input W: (num_directions, hidden_size, input_size).
ops.RandomFloat32TensorFixture(1, 4, 3),
ops.RandomFloat32TensorFixture(r, 1, 4, 3),
// Input R: (num_directions, hidden_size, hidden_size).
ops.RandomFloat32TensorFixture(1, 4, 4),
ops.RandomFloat32TensorFixture(r, 1, 4, 4),
// Input B: not provided.
nil,
// Input sequence_lens: not supported
Expand Down

0 comments on commit b0189f6

Please sign in to comment.