// See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. // // Compresses and decompresses unweighted FSTs. #ifndef FST_EXTENSIONS_COMPRESS_COMPRESS_H_ #define FST_EXTENSIONS_COMPRESS_COMPRESS_H_ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace fst { // Identifies stream data as a vanilla compressed FST. static const int32 kCompressMagicNumber = 1858869554; // Identifies stream data as (probably) a Gzip file accidentally read from // a vanilla stream, without gzip support. static const int32 kGzipMagicNumber = 0x8b1f; // Selects the two most significant bytes. constexpr uint32 kGzipMask = 0xffffffff >> 16; namespace internal { // Expands a Lempel Ziv code and returns the set of code words. expanded_code[i] // is the i^th Lempel Ziv codeword. template bool ExpandLZCode(const std::vector> &code, std::vector> *expanded_code) { expanded_code->resize(code.size()); for (int i = 0; i < code.size(); ++i) { if (code[i].first > i) { LOG(ERROR) << "ExpandLZCode: Not a valid code"; return false; } if (code[i].first == 0) { (*expanded_code)[i].resize(1, code[i].second); } else { (*expanded_code)[i].resize((*expanded_code)[code[i].first - 1].size() + 1); std::copy((*expanded_code)[code[i].first - 1].begin(), (*expanded_code)[code[i].first - 1].end(), (*expanded_code)[i].begin()); (*expanded_code)[i][(*expanded_code)[code[i].first - 1].size()] = code[i].second; } } return true; } } // namespace internal // Lempel Ziv on data structure Edge, with a less than operator // EdgeLessThan and an equals operator EdgeEquals. // Edge has a value defaultedge which it never takes and // Edge is defined, it is initialized to defaultedge template class LempelZiv { public: LempelZiv() : dict_number_(0), default_edge_() { root_.current_number = dict_number_++; root_.current_edge = default_edge_; decode_vector_.push_back(std::make_pair(0, default_edge_)); } // Encodes a vector input into output void BatchEncode(const std::vector &input, std::vector> *output); // Decodes codedvector to output. Returns false if // the index exceeds the size. bool BatchDecode(const std::vector> &input, std::vector *output); // Decodes a single dictionary element. Returns false // if the index exceeds the size. bool SingleDecode(const Var &index, Edge *output) { if (index >= decode_vector_.size()) { LOG(ERROR) << "LempelZiv::SingleDecode: " << "Index exceeded the dictionary size"; return false; } else { *output = decode_vector_[index].second; return true; } } ~LempelZiv() { for (auto it = (root_.next_number).begin(); it != (root_.next_number).end(); ++it) { CleanUp(it->second); } } // Adds a single dictionary element while decoding // void AddDictElement(const std::pair &newdict) { // EdgeEquals InstEdgeEquals; // if (InstEdgeEquals(newdict.second, default_edge_) != 1) // decode_vector_.push_back(newdict); // } private: // Node datastructure is used for encoding struct Node { Var current_number; Edge current_edge; std::map next_number; }; void CleanUp(Node *temp) { for (auto it = (temp->next_number).begin(); it != (temp->next_number).end(); ++it) { CleanUp(it->second); } delete temp; } Node root_; Var dict_number_; // decode_vector_ is used for decoding std::vector> decode_vector_; Edge default_edge_; }; template void LempelZiv::BatchEncode( const std::vector &input, std::vector> *output) { for (typename std::vector::const_iterator it = input.begin(); it != input.end(); ++it) { Node *temp_node = &root_; while (it != input.end()) { auto next = (temp_node->next_number).find(*it); if (next != (temp_node->next_number).end()) { temp_node = next->second; ++it; } else { break; } } if (it == input.end() && temp_node->current_number != 0) { output->push_back( std::make_pair(temp_node->current_number, default_edge_)); } else if (it != input.end()) { output->push_back(std::make_pair(temp_node->current_number, *it)); Node *new_node = new (Node); new_node->current_number = dict_number_++; new_node->current_edge = *it; (temp_node->next_number)[*it] = new_node; } if (it == input.end()) break; } } template bool LempelZiv::BatchDecode( const std::vector> &input, std::vector *output) { for (typename std::vector>::const_iterator it = input.begin(); it != input.end(); ++it) { std::vector temp_output; EdgeEquals InstEdgeEquals; if (InstEdgeEquals(it->second, default_edge_) != 1) { decode_vector_.push_back(*it); temp_output.push_back(it->second); } Var temp_integer = it->first; if (temp_integer >= decode_vector_.size()) { LOG(ERROR) << "LempelZiv::BatchDecode: " << "Index exceeded the dictionary size"; return false; } else { while (temp_integer != 0) { temp_output.push_back(decode_vector_[temp_integer].second); temp_integer = decode_vector_[temp_integer].first; } std::reverse(temp_output.begin(), temp_output.end()); output->insert(output->end(), temp_output.begin(), temp_output.end()); } } return true; } // The main Compressor class template class Compressor { public: typedef typename Arc::StateId StateId; typedef typename Arc::Label Label; typedef typename Arc::Weight Weight; Compressor() {} // Compresses fst into a boolean vector code. Returns true on sucesss. bool Compress(const Fst &fst, std::ostream &strm); // Decompresses the boolean vector into Fst. Returns true on sucesss. bool Decompress(std::istream &strm, const string &source, MutableFst *fst); // Finds the BFS order of a fst void BfsOrder(const ExpandedFst &fst, std::vector *order); // Preprocessing step to convert fst to a isomorphic fst // Returns a preproccess fst and a dictionary void Preprocess(const Fst &fst, MutableFst *preprocessedfst, EncodeMapper *encoder); // Performs Lempel Ziv and outputs a stream of integers // and sends it to a stream void EncodeProcessedFst(const ExpandedFst &fst, std::ostream &strm); // Decodes fst from the stream void DecodeProcessedFst(const std::vector &input, MutableFst *fst, bool unweighted); // Converts buffer_code_ to uint8 and writes to a stream. // Writes the boolean file to the stream void WriteToStream(std::ostream &strm); // Writes the weights to the stream void WriteWeight(const std::vector &input, std::ostream &strm); void ReadWeight(std::istream &strm, std::vector *output); // Same as fst::Decode without the line RmFinalEpsilon(fst) void DecodeForCompress(MutableFst *fst, const EncodeMapper &mapper); // Updates the buffer_code_ template void WriteToBuffer(CVar input) { std::vector current_code; Elias::DeltaEncode(input, ¤t_code); if (!buffer_code_.empty()) { buffer_code_.insert(buffer_code_.end(), current_code.begin(), current_code.end()); } else { buffer_code_.assign(current_code.begin(), current_code.end()); } } private: struct LZLabel { LZLabel() : label(0) {} Label label; }; struct LabelLessThan { bool operator()(const LZLabel &labelone, const LZLabel &labeltwo) const { return labelone.label < labeltwo.label; } }; struct LabelEquals { bool operator()(const LZLabel &labelone, const LZLabel &labeltwo) const { return labelone.label == labeltwo.label; } }; struct Transition { Transition() : nextstate(0), label(0), weight(Weight::Zero()) {} StateId nextstate; Label label; Weight weight; }; struct TransitionLessThan { bool operator()(const Transition &transition_one, const Transition &transition_two) const { if (transition_one.nextstate == transition_two.nextstate) return transition_one.label < transition_two.label; else return transition_one.nextstate < transition_two.nextstate; } } transition_less_than; struct TransitionEquals { bool operator()(const Transition &transition_one, const Transition &transition_two) const { return transition_one.nextstate == transition_two.nextstate && transition_one.label == transition_two.label; } } transition_equals; struct OldDictCompare { bool operator()(const std::pair &pair_one, const std::pair &pair_two) const { if ((pair_one.second).nextstate == (pair_two.second).nextstate) return (pair_one.second).label < (pair_two.second).label; else return (pair_one.second).nextstate < (pair_two.second).nextstate; } } old_dict_compare; std::vector buffer_code_; std::vector arc_weight_; std::vector final_weight_; }; template inline void Compressor::DecodeForCompress( MutableFst *fst, const EncodeMapper &mapper) { ArcMap(fst, EncodeMapper(mapper, DECODE)); fst->SetInputSymbols(mapper.InputSymbols()); fst->SetOutputSymbols(mapper.OutputSymbols()); } // Compressor::BfsOrder template void Compressor::BfsOrder(const ExpandedFst &fst, std::vector *order) { Arc arc; StateId bfs_visit_number = 0; std::queue states_queue; order->assign(fst.NumStates(), kNoStateId); states_queue.push(fst.Start()); (*order)[fst.Start()] = bfs_visit_number++; while (!states_queue.empty()) { for (ArcIterator> aiter(fst, states_queue.front()); !aiter.Done(); aiter.Next()) { arc = aiter.Value(); StateId nextstate = arc.nextstate; if ((*order)[nextstate] == kNoStateId) { (*order)[nextstate] = bfs_visit_number++; states_queue.push(nextstate); } } states_queue.pop(); } // If the FST is unconnected, then the following // code finds them while (bfs_visit_number < fst.NumStates()) { int unseen_state = 0; for (unseen_state = 0; unseen_state < fst.NumStates(); ++unseen_state) { if ((*order)[unseen_state] == kNoStateId) break; } states_queue.push(unseen_state); (*order)[unseen_state] = bfs_visit_number++; while (!states_queue.empty()) { for (ArcIterator> aiter(fst, states_queue.front()); !aiter.Done(); aiter.Next()) { arc = aiter.Value(); StateId nextstate = arc.nextstate; if ((*order)[nextstate] == kNoStateId) { (*order)[nextstate] = bfs_visit_number++; states_queue.push(nextstate); } } states_queue.pop(); } } } template void Compressor::Preprocess(const Fst &fst, MutableFst *preprocessedfst, EncodeMapper *encoder) { *preprocessedfst = fst; if (!preprocessedfst->NumStates()) { return; } // Relabels the edges and develops a dictionary Encode(preprocessedfst, encoder); std::vector order; // Finds the BFS sorting order of the fst BfsOrder(*preprocessedfst, &order); // Reorders the states according to the BFS order StateSort(preprocessedfst, order); } template void Compressor::EncodeProcessedFst(const ExpandedFst &fst, std::ostream &strm) { std::vector output; LempelZiv dict_new; LempelZiv dict_old; std::vector current_new_input; std::vector current_old_input; std::vector> current_new_output; std::vector> current_old_output; std::vector final_states; StateId number_of_states = fst.NumStates(); StateId seen_states = 0; // Adding the number of states WriteToBuffer(number_of_states); for (StateId state = 0; state < number_of_states; ++state) { current_new_input.clear(); current_old_input.clear(); current_new_output.clear(); current_old_output.clear(); if (state > seen_states) ++seen_states; // Collecting the final states if (fst.Final(state) != Weight::Zero()) { final_states.push_back(state); final_weight_.push_back(fst.Final(state)); } // Reading the states for (ArcIterator> aiter(fst, state); !aiter.Done(); aiter.Next()) { Arc arc = aiter.Value(); if (arc.nextstate > seen_states) { // RILEY: > or >= ? ++seen_states; LZLabel temp_label; temp_label.label = arc.ilabel; arc_weight_.push_back(arc.weight); current_new_input.push_back(temp_label); } else { Transition temp_transition; temp_transition.nextstate = arc.nextstate; temp_transition.label = arc.ilabel; temp_transition.weight = arc.weight; current_old_input.push_back(temp_transition); } } // Adding new states dict_new.BatchEncode(current_new_input, ¤t_new_output); WriteToBuffer(current_new_output.size()); for (auto it = current_new_output.begin(); it != current_new_output.end(); ++it) { WriteToBuffer(it->first); WriteToBuffer