diff --git a/include/gmssl/tls.h b/include/gmssl/tls.h index c03b0c1a1..4479208f6 100644 --- a/include/gmssl/tls.h +++ b/include/gmssl/tls.h @@ -718,7 +718,6 @@ typedef struct { uint8_t record[TLS_MAX_RECORD_SIZE]; - // 其实这个就不太对了,还是应该有一个完整的密文记录 uint8_t databuf[TLS_MAX_PLAINTEXT_SIZE]; uint8_t *data; size_t datalen; @@ -828,17 +827,20 @@ int tls13_gcm_decrypt(const BLOCK_CIPHER_KEY *key, const uint8_t iv[12], #ifdef ENABLE_TLS_DEBUG # define tls_trace(s) fprintf(stderr,(s)) # define tls_record_trace(fp,rec,reclen,fmt,ind) tls_record_print(fp,rec,reclen,fmt,ind) +# define tls_encrypted_record_trace(fp,rec,reclen,fmt,ind) tls_encrypted_record_print(fp,rec,reclen,fmt,ind) # define tlcp_record_trace(fp,rec,reclen,fmt,ind) tlcp_record_print(fp,rec,reclen,fmt,ind) # define tls12_record_trace(fp,rec,reclen,fmt,ind) tls12_record_print(fp,rec,reclen,fmt,ind) # define tls13_record_trace(fp,rec,reclen,fmt,ind) tls13_record_print(fp,fmt,ind,rec,reclen) #else # define tls_trace(s) # define tls_record_trace(fp,rec,reclen,fmt,ind) +# define tls_encrypted_record_trace(fp,rec,reclen,fmt,ind) # define tlcp_record_trace(fp,rec,reclen,fmt,ind) # define tls12_record_trace(fp,rec,reclen,fmt,ind) # define tls13_record_trace(fp,rec,reclen,fmt,ind) #endif +int tls_encrypted_record_print(FILE *fp, const uint8_t *record, size_t recordlen, int format, int indent); #ifdef __cplusplus } diff --git a/src/tlcp.c b/src/tlcp.c index a3308c6e9..5e8654b87 100644 --- a/src/tlcp.c +++ b/src/tlcp.c @@ -484,14 +484,13 @@ int tlcp_do_connect(TLS_CONNECT *conn) sm3_update(&sm3_ctx, finished_record + 5, finished_record_len - 5); // encrypt Client Finished - tls_trace("encrypt Finished\n"); if (tls_record_encrypt(&conn->client_write_mac_ctx, &conn->client_write_enc_key, conn->client_seq_num, finished_record, finished_record_len, record, &recordlen) != 1) { error_print(); tls_send_alert(conn, TLS_alert_internal_error); goto end; } - tlcp_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据 + tls_encrypted_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据 tls_seq_num_incr(conn->client_seq_num); if (tls_record_send(record, recordlen, conn->sock) != 1) { error_print(); @@ -526,8 +525,7 @@ int tlcp_do_connect(TLS_CONNECT *conn) tls_send_alert(conn, TLS_alert_bad_record_mac); goto end; } - tlcp_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据 - tls_trace("decrypt Finished\n"); + tls_encrypted_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据 if (tls_record_decrypt(&conn->server_write_mac_ctx, &conn->server_write_enc_key, conn->server_seq_num, record, recordlen, finished_record, &finished_record_len) != 1) { error_print(); @@ -920,10 +918,10 @@ int tlcp_do_accept(TLS_CONNECT *conn) tls_send_alert(conn, TLS_alert_unexpected_message); goto end; } - tlcp_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据 + tls_encrypted_record_trace(stderr, record, recordlen, 0, 0); // decrypt ClientFinished - tls_trace("decrypt Finished\n"); + //tls_trace("decrypt Finished\n"); if (tls_record_decrypt(&conn->client_write_mac_ctx, &conn->client_write_enc_key, conn->client_seq_num, record, recordlen, finished_record, &finished_record_len) != 1) { error_print(); @@ -990,8 +988,7 @@ int tlcp_do_accept(TLS_CONNECT *conn) tls_send_alert(conn, TLS_alert_internal_error); goto end; } - tls_trace("encrypt Finished\n"); - tlcp_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据 + tls_encrypted_record_trace(stderr, record, recordlen, 0, 0); tls_seq_num_incr(conn->server_seq_num); if (tls_record_send(record, recordlen, conn->sock) != 1) { error_print(); diff --git a/src/tls.c b/src/tls.c index 54d4047c4..9a38dee33 100644 --- a/src/tls.c +++ b/src/tls.c @@ -1332,6 +1332,8 @@ int tls_record_set_alert(uint8_t *record, size_t *recordlen, return -1; } record[0] = TLS_record_alert; + //record[1] = protocol.major should be set by others + //record[2] = protocol.minor should be set by others record[3] = 0; // length record[4] = 2; // length record[5] = (uint8_t)alert_level; @@ -1491,7 +1493,7 @@ int tls_record_send(const uint8_t *record, size_t recordlen, tls_socket_t sock) return 1; } -int tls_record_do_recv(uint8_t *record, size_t *recordlen, tls_socket_t sock) +int tls_record_recv(uint8_t *record, size_t *recordlen, tls_socket_t sock) { uint8_t *p = record; size_t len; @@ -1503,7 +1505,7 @@ int tls_record_do_recv(uint8_t *record, size_t *recordlen, tls_socket_t sock) p += n; len -= n; } else if (n == 0) { - error_puts("TCP connection closed"); + tls_trace("TCP connection closed"); *recordlen = 0; return 0; } else { @@ -1541,7 +1543,7 @@ int tls_record_do_recv(uint8_t *record, size_t *recordlen, tls_socket_t sock) p += n; len -= n; } else if (n == 0) { - error_puts("connection closed"); + tls_trace("connection closed"); *recordlen = 0; return 0; } else { @@ -1558,45 +1560,6 @@ int tls_record_do_recv(uint8_t *record, size_t *recordlen, tls_socket_t sock) return 1; } -int tls_record_recv(uint8_t *record, size_t *recordlen, tls_socket_t sock) -{ - int ret; - - if ((ret = tls_record_do_recv(record, recordlen, sock)) != 1) { - if (ret && ret != -EAGAIN) error_print(); - return ret; - } - - if (tls_record_type(record) == TLS_record_alert) { - int level; - int alert; - - if (tls_record_get_alert(record, &level, &alert) != 1) { - error_print(); - return -1; - } - tls_record_trace(stderr, record, *recordlen, 0, 0); - - if (level == TLS_alert_level_fatal && alert == TLS_alert_close_notify) { -#if ENABLE_TLS_RESPOND_CLOSE_NOTIFY - tls_trace("send Alert close_notifiy\n"); - tls_record_trace(stderr, record, *recordlen, 0, 0); - if (tls_record_send(record, *recordlen, sock) != 1) { - error_print(); - return -1; - } -#endif - return 0; - - } else { - error_print(); - return -1; - } - } - - return 1; -} - int tls_seq_num_incr(uint8_t seq_num[8]) { int i; @@ -1604,7 +1567,7 @@ int tls_seq_num_incr(uint8_t seq_num[8]) seq_num[i]++; if (seq_num[i]) break; } - // FIXME: 检查溢出 + // FIXME: check overflow return 1; } @@ -1632,6 +1595,7 @@ int tls_send_alert(TLS_CONNECT *conn, int alert) error_print(); return -1; } + tls_record_set_protocol(record, conn->protocol == TLS_protocol_tls13 ? TLS_protocol_tls12 : conn->protocol); tls_record_set_alert(record, &recordlen, TLS_alert_level_fatal, alert); @@ -1692,13 +1656,12 @@ int tls_send_warning(TLS_CONNECT *conn, int alert) return 1; } -int tls_send(TLS_CONNECT *conn, const uint8_t *in, size_t inlen, size_t *sentlen) +static int tls_encrypt_send(TLS_CONNECT *conn, int record_type, const uint8_t *in, size_t inlen, size_t *sentlen) { const SM3_HMAC_CTX *hmac_ctx; const SM4_KEY *enc_key; uint8_t *seq_num; - uint8_t *record; - size_t datalen; + size_t recordlen; if (!conn) { error_print(); @@ -1713,6 +1676,11 @@ int tls_send(TLS_CONNECT *conn, const uint8_t *in, size_t inlen, size_t *sentlen inlen = TLS_MAX_PLAINTEXT_SIZE; } + if (conn->datalen) { + error_puts("recv all buffered data before send"); + return -1; + } + if (conn->is_client) { hmac_ctx = &conn->client_write_mac_ctx; enc_key = &conn->client_write_enc_key; @@ -1722,37 +1690,34 @@ int tls_send(TLS_CONNECT *conn, const uint8_t *in, size_t inlen, size_t *sentlen enc_key = &conn->server_write_enc_key; seq_num = conn->server_seq_num; } - record = conn->record; - - tls_trace("send ApplicationData\n"); - if (tls_record_set_type(record, TLS_record_application_data) != 1 - || tls_record_set_protocol(record, conn->protocol) != 1 - || tls_record_set_length(record, inlen) != 1) { + if (tls_record_set_type(conn->databuf, record_type) != 1 + || tls_record_set_protocol(conn->databuf, conn->protocol) != 1 + || tls_record_set_data(conn->databuf, in, inlen) != 1) { error_print(); return -1; } + tls_record_trace(stderr, conn->databuf, tls_record_length(conn->databuf), 0, 0); - if (tls_cbc_encrypt(hmac_ctx, enc_key, seq_num, tls_record_header(record), - in, inlen, tls_record_data(record), &datalen) != 1) { - error_print(); - return -1; - } - if (tls_record_set_length(record, datalen) != 1) { + if (tls_record_encrypt(hmac_ctx, enc_key, seq_num, + conn->databuf, tls_record_length(conn->databuf), + conn->record, &recordlen) != 1) { error_print(); return -1; } tls_seq_num_incr(seq_num); - if (tls_record_send(record, tls_record_length(record), conn->sock) != 1) { + + if (tls_record_send(conn->record, recordlen, conn->sock) != 1) { error_print(); return -1; } + tls_encrypted_record_trace(stderr, conn->record, recordlen, 0, 0); + *sentlen = inlen; - tls_record_trace(stderr, record, tls_record_length(record), 0, 0); return 1; } -int tls_do_recv(TLS_CONNECT *conn) +int tls_decrypt_recv(TLS_CONNECT *conn) { int ret; const SM3_HMAC_CTX *hmac_ctx; @@ -1772,68 +1737,111 @@ int tls_do_recv(TLS_CONNECT *conn) seq_num = conn->client_seq_num; } - tls_trace("recv ApplicationData\n"); + tls_trace("recv Encrypted Record\n"); if ((ret = tls_record_recv(record, &recordlen, conn->sock)) != 1) { if (ret < 0 && ret != -EAGAIN) error_print(); return ret; } + tls_encrypted_record_trace(stderr, record, recordlen, 0, 0); - tls_record_trace(stderr, record, recordlen, 0, 0); - if (tls_cbc_decrypt(hmac_ctx, dec_key, seq_num, record, - tls_record_data(record), tls_record_data_length(record), + if (tls_record_decrypt(hmac_ctx, dec_key, seq_num, + record, recordlen, conn->databuf, &conn->datalen) != 1) { error_print(); return -1; } - conn->data = conn->databuf; tls_seq_num_incr(seq_num); - tls_record_set_data(record, conn->data, conn->datalen); - tls_trace("decrypt ApplicationData\n"); - tls_record_trace(stderr, record, tls_record_length(record), 0, 0); + conn->data = tls_record_data(conn->databuf); + conn->datalen = tls_record_data_length(conn->databuf); + + tls_record_trace(stderr, conn->databuf, tls_record_length(conn->databuf), 0, 0); + return 1; } +int tls_send(TLS_CONNECT *conn, const uint8_t *in, size_t inlen, size_t *sentlen) +{ + tls_trace("send ApplicationData\n"); + return tls_encrypt_send(conn, TLS_record_application_data, in, inlen, sentlen); +} + int tls_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_t *recvlen) { if (!conn || !out || !outlen || !recvlen) { error_print(); return -1; } + if (conn->datalen == 0) { int ret; - if ((ret = tls_do_recv(conn)) != 1) { + if ((ret = tls_decrypt_recv(conn)) != 1) { if (ret < 0 && ret != -EAGAIN) error_print(); return ret; } + + switch (tls_record_type(conn->record)) { + case TLS_record_application_data: + break; + case TLS_record_change_cipher_spec: + error_print(); + return -1; + case TLS_record_alert: + { + // should call tls_process_alert() + int level; + int alert; + tls_record_get_alert(conn->databuf, &level, &alert); + if (alert == TLS_alert_close_notify) { + tls_trace("recv Alert.close_notify\n"); + return 0; + } + tls_trace("alert received\n"); + return -1; + } + default: + error_print(); + return -1; + } } + *recvlen = outlen <= conn->datalen ? outlen : conn->datalen; memcpy(out, conn->data, *recvlen); conn->data += *recvlen; conn->datalen -= *recvlen; + return 1; } int tls_shutdown(TLS_CONNECT *conn) { + int ret; size_t recordlen; + uint8_t alert[2]; + alert[0] = TLS_alert_level_fatal; + alert[1] = TLS_alert_close_notify; + if (!conn) { error_print(); return -1; } - tls_trace("send Alert close_notify\n"); - if (tls_send_alert(conn, TLS_alert_close_notify) != 1) { + + tls_trace("send Alert.close_notify\n"); + + if (tls_encrypt_send(conn, TLS_record_alert, alert, sizeof(alert), &recordlen) != 1) { error_print(); return -1; } -#ifdef ENABLE_TLS_RESPOND_CLOSE_NOTIFY - tls_trace("recv Alert close_notify\n"); - if (tls_record_do_recv(conn->record, &recordlen, conn->sock) != 1) { - error_print(); + + tls_trace("recv Alert.close_notify\n"); + + if ((ret = tls_decrypt_recv(conn)) != 1) { + if (ret == 0) tls_trace("Connection closed by remote without close_notify\n"); + else if (ret == -EAGAIN) tls_trace("-EAGAIN\n"); + else error_print(); return -1; } - tls_record_trace(stderr, conn->record, recordlen, 0, 0); -#endif + return 1; } diff --git a/src/tls_trace.c b/src/tls_trace.c index 81a4b992f..fde98fb07 100644 --- a/src/tls_trace.c +++ b/src/tls_trace.c @@ -1070,7 +1070,7 @@ int tls13_record_print(FILE *fp, int format, int indent, const uint8_t *record, } -// FIXME: 需要根据RFC来考虑这个函数的参数,从底向上逐步修改每个函数的接口参数 +// FIXME: 根据RFC来考虑这个函数的参数,从底向上逐步修改每个函数的接口参数 // 仅从record数据是不能判断这个record是TLS 1.2还是TLS 1.3 // 不同协议上,同名的握手消息,其格式也是不一样的。这真是太恶心了!!!! @@ -1105,13 +1105,6 @@ int tls_record_print(FILE *fp, const uint8_t *record, size_t recordlen, int for return -1; } - // 最高字节设置后强制打印记录原始数据 - if (format >> 24) { - format_bytes(fp, format, indent, "Data", data, datalen); - fprintf(fp, "\n"); - return 1; - } - switch (record[0]) { case TLS_record_handshake: if (tls_handshake_print(fp, data, datalen, format, indent) != 1) { @@ -1173,3 +1166,24 @@ int tls_secrets_print(FILE *fp, format_print(stderr, format, indent, "\n"); return 1; } + +int tls_encrypted_record_print(FILE *fp, const uint8_t *record, size_t recordlen, int format, int indent) +{ + int protocol; + + if (!fp || !record || recordlen < 5) { + error_print(); + return -1; + } + + protocol = tls_record_protocol(record); + format_print(fp, format, indent, "EncryptedRecord\n"); indent += 4; + format_print(fp, format, indent, "ContentType: %s (%d)\n", tls_record_type_name(record[0]), record[0]); + format_print(fp, format, indent, "Version: %s (%d.%d)\n", tls_protocol_name(protocol), protocol >> 8, protocol & 0xff); + format_print(fp, format, indent, "Length: %d\n", tls_record_data_length(record)); + format_bytes(fp, format, indent, "EncryptedData", tls_record_data(record), tls_record_data_length(record)); + + fprintf(fp, "\n"); + return 1; +} + diff --git a/tools/tlcp_client.c b/tools/tlcp_client.c index cfbfa2508..6380ad4c7 100644 --- a/tools/tlcp_client.c +++ b/tools/tlcp_client.c @@ -70,6 +70,7 @@ int tlcp_client_main(int argc, char *argv[]) size_t len = sizeof(buf); char send_buf[1024] = {0}; size_t sentlen; + int read_stdin = 1; argc--; argv++; @@ -130,24 +131,19 @@ int tlcp_client_main(int argc, char *argv[]) return -1; } + if (tls_socket_create(&sock, AF_INET, SOCK_STREAM, 0) != 1) { + fprintf(stderr, "%s: open socket error\n", prog); + goto end; + } + if (!(hp = gethostbyname(host))) { fprintf(stderr, "%s: invalid hostname '%s'\n", prog, host); goto end; } - - memset(&ctx, 0, sizeof(ctx)); - memset(&conn, 0, sizeof(conn)); - server.sin_addr = *((struct in_addr *)hp->h_addr_list[0]); server.sin_family = AF_INET; server.sin_port = htons(port); - - if (tls_socket_create(&sock, AF_INET, SOCK_STREAM, 0) != 1) { - fprintf(stderr, "%s: open socket error\n", prog); - goto end; - } - if (tls_socket_connect(sock, &server) != 1) { fprintf(stderr, "%s: socket connect error\n", prog); goto end; @@ -158,19 +154,30 @@ int tlcp_client_main(int argc, char *argv[]) fprintf(stderr, "%s: context init error\n", prog); goto end; } + if (cacertfile) { if (tls_ctx_set_ca_certificates(&ctx, cacertfile, TLS_DEFAULT_VERIFY_DEPTH) != 1) { fprintf(stderr, "%s: context init error\n", prog); goto end; } } + if (certfile) { + if (!keyfile) { + fprintf(stderr, "%s: option '-key' should be assigned with '-cert'\n", prog); + goto end; + } + if (!pass) { + fprintf(stderr, "%s: option '-pass' should be assigned with '-pass'\n", prog); + goto end; + } if (tls_ctx_set_certificate_and_key(&ctx, certfile, keyfile, pass) != 1) { fprintf(stderr, "%s: context init error\n", prog); goto end; } } - if (quiet || get) { + + if (quiet) { ctx.quiet = 1; } @@ -196,6 +203,9 @@ int tlcp_client_main(int argc, char *argv[]) fclose(outcertsfp); } +// tls_shutdown(&conn); +// return 0; + if (get) { struct timeval timeout; timeout.tv_sec = TIMEOUT_SECONDS; @@ -208,6 +218,7 @@ int tlcp_client_main(int argc, char *argv[]) goto end; } + // use timeout to close the HTTP connection if (setsockopt(conn.sock, SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout, sizeof(timeout)) != 0) { perror("setsockopt"); fprintf(stderr, "%s: set socket timeout error\n", prog); @@ -215,100 +226,84 @@ int tlcp_client_main(int argc, char *argv[]) } for (;;) { - int ret; - if ((ret = tls_recv(&conn, (uint8_t *)buf, sizeof(buf), &len)) != 1) { - if (ret == 0) { - fprintf(stderr, "%s: TLCP connection is closed by remote host\n", prog); - } else if (ret != -EAGAIN) { - fprintf(stderr, "%s: recv error\n", prog); - } - break; - } - fwrite(buf, 1, len, stdout); - fflush(stdout); - } + int rv; - tls_shutdown(&conn); - goto end; - } + rv = tls_recv(&conn, (uint8_t *)buf, sizeof(buf), &len); - - for (;;) { - fd_set fds; - - if (!fgets(send_buf, sizeof(send_buf), stdin)) { - if (feof(stdin)) { + if (rv == 1) { + fwrite(buf, 1, len, stdout); + fflush(stdout); + } else if (rv == 0) { + fprintf(stderr, "%s: TLCP connection is closed by remote host\n", prog); + goto end; + } else if (rv == -EAGAIN) { + // when timeout, tls_recv return -EAGAIN (-11) tls_shutdown(&conn); + ret = 0; goto end; } else { - continue; + fprintf(stderr, "%s: tls_recv error\n", prog); + goto end; } } - if (tls_send(&conn, (uint8_t *)send_buf, strlen(send_buf), &sentlen) != 1) { - fprintf(stderr, "%s: send error\n", prog); - goto end; - } + } + + for (;;) { + fd_set fds; FD_ZERO(&fds); FD_SET(conn.sock, &fds); -#ifdef WIN32 -#else - FD_SET(fileno(stdin), &fds); //FD_SET(STDIN_FILENO, &fds); // NOT allowed in winsock2 !!! -#endif - - if (select((int)(conn.sock + 1), // WinSock2 select() ignore this arg - &fds, NULL, NULL, NULL) < 0) { - fprintf(stderr, "%s: select failed\n", prog); -#ifdef WIN32 - fprintf(stderr, "WSAGetLastError = %u\n", WSAGetLastError()); -#endif + if (read_stdin) + FD_SET(STDIN_FILENO, &fds); + + if (select(conn.sock + 1, &fds, NULL, NULL, NULL) < 0) { + fprintf(stderr, "%s: select error\n", prog); goto end; } - if (FD_ISSET(conn.sock, &fds)) { - for (;;) { - memset(buf, 0, sizeof(buf)); - if (tls_recv(&conn, (uint8_t *)buf, sizeof(buf), &len) != 1) { + if (read_stdin && FD_ISSET(STDIN_FILENO, &fds)) { + + if (fgets(buf, sizeof(buf), stdin)) { + if (tls_send(&conn, (uint8_t *)buf, strlen(buf), &len) != 1) { + fprintf(stderr, "%s: send error\n", prog); goto end; } - fwrite(buf, 1, len, stdout); - fflush(stdout); - - // 应该调整tls_recv 逻辑、API或者其他方式 - if (conn.datalen == 0) { - break; + } else { + if (!feof(stdin)) { + fprintf(stderr, "%s: length of input line exceeds buffer size\n", prog); + goto end; } + read_stdin = 0; } - } -#ifdef WIN32 -#else - if (FD_ISSET(fileno(stdin), &fds)) { - fprintf(stderr, "recv from stdin\n"); - memset(send_buf, 0, sizeof(send_buf)); + if (FD_ISSET(conn.sock, &fds)) { + int rv; - if (!fgets(send_buf, sizeof(send_buf), stdin)) { - if (feof(stdin)) { - tls_shutdown(&conn); - goto end; - } else { - continue; - } - } - if (tls_send(&conn, (uint8_t *)send_buf, strlen(send_buf), &sentlen) != 1) { - fprintf(stderr, "%s: send error\n", prog); + rv = tls_recv(&conn, (uint8_t *)buf, sizeof(buf), &len); + + if (rv == 1) { + fwrite(buf, 1, len, stdout); + fflush(stdout); + } else if (rv == 0) { + fprintf(stderr, "Connection closed by remote host\n"); + goto end; + } else if (rv == -EAGAIN) { + // should not happen + error_print(); + goto end; + } else { + error_print(); + fprintf(stderr, "%s: tls_recv error\n", prog); goto end; } } -#endif - - fprintf(stderr, "end of this round\n"); } end: + // FIXME: clean ctx and connection ASAP, as Ctrl-C is not handled if (sock != -1) tls_socket_close(sock); tls_ctx_cleanup(&ctx); tls_cleanup(&conn); - return 0; + return ret; }