view ngram.py @ 60:bc1acb1416ab

working on fixing gnus home foulup, see /disk/scratch/{mail,gnus}
author Henry S. Thompson <ht@inf.ed.ac.uk>
date Wed, 13 Dec 2023 17:31:28 +0000
parents fee51ab07d09
children
line wrap: on
line source

# Natural Language Toolkit: Language Models
#
# Copyright (C) 2001-2009 NLTK Project
# Author: Steven Bird <sb@csse.unimelb.edu.au>
# URL: <http://www.nltk.org/>
# For license information, see LICENSE.TXT

import random, types
from itertools import chain
from math import log

from nltk.probability import (ConditionalProbDist, ConditionalFreqDist,
                              MLEProbDist, FreqDist)
try:
    from nltk.util import ingrams
except:
    from nltkx.util import ingrams

from api import *

class NgramModel(ModelI):
    """
    A processing interface for assigning a probability to the next word.
    """

    def __init__(self, n, train, pad_left=False, pad_right=False,
                 estimator=None, *estimator_args, **estimator_kwargs):
        """
        Creates an ngram language model to capture patterns in n consecutive
        words of training text.  An estimator smooths the probabilities derived
        from the text and may allow generation of ngrams not seen during
        training.

        @param n: the order of the language model (ngram size)
        @type n: C{int}
        @param train: the training text
        @type train: C{list} of C{list} of C{string} 
        @param estimator: a function for generating a probability distribution
        @type estimator: a function that takes a C{ConditionalFreqDist} and
              returns a C{ConditionalProbDist}
        @param pad_left: whether to pad the left of each sentence with an (n-1)-gram of empty strings
        @type pad_left: bool
        @param pad_right: whether to pad the right of each sentence with an (n-1)-gram of empty strings
        @type pad_right: bool
        @param estimator_args: Extra arguments for estimator.
            These arguments are usually used to specify extra
            properties for the probability distributions of individual
            conditions, such as the number of bins they contain.
            Note: For backward-compatibility, if no arguments are specified, the
            number of bins in the underlying ConditionalFreqDist are passed to
            the estimator as an argument.
        @type estimator_args: (any)
        @param estimator_kwargs: Extra keyword arguments for the estimator
        @type estimator_kwargs: (any)
        """
        # protection from cryptic behavior for calling programs
        # that use the pre-2.0.2 interface
        assert(isinstance(pad_left, bool))
        assert(isinstance(pad_right, bool))

        self._n = n
        self._W = len(train)
        self._lpad = ('<s>',) * (n - 1) if pad_left else ()
        # Need _rpad even for unigrams or padded entropy will give
        #  wrong answer because '' will be treated as unseen...
        self._rpad = ('</s>',) * (max(1,(n - 1))) if pad_right else ()
        self._padLen = len(self._lpad)+len(self._rpad)

        self._N=0
        delta = 1+self._padLen-n        # len(sent)+delta == ngrams in sent

        if estimator is None:
            assert (estimator_args is None) and (estimator_kwargs is None),\
                   "estimator_args or _kwargs supplied, but no estimator"
            estimator = lambda fdist, bins: MLEProbDist(fdist)

        # Given backoff, a generator isn't acceptable
        if isinstance(train,types.GeneratorType):
          train=list(train)

        if n == 1:
            if pad_right:
                sents=(chain(s,self._rpad) for s in train)
            else:
                sents=train
            fd=FreqDist()
            for s in sents:
                fd.update(s)
            if not estimator_args and not estimator_kwargs:
                self._model = estimator(fd,fd.B())
            else:
                self._model = estimator(fd,fd.B(),
                                        *estimator_args, **estimator_kwargs)
            self._N=fd.N()
        else:
            cfd = ConditionalFreqDist()
            self._ngrams = set()

            for sent in train:
                self._N+=len(sent)+delta
                for ngram in ingrams(chain(self._lpad, sent, self._rpad), n):
                    self._ngrams.add(ngram)
                    context = tuple(ngram[:-1])
                    token = ngram[-1]
                    cfd[context][token]+=1
            if not estimator_args and not estimator_kwargs:
                self._model = ConditionalProbDist(cfd, estimator, len(cfd))
            else:
                self._model = ConditionalProbDist(cfd, estimator, *estimator_args, **estimator_kwargs)

        # recursively construct the lower-order models
        if n > 1:
            self._backoff = NgramModel(n-1, train, pad_left, pad_right,
                                       estimator, *estimator_args, **estimator_kwargs)

            # Code below here in this method, and the _words_following and _alpha method, are from
            # http://www.nltk.org/_modules/nltk/model/ngram.html "Last updated on Feb 26, 2015"
            self._backoff_alphas = dict()
            # For each condition (or context)
            #print cfd,cfd.conditions()
            for ctxt in cfd.conditions():
                backoff_ctxt = ctxt[1:]
                backoff_total_pr = 0.0
                total_observed_pr = 0.0

                # this is the subset of words that we OBSERVED following
                # this context.
                # i.e. Count(word | context) > 0
                wf=list(self._words_following(ctxt, cfd))
                for word in self._words_following(ctxt, cfd):
                    total_observed_pr += self.prob(word, ctxt)
                    # we also need the total (n-1)-gram probability of
                    # words observed in this n-gram context
                    backoff_total_pr += self._backoff.prob(word, backoff_ctxt)
                assert (0 <= total_observed_pr <= 1),\
                       "sum of probs for %s out of bounds: %s"%(ctxt,total_observed_pr)
                # beta is the remaining probability weight after we factor out
                # the probability of observed words.
                # As a sanity check, both total_observed_pr and backoff_total_pr
                # must be GE 0, since probabilities are never negative
                beta = 1.0 - total_observed_pr

                # if backoff total is 1, that should mean that all samples occur in this context,
                #  so we will never back off.
                # Greater than 1 is an error.
                assert (0 <= backoff_total_pr < 1), \
                       "sum of backoff probs for %s out of bounds: %s"%(ctxt,backoff_total_pr)
                alpha_ctxt = beta / (1.0 - backoff_total_pr)

                self._backoff_alphas[ctxt] = alpha_ctxt

    def _words_following(self, context, cond_freq_dist):
        return cond_freq_dist[context].iterkeys()
        # below from http://www.nltk.org/_modules/nltk/model/ngram.html,
        # depends on new CFD???
        #for ctxt, word in cond_freq_dist.iterkeys():
        #    if ctxt == context:
        #        yield word

    def prob(self, word, context, verbose=False):
        """
        Evaluate the probability of this word in this context
        using Katz Backoff.
        """
        assert(isinstance(word,types.StringTypes))
        context = tuple(context)
        if self._n==1:
            if not(self._model.SUM_TO_ONE):
                # Smoothing models should do the right thing for unigrams
                #  even if they're 'absent'
                return self._model.prob(word)
            else:
                try:
                    return self._model.prob(word)
                except:
                    raise RuntimeError("No probability mass assigned"
                                       "to unigram %s" % (word))
        if context + (word,) in self._ngrams:
            return self[context].prob(word)
        else:
            alpha=self._alpha(context)
            if alpha>0:
                if verbose:
                    print "backing off for %s"%(context+(word,),)
                return alpha * self._backoff.prob(word, context[1:],verbose)
            else:
                if verbose:
                    print "no backoff for %s as model doesn't do any smoothing"%word
                return alpha

    def _alpha(self, context,verbose=False):
        """Get the backoff alpha value for the given context
        """
        error_message = "Alphas and backoff are not defined for unigram models"
        assert (not self._n == 1), error_message

        if context in self._backoff_alphas:
            res = self._backoff_alphas[context]
        else:
            res = 1
        if verbose:
            print " alpha: %s = %s"%(context,res)
        return res


    def logprob(self, word, context,verbose=False):
        """
        Evaluate the (negative) log probability of this word in this context.
        """

        return -log(self.prob(word, context,verbose), 2)

    # NB, this will always start with same word since model
    # is trained on a single text
    def generate(self, num_words, context=()):
        '''Generate random text based on the language model.'''
        text = list(context)
        for i in range(num_words):
            text.append(self._generate_one(text))
        return text

    def _generate_one(self, context):
        context = (self._prefix + tuple(context))[-self._n+1:]
        # print "Context (%d): <%s>" % (self._n, ','.join(context))
        if context in self:
            return self[context].generate()
        elif self._n > 1:
            return self._backoff._generate_one(context[1:])
        else:
            return '.'

    def entropy(self, text, pad_left=False, pad_right=False,
                verbose=False, perItem=False):
        """
        Evaluate the total entropy of a text with respect to the model.
        This is the sum of the log probability of each word in the message.
        """
        # This version takes account of padding for greater accuracy
        e = 0.0
        for ngram in ngrams(chain(self._lpad, text, self._rpad), self._n):
            context = tuple(ngram[:-1])
            token = ngram[-1]
            cost=self.logprob(token, context, verbose)  # _negative_
                                                        # log2 prob == cost!
            if verbose:
                print "p(%s|%s) = [%s-gram] %7f"%(token,context,self._n,2**-cost)
            e += cost
        if perItem:
            return e/((len(text)+self._padLen)-(self._n - 1))
        else:
            return e

    def dump(self, file, logBase=None, precision=7):
        """Dump this model in SRILM/ARPA/Doug Paul format

        Use logBase=10 and the default precision to get something comparable
        to SRILM ngram-model -lm output
        @param file to dump to
        @type file file
        @param logBase If not None, output logBases to the specified base
        @type logBase int|None"""
        file.write('\n\\data\\\n')
        self._writeLens(file)
        self._writeModels(file,logBase,precision,None)
        file.write('\\end\\\n')

    def _writeLens(self,file):
        if self._n>1:
            self._backoff._writeLens(file)
            file.write('ngram %s=%s\n'%(self._n,
                                        sum(len(self._model[c].samples())\
                                            for c in self._model.keys())))
        else:
            file.write('ngram 1=%s\n'%len(self._model.samples()))
            

    def _writeModels(self,file,logBase,precision,alphas):
        if self._n>1:
            self._backoff._writeModels(file,logBase,precision,self._backoff_alphas)
        file.write('\n\\%s-grams:\n'%self._n)
        if self._n==1:
            self._writeProbs(self._model,file,logBase,precision,(),alphas)
        else:
            for c in sorted(self._model.conditions()):
                self._writeProbs(self._model[c],file,logBase,precision,
                                  c,alphas)

    def _writeProbs(self,pd,file,logBase,precision,ctxt,alphas):
        if self._n==1:
            for k in sorted(pd.samples()+['<unk>','<s>']):
                if k=='<s>':
                    file.write('-99')
                elif k=='<unk>':
                    _writeProb(file,logBase,precision,1-pd.discount()) 
                else:
                    _writeProb(file,logBase,precision,pd.prob(k))
                file.write('\t%s'%k)
                if k not in ('</s>','<unk>'):
                    file.write('\t')
                    _writeProb(file,logBase,precision,alphas[ctxt+(k,)])
                file.write('\n')
        else:
            ctxtString=' '.join(ctxt)
            for k in sorted(pd.samples()):
                _writeProb(file,logBase,precision,pd.prob(k))
                file.write('\t%s %s'%(ctxtString,k))
                if alphas is not None:
                    file.write('\t')
                    _writeProb(file,logBase,precision,alphas[ctxt+(k,)])
                file.write('\n')

    def __contains__(self, item):
        try:
            return item in self._model
        except:
            try:
                # hack if model is an MLEProbDist, more efficient
                return item in self._model._freqdist
            except:
                return item in self._model.samples()

    def __getitem__(self, item):
        return self._model[item]

    def __repr__(self):
        return '<NgramModel with %d %d-grams>' % (self._N, self._n)

def _writeProb(file,logBase,precision,p):
    file.write('%.*g'%(precision,
                       p if logBase is None else log(p,logBase)))

def demo():
    from nltk.corpus import brown
    from nltk.probability import LidstoneProbDist, WittenBellProbDist
    estimator = lambda fdist, bins: LidstoneProbDist(fdist, 0.2)
#    estimator = lambda fdist, bins: WittenBellProbDist(fdist, 0.2)
    lm = NgramModel(3, brown.words(categories='news'), estimator)
    print lm
#    print lm.entropy(sent)
    text = lm.generate(100)
    import textwrap
    print '\n'.join(textwrap.wrap(' '.join(text)))

if __name__ == '__main__':
    demo()