comparison hmm/semiSup.py @ 3:26d9c0308fcf

updated/added from ecclerig version
author Henry S. Thompson <ht@inf.ed.ac.uk>
date Mon, 09 Mar 2020 17:35:28 +0000
parents e07789816ca5
children
comparison
equal deleted inserted replaced
2:e07789816ca5 3:26d9c0308fcf
1 '''Exploring the claim that a small dictionary can seed 1 '''Exploring the claim that a small dictionary can seed
2 an otherwise unsupervised HMM to learn a decent POS-tagger''' 2 an otherwise unsupervised HMM to learn a decent POS-tagger'''
3 import nltk, random, itertools 3 import nltk, random
4 from nltk.corpus import brown 4 from nltk.corpus import brown
5 from nltk.tag.hmm import HiddenMarkovModelTagger, HiddenMarkovModelTrainer, logsumexp2 5 from nltk.tag.hmm import HiddenMarkovModelTagger, HiddenMarkovModelTrainer
6 from nltk.probability import FreqDist,ConditionalFreqDist 6 from nltk.probability import FreqDist,ConditionalFreqDist
7 from nltk.probability import MLEProbDist, RandomProbDist, DictionaryConditionalProbDist 7 from nltk.probability import MLEProbDist, RandomProbDist, DictionaryConditionalProbDist
8 8
9 def totLogProb(self,sequences): 9 trainTagsPercent=0.99
10 N = len(self._states) 10 trainHMMPercent=0.9
11 M = len(self._symbols) 11 knownWordsPercent=0.99
12 logProb = 0
13 for sequence in sequences:
14 T = len(sequence)
15 # compute forward and backward probabilities
16 alpha = self._forward_probability(sequence)
17 beta = self._backward_probability(sequence)
18 # find the log probability of the sequence
19 logProb += logsumexp2(alpha[T-1])
20 return logProb
21 12
22 HiddenMarkovModelTagger.totLogProb=totLogProb
23
24 trainTagsPercent=1.0
25 trainHMMPercent=0.9
26 knownWordsPercent=1.0
27
28 SST=SSW='<s>'
29 EST=ESW='</s>'
30 SS=[(SSW,SST)]
31 ES=[(ESW,EST)]
32 TAGSETS={ 13 TAGSETS={
33 'univ': 14 'univ':
34 [u'ADJ', u'ADP', u'ADV', u'CONJ', u'DET', u'NOUN', u'NUM', 15 [u'ADJ', u'ADP', u'ADV', u'CONJ', u'DET', u'NOUN', u'NUM',
35 u'PRON', u'PRT', u'VERB', u'X', u'.',SST,EST], 16 u'PRON', u'PRT', u'VERB', u'X', u'.'],
36 'brown': 17 'brown':
37 [u"ABL", u"ABN", u"ABX", u"AP", u"AP$", u"AP+AP", u"AT", u"BE", 18 [u"ABL", u"ABN", u"ABX", u"AP", u"AP$", u"AP+AP", u"AT", u"BE",
38 u"BED", u"BED*", u"BEDZ", u"BEDZ*", u"BEG", u"BEM", u"BEM*", 19 u"BED", u"BED*", u"BEDZ", u"BEDZ*", u"BEG", u"BEM", u"BEM*",
39 u"BEN", u"BER", u"BER*", u"BEZ", u"BEZ*", u"CC", u"CD", 20 u"BEN", u"BER", u"BER*", u"BEZ", u"BEZ*", u"CC", u"CD",
40 u"CD$", u"CS", u"DO", u"DO*", u"DO+PPSS", u"DOD", u"DOD*", 21 u"CD$", u"CS", u"DO", u"DO*", u"DO+PPSS", u"DOD", u"DOD*",
76 57
77 TAGSETS['universal']=TAGSETS['univ'] 58 TAGSETS['universal']=TAGSETS['univ']
78 TAGSETS['penn']=TAGSETS['upenn'] 59 TAGSETS['penn']=TAGSETS['upenn']
79 60
80 def setup(cat='news',tagset='brown',corpus=brown): 61 def setup(cat='news',tagset='brown',corpus=brown):
81 return ([list(itertools.chain(iter(SS), 62 return ([[(word.lower(),tag) for (word,tag) in s]
82 ((word.lower(),tag) for (word,tag) in s)
83 ,iter(ES)))
84 for s in corpus.tagged_sents(categories=cat,tagset=tagset)], 63 for s in corpus.tagged_sents(categories=cat,tagset=tagset)],
85 list(itertools.chain(iter(SS), iter(ES), 64 [(word.lower(),tag) for (word,tag) in corpus.tagged_words(categories=cat,tagset=tagset)],
86 ((word.lower(),tag) for (word,tag) in
87 corpus.tagged_words(categories=cat,tagset=tagset)))),
88 TAGSETS[tagset]) 65 TAGSETS[tagset])
89 66
90 def notCurrent(s,missList): 67 def notCurrent(s,missList):
91 global i,n,done 68 global i,n,done
92 if done or (missList[i] is not s): 69 if done or (missList[i] is not s):
111 88
112 def pickWords(tagged,percent): 89 def pickWords(tagged,percent):
113 #wToT=ConditionalFreqDist(tagged) 90 #wToT=ConditionalFreqDist(tagged)
114 tToW=ConditionalFreqDist((t,w) for (w,t) in tagged) 91 tToW=ConditionalFreqDist((t,w) for (w,t) in tagged)
115 #print len(tToW[u'ADV']) 92 #print len(tToW[u'ADV'])
116 dd=dict((tag,(lambda wl,p=percent:\ 93 return dict((tag,(lambda wl,p=percent:\
117 wl[:int(p*len(wl))])( 94 wl[:int(p*len(wl))])(
118 sorted(tToW[tag].items(),key=lambda (k,v):v,reverse=True))) 95 sorted(tToW[tag].items(),key=lambda (k,v):v,reverse=True)))
119 for tag in tToW.keys()) 96 for tag in tToW.keys())
120 return dd
121 97
122 (tagged_s,tagged_w,tagset)=setup(tagset='universal') 98 (tagged_s,tagged_w,tagset)=setup(tagset='universal')
123 99
124 true_tagged_w=tagged_w[2:] # not SS, SE 100 wordTokens=FreqDist(word for word,tag in tagged_w)
125
126 wordTokens=FreqDist(word for word,tag in true_tagged_w)
127 wordsAsSuch=list(wordTokens.keys())
128 print len(wordTokens), wordTokens.N() 101 print len(wordTokens), wordTokens.N()
129 102
130 (trainTags,trainHMM,testHMM)=splitData(true_tagged_w,trainTagsPercent, 103 (trainTags,trainHMM,testHMM)=splitData(tagged_w,trainTagsPercent,
131 tagged_s,trainHMMPercent) 104 tagged_s,trainHMMPercent)
132 105
133 knownWords=pickWords(trainTags,knownWordsPercent) 106 knownWords=pickWords(trainTags,knownWordsPercent)
134 107
135 class SubsetFreqDist(FreqDist): 108 class SubsetFreqDist(FreqDist):
161 self._nTypes=len(self._words) 134 self._nTypes=len(self._words)
162 135
163 def words(self): 136 def words(self):
164 return self._words 137 return self._words
165 138
166 def buildPD(self,allTokens): 139 def buildPD(self,tokens):
167 self._sfd=SubsetFreqDist(self._wordsAndCounts,allTokens) 140 self._sfd=SubsetFreqDist(self._wordsAndCounts,tokens)
168 self._pd=MLEProbDist(self._sfd) 141 self._pd=MLEProbDist(self._sfd)
169 142
170 def getSFD(self): 143 def getSFD(self):
171 return self._sfd 144 return self._sfd
172 145
173 def getPD(self): 146 def getPD(self):
174 return self._pd 147 return self._pd
175 148
176 class FixedTag(Tag):
177 def buildPD(self):
178 self._pd=MLEProbDist(FreqDist(dict(self._wordsAndCounts)))
179
180 def getSFD(self):
181 raise NotImplementedError("not implemented for this subclass")
182
183 tags=dict((tagName,Tag(tagName,wl)) for tagName,wl in knownWords.items()) 149 tags=dict((tagName,Tag(tagName,wl)) for tagName,wl in knownWords.items())
184 kws=dict((tagName,tag.words()) for tagName,tag in tags.items()) 150 kws=dict((tagName,tag.words()) for tagName,tag in tags.items())
185 151
186 t2=list(filter(None, 152 t2=list(filter(None,
187 ((lambda i:False if not i[1] else i) 153 ((lambda i:False if not i[1] else i)
188 (((tagset[i],tagset[j]), 154 (((tagset[i],tagset[j]),
189 kws[tagset[i]].intersection(kws[tagset[j]])),) 155 kws[tagset[i]].intersection(kws[tagset[j]])),)
190 for i in xrange(0,len(tagset)-2) 156 for i in xrange(0,len(tagset))
191 for j in xrange(i+1,len(tagset)-2)))) 157 for j in xrange(i+1,len(tagset)))))
192 158
193 for tag in tags.values(): 159 for tag in tags.values():
194 tag.buildPD(wordTokens) 160 tag.buildPD(wordTokens)
195 161
196 tags[SST]=FixedTag(SST,[(SSW,1)]) 162 priors = RandomProbDist(tagset)
197 tags[SST].buildPD()
198 tags[EST]=FixedTag(EST,[(ESW,1)])
199 tags[EST].buildPD()
200
201 priors = MLEProbDist(FreqDist(dict((tag,1 if tag==SST else 0) for tag in tagset)))
202 163
203 transitions = DictionaryConditionalProbDist( 164 transitions = DictionaryConditionalProbDist(
204 dict((state, RandomProbDist(tagset)) 165 dict((state, RandomProbDist(tagset))
205 for state in tagset)) 166 for state in tagset))
206 167
207 outputs = DictionaryConditionalProbDist( 168 outputs = DictionaryConditionalProbDist(
208 dict((state, tags[state].getPD()) 169 dict((state, tags[state].getPD())
209 for state in tagset)) 170 for state in tagset))
210 171
211 model = HiddenMarkovModelTagger(wordsAsSuch, tagset, 172 model = HiddenMarkovModelTagger(wordTokens, tagset,
212 transitions, outputs, priors) 173 transitions, outputs, priors)
213 174
214 print "model", model.evaluate(testHMM), model.totLogProb(testHMM) 175 print model.evaluate(testHMM)
215 176
216 nm=HiddenMarkovModelTrainer(states=tagset,symbols=wordsAsSuch) 177 nm=HiddenMarkovModelTrainer(states=tagset,symbols=wordTokens)
217 178
218 # Note that contrary to naive reading of the documentation, 179 # Note that contrary to naive reading of the documentation,
219 # train_unsupervised expects a sequence of sequences of word/tag pairs, 180 # train_unsupervised expects a sequence of sequences of word/tag pairs,
220 # it just ignores the tags 181 # it just ignores the tags
221 nnm=nm.train_unsupervised(trainHMM,True,model=model,max_iterations=10,testMe=testHMM) 182 nnm=nm.train_unsupervised(trainHMM,model=model,max_iterations=15,updateOutputs=False)
222 183
223 print nnm.totLogProb(testHMM) 184 print nnm.evaluate(testHMM)