diff --git a/net/tls/tls.h b/net/tls/tls.h index e5e47452308a..774859b63f0d 100644 --- a/net/tls/tls.h +++ b/net/tls/tls.h @@ -145,7 +145,8 @@ void tls_err_abort(struct sock *sk, int err); int init_prot_info(struct tls_prot_info *prot, const struct tls_crypto_info *crypto_info, const struct tls_cipher_desc *cipher_desc); -int tls_set_sw_offload(struct sock *sk, int tx); +int tls_set_sw_offload(struct sock *sk, int tx, + struct tls_crypto_info *new_crypto_info); void tls_update_rx_zc_capable(struct tls_context *tls_ctx); void tls_sw_strparser_arm(struct sock *sk, struct tls_context *ctx); void tls_sw_strparser_done(struct tls_context *tls_ctx); diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index dc063c2c7950..e50b6e71df13 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -1227,7 +1227,7 @@ int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx) context->resync_nh_reset = 1; ctx->priv_ctx_rx = context; - rc = tls_set_sw_offload(sk, 0); + rc = tls_set_sw_offload(sk, 0, NULL); if (rc) goto release_ctx; diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 6b4b9f2749a6..68b5735dafc1 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -423,9 +423,10 @@ static __poll_t tls_sk_poll(struct file *file, struct socket *sock, ctx = tls_sw_ctx_rx(tls_ctx); psock = sk_psock_get(sk); - if (skb_queue_empty_lockless(&ctx->rx_list) && - !tls_strp_msg_ready(ctx) && - sk_psock_queue_empty(psock)) + if ((skb_queue_empty_lockless(&ctx->rx_list) && + !tls_strp_msg_ready(ctx) && + sk_psock_queue_empty(psock)) || + READ_ONCE(ctx->key_update_pending)) mask &= ~(EPOLLIN | EPOLLRDNORM); if (psock) @@ -612,11 +613,13 @@ static int validate_crypto_info(const struct tls_crypto_info *crypto_info, static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, unsigned int optlen, int tx) { - struct tls_crypto_info *crypto_info; - struct tls_crypto_info *alt_crypto_info; + struct tls_crypto_info *crypto_info, *alt_crypto_info; + struct tls_crypto_info *old_crypto_info = NULL; struct tls_context *ctx = tls_get_ctx(sk); const struct tls_cipher_desc *cipher_desc; union tls_crypto_context *crypto_ctx; + union tls_crypto_context tmp = {}; + bool update = false; int rc = 0; int conf; @@ -633,9 +636,18 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, crypto_info = &crypto_ctx->info; - /* Currently we don't support set crypto info more than one time */ - if (TLS_CRYPTO_INFO_READY(crypto_info)) - return -EBUSY; + if (TLS_CRYPTO_INFO_READY(crypto_info)) { + /* Currently we only support setting crypto info more + * than one time for TLS 1.3 + */ + if (crypto_info->version != TLS_1_3_VERSION) + return -EBUSY; + + update = true; + old_crypto_info = crypto_info; + crypto_info = &tmp.info; + crypto_ctx = &tmp; + } rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info)); if (rc) { @@ -643,7 +655,14 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, goto err_crypto_info; } - rc = validate_crypto_info(crypto_info, alt_crypto_info); + if (update) { + /* Ensure that TLS version and ciphers are not modified */ + if (crypto_info->version != old_crypto_info->version || + crypto_info->cipher_type != old_crypto_info->cipher_type) + rc = -EINVAL; + } else { + rc = validate_crypto_info(crypto_info, alt_crypto_info); + } if (rc) goto err_crypto_info; @@ -673,7 +692,8 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE); TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); } else { - rc = tls_set_sw_offload(sk, 1); + rc = tls_set_sw_offload(sk, 1, + update ? crypto_info : NULL); if (rc) goto err_crypto_info; TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW); @@ -687,14 +707,16 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE); TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); } else { - rc = tls_set_sw_offload(sk, 0); + rc = tls_set_sw_offload(sk, 0, + update ? crypto_info : NULL); if (rc) goto err_crypto_info; TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW); TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); conf = TLS_SW; } - tls_sw_strparser_arm(sk, ctx); + if (!update) + tls_sw_strparser_arm(sk, ctx); } if (tx) diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 3dcf8ee60fea..9e5aff5bab98 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -2716,12 +2716,22 @@ int init_prot_info(struct tls_prot_info *prot, return 0; } -int tls_set_sw_offload(struct sock *sk, int tx) +static void tls_finish_key_update(struct sock *sk, struct tls_context *tls_ctx) { + struct tls_sw_context_rx *ctx = tls_ctx->priv_ctx_rx; + + WRITE_ONCE(ctx->key_update_pending, false); + /* wake-up pre-existing poll() */ + ctx->saved_data_ready(sk); +} + +int tls_set_sw_offload(struct sock *sk, int tx, + struct tls_crypto_info *new_crypto_info) +{ + struct tls_crypto_info *crypto_info, *src_crypto_info; struct tls_sw_context_tx *sw_ctx_tx = NULL; struct tls_sw_context_rx *sw_ctx_rx = NULL; const struct tls_cipher_desc *cipher_desc; - struct tls_crypto_info *crypto_info; char *iv, *rec_seq, *key, *salt; struct cipher_context *cctx; struct tls_prot_info *prot; @@ -2733,45 +2743,47 @@ int tls_set_sw_offload(struct sock *sk, int tx) ctx = tls_get_ctx(sk); prot = &ctx->prot_info; - if (tx) { - ctx->priv_ctx_tx = init_ctx_tx(ctx, sk); - if (!ctx->priv_ctx_tx) - return -ENOMEM; + /* new_crypto_info != NULL means rekey */ + if (!new_crypto_info) { + if (tx) { + ctx->priv_ctx_tx = init_ctx_tx(ctx, sk); + if (!ctx->priv_ctx_tx) + return -ENOMEM; + } else { + ctx->priv_ctx_rx = init_ctx_rx(ctx); + if (!ctx->priv_ctx_rx) + return -ENOMEM; + } + } + if (tx) { sw_ctx_tx = ctx->priv_ctx_tx; crypto_info = &ctx->crypto_send.info; cctx = &ctx->tx; aead = &sw_ctx_tx->aead_send; } else { - ctx->priv_ctx_rx = init_ctx_rx(ctx); - if (!ctx->priv_ctx_rx) - return -ENOMEM; - sw_ctx_rx = ctx->priv_ctx_rx; crypto_info = &ctx->crypto_recv.info; cctx = &ctx->rx; aead = &sw_ctx_rx->aead_recv; - sw_ctx_rx->key_update_pending = false; } - cipher_desc = get_cipher_desc(crypto_info->cipher_type); + src_crypto_info = new_crypto_info ?: crypto_info; + + cipher_desc = get_cipher_desc(src_crypto_info->cipher_type); if (!cipher_desc) { rc = -EINVAL; goto free_priv; } - rc = init_prot_info(prot, crypto_info, cipher_desc); + rc = init_prot_info(prot, src_crypto_info, cipher_desc); if (rc) goto free_priv; - iv = crypto_info_iv(crypto_info, cipher_desc); - key = crypto_info_key(crypto_info, cipher_desc); - salt = crypto_info_salt(crypto_info, cipher_desc); - rec_seq = crypto_info_rec_seq(crypto_info, cipher_desc); - - memcpy(cctx->iv, salt, cipher_desc->salt); - memcpy(cctx->iv + cipher_desc->salt, iv, cipher_desc->iv); - memcpy(cctx->rec_seq, rec_seq, cipher_desc->rec_seq); + iv = crypto_info_iv(src_crypto_info, cipher_desc); + key = crypto_info_key(src_crypto_info, cipher_desc); + salt = crypto_info_salt(src_crypto_info, cipher_desc); + rec_seq = crypto_info_rec_seq(src_crypto_info, cipher_desc); if (!*aead) { *aead = crypto_alloc_aead(cipher_desc->cipher_name, 0, 0); @@ -2784,20 +2796,30 @@ int tls_set_sw_offload(struct sock *sk, int tx) ctx->push_pending_record = tls_sw_push_pending_record; + /* setkey is the last operation that could fail during a + * rekey. if it succeeds, we can start modifying the + * context. + */ rc = crypto_aead_setkey(*aead, key, cipher_desc->key); - if (rc) - goto free_aead; + if (rc) { + if (new_crypto_info) + goto out; + else + goto free_aead; + } - rc = crypto_aead_setauthsize(*aead, prot->tag_size); - if (rc) - goto free_aead; + if (!new_crypto_info) { + rc = crypto_aead_setauthsize(*aead, prot->tag_size); + if (rc) + goto free_aead; + } - if (sw_ctx_rx) { + if (!tx && !new_crypto_info) { tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv); tls_update_rx_zc_capable(ctx); sw_ctx_rx->async_capable = - crypto_info->version != TLS_1_3_VERSION && + src_crypto_info->version != TLS_1_3_VERSION && !!(tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC); rc = tls_strp_init(&sw_ctx_rx->strp, sk); @@ -2805,18 +2827,33 @@ int tls_set_sw_offload(struct sock *sk, int tx) goto free_aead; } + memcpy(cctx->iv, salt, cipher_desc->salt); + memcpy(cctx->iv + cipher_desc->salt, iv, cipher_desc->iv); + memcpy(cctx->rec_seq, rec_seq, cipher_desc->rec_seq); + + if (new_crypto_info) { + unsafe_memcpy(crypto_info, new_crypto_info, + cipher_desc->crypto_info, + /* size was checked in do_tls_setsockopt_conf */); + memzero_explicit(new_crypto_info, cipher_desc->crypto_info); + if (!tx) + tls_finish_key_update(sk, ctx); + } + goto out; free_aead: crypto_free_aead(*aead); *aead = NULL; free_priv: - if (tx) { - kfree(ctx->priv_ctx_tx); - ctx->priv_ctx_tx = NULL; - } else { - kfree(ctx->priv_ctx_rx); - ctx->priv_ctx_rx = NULL; + if (!new_crypto_info) { + if (tx) { + kfree(ctx->priv_ctx_tx); + ctx->priv_ctx_tx = NULL; + } else { + kfree(ctx->priv_ctx_rx); + ctx->priv_ctx_rx = NULL; + } } out: return rc;