Mercurial > hg > python
view hmm/tinySup.py @ 51:44fea514ca45
foo
author | Henry S. Thompson <ht@inf.ed.ac.uk> |
---|---|
date | Sun, 19 Feb 2023 16:44:06 +0000 |
parents | 26d9c0308fcf |
children |
line wrap: on
line source
'''Trivial test of unsupervised learning with full dictionary supplied See fnlp/lectures/12/hmmDNV.xlsx''' import nltk, random from nltk.tag.hmm import HiddenMarkovModelTagger, HiddenMarkovModelTrainer from nltk.probability import FreqDist,ConditionalFreqDist from nltk.probability import MLEProbDist, RandomProbDist, DictionaryConditionalProbDist tagset=['<s>','D','N','V','</s>'] symbols=['<s>','the','sheep','run','</s>'] sents=[[('<s>','<s>'),('the','D'),('sheep','N'),('run','V'),('</s>','</s>')], [('<s>','<s>'),('sheep','N'),('run','V'),('the','D'),('sheep','N'),('</s>','</s>')], [('<s>','<s>'),('run','V'),('the','D'),('sheep','N'),('</s>','</s>')]] taglists=[('<s>',[('<s>',1),('the',0),('sheep',0),('run',0),('</s>',0)]), ('D',[('the',.8),('sheep',.1),('run',.1),('<s>',0),('</s>',0)]), ('N',[('the',.2),('sheep',.4),('run',.4),('<s>',0),('</s>',0)]), ('V',[('the',.2),('sheep',.4),('run',.4),('<s>',0),('</s>',0)]), ('</s>',[('<s>',0),('the',0),('sheep',0),('run',0),('</s>',1)])] tagdict=dict((k,MLEProbDist(FreqDist(dict(v)))) for k,v in taglists) priors = MLEProbDist(FreqDist({'<s>':1, 'D':0, 'N':0, 'V':0, '</s>':0})) transitions = DictionaryConditionalProbDist( dict((state, RandomProbDist(tagset)) for state in tagset)) outputs = DictionaryConditionalProbDist(tagdict) for tag in tagset: cp=outputs[tag] print tag,sum(cp.prob(s) for s in symbols) model = HiddenMarkovModelTagger(symbols, tagset, transitions, outputs, priors) for tag in tagset: cp=model._outputs[tag] print tag,sum(cp.prob(s) for s in symbols) nm=HiddenMarkovModelTrainer(states=tagset,symbols=symbols) # Note that contrary to naive reading of the documentation, # train_unsupervised expects a sequence of sequences of word/tag pairs, # it just ignores the tags nnm=nm.train_unsupervised(sents,model=model,max_iterations=15,updateOutputs=False) for tag in tagset: if tag=='</s>': break cp=nnm._transitions[tag] print((" "+4*"%6s")%tuple(tagset[1:])) print(("%3s: "+4*"%6.3f")%tuple([tag]+[cp.prob(s) for s in tagset[1:]])) for tag in tagset: cp=nnm._outputs[tag] print((" "+5*"%6s")%tuple(symbols)) x=[cp.prob(s) for s in symbols] print(("%3s: "+5*"%6.3f"+"%11.4e")%tuple([tag]+x+[sum(x)])) print nnm.evaluate(sents)