00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 #ifndef FST_COMPILE_MAIN_H__
00023 #define FST_COMPILE_MAIN_H__
00024
00025 #include <tr1/unordered_map>
00026 using std::tr1::unordered_map;
00027 #include <sstream>
00028 #include <string>
00029 #include <vector>
00030
00031 #include <fst/vector-fst.h>
00032 #include <fst/main.h>
00033
00034 DECLARE_bool(acceptor);
00035 DECLARE_string(arc_type);
00036 DECLARE_string(fst_type);
00037
00038 DECLARE_string(isymbols);
00039 DECLARE_string(osymbols);
00040 DECLARE_string(ssymbols);
00041
00042 DECLARE_bool(keep_isymbols);
00043 DECLARE_bool(keep_osymbols);
00044 DECLARE_bool(keep_state_numbering);
00045
00046 DECLARE_bool(allow_negative_labels);
00047
00048 namespace fst {
00049
00050 template <class A> class FstReader {
00051 public:
00052 typedef A Arc;
00053 typedef typename A::StateId StateId;
00054 typedef typename A::Label Label;
00055 typedef typename A::Weight Weight;
00056
00057 FstReader(istream &istrm, const string &source,
00058 const SymbolTable *isyms, const SymbolTable *osyms,
00059 const SymbolTable *ssyms, bool accep, bool ikeep,
00060 bool okeep, bool nkeep)
00061 : nline_(0), source_(source),
00062 isyms_(isyms), osyms_(osyms), ssyms_(ssyms),
00063 nstates_(0), keep_state_numbering_(nkeep) {
00064 char line[kLineLen];
00065 while (istrm.getline(line, kLineLen)) {
00066 ++nline_;
00067 vector<char *> col;
00068 SplitToVector(line, "\n\t ", &col, true);
00069 if (col.size() == 0 || col[0][0] == '\0')
00070 continue;
00071 if (col.size() > 5 ||
00072 (col.size() > 4 && accep) ||
00073 (col.size() == 3 && !accep)) {
00074 LOG(ERROR) << "FstReader: Bad number of columns, source = " << source_
00075 << ", line = " << nline_;
00076 exit(1);
00077 }
00078 StateId s = StrToStateId(col[0]);
00079 while (s >= fst_.NumStates())
00080 fst_.AddState();
00081 if (nline_ == 1)
00082 fst_.SetStart(s);
00083
00084 Arc arc;
00085 StateId d = s;
00086 switch (col.size()) {
00087 case 1:
00088 fst_.SetFinal(s, Weight::One());
00089 break;
00090 case 2:
00091 fst_.SetFinal(s, StrToWeight(col[1], true));
00092 break;
00093 case 3:
00094 arc.nextstate = d = StrToStateId(col[1]);
00095 arc.ilabel = StrToILabel(col[2]);
00096 arc.olabel = arc.ilabel;
00097 arc.weight = Weight::One();
00098 fst_.AddArc(s, arc);
00099 break;
00100 case 4:
00101 arc.nextstate = d = StrToStateId(col[1]);
00102 arc.ilabel = StrToILabel(col[2]);
00103 if (accep) {
00104 arc.olabel = arc.ilabel;
00105 arc.weight = StrToWeight(col[3], false);
00106 } else {
00107 arc.olabel = StrToOLabel(col[3]);
00108 arc.weight = Weight::One();
00109 }
00110 fst_.AddArc(s, arc);
00111 break;
00112 case 5:
00113 arc.nextstate = d = StrToStateId(col[1]);
00114 arc.ilabel = StrToILabel(col[2]);
00115 arc.olabel = StrToOLabel(col[3]);
00116 arc.weight = StrToWeight(col[4], false);
00117 fst_.AddArc(s, arc);
00118 }
00119 while (d >= fst_.NumStates())
00120 fst_.AddState();
00121 }
00122 if (ikeep)
00123 fst_.SetInputSymbols(isyms);
00124 if (okeep)
00125 fst_.SetOutputSymbols(osyms);
00126 }
00127
00128 const VectorFst<A> &Fst() const { return fst_; }
00129
00130 private:
00131
00132 static const int kLineLen = 8096;
00133
00134 int64 StrToId(const char *s, const SymbolTable *syms,
00135 const char *name, bool allow_negative = false) const {
00136 int64 n;
00137
00138 if (syms) {
00139 n = syms->Find(s);
00140 if (n == -1 || (!allow_negative && n < 0)) {
00141 LOG(ERROR) << "FstReader: Symbol \"" << s
00142 << "\" is not mapped to any integer " << name
00143 << ", symbol table = " << syms->Name()
00144 << ", source = " << source_ << ", line = " << nline_;
00145 exit(1);
00146 }
00147 } else {
00148 char *p;
00149 n = strtoll(s, &p, 10);
00150 if (p < s + strlen(s) || (!allow_negative && n < 0)) {
00151 LOG(ERROR) << "FstReader: Bad " << name << " integer = \"" << s
00152 << "\", source = " << source_ << ", line = " << nline_;
00153 exit(1);
00154 }
00155 }
00156 return n;
00157 }
00158
00159 StateId StrToStateId(const char *s) {
00160 StateId n = StrToId(s, ssyms_, "state ID");
00161
00162 if (keep_state_numbering_)
00163 return n;
00164
00165
00166 typename unordered_map<StateId, StateId>::const_iterator it = states_.find(n);
00167 if (it == states_.end()) {
00168 states_[n] = nstates_;
00169 return nstates_++;
00170 } else {
00171 return it->second;
00172 }
00173 }
00174
00175 StateId StrToILabel(const char *s) const {
00176 return StrToId(s, isyms_, "arc ilabel", FLAGS_allow_negative_labels);
00177 }
00178
00179 StateId StrToOLabel(const char *s) const {
00180 return StrToId(s, osyms_, "arc olabel", FLAGS_allow_negative_labels);
00181 }
00182
00183 Weight StrToWeight(const char *s, bool allow_zero) const {
00184 Weight w;
00185 istringstream strm(s);
00186 strm >> w;
00187 if (!strm || (!allow_zero && w == Weight::Zero())) {
00188 LOG(ERROR) << "FstReader: Bad weight = \"" << s
00189 << "\", source = " << source_ << ", line = " << nline_;
00190 exit(1);
00191 }
00192 return w;
00193 }
00194
00195 VectorFst<A> fst_;
00196 size_t nline_;
00197 string source_;
00198 const SymbolTable *isyms_;
00199 const SymbolTable *osyms_;
00200 const SymbolTable *ssyms_;
00201 unordered_map<StateId, StateId> states_;
00202 StateId nstates_;
00203 bool keep_state_numbering_;
00204 DISALLOW_COPY_AND_ASSIGN(FstReader);
00205 };
00206
00207
00208
00209
00210 template <class Arc>
00211 int CompileMain(int argc, char **argv, istream & ,
00212 const FstReadOptions & ) {
00213 const char *source = "standard input";
00214 istream *istrm = &std::cin;
00215 if (argc > 1 && strcmp(argv[1], "-") != 0) {
00216 source = argv[1];
00217 istrm = new ifstream(argv[1]);
00218 if (!istrm) {
00219 LOG(ERROR) << argv[0] << ": Open failed, file = " << argv[1];
00220 return 1;
00221 }
00222 }
00223 const SymbolTable *isyms = 0, *osyms = 0, *ssyms = 0;
00224
00225 if (!FLAGS_isymbols.empty()) {
00226 isyms = SymbolTable::ReadText(FLAGS_isymbols, FLAGS_allow_negative_labels);
00227 if (!isyms) exit(1);
00228 }
00229
00230 if (!FLAGS_osymbols.empty()) {
00231 osyms = SymbolTable::ReadText(FLAGS_osymbols, FLAGS_allow_negative_labels);
00232 if (!osyms) exit(1);
00233 }
00234
00235 if (!FLAGS_ssymbols.empty()) {
00236 ssyms = SymbolTable::ReadText(FLAGS_ssymbols);
00237 if (!ssyms) exit(1);
00238 }
00239
00240 FstReader<Arc> fstreader(*istrm, source, isyms, osyms, ssyms,
00241 FLAGS_acceptor, FLAGS_keep_isymbols,
00242 FLAGS_keep_osymbols, FLAGS_keep_state_numbering);
00243
00244 const Fst<Arc> *fst = &fstreader.Fst();
00245 if (FLAGS_fst_type != "vector") {
00246 fst = Convert<Arc>(*fst, FLAGS_fst_type);
00247 if (!fst) return 1;
00248 }
00249 fst->Write(argc > 2 ? argv[2] : "");
00250 if (istrm != &std::cin)
00251 delete istrm;
00252 return 0;
00253 }
00254
00255 }
00256
00257 #endif // FST_COMPILE_MAIN_H__