Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tls1prf: require callers to pass in the result buffer #45

Merged
merged 3 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions cng/tls1prf.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,29 @@ func loadTLS1PRF(id string) (bcrypt.ALG_HANDLE, error) {
return h.(bcrypt.ALG_HANDLE), nil
}

func TLS1PRF(secret, label, seed []byte, keyLen int, h func() hash.Hash) ([]byte, error) {
// TLS1PRF implements the TLS 1.0/1.1 pseudo-random function if h is nil,
// else it implements the TLS 1.2 pseudo-random function.
// The pseudo-random number will be written to result and will be of length len(result).
func TLS1PRF(result, secret, label, seed []byte, h func() hash.Hash) error {
// TLS 1.0/1.1 PRF uses MD5SHA1.
algID := bcrypt.TLS1_1_KDF_ALGORITHM
var hashID string
if h != nil {
// If h is specified, assume the caller wants to use TLS 1.2 PRF.
// TLS 1.0/1.1 PRF doesn't allow specifying the hash function.
if hashID = hashToID(h()); hashID == "" {
qmuntal marked this conversation as resolved.
Show resolved Hide resolved
return nil, errors.New("cng: unsupported hash function")
return errors.New("cng: unsupported hash function")
}
algID = bcrypt.TLS1_2_KDF_ALGORITHM
}

alg, err := loadTLS1PRF(algID)
if err != nil {
return nil, err
return err
}
var kh bcrypt.KEY_HANDLE
if err := bcrypt.GenerateSymmetricKey(alg, &kh, nil, secret, 0); err != nil {
return nil, err
return err
}

buffers := make([]bcrypt.Buffer, 0, 3)
Expand Down Expand Up @@ -73,11 +76,15 @@ func TLS1PRF(secret, label, seed []byte, keyLen int, h func() hash.Hash) ([]byte
Count: uint32(len(buffers)),
Buffers: &buffers[0],
}
out := make([]byte, keyLen)
var size uint32
err = bcrypt.KeyDerivation(kh, params, out, &size, 0)
err = bcrypt.KeyDerivation(kh, params, result, &size, 0)
if err != nil {
return nil, err
return err
}
return out[:size], nil
// The Go standard library expects TLS1PRF to return the requested number of bytes,
// fail if it doesn't.
if size != uint32(len(result)) {
qmuntal marked this conversation as resolved.
Show resolved Hide resolved
qmuntal marked this conversation as resolved.
Show resolved Hide resolved
return errors.New("tls1-prf: derived less bytes than requested")
}
return nil
}
7 changes: 4 additions & 3 deletions cng/tls1prf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,13 @@ var tls1prfTests = []tls1prfTest{

func TestTLS1PRF(t *testing.T) {
for i, tt := range tls1prfTests {
out, err := cng.TLS1PRF(tt.secret, tt.label, tt.seed, len(tt.out), tt.hash)
result := make([]byte, len(tt.out))
err := cng.TLS1PRF(result, tt.secret, tt.label, tt.seed, tt.hash)
if err != nil {
t.Errorf("test %d: error deriving TLS 1.2 PRF: %v.", i, err)
}
if !bytes.Equal(out, tt.out) {
t.Errorf("test %d: incorrect key output: have %v, need %v.", i, out, tt.out)
if !bytes.Equal(result, tt.out) {
t.Errorf("test %d: incorrect key output: have %v, need %v.", i, result, tt.out)
}
}
}