view hmm/tinySup.py @ 45:7d4da4e72d37

fix argv handling
author Henry S. Thompson <ht@inf.ed.ac.uk>
date Tue, 05 Jul 2022 10:22:50 +0100
parents 26d9c0308fcf
children
line wrap: on
line source

'''Trivial test of unsupervised learning with full dictionary supplied
See fnlp/lectures/12/hmmDNV.xlsx'''
import nltk, random
from nltk.tag.hmm import HiddenMarkovModelTagger, HiddenMarkovModelTrainer
from nltk.probability import FreqDist,ConditionalFreqDist
from nltk.probability import MLEProbDist, RandomProbDist, DictionaryConditionalProbDist

tagset=['<s>','D','N','V','</s>']
symbols=['<s>','the','sheep','run','</s>']
sents=[[('<s>','<s>'),('the','D'),('sheep','N'),('run','V'),('</s>','</s>')],
       [('<s>','<s>'),('sheep','N'),('run','V'),('the','D'),('sheep','N'),('</s>','</s>')],
       [('<s>','<s>'),('run','V'),('the','D'),('sheep','N'),('</s>','</s>')]]

taglists=[('<s>',[('<s>',1),('the',0),('sheep',0),('run',0),('</s>',0)]),
         ('D',[('the',.8),('sheep',.1),('run',.1),('<s>',0),('</s>',0)]),
         ('N',[('the',.2),('sheep',.4),('run',.4),('<s>',0),('</s>',0)]),
         ('V',[('the',.2),('sheep',.4),('run',.4),('<s>',0),('</s>',0)]),
         ('</s>',[('<s>',0),('the',0),('sheep',0),('run',0),('</s>',1)])]

tagdict=dict((k,MLEProbDist(FreqDist(dict(v)))) for k,v in taglists)
  
priors = MLEProbDist(FreqDist({'<s>':1,
         'D':0,
         'N':0,
         'V':0,
         '</s>':0}))

transitions = DictionaryConditionalProbDist(
                dict((state, RandomProbDist(tagset))
                      for state in tagset))

outputs = DictionaryConditionalProbDist(tagdict)


for tag in tagset:
  cp=outputs[tag]
  print tag,sum(cp.prob(s) for s in symbols)

model = HiddenMarkovModelTagger(symbols, tagset,
                transitions, outputs, priors)

for tag in tagset:
  cp=model._outputs[tag]
  print tag,sum(cp.prob(s) for s in symbols)

nm=HiddenMarkovModelTrainer(states=tagset,symbols=symbols)

# Note that contrary to naive reading of the documentation,
#  train_unsupervised expects a sequence of sequences of word/tag pairs,
#  it just ignores the tags
nnm=nm.train_unsupervised(sents,model=model,max_iterations=15,updateOutputs=False)

for tag in tagset:
  if tag=='</s>':
    break
  cp=nnm._transitions[tag]
  print(("    "+4*"%6s")%tuple(tagset[1:]))
  print(("%3s: "+4*"%6.3f")%tuple([tag]+[cp.prob(s) for s in tagset[1:]]))

for tag in tagset:
  cp=nnm._outputs[tag]
  print(("    "+5*"%6s")%tuple(symbols))
  x=[cp.prob(s) for s in symbols]
  print(("%3s: "+5*"%6.3f"+"%11.4e")%tuple([tag]+x+[sum(x)]))

print nnm.evaluate(sents)