diff --git a/tls1prf.go b/tls1prf.go index 7043b3bc..3153fc81 100644 --- a/tls1prf.go +++ b/tls1prf.go @@ -16,7 +16,10 @@ func SupportsTLS1PRF() bool { (vMajor >= 1 && vMinor >= 1) } -func TLS1PRF(secret, label, seed []byte, keyLen int, h func() hash.Hash) ([]byte, error) { +// TLS1PRF implements the TLS 1.0/1.1 pseudo-random function if h is nil, +// else it implements the TLS 1.2 pseudo-random function. +// The pseudo-random number will be written to result and will be of length len(result). +func TLS1PRF(result, secret, label, seed []byte, h func() hash.Hash) error { var md C.GO_EVP_MD_PTR if h == nil { // TLS 1.0/1.1 PRF doesn't allow to specify the hash function, @@ -29,70 +32,73 @@ func TLS1PRF(secret, label, seed []byte, keyLen int, h func() hash.Hash) ([]byte md = hashToMD(h()) } if md == nil { - return nil, errors.New("unsupported hash function") + return errors.New("unsupported hash function") } ctx := C.go_openssl_EVP_PKEY_CTX_new_id(C.GO_EVP_PKEY_TLS1_PRF, nil) if ctx == nil { - return nil, newOpenSSLError("EVP_PKEY_CTX_new_id") + return newOpenSSLError("EVP_PKEY_CTX_new_id") } defer func() { C.go_openssl_EVP_PKEY_CTX_free(ctx) }() if C.go_openssl_EVP_PKEY_derive_init(ctx) != 1 { - return nil, newOpenSSLError("EVP_PKEY_derive_init") + return newOpenSSLError("EVP_PKEY_derive_init") } switch vMajor { case 3: if C.go_openssl_EVP_PKEY_CTX_set_tls1_prf_md(ctx, md) != 1 { - return nil, newOpenSSLError("EVP_PKEY_CTX_set_tls1_prf_md") + return newOpenSSLError("EVP_PKEY_CTX_set_tls1_prf_md") } if C.go_openssl_EVP_PKEY_CTX_set1_tls1_prf_secret(ctx, base(secret), C.int(len(secret))) != 1 { - return nil, newOpenSSLError("EVP_PKEY_CTX_set1_tls1_prf_secret") + return newOpenSSLError("EVP_PKEY_CTX_set1_tls1_prf_secret") } if C.go_openssl_EVP_PKEY_CTX_add1_tls1_prf_seed(ctx, base(label), C.int(len(label))) != 1 { - return nil, newOpenSSLError("EVP_PKEY_CTX_add1_tls1_prf_seed") + return newOpenSSLError("EVP_PKEY_CTX_add1_tls1_prf_seed") } if C.go_openssl_EVP_PKEY_CTX_add1_tls1_prf_seed(ctx, base(seed), C.int(len(seed))) != 1 { - return nil, newOpenSSLError("EVP_PKEY_CTX_add1_tls1_prf_seed") + return newOpenSSLError("EVP_PKEY_CTX_add1_tls1_prf_seed") } case 1: if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE, C.GO_EVP_PKEY_CTRL_TLS_MD, 0, unsafe.Pointer(md)) != 1 { - return nil, newOpenSSLError("EVP_PKEY_CTX_set_tls1_prf_md") + return newOpenSSLError("EVP_PKEY_CTX_set_tls1_prf_md") } if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE, C.GO_EVP_PKEY_CTRL_TLS_SECRET, C.int(len(secret)), unsafe.Pointer(base(secret))) != 1 { - return nil, newOpenSSLError("EVP_PKEY_CTX_set1_tls1_prf_secret") + return newOpenSSLError("EVP_PKEY_CTX_set1_tls1_prf_secret") } if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE, C.GO_EVP_PKEY_CTRL_TLS_SEED, C.int(len(label)), unsafe.Pointer(base(label))) != 1 { - return nil, newOpenSSLError("EVP_PKEY_CTX_add1_tls1_prf_seed") + return newOpenSSLError("EVP_PKEY_CTX_add1_tls1_prf_seed") } if C.go_openssl_EVP_PKEY_CTX_ctrl(ctx, -1, C.GO1_EVP_PKEY_OP_DERIVE, C.GO_EVP_PKEY_CTRL_TLS_SEED, C.int(len(seed)), unsafe.Pointer(base(seed))) != 1 { - return nil, newOpenSSLError("EVP_PKEY_CTX_add1_tls1_prf_seed") + return newOpenSSLError("EVP_PKEY_CTX_add1_tls1_prf_seed") } } - outLen := C.size_t(keyLen) - out := make([]byte, outLen) - if C.go_openssl_EVP_PKEY_derive(ctx, base(out), &outLen) != 1 { - return nil, newOpenSSLError("EVP_PKEY_derive") + outLen := C.size_t(len(result)) + if C.go_openssl_EVP_PKEY_derive(ctx, base(result), &outLen) != 1 { + return newOpenSSLError("EVP_PKEY_derive") } - if outLen != C.size_t(keyLen) { - return nil, errors.New("tls1-prf: entropy limit reached") + // The Go standard library expects TLS1PRF to return the requested number of bytes, + // fail if it doesn't. While there is no known situation where this will happen, + // EVP_PKEY_derive handles multiple algorithms and there could be a subtle mismatch + // after more code changes in the future. + if outLen != C.size_t(len(result)) { + return errors.New("tls1-prf: derived less bytes than requested") } - return out[:outLen], nil + return nil } diff --git a/tls1prf_test.go b/tls1prf_test.go index bc81d984..80324603 100644 --- a/tls1prf_test.go +++ b/tls1prf_test.go @@ -155,12 +155,13 @@ func TestTLS1PRF(t *testing.T) { if !openssl.SupportsHash(tt.hash) { t.Skip("skipping: hash not supported") } - out, err := openssl.TLS1PRF(tt.secret, tt.label, tt.seed, len(tt.out), cryptoToHash(tt.hash)) + result := make([]byte, len(tt.out)) + err := openssl.TLS1PRF(result, tt.secret, tt.label, tt.seed, cryptoToHash(tt.hash)) if err != nil { t.Fatalf("error deriving TLS 1.2 PRF: %v.", err) } - if !bytes.Equal(out, tt.out) { - t.Errorf("incorrect key output: have %v, need %v.", out, tt.out) + if !bytes.Equal(result, tt.out) { + t.Errorf("incorrect key output: have %v, need %v.", result, tt.out) } }) }