diff --git a/src/tpm2/crypto/openssl/CryptSym.c b/src/tpm2/crypto/openssl/CryptSym.c index 2c97c1ab..756102cf 100644 --- a/src/tpm2/crypto/openssl/CryptSym.c +++ b/src/tpm2/crypto/openssl/CryptSym.c @@ -540,7 +540,7 @@ CryptSymmetricEncrypt( INT16 blockSize; BYTE *iv; BYTE defaultIv[MAX_SYM_BLOCK_SIZE] = {0}; - evpfunc evpfn; + const EVP_CIPHER *evp_cipher; EVP_CIPHER_CTX *ctx = NULL; int outlen1 = 0; int outlen2 = 0; @@ -580,9 +580,9 @@ CryptSymmetricEncrypt( return TPM_RC_SIZE; } - evpfn = GetEVPCipher(algorithm, keySizeInBits, mode, key, - keyToUse, &keyToUseLen); - if (evpfn == NULL) + evp_cipher = GetEVPCipher(algorithm, keySizeInBits, mode, key, + keyToUse, &keyToUseLen); + if (evp_cipher == NULL) return TPM_RC_FAILURE; if (dIn == dOut) { @@ -605,7 +605,7 @@ CryptSymmetricEncrypt( ctx = EVP_CIPHER_CTX_new(); if (!ctx || - EVP_EncryptInit_ex(ctx, evpfn(), NULL, keyToUse, iv) != 1 || + EVP_EncryptInit_ex(ctx, evp_cipher, NULL, keyToUse, iv) != 1 || EVP_CIPHER_CTX_set_padding(ctx, 0) != 1 || EVP_EncryptUpdate(ctx, pOut, &outlen1, dIn, dSize) != 1) ERROR_RETURN(TPM_RC_FAILURE); @@ -656,7 +656,7 @@ CryptSymmetricDecrypt( INT16 blockSize; BYTE *iv; BYTE defaultIv[MAX_SYM_BLOCK_SIZE] = {0}; - evpfunc evpfn; + const EVP_CIPHER *evp_cipher; EVP_CIPHER_CTX *ctx = NULL; int outlen1 = 0; int outlen2 = 0; @@ -704,9 +704,9 @@ CryptSymmetricDecrypt( break; } - evpfn = GetEVPCipher(algorithm, keySizeInBits, mode, key, - keyToUse, &keyToUseLen); - if (evpfn == NULL) + evp_cipher = GetEVPCipher(algorithm, keySizeInBits, mode, key, + keyToUse, &keyToUseLen); + if (evp_cipher == NULL) return TPM_RC_FAILURE; /* a buffer with a 'safety margin' for EVP_DecryptUpdate */ @@ -725,7 +725,7 @@ CryptSymmetricDecrypt( ctx = EVP_CIPHER_CTX_new(); if (!ctx || - EVP_DecryptInit_ex(ctx, evpfn(), NULL, keyToUse, iv) != 1 || + EVP_DecryptInit_ex(ctx, evp_cipher, NULL, keyToUse, iv) != 1 || EVP_CIPHER_CTX_set_padding(ctx, 0) != 1 || EVP_DecryptUpdate(ctx, buffer, &outlen1, dIn, dSize) != 1) ERROR_RETURN(TPM_RC_FAILURE); diff --git a/src/tpm2/crypto/openssl/Helpers.c b/src/tpm2/crypto/openssl/Helpers.c index 03a454a0..4ad0e2f0 100644 --- a/src/tpm2/crypto/openssl/Helpers.c +++ b/src/tpm2/crypto/openssl/Helpers.c @@ -72,6 +72,8 @@ # include #endif +typedef const EVP_CIPHER *(*evpfunc)(void); + /* to enable RSA_check_key() on private keys set to != 0 */ #ifndef DO_RSA_CHECK_KEY #define DO_RSA_CHECK_KEY 0 @@ -113,17 +115,56 @@ OpenSSLCryptGenerateKeyDes( return retVal; } -evpfunc GetEVPCipher(TPM_ALG_ID algorithm, // IN - UINT16 keySizeInBits, // IN - TPM_ALG_ID mode, // IN - const BYTE *key, // IN - BYTE *keyToUse, // OUT same as key or stretched key - UINT16 *keyToUseLen // IN/OUT - ) + +#define __NUM_ALGS 4 /* AES, TDES, Camellia, SM4 */ +#define __NUM_MODES 5 /* CTR, OFB, CBC, CFB, ECB */ +#define __NUM_KEYSIZES 3 /* 128, 192, 256 */ + +static const EVP_CIPHER *evp_cipher_cache[__NUM_ALGS][__NUM_MODES][__NUM_KEYSIZES] = { + { { NULL, } }, +}; + +static const EVP_CIPHER * +GetCachedEVPCipher( + evpfunc evpfunc, // IN + size_t algIdx, // IN algorithm Index for the cache + TPM_ALG_ID mode, // IN mode + size_t keySizeIdx // IN + ) +{ + size_t modeIdx = mode - ALG_CTR_VALUE; + const EVP_CIPHER *evp_cipher; + + pAssert(algIdx < __NUM_ALGS && + modeIdx < __NUM_MODES && + keySizeIdx < __NUM_KEYSIZES); + + evp_cipher = evp_cipher_cache[algIdx][modeIdx][keySizeIdx]; + if (evp_cipher == NULL) { + evp_cipher = evpfunc(); + evp_cipher_cache[algIdx][modeIdx][keySizeIdx] = evp_cipher; + } + + return evp_cipher; +} + +#undef __NUM_KEYSIZES +#undef __NUM_MODES +#undef __NUM_ALGS + +const EVP_CIPHER * +GetEVPCipher(TPM_ALG_ID algorithm, // IN + UINT16 keySizeInBits, // IN + TPM_ALG_ID mode, // IN + const BYTE *key, // IN + BYTE *keyToUse, // OUT same as key or stretched key + UINT16 *keyToUseLen // IN/OUT + ) { int i; UINT16 keySizeInBytes = keySizeInBits / 8; evpfunc evpfn = NULL; + size_t algIdx; // key size to array index: 128 -> 0, 192 -> 1, 256 -> 2 i = (keySizeInBits >> 6) - 2; @@ -136,7 +177,9 @@ evpfunc GetEVPCipher(TPM_ALG_ID algorithm, // IN switch (algorithm) { #if ALG_AES case TPM_ALG_AES: + algIdx = 0; *keyToUseLen = keySizeInBytes; + switch (mode) { #if ALG_CTR case TPM_ALG_CTR: @@ -173,6 +216,7 @@ evpfunc GetEVPCipher(TPM_ALG_ID algorithm, // IN #endif #if ALG_TDES case TPM_ALG_TDES: + algIdx = 1; if (keySizeInBits == 128) { pAssert(*keyToUseLen >= BITS_TO_BYTES(192)) // stretch the key @@ -212,7 +256,9 @@ evpfunc GetEVPCipher(TPM_ALG_ID algorithm, // IN #if ALG_SM4 case TPM_ALG_SM4: + algIdx = 2; *keyToUseLen = keySizeInBytes; + switch (mode) { #if ALG_CTR case TPM_ALG_CTR: @@ -245,7 +291,9 @@ evpfunc GetEVPCipher(TPM_ALG_ID algorithm, // IN #if ALG_CAMELLIA case TPM_ALG_CAMELLIA: + algIdx = 3; *keyToUseLen = keySizeInBytes; + switch (mode) { #if ALG_CTR case TPM_ALG_CTR: @@ -282,10 +330,13 @@ evpfunc GetEVPCipher(TPM_ALG_ID algorithm, // IN #endif } - if (evpfn == NULL) + if (evpfn == NULL) { MemorySet(keyToUse, 0, *keyToUseLen); + return NULL; + } - return evpfn; + /* get cached result of evpfn() */ + return GetCachedEVPCipher(evpfn, algIdx, mode, i); } TPM_RC DoEVPGetIV( diff --git a/src/tpm2/crypto/openssl/Helpers_fp.h b/src/tpm2/crypto/openssl/Helpers_fp.h index c42402d9..3fa0f96a 100644 --- a/src/tpm2/crypto/openssl/Helpers_fp.h +++ b/src/tpm2/crypto/openssl/Helpers_fp.h @@ -71,15 +71,13 @@ OpenSSLCryptGenerateKeyDes( TPMT_SENSITIVE *sensitive // OUT: sensitive area ); -typedef const EVP_CIPHER *(*evpfunc)(void); - -evpfunc GetEVPCipher(TPM_ALG_ID algorithm, // IN - UINT16 keySizeInBits, // IN - TPM_ALG_ID mode, // IN - const BYTE *key, // IN - BYTE *keyToUse, // OUT same as key or stretched key - UINT16 *keyToUseLen // IN/OUT - ); +const EVP_CIPHER *GetEVPCipher(TPM_ALG_ID algorithm, // IN + UINT16 keySizeInBits, // IN + TPM_ALG_ID mode, // IN + const BYTE *key, // IN + BYTE *keyToUse, // OUT same as key or stretched key + UINT16 *keyToUseLen // IN/OUT + ); TPM_RC DoEVPGetIV( EVP_CIPHER_CTX *ctx, // IN: required context