diff --git a/Source/Common/DataReader.cpp b/Source/Common/DataReader.cpp index bb7faef71ea8..920db7118fb7 100644 --- a/Source/Common/DataReader.cpp +++ b/Source/Common/DataReader.cpp @@ -274,11 +274,11 @@ bool DataReader::GetMinibatch(StreamMinibatchInputs& matrices) // uids - lables stored in size_t vector instead of ElemType matrix // boundary - phone boundaries // returns - true if there are more minibatches, false if no more minibatches remain -bool DataReader::GetMinibatch4SE(std::vector>& latticeinput, vector& uids, vector& boundaries, vector& extrauttmap) +bool DataReader::GetMinibatch4SE(std::vector>& latticeinput, vector& uids, vector& wids, vector& nws, vector& boundaries, vector& extrauttmap) { bool bRet = true; for (size_t i = 0; i < m_ioNames.size(); i++) - bRet &= m_dataReaders[m_ioNames[i]]->GetMinibatch4SE(latticeinput, uids, boundaries, extrauttmap); + bRet &= m_dataReaders[m_ioNames[i]]->GetMinibatch4SE(latticeinput, uids, wids, nws, boundaries, extrauttmap); return bRet; } diff --git a/Source/Common/Include/DataReader.h b/Source/Common/Include/DataReader.h index d74a8de02846..0e05fe98f14c 100644 --- a/Source/Common/Include/DataReader.h +++ b/Source/Common/Include/DataReader.h @@ -264,7 +264,7 @@ class DATAREADER_API IDataReader } virtual bool GetMinibatch(StreamMinibatchInputs& matrices) = 0; - virtual bool GetMinibatch4SE(std::vector>& /*latticeinput*/, vector& /*uids*/, vector& /*boundaries*/, vector& /*extrauttmap*/) + virtual bool GetMinibatch4SE(std::vector>& /*latticeinput*/, vector& /*uids*/, vector& /*wids*/, vector& /*nws*/, vector& /*boundaries*/, vector& /*extrauttmap*/) { NOT_IMPLEMENTED; }; @@ -444,7 +444,7 @@ class DataReader : public IDataReader, protected Plugin, public ScriptableObject // [out] each matrix resized if necessary containing data. // returns - true if there are more minibatches, false if no more minibatches remain virtual bool GetMinibatch(StreamMinibatchInputs& matrices); - virtual bool GetMinibatch4SE(std::vector>& latticeinput, vector& uids, vector& boundaries, vector& extrauttmap); + virtual bool GetMinibatch4SE(std::vector>& latticeinput, vector& uids, vector& wids, vector& nws, vector& boundaries, vector& extrauttmap); virtual bool GetHmmData(msra::asr::simplesenonehmm* hmm); size_t GetNumParallelSequencesForFixingBPTTMode(); diff --git a/Source/Common/Include/latticearchive.h b/Source/Common/Include/latticearchive.h index ed60f567dba5..e9ccf41246d8 100644 --- a/Source/Common/Include/latticearchive.h +++ b/Source/Common/Include/latticearchive.h @@ -23,7 +23,7 @@ #include // for find() #include "simplesenonehmm.h" #include "Matrix.h" - +#include namespace msra { namespace math { class ssematrixbase; @@ -67,7 +67,28 @@ enum mbrclassdefinition // used to identify definition of class in minimum bayes // =========================================================================== class lattice { -public: +public: + // definie structure for nbest EMBR + struct TokenInfo + { + double score; // the score of the token + size_t prev_edge_index; // edge ending with this token, edge start points to the previous node + size_t prev_token_index; // the token index in the previous node + }; + struct PrevTokenInfo + { + size_t prev_edge_index; + size_t prev_token_index; + double path_score; // use pure to indicatethe path score does not consider the WER of the path + }; + + struct NBestToken + { + // for sorting purpose + // make sure the map is stored with keys in descending order + std::map, std::greater > mp_score_token_infos; // for sorting the tokens in map + std::vector vt_nbest_tokens; // stores the nbest tokens in the node + }; struct header_v1_v2 { size_t numnodes : 32; @@ -90,12 +111,15 @@ class lattice static const unsigned int NOEDGE = 0xffffff; // 24 bits // static_assert (sizeof (nodeinfo) == 8, "unexpected size of nodeeinfo"); // note: int64_t required to allow going across 32-bit boundary // ensure type size as these are expected to be of this size in the files we read - static_assert(sizeof(nodeinfo) == 2, "unexpected size of nodeeinfo"); // note: int64_t required to allow going across 32-bit boundary + static_assert(sizeof(nodeinfo) == 16, "unexpected size of nodeeinfo"); // note: int64_t required to allow going across 32-bit boundary static_assert(sizeof(edgeinfowithscores) == 16, "unexpected size of edgeinfowithscores"); static_assert(sizeof(aligninfo) == 4, "unexpected size of aligninfo"); std::vector nodes; + mutable std::vector> vt_node_out_edge_indices; // vt_node_out_edge_indices[i]: it stores the outgoing edge indices starting from node i + std::vector is_special_words; // true if it is special words that do not count to WER computation, false if it is not std::vector edges; std::vector align; + // V2 lattices --for a while, we will store both in RAM, until all code is updated static int fsgn(float f) { @@ -217,6 +241,10 @@ class lattice public: // TODO: make private again once // construct from edges/align // This is also used for merging, where the edges[] array is not correctly sorted. So don't assume this here. + void erase_node_out_edges(size_t nodeidx, size_t edgeidx_start, size_t edgeidx_end) const + { + vt_node_out_edge_indices[nodeidx].erase(vt_node_out_edge_indices[nodeidx].begin() + edgeidx_start, vt_node_out_edge_indices[nodeidx].begin() + edgeidx_end); + } void builduniquealignments(size_t spunit = SIZE_MAX /*fix this later*/) { // infer /sp/ unit if not given @@ -701,6 +729,7 @@ class lattice const float lmf, const float wp, const float amf, const_array_ref& uids, const edgealignments& thisedgealignments, std::vector& Eframescorrect) const; + void sMBRerrorsignal(parallelstate& parallelstate, msra::math::ssematrixbase& errorsignal, msra::math::ssematrixbase& errorsignalneg, const std::vector& logpps, const float amf, double minlogpp, @@ -736,7 +765,8 @@ class lattice const std::vector& logpps, const float amf, const std::vector& logEframescorrect, const double logEframescorrecttotal, msra::math::ssematrixbase& errorsignal, msra::math::ssematrixbase& errorsignalneg) const; - + void parallelEMBRerrorsignal(parallelstate& parallelstate, const edgealignments& thisedgealignments, + const std::vector& edgeweights, msra::math::ssematrixbase& errorsignal) const; void parallelmmierrorsignal(parallelstate& parallelstate, const edgealignments& thisedgealignments, const std::vector& logpps, msra::math::ssematrixbase& errorsignal) const; @@ -747,6 +777,18 @@ class lattice const_array_ref& uids, std::vector& logEframescorrect, std::vector& Eframescorrectbuf, double& logEframescorrecttotal) const; + double parallelbackwardlatticeEMBR(parallelstate& parallelstate, const std::vector& edgeacscores, + const float lmf, const float wp, + const float amf, std::vector& edgelogbetas, + std::vector& logbetas) const; + + void EMBRsamplepaths(const std::vector &edgelogbetas, + const std::vector &logbetas, const size_t numPathsEMBR, const bool enforceValidPathEMBR, const bool excludeSpecialWords, std::vector< std::vector > & vt_paths) const; + + void EMBRnbestpaths(std::vector& tokenlattice, std::vector> & vt_paths, std::vector& path_posterior_probs) const; + + double get_edge_weights(std::vector& wids, std::vector>& vt_paths, std::vector& vt_edge_weights, std::vector& vt_path_posterior_probs, std::string getPathMethodEMBR, double& onebestwer) const; + static double scoregroundtruth(const_array_ref uids, const_array_ref transcript, const std::vector& transcriptunigrams, const msra::math::ssematrixbase& logLLs, const msra::asr::simplesenonehmm& hset, const float lmf, const float wp, const float amf); @@ -762,6 +804,14 @@ class lattice std::vector& logEframescorrect, std::vector& Eframescorrectbuf, double& logEframescorrecttotal) const; + double backwardlatticeEMBR(const std::vector& edgeacscores, parallelstate& parallelstate, std::vector &edgelogbetas, + std::vector& logbetas, + const float lmf, const float wp, const float amf) const; + + void constructnodenbestoken(std::vector &tokenlattice, const bool wordNbest, size_t numtokens2keep, size_t nidx) const; + + double nbestlatticeEMBR(const std::vector &edgeacscores, parallelstate ¶llelstate, std::vector &vt_nbesttokens, const size_t numtokens, const bool enforceValidPathEMBR, const bool excludeSpecialWords, + const float lmf, const float wp, const float amf, const bool wordNbest, const bool useAccInNbest, const float accWeightInNbest, const size_t numPathsEMBR, std::vector wids) const; public: // construct from a HTK lattice file void fromhtklattice(const std::wstring& path, const std::unordered_map& unitmap); @@ -1003,7 +1053,7 @@ class lattice // This will also map the aligninfo entries to the new symbol table, through idmap. // V1 lattices will be converted. 'spsenoneid' is used in that process. template - void fread(FILE* f, const IDMAP& idmap, size_t spunit) + void fread(FILE* f, const IDMAP& idmap, size_t spunit, std::set& specialwordids) { size_t version = freadtag(f, "LAT "); if (version == 1) @@ -1011,7 +1061,10 @@ class lattice freadOrDie(&info, sizeof(info), 1, f); freadvector(f, "NODE", nodes, info.numnodes); if (nodes.back().t != info.numframes) - RuntimeError("fread: mismatch between info.numframes and last node's time"); + { + // sometimes, the data is corrputed, let's try to live with it + fprintf(stderr, "fread: mismatch between info.numframes and last node's time: nodes.back().t = %d vs. info.numframes = %d \n", int(nodes.back().t), int(info.numframes)); + } freadvector(f, "EDGE", edges, info.numedges); freadvector(f, "ALIG", align); fcheckTag(f, "END "); @@ -1024,11 +1077,14 @@ class lattice freadOrDie(&info, sizeof(info), 1, f); freadvector(f, "NODS", nodes, info.numnodes); if (nodes.back().t != info.numframes) - RuntimeError("fread: mismatch between info.numframes and last node's time"); + { + // sometimes, the data is corrputed, let's try to live with it + fprintf(stderr, "fread: mismatch between info.numframes and last node's time: nodes.back().t = %d vs. info.numframes = %d \n", int(nodes.back().t), int(info.numframes)); + } freadvector(f, "EDGS", edges2, info.numedges); // uniqued edges freadvector(f, "ALNS", uniquededgedatatokens); // uniqued alignments fcheckTag(f, "END "); - ProcessV2Lattice(spunit, info, uniquededgedatatokens, idmap); + ProcessV2EMBRLattice(spunit, info, uniquededgedatatokens, idmap, specialwordids); } else RuntimeError("fread: unsupported lattice format version"); @@ -1124,7 +1180,28 @@ class lattice rebuildedges(info.impliedspunitid != spunit /*to be able to read somewhat broken V2 lattice archives*/); } - + + template + void ProcessV2EMBRLattice(size_t spunit, header_v1_v2& info, std::vector& uniquededgedatatokens, const IDMAP& idmap, std::set& specialwordids) + { + vt_node_out_edge_indices.resize(info.numnodes); + for (size_t j = 0; j < info.numedges; j++) + { + // an edge with !NULL pointing to not + // this code make sure if you always start from in the sampled path. + // mask here: we delay the processing in EMBRsamplepaths controlled by flag: enforceValidPathEMBR + // if (edges2[j].S == 0 && nodes[edges2[j].E].wid != 1) continue; + vt_node_out_edge_indices[edges2[j].S].push_back(j); + } + is_special_words.resize(info.numnodes); + for (size_t i = 0; i < info.numnodes; i++) + { + if (specialwordids.find(int(nodes[i].wid)) != specialwordids.end()) is_special_words[i] = true; + else is_special_words[i] = false; + } + ProcessV2Lattice(spunit, info, uniquededgedatatokens, idmap); + } + // parallel versions (defined in parallelforwardbackward.cpp) class parallelstate { @@ -1152,6 +1229,10 @@ class lattice const size_t getsilunitid(); void getedgeacscores(std::vector& edgeacscores); void getedgealignments(std::vector& edgealignments); + void getlogbetas(std::vector& logbetas); + void getedgelogbetas(std::vector& edgelogbetas); + void getedgeweights(std::vector& edgeweights); + void setedgeweights(const std::vector& edgeweights); // to work with CNTK's GPU memory void setdevice(size_t DeviceId); size_t getdevice(); @@ -1168,9 +1249,13 @@ class lattice // Note: logLLs and posteriors may be the same matrix (aliased). double forwardbackward(parallelstate& parallelstate, const class msra::math::ssematrixbase& logLLs, const class msra::asr::simplesenonehmm& hmms, class msra::math::ssematrixbase& result, class msra::math::ssematrixbase& errorsignalbuf, - const float lmf, const float wp, const float amf, const float boostingfactor, const bool sMBRmode, array_ref uids, const_array_ref bounds = const_array_ref(), + const float lmf, const float wp, const float amf, const float boostingfactor, const bool sMBRmode, const bool EMBR, const std::string EMBRUnit, const size_t numPathsEMBR, const bool enforceValidPathEMBR, const std::string getPathMethodEMBR, const std::string showWERMode, + const bool excludeSpecialWords, const bool wordNbest, const bool useAccInNbest, const float accWeightInNbest, const size_t numRawPathsEMBR, + array_ref uids, std::vector wids, const_array_ref bounds = const_array_ref(), const_array_ref transcript = const_array_ref(), const std::vector& transcriptunigrams = std::vector()) const; - + + void EMBRerrorsignal(parallelstate ¶llelstate, + const edgealignments &thisedgealignments, std::vector& edge_weights, msra::math::ssematrixbase &errorsignal) const; std::wstring key; // (keep our own name (key) so we can identify ourselves for diagnostics messages) const wchar_t* getkey() const { @@ -1358,8 +1443,10 @@ class archive if (sscanf(q, "[%" PRIu64 "]%c", &offset, &c) != 1) #endif RuntimeError("open: invalid TOC line (bad [] expression): %s", line); + if (!toc.insert(make_pair(key, latticeref(offset, archiveindex))).second) - RuntimeError("open: TOC entry leads to duplicate key: %s", line); + // sometimes, the training will report this error. I believe it is due to some small data corruption, and fine to go on, so change the error to warning + fprintf(stderr, " open: TOC entry leads to duplicate key: %s\n", line); } // initialize symmaps --alloc the array, but actually read the symmap on demand @@ -1390,7 +1477,7 @@ class archive // Lattices will have unit ids updated according to the modelsymmap. // V1 lattices will be converted. 'spsenoneid' is used in the conversion for optimizing storing 0-frame /sp/ aligns. void getlattice(const std::wstring& key, lattice& L, - size_t expectedframes = SIZE_MAX /*if unknown*/) const + std::set& specialwordids, size_t expectedframes = SIZE_MAX) const { auto iter = toc.find(key); if (iter == toc.end()) @@ -1417,7 +1504,7 @@ class archive // seek to start fsetpos(f, offset); // get it - L.fread(f, idmap, spunit); + L.fread(f, idmap, spunit, specialwordids); L.setverbosity(verbosity); #ifdef HACK_IN_SILENCE // hack to simulate DEL in the lattice const size_t silunit = getid(modelsymmap, "sil"); @@ -1451,7 +1538,8 @@ class archive // - dump to stdout // - merge two lattices (for merging numer into denom lattices) static void convert(const std::wstring& intocpath, const std::wstring& intocpath2, const std::wstring& outpath, - const msra::asr::simplesenonehmm& hset); + const msra::asr::simplesenonehmm& hset, std::set& specialwordids); }; }; }; + diff --git a/Source/Common/Include/latticesource.h b/Source/Common/Include/latticesource.h index 4794af1b4a2d..97f447097c5a 100644 --- a/Source/Common/Include/latticesource.h +++ b/Source/Common/Include/latticesource.h @@ -62,10 +62,10 @@ class latticesource #endif } - void getlattices(const std::wstring& key, std::shared_ptr& L, size_t expectedframes) const + void getlattices(const std::wstring& key, std::shared_ptr& L, size_t expectedframes, std::set& specialwordids) const { std::shared_ptr LP(new latticepair); - denlattices.getlattice(key, LP->second, expectedframes); // this loads the lattice from disk, using the existing L.second object + denlattices.getlattice(key, LP->second, specialwordids, expectedframes); // this loads the lattice from disk, using the existing L.second object L = LP; } diff --git a/Source/Common/Include/latticestorage.h b/Source/Common/Include/latticestorage.h index 76b1b87e733a..e0d963b7f188 100644 --- a/Source/Common/Include/latticestorage.h +++ b/Source/Common/Include/latticestorage.h @@ -12,6 +12,7 @@ #include #include #include +#include #undef INITIAL_STRANGE // [v-hansu] initialize structs to strange values #define PARALLEL_SIL // [v-hansu] process sil on CUDA, used in other files, please search this @@ -30,11 +31,15 @@ struct nodeinfo // uint64_t firstinedge : 24; // index of first incoming edge // uint64_t firstoutedge : 24; // index of first outgoing edge // uint64_t t : 16; // time associated with this + + uint64_t wid; // word ID associated with the node unsigned short t; // time associated with this - nodeinfo(size_t pt) - : t((unsigned short) pt) // , firstinedge (NOEDGE), firstoutedge (NOEDGE) + + nodeinfo(size_t pt, size_t pwid) + : t((unsigned short)pt), wid(pwid) { checkoverflow(t, pt, "nodeinfo::t"); + checkoverflow(wid, pwid, "nodeinfo::wid"); // checkoverflow (firstinedge, NOEDGE, "nodeinfo::firstinedge"); // checkoverflow (firstoutedge, NOEDGE, "nodeinfo::firstoutedge"); } diff --git a/Source/ComputationNetworkLib/ComputationNetwork.cpp b/Source/ComputationNetworkLib/ComputationNetwork.cpp index 6801f5b45d50..50ef9a792a27 100644 --- a/Source/ComputationNetworkLib/ComputationNetwork.cpp +++ b/Source/ComputationNetworkLib/ComputationNetwork.cpp @@ -647,12 +647,37 @@ void ComputationNetwork::SetSeqParam(ComputationNetworkPtr net, const double& lmf /*= 14.0f*/, const double& wp /*= 0.0f*/, const double& bMMIfactor /*= 0.0f*/, - const bool& sMBR /*= false*/ + const bool& sMBR /*= false */, + const bool& EMBR /*= false */, + const string& EMBRUnit /* = "word" */, + const size_t& numPathsEMBR, + const bool& enforceValidPathEMBR, + const string& getPathMethodEMBR, + const string& showWERMode, + const bool& excludeSpecialWords, + const bool& wordNbest, + const bool& useAccInNbest, + const float& accWeightInNbest, + const size_t& numRawPathsEMBR ) { fprintf(stderr, "Setting Hsmoothing weight to %.8g and frame-dropping threshhold to %.8g\n", hsmoothingWeight, frameDropThresh); - fprintf(stderr, "Setting SeqGammar-related parameters: amf=%.2f, lmf=%.2f, wp=%.2f, bMMIFactor=%.2f, usesMBR=%s\n", - amf, lmf, wp, bMMIfactor, sMBR ? "true" : "false"); + + if(EMBR) + { + fprintf(stderr, "Setting SeqGammar-related parameters: amf=%.2f, lmf=%.2f, wp=%.2f, bMMIFactor=%.2f, useEMBR=true, EMBRUnit=%s, numPathsEMBR=%d, enforceValidPathEMBR = %d, getPathMethodEMBR = %s, showWERMode = %s, excludeSpecialWords = %d, wordNbest = %d, useAccInNbest = %d, accWeightInNbest = %f, numRawPathsEMBR = %d \n", + amf, lmf, wp, bMMIfactor, EMBRUnit.c_str(), int(numPathsEMBR), int(enforceValidPathEMBR), getPathMethodEMBR.c_str(), showWERMode.c_str(), int(excludeSpecialWords), int(wordNbest), int(useAccInNbest), float(accWeightInNbest), int(numRawPathsEMBR)); + } + else if(sMBR) + { + fprintf(stderr, "Setting SeqGammar-related parameters: amf=%.2f, lmf=%.2f, wp=%.2f, bMMIFactor=%.2f, usesMBR=true \n", + amf, lmf, wp, bMMIfactor); + } + else + { + fprintf(stderr, "Setting SeqGammar-related parameters: amf=%.2f, lmf=%.2f, wp=%.2f, bMMIFactor=%.2f, useMMI=true \n", + amf, lmf, wp, bMMIfactor); + } list seqNodes = net->GetNodesWithType(OperationNameOf(SequenceWithSoftmaxNode), criterionNode); if (seqNodes.size() == 0) { @@ -666,7 +691,8 @@ void ComputationNetwork::SetSeqParam(ComputationNetworkPtr net, node->SetSmoothWeight(hsmoothingWeight); node->SetFrameDropThresh(frameDropThresh); node->SetReferenceAlign(doreferencealign); - node->SetGammarCalculationParam(amf, lmf, wp, bMMIfactor, sMBR); + node->SetMBR(sMBR || EMBR); + node->SetGammarCalculationParamEMBR(amf, lmf, wp, bMMIfactor, sMBR, EMBR, EMBRUnit, numPathsEMBR, enforceValidPathEMBR, getPathMethodEMBR, showWERMode, excludeSpecialWords, wordNbest, useAccInNbest, accWeightInNbest, numRawPathsEMBR); } } } @@ -1550,17 +1576,21 @@ template void ComputationNetwork::Read(const wstring& fileName); template void ComputationNetwork::ReadPersistableParameters(size_t modelVersion, File& fstream, bool create); template void ComputationNetwork::PerformSVDecomposition(const map& SVDConfig, size_t alignedsize); template /*static*/ void ComputationNetwork::SetBatchNormalizationTimeConstants(ComputationNetworkPtr net, const ComputationNodeBasePtr& criterionNode, const double normalizationTimeConstant, double& prevNormalizationTimeConstant, double blendTimeConstant, double& prevBlendTimeConstant); -template void ComputationNetwork::SetSeqParam(ComputationNetworkPtr net, const ComputationNodeBasePtr criterionNode, const double& hsmoothingWeight, const double& frameDropThresh, const bool& doreferencealign, - const double& amf, const double& lmf, const double& wp, const double& bMMIfactor, const bool& sMBR); + +template void ComputationNetwork::SetSeqParam(ComputationNetworkPtr net, const ComputationNodeBasePtr criterionNode, const double& hsmoothingWeight, const double& frameDropThresh, const bool& doreferencealign, const double& amf, const double& lmf, const double& wp, const double& bMMIfactor, const bool& sMBR, const bool& EMBR, const string& EMBRUnit, const size_t& numPathsEMBR, const bool& enforceValidPathEMBR, const string& getPathMethodEMBR, const string& showWERMode, const bool& excludeSpecialWords, const bool& wordNbest, const bool& useAccInNbest, const float& accWeightInNbest, const size_t& numRawPathsEMBR); template void ComputationNetwork::SaveToDbnFile(ComputationNetworkPtr net, const std::wstring& fileName) const; template void ComputationNetwork::InitLearnableParametersWithBilinearFill(const ComputationNodeBasePtr& node, size_t kernelWidth, size_t kernelHeight); template void ComputationNetwork::Read(const wstring& fileName); template void ComputationNetwork::ReadPersistableParameters(size_t modelVersion, File& fstream, bool create); template void ComputationNetwork::PerformSVDecomposition(const map& SVDConfig, size_t alignedsize); + template /*static*/ void ComputationNetwork::SetBatchNormalizationTimeConstants(ComputationNetworkPtr net, const ComputationNodeBasePtr& criterionNode, const double normalizationTimeConstant, double& prevNormalizationTimeConstant, double blendTimeConstant, double& prevBlendTimeConstant); + template void ComputationNetwork::SetSeqParam(ComputationNetworkPtr net, const ComputationNodeBasePtr criterionNode, const double& hsmoothingWeight, const double& frameDropThresh, const bool& doreferencealign, - const double& amf, const double& lmf, const double& wp, const double& bMMIfactor, const bool& sMBR); + const double& amf, const double& lmf, const double& wp, const double& bMMIfactor, const bool& sMBR, const bool& EMBR, const string& EMBRUnit, const size_t& numPathsEMBR, + const bool& enforceValidPathEMBR, const string& getPathMethodEMBR, const string& showWERMode, const bool& excludeSpecialWords, const bool& wordNbest, const bool& useAccInNbest, const float& accWeightInNbest, const size_t& numRawPathsEMBR); + template void ComputationNetwork::SaveToDbnFile(ComputationNetworkPtr net, const std::wstring& fileName) const; template void ComputationNetwork::InitLearnableParametersWithBilinearFill(const ComputationNodeBasePtr& node, size_t kernelWidth, size_t kernelHeight); @@ -1568,8 +1598,6 @@ template void ComputationNetwork::Read(const wstring& fileName); template void ComputationNetwork::ReadPersistableParameters(size_t modelVersion, File& fstream, bool create); template void ComputationNetwork::PerformSVDecomposition(const map& SVDConfig, size_t alignedsize); template /*static*/ void ComputationNetwork::SetBatchNormalizationTimeConstants(ComputationNetworkPtr net, const ComputationNodeBasePtr& criterionNode, const double normalizationTimeConstant, double& prevNormalizationTimeConstant, double blendTimeConstant, double& prevBlendTimeConstant); -template void ComputationNetwork::SetSeqParam(ComputationNetworkPtr net, const ComputationNodeBasePtr criterionNode, const double& hsmoothingWeight, const double& frameDropThresh, const bool& doreferencealign, - const double& amf, const double& lmf, const double& wp, const double& bMMIfactor, const bool& sMBR); template void ComputationNetwork::SaveToDbnFile(ComputationNetworkPtr net, const std::wstring& fileName) const; // register ComputationNetwork with the ScriptableObject system diff --git a/Source/ComputationNetworkLib/ComputationNetwork.h b/Source/ComputationNetworkLib/ComputationNetwork.h index 85d6922daf6a..0cfce9281989 100644 --- a/Source/ComputationNetworkLib/ComputationNetwork.h +++ b/Source/ComputationNetworkLib/ComputationNetwork.h @@ -554,7 +554,19 @@ class ComputationNetwork : const double& lmf = 14.0f, const double& wp = 0.0f, const double& bMMIfactor = 0.0f, - const bool& sMBR = false); + const bool& sMBR = false, + const bool& EMBR = false, + const string& EMBRUnit = "word", + const size_t& numPathsEMBR = 100, + const bool& enforceValidPathEMBR = false, + const string& getPathMethodEMBR = "sampling", + const string& showWERMode = "average", + const bool& excludeSpecialWords = false, + const bool& wordNbest = false, + const bool& useAccInNbest = false, + const float& accWeightInNbest = 1.0f, + const size_t& numRawPathsEMBR = 100 + ); static void SetMaxTempMemSizeForCNN(ComputationNetworkPtr net, const ComputationNodeBasePtr& criterionNode, const size_t maxTempMemSizeInSamples); // ----------------------------------------------------------------------- diff --git a/Source/ComputationNetworkLib/SpecialPurposeNodes.h b/Source/ComputationNetworkLib/SpecialPurposeNodes.h index f658339584fa..f0e18c931727 100755 --- a/Source/ComputationNetworkLib/SpecialPurposeNodes.h +++ b/Source/ComputationNetworkLib/SpecialPurposeNodes.h @@ -477,12 +477,12 @@ class SequenceWithSoftmaxNode : public ComputationNodeNonLooping, publ { Input(inputIndex)->Gradient().SetValue(0.0f); Value().SetValue(1.0f); - } + } else { FrameRange fr(Input(0)->GetMBLayout()); BackpropToRight(*m_softmaxOfRight, Input(0)->Value(), Input(inputIndex)->Gradient(), - Gradient(), *m_gammaFromLattice, m_fsSmoothingWeight, m_frameDropThreshold); + Gradient(), *m_gammaFromLattice, m_fsSmoothingWeight, m_frameDropThreshold, m_MBR); MaskMissingColumnsToZero(Input(inputIndex)->Gradient(), Input(0)->GetMBLayout(), fr); } #ifdef _DEBUG @@ -518,7 +518,7 @@ class SequenceWithSoftmaxNode : public ComputationNodeNonLooping, publ static void WINAPI BackpropToRight(const Matrix& softmaxOfRight, const Matrix& inputFunctionValues, Matrix& inputGradientValues, const Matrix& gradientValues, - const Matrix& gammaFromLattice, double hsmoothingWeight, double frameDropThresh) + const Matrix& gammaFromLattice, double hsmoothingWeight, double frameDropThresh, bool MBR) { #if DUMPOUTPUT softmaxOfRight.Print("SequenceWithSoftmaxNode Partial-softmaxOfRight"); @@ -526,8 +526,7 @@ class SequenceWithSoftmaxNode : public ComputationNodeNonLooping, publ gradientValues.Print("SequenceWithSoftmaxNode Partial-gradientValues"); inputGradientValues.Print("SequenceWithSoftmaxNode Partial-Right-in"); #endif - - inputGradientValues.AssignSequenceError((ElemType) hsmoothingWeight, inputFunctionValues, softmaxOfRight, gammaFromLattice, gradientValues.Get00Element()); + inputGradientValues.AssignSequenceError((ElemType)hsmoothingWeight, inputFunctionValues, softmaxOfRight, gammaFromLattice, gradientValues.Get00Element(), MBR); inputGradientValues.DropFrame(inputFunctionValues, gammaFromLattice, (ElemType) frameDropThresh); #if DUMPOUTPUT inputGradientValues.Print("SequenceWithSoftmaxNode Partial-Right"); @@ -563,7 +562,7 @@ class SequenceWithSoftmaxNode : public ComputationNodeNonLooping, publ m_gammaFromLattice->Resize(*m_softmaxOfRight); m_gammaCalculator.calgammaformb(Value(), m_lattices, Input(2)->Value() /*log LLs*/, Input(0)->Value() /*labels*/, *m_gammaFromLattice, - m_uids, m_boundaries, Input(1)->GetNumParallelSequences(), + m_uids, m_wids, m_nws, m_boundaries, Input(1)->GetNumParallelSequences(), Input(0)->GetMBLayout(), m_extraUttMap, m_doReferenceAlignment); #if NANCHECK @@ -635,14 +634,41 @@ class SequenceWithSoftmaxNode : public ComputationNodeNonLooping, publ // TODO: method names should be CamelCase std::vector>* getLatticePtr() { return &m_lattices; } std::vector* getuidprt() { return &m_uids; } + std::vector* getwidprt() { return &m_wids; } + std::vector* getnwprt() { return &m_nws; } + std::vector* getboundaryprt() { return &m_boundaries; } std::vector* getextrauttmap() { return &m_extraUttMap; } msra::asr::simplesenonehmm* gethmm() { return &m_hmm; } void SetSmoothWeight(double fsSmoothingWeight) { m_fsSmoothingWeight = fsSmoothingWeight; } + void SetMBR(bool MBR) { m_MBR = MBR; } void SetFrameDropThresh(double frameDropThresh) { m_frameDropThreshold = frameDropThresh; } void SetReferenceAlign(const bool doreferencealign) { m_doReferenceAlignment = doreferencealign; } + void SetGammarCalculationParamEMBR(const double& amf, const double& lmf, const double& wp, const double& bMMIfactor, const bool& sMBR, const bool& EMBR, const string& EMBRUnit, const size_t& numPathsEMBR, + const bool& enforceValidPathEMBR, const string& getPathMethodEMBR, const string& showWERMode, const bool& excludeSpecialWords, const bool& wordNbest, const bool& useAccInNbest, const float& accWeightInNbest, const size_t& numRawPathsEMBR) + { + msra::lattices::SeqGammarCalParam param; + param.amf = amf; + param.lmf = lmf; + param.wp = wp; + param.bMMIfactor = bMMIfactor; + param.sMBRmode = sMBR; + + param.EMBR = EMBR; + param.EMBRUnit = EMBRUnit; + param.numPathsEMBR = numPathsEMBR; + param.enforceValidPathEMBR = enforceValidPathEMBR; + param.getPathMethodEMBR = getPathMethodEMBR; + param.showWERMode = showWERMode; + param.excludeSpecialWords = excludeSpecialWords; + param.wordNbest = wordNbest; + param.useAccInNbest = useAccInNbest; + param.accWeightInNbest = accWeightInNbest; + param.numRawPathsEMBR = numRawPathsEMBR; + m_gammaCalculator.SetGammarCalculationParamsEMBR(param); + } void SetGammarCalculationParam(const double& amf, const double& lmf, const double& wp, const double& bMMIfactor, const bool& sMBR) { msra::lattices::SeqGammarCalParam param; @@ -653,7 +679,6 @@ class SequenceWithSoftmaxNode : public ComputationNodeNonLooping, publ param.sMBRmode = sMBR; m_gammaCalculator.SetGammarCalculationParams(param); } - void gettime(unsigned long long& gammatime, unsigned long long& partialtime) { gammatime = m_gammatime; @@ -667,6 +692,7 @@ class SequenceWithSoftmaxNode : public ComputationNodeNonLooping, publ bool m_invalidMinibatch; // for single minibatch double m_frameDropThreshold; double m_fsSmoothingWeight; // frame-sequence criterion interpolation weight --TODO: can this be done outside? + bool m_MBR; double m_seqGammarAMF; double m_seqGammarLMF; double m_seqGammarWP; @@ -678,6 +704,9 @@ class SequenceWithSoftmaxNode : public ComputationNodeNonLooping, publ msra::lattices::GammaCalculation m_gammaCalculator; bool m_gammaCalcInitialized; std::vector m_uids; + std::vector m_wids; + + std::vector m_nws; std::vector m_boundaries; std::vector m_extraUttMap; @@ -806,6 +835,8 @@ class LatticeSequenceWithSoftmaxNode : public SequenceWithSoftmaxNode, auto& currentLatticeSeq = latticeMBLayout->FindSequence(currentLabelSeq.seqId); std::shared_ptr latticePair(new msra::dbn::latticepair); const char* buffer = bufferStart + latticeMBNumTimeSteps * sizeof(float) * currentLatticeSeq.s + currentLatticeSeq.tBegin; + + latticePair->second.ReadFromBuffer(buffer, m_idmap, m_idmap.back()); assert((currentLabelSeq.tEnd - currentLabelSeq.tBegin) == latticePair->second.info.numframes); // The size of the vector is small -- the number of sequences in the minibatch. diff --git a/Source/Math/GPUMatrix.cu b/Source/Math/GPUMatrix.cu index 0e7ada4d51f4..fadc1c741b21 100755 --- a/Source/Math/GPUMatrix.cu +++ b/Source/Math/GPUMatrix.cu @@ -4687,7 +4687,7 @@ GPUMatrix& GPUMatrix::DropFrame(const GPUMatrix& l template GPUMatrix& GPUMatrix::AssignSequenceError(const ElemType hsmoothingWeight, const GPUMatrix& label, - const GPUMatrix& dnnoutput, const GPUMatrix& gamma, ElemType alpha) + const GPUMatrix& dnnoutput, const GPUMatrix& gamma, ElemType alpha, bool MBR) { if (IsEmpty()) LogicError("AssignSequenceError: Matrix is empty."); @@ -4697,7 +4697,7 @@ GPUMatrix& GPUMatrix::AssignSequenceError(const ElemType hsm SyncGuard syncGuard; long N = (LONG64) label.GetNumElements(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); - _AssignSequenceError<<>>(hsmoothingWeight, Data(), label.Data(), dnnoutput.Data(), gamma.Data(), alpha, N); + _AssignSequenceError << > >(hsmoothingWeight, Data(), label.Data(), dnnoutput.Data(), gamma.Data(), alpha, N, MBR); return *this; } diff --git a/Source/Math/GPUMatrix.h b/Source/Math/GPUMatrix.h index 80f650069f8e..20ec3ab911a5 100755 --- a/Source/Math/GPUMatrix.h +++ b/Source/Math/GPUMatrix.h @@ -370,8 +370,7 @@ class MATH_API GPUMatrix : public BaseMatrix // sequence training GPUMatrix& DropFrame(const GPUMatrix& label, const GPUMatrix& gamma, const ElemType& threshhold); - GPUMatrix& AssignSequenceError(const ElemType hsmoothingWeight, const GPUMatrix& label, const GPUMatrix& dnnoutput, const GPUMatrix& gamma, ElemType alpha); - + GPUMatrix& AssignSequenceError(const ElemType hsmoothingWeight, const GPUMatrix& label, const GPUMatrix& dnnoutput, const GPUMatrix& gamma, ElemType alpha, bool MBR); GPUMatrix& AssignCTCScore(const GPUMatrix& prob, GPUMatrix& alpha, GPUMatrix& beta, const GPUMatrix phoneSeq, const GPUMatrix phoneBoundary, GPUMatrix & totalScore, const vector& uttMap, const vector & uttBeginFrame, const vector & uttFrameNum, const vector & uttPhoneNum, const size_t samplesInRecurrentStep, const size_t maxFrameNum, const size_t blankTokenId, const int delayConstraint, const bool isColWise); diff --git a/Source/Math/GPUMatrixCUDAKernels.cuh b/Source/Math/GPUMatrixCUDAKernels.cuh index 59edf814b59c..35677b5e6c24 100755 --- a/Source/Math/GPUMatrixCUDAKernels.cuh +++ b/Source/Math/GPUMatrixCUDAKernels.cuh @@ -5275,13 +5275,17 @@ __global__ void _DropFrame( template __global__ void _AssignSequenceError(const ElemType hsmoothingWeight, ElemType* error, const ElemType* label, - const ElemType* dnnoutput, const ElemType* gamma, ElemType alpha, const long N) + const ElemType* dnnoutput, const ElemType* gamma, ElemType alpha, const long N, bool MBR) { typedef typename TypeSelector::comp_t comp_t; int id = blockDim.x * blockIdx.x + threadIdx.x; if (id >= N) return; - error[id] = (comp_t)error[id] - (comp_t)alpha * ((comp_t)label[id] - (1.0 - (comp_t)hsmoothingWeight) * (comp_t)dnnoutput[id] - (comp_t)hsmoothingWeight * (comp_t)gamma[id]); + if(!MBR) + error[id] -= alpha * (label[id] - (1.0 - hsmoothingWeight) * dnnoutput[id] - hsmoothingWeight * gamma[id]); + else + error[id] -= alpha * ( (1.0 - hsmoothingWeight) * (label[id] - dnnoutput[id]) + hsmoothingWeight * gamma[id]); + // change to ce // error[id] -= alpha * (label[id] - dnnoutput[id] ); } diff --git a/Source/Math/Matrix.cpp b/Source/Math/Matrix.cpp index a377eda814e8..f001e6825835 100755 --- a/Source/Math/Matrix.cpp +++ b/Source/Math/Matrix.cpp @@ -6171,7 +6171,7 @@ Matrix& Matrix::DropFrame(const Matrix& label, con /// Resulting matrix, user is responsible for allocating this template Matrix& Matrix::AssignSequenceError(const ElemType hsmoothingWeight, const Matrix& label, - const Matrix& dnnoutput, const Matrix& gamma, ElemType alpha) + const Matrix& dnnoutput, const Matrix& gamma, ElemType alpha, bool MBR) { DecideAndMoveToRightDevice(label, dnnoutput, gamma); @@ -6179,11 +6179,13 @@ Matrix& Matrix::AssignSequenceError(const ElemType hsmoothin NOT_IMPLEMENTED; SwitchToMatrixType(label.GetMatrixType(), label.GetFormat(), false); + + DISPATCH_MATRIX_ON_FLAG(this, this, m_CPUMatrix->AssignSequenceError(hsmoothingWeight, *label.m_CPUMatrix, *dnnoutput.m_CPUMatrix, *gamma.m_CPUMatrix, alpha), - m_GPUMatrix->AssignSequenceError(hsmoothingWeight, *label.m_GPUMatrix, *dnnoutput.m_GPUMatrix, *gamma.m_GPUMatrix, alpha), + m_GPUMatrix->AssignSequenceError(hsmoothingWeight, *label.m_GPUMatrix, *dnnoutput.m_GPUMatrix, *gamma.m_GPUMatrix, alpha, MBR), NOT_IMPLEMENTED, NOT_IMPLEMENTED); return *this; diff --git a/Source/Math/Matrix.h b/Source/Math/Matrix.h index 414e7e16678f..08437a2d1609 100755 --- a/Source/Math/Matrix.h +++ b/Source/Math/Matrix.h @@ -402,8 +402,7 @@ class MATH_API Matrix : public MatrixBase // sequence training Matrix& DropFrame(const Matrix& label, const Matrix& gamma, const ElemType& threshhold); - Matrix& AssignSequenceError(const ElemType hsmoothingWeight, const Matrix& label, const Matrix& dnnoutput, const Matrix& gamma, ElemType alpha); - + Matrix& AssignSequenceError(const ElemType hsmoothingWeight, const Matrix& label, const Matrix& dnnoutput, const Matrix& gamma, ElemType alpha, bool MBR); Matrix& AssignCTCScore(const Matrix& prob, Matrix& alpha, Matrix& beta, const Matrix& phoneSeq, const Matrix& phoneBound, Matrix& totalScore, const vector & extraUttMap, const vector & uttBeginFrame, const vector & uttFrameNum, const vector & uttPhoneNum, const size_t samplesInRecurrentStep, const size_t mbSize, const size_t blankTokenId, const int delayConstraint, const bool isColWise); diff --git a/Source/Math/NoGPU.cpp b/Source/Math/NoGPU.cpp index 422f1d389dec..6bf75ee4e150 100755 --- a/Source/Math/NoGPU.cpp +++ b/Source/Math/NoGPU.cpp @@ -1490,7 +1490,7 @@ GPUMatrix& GPUMatrix::DropFrame(const GPUMatrix& l return *this; } template -GPUMatrix& GPUMatrix::AssignSequenceError(const ElemType hsmoothingWeight, const GPUMatrix& label, const GPUMatrix& dnnoutput, const GPUMatrix& gamma, ElemType alpha) +GPUMatrix& GPUMatrix::AssignSequenceError(const ElemType hsmoothingWeight, const GPUMatrix& label, const GPUMatrix& dnnoutput, const GPUMatrix& gamma, ElemType alpha, bool MBR) { return *this; } diff --git a/Source/Math/cudalattice.cpp b/Source/Math/cudalattice.cpp index c154307295be..158b6ab8700f 100644 --- a/Source/Math/cudalattice.cpp +++ b/Source/Math/cudalattice.cpp @@ -162,6 +162,20 @@ class latticefunctionsimpl : public vectorbaseimpl> &>(Eframescorrectbuf), logEframescorrecttotal, totalfwscore); } + void backwardlatticeEMBR(const size_t *batchsizebackward, const size_t numlaunchbackward, + const floatvector &edgeacscores, const edgeinfowithscoresvector &edges, + const nodeinfovector &nodes, doublevector &edgelogbetas, doublevector &logbetas, + const float lmf, const float wp, const float amf, double &totalbwscore) + { + ondevice no(deviceid); + latticefunctionsops::backwardlatticeEMBR(batchsizebackward, numlaunchbackward, + dynamic_cast> &>(edgeacscores), + dynamic_cast> &>(edges), + dynamic_cast> &>(nodes), + dynamic_cast> &>(edgelogbetas), + dynamic_cast> &>(logbetas), + lmf, wp, amf, totalbwscore); + } void sMBRerrorsignal(const ushortvector &alignstateids, const uintvector &alignoffsets, @@ -183,6 +197,23 @@ class latticefunctionsimpl : public vectorbaseimpl &dengammas) + { + ondevice no(deviceid); + + matrixref dengammasMatrixRef = tomatrixref(dengammas); + + latticefunctionsops::EMBRerrorsignal(dynamic_cast> &>(alignstateids), + dynamic_cast> &>(alignoffsets), + dynamic_cast> &>(edges), + dynamic_cast> &>(nodes), + dynamic_cast> &>(edgeweights), + dengammasMatrixRef); + } void mmierrorsignal(const ushortvector &alignstateids, const uintvector &alignoffsets, const edgeinfowithscoresvector &edges, const nodeinfovector &nodes, const doublevector &logpps, Microsoft::MSR::CNTK::Matrix &dengammas) diff --git a/Source/Math/cudalattice.h b/Source/Math/cudalattice.h index 0e29c50fd1ee..fed8935845a2 100644 --- a/Source/Math/cudalattice.h +++ b/Source/Math/cudalattice.h @@ -99,6 +99,17 @@ struct latticefunctions : public vectorbase doublevector& logaccalphas, doublevector& logaccbetas, doublevector& logframescorrectedge, doublevector& logEframescorrect, doublevector& Eframescorrectbuf, double& logEframescorrecttotal, double& totalfwscore) = 0; + + virtual void backwardlatticeEMBR(const size_t* batchsizebackward, const size_t numlaunchbackward, + const floatvector& edgeacscores, const edgeinfowithscoresvector& edges, + const nodeinfovector& nodes, doublevector& edgelogbetas, doublevector& logbetas, + const float lmf, const float wp, const float amf, double& totalbwscore) = 0; + + virtual void EMBRerrorsignal(const ushortvector& alignstateids, const uintvector& alignoffsets, + const edgeinfowithscoresvector& edges, const nodeinfovector& nodes, + const doublevector& edgeweights, Microsoft::MSR::CNTK::Matrix& dengammas) = 0; + + virtual void sMBRerrorsignal(const ushortvector& alignstateids, const uintvector& alignoffsets, const edgeinfowithscoresvector& edges, const nodeinfovector& nodes, const doublevector& logpps, const float amf, const doublevector& logEframescorrect, diff --git a/Source/Math/cudalatticeops.cu.h b/Source/Math/cudalatticeops.cu.h index 6db5c35b2507..45f04848e80c 100644 --- a/Source/Math/cudalatticeops.cu.h +++ b/Source/Math/cudalatticeops.cu.h @@ -227,6 +227,24 @@ __global__ void backwardlatticej(const size_t batchsize, const size_t startindex } } +__global__ void backwardlatticejEMBR(const size_t batchsize, const size_t startindex, const vectorref edgeacscores, + vectorref edges, vectorref nodes, + vectorref edgelogbetas, vectorref logbetas, + float lmf, float wp, float amf) +{ + const size_t tpb = blockDim.x * blockDim.y; // total #threads in a block + const size_t jinblock = threadIdx.x + threadIdx.y * blockDim.x; + size_t j = jinblock + blockIdx.x * tpb; + if (j < batchsize) // note: will cause issues if we ever use __synctreads() + { + msra::lattices::latticefunctionskernels::backwardlatticejEMBR(j + startindex, edgeacscores, + edges, nodes, edgelogbetas, logbetas, + lmf, wp, amf); + + + } +} + void latticefunctionsops::forwardbackwardlattice(const size_t *batchsizeforward, const size_t *batchsizebackward, const size_t numlaunchforward, const size_t numlaunchbackward, const size_t spalignunitid, const size_t silalignunitid, @@ -326,6 +344,44 @@ void latticefunctionsops::forwardbackwardlattice(const size_t *batchsizeforward, } } +void latticefunctionsops::backwardlatticeEMBR( const size_t *batchsizebackward, const size_t numlaunchbackward, + const vectorref &edgeacscores, + const vectorref &edges, + const vectorref &nodes, vectorref &edgelogbetas, vectorref &logbetas, + const float lmf, const float wp, const float amf, double &totalbwscore) const +{ + // initialize log{,acc}(alhas/betas) + dim3 t(32, 8); + const size_t tpb = t.x * t.y; + dim3 b((unsigned int)((logbetas.size() + tpb - 1) / tpb)); + + // TODO: is this really efficient? One thread per value? + setvaluej << > >(logbetas, LOGZERO, logbetas.size()); + checklaunch("setvaluej"); + + // set initial tokens to probability 1 (0 in log) + double log1 = 0.0; + memcpy(logbetas.get(), nodes.size() - 1, &log1, 1); + + + // backward pass + size_t startindex = 0; + startindex = edges.size(); + for (size_t i = 0; i < numlaunchbackward; i++) + { + dim3 b2((unsigned int)((batchsizebackward[i] + tpb - 1) / tpb)); + backwardlatticejEMBR << > >(batchsizebackward[i], startindex - batchsizebackward[i], + edgeacscores, edges, nodes, edgelogbetas, logbetas, + lmf, wp, amf); + + + checklaunch("edgealignment"); + startindex -= batchsizebackward[i]; + } + memcpy(&totalbwscore, logbetas.get(), 0, 1); + +} + // ----------------------------------------------------------------------- // sMBRerrorsignal -- accumulate difference of logEframescorrect and logEframescorrecttotal into errorsignal // ----------------------------------------------------------------------- @@ -342,6 +398,18 @@ __global__ void sMBRerrorsignalj(const vectorref alignstateids, msra::lattices::latticefunctionskernels::sMBRerrorsignalj(j, alignstateids, alignoffsets, edges, nodes, logpps, amf, logEframescorrect, logEframescorrecttotal, errorsignal, errorsignalneg); } } +__global__ void EMBRerrorsignalj(const vectorref alignstateids, const vectorref alignoffsets, + const vectorref edges, const vectorref nodes, + vectorref edgeweights, + matrixref errorsignal) +{ + const size_t shufflemode = 1; // [v-hansu] this gives us about 100% speed up than shufflemode = 0 (no shuffle) + const size_t j = msra::lattices::latticefunctionskernels::shuffle(threadIdx.x, blockDim.x, threadIdx.y, blockDim.y, blockIdx.x, gridDim.x, shufflemode); + if (j < edges.size()) // note: will cause issues if we ever use __synctreads() + { + msra::lattices::latticefunctionskernels::EMBRerrorsignalj(j, alignstateids, alignoffsets, edges, nodes, edgeweights, errorsignal); + } +} // ----------------------------------------------------------------------- // stateposteriors --accumulate a per-edge quantity into the states that the edge is aligned with @@ -433,6 +501,27 @@ void latticefunctionsops::sMBRerrorsignal(const vectorref &align #endif } +void latticefunctionsops::EMBRerrorsignal(const vectorref &alignstateids, const vectorref &alignoffsets, + const vectorref &edges, const vectorref &nodes, + const vectorref &edgeweights, + matrixref &errorsignal) const +{ + // Layout: each thread block takes 1024 threads; and we have #edges/1024 blocks. + // This limits us to 16 million edges. If you need more, please adjust to either use wider thread blocks or a second dimension for the grid. Don't forget to adjust the kernel as well. + const size_t numedges = edges.size(); + dim3 t(32, 8); + const size_t tpb = t.x * t.y; + dim3 b((unsigned int)((numedges + tpb - 1) / tpb)); + + setvaluei << > >(errorsignal, 0); + checklaunch("setvaluei"); + + EMBRerrorsignalj << > >(alignstateids, alignoffsets, edges, nodes, edgeweights, errorsignal); + checklaunch("EMBRerrorsignal"); + + +} + void latticefunctionsops::mmierrorsignal(const vectorref &alignstateids, const vectorref &alignoffsets, const vectorref &edges, const vectorref &nodes, const vectorref &logpps, matrixref &errorsignal) const diff --git a/Source/Math/cudalatticeops.h b/Source/Math/cudalatticeops.h index b535d86d79cc..9b2c46c71bf2 100644 --- a/Source/Math/cudalatticeops.h +++ b/Source/Math/cudalatticeops.h @@ -53,6 +53,15 @@ class latticefunctionsops : protected vectorref vectorref& logframescorrectedge, vectorref& logEframescorrect, vectorref& Eframescorrectbuf, double& logEframescorrecttotal, double& totalfwscore) const; + void backwardlatticeEMBR(const size_t *batchsizebackward, const size_t numlaunchbackward, + const vectorref &edgeacscores, + const vectorref &edges, + const vectorref &nodes, vectorref &edgelogbetas, vectorref &logbetas, + const float lmf, const float wp, const float amf, double &totalbwscore) const; + void EMBRerrorsignal(const vectorref &alignstateids, const vectorref &alignoffsets, + const vectorref &edges, const vectorref &nodes, + const vectorref &edgeweights, + matrixref &errorsignal) const; void sMBRerrorsignal(const vectorref& alignstateids, const vectorref& alignoffsets, const vectorref& edges, const vectorref& nodes, const vectorref& logpps, const float amf, const vectorref& logEframescorrect, const double logEframescorrecttotal, diff --git a/Source/Math/latticefunctionskernels.h b/Source/Math/latticefunctionskernels.h index 207a6bb70b6a..4fbffd2ba4ac 100644 --- a/Source/Math/latticefunctionskernels.h +++ b/Source/Math/latticefunctionskernels.h @@ -302,6 +302,22 @@ struct latticefunctionskernels // note: critically, ^^ this comparison must copare the bits ('int') instead of the converted float values, since this will fail for NaNs (NaN != NaN is true always) return bitsasfloat(old); } + template // adapted from [http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#ixzz32EuzZjxV] + static __device__ FLOAT atomicAdd(FLOAT *address, FLOAT val) // direct adaptation from NVidia source code + { + typedef decltype(floatasbits(val)) bitstype; + bitstype *address_as_ull = (bitstype *)address; + bitstype old = *address_as_ull, assumed; + do + { + assumed = old; + FLOAT sum = bitsasfloat(assumed); + sum = sum + val; + old = atomicCAS(address_as_ull, assumed, floatasbits(sum)); + } while (assumed != old); + // note: critically, ^^ this comparison must copare the bits ('int') instead of the converted float values, since this will fail for NaNs (NaN != NaN is true always) + return bitsasfloat(old); + } #else // this code does not work because (assumed != old) will not compare correctly in case of NaNs // same pattern as atomicAdd(), but performing the log-add operation instead template @@ -887,6 +903,61 @@ struct latticefunctionskernels logpps[j] = logpp; if (returnEframescorrect) logEframescorrect[j] = logEframescorrectj; + } + template + static inline __device__ void backwardlatticejEMBR(size_t j, const floatvector &edgeacscores, + const edgeinforvector &edges, const nodeinfovector &nodes, doublevector & edgelogbetas, + doublevector &logbetas, float lmf, float wp, float amf) + { + + // edge info + const edgeinfowithscores &e = edges[j]; + double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf; + // zhaorui to deal with the abnormal score for sent start. + if (e.l < -200.0f) + edgescore = (0.0 * lmf + wp + edgeacscores[j]) / amf; + + + +#ifdef FORBID_INVALID_SIL_PATHS + // original mode + const bool forbidinvalidsilpath = (logbetas.size() > nodes.size()); // we prune sil to sil path if alphabetablowup != 1 + const bool isaddedsil = forbidinvalidsilpath && (e.unused == 1); // HACK: 'unused' indicates artificially added sil/sp edge + + if (!isaddedsil) // original mode +#endif + { + const size_t S = e.S; + const size_t E = e.E; + + // backward pass + const double inscore = logbetas[E]; + const double pathscore = inscore + edgescore; + edgelogbetas[j] = pathscore; + atomicLogAdd(&logbetas[S], pathscore); + } + +#ifdef FORBID_INVALID_SIL_PATHS + + // silence edge or second speech edge + if ((isaddedsil && e.E != nodes.size() - 1) || (forbidinvalidsilpath && e.S != 0)) + { + const size_t S = (size_t)(!isaddedsil ? e.S + nodes.size() : e.S); // second speech edge comes from special 'silence state' node + const size_t E = (size_t)(isaddedsil ? e.E + nodes.size() : e.E); // silence edge goes into special 'silence state' node + // remaining lines here are code dup from above, with two changes: logadd2/logEframescorrectj2 instead of logadd/logEframescorrectj + + // backward pass + const double inscore = logbetas[E]; + const double pathscore = inscore + edgescore; + edgelogbetas[j] = pathscore; + atomicLogAdd(&logbetas[S], pathscore); + + } +#else + nodes; +#endif + + } template @@ -930,6 +1001,29 @@ struct latticefunctionskernels } } + template + static inline __device__ void EMBRerrorsignalj(size_t j, const ushortvector &alignstateids, const uintvector &alignoffsets, + const edgeinfowithscoresvector &edges, + const nodeinfovector &nodes, const doublevector &edgeweights, + matrix &errorsignal) + { + size_t ts = nodes[edges[j].S].t; + size_t te = nodes[edges[j].E].t; + if (ts != te) + { + + + const float weight = (float)(edgeweights[j]); + size_t offset = alignoffsets[j]; + for (size_t t = ts; t < te; t++) + { + const size_t s = (size_t)alignstateids[t - ts + offset]; + + // use atomic function for lock the value + atomicAdd(&errorsignal(s, t), weight); + } + } + } // accumulate a per-edge quantity into the states that the edge is aligned with // Use this for MMI passing the edge posteriors logpps[] as logq, or for sMBR passing logEframescorrect[]. // j=edge index, alignment in (alignstateids, alignoffsets) diff --git a/Source/Readers/HTKMLFReader/HTKMLFReader.cpp b/Source/Readers/HTKMLFReader/HTKMLFReader.cpp index f56a1ac3ff3e..02fdda164b52 100644 --- a/Source/Readers/HTKMLFReader/HTKMLFReader.cpp +++ b/Source/Readers/HTKMLFReader/HTKMLFReader.cpp @@ -26,6 +26,8 @@ #include "ScriptableObjects.h" #include "HTKMLFReader.h" #include "TimerUtility.h" +#include "fileutil.h" +#include #ifdef LEAKDETECT #include // for memory leak detection #endif @@ -98,7 +100,30 @@ void HTKMLFReader::InitFromConfig(const ConfigRecordType& readerConfig PrepareForTrainingOrTesting(readerConfig); } } +void readwordidmap(const std::wstring &pathname, std::unordered_map& wordidmap, int start_id) +{ + std::unordered_map::iterator mp_itr; + auto_file_ptr f(fopenOrDie(pathname, L"rbS")); + fprintf(stderr, "readwordidmap: reading %ls \n", pathname.c_str()); + char buf[1024]; + char word[1024]; + int dumid; + while (!feof(f)) + { + fgetline(f, buf); + if (sscanf(buf, "%s %d", word, &dumid) != 2) + { + fprintf(stderr, "readwordidmap: reaching the end of line, with content = %s", buf); + break; + } + if (wordidmap.find(std::string(word)) == wordidmap.end()) + { + wordidmap.insert(pair(string(word),start_id++)); + } + } + fclose(f); +} // Load all input and output data. // Note that the terms features imply be real-valued quantities and // labels imply categorical quantities, irrespective of whether they @@ -116,6 +141,7 @@ void HTKMLFReader::PrepareForTrainingOrTesting(const ConfigRecordType& vector> infilesmulti; size_t numFiles; wstring unigrampath(L""); + wstring wordidmappath(L""); size_t randomize = randomizeAuto; size_t iFeat, iLabel; @@ -443,8 +469,16 @@ void HTKMLFReader::PrepareForTrainingOrTesting(const ConfigRecordType& if (readerConfig.Exists(L"unigram")) unigrampath = (const wstring&) readerConfig(L"unigram"); + if (readerConfig.Exists(L"wordidmap")) + wordidmappath = (const wstring&)readerConfig(L"wordidmap"); + // load a unigram if needed (this is used for MMI training) msra::lm::CSymbolSet unigramsymbols; + std::set specialwordids; + std::vector specialwords; + std::unordered_map wordidmap; + std::unordered_map::iterator wordidmap_itr; + std::unique_ptr unigram; size_t silencewordid = SIZE_MAX; size_t startwordid = SIZE_MAX; @@ -452,10 +486,96 @@ void HTKMLFReader::PrepareForTrainingOrTesting(const ConfigRecordType& if (unigrampath != L"") { unigram.reset(new msra::lm::CMGramLM()); + + unigramsymbols["!NULL"]; + unigramsymbols[""]; + unigramsymbols[""]; + unigramsymbols["!sent_start"]; + unigramsymbols["!sent_end"]; + unigramsymbols["!silence"]; unigram->read(unigrampath, unigramsymbols, false /*filterVocabulary--false will build the symbol map*/, 1 /*maxM--unigram only*/); silencewordid = unigramsymbols["!silence"]; // give this an id (even if not in the LM vocabulary) startwordid = unigramsymbols[""]; endwordid = unigramsymbols[""]; + + specialwordids.clear(); + specialwordids.insert(unigramsymbols[""]); + specialwordids.insert(unigramsymbols[""]); + specialwordids.insert(unigramsymbols["!NULL"]); + specialwordids.insert(unigramsymbols["!sent_start"]); + specialwordids.insert(unigramsymbols["!sent_end"]); + specialwordids.insert(unigramsymbols["!silence"]); + specialwordids.insert(unigramsymbols["[/CNON]"]); + specialwordids.insert(unigramsymbols["[/CSPN]"]); + specialwordids.insert(unigramsymbols["[/NPS]"]); + specialwordids.insert(unigramsymbols["[CNON/]"]); + specialwordids.insert(unigramsymbols["[CNON]"]); + specialwordids.insert(unigramsymbols["[CSPN]"]); + specialwordids.insert(unigramsymbols["[FILL/]"]); + specialwordids.insert(unigramsymbols["[NON/]"]); + specialwordids.insert(unigramsymbols["[NONNATIVE/]"]); + specialwordids.insert(unigramsymbols["[NPS]"]); + + specialwordids.insert(unigramsymbols["[SB/]"]); + specialwordids.insert(unigramsymbols["[SBP/]"]); + specialwordids.insert(unigramsymbols["[SN/]"]); + specialwordids.insert(unigramsymbols["[SPN/]"]); + specialwordids.insert(unigramsymbols["[UNKNOWN/]"]); + specialwordids.insert(unigramsymbols[".]"]); + + // this is to exclude the unknown words in lattice brought when merging the numerator lattice into denominator lattice. + specialwordids.insert(0xfffff); + } + + else if (wordidmappath != L"") + { + wordidmap.insert(pair("!NULL", 0)); + wordidmap.insert(pair("", 1)); + wordidmap.insert(pair("", 2)); + wordidmap.insert(pair("!sent_start", 3)); + wordidmap.insert(pair("!sent_end", 4)); + wordidmap.insert(pair("!silence", 5)); + + silencewordid = 5; // give this an id (even if not in the LM vocabulary) + startwordid = 1; + endwordid = 2; + + int start_id = 6; + readwordidmap(wordidmappath, wordidmap, start_id); + specialwordids.clear(); + specialwords.clear(); + + specialwords.push_back(""); + + specialwords.push_back(""); + specialwords.push_back("!NULL"); + specialwords.push_back("!sent_start"); + specialwords.push_back("!sent_end"); + specialwords.push_back("!silence"); + specialwords.push_back("[/CNON]"); + specialwords.push_back("[/CSPN]"); + specialwords.push_back("[/NPS]"); + specialwords.push_back("[CNON/]"); + specialwords.push_back("[CNON]"); + specialwords.push_back("[CSPN]"); + specialwords.push_back("[FILL/]"); + specialwords.push_back("[NON/]"); + specialwords.push_back("[NONNATIVE/]"); + specialwords.push_back("[NPS]"); + + specialwords.push_back("[SB/]"); + specialwords.push_back("[SBP/]"); + specialwords.push_back("[SN/]"); + specialwords.push_back("[SPN/]"); + specialwords.push_back("[UNKNOWN/]"); + specialwords.push_back(".]"); + + for (size_t i = 0; i < specialwords.size(); i++) + { + wordidmap_itr = wordidmap.find(specialwords[i]); + specialwordids.insert((wordidmap_itr == wordidmap.end()) ? -1 : wordidmap_itr->second); + } + specialwordids.insert(0xfffff); } if (!unigram && latticetocs.second.size() > 0) @@ -497,19 +617,18 @@ void HTKMLFReader::PrepareForTrainingOrTesting(const ConfigRecordType& double htktimetoframe = 100000.0; // default is 10ms // std::vector> labelsmulti; - std::vector>> labelsmulti; + std::vector, std::vector>>> labelsmulti; // std::vector pagepath; foreach_index (i, mlfpathsmulti) { - const msra::lm::CSymbolSet* wordmap = unigram ? &unigramsymbols : NULL; msra::asr::htkmlfreader - labels(mlfpathsmulti[i], restrictmlftokeys, statelistpaths[i], wordmap, (map*) NULL, htktimetoframe); // label MLF + labels(mlfpathsmulti[i], restrictmlftokeys, statelistpaths[i], wordidmap, htktimetoframe); // label MLF + // get the temp file name for the page file // Make sure 'msra::asr::htkmlfreader' type has a move constructor static_assert(std::is_move_constructible>::value, "Type 'msra::asr::htkmlfreader' should be move constructible!"); - labelsmulti.push_back(std::move(labels)); } @@ -522,7 +641,7 @@ void HTKMLFReader::PrepareForTrainingOrTesting(const ConfigRecordType& // now get the frame source. This has better randomization and doesn't create temp files bool useMersenneTwisterRand = readerConfig(L"useMersenneTwisterRand", false); - m_frameSource.reset(new msra::dbn::minibatchutterancesourcemulti(useMersenneTwisterRand, infilesmulti, labelsmulti, m_featDims, m_labelDims, + m_frameSource.reset(new msra::dbn::minibatchutterancesourcemulti(useMersenneTwisterRand, infilesmulti, labelsmulti, specialwordids, m_featDims, m_labelDims, numContextLeft, numContextRight, randomize, *m_lattices, m_latticeMap, m_frameMode, m_expandToUtt, m_maxUtteranceLength, m_truncated)); @@ -756,6 +875,8 @@ void HTKMLFReader::StartDistributedMinibatchLoop(size_t requestedMBSiz // for the multi-utterance process for lattice and phone boundary m_latticeBufferMultiUtt.assign(m_numSeqsPerMB, nullptr); m_labelsIDBufferMultiUtt.resize(m_numSeqsPerMB); + m_wlabelsIDBufferMultiUtt.resize(m_numSeqsPerMB); + m_nwsBufferMultiUtt.resize(m_numSeqsPerMB); m_phoneboundaryIDBufferMultiUtt.resize(m_numSeqsPerMB); if (m_frameMode && (m_numSeqsPerMB > 1)) @@ -894,11 +1015,11 @@ void HTKMLFReader::StartMinibatchLoopToWrite(size_t mbSize, size_t /*e template bool HTKMLFReader::GetMinibatch4SE(std::vector>& latticeinput, - vector& uids, vector& boundaries, vector& extrauttmap) + vector& uids, vector& wids, vector& nws, vector& boundaries, vector& extrauttmap) { if (m_trainOrTest) { - return GetMinibatch4SEToTrainOrTest(latticeinput, uids, boundaries, extrauttmap); + return GetMinibatch4SEToTrainOrTest(latticeinput, uids, wids, nws, boundaries, extrauttmap); } else { @@ -907,16 +1028,22 @@ bool HTKMLFReader::GetMinibatch4SE(std::vector bool HTKMLFReader::GetMinibatch4SEToTrainOrTest(std::vector>& latticeinput, - std::vector& uids, std::vector& boundaries, std::vector& extrauttmap) + + std::vector& uids, std::vector& wids, std::vector& nws, std::vector& boundaries, std::vector& extrauttmap) { latticeinput.clear(); uids.clear(); + wids.clear(); + nws.clear(); boundaries.clear(); extrauttmap.clear(); for (size_t i = 0; i < m_extraSeqsPerMB.size(); i++) { latticeinput.push_back(m_extraLatticeBufferMultiUtt[i]); uids.insert(uids.end(), m_extraLabelsIDBufferMultiUtt[i].begin(), m_extraLabelsIDBufferMultiUtt[i].end()); + wids.insert(wids.end(), m_extraWLabelsIDBufferMultiUtt[i].begin(), m_extraWLabelsIDBufferMultiUtt[i].end()); + + nws.insert(nws.end(), m_extraNWsBufferMultiUtt[i].begin(), m_extraNWsBufferMultiUtt[i].end()); boundaries.insert(boundaries.end(), m_extraPhoneboundaryIDBufferMultiUtt[i].begin(), m_extraPhoneboundaryIDBufferMultiUtt[i].end()); } @@ -984,6 +1111,8 @@ bool HTKMLFReader::GetMinibatchToTrainOrTest(StreamMinibatchInputs& ma m_extraLabelsIDBufferMultiUtt.clear(); m_extraPhoneboundaryIDBufferMultiUtt.clear(); m_extraSeqsPerMB.clear(); + m_extraWLabelsIDBufferMultiUtt.clear(); + m_extraNWsBufferMultiUtt.clear(); if (m_noData && m_numFramesToProcess[0] == 0) // no data left for the first channel of this minibatch, { return false; @@ -1064,6 +1193,8 @@ bool HTKMLFReader::GetMinibatchToTrainOrTest(StreamMinibatchInputs& ma { m_extraLatticeBufferMultiUtt.push_back(m_latticeBufferMultiUtt[i]); m_extraLabelsIDBufferMultiUtt.push_back(m_labelsIDBufferMultiUtt[i]); + m_extraWLabelsIDBufferMultiUtt.push_back(m_wlabelsIDBufferMultiUtt[i]); + m_extraNWsBufferMultiUtt.push_back(m_nwsBufferMultiUtt[i]); m_extraPhoneboundaryIDBufferMultiUtt.push_back(m_phoneboundaryIDBufferMultiUtt[i]); } } @@ -1106,6 +1237,8 @@ bool HTKMLFReader::GetMinibatchToTrainOrTest(StreamMinibatchInputs& ma { m_extraLatticeBufferMultiUtt.push_back(m_latticeBufferMultiUtt[src]); m_extraLabelsIDBufferMultiUtt.push_back(m_labelsIDBufferMultiUtt[src]); + m_extraWLabelsIDBufferMultiUtt.push_back(m_wlabelsIDBufferMultiUtt[src]); + m_extraNWsBufferMultiUtt.push_back(m_nwsBufferMultiUtt[src]); m_extraPhoneboundaryIDBufferMultiUtt.push_back(m_phoneboundaryIDBufferMultiUtt[src]); } @@ -1811,6 +1944,10 @@ bool HTKMLFReader::ReNewBufferForMultiIO(size_t i) m_phoneboundaryIDBufferMultiUtt[i] = m_mbiter->bounds(); m_labelsIDBufferMultiUtt[i].clear(); m_labelsIDBufferMultiUtt[i] = m_mbiter->labels(); + m_wlabelsIDBufferMultiUtt[i].clear(); + m_wlabelsIDBufferMultiUtt[i] = m_mbiter->wlabels(); + m_nwsBufferMultiUtt[i].clear(); + m_nwsBufferMultiUtt[i] = m_mbiter->nwords(); } m_processedFrame[i] = 0; @@ -2031,8 +2168,7 @@ unique_ptr& HTKMLFReader::GetCUDAAllocator if (m_cudaAllocator == nullptr) { m_cudaAllocator.reset(new CUDAPageLockedMemAllocator(deviceID)); - } - + } return m_cudaAllocator; } @@ -2049,6 +2185,7 @@ std::shared_ptr HTKMLFReader::AllocateIntermediateBuffer(int this->GetCUDAAllocator(deviceID)->Free((char*) p); }); } + else { return std::shared_ptr(new ElemType[numElements], @@ -2059,6 +2196,9 @@ std::shared_ptr HTKMLFReader::AllocateIntermediateBuffer(int } } + template class HTKMLFReader; template class HTKMLFReader; } } } + + diff --git a/Source/Readers/HTKMLFReader/HTKMLFReader.h b/Source/Readers/HTKMLFReader/HTKMLFReader.h index 752d36aab559..0a05945a019c 100644 --- a/Source/Readers/HTKMLFReader/HTKMLFReader.h +++ b/Source/Readers/HTKMLFReader/HTKMLFReader.h @@ -77,6 +77,12 @@ class HTKMLFReader : public DataReaderBase std::vector> m_phoneboundaryIDBufferMultiUtt; std::vector> m_extraLatticeBufferMultiUtt; std::vector> m_extraLabelsIDBufferMultiUtt; + + /* word labels */ + std::vector> m_wlabelsIDBufferMultiUtt; + std::vector> m_extraWLabelsIDBufferMultiUtt; + std::vector> m_nwsBufferMultiUtt; + std::vector> m_extraNWsBufferMultiUtt; std::vector> m_extraPhoneboundaryIDBufferMultiUtt; // hmm @@ -109,7 +115,7 @@ class HTKMLFReader : public DataReaderBase void PrepareForWriting(const ConfigRecordType& config); bool GetMinibatchToTrainOrTest(StreamMinibatchInputs& matrices); - bool GetMinibatch4SEToTrainOrTest(std::vector>& latticeinput, vector& uids, vector& boundaries, std::vector& extrauttmap); + bool GetMinibatch4SEToTrainOrTest(std::vector>& latticeinput, vector& uids, vector& wids, vector& nws, vector& boundaries, std::vector& extrauttmap); void fillOneUttDataforParallelmode(StreamMinibatchInputs& matrices, size_t startFr, size_t framenum, size_t channelIndex, size_t sourceChannelIndex); // TODO: PascalCase() bool GetMinibatchToWrite(StreamMinibatchInputs& matrices); @@ -189,7 +195,7 @@ class HTKMLFReader : public DataReaderBase virtual const std::map& GetLabelMapping(const std::wstring& sectionName); virtual void SetLabelMapping(const std::wstring& sectionName, const std::map& labelMapping); virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart = 0); - virtual bool GetMinibatch4SE(std::vector>& latticeinput, vector& uids, vector& boundaries, vector& extrauttmap); + virtual bool GetMinibatch4SE(std::vector>& latticeinput, vector& uids, vector& wids, vector& nws, vector& boundaries, vector& extrauttmap); virtual bool GetHmmData(msra::asr::simplesenonehmm* hmm); virtual bool DataEnd(); diff --git a/Source/Readers/HTKMLFReader/htkfeatio.h b/Source/Readers/HTKMLFReader/htkfeatio.h index c888cd4812b3..910fe0ea6067 100644 --- a/Source/Readers/HTKMLFReader/htkfeatio.h +++ b/Source/Readers/HTKMLFReader/htkfeatio.h @@ -864,13 +864,14 @@ struct htkmlfentry setdata(ts, te, uid); } }; - template -class htkmlfreader : public map> // [key][i] the data + +class htkmlfreader : public map, vector>> // [key][i] the data { wstring curpath; // for error messages unordered_map statelistmap; // for state <=> index map wordsequences; // [key] word sequences (if we are building word entries as well, for MMI) + std::unordered_map symmap; void strtok(char* s, const char* delim, vector& toks) { @@ -900,10 +901,9 @@ class htkmlfreader : public map> // [key][i] the data return lines; } - template + template void parseentry(const vector& lines, size_t line, const set& restricttokeys, - const WORDSYMBOLTABLE* wordmap, const UNITSYMBOLTABLE* unitmap, - vector& wordseqbuffer, vector& alignseqbuffer, + const WORDSYMBOLTABLE* wordmap, /* const UNITSYMBOLTABLE* unitmap, */ const double htkTimeToFrame) { size_t idx = 0; @@ -936,13 +936,15 @@ class htkmlfreader : public map> // [key][i] the data // don't parse unused entries (this is supposed to be used for very small debugging setups with huge MLFs) if (!restricttokeys.empty() && restricttokeys.find(key) == restricttokeys.end()) return; - - vector& entries = (*this)[key]; // this creates a new entry + vector& entries = (*this)[key].first; // this creates a new entry if (!entries.empty()) - malformed(msra::strfun::strprintf("duplicate entry '%ls'", key.c_str())); + // do not want to die immediately + fprintf(stderr, + "Warning: duplicate entry: %ls \n", + key.c_str()); entries.resize(e - s); - wordseqbuffer.resize(0); - alignseqbuffer.resize(0); + vector& wordids = (*this)[key].second; + wordids.resize(0); vector toks; for (size_t i = s; i < e; i++) { @@ -957,55 +959,240 @@ class htkmlfreader : public map> // [key][i] the data { if (toks.size() > 6 /*word entry are in this column*/) { + // convert letter to uppercase + if (strcmp(toks[6], "") != 0 + && strcmp(toks[6], "") != 0 + && strcmp(toks[6], "!sent_start") != 0 + && strcmp(toks[6], "!sent_end") != 0 + && strcmp(toks[6], "!silence") != 0) + { + for(size_t j = 0; j < strlen(toks[6]); j++) + { + if(toks[6][j] >= 'a' && toks[6][j] <= 'z') + toks[6][j] = toks[6][j] + 'A' - 'a'; + } + } const char* w = toks[6]; // the word name - int wid = (*wordmap)[w]; // map to word id --may be -1 for unseen words in the transcript (word list typically comes from a test LM) - size_t wordindex = (wid == -1) ? WORDSEQUENCE::word::unknownwordindex : (size_t) wid; - wordseqbuffer.push_back(typename WORDSEQUENCE::word(wordindex, entries[i - s].firstframe, alignseqbuffer.size())); - } - if (unitmap) - { - if (toks.size() > 4) + // For some alignment MLF the sentence start and end are both represented by , we change sentence end to be + if (i > s && strcmp(w, "") == 0) + w = ""; + /* skip the words that are not used in WER computation */ + /* ugly hard code, will improve later */ + if (strcmp(w, "") != 0 + && strcmp(w, "") != 0 + && strcmp(w, "!NULL") != 0 + && strcmp(w, "!sent_start") != 0 + && strcmp(w, "!sent_end") != 0 + && strcmp(w, "!silence") != 0 + && strcmp(w, "[/CNON]") != 0 + && strcmp(w, "[/CSPN]") != 0 + && strcmp(w, "[/NPS]") != 0 + && strcmp(w, "[CNON/]") != 0 + && strcmp(w, "[CNON]") != 0 + && strcmp(w, "[CSPN]") != 0 + && strcmp(w, "[FILL/]") != 0 + && strcmp(w, "[NON/]") != 0 + && strcmp(w, "[NONNATIVE/]") != 0 + && strcmp(w, "[NPS]") != 0 + && strcmp(w, "[SB/]") != 0 + && strcmp(w, "[SBP/]") != 0 + && strcmp(w, "[SN/]") != 0 + && strcmp(w, "[SPN/]") != 0 + && strcmp(w, "[UNKNOWN/]") != 0 + && strcmp(w, ".]") != 0 + ) { - const char* u = toks[4]; // the triphone name - auto iter = unitmap->find(u); // map to unit id - if (iter == unitmap->end()) - RuntimeError("parseentry: unknown unit %s in utterance %ls", u, key.c_str()); - const size_t uid = iter->second; - alignseqbuffer.push_back(typename WORDSEQUENCE::aligninfo(uid, 0 /*#frames--we accumulate*/)); + int wid = (*wordmap)[w]; // map to word id --may be -1 for unseen words in the transcript (word list typically comes from a test LM) + static const unsigned int unknownwordindex = 0xfffff; + size_t wordindex = (wid == -1) ? unknownwordindex : (size_t)wid; + wordids.push_back(wordindex); } - if (alignseqbuffer.empty()) - RuntimeError("parseentry: lonely senone entry at start without phone/word entry found, for utterance %ls", key.c_str()); - alignseqbuffer.back().frames += entries[i - s].numframes; // (we do not have an overflow check here, but should...) } } } if (wordmap) // if reading word sequences as well (for MMI), then record it (in a separate map) { - if (!entries.empty() && wordseqbuffer.empty()) - RuntimeError("parseentry: got state alignment but no word-level info, although being requested, for utterance %ls", key.c_str()); + if (!entries.empty() && wordids.empty()) + { + fprintf(stderr, + "Warning: parseentry: got state alignment but no word-level info, although being requested, for utterance %ls \n", + key.c_str()); + } // post-process silence // - first !silence -> !sent_start // - last !silence -> !sent_end - int silence = (*wordmap)["!silence"]; - if (silence >= 0) + else { - int sentstart = (*wordmap)["!sent_start"]; // these must have been created - int sentend = (*wordmap)["!sent_end"]; - // map first and last !silence to !sent_start and !sent_end, respectively - if (sentstart >= 0 && wordseqbuffer.front().wordindex == (size_t) silence) - wordseqbuffer.front().wordindex = sentstart; - if (sentend >= 0 && wordseqbuffer.back().wordindex == (size_t) silence) - wordseqbuffer.back().wordindex = sentend; + int silence = (*wordmap)["!silence"]; + if (silence >= 0) + { + int sentstart = (*wordmap)["!sent_start"]; // these must have been created + int sentend = (*wordmap)["!sent_end"]; + // map first and last !silence to !sent_start and !sent_end, respectively + if (sentstart >= 0 && wordids.front() == (size_t)silence) + wordids.front() = sentstart; + if (sentend >= 0 && wordids.back() == (size_t)silence) + wordids.back() = sentend; + } } // if (sentstart < 0 || sentend < 0 || silence < 0) // LogicError("parseentry: word map must contain !silence, !sent_start, and !sent_end"); // implant - auto& wordsequence = wordsequences[key]; // this creates the map entry - wordsequence.words = wordseqbuffer; // makes a copy - wordsequence.align = alignseqbuffer; } } + + void parseentry(const vector& lines, size_t line, const set& restricttokeys, + const std::unordered_map& wordidmap, + const double htkTimeToFrame) + { + + std::unordered_map::const_iterator mp_itr; + + size_t idx = 0; + string filename = lines[idx++]; + while (filename == "#!MLF!#") // skip embedded duplicate MLF headers (so user can 'cat' MLFs) + filename = lines[idx++]; + + // some mlf file have write errors, so skip malformed entry + if (filename.length() < 3 || filename[0] != '"' || filename[filename.length() - 1] != '"') + { + fprintf(stderr, "warning: filename entry (%s)\n", filename.c_str()); + fprintf(stderr, "skip current mlf entry from line (%lu) until line (%lu).\n", (unsigned long)(line + idx), (unsigned long)(line + lines.size())); + return; + } + + filename = filename.substr(1, filename.length() - 2); // strip quotes + if (filename.find("*/") == 0) + filename = filename.substr(2); +#ifdef _MSC_VER + wstring key = Microsoft::MSR::CNTK::ToFixedWStringFromMultiByte(regex_replace(filename, regex("\\.[^\\.\\\\/:]*$"), string())); // delete extension (or not if none) +#else + wstring key = Microsoft::MSR::CNTK::ToFixedWStringFromMultiByte(msra::dbn::removeExtension(filename)); // note that c++ 4.8 is incomplete for supporting regex +#endif + + // determine lines range + size_t s = idx; + size_t e = lines.size() - 1; + // lines range: [s,e) + // don't parse unused entries (this is supposed to be used for very small debugging setups with huge MLFs) + if (!restricttokeys.empty() && restricttokeys.find(key) == restricttokeys.end()) + return; + vector& entries = (*this)[key].first; + if (!entries.empty()) + // do not want to die immediately + fprintf(stderr, + "Warning: duplicate entry : %ls \n", + key.c_str()); + + entries.resize(e - s); + + vector& wordids = (*this)[key].second; + wordids.resize(0); + vector toks; + for (size_t i = s; i < e; i++) + { + // We can mutate the original string as it is no longer needed after tokenization + strtok(const_cast(lines[i].c_str()), " \t", toks); + if (statelistmap.size() == 0) + entries[i - s].parse(toks, htkTimeToFrame); + else + entries[i - s].parsewithstatelist(toks, statelistmap, htkTimeToFrame); + // if we also read word entries, do it here + if (wordidmap.size() != 0) + { + if (toks.size() > 6 /*word entry are in this column*/) + { + // convert word to uppercase + if (strcmp(toks[6], "") != 0 + && strcmp(toks[6], "") != 0 + && strcmp(toks[6], "!sent_start") != 0 + && strcmp(toks[6], "!sent_end") != 0 + && strcmp(toks[6], "!silence") != 0) + { + for(size_t j = 0; j < strlen(toks[6]); j++) + { + if(toks[6][j] >= 'a' && toks[6][j] <= 'z') + toks[6][j] = toks[6][j] + 'A' - 'a'; + } + } + const char* w = toks[6]; // the word name + // For some alignment MLF the sentence start and end are both represented by , we change sentence end to be + if (i > s && strcmp(w, "") == 0) + w = ""; + /* skip the words that are not used in WER computation */ + /* ugly hard code, will improve later */ + if (strcmp(w, "") != 0 + && strcmp(w, "") != 0 + && strcmp(w, "!NULL") != 0 + && strcmp(w, "!sent_start") != 0 + && strcmp(w, "!sent_end") != 0 + && strcmp(w, "!silence") != 0 + && strcmp(w, "[/CNON]") != 0 + && strcmp(w, "[/CSPN]") != 0 + && strcmp(w, "[/NPS]") != 0 + && strcmp(w, "[CNON/]") != 0 + && strcmp(w, "[CNON]") != 0 + && strcmp(w, "[CSPN]") != 0 + && strcmp(w, "[FILL/]") != 0 + && strcmp(w, "[NON/]") != 0 + && strcmp(w, "[NONNATIVE/]") != 0 + && strcmp(w, "[NPS]") != 0 + && strcmp(w, "[SB/]") != 0 + && strcmp(w, "[SBP/]") != 0 + && strcmp(w, "[SN/]") != 0 + && strcmp(w, "[SPN/]") != 0 + && strcmp(w, "[UNKNOWN/]") != 0 + && strcmp(w, ".]") != 0 + ) + { + mp_itr = wordidmap.find(std::string(w)); + int wid = ((mp_itr == wordidmap.end()) ? -1: mp_itr->second); + static const unsigned int unknownwordindex = 0xfffff; + unsigned int wordindex = (wid == -1) ? unknownwordindex : (unsigned int)wid; + wordids.push_back(wordindex); + } + } + } + } + if (wordidmap.size() != 0) // if reading word sequences as well (for MMI), then record it (in a separate map) + { + if (!entries.empty() && wordids.empty()) + { + + fprintf(stderr, + "Warning: parseentry: got state alignment but no word-level info, although being requested, for utterance %ls. Ignoring this utterance for EMBR \n", + key.c_str()); + // delete this item + (*this).erase(key); + return; + + } + + // post-process silence + // - first !silence -> !sent_start + // - last !silence -> !sent_end + else + { + + mp_itr = wordidmap.find("!silence"); + int silence = ((mp_itr == wordidmap.end()) ? -1: mp_itr->second); + if (silence >= 0) + { + mp_itr = wordidmap.find("!sent_start"); + int sentstart = ((mp_itr == wordidmap.end()) ? -1: mp_itr->second); + + mp_itr = wordidmap.find("!sent_end"); + int sentend = ((mp_itr == wordidmap.end()) ? -1: mp_itr->second); + // map first and last !silence to !sent_start and !sent_end, respectively + if (sentstart >= 0 && wordids.front() == (size_t)silence) + wordids.front() = sentstart; + if (sentend >= 0 && wordids.back() == (size_t)silence) + wordids.back() = sentend; + } + } + } + } public: // return if input statename is sil state (hard code to compared first 3 chars with "sil") bool issilstate(const string& statename) const // (later use some configuration table) @@ -1044,9 +1231,32 @@ class htkmlfreader : public map> // [key][i] the data read(paths[i], restricttokeys, wordmap, unitmap, htkTimeToFrame); } - // note: this function is not designed to be pretty but to be fast + + htkmlfreader(const vector& paths, const set& restricttokeys, const wstring& stateListPath, const std::unordered_map& wordidmap, const double htkTimeToFrame) + { + // read state list + if (stateListPath != L"") + readstatelist(stateListPath); + + // read MLF(s) --note: there can be multiple, so this is a loop + foreach_index(i, paths) + read(paths[i], restricttokeys, wordidmap, htkTimeToFrame); + } + + // phone boundary template - void read(const wstring& path, const set& restricttokeys, const WORDSYMBOLTABLE* wordmap, const UNITSYMBOLTABLE* unitmap, const double htkTimeToFrame) + htkmlfreader(const vector& paths, const set& restricttokeys, const wstring& stateListPath, const WORDSYMBOLTABLE* wordmap, const UNITSYMBOLTABLE* unitmap, + const double htkTimeToFrame, const msra::asr::simplesenonehmm& hset) + { + if (stateListPath != L"") + readstatelist(stateListPath); + symmap = hset.symmap; + foreach_index (i, paths) + read(paths[i], restricttokeys, wordmap, unitmap, htkTimeToFrame); + } + // note: this function is not designed to be pretty but to be fast + template + void read(const wstring& path, const set& restricttokeys, const WORDSYMBOLTABLE* wordmap, const double htkTimeToFrame) { if (!restricttokeys.empty() && this->size() >= restricttokeys.size()) // no need to even read the file if we are there (we support multiple files) return; @@ -1060,8 +1270,6 @@ class htkmlfreader : public map> // [key][i] the data malformed("header missing"); // Read the file in blocks and parse MLF entries - std::vector wordsequencebuffer; - std::vector alignsequencebuffer; size_t readBlockSize = 1000000; std::vector currBlockBuf(readBlockSize + 1); size_t currLineNum = 1; @@ -1091,7 +1299,7 @@ class htkmlfreader : public map> // [key][i] the data { if (restricttokeys.empty() || (this->size() < restricttokeys.size())) { - parseentry(currMLFLines, currLineNum - currMLFLines.size(), restricttokeys, wordmap, unitmap, wordsequencebuffer, alignsequencebuffer, htkTimeToFrame); + parseentry(currMLFLines, currLineNum - currMLFLines.size(), restricttokeys, wordmap, htkTimeToFrame); } currMLFLines.clear(); @@ -1134,6 +1342,93 @@ class htkmlfreader : public map> // [key][i] the data fprintf(stderr, " total %lu entries\n", (unsigned long)this->size()); } + // note: this function is not designed to be pretty but to be fast + void read(const wstring& path, const set& restricttokeys, const std::unordered_map& wordidmap, const double htkTimeToFrame) + { + if (!restricttokeys.empty() && this->size() >= restricttokeys.size()) // no need to even read the file if we are there (we support multiple files) + return; + + fprintf(stderr, "htkmlfreader: reading MLF file %ls ...", path.c_str()); + curpath = path; // for error messages only + + auto_file_ptr f(fopenOrDie(path, L"rb")); + std::string headerLine = fgetline(f); + if (headerLine != "#!MLF!#") + malformed("header missing"); + + // Read the file in blocks and parse MLF entries + size_t readBlockSize = 1000000; + std::vector currBlockBuf(readBlockSize + 1); + size_t currLineNum = 1; + std::vector currMLFLines; + bool reachedEOF = (feof(f) != 0); + char* nextReadPtr = currBlockBuf.data(); + size_t nextReadSize = readBlockSize; + while (!reachedEOF) + { + size_t numBytesRead = fread(nextReadPtr, sizeof(char), nextReadSize, f); + reachedEOF = (numBytesRead != nextReadSize); + if (ferror(f)) + RuntimeError("error reading from file: %s", strerror(errno)); + + // Add 0 at the end to make it a proper C string + nextReadPtr[numBytesRead] = 0; + + // Now extract lines from the currBlockBuf and parse MLF entries + char* context = nullptr; + const char* delim = "\r\n"; + + auto consumeMLFLine = [&](const char* mlfLine) + { + currLineNum++; + currMLFLines.push_back(mlfLine); + if ((mlfLine[0] == '.') && (mlfLine[1] == 0)) // utterance end delimiter: a single dot on a line + { + if (restricttokeys.empty() || (this->size() < restricttokeys.size())) + { + // parseentry(currMLFLines, currLineNum - currMLFLines.size(), restricttokeys, wordidmap, unitmap, wordsequencebuffer, alignsequencebuffer, htkTimeToFrame); + parseentry(currMLFLines, currLineNum - currMLFLines.size(), restricttokeys, wordidmap, htkTimeToFrame); + } + + currMLFLines.clear(); + } + }; + + char* prevLine = strtok_s(currBlockBuf.data(), delim, &context); + for (char* currLine = strtok_s(NULL, delim, &context); currLine; currLine = strtok_s(NULL, delim, &context)) + { + consumeMLFLine(prevLine); + prevLine = currLine; + } + + // The last line read from the block may be a full line or part of a line + // We can tell by whether the terminating NULL for this line is the NULL + // we inserted after reading from the file + size_t prevLineLen = strlen(prevLine); + if ((prevLine + prevLineLen) == (nextReadPtr + numBytesRead)) + { + // This is not a full line, but just a truncated part of a line. + // Lets copy this to the start of the currBlockBuf and read new data + // from there on + strcpy_s(currBlockBuf.data(), currBlockBuf.size(), prevLine); + nextReadPtr = currBlockBuf.data() + prevLineLen; + nextReadSize = readBlockSize - prevLineLen; + } + else + { + // A full line + consumeMLFLine(prevLine); + nextReadPtr = currBlockBuf.data(); + nextReadSize = readBlockSize; + } + } + + if (!currMLFLines.empty()) + malformed("unexpected end in mid-utterance"); + + curpath.clear(); + fprintf(stderr, " total %lu entries\n", (unsigned long)this->size()); + } // read state list, index is from 0 void readstatelist(const wstring& stateListPath = L"") { diff --git a/Source/Readers/HTKMLFReader/latticearchive.cpp b/Source/Readers/HTKMLFReader/latticearchive.cpp index 639ce8a32dc1..18aca4d36100 100644 --- a/Source/Readers/HTKMLFReader/latticearchive.cpp +++ b/Source/Readers/HTKMLFReader/latticearchive.cpp @@ -405,8 +405,8 @@ void lattice::dedup() // - empty ("") -> don't output, just check the format // - dash ("-") -> dump lattice to stdout instead /*static*/ void archive::convert(const std::wstring &intocpath, const std::wstring &intocpath2, const std::wstring &outpath, - const msra::asr::simplesenonehmm &hset) -{ + const msra::asr::simplesenonehmm &hset, std::set& specialwordids) + { const auto &modelsymmap = hset.getsymmap(); const std::wstring tocpath = outpath + L".toc"; @@ -457,8 +457,7 @@ void lattice::dedup() // fetch lattice --this performs any necessary format conversions already lattice L; - archive.getlattice(key, L); - + archive.getlattice(key, L, specialwordids); lattice L2; if (mergemode) { @@ -468,8 +467,7 @@ void lattice::dedup() skippedmerges++; continue; } - archive2.getlattice(key, L2); - + archive2.getlattice(key, L2, specialwordids); // merge it in // This will connect each node with matching 1-phone context conditions; aimed at merging numer lattices. L.removefinalnull(); // get rid of that final !NULL headache @@ -563,6 +561,7 @@ void lattice::fromhtklattice(const wstring &path, const std::unordered_map 0); nodes.reserve(info.numnodes); + vt_node_out_edge_indices.resize(info.numnodes); // parse the nodes for (size_t i = 0; i < info.numnodes; i++, iter++) { @@ -570,11 +569,15 @@ void lattice::fromhtklattice(const wstring &path, const std::unordered_map &uids, std::vector> &transcripts, std::vector> &lattices) = 0; // alternate (updated) definition for multiple inputs/outputs - read as a vector of feature matrixes or a vector of label strings + + + // alternate (updated) definition for multiple inputs/outputs - read as a vector of feature matrixes or a vector of label strings virtual bool getbatch(const size_t globalts, - const size_t framesrequested, std::vector &feat, std::vector> &uids, - std::vector> &transcripts, - std::vector> &lattices, std::vector> &sentendmark, - std::vector> &phoneboundaries) = 0; + const size_t framesrequested, std::vector &feat, std::vector> &uids, + std::vector> &transcripts, + std::vector> &lattices, std::vector> &sentendmark, + std::vector> &phoneboundaries) = 0; + + + // getbatch() overload to support subsetting of mini-batches for parallel training // Default implementation does not support subsetting and throws an exception on // calling this overload with a numsubsets value other than 1. + virtual bool getbatch(const size_t globalts, - const size_t framesrequested, const size_t subsetnum, const size_t numsubsets, size_t &framesadvanced, - std::vector &feat, std::vector> &uids, - std::vector> &transcripts, - std::vector> &lattices, std::vector> &sentendmark, - std::vector> &phoneboundaries) + const size_t framesrequested, const size_t subsetnum, const size_t numsubsets, size_t &framesadvanced, + std::vector &feat, std::vector> &uids, + std::vector> &transcripts, + std::vector> &lattices, std::vector> &sentendmark, + std::vector> &phoneboundaries) { assert((subsetnum == 0) && (numsubsets == 1) && !supportsbatchsubsetting()); subsetnum; numsubsets; + bool retVal = getbatch(globalts, framesrequested, feat, uids, transcripts, lattices, sentendmark, phoneboundaries); framesadvanced = feat[0].cols(); return retVal; } + virtual bool getbatch(const size_t globalts, + const size_t framesrequested, const size_t subsetnum, const size_t numsubsets, size_t &framesadvanced, + std::vector &feat, std::vector> &uids, std::vector> &wids, std::vector> &nws, + std::vector> &transcripts, + std::vector> &lattices, std::vector> &sentendmark, + std::vector> &phoneboundaries) + { + wids.resize(0); + nws.resize(0); + + + bool retVal = getbatch(globalts, framesrequested, subsetnum, numsubsets, framesadvanced, feat, uids, transcripts, lattices, sentendmark, phoneboundaries); + + return retVal; + } + + virtual bool supportsbatchsubsetting() const { return false; @@ -102,6 +128,10 @@ class minibatchiterator std::vector featbuf; // buffer for holding curernt minibatch's frames std::vector> uids; // buffer for storing current minibatch's frame-level label sequence + + std::vector> wids; // buffer for storing current minibatch's word-level label sequence + std::vector> nws; // buffer for storing current minibatch's number of words for each utterance + std::vector> transcripts; // buffer for storing current minibatch's word-level label sequences (if available and used; empty otherwise) std::vector> lattices; // lattices of the utterances in current minibatch (empty in frame mode) @@ -127,6 +157,11 @@ class minibatchiterator foreach_index (i, uids) uids[i].clear(); + foreach_index(i, wids) + wids[i].clear(); + + foreach_index(i, nws) + nws[i].clear(); transcripts.clear(); actualmbframes = 0; return; @@ -135,7 +170,7 @@ class minibatchiterator assert(requestedmbframes > 0); const size_t requestedframes = std::min(requestedmbframes, epochendframe - mbstartframe); // (< mbsize at end) assert(requestedframes > 0); - source.getbatch(mbstartframe, requestedframes, subsetnum, numsubsets, mbframesadvanced, featbuf, uids, transcripts, lattices, sentendmark, phoneboundaries); + source.getbatch(mbstartframe, requestedframes, subsetnum, numsubsets, mbframesadvanced, featbuf, uids, wids, nws, transcripts, lattices, sentendmark, phoneboundaries); timegetbatch = source.gettimegetbatch(); actualmbframes = featbuf[0].cols(); // for single i/o, there featbuf is length 1 // note: @@ -314,6 +349,23 @@ class minibatchiterator assert(uids.size() >= i + 1); return uids[i]; } + // return the reference transcript word labels (word labels) for current minibatch + /*const*/ std::vector &wlabels() + { + checkhasdata(); + assert(wids.size() == 1); + + return wids[0]; + } + + // return the number of words for current minibatch + /*const*/ std::vector &nwords() + { + checkhasdata(); + assert(nws.size() == 1); + + return nws[0]; + } std::vector &sentends() { diff --git a/Source/Readers/HTKMLFReader/minibatchsourcehelpers.h b/Source/Readers/HTKMLFReader/minibatchsourcehelpers.h index 2e5043c6c45a..a59b9381207d 100644 --- a/Source/Readers/HTKMLFReader/minibatchsourcehelpers.h +++ b/Source/Readers/HTKMLFReader/minibatchsourcehelpers.h @@ -194,6 +194,7 @@ static void augmentneighbors(const std::vector>& frames, cons // TODO: This is currently being hardcoded to unsigned short for saving space, which means untied context-dependent phones // will not work. This needs to be changed to dynamically choose what size to use based on the number of class ids. typedef unsigned short CLASSIDTYPE; +typedef unsigned int WORDIDTYPE; typedef unsigned short HMMIDTYPE; #ifndef _MSC_VER diff --git a/Source/Readers/HTKMLFReader/msra_mgram.h b/Source/Readers/HTKMLFReader/msra_mgram.h index 043e42be5391..5573c6fd621c 100644 --- a/Source/Readers/HTKMLFReader/msra_mgram.h +++ b/Source/Readers/HTKMLFReader/msra_mgram.h @@ -1,4 +1,4 @@ -// +// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE.md file in the project root for full license information. // @@ -15,6 +15,7 @@ #include // for various sort() calls #include + namespace msra { namespace lm { // =========================================================================== @@ -92,15 +93,33 @@ static inline double invertlogprob(double logP) // compare function to allow char* as keys (without, unordered_map will correctly // compute a hash key from the actual strings, but then compare the pointers // -- duh!) -struct less_strcmp : public std::binary_function +struct equal_strcmp : public std::binary_function { // this implements operator< bool operator()(const char *const &_Left, const char *const &_Right) const { - return strcmp(_Left, _Right) < 0; + return strcmp(_Left, _Right) == 0; + } +}; +struct BKDRHash { + //BKDR hash algorithm + int operator()(const char * str)const + { + unsigned int seed = 131; //31 131 1313 13131131313 etc// + unsigned int hash = 0; + while (*str) + { + hash = (hash * seed) + (*str); + str++; + } + + return hash & (0x7FFFFFFF); } }; -class CSymbolSet : public std::unordered_map, less_strcmp> +/* bug fix: the customize function of compare should be written in the one commented below is not right. The generated behavior is very strange: it does not correctly make a map. So, fix it. */ +// class CSymbolSet : public std::unordered_map, less_strcmp> +// class CSymbolSet : public std::unordered_map, equal_strcmp> +class CSymbolSet : public std::unordered_map { std::vector symbols; // the symbols @@ -128,7 +147,7 @@ class CSymbolSet : public std::unordered_map::const_iterator iter = find(key); + unordered_map::const_iterator iter = find(key); return (iter != end()) ? iter->second : -1; } @@ -136,7 +155,8 @@ class CSymbolSet : public std::unordered_map::const_iterator iter = find(key); + unordered_map::const_iterator iter = find(key); + if (iter != end()) return iter->second; @@ -149,7 +169,8 @@ class CSymbolSet : public std::unordered_map dims; dims.reserve(4); while (buf[0] == 0 && !feof(f)) lineNo++, fgetline(f, buf); - int n, dim; dims.push_back(1); // dummy zerogram entry while (sscanf(buf, "ngram %d=%d", &n, &dim) == 2 && n == (int) dims.size()) @@ -1510,11 +1531,10 @@ class CMGramLM : public ILM { while (buf[0] == 0 && !feof(f)) lineNo++, fgetline(f, buf); - if (sscanf(buf, "\\%d-grams:", &n) != 1 || n != m) RuntimeError("read: mal-formed LM file, bad section header (%d): %ls", lineNo, pathname.c_str()); lineNo++, fgetline(f, buf); - + std::vector mgram(m + 1, -1); // current mgram being read ([0]=dummy) std::vector prevmgram(m + 1, -1); // cache to speed up symbol lookup mgram_map::cache_t mapCache; // cache to speed up map.create() @@ -1576,9 +1596,7 @@ class CMGramLM : public ILM double boVal = atof(tokens[m + 1]); // ... use sscanf() instead for error checking? thisLogB = boVal * ln10xLMF; // convert to natural log } - lineNo++, fgetline(f, buf); - if (skipEntry) // word contained unknown vocabulary: skip entire entry goto skipMGram; diff --git a/Source/Readers/HTKMLFReader/rollingwindowsource.h b/Source/Readers/HTKMLFReader/rollingwindowsource.h index 91121104b69f..c4a77e7f3f50 100644 --- a/Source/Readers/HTKMLFReader/rollingwindowsource.h +++ b/Source/Readers/HTKMLFReader/rollingwindowsource.h @@ -561,7 +561,7 @@ class minibatchframesourcemulti : public minibatchsource public: // constructor // Pass empty labels to denote unsupervised training (so getbatch() will not return uids). - minibatchframesourcemulti(const std::vector> &infiles, const std::vector>> &labels, + minibatchframesourcemulti(const std::vector> &infiles, const std::vector, std::vector>>> &labels, std::vector vdim, std::vector udim, std::vector leftcontext, std::vector rightcontext, size_t randomizationrange, const std::vector &pagepath, const bool mayhavenoframe = false, int addEnergy = 0) : vdim(vdim), leftcontext(leftcontext), rightcontext(rightcontext), sampperiod(0), featdim(0), numframes(0), timegetbatch(0), verbosity(2), maxvdim(0) { @@ -656,7 +656,7 @@ class minibatchframesourcemulti : public minibatchsource // HVite occasionally generates mismatching output --skip such files if (!key.empty()) // (we have a key if supervised mode) { - const auto &labseq = labels[0].find(key)->second; // (we already checked above that it exists) + const auto &labseq = labels[0].find(key)->second.first; // (we already checked above that it exists) size_t labframes = labseq.empty() ? 0 : (labseq[labseq.size() - 1].firstframe + labseq[labseq.size() - 1].numframes); if (abs((int) labframes - (int) feat.cols()) > 0) { @@ -695,7 +695,7 @@ class minibatchframesourcemulti : public minibatchsource { foreach_index (j, labels) { - const auto &labseq = labels[j].find(key)->second; // (we already checked above that it exists) + const auto &labseq = labels[j].find(key)->second.first; // (we already checked above that it exists) foreach_index (i2, labseq) { const auto &e = labseq[i2]; diff --git a/Source/Readers/HTKMLFReader/utterancesourcemulti.h b/Source/Readers/HTKMLFReader/utterancesourcemulti.h index c43aaac83a97..fba7e33563d1 100644 --- a/Source/Readers/HTKMLFReader/utterancesourcemulti.h +++ b/Source/Readers/HTKMLFReader/utterancesourcemulti.h @@ -14,6 +14,7 @@ #include "minibatchiterator.h" #include #include +#include namespace msra { namespace dbn { @@ -36,6 +37,7 @@ class minibatchutterancesourcemulti : public minibatchsource const bool truncated; //false -> truncated utterance or not within minibatch size_t maxUtteranceLength; //10000 ->maximum utterance length in non-frame and non-truncated mode + std::set specialwordids; // stores the word ids that will not be counted for WER computation std::vector> counts; // [s] occurence count for all states (used for priors) int verbosity; // lattice reader @@ -55,9 +57,11 @@ class minibatchutterancesourcemulti : public minibatchsource { msra::asr::htkfeatreader::parsedpath parsedpath; // archive filename and frame range in that file size_t classidsbegin; // index into allclassids[] array (first frame) + size_t wordidsbegin; + short numwords; - utterancedesc(msra::asr::htkfeatreader::parsedpath &&ppath, size_t classidsbegin) - : parsedpath(std::move(ppath)), classidsbegin(classidsbegin), framesToExpand(0), needsExpansion(false) + utterancedesc(msra::asr::htkfeatreader::parsedpath &&ppath, size_t classidsbegin, size_t wordidsbegin) + : parsedpath(std::move(ppath)), classidsbegin(classidsbegin), wordidsbegin(wordidsbegin), framesToExpand(0), needsExpansion(false) { } bool needsExpansion; // ivector type of feature @@ -73,6 +77,15 @@ class minibatchutterancesourcemulti : public minibatchsource else return parsedpath.numframes(); } + short getnumwords() const + { + return numwords; + } + + void setnumwords(short nw) + { + numwords = nw; + } std::wstring key() const // key used for looking up lattice (not stored to save space) { #ifdef _MSC_VER @@ -129,6 +142,15 @@ class minibatchutterancesourcemulti : public minibatchsource { return utteranceset[i].classidsbegin; } + size_t getwordidsbegin(size_t i) const + { + return utteranceset[i].wordidsbegin; + } + + short numwords(size_t i) const + { + return utteranceset[i].numwords; + } msra::dbn::matrixstripe getutteranceframes(size_t i) const // return the frame set for a given utterance { if (!isinram()) @@ -152,8 +174,9 @@ class minibatchutterancesourcemulti : public minibatchsource } // page in data for this chunk // We pass in the feature info variables by ref which will be filled lazily upon first read - void requiredata(std::string &featkind, size_t &featdim, unsigned int &sampperiod, const latticesource &latticesource, int verbosity = 0) const + void requiredata(std::string &featkind, size_t &featdim, unsigned int &sampperiod, const latticesource &latticesource, std::set& specialwordids, int verbosity = 0) const { + if (numutterances() == 0) LogicError("requiredata: cannot page in virgin block"); if (isinram()) @@ -182,8 +205,8 @@ class minibatchutterancesourcemulti : public minibatchsource reader.read(utteranceset[i].parsedpath, (const std::string &)featkind, sampperiod, uttframes, utteranceset[i].needsExpansion); // note: file info here used for checkuing only // page in lattice data if (!latticesource.empty()) - latticesource.getlattices(utteranceset[i].key(), lattices[i], uttframes.cols()); - } + latticesource.getlattices(utteranceset[i].key(), lattices[i], uttframes.cols(), specialwordids); + } if (verbosity) { fprintf(stderr, "requiredata: %d utterances read\n", (int)utteranceset.size()); @@ -234,6 +257,7 @@ class minibatchutterancesourcemulti : public minibatchsource std::vector> allchunks; // set of utterances organized in chunks, referred to by an iterator (not an index) std::vector>> classids; // [classidsbegin+t] concatenation of all state sequences + std::vector>> wordids; // [wordidsbegin+t] concatenation of all state sequences bool m_generatePhoneBoundaries; std::vector>> phoneboundaries; bool issupervised() const @@ -299,6 +323,7 @@ class minibatchutterancesourcemulti : public minibatchsource } size_t numframes; // (cached since we cannot directly access the underlying data from here) + short numwords; size_t globalts; // start frame in global space after randomization (for mapping frame index to utterance position) size_t globalte() const { @@ -850,6 +875,30 @@ class minibatchutterancesourcemulti : public minibatchsource } return allclassids; // nothing to return } + + template + std::vector>> getwordids(const UTTREF &uttref) // return sub-vector of classids[] for a given utterance + { + std::vector>> allwordids; + + if (!issupervised()) + { + foreach_index(i, wordids) + allwordids.push_back(std::move(shiftedvector>((*wordids[i]), 0, 0))); + return allwordids; // nothing to return + } + const auto &chunk = randomizedchunks[0][uttref.chunkindex]; + const auto &chunkdata = chunk.getchunkdata(); + const size_t wordidsbegin = chunkdata.getwordidsbegin(uttref.utteranceindex()); // index of first state label in global concatenated classids[] array + const size_t n = chunkdata.numwords(uttref.utteranceindex()); + foreach_index(i, wordids) + { + if ((*wordids[i])[wordidsbegin + n] != (WORDIDTYPE)-1) + LogicError("getwordids: expected boundary marker not found, internal data structure screwed up"); + allwordids.push_back(std::move(shiftedvector>((*wordids[i]), wordidsbegin, n))); + } + return allwordids; // nothing to return + } template std::vector>> getphonebound(const UTTREF &uttref) // return sub-vector of classids[] for a given utterance { @@ -882,13 +931,16 @@ class minibatchutterancesourcemulti : public minibatchsource // constructor // Pass empty labels to denote unsupervised training (so getbatch() will not return uids). // This mode requires utterances with time stamps. - minibatchutterancesourcemulti(bool useMersenneTwister, const std::vector> &infiles, const std::vector>> &labels, + + minibatchutterancesourcemulti(bool useMersenneTwister, const std::vector> &infiles, const std::vector, std::vector>>> &labels, + std::set& specialwordids, std::vector vdim, std::vector udim, std::vector leftcontext, std::vector rightcontext, size_t randomizationrange, const latticesource &lattices, const std::map &allwordtranscripts, const bool framemode, std::vector expandToUtt, const size_t maxUtteranceLength, const bool truncated) : vdim(vdim), leftcontext(leftcontext), rightcontext(rightcontext), sampperiod(0), featdim(0), randomizationrange(randomizationrange), currentsweep(SIZE_MAX), lattices(lattices), allwordtranscripts(allwordtranscripts), framemode(framemode), chunksinram(0), timegetbatch(0), verbosity(2), m_generatePhoneBoundaries(!lattices.empty()), m_frameRandomizer(randomizedchunks, useMersenneTwister), expandToUtt(expandToUtt), m_useMersenneTwister(useMersenneTwister), maxUtteranceLength(maxUtteranceLength), truncated(truncated) + , specialwordids(specialwordids) // [v-hansu] change framemode (lattices.empty()) into framemode (false) to run utterance mode without lattice // you also need to change another line, search : [v-hansu] comment out to run utterance mode without lattice { @@ -905,6 +957,7 @@ class minibatchutterancesourcemulti : public minibatchsource std::vector uttduration; // track utterance durations to determine utterance validity std::vector classidsbegin; + std::vector wordidsbegin; allchunks = std::vector>(infiles.size(), std::vector()); featdim = std::vector(infiles.size(), 0); @@ -917,6 +970,7 @@ class minibatchutterancesourcemulti : public minibatchsource foreach_index (i, labels) { classids.push_back(std::unique_ptr>(new biggrowablevector())); + wordids.push_back(std::unique_ptr>(new biggrowablevector())); if (m_generatePhoneBoundaries) phoneboundaries.push_back(std::unique_ptr>(new biggrowablevector())); @@ -945,7 +999,7 @@ class minibatchutterancesourcemulti : public minibatchsource foreach_index (i, infiles[m]) { - utterancedesc utterance(msra::asr::htkfeatreader::parsedpath(infiles[m][i]), 0); // mseltzer - is this foolproof for multiio? is classids always non-empty? + utterancedesc utterance(msra::asr::htkfeatreader::parsedpath(infiles[m][i]), 0, 0); const size_t uttframes = utterance.numframes(); // will throw if frame bounds not given --required to be given in this mode if (expandToUtt[m] && uttframes != 1) RuntimeError("minibatchutterancesource: utterance-based features must be 1 frame in duration"); @@ -1003,8 +1057,10 @@ class minibatchutterancesourcemulti : public minibatchsource // if (infiles[m].size()!=numutts) // RuntimeError("minibatchutterancesourcemulti: all feature files must have same number of utterances\n"); if (m == 0) + { classidsbegin.clear(); - + wordidsbegin.clear(); + } foreach_index (i, infiles[m]) { if (i % (infiles[m].size() / 100 + 1) == 0) @@ -1014,11 +1070,14 @@ class minibatchutterancesourcemulti : public minibatchsource } // build utterance descriptor if (m == 0 && !labels.empty()) + { classidsbegin.push_back(classids[0]->size()); + wordidsbegin.push_back(wordids[0]->size()); + } if (uttisvalid[i]) { - utterancedesc utterance(msra::asr::htkfeatreader::parsedpath(infiles[m][i]), labels.empty() ? 0 : classidsbegin[i]); // mseltzer - is this foolproof for multiio? is classids always non-empty? + utterancedesc utterance(msra::asr::htkfeatreader::parsedpath(infiles[m][i]), labels.empty() ? 0 : classidsbegin[i], labels.empty() ? 0 : wordidsbegin[i]); // mseltzer - is this foolproof for multiio? is classids always non-empty? const size_t uttframes = utterance.numframes(); // will throw if frame bounds not given --required to be given in this mode if (expandToUtt[m]) { @@ -1078,7 +1137,7 @@ class minibatchutterancesourcemulti : public minibatchsource // first verify that all the label files have the proper duration foreach_index (j, labels) { - const auto &labseq = labels[j].find(key)->second; + const auto &labseq = labels[j].find(key)->second.first; // check if durations match; skip if not size_t labframes = labseq.empty() ? 0 : (labseq[labseq.size() - 1].firstframe + labseq[labseq.size() - 1].numframes); if (labframes != uttframes) @@ -1092,12 +1151,12 @@ class minibatchutterancesourcemulti : public minibatchsource } if (uttisvalid[i]) { - utteranceset.push_back(std::move(utterance)); _totalframes += uttframes; // then parse each mlf if the durations are consistent foreach_index (j, labels) { - const auto &labseq = labels[j].find(key)->second; + const auto & seqs = labels[j].find(key)->second; + const auto &labseq = seqs.first; // expand classid sequence into flat array foreach_index (i2, labseq) { @@ -1126,18 +1185,37 @@ class minibatchutterancesourcemulti : public minibatchsource } classids[j]->push_back((CLASSIDTYPE) -1); // append a boundary marker marker for checking + + if (m_generatePhoneBoundaries) phoneboundaries[j]->push_back((HMMIDTYPE) -1); // append a boundary marker marker for checking - if (!labels[j].empty() && classids[j]->size() != _totalframes + utteranceset.size()) + if (!labels[j].empty() && classids[j]->size() != _totalframes + utteranceset.size() + 1) LogicError("minibatchutterancesource: label duration inconsistent with feature file in MLF label set: %ls", key.c_str()); - assert(labels[j].empty() || classids[j]->size() == _totalframes + utteranceset.size()); + assert(labels[j].empty() || classids[j]->size() == _totalframes + utteranceset.size() + 1); + + const auto &wordlabseq = seqs.second; + + if (j == 0) + utterance.setnumwords(short(wordlabseq.size())); + + foreach_index(i2, wordlabseq) + { + const auto &e = wordlabseq[i2]; + if (e != (WORDIDTYPE)e) + RuntimeError("WORDIDTYPE has too few bits"); + + wordids[j]->push_back(e); + } + wordids[j]->push_back((WORDIDTYPE)-1); // append a boundary marker marker for checking } + utteranceset.push_back(std::move(utterance)); + } } else { - assert(classids.empty() && labels.empty()); + assert(classids.empty() && labels.empty() && wordids.empty()); utteranceset.push_back(std::move(utterance)); _totalframes += uttframes; } @@ -1424,6 +1502,7 @@ class minibatchutterancesourcemulti : public minibatchsource auto &uttref = randomizedutterancerefs[i]; uttref.globalts = t; uttref.numframes = randomizedchunks[0][uttref.chunkindex].getchunkdata().numframes(uttref.utteranceindex()); + uttref.numwords = randomizedchunks[0][uttref.chunkindex].getchunkdata().numwords(uttref.utteranceindex()); t = uttref.globalte(); } assert(t == sweepts + _totalframes); @@ -1486,6 +1565,7 @@ class minibatchutterancesourcemulti : public minibatchsource // Returns true if we actually did read something. bool requirerandomizedchunk(const size_t chunkindex, const size_t windowbegin, const size_t windowend) { + size_t numinram = 0; if (chunkindex < windowbegin || chunkindex >= windowend) @@ -1510,7 +1590,7 @@ class minibatchutterancesourcemulti : public minibatchsource fprintf(stderr, "feature set %d: requirerandomizedchunk: paging in randomized chunk %d (frame range [%d..%d]), %d resident in RAM\n", m, (int) chunkindex, (int) chunk.globalts, (int) (chunk.globalte() - 1), (int) (chunksinram + 1)); msra::util::attempt(5, [&]() // (reading from network) { - chunkdata.requiredata(featkind[m], featdim[m], sampperiod[m], this->lattices, verbosity); + chunkdata.requiredata(featkind[m], featdim[m], sampperiod[m], this->lattices, specialwordids, verbosity); }); } chunksinram++; @@ -1561,6 +1641,8 @@ class minibatchutterancesourcemulti : public minibatchsource verbosity = newverbosity; } + + // get the next minibatch // A minibatch is made up of one or more utterances. // We will return less than 'framesrequested' unless the first utterance is too long. @@ -1569,16 +1651,18 @@ class minibatchutterancesourcemulti : public minibatchsource // This is efficient since getbatch() is called with sequential 'globalts' except at epoch start. // Note that the start of an epoch does not necessarily fall onto an utterance boundary. The caller must use firstvalidglobalts() to find the first valid globalts at or after a given time. // Support for data parallelism: If mpinodes > 1 then we will + // - load only a subset of blocks from the disk // - skip frames/utterances in not-loaded blocks in the returned data // - 'framesadvanced' will still return the logical #frames; that is, by how much the global time index is advanced bool getbatch(const size_t globalts, const size_t framesrequested, const size_t subsetnum, const size_t numsubsets, size_t &framesadvanced, - std::vector &feat, std::vector> &uids, + std::vector &feat, std::vector> &uids, std::vector> &wids, std::vector> &nws, std::vector> &transcripts, std::vector> &latticepairs, std::vector> &sentendmark, std::vector> &phoneboundaries2) override { + bool readfromdisk = false; // return value: shall be 'true' if we paged in anything auto_timer timergetbatch; @@ -1624,6 +1708,7 @@ class minibatchutterancesourcemulti : public minibatchsource // determine the true #frames we return, for allocation--it is less than mbframes in the case of MPI/data-parallel sub-set mode size_t tspos = 0; + size_t twrds = 0; for (size_t pos = spos; pos < epos; pos++) { const auto &uttref = randomizedutterancerefs[pos]; @@ -1631,11 +1716,14 @@ class minibatchutterancesourcemulti : public minibatchsource continue; tspos += uttref.numframes; + twrds += uttref.numwords; } // resize feat and uids feat.resize(vdim.size()); uids.resize(classids.size()); + wids.resize(wordids.size()); + nws.resize(wordids.size()); if (m_generatePhoneBoundaries) phoneboundaries2.resize(classids.size()); sentendmark.resize(vdim.size()); @@ -1649,15 +1737,21 @@ class minibatchutterancesourcemulti : public minibatchsource { foreach_index (j, uids) { + nws[j].clear(); if (issupervised()) // empty means unsupervised training -> return empty uids { uids[j].resize(tspos); + wids[j].resize(twrds); if (m_generatePhoneBoundaries) phoneboundaries2[j].resize(tspos); } else { - uids[i].clear(); + // uids[i].clear(); + // guoye: i think original code above is a bug, i should be j + uids[j].clear(); + + wids[j].clear(); if (m_generatePhoneBoundaries) phoneboundaries2[i].clear(); } @@ -1674,6 +1768,7 @@ class minibatchutterancesourcemulti : public minibatchsource if (verbosity > 0) fprintf(stderr, "getbatch: getting utterances %d..%d (%d subset of %d frames out of %d requested) in sweep %d\n", (int) spos, (int) (epos - 1), (int) tspos, (int) mbframes, (int) framesrequested, (int) sweep); tspos = 0; // relative start of utterance 'pos' within the returned minibatch + twrds = 0; for (size_t pos = spos; pos < epos; pos++) { const auto &uttref = randomizedutterancerefs[pos]; @@ -1681,6 +1776,7 @@ class minibatchutterancesourcemulti : public minibatchsource continue; size_t n = 0; + size_t nw = 0; foreach_index (i, randomizedchunks) { const auto &chunk = randomizedchunks[i][uttref.chunkindex]; @@ -1692,6 +1788,7 @@ class minibatchutterancesourcemulti : public minibatchsource sentendmark[i].push_back(n + tspos); assert(n == uttframes.cols() && uttref.numframes == n && chunkdata.numframes(uttref.utteranceindex()) == n); + nw = uttref.numwords; // copy the frames and class labels for (size_t t = 0; t < n; t++) // t = time index into source utterance { @@ -1714,6 +1811,7 @@ class minibatchutterancesourcemulti : public minibatchsource if (i == 0) { auto uttclassids = getclassids(uttref); + auto uttwordids = getwordids(uttref); std::vector>> uttphoneboudaries; if (m_generatePhoneBoundaries) uttphoneboudaries = getphonebound(uttref); @@ -1742,9 +1840,23 @@ class minibatchutterancesourcemulti : public minibatchsource } } } + foreach_index(j, uttwordids) + { + nws[j].push_back(short(nw)); + + for (size_t t = 0; t < nw; t++) // t = time index into source utterance + { + if (issupervised()) + { + wids[j][t + twrds] = uttwordids[j][t]; + } + } + + } } } tspos += n; + twrds += nw; } foreach_index (i, feat) @@ -1795,6 +1907,7 @@ class minibatchutterancesourcemulti : public minibatchsource // resize feat and uids feat.resize(vdim.size()); uids.resize(classids.size()); + // no need to care about wids for framemode = true assert(feat.size() == vdim.size()); assert(feat.size() == randomizedchunks.size()); foreach_index (i, feat) @@ -1878,31 +1991,349 @@ class minibatchutterancesourcemulti : public minibatchsource return readfromdisk; } + // get the next minibatch + // A minibatch is made up of one or more utterances. + // We will return less than 'framesrequested' unless the first utterance is too long. + // Note that this may return frames that are beyond the epoch end, but the first frame is always within the epoch. + // We specify the utterance by its global start time (in a space of a infinitely repeated training set). + // This is efficient since getbatch() is called with sequential 'globalts' except at epoch start. + // Note that the start of an epoch does not necessarily fall onto an utterance boundary. The caller must use firstvalidglobalts() to find the first valid globalts at or after a given time. + // Support for data parallelism: If mpinodes > 1 then we will + // - load only a subset of blocks from the disk + // - skip frames/utterances in not-loaded blocks in the returned data + // - 'framesadvanced' will still return the logical #frames; that is, by how much the global time index is advanced + bool getbatch(const size_t globalts, const size_t framesrequested, + const size_t subsetnum, const size_t numsubsets, size_t &framesadvanced, + std::vector &feat, std::vector> &uids, + std::vector> &transcripts, + std::vector> &latticepairs, std::vector> &sentendmark, + std::vector> &phoneboundaries2) override + { + + bool readfromdisk = false; // return value: shall be 'true' if we paged in anything + + auto_timer timergetbatch; + assert(_totalframes > 0); + + // update randomization if a new sweep is entered --this is a complex operation that updates many of the data members used below + const size_t sweep = lazyrandomization(globalts); + + size_t mbframes = 0; + const std::vector noboundaryflags; // dummy + if (!framemode) // regular utterance mode + { + // find utterance position for globalts + // There must be a precise match; it is not possible to specify frames that are not on boundaries. + auto positer = randomizedutteranceposmap.find(globalts); + if (positer == randomizedutteranceposmap.end()) + LogicError("getbatch: invalid 'globalts' parameter; must match an existing utterance boundary"); + const size_t spos = positer->second; + + // determine how many utterances will fit into the requested minibatch size + mbframes = randomizedutterancerefs[spos].numframes; // at least one utterance, even if too long + size_t epos; + for (epos = spos + 1; epos < numutterances && ((mbframes + randomizedutterancerefs[epos].numframes) < framesrequested); epos++) // add more utterances as long as they fit within requested minibatch size + mbframes += randomizedutterancerefs[epos].numframes; + + // do some paging housekeeping + // This will also set the feature-kind information if it's the first time. + // Free all chunks left of the range. + // Page-in all chunks right of the range. + // We are a little more blunt for now: Free all outside the range, and page in only what is touched. We could save some loop iterations. + const size_t windowbegin = positionchunkwindows[spos].windowbegin(); + const size_t windowend = positionchunkwindows[epos - 1].windowend(); + for (size_t k = 0; k < windowbegin; k++) + releaserandomizedchunk(k); + for (size_t k = windowend; k < randomizedchunks[0].size(); k++) + releaserandomizedchunk(k); + for (size_t pos = spos; pos < epos; pos++) + if ((randomizedutterancerefs[pos].chunkindex % numsubsets) == subsetnum) + readfromdisk |= requirerandomizedchunk(randomizedutterancerefs[pos].chunkindex, windowbegin, windowend); // (window range passed in for checking only) + + // Note that the above loop loops over all chunks incl. those that we already should have. + // This has an effect, e.g., if 'numsubsets' has changed (we will fill gaps). + + // determine the true #frames we return, for allocation--it is less than mbframes in the case of MPI/data-parallel sub-set mode + size_t tspos = 0; + for (size_t pos = spos; pos < epos; pos++) + { + const auto &uttref = randomizedutterancerefs[pos]; + if ((uttref.chunkindex % numsubsets) != subsetnum) // chunk not to be returned for this MPI node + continue; + + tspos += uttref.numframes; + } + + // resize feat and uids + feat.resize(vdim.size()); + uids.resize(classids.size()); + + if (m_generatePhoneBoundaries) + phoneboundaries2.resize(classids.size()); + sentendmark.resize(vdim.size()); + assert(feat.size() == vdim.size()); + assert(feat.size() == randomizedchunks.size()); + foreach_index(i, feat) + { + feat[i].resize(vdim[i], tspos); + + if (i == 0) + { + foreach_index(j, uids) + { + if (issupervised()) // empty means unsupervised training -> return empty uids + { + uids[j].resize(tspos); + if (m_generatePhoneBoundaries) + phoneboundaries2[j].resize(tspos); + } + else + { + uids[i].clear(); + if (m_generatePhoneBoundaries) + phoneboundaries2[i].clear(); + } + latticepairs.clear(); // will push_back() below + transcripts.clear(); + } + foreach_index(j, sentendmark) + { + sentendmark[j].clear(); + } + } + } + // return these utterances + if (verbosity > 0) + fprintf(stderr, "getbatch: getting utterances %d..%d (%d subset of %d frames out of %d requested) in sweep %d\n", (int)spos, (int)(epos - 1), (int)tspos, (int)mbframes, (int)framesrequested, (int)sweep); + tspos = 0; // relative start of utterance 'pos' within the returned minibatch + for (size_t pos = spos; pos < epos; pos++) + { + const auto &uttref = randomizedutterancerefs[pos]; + if ((uttref.chunkindex % numsubsets) != subsetnum) // chunk not to be returned for this MPI node + continue; + + size_t n = 0; + foreach_index(i, randomizedchunks) + { + const auto &chunk = randomizedchunks[i][uttref.chunkindex]; + const auto &chunkdata = chunk.getchunkdata(); + assert((numsubsets > 1) || (uttref.globalts == globalts + tspos)); + auto uttframes = chunkdata.getutteranceframes(uttref.utteranceindex()); + matrixasvectorofvectors uttframevectors(uttframes); // (wrapper that allows m[j].size() and m[j][i] as required by augmentneighbors()) + n = uttframevectors.size(); + sentendmark[i].push_back(n + tspos); + assert(n == uttframes.cols() && uttref.numframes == n && chunkdata.numframes(uttref.utteranceindex()) == n); + + // copy the frames and class labels + for (size_t t = 0; t < n; t++) // t = time index into source utterance + { + size_t leftextent, rightextent; + // page in the needed range of frames + if (leftcontext[i] == 0 && rightcontext[i] == 0) + { + leftextent = rightextent = augmentationextent(uttframevectors[t].size(), vdim[i]); + } + else + { + leftextent = leftcontext[i]; + rightextent = rightcontext[i]; + } + augmentneighbors(uttframevectors, noboundaryflags, t, leftextent, rightextent, feat[i], t + tspos); + // augmentneighbors(uttframevectors, noboundaryflags, t, feat[i], t + tspos); + } + + // copy the frames and class labels + if (i == 0) + { + auto uttclassids = getclassids(uttref); + std::vector>> uttphoneboudaries; + if (m_generatePhoneBoundaries) + uttphoneboudaries = getphonebound(uttref); + foreach_index(j, uttclassids) + { + for (size_t t = 0; t < n; t++) // t = time index into source utterance + { + if (issupervised()) + { + uids[j][t + tspos] = uttclassids[j][t]; + if (m_generatePhoneBoundaries) + phoneboundaries2[j][t + tspos] = uttphoneboudaries[j][t]; + } + } + + if (!this->lattices.empty()) + { + auto latticepair = chunkdata.getutterancelattice(uttref.utteranceindex()); + latticepairs.push_back(latticepair); + // look up reference + const auto &key = latticepair->getkey(); + if (!allwordtranscripts.empty()) + { + const auto &transcript = allwordtranscripts.find(key)->second; + transcripts.push_back(transcript.words); + } + } + } + } + } + tspos += n; + } + + foreach_index(i, feat) + { + assert(tspos == feat[i].cols()); + } + } + else + { + const size_t sweepts = sweep * _totalframes; // first global frame index for this sweep + const size_t sweepte = sweepts + _totalframes; // and its end + const size_t globalte = std::min(globalts + framesrequested, sweepte); // we return as much as requested, but not exceeding sweep end + mbframes = globalte - globalts; // that's our mb size + + // Perform randomization of the desired frame range + m_frameRandomizer.randomizeFrameRange(globalts, globalte); + + // determine window range + // We enumerate all frames--can this be done more efficiently? + const size_t firstchunk = chunkforframepos(globalts); + const size_t lastchunk = chunkforframepos(globalte - 1); + const size_t windowbegin = randomizedchunks[0][firstchunk].windowbegin; + const size_t windowend = randomizedchunks[0][lastchunk].windowend; + if (verbosity > 0) + fprintf(stderr, "getbatch: getting randomized frames [%d..%d] (%d frames out of %d requested) in sweep %d; chunks [%d..%d] -> chunk window [%d..%d)\n", + (int)globalts, (int)globalte, (int)mbframes, (int)framesrequested, (int)sweep, (int)firstchunk, (int)lastchunk, (int)windowbegin, (int)windowend); + // release all data outside, and page in all data inside + for (size_t k = 0; k < windowbegin; k++) + releaserandomizedchunk(k); + for (size_t k = windowbegin; k < windowend; k++) + if ((k % numsubsets) == subsetnum) // in MPI mode, we skip chunks this way + readfromdisk |= requirerandomizedchunk(k, windowbegin, windowend); // (window range passed in for checking only, redundant here) + for (size_t k = windowend; k < randomizedchunks[0].size(); k++) + releaserandomizedchunk(k); + + // determine the true #frames we return--it is less than mbframes in the case of MPI/data-parallel sub-set mode + // First determine it for all nodes, then pick the min over all nodes, as to give all the same #frames for better load balancing. + // TODO: No, return all; and leave it to caller to redistribute them [Zhijie Yan] + std::vector subsetsizes(numsubsets, 0); + for (size_t i = 0; i < mbframes; i++) // i is input frame index; j < i in case of MPI/data-parallel sub-set mode + { + const frameref &frameref = m_frameRandomizer.randomizedframeref(globalts + i); + subsetsizes[frameref.chunkindex % numsubsets]++; + } + size_t j = subsetsizes[subsetnum]; // return what we have --TODO: we can remove the above full computation again now + const size_t allocframes = std::max(j, (mbframes + numsubsets - 1) / numsubsets); // we leave space for the desired #frames, assuming caller will try to pad them later + + // resize feat and uids + feat.resize(vdim.size()); + uids.resize(classids.size()); + assert(feat.size() == vdim.size()); + assert(feat.size() == randomizedchunks.size()); + foreach_index(i, feat) + { + feat[i].resize(vdim[i], allocframes); + feat[i].shrink(vdim[i], j); + + if (i == 0) + { + foreach_index(k, uids) + { + if (issupervised()) // empty means unsupervised training -> return empty uids + uids[k].resize(j); + else + uids[k].clear(); + latticepairs.clear(); // will push_back() below + transcripts.clear(); + } + } + } + + // return randomized frames for the time range of those utterances + size_t currmpinodeframecount = 0; + for (size_t j2 = 0; j2 < mbframes; j2++) + { + if (currmpinodeframecount >= feat[0].cols()) // MPI/data-parallel mode: all nodes return the same #frames, which is how feat(,) is allocated + break; + + // map to time index inside arrays + const frameref &frameref = m_frameRandomizer.randomizedframeref(globalts + j2); + + // in MPI/data-parallel mode, skip frames that are not in chunks loaded for this MPI node + if ((frameref.chunkindex % numsubsets) != subsetnum) + continue; + + // random utterance + readfromdisk |= requirerandomizedchunk(frameref.chunkindex, windowbegin, windowend); // (this is just a check; should not actually page in anything) + + foreach_index(i, randomizedchunks) + { + const auto &chunk = randomizedchunks[i][frameref.chunkindex]; + const auto &chunkdata = chunk.getchunkdata(); + auto uttframes = chunkdata.getutteranceframes(frameref.utteranceindex()); + matrixasvectorofvectors uttframevectors(uttframes); // (wrapper that allows m[.].size() and m[.][.] as required by augmentneighbors()) + const size_t n = uttframevectors.size(); + assert(n == uttframes.cols() && chunkdata.numframes(frameref.utteranceindex()) == n); + n; + + // copy frame and class labels + const size_t t = frameref.frameindex(); + + size_t leftextent, rightextent; + // page in the needed range of frames + if (leftcontext[i] == 0 && rightcontext[i] == 0) + { + leftextent = rightextent = augmentationextent(uttframevectors[t].size(), vdim[i]); + } + else + { + leftextent = leftcontext[i]; + rightextent = rightcontext[i]; + } + augmentneighbors(uttframevectors, noboundaryflags, t, leftextent, rightextent, feat[i], currmpinodeframecount); + + if (issupervised() && i == 0) + { + auto frameclassids = getclassids(frameref); + foreach_index(k, uids) + uids[k][currmpinodeframecount] = frameclassids[k][t]; + } + } + + currmpinodeframecount++; + } + } + timegetbatch = timergetbatch; + + // this is the number of frames we actually moved ahead in time + framesadvanced = mbframes; + + return readfromdisk; + } bool supportsbatchsubsetting() const override { return true; } bool getbatch(const size_t globalts, - const size_t framesrequested, std::vector &feat, std::vector> &uids, - std::vector> &transcripts, - std::vector> &lattices2, std::vector> &sentendmark, - std::vector> &phoneboundaries2) + const size_t framesrequested, std::vector &feat, std::vector> &uids, + std::vector> &transcripts, + std::vector> &lattices2, std::vector> &sentendmark, + std::vector> &phoneboundaries2) + { size_t dummy; return getbatch(globalts, framesrequested, 0, 1, dummy, feat, uids, transcripts, lattices2, sentendmark, phoneboundaries2); } - double gettimegetbatch() - { - return timegetbatch; - } + + + // alternate (updated) definition for multiple inputs/outputs - read as a vector of feature matrixes or a vector of label strings bool getbatch(const size_t /*globalts*/, - const size_t /*framesrequested*/, msra::dbn::matrix & /*feat*/, std::vector & /*uids*/, - std::vector> & /*transcripts*/, - std::vector> & /*latticepairs*/) + const size_t /*framesrequested*/, msra::dbn::matrix & /*feat*/, std::vector & /*uids*/, + std::vector> & /*transcripts*/, + std::vector> & /*latticepairs*/) override { // should never get here RuntimeError("minibatchframesourcemulti: getbatch() being called for single input feature and single output feature, should use minibatchutterancesource instead\n"); @@ -1912,6 +2343,14 @@ class minibatchutterancesourcemulti : public minibatchsource // uids.resize(1); // return getbatch(globalts, framesrequested, feat[0], uids[0], transcripts, latticepairs); } + + + double gettimegetbatch() + { + return timegetbatch; + } + + size_t totalframes() const { diff --git a/Source/Readers/Kaldi2Reader/utterancesourcemulti.h b/Source/Readers/Kaldi2Reader/utterancesourcemulti.h index 77619d6642de..a3eafac9da2a 100644 --- a/Source/Readers/Kaldi2Reader/utterancesourcemulti.h +++ b/Source/Readers/Kaldi2Reader/utterancesourcemulti.h @@ -36,6 +36,8 @@ class minibatchutterancesourcemulti : public minibatchsource // const std::vector> &lattices; const latticesource &lattices; + + // std::vector lattices; // word-level transcripts (for MMI mode when adding best path to lattices) const std::map &allwordtranscripts; // (used for getting word-level transcripts) @@ -158,7 +160,13 @@ class minibatchutterancesourcemulti : public minibatchsource reader.readNoAlloc(utteranceset[i].parsedpath, (const string &) featkind, sampperiod, uttframes); // note: file info here used for checkuing only // page in lattice data if (!latticesource.empty()) - latticesource.getlattices(utteranceset[i].key(), lattices[i], uttframes.cols()); + // we currently don't care about kaldi format, so, just to make the compiler happy + // latticesource.getlattices(utteranceset[i].key(), lattices[i], uttframes.cols()); + { + std::set specialwordids; + specialwordids.clear(); + latticesource.getlattices(utteranceset[i].key(), lattices[i], uttframes.cols(), specialwordids); + } } // fprintf (stderr, "\n"); if (verbosity) diff --git a/Source/SGDLib/DataReaderHelpers.h b/Source/SGDLib/DataReaderHelpers.h index 3fb8c750a0a6..78cd3457302d 100644 --- a/Source/SGDLib/DataReaderHelpers.h +++ b/Source/SGDLib/DataReaderHelpers.h @@ -46,7 +46,8 @@ namespace Microsoft { namespace MSR { namespace CNTK { bool useParallelTrain, StreamMinibatchInputs& inputMatrices, size_t& actualMBSize, - const MPIWrapperPtr& mpi) + const MPIWrapperPtr& mpi, + size_t& actualNumWords) { // Reading consists of a sequence of Reader API calls: // - GetMinibatch() --fills the inputMatrices and copies the MBLayout from Reader into inputMatrices @@ -71,8 +72,13 @@ namespace Microsoft { namespace MSR { namespace CNTK { auto uids = node->getuidprt(); auto boundaries = node->getboundaryprt(); auto extrauttmap = node->getextrauttmap(); + auto wids = node->getwidprt(); + auto nws = node->getnwprt(); + trainSetDataReader.GetMinibatch4SE(*latticeinput, *uids, *wids, *nws, *boundaries, *extrauttmap); - trainSetDataReader.GetMinibatch4SE(*latticeinput, *uids, *boundaries, *extrauttmap); + actualNumWords = 0; + for (size_t i = 0; i < (*nws).size(); i++) + actualNumWords += (*nws)[i]; } // TODO: move this into shim for the old readers. @@ -284,11 +290,15 @@ namespace Microsoft { namespace MSR { namespace CNTK { private: typedef std::vector> Lattice; typedef std::vector Uid; + typedef std::vector Wid; + typedef std::vector Nw; typedef std::vector ExtrauttMap; typedef std::vector Boundaries; typedef std::vector>* LatticePtr; typedef std::vector* UidPtr; + typedef std::vector* WidPtr; + typedef std::vector* NwPtr; typedef std::vector* ExtrauttMapPtr; typedef std::vector* BoundariesPtr; typedef StreamMinibatchInputs Matrices; @@ -298,6 +308,8 @@ namespace Microsoft { namespace MSR { namespace CNTK { MBLayoutPtr m_MBLayoutCache; Lattice m_LatticeCache; Uid m_uidCache; + Wid m_widCache; + Nw m_nwCache; ExtrauttMap m_extrauttmapCache; Boundaries m_BoundariesCache; shared_ptr> m_netCriterionAccumulator; @@ -313,6 +325,8 @@ namespace Microsoft { namespace MSR { namespace CNTK { Matrices m_netInputMatrixPtr; LatticePtr m_netLatticePtr; UidPtr m_netUidPtr; + WidPtr m_netWidPtr; + NwPtr m_netNwPtr; ExtrauttMapPtr m_netExtrauttMapPtr; BoundariesPtr m_netBoundariesPtr; // we remember the pointer to the learnable Nodes so that we can accumulate the gradient once a sub-minibatch is done @@ -352,7 +366,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { public: SubminibatchDispatcher() - : m_MBLayoutCache(nullptr), m_netLatticePtr(nullptr), m_netExtrauttMapPtr(nullptr), m_netUidPtr(nullptr), m_netBoundariesPtr(nullptr) + : m_MBLayoutCache(nullptr), m_netLatticePtr(nullptr), m_netExtrauttMapPtr(nullptr), m_netUidPtr(nullptr), m_netBoundariesPtr(nullptr), m_netWidPtr(nullptr), m_netNwPtr(nullptr) { } @@ -398,6 +412,8 @@ namespace Microsoft { namespace MSR { namespace CNTK { m_netLatticePtr = node->getLatticePtr(); m_netExtrauttMapPtr = node->getextrauttmap(); m_netUidPtr = node->getuidprt(); + m_netWidPtr = node->getwidprt(); + m_netNwPtr = node->getnwprt(); m_netBoundariesPtr = node->getboundaryprt(); m_hasLattices = true; } @@ -408,6 +424,8 @@ namespace Microsoft { namespace MSR { namespace CNTK { m_netUidPtr = nullptr; m_netBoundariesPtr = nullptr; m_hasLattices = false; + m_netWidPtr = nullptr; + m_netNwPtr = nullptr; } } @@ -444,11 +462,16 @@ namespace Microsoft { namespace MSR { namespace CNTK { m_uidCache.clear(); m_extrauttmapCache.clear(); m_BoundariesCache.clear(); + m_widCache.clear(); + m_nwCache.clear(); + m_LatticeCache = *m_netLatticePtr; m_uidCache = *m_netUidPtr; m_extrauttmapCache = *m_netExtrauttMapPtr; m_BoundariesCache = *m_netBoundariesPtr; + m_widCache = *m_netWidPtr; + m_nwCache = *m_netNwPtr; } // subminibatches are cutted at the parallel sequence level; @@ -495,10 +518,14 @@ namespace Microsoft { namespace MSR { namespace CNTK { BoundariesPtr decimatedBoundaryPtr, /* output: boundary after decimation*/ ExtrauttMapPtr decimatedExtraMapPtr, /* output: extramap after decimation*/ UidPtr decimatedUidPtr, /* output: Uid after decimation*/ + WidPtr decimatedWidPtr, /* output: Wid after decimation*/ + NwPtr decimatedNwPtr, /* output: Nw after decimation*/ const Lattice lattices, /* input: lattices to be decimated */ const Boundaries boundaries, /* input: boundary to be decimated */ const ExtrauttMap extraMaps, /* input: extra map to be decimated */ const Uid uids, /* input: uid to be decimated*/ + const Wid wids, /* input: uid to be decimated*/ + const Nw nws, /* input: uid to be decimated*/ pair parallelSeqRange /* input: what parallel sequence range we are looking at */ ) { @@ -509,12 +536,16 @@ namespace Microsoft { namespace MSR { namespace CNTK { decimatedBoundaryPtr->clear(); decimatedExtraMapPtr->clear(); decimatedUidPtr->clear(); + decimatedWidPtr->clear(); + decimatedNwPtr->clear(); size_t stFrame = 0; + size_t stWord = 0; for (size_t iUtt = 0; iUtt < extraMaps.size(); iUtt++) { size_t numFramesInThisUtterance = lattices[iUtt]->getnumframes(); size_t iParallelSeq = extraMaps[iUtt]; // i-th utterance belongs to iParallelSeq-th parallel sequence + size_t numWordsInThisUtterance = nws[iUtt]; if (iParallelSeq >= parallelSeqStId && iParallelSeq < parallelSeqEnId) { // this utterance has been selected @@ -522,8 +553,11 @@ namespace Microsoft { namespace MSR { namespace CNTK { decimatedBoundaryPtr->insert(decimatedBoundaryPtr->end(), boundaries.begin() + stFrame, boundaries.begin() + stFrame + numFramesInThisUtterance); decimatedUidPtr->insert(decimatedUidPtr->end(), uids.begin() + stFrame, uids.begin() + stFrame + numFramesInThisUtterance); decimatedExtraMapPtr->push_back(extraMaps[iUtt] - parallelSeqStId); + decimatedWidPtr->insert(decimatedWidPtr->end(), wids.begin() + stWord, wids.begin() + stWord + numWordsInThisUtterance); + decimatedNwPtr->push_back(numWordsInThisUtterance); } stFrame += numFramesInThisUtterance; + stWord += numWordsInThisUtterance; } } @@ -539,9 +573,9 @@ namespace Microsoft { namespace MSR { namespace CNTK { { DecimateLattices( /*output */ - m_netLatticePtr, m_netBoundariesPtr, m_netExtrauttMapPtr, m_netUidPtr, + m_netLatticePtr, m_netBoundariesPtr, m_netExtrauttMapPtr, m_netUidPtr, m_netWidPtr, m_netNwPtr, /*input to be decimated */ - m_LatticeCache, m_BoundariesCache, m_extrauttmapCache, m_uidCache, + m_LatticeCache, m_BoundariesCache, m_extrauttmapCache, m_uidCache, m_widCache, m_nwCache, /* what range we want ? */ seqRange); } diff --git a/Source/SGDLib/PostComputingActions.cpp b/Source/SGDLib/PostComputingActions.cpp index d67a7e3b3efd..03fd825e2fe1 100644 --- a/Source/SGDLib/PostComputingActions.cpp +++ b/Source/SGDLib/PostComputingActions.cpp @@ -81,6 +81,7 @@ void PostComputingActions::BatchNormalizationStatistics(IDataReader * let bnNode = static_pointer_cast>(node); size_t actualMBSize = 0; + size_t actualNumWords = 0; LOGPRINTF(stderr, "Estimating Statistics --> %ls\n", bnNode->GetName().c_str()); @@ -90,8 +91,7 @@ void PostComputingActions::BatchNormalizationStatistics(IDataReader * { // during the bn stat, dataRead must be ensured bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork(*dataReader, m_net, - nullptr, useDistributedMBReading, useParallelTrain, inputMatrices, actualMBSize, m_mpi); - + nullptr, useDistributedMBReading, useParallelTrain, inputMatrices, actualMBSize, m_mpi, actualNumWords); if (!wasDataRead) LogicError("DataRead Failure in batch normalization statistics"); ComputationNetwork::BumpEvalTimeStamp(featureNodes); diff --git a/Source/SGDLib/SGD.cpp b/Source/SGDLib/SGD.cpp index a924a92560cf..7190851b91f1 100644 --- a/Source/SGDLib/SGD.cpp +++ b/Source/SGDLib/SGD.cpp @@ -262,7 +262,6 @@ void SGD::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net, } } } - std::vector additionalNodesToEvaluate; // Do not include the output nodes in the matrix sharing structure when using forward value matrix @@ -273,13 +272,10 @@ void SGD::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net, auto& outputNodes = net->OutputNodes(); additionalNodesToEvaluate.insert(additionalNodesToEvaluate.end(), outputNodes.cbegin(), outputNodes.cend()); } - auto preComputeNodesList = net->GetNodesRequiringPreComputation(); additionalNodesToEvaluate.insert(additionalNodesToEvaluate.end(), preComputeNodesList.cbegin(), preComputeNodesList.cend()); - // allocate memory for forward and backward computation net->AllocateAllMatrices(evaluationNodes, additionalNodesToEvaluate, criterionNodes[0]); // TODO: use criterionNodes.front() throughout - // get feature and label nodes into an array of matrices that will be passed to GetMinibatch() // TODO: instead, remember the nodes directly, to be able to handle both float and double nodes; current version will crash for mixed networks StreamMinibatchInputs* inputMatrices = new StreamMinibatchInputs(); @@ -293,14 +289,12 @@ void SGD::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net, for (const auto & node : nodes) inputMatrices->AddInput(node->NodeName(), node->ValuePtr(), node->GetMBLayout(), node->GetSampleLayout()); } - // get hmm file for sequence training bool isSequenceTrainingCriterion = (criterionNodes[0]->OperationName() == L"SequenceWithSoftmax"); if (isSequenceTrainingCriterion) { // SequenceWithSoftmaxNode* node = static_cast*>(criterionNodes[0]); auto node = dynamic_pointer_cast>(criterionNodes[0]); - auto hmm = node->gethmm(); trainSetDataReader->GetHmmData(hmm); } @@ -329,7 +323,6 @@ void SGD::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net, // allocate memory for forward computation refNet->AllocateAllMatrices({refNode}, {}, nullptr); } - // initializing weights and gradient holder // only one criterion so far TODO: support multiple ones? auto& learnableNodes = net->LearnableParameterNodes(criterionNodes[0]); @@ -483,7 +476,8 @@ void SGD::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net, if (isSequenceTrainingCriterion) { ComputationNetwork::SetSeqParam(net, criterionNodes[0], m_hSmoothingWeight, m_frameDropThresh, m_doReferenceAlign, - m_seqGammarCalcAMF, m_seqGammarCalcLMF, m_seqGammarCalcWP, m_seqGammarCalcbMMIFactor, m_seqGammarCalcUsesMBR); + m_seqGammarCalcAMF, m_seqGammarCalcLMF, m_seqGammarCalcWP, m_seqGammarCalcbMMIFactor, m_seqGammarCalcUsesMBR, + m_seqGammarCalcUseEMBR, m_EMBRUnit, m_numPathsEMBR, m_enforceValidPathEMBR, m_getPathMethodEMBR, m_showWERMode, m_excludeSpecialWords, m_wordNbest, m_useAccInNbest, m_accWeightInNbest, m_numRawPathsEMBR); } // Multiverso Warpper for ASGD logic init @@ -660,8 +654,11 @@ void SGD::TrainOrAdaptModel(int startEpoch, ComputationNetworkPtr net, learnableNodes, smoothedGradients, smoothedCounts, epochCriterion, epochEvalErrors, "", SIZE_MAX, totalMBsSeen, tensorBoardWriter, startEpoch); - totalTrainingSamplesSeen += epochCriterion.second; // aggregate #training samples, for logging purposes only - + + if(!m_seqGammarCalcUseEMBR) + totalTrainingSamplesSeen += epochCriterion.second; + else + totalTrainingSamplesSeen += epochEvalErrors[0].second; timer.Stop(); double epochTime = timer.ElapsedSeconds(); @@ -1167,10 +1164,11 @@ size_t SGD::TrainOneEpoch(ComputationNetworkPtr net, // get minibatch // TODO: is it guaranteed that the GPU is already completed at this point, is it safe to overwrite the buffers? size_t actualMBSize = 0; + size_t actualNumWords = 0; auto profGetMinibatch = ProfilerTimeBegin(); bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork(*trainSetDataReader, net, criterionNodes[0], - useDistributedMBReading, useParallelTrain, *inputMatrices, actualMBSize, m_mpi); + useDistributedMBReading, useParallelTrain, *inputMatrices, actualMBSize, m_mpi, actualNumWords); if (maxNumSamplesExceeded) // Dropping data. wasDataRead = false; @@ -1294,7 +1292,11 @@ size_t SGD::TrainOneEpoch(ComputationNetworkPtr net, // accumulate criterion values (objective, eval) assert(wasDataRead || numSamplesWithLabelOfNetwork == 0); // criteria are in Value()(0,0), we accumulate into another 1x1 Matrix (to avoid having to pull the values off the GPU) - localEpochCriterion.Add(0, numSamplesWithLabelOfNetwork); + if(!m_seqGammarCalcUseEMBR) + localEpochCriterion.Add(0, numSamplesWithLabelOfNetwork); + else + localEpochCriterion.Add(0, actualNumWords); + for (size_t i = 0; i < evaluationNodes.size(); i++) localEpochEvalErrors.Add(i, numSamplesWithLabelOfNetwork); } @@ -1326,14 +1328,22 @@ size_t SGD::TrainOneEpoch(ComputationNetworkPtr net, } // hoist the criterion into CPU space for all-reduce - localEpochCriterion.Assign(0, numSamplesWithLabelOfNetwork); + + if (!m_seqGammarCalcUseEMBR) + localEpochCriterion.Assign(0, numSamplesWithLabelOfNetwork); + else + localEpochCriterion.Assign(0, actualNumWords); + for (size_t i = 0; i < evaluationNodes.size(); i++) localEpochEvalErrors.Assign(i, numSamplesWithLabelOfNetwork); // copy all values to be aggregated into the header m_gradHeader->numSamples = aggregateNumSamples; m_gradHeader->criterion = localEpochCriterion.GetCriterion(0).first; - m_gradHeader->numSamplesWithLabel = localEpochCriterion.GetCriterion(0).second; // same as aggregateNumSamplesWithLabel + if (!m_seqGammarCalcUseEMBR) + m_gradHeader->numSamplesWithLabel = localEpochCriterion.GetCriterion(0).second; // same as aggregateNumSamplesWithLabel + else + m_gradHeader->numSamplesWithLabel = numSamplesWithLabelOfNetwork; assert(m_gradHeader->numSamplesWithLabel == aggregateNumSamplesWithLabel); for (size_t i = 0; i < evaluationNodes.size(); i++) m_gradHeader->evalErrors[i] = localEpochEvalErrors.GetCriterion(i); @@ -1482,7 +1492,9 @@ size_t SGD::TrainOneEpoch(ComputationNetworkPtr net, // epochCriterion aggregates over entire epoch, but we only show difference to last time we logged EpochCriterion epochCriterionSinceLastLogged = epochCriterion - epochCriterionLastLogged; let trainLossSinceLastLogged = epochCriterionSinceLastLogged.Average(); // TODO: Check whether old trainSamplesSinceLastLogged matches this ^^ difference - let trainSamplesSinceLastLogged = (int)epochCriterionSinceLastLogged.second; + + // for EMBR, epochCriterionSinceLastLogged.second stores the #words rather than #frames + let trainSamplesSinceLastLogged = (m_seqGammarCalcUseEMBR? (int)(epochEvalErrors[0].second - epochEvalErrorsLastLogged[0].second) : (int)epochCriterionSinceLastLogged.second); // determine progress in percent int mbProgNumPrecision = 2; @@ -1777,7 +1789,9 @@ bool SGD::PreCompute(ComputationNetworkPtr net, const size_t numIterationsBeforePrintingProgress = 100; size_t numItersSinceLastPrintOfProgress = 0; size_t actualMBSizeDummy; - while (DataReaderHelpers::GetMinibatchIntoNetwork(*trainSetDataReader, net, nullptr, false, false, *inputMatrices, actualMBSizeDummy, m_mpi)) + size_t actualNumWordsDummy; + + while (DataReaderHelpers::GetMinibatchIntoNetwork(*trainSetDataReader, net, nullptr, false, false, *inputMatrices, actualMBSizeDummy, m_mpi, actualNumWordsDummy)) { // TODO: move these into GetMinibatchIntoNetwork() --but those are passed around; necessary? Can't we get them from 'net'? ComputationNetwork::BumpEvalTimeStamp(featureNodes); @@ -2987,6 +3001,40 @@ SGDParams::SGDParams(const ConfigRecordType& configSGD, size_t sizeofElemType) m_frameDropThresh = configSGD(L"frameDropThresh", 1e-10); m_doReferenceAlign = configSGD(L"doReferenceAlign", false); m_seqGammarCalcUsesMBR = configSGD(L"seqGammarUsesMBR", false); + + m_seqGammarCalcUseEMBR = configSGD(L"seqGammarUseEMBR", false); + m_EMBRUnit = configSGD(L"EMBRUnit", "word"); + + m_numPathsEMBR = configSGD(L"numPathsEMBR", (size_t)100); + // enforce the path starting with sentence start + m_enforceValidPathEMBR = configSGD(L"enforceValidPathEMBR", false); + //could be sampling or nbest + m_getPathMethodEMBR = configSGD(L"getPathMethodEMBR", "sampling"); + // could be average or onebest + m_showWERMode = configSGD(L"showWERMode", "average"); + + // don't include path that has special words if true + m_excludeSpecialWords = configSGD(L"excludeSpecialWords", false); + + // true then, we force the nbest has different word sequence + m_wordNbest = configSGD(L"wordNbest", false); + m_useAccInNbest = configSGD(L"useAccInNbest", false); + m_accWeightInNbest = configSGD(L"accWeightInNbest", 1.0f); + + m_numRawPathsEMBR = configSGD(L"numRawPathsEMBR", (size_t)100); + + if (!m_useAccInNbest) + { + if (m_numRawPathsEMBR > m_numPathsEMBR) + { + fprintf(stderr, "SGDParams: WARNING: we do not use acc in nbest, so no need to make numRawPathsEMBR = %d larger than numPathsEMBR = %d \n", (int)m_numRawPathsEMBR, (int)m_numPathsEMBR); + } + } + if (m_getPathMethodEMBR == "sampling" && m_showWERMode == "onebest") + { + RuntimeError("There is no way to show onebest WER in sampling based EMBR"); + } + m_seqGammarCalcAMF = configSGD(L"seqGammarAMF", 14.0); m_seqGammarCalcLMF = configSGD(L"seqGammarLMF", 14.0); m_seqGammarCalcbMMIFactor = configSGD(L"seqGammarBMMIFactor", 0.0); diff --git a/Source/SGDLib/SGD.h b/Source/SGDLib/SGD.h index 04eb55073143..f146e61038c8 100644 --- a/Source/SGDLib/SGD.h +++ b/Source/SGDLib/SGD.h @@ -340,6 +340,18 @@ struct SGDParams : public ScriptableObjects::Object double m_seqGammarCalcbMMIFactor; bool m_seqGammarCalcUsesMBR; + bool m_seqGammarCalcUseEMBR; + string m_EMBRUnit; //unit could be: word, phone, state (we all compute edit distance + bool m_enforceValidPathEMBR; + string m_getPathMethodEMBR; + size_t m_numPathsEMBR; // number of sampled paths + string m_showWERMode; // number of sampled paths + bool m_excludeSpecialWords; + bool m_wordNbest; + bool m_useAccInNbest; + float m_accWeightInNbest; + size_t m_numRawPathsEMBR; + // decide whether should apply regularization into BatchNormalizationNode // true: disable Regularization // false: enable Regularization (default) diff --git a/Source/SGDLib/SimpleEvaluator.h b/Source/SGDLib/SimpleEvaluator.h index 2941c26d4e26..ed669e51fdf9 100644 --- a/Source/SGDLib/SimpleEvaluator.h +++ b/Source/SGDLib/SimpleEvaluator.h @@ -120,7 +120,8 @@ class SimpleEvaluator for (;;) { size_t actualMBSize = 0; - bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork(*dataReader, m_net, nullptr, useDistributedMBReading, useParallelTrain, inputMatrices, actualMBSize, m_mpi); + size_t actualNumWords = 0; + bool wasDataRead = DataReaderHelpers::GetMinibatchIntoNetwork(*dataReader, m_net, nullptr, useDistributedMBReading, useParallelTrain, inputMatrices, actualMBSize, m_mpi, actualNumWords); // in case of distributed reading, we do a few more loops until all ranks have completed // end of epoch if (!wasDataRead && (!useDistributedMBReading || noMoreSamplesToProcess)) diff --git a/Source/SGDLib/SimpleOutputWriter.h b/Source/SGDLib/SimpleOutputWriter.h index 93d46ed8d292..5d0cca529f35 100644 --- a/Source/SGDLib/SimpleOutputWriter.h +++ b/Source/SGDLib/SimpleOutputWriter.h @@ -62,7 +62,8 @@ class SimpleOutputWriter const size_t numIterationsBeforePrintingProgress = 100; size_t numItersSinceLastPrintOfProgress = 0; size_t actualMBSize; - while (DataReaderHelpers::GetMinibatchIntoNetwork(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize, nullptr)) + size_t actualNumWords; + while (DataReaderHelpers::GetMinibatchIntoNetwork(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize, nullptr, actualNumWords)) { ComputationNetwork::BumpEvalTimeStamp(inputNodes); m_net->ForwardProp(outputNodes); @@ -230,7 +231,8 @@ class SimpleOutputWriter char formatChar = !formattingOptions.isCategoryLabel ? 'f' : !formattingOptions.labelMappingFile.empty() ? 's' : 'u'; std::string valueFormatString = "%" + formattingOptions.precisionFormat + formatChar; // format string used in fprintf() for formatting the values - for (size_t numMBsRun = 0; DataReaderHelpers::GetMinibatchIntoNetwork(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize, nullptr); numMBsRun++) + size_t actualNumWords; + for (size_t numMBsRun = 0; DataReaderHelpers::GetMinibatchIntoNetwork(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize, nullptr, actualNumWords); numMBsRun++) { ComputationNetwork::BumpEvalTimeStamp(inputNodes); m_net->ForwardProp(outputNodes); diff --git a/Source/SequenceTrainingLib/gammacalculation.h b/Source/SequenceTrainingLib/gammacalculation.h index 80f19d98aff9..1ce2a3626902 100644 --- a/Source/SequenceTrainingLib/gammacalculation.h +++ b/Source/SequenceTrainingLib/gammacalculation.h @@ -11,6 +11,7 @@ #include #include +#include #pragma warning(disable : 4127) // conditional expression is constant namespace msra { namespace lattices { @@ -22,6 +23,17 @@ struct SeqGammarCalParam double wp; double bMMIfactor; bool sMBRmode; + bool EMBR; + std::string EMBRUnit; + size_t numPathsEMBR; + bool enforceValidPathEMBR; + std::string getPathMethodEMBR; + std::string showWERMode; + bool excludeSpecialWords; + bool wordNbest; + bool useAccInNbest; + float accWeightInNbest; + size_t numRawPathsEMBR; SeqGammarCalParam() { amf = 14.0; @@ -29,6 +41,17 @@ struct SeqGammarCalParam wp = 0.0; bMMIfactor = 0.0; sMBRmode = false; + EMBR = false; + EMBRUnit = "word"; + numPathsEMBR = 100; + enforceValidPathEMBR = false; + getPathMethodEMBR = "sampling"; + showWERMode = "average"; + excludeSpecialWords = false; + wordNbest = false; + useAccInNbest = false; + accWeightInNbest = 1.0; + numRawPathsEMBR = 100; } }; @@ -82,6 +105,25 @@ class GammaCalculation seqsMBRmode = gammarParam.sMBRmode; boostmmifactor = (float) gammarParam.bMMIfactor; } + void SetGammarCalculationParamsEMBR(const SeqGammarCalParam& gammarParam) + { + lmf = (float) gammarParam.lmf; + amf = (float) gammarParam.amf; + wp = (float) gammarParam.wp; + seqsMBRmode = gammarParam.sMBRmode; + boostmmifactor = (float) gammarParam.bMMIfactor; + EMBR = gammarParam.EMBR; + EMBRUnit = gammarParam.EMBRUnit; + numPathsEMBR = gammarParam.numPathsEMBR; + enforceValidPathEMBR = gammarParam.enforceValidPathEMBR; + getPathMethodEMBR = gammarParam.getPathMethodEMBR; + showWERMode = gammarParam.showWERMode; + excludeSpecialWords = gammarParam.excludeSpecialWords; + wordNbest = gammarParam.wordNbest; + useAccInNbest = gammarParam.useAccInNbest; + accWeightInNbest = gammarParam.accWeightInNbest; + numRawPathsEMBR = gammarParam.numRawPathsEMBR; + } // ======================================== // Sec. 3 calculation functions @@ -91,7 +133,7 @@ class GammaCalculation const Microsoft::MSR::CNTK::Matrix& loglikelihood, Microsoft::MSR::CNTK::Matrix& labels, Microsoft::MSR::CNTK::Matrix& gammafromlattice, - std::vector& uids, std::vector& boundaries, + std::vector& uids, std::vector& wids, std::vector& nws, std::vector& boundaries, size_t samplesInRecurrentStep, /* numParallelUtterance ? */ std::shared_ptr pMBLayout, std::vector& extrauttmap, @@ -128,9 +170,11 @@ class GammaCalculation size_t mapi = 0; // parallel-sequence index for utterance [i] // cal gamma for each utterance size_t ts = 0; + size_t ws = 0; for (size_t i = 0; i < lattices.size(); i++) { const size_t numframes = lattices[i]->getnumframes(); + const short numwords = nws[i]; msra::dbn::matrixstripe predstripe(pred, ts, numframes); // logLLs for this utterance msra::dbn::matrixstripe dengammasstripe(dengammas, ts, numframes); // denominator gammas @@ -186,6 +230,7 @@ class GammaCalculation } array_ref uidsstripe(&uids[ts], numframes); + std::vector widsstripe(wids.begin() + ws, wids.begin() + ws + numwords); if (doreferencealign) { @@ -204,13 +249,12 @@ class GammaCalculation numavlogp /= numframes; // auto_timer dengammatimer; + double denavlogp = lattices[i]->second.forwardbackward(parallellattice, (const msra::math::ssematrixbase&) predstripe, (const msra::asr::simplesenonehmm&) m_hset, (msra::math::ssematrixbase&) dengammasstripe, (msra::math::ssematrixbase&) gammasbuffer /*empty, not used*/, - lmf, wp, amf, boostmmifactor, seqsMBRmode, uidsstripe, boundariesstripe); - - objectValue += (ElemType)((numavlogp - denavlogp) * numframes); - + lmf, wp, amf, boostmmifactor, seqsMBRmode, EMBR, EMBRUnit, numPathsEMBR, enforceValidPathEMBR, getPathMethodEMBR, showWERMode, excludeSpecialWords, wordNbest, useAccInNbest, accWeightInNbest, numRawPathsEMBR, uidsstripe, widsstripe, boundariesstripe); + objectValue += (ElemType)(denavlogp*numwords); if (samplesInRecurrentStep == 1) { tempmatrix = gammafromlattice.ColumnSlice(ts, numframes); @@ -244,8 +288,8 @@ class GammaCalculation } if (samplesInRecurrentStep > 1) validframes[mapi] += numframes; // advance the cursor within the parallel sequence - fprintf(stderr, "dengamma value %f\n", denavlogp); ts += numframes; + ws += numwords; } functionValues.SetValue(objectValue); } @@ -510,6 +554,18 @@ class GammaCalculation float boostmmifactor; bool seqsMBRmode; + bool EMBR; + std::string EMBRUnit; + size_t numPathsEMBR; + bool enforceValidPathEMBR; + std::string getPathMethodEMBR; + std::string showWERMode; + bool excludeSpecialWords; + bool wordNbest; + bool useAccInNbest; + float accWeightInNbest; + size_t numRawPathsEMBR; + private: std::unique_ptr m_cudaAllocator; std::shared_ptr m_intermediateCUDACopyBuffer; diff --git a/Source/SequenceTrainingLib/latticeforwardbackward.cpp b/Source/SequenceTrainingLib/latticeforwardbackward.cpp index 2dbb244bf930..f269d152aac1 100644 --- a/Source/SequenceTrainingLib/latticeforwardbackward.cpp +++ b/Source/SequenceTrainingLib/latticeforwardbackward.cpp @@ -504,6 +504,18 @@ double lattice::forwardbackwardlattice(const std::vector &edgeacscores, p { double totalfwscore = parallelforwardbackwardlattice(parallelstate, edgeacscores, thisedgealignments, lmf, wp, amf, boostingfactor, logpps, logalphas, logbetas, sMBRmode, uids, logEframescorrect, Eframescorrectbuf, logEframescorrecttotal); + parallelstate.getlogbetas(logbetas); + if (nodes.size() != logbetas.size()) + { + // it is possible if #define TWO_CHANNEL in parallelforwardbackward.cpp: in which case, logbetas will be doulbe the size of (nodes) + if (logbetas.size() != (nodes.size() * 2)) + { + RuntimeError("forwardbackwardlattice: logbetas size is not equal or twice of node size, logbetas.size() = %d, nodes.size() = %d", int(logbetas.size()), int(nodes.size())); + } + + //only taket the first half of the data + logbetas.erase(logbetas.begin() + nodes.size(), logbetas.begin() + logbetas.size()); + } return totalfwscore; } // if we get here, we have no CUDA, and do it the good ol' way @@ -647,6 +659,373 @@ double lattice::forwardbackwardlattice(const std::vector &edgeacscores, p return totalfwscore; } +void lattice::constructnodenbestoken(std::vector &tokenlattice, const bool wordNbest, size_t numtokens2keep, size_t nidx) const +{ + std::map>::iterator mp_itr; + std::map> mp_wid_tokenidx; + std::map>::iterator mp_itr1; + size_t count; + bool done; + TokenInfo tokeninfo; + uint64_t wid; + vector vt_tokenidx; + + if (wordNbest) mp_wid_tokenidx.clear(); + + count = 0; + done = false; + // Sometime,s numtokens is larger than numPathsEMBR. if , keep tokens to be numPathsEMBR + + for (mp_itr = tokenlattice[nidx].mp_score_token_infos.begin(); mp_itr != tokenlattice[nidx].mp_score_token_infos.end(); mp_itr++) + { + for (size_t i = 0; i < mp_itr->second.size(); i++) + { + tokeninfo.prev_edge_index = mp_itr->second[i].prev_edge_index; + tokeninfo.prev_token_index = mp_itr->second[i].prev_token_index; + tokeninfo.score = mp_itr->second[i].path_score; + + if (wordNbest) + { + wid = nodes[edges[tokeninfo.prev_edge_index].S].wid; + mp_itr1 = mp_wid_tokenidx.find(wid); + + bool different = true; + + if (mp_itr1 == mp_wid_tokenidx.end()) + { + // the wid does not exist in previous tokens of this node, so it is a path with different word sequence + vt_tokenidx.clear(); + vt_tokenidx.push_back(count); + mp_wid_tokenidx.insert(pair>(wid, vt_tokenidx)); + } + else + { + for (size_t j = 0; j < mp_itr1->second.size(); j++) + { + + size_t oldnodeidx, oldtokenidx, newnodeidx, newtokenidx; + + oldnodeidx = edges[tokenlattice[nidx].vt_nbest_tokens[mp_itr1->second[j]].prev_edge_index].S; + oldtokenidx = tokenlattice[nidx].vt_nbest_tokens[mp_itr1->second[j]].prev_token_index; + newnodeidx = edges[tokeninfo.prev_edge_index].S; newtokenidx = tokeninfo.prev_token_index; + + + while (1) + { + if (nodes[oldnodeidx].wid != nodes[newnodeidx].wid) break; + if (oldnodeidx == newnodeidx) + { + if (oldtokenidx == newtokenidx) different = false; + break; + } + + if (oldnodeidx == 0 || newnodeidx == 0) + { + fprintf(stderr, "nbestlatticeEMBR: WARNING: should not come her, oldnodeidx = %d, newnodeidx = %d\n", int(oldnodeidx), int(newnodeidx)); + break; + } + size_t tmpnodeix, tmptokenidx; + + + tmpnodeix = edges[tokenlattice[oldnodeidx].vt_nbest_tokens[oldtokenidx].prev_edge_index].S; + tmptokenidx = tokenlattice[oldnodeidx].vt_nbest_tokens[oldtokenidx].prev_token_index; + oldnodeidx = tmpnodeix; oldtokenidx = tmptokenidx; + + + tmpnodeix = edges[tokenlattice[newnodeidx].vt_nbest_tokens[newtokenidx].prev_edge_index].S; + tmptokenidx = tokenlattice[newnodeidx].vt_nbest_tokens[newtokenidx].prev_token_index; + newnodeidx = tmpnodeix; newtokenidx = tmptokenidx; + } + if (!different) break; + } + + if (different) + { + mp_itr1->second.push_back(count); + } + } + + if (different) + { + tokenlattice[nidx].vt_nbest_tokens.push_back(tokeninfo); + count++; + } + } + else + { + tokenlattice[nidx].vt_nbest_tokens.push_back(tokeninfo); + count++; + } + + if (count >= numtokens2keep) + { + done = true; + break; + } + } + if (done) break; + } + + // free the space. + tokenlattice[nidx].mp_score_token_infos.clear(); + +} +float compute_wer(vector &ref, vector &rec) +{ + short ** mat; + size_t i, j; + + mat = new short*[rec.size() + 1]; + for (i = 0; i <= rec.size(); i++) mat[i] = new short[ref.size() + 1]; + + for (i = 0; i <= rec.size(); i++) mat[i][0] = short(i); + for (j = 1; j <= ref.size(); j++) mat[0][j] = short(j); + + for (i = 1; i <= rec.size(); i++) + for (j = 1; j <= ref.size(); j++) + { + mat[i][j] = mat[i - 1][j - 1]; + + if (rec[i - 1] != ref[j - 1]) + { + + if ((mat[i - 1][j]) < mat[i][j]) mat[i][j] = mat[i - 1][j]; + if ((mat[i][j - 1]) < mat[i][j]) mat[i][j] = mat[i][j - 1]; + mat[i][j] ++; + } + } + float wer = float(mat[rec.size()][ref.size()]) / ref.size(); + for (i = 0; i < rec.size(); i++) delete[] mat[i]; + delete[] mat; + return wer; +} + + + +double lattice::nbestlatticeEMBR(const std::vector &edgeacscores, parallelstate ¶llelstate, std::vector &tokenlattice, const size_t numtokens, const bool enforceValidPathEMBR, const bool excludeSpecialWords, + const float lmf, const float wp, const float amf, const bool wordNbest, const bool useAccInNbest, const float accWeightInNbest, const size_t numPathsEMBR, std::vector wids) const +{ // ^^ TODO: remove this + // --- hand off to parallelized (CUDA) implementation if available + std::map>::iterator mp_itr; + size_t numtokens2keep; + // TODO: support parallel state + parallelstate; + PrevTokenInfo prevtokeninfo; + std::vector vt_prevtokeninfo; + + // if we get here, we have no CUDA, and do it the good ol' way + + // allocate return values + tokenlattice.resize(nodes.size()); + + tokenlattice[0].vt_nbest_tokens.resize(1); + tokenlattice[0].vt_nbest_tokens[0].score = 0.0f; + tokenlattice[0].vt_nbest_tokens[0].prev_edge_index = 0; + tokenlattice[0].vt_nbest_tokens[0].prev_token_index = 0; + // forward pass + foreach_index(j, edges) + { + const auto &e = edges[j]; + if (enforceValidPathEMBR) + { + if (e.S == 0 && nodes[e.E].wid != 1) continue; + } + if (excludeSpecialWords) + { + // 0~4 is: !NULL, , , !sent_start, and !sent_end + if (nodes[e.E].wid > 4) + { + if (is_special_words[e.E]) continue; + } + if (nodes[e.S].wid > 4) + { + if (is_special_words[e.S]) continue; + } + + } + + if (tokenlattice[e.S].mp_score_token_infos.size() != 0) + { + //sanity check + if(tokenlattice[e.S].vt_nbest_tokens.size() != 0) + RuntimeError("nbestlatticeEMBR: node = %d, mp_score_token_infos.size() = %d, vt_nbest_tokens.size() = %d, both are not 0!", int(e.S), int(tokenlattice[e.S].mp_score_token_infos.size()), int(tokenlattice[e.S].vt_nbest_tokens.size())); + + + + // Sometime,s numtokens is larger than numPathsEMBR. if , keep tokens to be numPathsEMBR + + if (nodes[e.S].wid == 2) numtokens2keep = numPathsEMBR; + else numtokens2keep = numtokens; + + constructnodenbestoken(tokenlattice, wordNbest, numtokens2keep, e.S); + + } + + if (tokenlattice[e.S].vt_nbest_tokens.size() == 0) + { + // it is possible to happen, when you exclude specialwords + continue; + } + prevtokeninfo.prev_edge_index = j; + + const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf; // note: edgeacscores[j] == LOGZERO if edge was pruned + + for (size_t i = 0; i < tokenlattice[e.S].vt_nbest_tokens.size(); i++) + { + prevtokeninfo.prev_token_index = i; + + double pathscore = tokenlattice[e.S].vt_nbest_tokens[i].score + edgescore; + + prevtokeninfo.path_score = pathscore; + + + if (useAccInNbest && nodes[e.E].wid == 2) + { + // add the wegithed path Accuracy into path score + + std::vector path, path_ids; // stores the edges in the path + + size_t curnodeidx, curtokenidx, prevtokenidx, prevnodeidx; + // ignore the edge with ending node in the path, as will anyway not be used for WER computation + path.clear(); // store the edge sequence of the path + path_ids.clear(); // store the wid sequence of the path + curnodeidx = e.S; + curtokenidx = i; + while (curnodeidx != 0) + { + path.insert(path.begin(), tokenlattice[curnodeidx].vt_nbest_tokens[curtokenidx].prev_edge_index); + + prevtokenidx = tokenlattice[curnodeidx].vt_nbest_tokens[curtokenidx].prev_token_index; + prevnodeidx = edges[tokenlattice[curnodeidx].vt_nbest_tokens[curtokenidx].prev_edge_index].S; + + curnodeidx = prevnodeidx; + curtokenidx = prevtokenidx; + } + + + for (size_t k = 0; k < path.size(); k++) + { + if (k == 0) + { + if (!is_special_words[edges[path[k]].S]) path_ids.push_back(nodes[edges[path[k]].S].wid); + } + if (!is_special_words[edges[path[k]].E]) path_ids.push_back(nodes[edges[path[k]].E].wid); + } + + float wer = compute_wer(wids, path_ids); + // will favor the path with better WER + pathscore -= double(accWeightInNbest*wer); + + // If you only want WER to affect the selection of Nbest, disable the below line. If you aslo want the WER as weight in error computation, enable this line + prevtokeninfo.path_score = pathscore; + } + + mp_itr = tokenlattice[e.E].mp_score_token_infos.find(pathscore); + if (mp_itr != tokenlattice[e.E].mp_score_token_infos.end()) + { + mp_itr->second.push_back(prevtokeninfo); + } + else + { + vt_prevtokeninfo.clear(); + vt_prevtokeninfo.push_back(prevtokeninfo); + tokenlattice[e.E].mp_score_token_infos.insert(std::pair>(pathscore, vt_prevtokeninfo)); + } + + } + } + // for the last node, which is or !NULL (!NULL if you do not merge numerator lattice into denominator lattice) + numtokens2keep = numPathsEMBR; + constructnodenbestoken(tokenlattice, wordNbest, numtokens2keep, tokenlattice.size() - 1); + + + double bestscore; + if (tokenlattice[tokenlattice.size() - 1].vt_nbest_tokens.size() == 0) + { + if (!excludeSpecialWords) RuntimeError("nbestlatticeEMBR: no token survive while excludeSpecialWords is false"); + else bestscore = LOGZERO; + + } + else bestscore = tokenlattice[tokenlattice.size() - 1].vt_nbest_tokens[0].score; + + + if (islogzero(bestscore)) + { + + fprintf(stderr, "nbestlatticeEMBR: WARNING: best score is logzero in lattice \n"); + return LOGZERO; // failed, do not use resulting matrix + } + + + return bestscore; +} + +// --------------------------------------------------------------------------- +// backwardlatticeEMBR() -- lattice-level backward +// +// This computes per-node betas for EMBR +// --------------------------------------------------------------------------- + +double lattice::backwardlatticeEMBR(const std::vector &edgeacscores, parallelstate ¶llelstate, std::vector &edgelogbetas, std::vector &logbetas, + const float lmf, const float wp, const float amf) const +{ // ^^ TODO: remove this + // --- hand off to parallelized (CUDA) implementation if available + if (parallelstate.enabled()) + { + double totalbwscore = parallelbackwardlatticeEMBR(parallelstate, edgeacscores, lmf, wp, amf, edgelogbetas, logbetas); + + parallelstate.getlogbetas(logbetas); + parallelstate.getedgelogbetas(edgelogbetas); + if (nodes.size() != logbetas.size()) + { + // it is possible if #define TWO_CHANNEL in parallelforwardbackward.cpp: in which case, logbetas will be doulbe the size of (nodes) + if (logbetas.size() != (nodes.size() * 2)) + { + RuntimeError("forwardbackwardlattice: logbetas size is not equal or twice of node size, logbetas.size() = %d, nodes.size() = %d", int(logbetas.size()), int(nodes.size())); + } + + //only taket the first half of the data + logbetas.erase(logbetas.begin() + nodes.size(), logbetas.begin() + logbetas.size()); + } + + + return totalbwscore; + } + // if we get here, we have no CUDA, and do it the good ol' way + + // allocate return values + + logbetas.assign(nodes.size(), LOGZERO); + logbetas.back() = 0.0f; + + edgelogbetas.assign(edges.size(), LOGZERO); + + + // backward pass + // this also computes the word posteriors on the fly, since we are at it + for (size_t j = edges.size() - 1; j + 1 > 0; j--) + { + const auto &e = edges[j]; + const double inscore = logbetas[e.E]; + const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf; + const double pathscore = inscore + edgescore; + + edgelogbetas[j] = pathscore; + + logadd(logbetas[e.S], pathscore); + + } + + const double totalbwscore = logbetas.front(); + + if (islogzero(totalbwscore)) + { + fprintf(stderr, "backwardlatticeEMBR: WARNING: no path found in lattice (%d nodes/%d edges)\n", (int)nodes.size(), (int)edges.size()); + return LOGZERO; // failed, do not use resulting matrix + } + + return totalbwscore; +} // --------------------------------------------------------------------------- // forwardbackwardlatticesMBR() -- compute expected frame-accuracy counts, // both the conditioned one (corresponding to c(q) in Dan Povey's thesis) @@ -997,6 +1376,13 @@ void lattice::forwardbackwardalign(parallelstate ¶llelstate, } } } + + // make sure thisedgealignment has values for later CPU use + if (parallelstate.enabled()) + { + parallelstate.copyalignments(thisedgealignments); + parallelstate.getedgeacscores(edgeacscores); + } } // compute the error signal for sMBR mode @@ -1043,6 +1429,301 @@ void lattice::sMBRerrorsignal(parallelstate ¶llelstate, } } +// compute the error signal for sMBR mode +size_t sample_from_cumulative_prob(const std::vector &cumulative_prob) +{ + if (cumulative_prob.size() < 1) + { + RuntimeError("sample_from_cumulative_prob: the number of bins is 0 \n"); + } + double rand_prob = (double)rand() / (double)RAND_MAX * cumulative_prob.back(); + for (size_t i = 0; i < cumulative_prob.size() - 1; i++) + { + if (rand_prob <= cumulative_prob[i]) return i; + } + return cumulative_prob.size() - 1; +} + +void lattice::EMBRsamplepaths(const std::vector &edgelogbetas, + const std::vector &logbetas, const size_t numPathsEMBR, const bool enforceValidPathEMBR, const bool excludeSpecialWords, std::vector> & vt_paths) const +{ + // In mp_node_ocp, key is the node id, and value stores the outgoing cumulative locally normalized probability. e.g., if the outgoing probabilities of the node are 0.3 0.1 0.6, the ocp stores: 0.3 0.4 1.0. + // This serves as a cache to avoid recomputation if sampling the same node twice + std::map> mp_node_ocp; + std::map>::iterator mp_itr; + std::vector path; // stores the edges in the path + std::vector ocp; + + mp_node_ocp.clear(); + vt_paths.clear(); + size_t curnodeidx, edgeidx; + if(enforceValidPathEMBR) + { + for (size_t i = 0; i < vt_node_out_edge_indices[0].size(); i++) + { + // remove the edge + if (nodes[edges[vt_node_out_edge_indices[0][i]].E].wid != 1) lattice::erase_node_out_edges(0, i, i); + } + + } + + // this is inefficent implementation, we should think of efficient ways to do it later + if (excludeSpecialWords) + { + size_t nidx; + for(size_t j = 0; j < vt_node_out_edge_indices.size(); j++) + { + for (size_t i = 0; i < vt_node_out_edge_indices[j].size(); i++) + { + // remove the edge + // 0~4 is: !NULL, , , !sent_start, and !sent_end + nidx = edges[vt_node_out_edge_indices[j][i]].E; + + if (nodes[nidx].wid > 4) + { + if (is_special_words[nidx]) + { + lattice::erase_node_out_edges(j, i, i); + continue; + } + } + + nidx = edges[vt_node_out_edge_indices[j][i]].S; + + if (nodes[nidx].wid > 4) + { + if (is_special_words[nidx]) lattice::erase_node_out_edges(j, i, i); + } + } + } + } + while (vt_paths.size() < numPathsEMBR) + { + path.clear(); + curnodeidx = 0; + //start sampling from node 0 + bool success = false; + + while(true) + { + mp_itr = mp_node_ocp.find(curnodeidx); + if (mp_itr == mp_node_ocp.end()) + { + ocp.clear(); + + for (size_t i = 0; i < vt_node_out_edge_indices[curnodeidx].size(); i++) + { + double prob = exp(edgelogbetas[vt_node_out_edge_indices[curnodeidx][i]] - logbetas[curnodeidx]); + if(i == 0) ocp.push_back(prob); + else ocp.push_back(prob + ocp.back()); + } + mp_node_ocp.insert(pair>(curnodeidx, ocp)); + edgeidx = vt_node_out_edge_indices[curnodeidx][sample_from_cumulative_prob(ocp)]; + + } + else + { + edgeidx = vt_node_out_edge_indices[curnodeidx][sample_from_cumulative_prob(mp_itr->second)]; + } + + path.push_back(edgeidx); + curnodeidx = edges[edgeidx].E; + // the end of lattice is not !NULL (the end of !NULL is deleted in dbn.exe when converting lattice of htk format to chunk) + // if (nodes[edges[edgeidx].E].t == nodes[edges[edgeidx].S].t) + // wid = 2 is for , the lattice ends with + // the node has no outgoing arc + if (vt_node_out_edge_indices[curnodeidx].size() == 0) + { + if ( (nodes[curnodeidx].wid == 2 || nodes[curnodeidx].wid == 0) && nodes[curnodeidx].t == info.numframes) + { + success = true; + break; + } + else + { + fprintf(stderr, "EMBRsamplepaths: WARNING: the node with index = %d has no outgoing arc, but it is not the node with timing ending with last frame \n", int(curnodeidx)); + success = false; + break; + } + } + } + if (success == true) vt_paths.push_back(path); + } + if (vt_paths.size() != numPathsEMBR) + { + fprintf(stderr, "EMBRsamplepaths: Error: vt_paths.size() = %d, and numPathsEMBR = %d \n", int(vt_paths.size()), int(numPathsEMBR)); + exit(-1); + } +} + +void lattice::EMBRnbestpaths(std::vector& tokenlattice, std::vector> & vt_paths, std::vector& path_posterior_probs) const +{ + + double log_nbest_posterior_prob; + + path_posterior_probs.resize(tokenlattice[tokenlattice.size() - 1].vt_nbest_tokens.size()); + log_nbest_posterior_prob = LOGZERO; + + for (size_t i = 0; i < tokenlattice[tokenlattice.size() - 1].vt_nbest_tokens.size(); i++) + { + logadd(log_nbest_posterior_prob, tokenlattice[tokenlattice.size() - 1].vt_nbest_tokens[i].score); + } + for (size_t i = 0; i < tokenlattice[tokenlattice.size() - 1].vt_nbest_tokens.size(); i++) + { + path_posterior_probs[i] = exp(tokenlattice[tokenlattice.size() - 1].vt_nbest_tokens[i].score - log_nbest_posterior_prob); + } + std::vector path; // stores the edges in the path + vt_paths.clear(); + size_t curnodeidx, curtokenidx, prevtokenidx, prevnodeidx; + + for (size_t i = 0; i < tokenlattice[tokenlattice.size() - 1].vt_nbest_tokens.size(); i++) + { + path.clear(); + curnodeidx = tokenlattice.size() - 1; + curtokenidx = i; + while (curnodeidx != 0) + { + path.insert(path.begin(), tokenlattice[curnodeidx].vt_nbest_tokens[curtokenidx].prev_edge_index); + + prevtokenidx = tokenlattice[curnodeidx].vt_nbest_tokens[curtokenidx].prev_token_index; + prevnodeidx = edges[tokenlattice[curnodeidx].vt_nbest_tokens[curtokenidx].prev_edge_index].S; + + curnodeidx = prevnodeidx; + curtokenidx = prevtokenidx; + + } + vt_paths.push_back(path); + } +} +double lattice::get_edge_weights(std::vector& wids, std::vector>& vt_paths, std::vector& vt_edge_weights, std::vector& vt_path_posterior_probs, string getPathMethodEMBR, double& onebest_wer) const +{ + + struct PATHINFO + { + size_t count; + float WER; + }; + + std::map mp_path_info; + std::map::iterator mp_itr; + std::set set_edge_path; + + std::vector vt_path_weights; + vt_path_weights.resize(vt_paths.size()); + + + vector path_ids; + double avg_wer; + avg_wer = 0; + + for (size_t i = 0; i < vt_paths.size(); i++) + { + path_ids.clear(); + + for (size_t j = 0; j < vt_paths[i].size(); j++) + { + if (j == 0) + { + if (!is_special_words[edges[vt_paths[i][j]].S]) path_ids.push_back(nodes[edges[vt_paths[i][j]].S].wid); + + nodes[edges[vt_paths[i][j]].S].wid; + } + if (!is_special_words[edges[vt_paths[i][j]].E]) path_ids.push_back(nodes[edges[vt_paths[i][j]].E].wid); + nodes[edges[vt_paths[i][j]].E].wid; + } + + vt_path_weights[i] = compute_wer(wids, path_ids); + + string pathidstr = "$"; + for (size_t j = 0; j < path_ids.size(); j++) pathidstr += ("_" + std::to_string(path_ids[j])); + mp_itr = mp_path_info.find(pathidstr); + if (mp_itr != mp_path_info.end()) + { + mp_itr->second.count++; + } + else + { + PATHINFO pathinfo; + pathinfo.count = 1; + pathinfo.WER = float(vt_path_weights[i]); + mp_path_info.insert(pair(pathidstr, pathinfo)); + } + + // this uses weighted avg wer + avg_wer += (vt_path_weights[i] * vt_path_posterior_probs[i]); + + } + if (getPathMethodEMBR == "sampling") onebest_wer = -10000; + else onebest_wer = vt_path_weights[0]; + + for (size_t i = 0; i < vt_path_weights.size(); i++) + { + // loss - mean_loss + vt_path_weights[i] -= avg_wer; + if(getPathMethodEMBR == "sampling") vt_path_weights[i] /= (vt_paths.size() - 1); + else vt_path_weights[i] *= (vt_path_posterior_probs[i]); + } + + + for (size_t i = 0; i < vt_paths.size(); i++) + { + for (size_t j = 0; j < vt_paths[i].size(); j++) + // substraction, since we want to minimize the loss function, rather than maximize + vt_edge_weights[vt_paths[i][j]] -= vt_path_weights[i]; + } + + set_edge_path.clear(); + + for (size_t i = 0; i < vt_paths.size(); i++) + { + string pathedgeidstr = "$"; + for (size_t j = 0; j < vt_paths[i].size(); j++) + { + pathedgeidstr += ("_" + std::to_string(vt_paths[i][j])); + + } + set_edge_path.insert(pathedgeidstr); + } + return avg_wer; +} +void lattice::EMBRerrorsignal(parallelstate ¶llelstate, + const edgealignments &thisedgealignments, std::vector& edge_weights, msra::math::ssematrixbase &errorsignal) const + +{ + Microsoft::MSR::CNTK::Matrix errorsignalcpu(-1); + if (parallelstate.enabled()) // parallel version + { + parallelstate.setedgeweights(edge_weights); + std::vector verify_edge_weights; + parallelstate.getedgeweights(verify_edge_weights); + parallelEMBRerrorsignal(parallelstate, thisedgealignments, edge_weights, errorsignal); + parallelstate.getgamma(errorsignalcpu); + return; + + } + // linear mode + foreach_coord(i, j, errorsignal) + errorsignal(i, j) = 0.0f; // Note: we don't actually put anything into the numgammas + + foreach_index(j, edges) + { + + const auto &e = edges[j]; + if (nodes[e.S].t == nodes[e.E].t) // this happens for dummy !NULL edge at end of file + continue; + + + size_t ts = nodes[e.S].t; + size_t te = nodes[e.E].t; + + for (size_t t = ts; t < te; t++) + { + const size_t s = thisedgealignments[j][t - ts]; + errorsignal(s, t) = errorsignal(s, t) + float(edge_weights[j]); + } + } +} + // compute the error signal for MMI mode void lattice::mmierrorsignal(parallelstate ¶llelstate, double minlogpp, const std::vector &origlogpps, std::vector &abcs, const bool softalignstates, @@ -1335,9 +2016,54 @@ void sMBRsuppressweirdstuff(msra::math::ssematrixbase &errorsignal, const_array_ double lattice::forwardbackward(parallelstate ¶llelstate, const msra::math::ssematrixbase &logLLs, const msra::asr::simplesenonehmm &hset, msra::math::ssematrixbase &result, msra::math::ssematrixbase &errorsignalbuf, const float lmf, const float wp, const float amf, const float boostingfactor, - const bool sMBRmode, array_ref uids, const_array_ref bounds, - const_array_ref transcript, const std::vector &transcriptunigrams) const + const bool sMBRmode, const bool EMBR, const string EMBRUnit, const size_t numPathsEMBR, const bool enforceValidPathEMBR, const string getPathMethodEMBR, const string showWERMode, const bool excludeSpecialWords, const bool wordNbest, const bool useAccInNbest, const float accWeightInNbest, const size_t numRawPathsEMBR, array_ref uids, vector wids, const_array_ref bounds, + const_array_ref transcript, const std::vector &transcriptunigram) const { + + std::vector tokenlattice; + tokenlattice.clear(); + + if (wids.size() == 0) return 0; + + if (numPathsEMBR < 1) + { + fprintf(stderr, "forwardbackward: WARNING: numPathsEMBR = %d , which is smaller than 1\n", (int)numPathsEMBR); + return LOGZERO; // failed, do not use resulting matrix + } + if (EMBRUnit != "word") + { + fprintf(stderr, "forwardbackward: Error: Currently do not support EMBR unit other than word\n"); + return LOGZERO; // failed, do not use resulting matrix + } + // sanity check + + if (nodes[0].wid != 0) RuntimeError("The first node is not 0 (i.e.) !NULL, but is %d \n", int(nodes[0].wid)); + + // the lattice last node could be either 0 or 2, i.e., if it is an merged lattice (merged numerator and denominator the dnb code dedicately removes ending !NULL, it is 0. If it is not merged lattice (the one that I changed TAMER code to only use denominator lattice), the last node could be !NULL + if(nodes[nodes.size()-1].wid != 2 && nodes[nodes.size() - 1].wid != 0) RuntimeError("The last node is not 2 (i.e.) or 0 (i.e, !NULL), but is %d \n", int(nodes[0].wid)); + // I want to make sure there is only one , it is crucial when the useAccinNbest is true: we add sentence acc into nbest cost function in the . + size_t sent_end_count = 0; + + if (nodes[nodes.size() - 1].wid == 2) sent_end_count = 1; + + for (size_t i = 1; i < nodes.size() - 1; i++) + { + if (nodes[i].wid == 2) + { + if (nodes[nodes.size() - 1].wid == 2) + { + RuntimeError("The node %d wid is 2 (i.e.) , but it is not the last node, total number of node is %d \n", int(i), int(nodes.size())); + } + sent_end_count++; + } + if (nodes[i].wid == 0) RuntimeError("The node %d wid is 0 (i.e.) , but it is not the first node or last node, total number of node is %d \n", int(i), int(nodes.size())); + + } + if (sent_end_count != 1) + { + RuntimeError(" count is not 1 in the lattice, but %d, and total number of node is %d \n", int(sent_end_count), int(nodes.size())); + } + bool softalign = true; bool softalignstates = false; // true if soft alignment within edges, currently we only support soft within edge in cpu mode bool softalignlattice = softalign; // w.r.t. whole lattice @@ -1364,14 +2090,13 @@ double lattice::forwardbackward(parallelstate ¶llelstate, const msra::math:: // score the ground truth --only if a transcript is provided, which happens if the user provides a language model // TODO: no longer used, remove this. 'transcript' parameter is no longer used in this function. transcript; - transcriptunigrams; + transcriptunigram; // allocate alpha/beta/gamma matrices (all are sharing the same memory in-place) std::vector abcs; - std::vector edgeacscores; // [edge index] acoustic scores - // funcation call for forwardbackward on edge level - forwardbackwardalign(parallelstate, hset, softalignstates, minlogpp, origlogpps, abcs, matrixheap, sMBRmode /*returnsenoneids*/, edgeacscores, logLLs, thisedgealignments, thisbackpointers, uids, bounds); - + std::vector edgeacscores; // [edge index] acoustic scores + // return senone id for EMBR or sMBR, but not for MMI + forwardbackwardalign(parallelstate, hset, softalignstates, minlogpp, origlogpps, abcs, matrixheap, (sMBRmode || EMBR) /*returnsenoneids*/, edgeacscores, logLLs, thisedgealignments, thisbackpointers, uids, bounds); // PHASE 2: lattice-level forward backward // we exploit that the lattice is sorted by (end node, start node) for in-place processing @@ -1385,16 +2110,36 @@ double lattice::forwardbackward(parallelstate ¶llelstate, const msra::math:: std::vector logEframescorrect; // this is the final output of PHASE 2 std::vector logalphas; std::vector logbetas; + + std::vector edgelogbetas; // the edge score plus the edge's outgoing node's beta scores double totalfwscore = 0; // TODO: name no longer precise in sMBRmode double logEframescorrecttotal = LOGZERO; + double totalbwscore = 0; bool returnEframescorrect = sMBRmode; if (softalignlattice) { - totalfwscore = forwardbackwardlattice(edgeacscores, parallelstate, logpps, logalphas, logbetas, lmf, wp, amf, boostingfactor, returnEframescorrect, (const_array_ref &) uids, thisedgealignments, logEframescorrect, Eframescorrectbuf, logEframescorrecttotal); - if (sMBRmode && !returnEframescorrect) + if (EMBR) + { + //compute Beta only, + if (getPathMethodEMBR == "sampling") + { + totalbwscore = backwardlatticeEMBR(edgeacscores, parallelstate, edgelogbetas, logbetas, lmf, wp, amf); + totalfwscore = totalbwscore; // to make the existing code happy + } + else //nbest + { + double bestscore = nbestlatticeEMBR(edgeacscores, parallelstate, tokenlattice, numRawPathsEMBR, enforceValidPathEMBR, excludeSpecialWords, lmf, wp, amf, wordNbest, useAccInNbest, accWeightInNbest, numPathsEMBR, wids); + totalfwscore = bestscore; // to make the code happy, it should be called bestscore, rather than totalfwscore though, will fix later + } + } + else + { + totalfwscore = forwardbackwardlattice(edgeacscores, parallelstate, logpps, logalphas, logbetas, lmf, wp, amf, boostingfactor, returnEframescorrect, (const_array_ref &) uids, thisedgealignments, logEframescorrect, Eframescorrectbuf, logEframescorrecttotal); + if (sMBRmode && !returnEframescorrect) logEframescorrecttotal = forwardbackwardlatticesMBR(edgeacscores, hset, logalphas, logbetas, lmf, wp, amf, (const_array_ref &) uids, thisedgealignments, Eframescorrectbuf); - // ^^ BUGBUG not tested + // ^^ BUGBUG not tested + } } else totalfwscore = bestpathlattice(edgeacscores, logpps, lmf, wp, amf); @@ -1403,8 +2148,9 @@ double lattice::forwardbackward(parallelstate ¶llelstate, const msra::math:: #endif if (islogzero(totalfwscore)) { - fprintf(stderr, "forwardbackward: WARNING: no path found in lattice (%d nodes/%d edges)\n", (int) nodes.size(), (int) edges.size()); - return LOGZERO; // failed, do not use resulting matrix + fprintf(stderr, "forwardbackward: WARNING: totalforwardscore is zero: (%d nodes/%d edges), totalfwscore = %f \n", (int)nodes.size(), (int)edges.size(), totalfwscore); + if(!EMBR || !excludeSpecialWords) + return LOGZERO; // failed, do not use resulting matrix } // PHASE 3: compute final state-level posteriors (MMI mode) @@ -1412,25 +2158,60 @@ double lattice::forwardbackward(parallelstate ¶llelstate, const msra::math:: // compute expected frames correct in sMBRmode const size_t numframes = logLLs.cols(); - assert(numframes == info.numframes); - // fprintf (stderr, "forwardbackward: total forward score %.6f (%d frames)\n", totalfwscore, (int) numframes); // for now--while we are debugging the GPU port - - // MMI mode - if (!sMBRmode) + assert(numframes == info.numframes); + if (EMBR) { - // we first take the sum in log domain to avoid numerical issues - auto &dengammas = result; // result is denominator gammas - mmierrorsignal(parallelstate, minlogpp, origlogpps, abcs, softalignstates, logpps, hset, thisedgealignments, dengammas); - return totalfwscore / numframes; // return value is av. posterior + std::vector> vt_paths; + std::vector edge_weights(edges.size(), 0.0); + std::vector path_posterior_probs; + + double onebest_wer = 0.0; + double avg_wer = 0.0; + // for getPathMethodEMBR=sampling, the onebest_wer does not make any sense, pls. do not use it + // ToDO: if it is logzero(totalfwscore), the criterion shown in the training log is not totally correct: for this problematic utterance, the wer is counted as 0. Problematic in the sense that: we set excludeSpecialWords is true, and found no token survive + + if (!islogzero(totalfwscore)) + { + // Do path sampling + if (getPathMethodEMBR == "sampling") + { + EMBRsamplepaths(edgelogbetas, logbetas, numPathsEMBR, enforceValidPathEMBR, excludeSpecialWords, vt_paths); + path_posterior_probs.resize(vt_paths.size(), (1.0 / vt_paths.size())); + } + else + { + EMBRnbestpaths(tokenlattice, vt_paths, path_posterior_probs); + } + + avg_wer = get_edge_weights(wids, vt_paths, edge_weights, path_posterior_probs, getPathMethodEMBR, onebest_wer); + } + + + auto &errorsignal = result; + EMBRerrorsignal(parallelstate, thisedgealignments, edge_weights, errorsignal); + if(getPathMethodEMBR == "nbest" && showWERMode == "onebest") return onebest_wer; + else return avg_wer; } - // sMBR mode + else { - auto &errorsignal = result; - sMBRerrorsignal(parallelstate, errorsignal, errorsignalbuf, logpps, amf, minlogpp, origlogpps, logEframescorrect, logEframescorrecttotal, thisedgealignments); + // MMI mode + if (!sMBRmode) + { + // we first take the sum in log domain to avoid numerical issues + auto &dengammas = result; // result is denominator gammas + mmierrorsignal(parallelstate, minlogpp, origlogpps, abcs, softalignstates, logpps, hset, thisedgealignments, dengammas); + return totalfwscore / numframes; // return value is av. posterior + } + // sMBR mode + else + { + auto &errorsignal = result; + sMBRerrorsignal(parallelstate, errorsignal, errorsignalbuf, logpps, amf, minlogpp, origlogpps, logEframescorrect, logEframescorrecttotal, thisedgealignments); - static bool dummyvariable = (fprintf(stderr, "note: new version with kappa adjustment, kappa = %.2f\n", 1 / amf), true); // we only print once - return exp(logEframescorrecttotal) / numframes; // return value is av. expected frame-correct count + static bool dummyvariable = (fprintf(stderr, "note: new version with kappa adjustment, kappa = %.2f\n", 1 / amf), true); // we only print once + return exp(logEframescorrecttotal) / numframes; // return value is av. expected frame-correct count + } } } }; diff --git a/Source/SequenceTrainingLib/parallelforwardbackward.cpp b/Source/SequenceTrainingLib/parallelforwardbackward.cpp index 678d89af3d21..374299d7d3fe 100644 --- a/Source/SequenceTrainingLib/parallelforwardbackward.cpp +++ b/Source/SequenceTrainingLib/parallelforwardbackward.cpp @@ -132,7 +132,22 @@ void backwardlatticej(const size_t batchsize, const size_t startindex, const std logaccalphas, Eframescorrectbuf, logaccbetas); } } - +void backwardlatticejEMBR(const size_t batchsize, const size_t startindex, const std::vector& edgeacscores, + const std::vector& edges, + const std::vector& nodes, + std::vector& edgelogbetas, std::vector& logbetas, + float lmf, float wp, float amf) +{ + const size_t tpb = blockDim.x * blockDim.y; // total #threads in a block + const size_t jinblock = threadIdx.x + threadIdx.y * blockDim.x; + const size_t j = jinblock + blockIdx.x * tpb; + if (j < batchsize) // note: will cause issues if we ever use __synctreads() in backwardlatticej + { + msra::lattices::latticefunctionskernels::backwardlatticejEMBR(j + startindex, edgeacscores, + edges, nodes, edgelogbetas, + logbetas, lmf, wp, amf); + } +} void sMBRerrorsignalj(const std::vector& alignstateids, const std::vector& alignoffsets, const std::vector& edges, const std::vector& nodes, const std::vector& logpps, const float amf, const std::vector& logEframescorrect, @@ -146,7 +161,17 @@ void sMBRerrorsignalj(const std::vector& alignstateids, const st errorsignal, errorsignalneg); } } - +void EMBRerrorsignalj(const std::vector& alignstateids, const std::vector& alignoffsets, + const std::vector& edges, const std::vector& nodes, + const std::vector& edgeweights, msra::math::ssematrixbase& errorsignal) +{ + const size_t shufflemode = 3; + const size_t j = msra::lattices::latticefunctionskernels::shuffle(threadIdx.x, blockDim.x, threadIdx.y, blockDim.y, blockIdx.x, gridDim.x, shufflemode); + if (j < edges.size()) // note: will cause issues if we ever use __synctreads() + { + msra::lattices::latticefunctionskernels::EMBRerrorsignalj(j, alignstateids, alignoffsets, edges, nodes, edgeweights, errorsignal); + } +} void stateposteriorsj(const std::vector& alignstateids, const std::vector& alignoffsets, const std::vector& edges, const std::vector& nodes, const std::vector& logqs, msra::math::ssematrixbase& logacc) @@ -298,6 +323,39 @@ static double emulateforwardbackwardlattice(const size_t* batchsizeforward, cons #endif return totalfwscore; } +static double emulatebackwardlatticeEMBR(const size_t* batchsizebackward, const size_t numlaunchbackward, + const std::vector& edgeacscores, + const std::vector& edges, const std::vector& nodes, + std::vector& edgelogbetas, std::vector& logbetas, + const float lmf, const float wp, const float amf) +{ + dim3 t(32, 8); + const size_t tpb = t.x * t.y; + dim3 b((unsigned int)((logbetas.size() + tpb - 1) / tpb)); + + emulatecuda(b, t, [&]() + { + setvaluej(logbetas, LOGZERO, logbetas.size()); + }); + logbetas[nodes.size() - 1] = 0; + size_t startindex = edges.size(); + for (size_t i = 0; i < numlaunchbackward; i++) + { + dim3 b3((unsigned int)((batchsizebackward[i] + tpb - 1) / tpb)); + emulatecuda(b3, t, [&]() + { + backwardlatticejEMBR(batchsizebackward[i], startindex - batchsizebackward[i], edgeacscores, + edges, nodes, edgelogbetas, logbetas, lmf, wp, amf); + + + }); + startindex -= batchsizebackward[i]; + } + double totalbwscore = logbetas.front(); + + + return totalbwscore; +} // this function behaves as its CUDA conterparts, except that it takes CPU-side std::vectors for everything // this must be identical to CUDA kernel-launch function in -ops class (except for the input data types: vectorref -> std::vector) static void emulatesMBRerrorsignal(const std::vector& alignstateids, const std::vector& alignoffsets, @@ -324,6 +382,26 @@ static void emulatesMBRerrorsignal(const std::vector& alignstate }); } +// this function behaves as its CUDA conterparts, except that it takes CPU-side std::vectors for everything +// this must be identical to CUDA kernel-launch function in -ops class (except for the input data types: vectorref -> std::vector) +static void emulateEMBRerrorsignal(const std::vector& alignstateids, const std::vector& alignoffsets, + const std::vector& edges, const std::vector& nodes, + const std::vector& edgeweights, + msra::math::ssematrixbase& errorsignal) +{ + + const size_t numedges = edges.size(); + dim3 t(32, 8); + const size_t tpb = t.x * t.y; + foreach_coord(i, j, errorsignal) + errorsignal(i, j) = 0; + dim3 b((unsigned int)((numedges + tpb - 1) / tpb)); + emulatecuda(b, t, [&]() + { + EMBRerrorsignalj(alignstateids, alignoffsets, edges, nodes, edgeweights, errorsignal); + }); + dim3 b1((((unsigned int)errorsignal.rows()) + 31) / 32); +} // this function behaves as its CUDA conterparts, except that it takes CPU-side std::vectors for everything // this must be identical to CUDA kernel-launch function in -ops class (except for the input data types: vectorref -> std::vector) static void emulatemmierrorsignal(const std::vector& alignstateids, const std::vector& alignoffsets, @@ -388,6 +466,9 @@ struct parallelstateimpl logppsgpu(msra::cuda::newdoublevector(deviceid)), logalphasgpu(msra::cuda::newdoublevector(deviceid)), logbetasgpu(msra::cuda::newdoublevector(deviceid)), + edgelogbetasgpu(msra::cuda::newdoublevector(deviceid)), + edgeweightsgpu(msra::cuda::newdoublevector(deviceid)), + logaccalphasgpu(msra::cuda::newdoublevector(deviceid)), logaccbetasgpu(msra::cuda::newdoublevector(deviceid)), logframescorrectedgegpu(msra::cuda::newdoublevector(deviceid)), @@ -526,6 +607,8 @@ struct parallelstateimpl std::unique_ptr logppsgpu; std::unique_ptr logalphasgpu; + std::unique_ptr edgelogbetasgpu; + std::unique_ptr edgeweightsgpu; std::unique_ptr logbetasgpu; std::unique_ptr logaccalphasgpu; std::unique_ptr logaccbetasgpu; @@ -619,6 +702,18 @@ struct parallelstateimpl logEframescorrectgpu->allocate(edges.size()); } } + template + void allocbwvectorsEMBR(const edgestype& edges, const nodestype& nodes) + { +#ifndef TWO_CHANNEL + const size_t alphabetanoderatio = 1; +#else + const size_t alphabetanoderatio = 2; +#endif + logbetasgpu->allocate(alphabetanoderatio * nodes.size()); + edgelogbetasgpu->allocate(edges.size()); + + } // check if gpumatrixstorage supports size of cpumatrix, if not allocate. set gpumatrix to part of gpumatrixstorage // This function checks the size of errorsignalgpustorage, and then sets errorsignalgpu to a columnslice of the @@ -664,6 +759,30 @@ struct parallelstateimpl edgealignments.resize(alignresult->size()); alignresult->fetch(edgealignments, true); } + void getlogbetas(std::vector& logbetas) + { + logbetas.resize(logbetasgpu->size()); + logbetasgpu->fetch(logbetas, true); + } + + void getedgelogbetas(std::vector& edgelogbetas) + { + edgelogbetas.resize(edgelogbetasgpu->size()); + edgelogbetasgpu->fetch(edgelogbetas, true); + } + + void getedgeweights(std::vector& edgeweights) + { + edgeweights.resize(edgeweightsgpu->size()); + edgeweightsgpu->fetch(edgeweights, true); + } + + + + void setedgeweights(const std::vector& edgeweights) + { + edgeweightsgpu->assign(edgeweights, false); + } }; void lattice::parallelstate::setdevice(size_t deviceid) @@ -725,6 +844,22 @@ void lattice::parallelstate::getedgealignments(std::vector& edge { pimpl->getedgealignments(edgealignments); } +void lattice::parallelstate::getlogbetas(std::vector& logbetas) +{ + pimpl->getlogbetas(logbetas); +} +void lattice::parallelstate::getedgelogbetas(std::vector& edgelogbetas) +{ + pimpl->getedgelogbetas(edgelogbetas); +} +void lattice::parallelstate::getedgeweights(std::vector& edgeweights) +{ + pimpl->getedgeweights(edgeweights); +} +void lattice::parallelstate::setedgeweights(const std::vector& edgeweights) +{ + pimpl->setedgeweights(edgeweights); +} //template void lattice::parallelstate::setloglls(const Microsoft::MSR::CNTK::Matrix& loglls) { @@ -909,6 +1044,68 @@ double lattice::parallelforwardbackwardlattice(parallelstate& parallelstate, con return totalfwscore; } +double lattice::parallelbackwardlatticeEMBR(parallelstate& parallelstate, const std::vector& edgeacscores, + const float lmf, const float wp, const float amf, + std::vector& edgelogbetas, std::vector& logbetas) const +{ // ^^ TODO: remove this + vector batchsizebackward; // record the batch size that exclude the data dependency for backward + + + size_t endindexbackward = edges.back().S; + size_t countbatchbackward = 0; + foreach_index (j, edges) // compute the batch size info for kernel launches + { + const size_t backj = edges.size() - 1 - j; + if (edges[backj].E > endindexbackward) + { + countbatchbackward++; + if (endindexbackward < edges[backj].S) + endindexbackward = edges[backj].S; + } + else + { + batchsizebackward.push_back(countbatchbackward); + countbatchbackward = 1; + endindexbackward = edges[backj].S; + } + } + batchsizebackward.push_back(countbatchbackward); + + + double totalbwscore = 0.0f; + if (!parallelstate->emulation) + { + if (verbosity >= 2) + fprintf(stderr, "parallelbackwardlatticeEMBR: %d launches for backward\n", (int) batchsizebackward.size()); + + + parallelstate->allocbwvectorsEMBR(edges, nodes); + + std::unique_ptr latticefunctions(msra::cuda::newlatticefunctions(parallelstate.getdevice())); // final CUDA call + latticefunctions->backwardlatticeEMBR(&batchsizebackward[0], batchsizebackward.size(), + *parallelstate->edgeacscoresgpu.get(), *parallelstate->edgesgpu.get(), + *parallelstate->nodesgpu.get(), *parallelstate->edgelogbetasgpu.get(), + *parallelstate->logbetasgpu.get(), lmf, wp, amf, totalbwscore); + + } + else // emulation + { +#ifndef TWO_CHANNEL + fprintf(stderr, "forbid invalid sil path\n"); + const size_t alphabetanoderatio = 1; +#else + const size_t alphabetanoderatio = 2; +#endif + logbetas.resize(alphabetanoderatio * nodes.size()); + edgelogbetas.resize(edges.size()); + + + totalbwscore = emulatebackwardlatticeEMBR(&batchsizebackward[0], batchsizebackward.size(), + edgeacscores, edges, nodes, + edgelogbetas, logbetas, lmf, wp, amf); + } + return totalbwscore; +} // ------------------------------------------------------------------------ // parallel implementations of sMBR error updating step // ------------------------------------------------------------------------ @@ -948,6 +1145,32 @@ void lattice::parallelsMBRerrorsignal(parallelstate& parallelstate, const edgeal } } +void lattice::parallelEMBRerrorsignal(parallelstate& parallelstate, const edgealignments& thisedgealignments, + const std::vector& edgeweights, + msra::math::ssematrixbase& errorsignal) const +{ + + if (!parallelstate->emulation) + { + // no need negative buffer for EMBR + const bool cacheerrorsignalneg = false; + parallelstate->cacheerrorsignal(errorsignal, cacheerrorsignalneg); + + std::unique_ptr latticefunctions(msra::cuda::newlatticefunctions(parallelstate.getdevice())); + latticefunctions->EMBRerrorsignal(*parallelstate->alignresult.get(), *parallelstate->alignoffsetsgpu.get(), *parallelstate->edgesgpu.get(), + *parallelstate->nodesgpu.get(), *parallelstate->edgeweightsgpu.get(), + *parallelstate->errorsignalgpu.get()); + + if (errorsignal.rows() > 0 && errorsignal.cols() > 0) + { + parallelstate->errorsignalgpu->CopySection(errorsignal.rows(), errorsignal.cols(), &errorsignal(0, 0), errorsignal.getcolstride()); + } + } + else + { + emulateEMBRerrorsignal(thisedgealignments.getalignmentsbuffer(), thisedgealignments.getalignoffsets(), edges, nodes, edgeweights, errorsignal); + } +} // ------------------------------------------------------------------------ // parallel implementations of MMI error updating step // ------------------------------------------------------------------------