diff --git a/src/core/decryptor.cpp b/src/core/decryptor.cpp index 6c8781f..b7a4952 100644 --- a/src/core/decryptor.cpp +++ b/src/core/decryptor.cpp @@ -59,14 +59,15 @@ bool SDMCDecryptor::DecryptAndWriteFile(const std::string& source, const std::st return false; } + auto key = Key::GetNormalKey(Key::SDKey); + auto ctr = GetFileCTR(source); + quick_decryptor.SetCrypto(CreateCTRCrypto(key, ctr)); + auto source_file = std::make_shared(root_folder + source, "rb"); auto size = source_file->GetSize(); auto destination_file = std::make_shared(destination, "wb"); - auto key = Key::GetNormalKey(Key::SDKey); - auto ctr = GetFileCTR(source); - return quick_decryptor.DecryptAndWriteFile(std::move(source_file), size, - std::move(destination_file), callback, true, - std::move(key), std::move(ctr)); + return quick_decryptor.CryptAndWriteFile(std::move(source_file), size, + std::move(destination_file), callback); } void SDMCDecryptor::Abort() { diff --git a/src/core/importer.cpp b/src/core/importer.cpp index 5446851..e8ba239 100644 --- a/src/core/importer.cpp +++ b/src/core/importer.cpp @@ -196,8 +196,8 @@ bool SDMCImporter::ImportNandTitle(const ContentSpecifier& specifier, LOG_ERROR(Core, "Could not create path {}", citra_path); return false; } - // Do not specify keys: plain copy with progress. - return quick_decryptor.DecryptAndWriteFile( + // Crypto is not set: plain copy with progress. + return quick_decryptor.CryptAndWriteFile( std::make_shared(physical_path, "rb"), FileUtil::GetSize(physical_path), std::make_shared(citra_path, "wb"), callback); diff --git a/src/core/ncch/ncch_container.cpp b/src/core/ncch/ncch_container.cpp index 5183ce2..8dbea31 100644 --- a/src/core/ncch/ncch_container.cpp +++ b/src/core/ncch/ncch_container.cpp @@ -444,13 +444,14 @@ ResultStatus NCCHContainer::DecryptToFile(std::shared_ptr dest } if (!is_encrypted) { - // Simply copy everything + // Simply copy everything. QuickDecryptor is used for progress reporting file->Seek(0, SEEK_SET); const auto size = file->GetSize(); - decryptor.Reset(size); - const bool ret = decryptor.DecryptAndWriteFile(file, size, dest_file, callback); + decryptor.Reset(size); + decryptor.SetCrypto(nullptr); + const bool ret = decryptor.CryptAndWriteFile(file, size, dest_file, callback); return ret ? ResultStatus::Success : ResultStatus::Error; } @@ -520,8 +521,9 @@ ResultStatus NCCHContainer::DecryptToFile(std::shared_ptr dest } written = offset; - if (!decryptor.DecryptAndWriteFile(file, size, dest_file, decryptor_callback, decrypt, key, - ctr, aes_seek_pos)) { + + decryptor.SetCrypto(decrypt ? CreateCTRCrypto(key, ctr, aes_seek_pos) : nullptr); + if (!decryptor.CryptAndWriteFile(file, size, dest_file, decryptor_callback)) { LOG_ERROR(Core, "Could not write {}", name); return false; } diff --git a/src/core/quick_decryptor.cpp b/src/core/quick_decryptor.cpp index 4683fed..1d2f6d5 100644 --- a/src/core/quick_decryptor.cpp +++ b/src/core/quick_decryptor.cpp @@ -20,12 +20,13 @@ QuickDecryptor::QuickDecryptor() = default; QuickDecryptor::~QuickDecryptor() = default; -bool QuickDecryptor::DecryptAndWriteFile(std::shared_ptr source_, - std::size_t size, - std::shared_ptr destination_, - const Common::ProgressCallback& callback_, bool decrypt_, - Core::Key::AESKey key_, Core::Key::AESKey ctr_, - std::size_t aes_seek_pos_) { +void QuickDecryptor::SetCrypto(std::shared_ptr crypto_) { + crypto = std::move(crypto_); +} + +bool QuickDecryptor::CryptAndWriteFile(std::shared_ptr source_, std::size_t size, + std::shared_ptr destination_, + const Common::ProgressCallback& callback_) { if (is_running) { LOG_ERROR(Core, "Decryptor is running"); return false; @@ -48,10 +49,6 @@ bool QuickDecryptor::DecryptAndWriteFile(std::shared_ptr sourc source = std::move(source_); destination = std::move(destination_); - decrypt = decrypt_; - key = std::move(key_); - ctr = std::move(ctr_); - aes_seek_pos = aes_seek_pos_; callback = callback_; current_total_size = size; @@ -60,7 +57,7 @@ bool QuickDecryptor::DecryptAndWriteFile(std::shared_ptr sourc read_thread = std::make_unique(&QuickDecryptor::DataReadLoop, this); write_thread = std::make_unique(&QuickDecryptor::DataWriteLoop, this); - if (decrypt) { + if (crypto) { decrypt_thread = std::make_unique(&QuickDecryptor::DataDecryptLoop, this); } @@ -69,7 +66,7 @@ bool QuickDecryptor::DecryptAndWriteFile(std::shared_ptr sourc read_thread->join(); write_thread->join(); - if (decrypt) { + if (crypto) { decrypt_thread->join(); } @@ -117,10 +114,6 @@ void QuickDecryptor::DataReadLoop() { } void QuickDecryptor::DataDecryptLoop() { - CryptoPP::CTR_Mode::Decryption aes; - aes.SetKeyWithIV(key.data(), key.size(), ctr.data()); - aes.Seek(aes_seek_pos); - std::size_t current_buffer = 0; std::size_t file_size = current_total_size; @@ -128,8 +121,7 @@ void QuickDecryptor::DataDecryptLoop() { data_read_event[current_buffer].Wait(); const auto bytes_to_process = std::min(BufferSize, file_size); - aes.ProcessData(buffers[current_buffer].data(), buffers[current_buffer].data(), - bytes_to_process); + crypto->ProcessData(buffers[current_buffer].data(), bytes_to_process); file_size -= bytes_to_process; @@ -159,7 +151,7 @@ void QuickDecryptor::DataWriteLoop() { iteration++; - if (decrypt) { + if (crypto) { data_decrypted_event[current_buffer].Wait(); } else { data_read_event[current_buffer].Wait(); @@ -194,4 +186,30 @@ void QuickDecryptor::Reset(std::size_t total_size_) { imported_size = 0; } +CryptoFunc::~CryptoFunc() = default; + +class CryptoFunc_AES_CTR final : public CryptoFunc { +public: + explicit CryptoFunc_AES_CTR(const Key::AESKey& key, const Key::AESKey& ctr, + std::size_t seek_pos = 0) { + + aes.SetKeyWithIV(key.data(), key.size(), ctr.data()); + aes.Seek(seek_pos); + } + + ~CryptoFunc_AES_CTR() override = default; + + void ProcessData(u8* data, std::size_t size) override { + aes.ProcessData(data, data, size); + } + +private: + CryptoPP::CTR_Mode::Decryption aes; +}; + +std::shared_ptr CreateCTRCrypto(const Key::AESKey& key, const Key::AESKey& ctr, + std::size_t seek_pos) { + return std::make_shared(key, ctr, seek_pos); +} + } // namespace Core diff --git a/src/core/quick_decryptor.h b/src/core/quick_decryptor.h index 009bae0..81ecef7 100644 --- a/src/core/quick_decryptor.h +++ b/src/core/quick_decryptor.h @@ -15,6 +15,8 @@ namespace Core { +class CryptoFunc; + /** * Helper that reads, decrypts and writes data. This uses three threads to process the data * and call progress callbacks occasionally. @@ -25,23 +27,23 @@ public: ~QuickDecryptor(); /** - * Decrypts and writes a file. + * Set up the crypto to use. + * Default / nullptr is plain copy. + */ + void SetCrypto(std::shared_ptr crypto); + + /** + * Crypts and writes a file. * * @param source Source file * @param size Size to read, decrypt and write * @param destination Destination file - * @param callback Progress callback - * @param decrypt Whether to perform decryption or not - * @param key AES Key for decryption - * @param ctr AES CTR for decryption - * @param aes_seek_pos The position to seek to for decryption. + * @param callback Progress callback. default for nothing. */ - bool DecryptAndWriteFile( + bool CryptAndWriteFile( std::shared_ptr source, std::size_t size, std::shared_ptr destination, - const Common::ProgressCallback& callback = [](std::size_t, std::size_t) {}, - bool decrypt = false, Core::Key::AESKey key = {}, Core::Key::AESKey ctr = {}, - std::size_t aes_seek_pos = 0); + const Common::ProgressCallback& callback = [](std::size_t, std::size_t) {}); void DataReadLoop(); void DataDecryptLoop(); @@ -57,10 +59,7 @@ private: std::shared_ptr source; std::shared_ptr destination; - bool decrypt{}; - Core::Key::AESKey key; - Core::Key::AESKey ctr; - std::size_t aes_seek_pos; + std::shared_ptr crypto; // Total size of this content, may consist of multiple files std::size_t total_size{}; @@ -85,4 +84,13 @@ private: std::atomic_bool is_running{false}; }; +class CryptoFunc { +public: + virtual ~CryptoFunc(); + virtual void ProcessData(u8* data, std::size_t size) = 0; +}; + +std::shared_ptr CreateCTRCrypto(const Key::AESKey& key, const Key::AESKey& ctr, + std::size_t seek_pos = 0); + } // namespace Core