Use virtual classes instead of templates for file interfaces

For convenience.
This commit is contained in:
zhupengfei
2020-08-01 09:47:44 +08:00
parent e861d84b72
commit 1f91cbdaec
11 changed files with 111 additions and 212 deletions
+26 -14
View File
@@ -821,20 +821,6 @@ IOFile::~IOFile() {
Close();
}
IOFile::IOFile(IOFile&& other) {
Swap(other);
}
IOFile& IOFile::operator=(IOFile&& other) {
Swap(other);
return *this;
}
void IOFile::Swap(IOFile& other) {
std::swap(m_file, other.m_file);
std::swap(m_good, other.m_good);
}
bool IOFile::Open(const std::string& filename, const char openmode[], int flags) {
Close();
#ifdef _WIN32
@@ -861,6 +847,32 @@ bool IOFile::Close() {
return m_good;
}
std::size_t IOFile::Read(char* data, std::size_t length) {
if (!IsOpen()) {
m_good = false;
return std::numeric_limits<std::size_t>::max();
}
std::size_t items_read = std::fread(data, 1, length, m_file);
if (items_read != length)
m_good = false;
return items_read;
}
std::size_t IOFile::Write(const char* data, std::size_t length) {
if (!IsOpen()) {
m_good = false;
return std::numeric_limits<std::size_t>::max();
}
std::size_t items_written = std::fwrite(data, 1, length, m_file);
if (items_written != length)
m_good = false;
return items_written;
}
u64 IOFile::GetSize() const {
if (IsOpen())
return FileUtil::GetSize(m_file);
+9 -33
View File
@@ -178,48 +178,21 @@ public:
// isn't considered "locked" while citra is open and people can open the log file and view it
IOFile(const std::string& filename, const char openmode[], int flags = 0);
~IOFile();
IOFile(IOFile&& other);
IOFile& operator=(IOFile&& other);
void Swap(IOFile& other);
virtual ~IOFile();
bool Open(const std::string& filename, const char openmode[], int flags = 0);
bool Close();
template <typename T>
std::size_t ReadArray(T* data, std::size_t length) {
static_assert(std::is_trivially_copyable_v<T>,
"Given array does not consist of trivially copyable objects");
if (!IsOpen()) {
m_good = false;
return std::numeric_limits<std::size_t>::max();
}
std::size_t items_read = std::fread(data, sizeof(T), length, m_file);
if (items_read != length)
m_good = false;
return items_read;
static_assert(std::is_trivially_copyable_v<T>, "T must be trivially copyable");
return Read(reinterpret_cast<char*>(data), length * sizeof(T));
}
template <typename T>
std::size_t WriteArray(const T* data, std::size_t length) {
static_assert(std::is_trivially_copyable_v<T>,
"Given array does not consist of trivially copyable objects");
if (!IsOpen()) {
m_good = false;
return std::numeric_limits<std::size_t>::max();
}
std::size_t items_written = std::fwrite(data, sizeof(T), length, m_file);
if (items_written != length)
m_good = false;
return items_written;
static_assert(std::is_trivially_copyable_v<T>, "T must be trivially copyable");
return Write(reinterpret_cast<const char*>(data), length * sizeof(T));
}
template <typename T>
@@ -244,6 +217,9 @@ public:
return WriteArray(str.data(), str.length());
}
virtual std::size_t Read(char* data, std::size_t length);
virtual std::size_t Write(const char* data, std::size_t length);
bool IsOpen() const {
return nullptr != m_file;
}
@@ -256,7 +232,7 @@ public:
return IsGood();
}
bool Seek(s64 off, int origin);
virtual bool Seek(s64 off, int origin);
u64 Tell() const;
u64 GetSize() const;
bool Resize(u64 size);
+23 -36
View File
@@ -98,18 +98,25 @@ std::vector<u8> SDMCDecryptor::DecryptFile(const std::string& source) const {
return data;
}
SDMCFile::SDMCFile() = default;
struct SDMCFile::Impl {
CryptoPP::CTR_Mode<CryptoPP::AES>::Decryption aes;
std::array<u8, 16> original_ctr;
std::array<u8, 16> key;
};
SDMCFile::SDMCFile(std::string root_folder, const std::string& filename, const char openmode[],
int flags) {
impl = std::make_unique<Impl>();
if (root_folder.back() == '/' || root_folder.back() == '\\') {
// Remove '/' or '\' character at the end as we will add them back when combining path
root_folder.erase(root_folder.size() - 1);
}
original_ctr = GetFileCTR(filename);
key = Key::GetNormalKey(Key::SDKey);
// aes.SetKeyWithIV(key.data(), key.size(), original_ctr.data());
impl->original_ctr = GetFileCTR(filename);
impl->key = Key::GetNormalKey(Key::SDKey);
impl->aes.SetKeyWithIV(impl->key.data(), impl->key.size(), impl->original_ctr.data());
Open(root_folder + filename, openmode, flags);
}
@@ -118,46 +125,26 @@ SDMCFile::~SDMCFile() {
Close();
}
SDMCFile::SDMCFile(SDMCFile&& other) {
Swap(other);
std::size_t SDMCFile::Read(char* data, std::size_t length) {
const std::size_t length_read = FileUtil::IOFile::Read(data, length);
DecryptData(reinterpret_cast<u8*>(data), length_read);
return length_read;
}
SDMCFile& SDMCFile::operator=(SDMCFile&& other) {
Swap(other);
return *this;
}
void SDMCFile::Swap(SDMCFile& other) {
file.Swap(other.file);
std::swap(original_ctr, other.original_ctr);
std::swap(key, other.key);
}
bool SDMCFile::Open(const std::string& filename, const char openmode[], int flags) {
return file.Open(filename, openmode, flags);
}
bool SDMCFile::Close() {
return file.Close();
}
u64 SDMCFile::GetSize() const {
return file.GetSize();
std::size_t SDMCFile::Write(const char* data, std::size_t length) {
UNREACHABLE_MSG("Cannot write to a SDMCFile");
}
bool SDMCFile::Seek(s64 off, int origin) {
return file.Seek(off, origin);
}
u64 SDMCFile::Tell() const {
return file.Tell();
if (!FileUtil::IOFile::Seek(off, origin)) {
return false;
}
impl->aes.Seek(Tell());
return true;
}
void SDMCFile::DecryptData(u8* data, std::size_t size) {
CryptoPP::CTR_Mode<CryptoPP::AES>::Decryption aes;
aes.SetKeyWithIV(key.data(), key.size(), original_ctr.data());
aes.Seek(Tell() - size);
aes.ProcessData(data, data, size);
impl->aes.ProcessData(data, data, size);
}
} // namespace Core
+8 -51
View File
@@ -52,69 +52,26 @@ public:
private:
std::string root_folder;
QuickDecryptor<> quick_decryptor;
QuickDecryptor quick_decryptor;
};
/// Interface for reading an SDMC file like a normal IOFile. This is read-only.
class SDMCFile : public NonCopyable {
class SDMCFile : public FileUtil::IOFile {
public:
SDMCFile();
SDMCFile(std::string root_folder, const std::string& filename, const char openmode[],
int flags = 0);
~SDMCFile();
~SDMCFile() override;
SDMCFile(SDMCFile&& other);
SDMCFile& operator=(SDMCFile&& other);
void Swap(SDMCFile& other);
bool Open(const std::string& filename, const char openmode[], int flags = 0);
bool Close();
template <typename T>
std::size_t ReadArray(T* data, std::size_t length) {
std::size_t items_read = file.ReadArray(data, length);
if (IsGood()) {
DecryptData(reinterpret_cast<u8*>(data), sizeof(T) * length);
}
return items_read;
}
template <typename T>
std::size_t ReadBytes(T* data, std::size_t length) {
static_assert(std::is_trivially_copyable_v<T>, "T must be trivially copyable");
return ReadArray(reinterpret_cast<char*>(data), length);
}
bool IsOpen() const {
return file.IsOpen();
}
// m_good is set to false when a read, write or other function fails
bool IsGood() const {
return file.IsGood();
}
explicit operator bool() const {
return IsGood();
}
bool Seek(s64 off, int origin);
u64 Tell() const;
u64 GetSize() const;
void Clear();
std::size_t Read(char* data, std::size_t length) override;
std::size_t Write(const char* data, std::size_t length) override;
bool Seek(s64 off, int origin) override;
private:
void DecryptData(u8* data, std::size_t size);
FileUtil::IOFile file;
// CryptoPP::CTR_Mode<CryptoPP::AES>::Decryption aes;
std::array<u8, 16> original_ctr;
std::array<u8, 16> key;
struct Impl;
std::unique_ptr<Impl> impl;
};
} // namespace Core
+5 -6
View File
@@ -137,7 +137,7 @@ bool SDMCImporter::ImportNandTitle(const ContentSpecifier& specifier,
const auto base_path =
config.system_titles_path.substr(0, config.system_titles_path.size() - 6);
QuickDecryptor<> quick_decryptor;
QuickDecryptor quick_decryptor;
return ImportTitleGeneric(
quick_decryptor, base_path, specifier,
[&base_path, &quick_decryptor, &callback](const std::string& filepath) {
@@ -416,8 +416,7 @@ static bool LoadTMD(const std::string& sdmc_path, const std::string& path, SDMCD
// English short title name, extdata id, encryption, seed, icon
using TitleData = std::tuple<std::string, u64, EncryptionType, bool, std::vector<u16>>;
template <typename File>
TitleData LoadTitleData(NCCHContainer<File>& ncch) {
TitleData LoadTitleData(NCCHContainer& ncch) {
std::string codeset_name;
ncch.ReadCodesetName(codeset_name);
@@ -478,7 +477,7 @@ bool SDMCImporter::DumpCXI(const ContentSpecifier& specifier, const std::string&
const auto boot_content_path =
fmt::format("{}{:08x}.app", content_path, tmd.GetBootContentID());
dump_cxi_ncch = std::make_unique<NCCHContainer<SDMCFile>>(
dump_cxi_ncch = std::make_unique<NCCHContainer>(
std::make_shared<SDMCFile>(config.sdmc_path, boot_content_path, "rb"));
return dump_cxi_ncch->DecryptToFile(destination, callback) == ResultStatus::Success;
}
@@ -526,7 +525,7 @@ void SDMCImporter::ListTitle(std::vector<ContentSpecifier>& out) const {
const auto boot_content_path =
fmt::format("{}{:08x}.app", content_path, tmd.GetBootContentID());
NCCHContainer<SDMCFile> ncch(
NCCHContainer ncch(
std::make_shared<SDMCFile>(sdmc_path, boot_content_path, "rb"));
if (ncch.Load() != ResultStatus::Success) {
LOG_WARNING(Core, "Could not load NCCH {}", boot_content_path);
@@ -631,7 +630,7 @@ void SDMCImporter::ListNandTitle(std::vector<ContentSpecifier>& out) const {
const auto boot_content_path =
fmt::format("{}{:08x}.app", content_path, tmd.GetBootContentID());
NCCHContainer<FileUtil::IOFile> ncch(
NCCHContainer ncch(
std::make_shared<FileUtil::IOFile>(boot_content_path, "rb"));
if (ncch.Load() != ResultStatus::Success) {
LOG_WARNING(Core, "Could not load NCCH {}", boot_content_path);
+1 -3
View File
@@ -87,8 +87,6 @@ struct Config {
constexpr int CurrentDumperVersion = 1;
class SDMCFile;
template <typename File>
class NCCHContainer;
class SDMCImporter {
@@ -174,7 +172,7 @@ private:
std::unique_ptr<SDMCDecryptor> decryptor;
// The NCCH used to dump CXIs.
std::unique_ptr<NCCHContainer<SDMCFile>> dump_cxi_ncch;
std::unique_ptr<NCCHContainer> dump_cxi_ncch;
};
/**
+15 -31
View File
@@ -29,11 +29,9 @@ constexpr u32 MakeMagic(char a, char b, char c, char d) {
static const int kMaxSections = 8; ///< Maximum number of sections (files) in an ExeFs
static const int kBlockSize = 0x200; ///< Size of ExeFS blocks (in bytes)
template <typename File>
NCCHContainer<File>::NCCHContainer(std::shared_ptr<File> file_) : file(std::move(file_)) {}
NCCHContainer::NCCHContainer(std::shared_ptr<FileUtil::IOFile> file_) : file(std::move(file_)) {}
template <typename File>
ResultStatus NCCHContainer<File>::OpenFile(std::shared_ptr<File> file_) {
ResultStatus NCCHContainer::OpenFile(std::shared_ptr<FileUtil::IOFile> file_) {
file = std::move(file_);
if (!file->IsOpen()) {
@@ -45,8 +43,7 @@ ResultStatus NCCHContainer<File>::OpenFile(std::shared_ptr<File> file_) {
return ResultStatus::Success;
}
template <typename File>
ResultStatus NCCHContainer<File>::Load() {
ResultStatus NCCHContainer::Load() {
if (is_loaded)
return ResultStatus::Success;
@@ -185,7 +182,7 @@ ResultStatus NCCHContainer<File>::Load() {
// System archives and DLC don't have an extended header but have RomFS
if (ncch_header.extended_header_size) {
auto read_exheader = [this](File& file) {
auto read_exheader = [this](FileUtil::IOFile& file) {
const std::size_t size = sizeof(exheader_header);
return file && file.ReadBytes(&exheader_header, size) == size;
};
@@ -269,8 +266,7 @@ ResultStatus NCCHContainer<File>::Load() {
return ResultStatus::Success;
}
template <typename File>
ResultStatus NCCHContainer<File>::LoadSectionExeFS(const char* name, std::vector<u8>& buffer) {
ResultStatus NCCHContainer::LoadSectionExeFS(const char* name, std::vector<u8>& buffer) {
ResultStatus result = Load();
if (result != ResultStatus::Success)
return result;
@@ -307,8 +303,7 @@ ResultStatus NCCHContainer<File>::LoadSectionExeFS(const char* name, std::vector
return ResultStatus::ErrorNotUsed;
}
template <typename File>
ResultStatus NCCHContainer<File>::ReadProgramId(u64_le& program_id) {
ResultStatus NCCHContainer::ReadProgramId(u64_le& program_id) {
ResultStatus result = Load();
if (result != ResultStatus::Success)
return result;
@@ -320,8 +315,7 @@ ResultStatus NCCHContainer<File>::ReadProgramId(u64_le& program_id) {
return ResultStatus::Success;
}
template <typename File>
ResultStatus NCCHContainer<File>::ReadExtdataId(u64& extdata_id) {
ResultStatus NCCHContainer::ReadExtdataId(u64& extdata_id) {
ResultStatus result = Load();
if (result != ResultStatus::Success)
return result;
@@ -356,8 +350,7 @@ ResultStatus NCCHContainer<File>::ReadExtdataId(u64& extdata_id) {
return ResultStatus::Success;
}
template <typename File>
bool NCCHContainer<File>::HasExeFS() {
bool NCCHContainer::HasExeFS() {
ResultStatus result = Load();
if (result != ResultStatus::Success)
return false;
@@ -365,8 +358,7 @@ bool NCCHContainer<File>::HasExeFS() {
return has_exefs;
}
template <typename File>
bool NCCHContainer<File>::HasExHeader() {
bool NCCHContainer::HasExHeader() {
ResultStatus result = Load();
if (result != ResultStatus::Success)
return false;
@@ -374,8 +366,7 @@ bool NCCHContainer<File>::HasExHeader() {
return has_exheader;
}
template <typename File>
ResultStatus NCCHContainer<File>::ReadCodesetName(std::string& name) {
ResultStatus NCCHContainer::ReadCodesetName(std::string& name) {
ResultStatus result = Load();
if (result != ResultStatus::Success)
return result;
@@ -389,8 +380,7 @@ ResultStatus NCCHContainer<File>::ReadCodesetName(std::string& name) {
return ResultStatus::Success;
}
template <typename File>
ResultStatus NCCHContainer<File>::ReadEncryptionType(EncryptionType& encryption) {
ResultStatus NCCHContainer::ReadEncryptionType(EncryptionType& encryption) {
ResultStatus result = Load();
if (result != ResultStatus::Success)
return result;
@@ -425,8 +415,7 @@ ResultStatus NCCHContainer<File>::ReadEncryptionType(EncryptionType& encryption)
return ResultStatus::Success;
}
template <typename File>
ResultStatus NCCHContainer<File>::ReadSeedCrypto(bool& used) {
ResultStatus NCCHContainer::ReadSeedCrypto(bool& used) {
ResultStatus result = Load();
if (result != ResultStatus::Success)
return result;
@@ -438,9 +427,8 @@ ResultStatus NCCHContainer<File>::ReadSeedCrypto(bool& used) {
return ResultStatus::Success;
}
template <typename File>
ResultStatus NCCHContainer<File>::DecryptToFile(const std::string& destination,
const ProgressCallback& callback) {
ResultStatus NCCHContainer::DecryptToFile(const std::string& destination,
const ProgressCallback& callback) {
ResultStatus result = Load();
if (result != ResultStatus::Success)
return result;
@@ -565,15 +553,11 @@ ResultStatus NCCHContainer<File>::DecryptToFile(const std::string& destination,
return ResultStatus::Success;
}
template <typename File>
void NCCHContainer<File>::AbortDecryptToFile() {
void NCCHContainer::AbortDecryptToFile() {
aborted = true;
decryptor.Abort();
}
template class NCCHContainer<SDMCFile>;
template class NCCHContainer<FileUtil::IOFile>;
#pragma pack(push, 1)
struct RomFSIVFCHeader {
u32_le magic;
+5 -6
View File
@@ -204,13 +204,12 @@ enum class EncryptionType;
* Note that this is heavily stripped down and can only read (primary-key
* encrypted non-code sections of) ExeFS and ExHeader by design.
*/
template <typename File = SDMCFile>
class NCCHContainer {
public:
NCCHContainer(std::shared_ptr<File> file);
NCCHContainer(std::shared_ptr<FileUtil::IOFile> file);
NCCHContainer() {}
ResultStatus OpenFile(std::shared_ptr<File> file);
ResultStatus OpenFile(std::shared_ptr<FileUtil::IOFile> file);
/**
* Ensure ExeFS and exheader is loaded and ready for reading sections
@@ -304,11 +303,11 @@ private:
std::string root_folder;
std::string filepath;
std::shared_ptr<File> file;
std::shared_ptr<File> exefs_file;
std::shared_ptr<FileUtil::IOFile> file;
std::shared_ptr<FileUtil::IOFile> exefs_file;
// Used for DecryptToFile
QuickDecryptor<File, FileUtil::IOFile> decryptor;
QuickDecryptor decryptor;
std::atomic_bool aborted{false};
};
+13 -23
View File
@@ -16,18 +16,16 @@
namespace Core {
template <typename In, typename Out>
QuickDecryptor<In, Out>::QuickDecryptor() = default;
QuickDecryptor::QuickDecryptor() = default;
template <typename In, typename Out>
QuickDecryptor<In, Out>::~QuickDecryptor() = default;
QuickDecryptor::~QuickDecryptor() = default;
template <typename In, typename Out>
bool QuickDecryptor<In, Out>::DecryptAndWriteFile(std::shared_ptr<In> source_, std::size_t size,
std::shared_ptr<Out> destination_,
const ProgressCallback& callback_, bool decrypt_,
Core::Key::AESKey key_, Core::Key::AESKey ctr_,
std::size_t aes_seek_pos_) {
bool QuickDecryptor::DecryptAndWriteFile(std::shared_ptr<FileUtil::IOFile> source_,
std::size_t size,
std::shared_ptr<FileUtil::IOFile> destination_,
const ProgressCallback& callback_, bool decrypt_,
Core::Key::AESKey key_, Core::Key::AESKey ctr_,
std::size_t aes_seek_pos_) {
if (is_running) {
LOG_ERROR(Core, "Decryptor is running");
return false;
@@ -84,8 +82,7 @@ bool QuickDecryptor<In, Out>::DecryptAndWriteFile(std::shared_ptr<In> source_, s
return ret;
}
template <typename In, typename Out>
void QuickDecryptor<In, Out>::DataReadLoop() {
void QuickDecryptor::DataReadLoop() {
std::size_t current_buffer = 0;
bool is_first_run = true;
@@ -119,8 +116,7 @@ void QuickDecryptor<In, Out>::DataReadLoop() {
}
}
template <typename In, typename Out>
void QuickDecryptor<In, Out>::DataDecryptLoop() {
void QuickDecryptor::DataDecryptLoop() {
CryptoPP::CTR_Mode<CryptoPP::AES>::Decryption aes;
aes.SetKeyWithIV(key.data(), key.size(), ctr.data());
aes.Seek(aes_seek_pos);
@@ -142,8 +138,7 @@ void QuickDecryptor<In, Out>::DataDecryptLoop() {
}
}
template <typename In, typename Out>
void QuickDecryptor<In, Out>::DataWriteLoop() {
void QuickDecryptor::DataWriteLoop() {
std::size_t current_buffer = 0;
if (!*destination) {
@@ -187,21 +182,16 @@ void QuickDecryptor<In, Out>::DataWriteLoop() {
completion_event.Set();
}
template <typename In, typename Out>
void QuickDecryptor<In, Out>::Abort() {
void QuickDecryptor::Abort() {
if (is_running.exchange(false)) {
is_good = false;
completion_event.Set();
}
}
template <typename In, typename Out>
void QuickDecryptor<In, Out>::Reset(std::size_t total_size_) {
void QuickDecryptor::Reset(std::size_t total_size_) {
total_size = total_size_;
imported_size = 0;
}
template class QuickDecryptor<FileUtil::IOFile, FileUtil::IOFile>;
template class QuickDecryptor<SDMCFile, FileUtil::IOFile>;
} // namespace Core
+4 -7
View File
@@ -20,10 +20,7 @@ using ProgressCallback = std::function<void(std::size_t, std::size_t)>;
/**
* Helper that reads, decrypts and writes data. This uses three threads to process the data
* and call progress callbacks occasionally.
*
* While this is a template, it really should only be used with IOFile and SDMCFile.
*/
template <typename In = FileUtil::IOFile, typename Out = FileUtil::IOFile>
class QuickDecryptor {
public:
/**
@@ -45,8 +42,8 @@ public:
* @param ctr AES CTR for decryption
* @param aes_seek_pos The position to seek to for decryption.
*/
bool DecryptAndWriteFile(std::shared_ptr<In> source, std::size_t size,
std::shared_ptr<Out> destination,
bool DecryptAndWriteFile(std::shared_ptr<FileUtil::IOFile> source, std::size_t size,
std::shared_ptr<FileUtil::IOFile> destination,
const ProgressCallback& callback = [](std::size_t, std::size_t) {},
bool decrypt = false, Core::Key::AESKey key = {},
Core::Key::AESKey ctr = {}, std::size_t aes_seek_pos = 0);
@@ -63,8 +60,8 @@ public:
private:
static constexpr std::size_t BufferSize = 16 * 1024; // 16 KB
std::shared_ptr<In> source;
std::shared_ptr<Out> destination;
std::shared_ptr<FileUtil::IOFile> source;
std::shared_ptr<FileUtil::IOFile> destination;
bool decrypt{};
Core::Key::AESKey key;
Core::Key::AESKey ctr;
+2 -2
View File
@@ -42,8 +42,8 @@ private:
Core::ContentSpecifier SpecifierFromItem(QTreeWidgetItem* item) const;
void OnContextMenu(const QPoint& point);
void StartDumpingCXI(const Core::ContentSpecifier& content);
Core::NCCHContainer<Core::SDMCFile> dump_cxi_container; // NCCH container used for dumping CXI
QString last_dump_cxi_path; // Used for recording last path in StartDumpingCXI
Core::NCCHContainer dump_cxi_container; // NCCH container used for dumping CXI
QString last_dump_cxi_path; // Used for recording last path in StartDumpingCXI
std::unique_ptr<Ui::ImportDialog> ui;