diff --git a/src/AES.cpp b/src/AES.cpp index b3e91c1..7528947 100644 --- a/src/AES.cpp +++ b/src/AES.cpp @@ -1,7 +1,6 @@ #include "AES.h" AES::AES(const AESKeyLength keyLength) { - this->Nb = 4; switch (keyLength) { case AESKeyLength::AES_128: this->Nk = 4; @@ -16,8 +15,6 @@ AES::AES(const AESKeyLength keyLength) { this->Nr = 14; break; } - - blockBytesLen = 4 * this->Nb * sizeof(unsigned char); } unsigned char *AES::EncryptECB(const unsigned char in[], unsigned int inLen, @@ -55,7 +52,7 @@ unsigned char *AES::EncryptCBC(const unsigned char in[], unsigned int inLen, const unsigned char *iv) { CheckLength(inLen); unsigned char *out = new unsigned char[inLen]; - unsigned char *block = new unsigned char[blockBytesLen]; + unsigned char block[blockBytesLen]; unsigned char *roundKeys = new unsigned char[4 * Nb * (Nr + 1)]; KeyExpansion(key, roundKeys); memcpy(block, iv, blockBytesLen); @@ -65,7 +62,6 @@ unsigned char *AES::EncryptCBC(const unsigned char in[], unsigned int inLen, memcpy(block, out + i, blockBytesLen); } - delete[] block; delete[] roundKeys; return out; @@ -76,7 +72,7 @@ unsigned char *AES::DecryptCBC(const unsigned char in[], unsigned int inLen, const unsigned char *iv) { CheckLength(inLen); unsigned char *out = new unsigned char[inLen]; - unsigned char *block = new unsigned char[blockBytesLen]; + unsigned char block[blockBytesLen]; unsigned char *roundKeys = new unsigned char[4 * Nb * (Nr + 1)]; KeyExpansion(key, roundKeys); memcpy(block, iv, blockBytesLen); @@ -86,7 +82,6 @@ unsigned char *AES::DecryptCBC(const unsigned char in[], unsigned int inLen, memcpy(block, in + i, blockBytesLen); } - delete[] block; delete[] roundKeys; return out; @@ -97,8 +92,8 @@ unsigned char *AES::EncryptCFB(const unsigned char in[], unsigned int inLen, const unsigned char *iv) { CheckLength(inLen); unsigned char *out = new unsigned char[inLen]; - unsigned char *block = new unsigned char[blockBytesLen]; - unsigned char *encryptedBlock = new unsigned char[blockBytesLen]; + unsigned char block[blockBytesLen]; + unsigned char encryptedBlock[blockBytesLen]; unsigned char *roundKeys = new unsigned char[4 * Nb * (Nr + 1)]; KeyExpansion(key, roundKeys); memcpy(block, iv, blockBytesLen); @@ -108,8 +103,6 @@ unsigned char *AES::EncryptCFB(const unsigned char in[], unsigned int inLen, memcpy(block, out + i, blockBytesLen); } - delete[] block; - delete[] encryptedBlock; delete[] roundKeys; return out; @@ -120,8 +113,8 @@ unsigned char *AES::DecryptCFB(const unsigned char in[], unsigned int inLen, const unsigned char *iv) { CheckLength(inLen); unsigned char *out = new unsigned char[inLen]; - unsigned char *block = new unsigned char[blockBytesLen]; - unsigned char *encryptedBlock = new unsigned char[blockBytesLen]; + unsigned char block[blockBytesLen]; + unsigned char encryptedBlock[blockBytesLen]; unsigned char *roundKeys = new unsigned char[4 * Nb * (Nr + 1)]; KeyExpansion(key, roundKeys); memcpy(block, iv, blockBytesLen); @@ -131,8 +124,6 @@ unsigned char *AES::DecryptCFB(const unsigned char in[], unsigned int inLen, memcpy(block, in + i, blockBytesLen); } - delete[] block; - delete[] encryptedBlock; delete[] roundKeys; return out; @@ -147,12 +138,8 @@ void AES::CheckLength(unsigned int len) { void AES::EncryptBlock(const unsigned char in[], unsigned char out[], unsigned char *roundKeys) { - unsigned char **state = new unsigned char *[4]; - state[0] = new unsigned char[4 * Nb]; - int i, j, round; - for (i = 0; i < 4; i++) { - state[i] = state[0] + Nb * i; - } + unsigned char state[4][Nb]; + unsigned int i, j, round; for (i = 0; i < 4; i++) { for (j = 0; j < Nb; j++) { @@ -178,19 +165,12 @@ void AES::EncryptBlock(const unsigned char in[], unsigned char out[], out[i + 4 * j] = state[i][j]; } } - - delete[] state[0]; - delete[] state; } void AES::DecryptBlock(const unsigned char in[], unsigned char out[], unsigned char *roundKeys) { - unsigned char **state = new unsigned char *[4]; - state[0] = new unsigned char[4 * Nb]; - int i, j, round; - for (i = 0; i < 4; i++) { - state[i] = state[0] + Nb * i; - } + unsigned char state[4][Nb]; + unsigned int i, j, round; for (i = 0; i < 4; i++) { for (j = 0; j < Nb; j++) { @@ -216,13 +196,10 @@ void AES::DecryptBlock(const unsigned char in[], unsigned char out[], out[i + 4 * j] = state[i][j]; } } - - delete[] state[0]; - delete[] state; } -void AES::SubBytes(unsigned char **state) { - int i, j; +void AES::SubBytes(unsigned char state[4][Nb]) { + unsigned int i, j; unsigned char t; for (i = 0; i < 4; i++) { for (j = 0; j < Nb; j++) { @@ -232,19 +209,17 @@ void AES::SubBytes(unsigned char **state) { } } -void AES::ShiftRow(unsigned char **state, int i, - int n) // shift row i on n positions +void AES::ShiftRow(unsigned char state[4][Nb], unsigned int i, + unsigned int n) // shift row i on n positions { - unsigned char *tmp = new unsigned char[Nb]; - for (int j = 0; j < Nb; j++) { + unsigned char tmp[Nb]; + for (unsigned int j = 0; j < Nb; j++) { tmp[j] = state[i][(j + n) % Nb]; } memcpy(state[i], tmp, Nb * sizeof(unsigned char)); - - delete[] tmp; } -void AES::ShiftRows(unsigned char **state) { +void AES::ShiftRows(unsigned char state[4][Nb]) { ShiftRow(state, 1, 1); ShiftRow(state, 2, 2); ShiftRow(state, 3, 3); @@ -255,8 +230,8 @@ unsigned char AES::xtime(unsigned char b) // multiply on x return (b << 1) ^ (((b >> 7) & 1) * 0x1b); } -void AES::MixColumns(unsigned char **state) { - unsigned char temp_state[4][4]; +void AES::MixColumns(unsigned char state[4][Nb]) { + unsigned char temp_state[4][Nb]; for (size_t i = 0; i < 4; ++i) { memset(temp_state[i], 0, 4); @@ -278,8 +253,8 @@ void AES::MixColumns(unsigned char **state) { } } -void AES::AddRoundKey(unsigned char **state, unsigned char *key) { - int i, j; +void AES::AddRoundKey(unsigned char state[4][Nb], unsigned char *key) { + unsigned int i, j; for (i = 0; i < 4; i++) { for (j = 0; j < Nb; j++) { state[i][j] = state[i][j] ^ key[i + 4 * j]; @@ -309,8 +284,8 @@ void AES::XorWords(unsigned char *a, unsigned char *b, unsigned char *c) { } } -void AES::Rcon(unsigned char *a, int n) { - int i; +void AES::Rcon(unsigned char *a, unsigned int n) { + unsigned int i; unsigned char c = 1; for (i = 0; i < n - 1; i++) { c = xtime(c); @@ -321,10 +296,10 @@ void AES::Rcon(unsigned char *a, int n) { } void AES::KeyExpansion(const unsigned char key[], unsigned char w[]) { - unsigned char *temp = new unsigned char[4]; - unsigned char *rcon = new unsigned char[4]; + unsigned char temp[4]; + unsigned char rcon[4]; - int i = 0; + unsigned int i = 0; while (i < 4 * Nk) { w[i] = key[i]; i++; @@ -352,13 +327,10 @@ void AES::KeyExpansion(const unsigned char key[], unsigned char w[]) { w[i + 3] = w[i + 3 - 4 * Nk] ^ temp[3]; i += 4; } - - delete[] rcon; - delete[] temp; } -void AES::InvSubBytes(unsigned char **state) { - int i, j; +void AES::InvSubBytes(unsigned char state[4][Nb]) { + unsigned int i, j; unsigned char t; for (i = 0; i < 4; i++) { for (j = 0; j < Nb; j++) { @@ -368,8 +340,8 @@ void AES::InvSubBytes(unsigned char **state) { } } -void AES::InvMixColumns(unsigned char **state) { - unsigned char temp_state[4][4]; +void AES::InvMixColumns(unsigned char state[4][Nb]) { + unsigned char temp_state[4][Nb]; for (size_t i = 0; i < 4; ++i) { memset(temp_state[i], 0, 4); @@ -388,7 +360,7 @@ void AES::InvMixColumns(unsigned char **state) { } } -void AES::InvShiftRows(unsigned char **state) { +void AES::InvShiftRows(unsigned char state[4][Nb]) { ShiftRow(state, 1, Nb - 1); ShiftRow(state, 2, Nb - 2); ShiftRow(state, 3, Nb - 3); diff --git a/src/AES.h b/src/AES.h index 359955a..2f3c42d 100644 --- a/src/AES.h +++ b/src/AES.h @@ -11,24 +11,24 @@ enum class AESKeyLength { AES_128, AES_192, AES_256 }; class AES { private: - int Nb; - int Nk; - int Nr; + static constexpr unsigned int Nb = 4; + static constexpr unsigned int blockBytesLen = 4 * Nb * sizeof(unsigned char); - unsigned int blockBytesLen; + unsigned int Nk; + unsigned int Nr; - void SubBytes(unsigned char **state); + void SubBytes(unsigned char state[4][Nb]); - void ShiftRow(unsigned char **state, int i, - int n); // shift row i on n positions + void ShiftRow(unsigned char state[4][Nb], unsigned int i, + unsigned int n); // shift row i on n positions - void ShiftRows(unsigned char **state); + void ShiftRows(unsigned char state[4][Nb]); unsigned char xtime(unsigned char b); // multiply on x - void MixColumns(unsigned char **state); + void MixColumns(unsigned char state[4][Nb]); - void AddRoundKey(unsigned char **state, unsigned char *key); + void AddRoundKey(unsigned char state[4][Nb], unsigned char *key); void SubWord(unsigned char *a); @@ -36,13 +36,13 @@ class AES { void XorWords(unsigned char *a, unsigned char *b, unsigned char *c); - void Rcon(unsigned char *a, int n); + void Rcon(unsigned char *a, unsigned int n); - void InvSubBytes(unsigned char **state); + void InvSubBytes(unsigned char state[4][Nb]); - void InvMixColumns(unsigned char **state); + void InvMixColumns(unsigned char state[4][Nb]); - void InvShiftRows(unsigned char **state); + void InvShiftRows(unsigned char state[4][Nb]); void CheckLength(unsigned int len);