• Main Page
  • Related Pages
  • Data Structures
  • Files
  • File List
  • Globals

src/libsphinxbase/lm/lm3g_templates.c

00001 /* -*- c-basic-offset: 4; indent-tabs-mode: nil -*- */
00002 /* ====================================================================
00003  * Copyright (c) 1999-2007 Carnegie Mellon University.  All rights
00004  * reserved.
00005  *
00006  * Redistribution and use in source and binary forms, with or without
00007  * modification, are permitted provided that the following conditions
00008  * are met:
00009  *
00010  * 1. Redistributions of source code must retain the above copyright
00011  *    notice, this list of conditions and the following disclaimer. 
00012  *
00013  * 2. Redistributions in binary form must reproduce the above copyright
00014  *    notice, this list of conditions and the following disclaimer in
00015  *    the documentation and/or other materials provided with the
00016  *    distribution.
00017  *
00018  * This work was supported in part by funding from the Defense Advanced 
00019  * Research Projects Agency and the National Science Foundation of the 
00020  * United States of America, and the CMU Sphinx Speech Consortium.
00021  *
00022  * THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND 
00023  * ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 
00024  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
00025  * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY
00026  * NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
00027  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 
00028  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 
00029  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 
00030  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 
00031  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 
00032  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
00033  *
00034  * ====================================================================
00035  *
00036  */
00037 /*
00038  * \file lm3g_templates.c Core Sphinx 3-gram code used in
00039  * DMP/DMP32/ARPA (for now) model code.
00040  */
00041 
00042 #include <assert.h>
00043 
00044 /* Locate a specific bigram within a bigram list */
00045 #define BINARY_SEARCH_THRESH    16
00046 static int32
00047 find_bg(bigram_t * bg, int32 n, int32 w)
00048 {
00049     int32 i, b, e;
00050 
00051     /* Binary search until segment size < threshold */
00052     b = 0;
00053     e = n;
00054     while (e - b > BINARY_SEARCH_THRESH) {
00055         i = (b + e) >> 1;
00056         if (bg[i].wid < w)
00057             b = i + 1;
00058         else if (bg[i].wid > w)
00059             e = i;
00060         else
00061             return i;
00062     }
00063 
00064     /* Linear search within narrowed segment */
00065     for (i = b; (i < e) && (bg[i].wid != w); i++);
00066     return ((i < e) ? i : -1);
00067 }
00068 
00069 static int32
00070 lm3g_bg_score(NGRAM_MODEL_TYPE *model,
00071               int32 lw1, int32 lw2, int32 *n_used)
00072 {
00073     int32 i, n, b, score;
00074     bigram_t *bg;
00075 
00076     if (lw1 < 0) {
00077         *n_used = 1;
00078         return model->lm3g.unigrams[lw2].prob1.l;
00079     }
00080 
00081     b = FIRST_BG(model, lw1);
00082     n = FIRST_BG(model, lw1 + 1) - b;
00083     bg = model->lm3g.bigrams + b;
00084 
00085     if ((i = find_bg(bg, n, lw2)) >= 0) {
00086         /* Access mode = bigram */
00087         *n_used = 2;
00088         score = model->lm3g.prob2[bg[i].prob2].l;
00089     }
00090     else {
00091         /* Access mode = unigram */
00092         *n_used = 1;
00093         score = model->lm3g.unigrams[lw1].bo_wt1.l + model->lm3g.unigrams[lw2].prob1.l;
00094     }
00095 
00096     return (score);
00097 }
00098 
00099 static void
00100 load_tginfo(NGRAM_MODEL_TYPE *model, int32 lw1, int32 lw2)
00101 {
00102     int32 i, n, b, t;
00103     bigram_t *bg;
00104     tginfo_t *tginfo;
00105 
00106     /* First allocate space for tg information for bg lw1,lw2 */
00107     tginfo = (tginfo_t *) listelem_malloc(model->lm3g.le);
00108     tginfo->w1 = lw1;
00109     tginfo->tg = NULL;
00110     tginfo->next = model->lm3g.tginfo[lw2];
00111     model->lm3g.tginfo[lw2] = tginfo;
00112 
00113     /* Locate bigram lw1,lw2 */
00114     b = model->lm3g.unigrams[lw1].bigrams;
00115     n = model->lm3g.unigrams[lw1 + 1].bigrams - b;
00116     bg = model->lm3g.bigrams + b;
00117 
00118     if ((n > 0) && ((i = find_bg(bg, n, lw2)) >= 0)) {
00119         tginfo->bowt = model->lm3g.bo_wt2[bg[i].bo_wt2].l;
00120 
00121         /* Find t = Absolute first trigram index for bigram lw1,lw2 */
00122         b += i;                 /* b = Absolute index of bigram lw1,lw2 on disk */
00123         t = FIRST_TG(model, b);
00124 
00125         tginfo->tg = model->lm3g.trigrams + t;
00126 
00127         /* Find #tg for bigram w1,w2 */
00128         tginfo->n_tg = FIRST_TG(model, b + 1) - t;
00129     }
00130     else {                      /* No bigram w1,w2 */
00131         tginfo->bowt = 0;
00132         tginfo->n_tg = 0;
00133     }
00134 }
00135 
00136 /* Similar to find_bg */
00137 static int32
00138 find_tg(trigram_t * tg, int32 n, int32 w)
00139 {
00140     int32 i, b, e;
00141 
00142     b = 0;
00143     e = n;
00144     while (e - b > BINARY_SEARCH_THRESH) {
00145         i = (b + e) >> 1;
00146         if (tg[i].wid < w)
00147             b = i + 1;
00148         else if (tg[i].wid > w)
00149             e = i;
00150         else
00151             return i;
00152     }
00153 
00154     for (i = b; (i < e) && (tg[i].wid != w); i++);
00155     return ((i < e) ? i : -1);
00156 }
00157 
00158 static int32
00159 lm3g_tg_score(NGRAM_MODEL_TYPE *model, int32 lw1,
00160               int32 lw2, int32 lw3, int32 *n_used)
00161 {
00162     ngram_model_t *base = &model->base;
00163     int32 i, n, score;
00164     trigram_t *tg;
00165     tginfo_t *tginfo, *prev_tginfo;
00166 
00167     if ((base->n < 3) || (lw1 < 0) || (lw2 < 0))
00168         return (lm3g_bg_score(model, lw2, lw3, n_used));
00169 
00170     prev_tginfo = NULL;
00171     for (tginfo = model->lm3g.tginfo[lw2]; tginfo; tginfo = tginfo->next) {
00172         if (tginfo->w1 == lw1)
00173             break;
00174         prev_tginfo = tginfo;
00175     }
00176 
00177     if (!tginfo) {
00178         load_tginfo(model, lw1, lw2);
00179         tginfo = model->lm3g.tginfo[lw2];
00180     }
00181     else if (prev_tginfo) {
00182         prev_tginfo->next = tginfo->next;
00183         tginfo->next = model->lm3g.tginfo[lw2];
00184         model->lm3g.tginfo[lw2] = tginfo;
00185     }
00186 
00187     tginfo->used = 1;
00188 
00189     /* Trigrams for w1,w2 now pointed to by tginfo */
00190     n = tginfo->n_tg;
00191     tg = tginfo->tg;
00192     if ((i = find_tg(tg, n, lw3)) >= 0) {
00193         /* Access mode = trigram */
00194         *n_used = 3;
00195         score = model->lm3g.prob3[tg[i].prob3].l;
00196     }
00197     else {
00198         score = tginfo->bowt + lm3g_bg_score(model, lw2, lw3, n_used);
00199     }
00200 
00201     return (score);
00202 }
00203 
00204 static int32
00205 lm3g_template_score(ngram_model_t *base, int32 wid,
00206                       int32 *history, int32 n_hist,
00207                       int32 *n_used)
00208 {
00209     NGRAM_MODEL_TYPE *model = (NGRAM_MODEL_TYPE *)base;
00210     switch (n_hist) {
00211     case 0:
00212         /* Access mode: unigram */
00213         *n_used = 1;
00214         return model->lm3g.unigrams[wid].prob1.l;
00215     case 1:
00216         return lm3g_bg_score(model, history[0], wid, n_used);
00217     case 2:
00218     default:
00219         /* Anything greater than 2 is the same as a trigram for now. */
00220         return lm3g_tg_score(model, history[1], history[0], wid, n_used);
00221     }
00222 }
00223 
00224 static int32
00225 lm3g_template_raw_score(ngram_model_t *base, int32 wid,
00226                         int32 *history, int32 n_hist,
00227                           int32 *n_used)
00228 {
00229     NGRAM_MODEL_TYPE *model = (NGRAM_MODEL_TYPE *)base;
00230     int32 score;
00231 
00232     switch (n_hist) {
00233     case 0:
00234         /* Access mode: unigram */
00235         *n_used = 1;
00236         /* Undo insertion penalty. */
00237         score = model->lm3g.unigrams[wid].prob1.l - base->log_wip;
00238         /* Undo language weight. */
00239         score = (int32)(score / base->lw);
00240         /* Undo unigram interpolation */
00241         if (strcmp(base->word_str[wid], "<s>") != 0) { /* FIXME: configurable start_sym */
00242             score = logmath_log(base->lmath,
00243                                 logmath_exp(base->lmath, score)
00244                                 - logmath_exp(base->lmath, 
00245                                               base->log_uniform + base->log_uniform_weight));
00246         }
00247         return score;
00248     case 1:
00249         score = lm3g_bg_score(model, history[0], wid, n_used);
00250         break;
00251     case 2:
00252     default:
00253         /* Anything greater than 2 is the same as a trigram for now. */
00254         score = lm3g_tg_score(model, history[1], history[0], wid, n_used);
00255         break;
00256     }
00257     /* FIXME (maybe): This doesn't undo unigram weighting in backoff cases. */
00258     return (int32)((score - base->log_wip) / base->lw);
00259 }
00260 
00261 static int32
00262 lm3g_template_add_ug(ngram_model_t *base,
00263                        int32 wid, int32 lweight)
00264 {
00265     NGRAM_MODEL_TYPE *model = (NGRAM_MODEL_TYPE *)base;
00266     return lm3g_add_ug(base, &model->lm3g, wid, lweight);
00267 }
00268 
00269 static void
00270 lm3g_template_flush(ngram_model_t *base)
00271 {
00272     NGRAM_MODEL_TYPE *model = (NGRAM_MODEL_TYPE *)base;
00273     lm3g_tginfo_reset(base, &model->lm3g);
00274 }
00275 
00276 typedef struct lm3g_iter_s {
00277     ngram_iter_t base;
00278     unigram_t *ug;
00279     bigram_t *bg;
00280     trigram_t *tg;
00281 } lm3g_iter_t;
00282 
00283 static ngram_iter_t *
00284 lm3g_template_iter(ngram_model_t *base, int32 wid,
00285                    int32 *history, int32 n_hist)
00286 {
00287     NGRAM_MODEL_TYPE *model = (NGRAM_MODEL_TYPE *)base;
00288     lm3g_iter_t *itor = ckd_calloc(1, sizeof(*itor));
00289 
00290     ngram_iter_init((ngram_iter_t *)itor, base, n_hist, FALSE);
00291 
00292     if (n_hist == 0) {
00293         /* Unigram is the easiest. */
00294         itor->ug = model->lm3g.unigrams + wid;
00295         return (ngram_iter_t *)itor;
00296     }
00297     else if (n_hist == 1) {
00298         int32 i, n, b;
00299         /* Find the bigram, as in bg_score above (duplicate code...) */
00300         itor->ug = model->lm3g.unigrams + history[0];
00301         b = FIRST_BG(model, history[0]);
00302         n = FIRST_BG(model, history[0] + 1) - b;
00303         itor->bg = model->lm3g.bigrams + b;
00304         /* If no such bigram exists then fail. */
00305         if ((i = find_bg(itor->bg, n, wid)) < 0) {
00306             ngram_iter_free((ngram_iter_t *)itor);
00307             return NULL;
00308         }
00309         itor->bg += i;
00310         return (ngram_iter_t *)itor;
00311     }
00312     else if (n_hist == 2) {
00313         int32 i, n;
00314         tginfo_t *tginfo, *prev_tginfo;
00315         /* Find the trigram, as in tg_score above (duplicate code...) */
00316         itor->ug = model->lm3g.unigrams + history[1];
00317         prev_tginfo = NULL;
00318         for (tginfo = model->lm3g.tginfo[history[0]];
00319              tginfo; tginfo = tginfo->next) {
00320             if (tginfo->w1 == history[1])
00321                 break;
00322             prev_tginfo = tginfo;
00323         }
00324 
00325         if (!tginfo) {
00326             load_tginfo(model, history[1], history[0]);
00327             tginfo = model->lm3g.tginfo[history[0]];
00328         }
00329         else if (prev_tginfo) {
00330             prev_tginfo->next = tginfo->next;
00331             tginfo->next = model->lm3g.tginfo[history[0]];
00332             model->lm3g.tginfo[history[0]] = tginfo;
00333         }
00334 
00335         tginfo->used = 1;
00336 
00337         /* Trigrams for w1,w2 now pointed to by tginfo */
00338         n = tginfo->n_tg;
00339         itor->tg = tginfo->tg;
00340         if ((i = find_tg(itor->tg, n, wid)) >= 0) {
00341             itor->tg += i;
00342             /* Now advance the bigram pointer accordingly.  FIXME:
00343              * Note that we actually already found the relevant bigram
00344              * in load_tginfo. */
00345             itor->bg = model->lm3g.bigrams;
00346             while (FIRST_TG(model, (itor->bg - model->lm3g.bigrams + 1))
00347                    <= (itor->tg - model->lm3g.trigrams))
00348                 ++itor->bg;
00349             return (ngram_iter_t *)itor;
00350         }
00351         else {
00352             ngram_iter_free((ngram_iter_t *)itor);
00353             return (ngram_iter_t *)NULL;
00354         }
00355     }
00356     else {
00357         /* Should not happen. */
00358         assert(n_hist == 0); /* Guaranteed to fail. */
00359         ngram_iter_free((ngram_iter_t *)itor);
00360         return NULL;
00361     }
00362 }
00363 
00364 static ngram_iter_t *
00365 lm3g_template_mgrams(ngram_model_t *base, int m)
00366 {
00367     NGRAM_MODEL_TYPE *model = (NGRAM_MODEL_TYPE *)base;
00368     lm3g_iter_t *itor = ckd_calloc(1, sizeof(*itor));
00369     ngram_iter_init((ngram_iter_t *)itor, base, m, FALSE);
00370 
00371     itor->ug = model->lm3g.unigrams;
00372     itor->bg = model->lm3g.bigrams;
00373     itor->tg = model->lm3g.trigrams;
00374 
00375     /* Advance bigram pointer to match first trigram. */
00376     if (m > 1 && base->n_counts[1] > 1)  {
00377         while (FIRST_TG(model, (itor->bg - model->lm3g.bigrams + 1))
00378                <= (itor->tg - model->lm3g.trigrams))
00379             ++itor->bg;
00380     }
00381 
00382     /* Advance unigram pointer to match first bigram. */
00383     if (m > 0 && base->n_counts[0] > 1) {
00384         while (itor->ug[1].bigrams <= (itor->bg - model->lm3g.bigrams))
00385             ++itor->ug;
00386     }
00387 
00388     return (ngram_iter_t *)itor;
00389 }
00390 
00391 static ngram_iter_t *
00392 lm3g_template_successors(ngram_iter_t *bitor)
00393 {
00394     NGRAM_MODEL_TYPE *model = (NGRAM_MODEL_TYPE *)bitor->model;
00395     lm3g_iter_t *from = (lm3g_iter_t *)bitor;
00396     lm3g_iter_t *itor = ckd_calloc(1, sizeof(*itor));
00397 
00398     itor->ug = from->ug;
00399     switch (bitor->m) {
00400     case 0:
00401         /* Next itor bigrams is the same as this itor bigram or
00402            itor bigrams is more than total count. This means no successors */
00403         if (((itor->ug + 1) - model->lm3g.unigrams < bitor->model->n_counts[0] &&
00404             itor->ug->bigrams == (itor->ug + 1)->bigrams) || 
00405             itor->ug->bigrams == bitor->model->n_counts[1])
00406             goto done;
00407             
00408         /* Start iterating from first bigram successor of from->ug. */
00409         itor->bg = model->lm3g.bigrams + itor->ug->bigrams;
00410         break;
00411     case 1:
00412         itor->bg = from->bg;
00413 
00414         /* This indicates no successors */
00415         if (((itor->bg + 1) - model->lm3g.bigrams < bitor->model->n_counts[1] &&
00416             FIRST_TG (model, itor->bg - model->lm3g.bigrams) == 
00417             FIRST_TG (model, (itor->bg + 1) - model->lm3g.bigrams)) ||
00418             FIRST_TG (model, itor->bg - model->lm3g.bigrams) == bitor->model->n_counts[2])
00419             goto done;
00420             
00421         /* Start iterating from first trigram successor of from->bg. */
00422         itor->tg = (model->lm3g.trigrams 
00423                     + FIRST_TG(model, (itor->bg - model->lm3g.bigrams)));
00424 #if 0
00425         printf("%s %s => %d (%s)\n",
00426                model->base.word_str[itor->ug - model->lm3g.unigrams],
00427                model->base.word_str[itor->bg->wid],
00428                FIRST_TG(model, (itor->bg - model->lm3g.bigrams)),
00429                model->base.word_str[itor->tg->wid]);
00430 #endif
00431         break;
00432     case 2:
00433     default:
00434         /* All invalid! */
00435         goto done;
00436     }
00437 
00438     ngram_iter_init((ngram_iter_t *)itor, bitor->model, bitor->m + 1, TRUE);
00439     return (ngram_iter_t *)itor;
00440     done:
00441         ckd_free(itor);
00442         return NULL;
00443 }
00444 
00445 static int32 const *
00446 lm3g_template_iter_get(ngram_iter_t *base,
00447                        int32 *out_score, int32 *out_bowt)
00448 {
00449     NGRAM_MODEL_TYPE *model = (NGRAM_MODEL_TYPE *)base->model;
00450     lm3g_iter_t *itor = (lm3g_iter_t *)base;
00451 
00452     base->wids[0] = itor->ug - model->lm3g.unigrams;
00453     if (itor->bg) base->wids[1] = itor->bg->wid;
00454     if (itor->tg) base->wids[2] = itor->tg->wid;
00455 #if 0
00456     printf("itor_get: %d %d %d\n", base->wids[0], base->wids[1], base->wids[2]);
00457 #endif
00458 
00459     switch (base->m) {
00460     case 0:
00461         *out_score = itor->ug->prob1.l;
00462         *out_bowt = itor->ug->bo_wt1.l;
00463         break;
00464     case 1:
00465         *out_score = model->lm3g.prob2[itor->bg->prob2].l;
00466         if (model->lm3g.bo_wt2)
00467             *out_bowt = model->lm3g.bo_wt2[itor->bg->bo_wt2].l;
00468         else
00469             *out_bowt = 0;
00470         break;
00471     case 2:
00472         *out_score = model->lm3g.prob3[itor->tg->prob3].l;
00473         *out_bowt = 0;
00474         break;
00475     default: /* Should not happen. */
00476         return NULL;
00477     }
00478     return base->wids;
00479 }
00480 
00481 static ngram_iter_t *
00482 lm3g_template_iter_next(ngram_iter_t *base)
00483 {
00484     NGRAM_MODEL_TYPE *model = (NGRAM_MODEL_TYPE *)base->model;
00485     lm3g_iter_t *itor = (lm3g_iter_t *)base;
00486 
00487     switch (base->m) {
00488     case 0:
00489         ++itor->ug;
00490         /* Check for end condition. */
00491         if (itor->ug - model->lm3g.unigrams >= base->model->n_counts[0])
00492             goto done;
00493         break;
00494     case 1:
00495         ++itor->bg;
00496         /* Check for end condition. */
00497         if (itor->bg - model->lm3g.bigrams >= base->model->n_counts[1])
00498             goto done;
00499         /* Advance unigram pointer if necessary in order to get one
00500          * that points to this bigram. */
00501         while (itor->bg - model->lm3g.bigrams >= itor->ug[1].bigrams) {
00502             /* Stop if this is a successor iterator, since we don't
00503              * want a new unigram. */
00504             if (base->successor)
00505                 goto done;
00506             ++itor->ug;
00507             if (itor->ug == model->lm3g.unigrams + base->model->n_counts[0]) {
00508                 E_ERROR("Bigram %d has no vaild unigram parent\n",
00509                         itor->bg - model->lm3g.bigrams);
00510                 goto done;
00511             }
00512         }
00513         break;
00514     case 2:
00515         ++itor->tg;
00516         /* Check for end condition. */
00517         if (itor->tg - model->lm3g.trigrams >= base->model->n_counts[2])
00518             goto done;
00519         /* Advance bigram pointer if necessary. */
00520         while (itor->tg - model->lm3g.trigrams >=
00521             FIRST_TG(model, (itor->bg - model->lm3g.bigrams + 1))) {
00522             if (base->successor)
00523                 goto done;
00524             ++itor->bg;
00525             if (itor->bg == model->lm3g.bigrams + base->model->n_counts[1]) {
00526                 E_ERROR("Trigram %d has no vaild bigram parent\n",
00527                         itor->tg - model->lm3g.trigrams);
00528                 goto done;
00529             }
00530         }
00531         /* Advance unigram pointer if necessary. */
00532         while (itor->bg - model->lm3g.bigrams >= itor->ug[1].bigrams) {
00533             ++itor->ug;
00534             if (itor->ug == model->lm3g.unigrams + base->model->n_counts[0]) {
00535                 E_ERROR("Trigram %d has no vaild unigram parent\n",
00536                         itor->tg - model->lm3g.trigrams);
00537                 goto done;
00538             }
00539         }
00540         break;
00541     default: /* Should not happen. */
00542         goto done;
00543     }
00544 
00545     return (ngram_iter_t *)itor;
00546 done:
00547     ngram_iter_free(base);
00548     return NULL;
00549 }
00550 
00551 static void
00552 lm3g_template_iter_free(ngram_iter_t *base)
00553 {
00554     ckd_free(base);
00555 }

Generated on Mon Aug 29 2011 for SphinxBase by  doxygen 1.7.1