diff --git a/sunshine/crypto.cpp b/sunshine/crypto.cpp index ddf41bc4..50d190d5 100644 --- a/sunshine/crypto.cpp +++ b/sunshine/crypto.cpp @@ -130,19 +130,19 @@ static int init_encrypt_cbc(cipher_ctx_t &ctx, aes_t *key, aes_t *iv, bool paddi return 0; } -int gcm_t::decrypt(const std::string_view &tagged_cipher, std::vector &plaintext) { - if(!decrypt_ctx && init_decrypt_gcm(decrypt_ctx, &key, &iv, padding)) { +int gcm_t::decrypt(const std::string_view &tagged_cipher, std::vector &plaintext, aes_t *iv) { + if(!decrypt_ctx && init_decrypt_gcm(decrypt_ctx, &key, iv, padding)) { return -1; } // Calling with cipher == nullptr results in a parameter change // without requiring a reallocation of the internal cipher ctx. - if(EVP_DecryptInit_ex(decrypt_ctx.get(), nullptr, nullptr, nullptr, iv.data()) != 1) { + if(EVP_DecryptInit_ex(decrypt_ctx.get(), nullptr, nullptr, nullptr, iv->data()) != 1) { return false; } - auto cipher = tagged_cipher.substr(16); - auto tag = tagged_cipher.substr(0, 16); + auto cipher = tagged_cipher.substr(tag_size); + auto tag = tagged_cipher.substr(0, tag_size); plaintext.resize((cipher.size() + 15) / 16 * 16); @@ -164,33 +164,38 @@ int gcm_t::decrypt(const std::string_view &tagged_cipher, std::vector &cipher) { - if(!encrypt_ctx && init_encrypt_gcm(encrypt_ctx, &key, &iv, padding)) { +int gcm_t::encrypt(const std::string_view &plaintext, std::uint8_t *tagged_cipher, aes_t *iv) { + if(!encrypt_ctx && init_encrypt_gcm(encrypt_ctx, &key, iv, padding)) { return -1; } // Calling with cipher == nullptr results in a parameter change // without requiring a reallocation of the internal cipher ctx. - if(EVP_EncryptInit_ex(encrypt_ctx.get(), nullptr, nullptr, nullptr, iv.data()) != 1) { - return false; + if(EVP_EncryptInit_ex(encrypt_ctx.get(), nullptr, nullptr, nullptr, iv->data()) != 1) { + return -1; } + auto tag = tagged_cipher; + auto cipher = tag + tag_size; + int len; - - cipher.resize((plaintext.size() + 15) / 16 * 16); - auto size = (int)cipher.size(); + int size = round_to_pkcs7_padded(plaintext.size()); // Encrypt into the caller's buffer - if(EVP_EncryptUpdate(encrypt_ctx.get(), cipher.data(), &size, (const std::uint8_t *)plaintext.data(), plaintext.size()) != 1) { + if(EVP_EncryptUpdate(encrypt_ctx.get(), cipher, &size, (const std::uint8_t *)plaintext.data(), plaintext.size()) != 1) { return -1; } - if(EVP_EncryptFinal_ex(encrypt_ctx.get(), cipher.data() + size, &len) != 1) { + // GCM encryption won't ever fill ciphertext here but we have to call it anyway + if(EVP_EncryptFinal_ex(encrypt_ctx.get(), cipher + size, &len) != 1) { return -1; } - cipher.resize(len + size); - return 0; + if(EVP_CIPHER_CTX_ctrl(encrypt_ctx.get(), EVP_CTRL_GCM_GET_TAG, tag_size, tag) != 1) { + return -1; + } + + return len + size; } int ecb_t::decrypt(const std::string_view &cipher, std::vector &plaintext) { @@ -285,10 +290,8 @@ ecb_t::ecb_t(const aes_t &key, bool padding) cbc_t::cbc_t(const aes_t &key, bool padding) : cipher_t { nullptr, nullptr, key, padding } {} -gcm_t::gcm_t(const crypto::aes_t &key, const crypto::aes_t &iv, bool padding) - : cipher_t { nullptr, nullptr, key, padding } { - this->iv = iv; -} +gcm_t::gcm_t(const crypto::aes_t &key, bool padding) + : cipher_t { nullptr, nullptr, key, padding } {} } // namespace cipher diff --git a/sunshine/crypto.h b/sunshine/crypto.h index 9c925845..7f648311 100644 --- a/sunshine/crypto.h +++ b/sunshine/crypto.h @@ -67,6 +67,7 @@ private: }; namespace cipher { +constexpr std::size_t tag_size = 16; constexpr std::size_t round_to_pkcs7_padded(std::size_t size) { return ((size + 15) / 16) * 16; } @@ -99,14 +100,17 @@ public: gcm_t(gcm_t &&) noexcept = default; gcm_t &operator=(gcm_t &&) noexcept = default; - gcm_t(const crypto::aes_t &key, const crypto::aes_t &iv, bool padding = true); + gcm_t(const crypto::aes_t &key, bool padding = true); - int encrypt(const std::string_view &plaintext, std::vector &cipher); - int decrypt(const std::string_view &cipher, std::vector &plaintext); + /** + * length of cipher must be at least: round_to_pkcs7_padded(plaintext.size()) + crypto::cipher::tag_size + * + * return -1 on error + * return bytes written on success + */ + int encrypt(const std::string_view &plaintext, std::uint8_t *tagged_cipher, aes_t *iv); - aes_t &get_iv() { return iv; } - - aes_t iv; + int decrypt(const std::string_view &cipher, std::vector &plaintext, aes_t *iv); }; class cbc_t : public cipher_t { diff --git a/sunshine/stream.cpp b/sunshine/stream.cpp index 98d2de47..8f1c26a7 100644 --- a/sunshine/stream.cpp +++ b/sunshine/stream.cpp @@ -123,7 +123,7 @@ using message_queue_queue_t = std::shared_ptrpayload()); @@ -174,23 +174,18 @@ public: _map_type_cb.emplace(type, std::move(cb)); } - void send(const std::string_view &payload, net::peer_t peer) { + int send(const std::string_view &payload, net::peer_t peer) { auto packet = enet_packet_create(payload.data(), payload.size(), ENET_PACKET_FLAG_RELIABLE); if(enet_peer_send(peer, 0, packet)) { enet_packet_destroy(packet); + + return -1; } - enet_host_flush(_host.get()); + return 0; } - void send(const std::string_view &payload) { - std::for_each(_host->peers, _host->peers + _host->peerCount, [payload](auto &peer) { - auto packet = enet_packet_create(payload.data(), payload.size(), ENET_PACKET_FLAG_RELIABLE); - if(enet_peer_send(&peer, 0, packet)) { - enet_packet_destroy(packet); - } - }); - + void flush() { enet_host_flush(_host.get()); } @@ -251,8 +246,10 @@ struct session_t { struct { crypto::cipher::gcm_t cipher; + crypto::aes_t iv; net::peer_t peer; + std::uint8_t seq; } control; safe::mail_raw_t::event_t shutdown_event; @@ -261,6 +258,40 @@ struct session_t { std::atomic state; }; +/** + * First part of cipher must be struct of type NVCTL_ENCRYPTED_PACKET_HEADER + * + * returns empty string_view on failure + * returns string_view pointing to payload data + */ +template +static inline std::string_view encode_control(session_t *session, const std::string_view &plaintext, std::array &tagged_cipher) { + static_assert( + max_payload_size >= sizeof(NVCTL_ENCRYPTED_PACKET_HEADER) + sizeof(crypto::cipher::tag_size), + "max_payload_size >= sizeof(NVCTL_ENCRYPTED_PACKET_HEADER) + sizeof(crypto::cipher::tag_size)"); + + + if(session->config.controlProtocolType != 13) { + return plaintext; + } + + crypto::aes_t iv {}; + auto seq = session->control.seq++; + iv[0] = seq; + + auto packet = (PNVCTL_ENCRYPTED_PACKET_HEADER)tagged_cipher.data(); + + auto bytes = session->control.cipher.encrypt(plaintext, packet->payload(), &iv); + if(bytes <= 0) { + BOOST_LOG(error) << "Couldn't encrypt control data"sv; + return {}; + } + + packet->seq = util::endian::little(seq); + + return std::string_view { (char *)tagged_cipher.data(), (std::size_t)bytes }; +} + int start_broadcast(broadcast_ctx_t &ctx); void end_broadcast(broadcast_ctx_t &ctx); @@ -529,7 +560,8 @@ void controlBroadcastThread(control_server_t *server) { std::vector plaintext; auto &cipher = session->control.cipher; - if(cipher.decrypt(tagged_cipher, plaintext)) { + auto &iv = session->control.iv; + if(cipher.decrypt(tagged_cipher, plaintext, &iv)) { // something went wrong :( BOOST_LOG(error) << "Failed to verify tag"sv; @@ -539,7 +571,7 @@ void controlBroadcastThread(control_server_t *server) { } if(tagged_cipher_length >= 16 + sizeof(crypto::aes_t)) { - std::copy(payload.end() - 16, payload.end(), std::begin(cipher.get_iv())); + std::copy(payload.end() - 16, payload.end(), std::begin(iv)); } input::print(plaintext.data()); @@ -564,11 +596,13 @@ void controlBroadcastThread(control_server_t *server) { auto &cipher = session->control.cipher; crypto::aes_t iv {}; - iv[0] = (char)seq; - cipher.get_iv() = iv; + iv[0] = (std::uint8_t)seq; + + // update control sequence + ++session->control.seq; std::vector plaintext; - if(cipher.decrypt(tagged_cipher, plaintext)) { + if(cipher.decrypt(tagged_cipher, plaintext, &iv)) { // something went wrong :( BOOST_LOG(error) << "Failed to verify tag"sv; @@ -644,18 +678,34 @@ void controlBroadcastThread(control_server_t *server) { // Let all remaining connections know the server is shutting down std::uint16_t reason = 0x0100; - std::array payload; - payload[0] = packetTypes[IDX_TERMINATION]; - payload[1] = reason; + std::array plaintext; + plaintext[0] = packetTypes[IDX_TERMINATION]; + plaintext[1] = reason; - server->send(std::string_view { (char *)payload.data(), payload.size() }); + std::array + encrypted_payload; + + auto packet = (PNVCTL_ENCRYPTED_PACKET_HEADER)encrypted_payload.data(); + packet->encryptedHeaderType = util::endian::little(0x0001); + packet->length = encrypted_payload.size() - sizeof(NVCTL_ENCRYPTED_PACKET_HEADER) + 4; auto lg = server->_map_addr_session.lock(); for(auto pos = std::begin(*server->_map_addr_session); pos != std::end(*server->_map_addr_session); ++pos) { auto session = pos->second.second; + + auto payload = encode_control(session, std::string_view { (char *)plaintext.data(), plaintext.size() }, encrypted_payload); + + if(server->send(payload, session->control.peer)) { + TUPLE_2D(port, addr, platf::from_sockaddr_ex((sockaddr *)&session->control.peer->address.address)); + BOOST_LOG(warning) << "Couldn't send termination code to ["sv << addr << ':' << port << ']'; + } + session->shutdown_event->raise(true); session->controlEnd.raise(true); } + + server->flush(); } void recvThread(broadcast_ctx_t &ctx) { @@ -1154,14 +1204,14 @@ std::shared_ptr alloc(config_t &config, crypto::aes_t &gcm_key, crypt session->config = config; + session->control.iv = iv; session->control.cipher = crypto::cipher::gcm_t { - gcm_key, iv, false + gcm_key, false }; session->video.idr_events = mail->event(mail::idr); session->video.lowseq = 0; - session->audio.cipher = crypto::cipher::cbc_t { gcm_key, true };