Mercurial > hg > python
comparison hmm/semiSup.py @ 2:e07789816ca5
adding more python files from lib/python on origen
author | Henry Thompson <ht@markup.co.uk> |
---|---|
date | Mon, 09 Mar 2020 16:48:09 +0000 |
parents | |
children | 26d9c0308fcf |
comparison
equal
deleted
inserted
replaced
1:0a3abe59e364 | 2:e07789816ca5 |
---|---|
1 '''Exploring the claim that a small dictionary can seed | |
2 an otherwise unsupervised HMM to learn a decent POS-tagger''' | |
3 import nltk, random, itertools | |
4 from nltk.corpus import brown | |
5 from nltk.tag.hmm import HiddenMarkovModelTagger, HiddenMarkovModelTrainer, logsumexp2 | |
6 from nltk.probability import FreqDist,ConditionalFreqDist | |
7 from nltk.probability import MLEProbDist, RandomProbDist, DictionaryConditionalProbDist | |
8 | |
9 def totLogProb(self,sequences): | |
10 N = len(self._states) | |
11 M = len(self._symbols) | |
12 logProb = 0 | |
13 for sequence in sequences: | |
14 T = len(sequence) | |
15 # compute forward and backward probabilities | |
16 alpha = self._forward_probability(sequence) | |
17 beta = self._backward_probability(sequence) | |
18 # find the log probability of the sequence | |
19 logProb += logsumexp2(alpha[T-1]) | |
20 return logProb | |
21 | |
22 HiddenMarkovModelTagger.totLogProb=totLogProb | |
23 | |
24 trainTagsPercent=1.0 | |
25 trainHMMPercent=0.9 | |
26 knownWordsPercent=1.0 | |
27 | |
28 SST=SSW='<s>' | |
29 EST=ESW='</s>' | |
30 SS=[(SSW,SST)] | |
31 ES=[(ESW,EST)] | |
32 TAGSETS={ | |
33 'univ': | |
34 [u'ADJ', u'ADP', u'ADV', u'CONJ', u'DET', u'NOUN', u'NUM', | |
35 u'PRON', u'PRT', u'VERB', u'X', u'.',SST,EST], | |
36 'brown': | |
37 [u"ABL", u"ABN", u"ABX", u"AP", u"AP$", u"AP+AP", u"AT", u"BE", | |
38 u"BED", u"BED*", u"BEDZ", u"BEDZ*", u"BEG", u"BEM", u"BEM*", | |
39 u"BEN", u"BER", u"BER*", u"BEZ", u"BEZ*", u"CC", u"CD", | |
40 u"CD$", u"CS", u"DO", u"DO*", u"DO+PPSS", u"DOD", u"DOD*", | |
41 u"DOZ", u"DOZ*", u"DT", u"DT$", u"DT+BEZ", u"DT+MD", u"DTI", | |
42 u"DTS", u"DTS+BEZ", u"DTX", u"EX", u"EX+BEZ", u"EX+HVD", u"EX+HVZ", | |
43 u"EX+MD", u"FW-*", u"FW-AT", u"FW-AT+NN", u"FW-AT+NP", u"FW-BE", u"FW-BER", | |
44 u"FW-BEZ", u"FW-CC", u"FW-CD", u"FW-CS", u"FW-DT", u"FW-DT+BEZ", u"FW-DTS", | |
45 u"FW-HV", u"FW-IN", u"FW-IN+AT", u"FW-IN+NN", u"FW-IN+NP", u"FW-JJ", | |
46 u"FW-JJR", u"FW-JJT", u"FW-NN", u"FW-NN$", u"FW-NNS", u"FW-NP", u"FW-NPS", | |
47 u"FW-NR", u"FW-OD", u"FW-PN", u"FW-PP$", u"FW-PPL", u"FW-PPL+VBZ", | |
48 u"FW-PPO", u"FW-PPO+IN", u"FW-PPS", u"FW-PPSS", u"FW-PPSS+HV", u"FW-QL", | |
49 u"FW-RB", u"FW-RB+CC", u"FW-TO+VB", u"FW-UH", u"FW-VB", u"FW-VBD", | |
50 u"FW-VBG", u"FW-VBN", u"FW-VBZ", u"FW-WDT", u"FW-WPO", u"FW-WPS", u"HV", | |
51 u"HV*", u"HV+TO", u"HVD", u"HVD*", u"HVG", u"HVN", u"HVZ", u"HVZ*", u"IN", | |
52 u"IN+IN", u"IN+PPO", u"JJ", u"JJ$", u"JJ+JJ", u"JJR", u"JJR+CS", u"JJS", | |
53 u"JJT", u"MD", u"MD*", u"MD+HV", u"MD+PPSS", u"MD+TO", u"NN", u"NN$", | |
54 u"NN+BEZ", u"NN+HVD", u"NN+HVZ", u"NN+IN", u"NN+MD", u"NN+NN", u"NNS", | |
55 u"NNS$", u"NNS+MD", u"NP", u"NP$", u"NP+BEZ", u"NP+HVZ", u"NP+MD", | |
56 u"NPS", u"NPS$", u"NR", u"NR$", u"NR+MD", u"NRS", u"OD", | |
57 u"PN", u"PN$", u"PN+BEZ", u"PN+HVD", u"PN+HVZ", u"PN+MD", u"PP$", | |
58 u"PP$$", u"PPL", u"PPLS", u"PPO", u"PPS", u"PPS+BEZ", u"PPS+HVD", | |
59 u"PPS+HVZ", u"PPS+MD", u"PPSS", u"PPSS+BEM", u"PPSS+BER", u"PPSS+BEZ", | |
60 u"PPSS+BEZ*", u"PPSS+HV", u"PPSS+HVD", u"PPSS+MD", u"PPSS+VB", u"QL", | |
61 u"QLP", u"RB", u"RB$", u"RB+BEZ", u"RB+CS", u"RBR", u"RBR+CS", u"RBT", | |
62 u"RN", u"RP", u"RP+IN", u"TO", u"TO+VB", u"UH", u"VB", u"VB+AT", | |
63 u"VB+IN", u"VB+JJ", u"VB+PPO", u"VB+RP", u"VB+TO", u"VB+VB", u"VBD", | |
64 u"VBG", u"VBG+TO", u"VBN", u"VBN+TO", u"VBZ", u"WDT", u"WDT+BER", | |
65 u"WDT+BER+PP", u"WDT+BEZ", u"WDT+DO+PPS", u"WDT+DOD", u"WDT+HVZ", u"WP$", | |
66 u"WPO", u"WPS", u"WPS+BEZ", u"WPS+HVD", u"WPS+HVZ", u"WPS+MD", u"WQL", | |
67 u"WRB", u"WRB+BER", u"WRB+BEZ", u"WRB+DO", u"WRB+DOD", u"WRB+DOD*", | |
68 u"WRB+DOZ", u"WRB+IN", u"WRB+MD", | |
69 u"(", u")", u"*", u",", u"--", u".", u":"], | |
70 'upenn': | |
71 [u"CC", u"CD", u"DT", u"EX", u"FW", u"IN", u"JJ", u"JJR", u"JJS", u"LS", | |
72 u"MD", u"NN", u"NNP", u"NNPS", u"NNS", u"PDT", u"POS", u"PRP", u"PRP$", | |
73 u"RB", u"RBR", u"RBS", u"RP", u"SYM", u"TO", u"UH", u"VB", u"VBD", u"VBG", | |
74 u"VBN", u"VBP", u"VBZ", u"WDT", u"WP", u"WP$", u"WRB", | |
75 u"``", u"$", u"''", u"(", u")", u",", u"--", u".", u":"]} | |
76 | |
77 TAGSETS['universal']=TAGSETS['univ'] | |
78 TAGSETS['penn']=TAGSETS['upenn'] | |
79 | |
80 def setup(cat='news',tagset='brown',corpus=brown): | |
81 return ([list(itertools.chain(iter(SS), | |
82 ((word.lower(),tag) for (word,tag) in s) | |
83 ,iter(ES))) | |
84 for s in corpus.tagged_sents(categories=cat,tagset=tagset)], | |
85 list(itertools.chain(iter(SS), iter(ES), | |
86 ((word.lower(),tag) for (word,tag) in | |
87 corpus.tagged_words(categories=cat,tagset=tagset)))), | |
88 TAGSETS[tagset]) | |
89 | |
90 def notCurrent(s,missList): | |
91 global i,n,done | |
92 if done or (missList[i] is not s): | |
93 return True | |
94 else: | |
95 i+=1 | |
96 if i==n: | |
97 done=True | |
98 return False | |
99 | |
100 def splitData(words,wordPercent,sentences,sentPercent): | |
101 global i,n, done | |
102 trainWords=random.sample(words,int(wordPercent*len(words))) | |
103 # random.sample(sentences,int(sentPercent*len(sentences))) | |
104 trainSents=[s for s in sentences if random.random()<sentPercent] | |
105 # hack! | |
106 i=0 | |
107 n=len(trainSents) | |
108 done=False | |
109 testSents=[s for s in sentences if notCurrent(s,trainSents)] | |
110 return trainWords, trainSents, testSents | |
111 | |
112 def pickWords(tagged,percent): | |
113 #wToT=ConditionalFreqDist(tagged) | |
114 tToW=ConditionalFreqDist((t,w) for (w,t) in tagged) | |
115 #print len(tToW[u'ADV']) | |
116 dd=dict((tag,(lambda wl,p=percent:\ | |
117 wl[:int(p*len(wl))])( | |
118 sorted(tToW[tag].items(),key=lambda (k,v):v,reverse=True))) | |
119 for tag in tToW.keys()) | |
120 return dd | |
121 | |
122 (tagged_s,tagged_w,tagset)=setup(tagset='universal') | |
123 | |
124 true_tagged_w=tagged_w[2:] # not SS, SE | |
125 | |
126 wordTokens=FreqDist(word for word,tag in true_tagged_w) | |
127 wordsAsSuch=list(wordTokens.keys()) | |
128 print len(wordTokens), wordTokens.N() | |
129 | |
130 (trainTags,trainHMM,testHMM)=splitData(true_tagged_w,trainTagsPercent, | |
131 tagged_s,trainHMMPercent) | |
132 | |
133 knownWords=pickWords(trainTags,knownWordsPercent) | |
134 | |
135 class SubsetFreqDist(FreqDist): | |
136 def __init__(self,pairs,baseset,basecount=.05): | |
137 dict.update(self,pairs) | |
138 self._baseset=baseset | |
139 self._basecount=basecount | |
140 pn=sum(n for w,n in pairs) | |
141 self._N=pn+((len(baseset)-len(pairs))*basecount) | |
142 | |
143 def __getitem__(self,key): | |
144 return dict.__getitem__(self,key) | |
145 | |
146 def __missing__(self,key): | |
147 if key in self._baseset: | |
148 return self._basecount | |
149 else: | |
150 return 0 | |
151 | |
152 def N(self): | |
153 return self._N | |
154 | |
155 class Tag: | |
156 def __init__(self,tag,wordsAndCounts): | |
157 self._tag=tag | |
158 self._wordsAndCounts=wordsAndCounts | |
159 self._words=set(w for w,n in wordsAndCounts) | |
160 self._nTokens=sum(n for w,n in wordsAndCounts) | |
161 self._nTypes=len(self._words) | |
162 | |
163 def words(self): | |
164 return self._words | |
165 | |
166 def buildPD(self,allTokens): | |
167 self._sfd=SubsetFreqDist(self._wordsAndCounts,allTokens) | |
168 self._pd=MLEProbDist(self._sfd) | |
169 | |
170 def getSFD(self): | |
171 return self._sfd | |
172 | |
173 def getPD(self): | |
174 return self._pd | |
175 | |
176 class FixedTag(Tag): | |
177 def buildPD(self): | |
178 self._pd=MLEProbDist(FreqDist(dict(self._wordsAndCounts))) | |
179 | |
180 def getSFD(self): | |
181 raise NotImplementedError("not implemented for this subclass") | |
182 | |
183 tags=dict((tagName,Tag(tagName,wl)) for tagName,wl in knownWords.items()) | |
184 kws=dict((tagName,tag.words()) for tagName,tag in tags.items()) | |
185 | |
186 t2=list(filter(None, | |
187 ((lambda i:False if not i[1] else i) | |
188 (((tagset[i],tagset[j]), | |
189 kws[tagset[i]].intersection(kws[tagset[j]])),) | |
190 for i in xrange(0,len(tagset)-2) | |
191 for j in xrange(i+1,len(tagset)-2)))) | |
192 | |
193 for tag in tags.values(): | |
194 tag.buildPD(wordTokens) | |
195 | |
196 tags[SST]=FixedTag(SST,[(SSW,1)]) | |
197 tags[SST].buildPD() | |
198 tags[EST]=FixedTag(EST,[(ESW,1)]) | |
199 tags[EST].buildPD() | |
200 | |
201 priors = MLEProbDist(FreqDist(dict((tag,1 if tag==SST else 0) for tag in tagset))) | |
202 | |
203 transitions = DictionaryConditionalProbDist( | |
204 dict((state, RandomProbDist(tagset)) | |
205 for state in tagset)) | |
206 | |
207 outputs = DictionaryConditionalProbDist( | |
208 dict((state, tags[state].getPD()) | |
209 for state in tagset)) | |
210 | |
211 model = HiddenMarkovModelTagger(wordsAsSuch, tagset, | |
212 transitions, outputs, priors) | |
213 | |
214 print "model", model.evaluate(testHMM), model.totLogProb(testHMM) | |
215 | |
216 nm=HiddenMarkovModelTrainer(states=tagset,symbols=wordsAsSuch) | |
217 | |
218 # Note that contrary to naive reading of the documentation, | |
219 # train_unsupervised expects a sequence of sequences of word/tag pairs, | |
220 # it just ignores the tags | |
221 nnm=nm.train_unsupervised(trainHMM,True,model=model,max_iterations=10,testMe=testHMM) | |
222 | |
223 print nnm.totLogProb(testHMM) |