Skip to content

Commit

Permalink
NPY to train data bugs fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
roma160 committed Apr 25, 2021
1 parent 2e81866 commit d9c47cf
Showing 1 changed file with 59 additions and 28 deletions.
87 changes: 59 additions & 28 deletions utils/npy_to_traindata/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const char* magic_string = "\x93NUMPY";
const char* descr = "\'descr\' : \'|";
const char* shape = "\'shape\': (";
const char* separator = "\n------------------------------------------------\n";
const char* default_result_location = "./result.b";
const char header_length_size = 2;

struct npy_contents
Expand All @@ -25,9 +26,9 @@ struct npy_contents
unsigned short int header_len = 0;
npy_file->seekg(strlen(magic_string) + 2, ios_base::beg);
npy_file->read((char*) &header_len, header_length_size);
string header(new char[header_len]);
npy_file->read((char*)header.c_str(), header_len);

string header(header_len, '\x00');
npy_file->read((char*) header.c_str(), header_len);

unsigned short data_type_start = header.find(descr) + strlen(descr);
ret.data_type = header.substr(
Expand All @@ -47,25 +48,26 @@ struct npy_contents
size_t get_int(const char* int_name)
{
size_t ret;
char check = ' ';
string check = "";
do
{
cout << "Please, enter the " << int_name << ": ";
cin >> ret;
getline(cin, check);
cout << int_name << " = " << ret << ". Are you sure? (Y/N) : ";
cin >> check;
} while (check != 'Y');
getline(cin, check);
} while (check != "Y");
return ret;
}
bool get_bool(const char* message)
{
char entered = ' ';
while (entered != 'Y' && entered != 'N')
string entered = "";
while (entered != "Y" && entered != "N")
{
cout << message << " (Y/N) : ";
cin>>entered;
getline(cin, entered);
}
return entered == 'Y';
return entered == "Y";
}

char get_type_size(const string &type)
Expand Down Expand Up @@ -143,11 +145,11 @@ void program(ifstream* input, ifstream* output, ofstream* result)
//TODO: implement more flexible conversion system
// (now its working only for MNIST.npy train data)

cout << separator << "\nThe result file header : " << endl <<
"Input layer size : " << input_layer_size << endl <<
"Output layer size : " << output_layer_size << endl <<
"Number of data samples : " << data_length << endl <<
"Writing started..." << endl;
cout << separator << "\nThe result file header: " << endl <<
" Input layer size : " << input_layer_size << endl <<
" Output layer size : " << output_layer_size << endl <<
" Number of data samples : " << data_length << endl <<
" Writing started..." << endl;

result->write((char*) &input_layer_size, sizeof(size_t));
result->write((char*) &output_layer_size, sizeof(size_t));
Expand Down Expand Up @@ -180,6 +182,13 @@ void program(ifstream* input, ifstream* output, ofstream* result)
result->close();
input->close();
output->close();

cout << "Program finished successfully!";
}

bool is_rewrite(const char* location)
{
return ifstream(location).good();
}

void interactive_mode()
Expand Down Expand Up @@ -207,6 +216,11 @@ void interactive_mode()
cout << "Enter result file location : ";
getline(cin, buff_string);
if(buff_string.empty()) continue;
if (is_rewrite(buff_string.c_str()) &&
!get_bool("Such file already exists. Do you want to rewrite it?")) {
buff_string = "";
continue;
}
result_file.open(buff_string, ios::out | ios::binary);
}
cout << "Starting program...\n";
Expand All @@ -216,26 +230,28 @@ void interactive_mode()
void invalid_arguments(const char* message)
{
cout << message << "\nDo you want to start interactive mode? ";
char entered = 'l';
while (entered != 'Y' && entered != 'N')
string entered = "";
while (entered != "Y" && entered != "N")
{
cout << "(Y/N) : ";
entered = getc(stdin);
getline(cin, entered);
}
if (entered == 'Y') interactive_mode();
if (entered == "Y") interactive_mode();
else cout << "Exiting...";
}

int main(int argc, char* argv[])
{
#ifdef _DEBUG
argc = 7;
argv = new char* []
{
"", "-i", "C:\\Users\\1\\Downloads\\mnist\\x_train.npy",
"-o", "C:\\Users\\1\\Downloads\\mnist\\y_train.npy",
"-r", "./res.b"
};
if (get_bool("Do you want to use default debug arguments?")) {
argc = 7;
argv = new char* []
{
"", "-i", "C:\\Users\\1\\Downloads\\mnist\\x_test.npy",
"-o", "C:\\Users\\1\\Downloads\\mnist\\y_test.npy",
"-r", "./res_test.b"
};
}
#endif

if (argc > 1)
Expand Down Expand Up @@ -273,7 +289,12 @@ int main(int argc, char* argv[])
return 0;
}
cout << "Result will be saved to ./result.b...\n";
ofstream result_file("./result.b", ios::out | ios::binary);
if (is_rewrite(default_result_location))
{
invalid_arguments("Such file already exists!");
return 0;
}
ofstream result_file(default_result_location, ios::out | ios::binary);
cout << "Starting program...\n";
program(&input_file, &output_file, &result_file);
}
Expand Down Expand Up @@ -308,6 +329,11 @@ int main(int argc, char* argv[])
ofstream result_file;
if (result != nullptr)
{
if (is_rewrite(result))
{
invalid_arguments("Such result file already exists!");
return 0;
}
new (&result_file) ofstream(result, ios::out | ios::binary);
if (!result_file.good())
{
Expand All @@ -318,7 +344,12 @@ int main(int argc, char* argv[])
else
{
cout << "Result will be saved to ./result.b...\n";
new (&result_file) ofstream("./result.b", ios::out | ios::binary);
if (is_rewrite(default_result_location))
{
invalid_arguments("Such file already exists!");
return 0;
}
new (&result_file) ofstream(default_result_location, ios::out | ios::binary);
}
cout << "Starting program...\n";
program(&input_file, &output_file, &result_file);
Expand Down

0 comments on commit d9c47cf

Please sign in to comment.