comparison hmm/tinySup.py @ 2:e07789816ca5

adding more python files from lib/python on origen
author Henry Thompson <ht@markup.co.uk>
date Mon, 09 Mar 2020 16:48:09 +0000
parents
children 26d9c0308fcf
comparison
equal deleted inserted replaced
1:0a3abe59e364 2:e07789816ca5
1 '''Trivial test of unsupervised learning with full dictionary supplied
2 See fnlp/lectures/12/hmmDNV.xlsx'''
3 import nltk, random
4 from nltk.tag.hmm import HiddenMarkovModelTagger, HiddenMarkovModelTrainer
5 from nltk.probability import FreqDist,ConditionalFreqDist
6 from nltk.probability import MLEProbDist, RandomProbDist, DictionaryConditionalProbDist
7
8 tagset=['<s>','D','N','V','</s>']
9 symbols=['<s>','the','sheep','run','</s>']
10 sents=[[('<s>','<s>'),('the','D'),('sheep','N'),('run','V'),('</s>','</s>')],
11 [('<s>','<s>'),('sheep','N'),('run','V'),('the','D'),('sheep','N'),('</s>','</s>')],
12 [('<s>','<s>'),('run','V'),('the','D'),('sheep','N'),('</s>','</s>')]]
13
14 taglists=[('<s>',[('<s>',1),('the',0),('sheep',0),('run',0),('</s>',0)]),
15 ('D',[('the',1),('sheep',0),('run',0),('<s>',0),('</s>',0)]),
16 ('N',[('the',0),('sheep',.5),('run',.5),('<s>',0),('</s>',0)]),
17 ('V',[('the',0),('sheep',.5),('run',.5),('<s>',0),('</s>',0)]),
18 ('</s>',[('<s>',0),('the',0),('sheep',0),('run',0),('</s>',1)])]
19
20 tagdict=dict((k,MLEProbDist(FreqDist(dict(v)))) for k,v in taglists)
21
22 priors = MLEProbDist(FreqDist({'<s>':1,
23 'D':0,
24 'N':0,
25 'V':0,
26 '</s>':0}))
27
28 transitions = DictionaryConditionalProbDist(
29 dict((state, RandomProbDist(tagset))
30 for state in tagset))
31
32 outputs = DictionaryConditionalProbDist(tagdict)
33
34
35 for tag in tagset:
36 cp=outputs[tag]
37 print tag,sum(cp.prob(s) for s in symbols)
38
39 model = HiddenMarkovModelTagger(symbols, tagset,
40 transitions, outputs, priors)
41
42 for tag in tagset:
43 cp=model._outputs[tag]
44 print tag,sum(cp.prob(s) for s in symbols)
45
46 nm=HiddenMarkovModelTrainer(states=tagset,symbols=symbols)
47
48 # Note that contrary to naive reading of the documentation,
49 # train_unsupervised expects a sequence of sequences of word/tag pairs,
50 # it just ignores the tags
51 nnm=nm.train_unsupervised(sents,model=model,max_iterations=10,updateOutputs=False)
52
53 for tag in tagset:
54 if tag=='</s>':
55 break
56 cp=nnm._transitions[tag]
57 print((" "+4*"%6s")%tuple(tagset[1:]))
58 print(("%3s: "+4*"%6.3f")%tuple([tag]+[cp.prob(s) for s in tagset[1:]]))
59
60 for tag in tagset:
61 cp=nnm._outputs[tag]
62 print((" "+5*"%6s")%tuple(symbols))
63 x=[cp.prob(s) for s in symbols]
64 print(("%3s: "+5*"%6.3f"+"%11.4e")%tuple([tag]+x+[sum(x)]))
65
66 print nnm.evaluate(sents)