kongdeqiang
5 天以前 28ccfbfc51068a663a80764e14074df5edf2b5ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
// See www.openfst.org for extensive documentation on this weighted
// finite-state transducer library.
//
 
#include <fst/symbol-table-ops.h>
 
#include <string>
 
namespace fst {
 
SymbolTable *MergeSymbolTable(const SymbolTable &left, const SymbolTable &right,
                              bool *right_relabel_output) {
  // MergeSymbolTable detects several special cases.  It will return a reference
  // copied version of SymbolTable of left or right if either symbol table is
  // a superset of the other.
  std::unique_ptr<SymbolTable> merged(
      new SymbolTable("merge_" + left.Name() + "_" + right.Name()));
  // Copies everything from the left symbol table.
  bool left_has_all = true;
  bool right_has_all = true;
  bool relabel = false;
  for (SymbolTableIterator liter(left); !liter.Done(); liter.Next()) {
    merged->AddSymbol(liter.Symbol(), liter.Value());
    if (right_has_all) {
      int64 key = right.Find(liter.Symbol());
      if (key == -1) {
        right_has_all = false;
      } else if (!relabel && key != liter.Value()) {
        relabel = true;
      }
    }
  }
  if (right_has_all) {
    if (right_relabel_output) *right_relabel_output = relabel;
    return right.Copy();
  }
  // add all symbols we can from right symbol table
  std::vector<string> conflicts;
  for (SymbolTableIterator riter(right); !riter.Done(); riter.Next()) {
    int64 key = merged->Find(riter.Symbol());
    if (key != -1) {
      // Symbol already exists, maybe with different value
      if (key != riter.Value()) relabel = true;
      continue;
    }
    // Symbol doesn't exist from left
    left_has_all = false;
    if (!merged->Find(riter.Value()).empty()) {
      // we can't add this where we want to, add it later, in order
      conflicts.push_back(riter.Symbol());
      continue;
    }
    // there is a hole and we can add this symbol with its id
    merged->AddSymbol(riter.Symbol(), riter.Value());
  }
  if (right_relabel_output) *right_relabel_output = relabel;
  if (left_has_all) return left.Copy();
  // Add all symbols that conflicted, in order
  for (const auto &conflict : conflicts) merged->AddSymbol(conflict);
  return merged.release();
}
 
SymbolTable *CompactSymbolTable(const SymbolTable &syms) {
  std::map<int64, string> sorted;
  SymbolTableIterator stiter(syms);
  for (; !stiter.Done(); stiter.Next()) {
    sorted[stiter.Value()] = stiter.Symbol();
  }
  auto *compact = new SymbolTable(syms.Name() + "_compact");
  int64 newkey = 0;
  for (const auto &kv : sorted) compact->AddSymbol(kv.second, newkey++);
  return compact;
}
 
SymbolTable *FstReadSymbols(const string &filename, bool input_symbols) {
  std::ifstream in(filename, std::ios_base::in | std::ios_base::binary);
  if (!in) {
    LOG(ERROR) << "FstReadSymbols: Can't open file " << filename;
    return nullptr;
  }
  FstHeader hdr;
  if (!hdr.Read(in, filename)) {
    LOG(ERROR) << "FstReadSymbols: Couldn't read header from " << filename;
    return nullptr;
  }
  if (hdr.GetFlags() & FstHeader::HAS_ISYMBOLS) {
    std::unique_ptr<SymbolTable> isymbols(SymbolTable::Read(in, filename));
    if (isymbols == nullptr) {
      LOG(ERROR) << "FstReadSymbols: Couldn't read input symbols from "
                 << filename;
      return nullptr;
    }
    if (input_symbols) return isymbols.release();
  }
  if (hdr.GetFlags() & FstHeader::HAS_OSYMBOLS) {
    std::unique_ptr<SymbolTable> osymbols(SymbolTable::Read(in, filename));
    if (osymbols == nullptr) {
      LOG(ERROR) << "FstReadSymbols: Couldn't read output symbols from "
                 << filename;
      return nullptr;
    }
    if (!input_symbols) return osymbols.release();
  }
  LOG(ERROR) << "FstReadSymbols: The file " << filename
             << " doesn't contain the requested symbols";
  return nullptr;
}
 
bool AddAuxiliarySymbols(const string &prefix, int64 start_label,
                         int64 nlabels, SymbolTable *syms) {
  for (int64 i = 0; i < nlabels; ++i) {
    auto index = i + start_label;
    if (index != syms->AddSymbol(prefix + std::to_string(i), index)) {
      FSTERROR() << "AddAuxiliarySymbols: Symbol table clash";
      return false;
    }
  }
  return true;
}
 
}  // namespace fst