001 package org.maltparser.parser;
002
003 import org.maltparser.core.exception.MaltChainedException;
004 import org.maltparser.core.syntaxgraph.DependencyStructure;
005 import org.maltparser.parser.guide.ClassifierGuide;
006 import org.maltparser.parser.guide.OracleGuide;
007 import org.maltparser.parser.guide.SingleGuide;
008 import org.maltparser.parser.history.GuideHistory;
009 import org.maltparser.parser.history.action.GuideDecision;
010 import org.maltparser.parser.history.action.GuideUserAction;
011 /**
012 * @author Johan Hall
013 *
014 */
015 public class BatchTrainer extends Trainer {
016 private final OracleGuide oracleGuide;
017 private int parseCount;
018
019 public BatchTrainer(DependencyParserConfig manager) throws MaltChainedException {
020 super(manager);
021 ((SingleMalt)manager).addRegistry(org.maltparser.parser.Algorithm.class, this);
022 setManager(manager);
023 initParserState(1);
024 setGuide(new SingleGuide(manager, (GuideHistory)parserState.getHistory(), ClassifierGuide.GuideMode.BATCH));
025 oracleGuide = parserState.getFactory().makeOracleGuide(parserState.getHistory());
026 }
027
028 public DependencyStructure parse(DependencyStructure goldDependencyGraph, DependencyStructure parseDependencyGraph) throws MaltChainedException {
029 parserState.clear();
030 parserState.initialize(parseDependencyGraph);
031 currentParserConfiguration = parserState.getConfiguration();
032 parseCount++;
033 if (diagnostics == true) {
034 writeToDiaFile(parseCount + "");
035 }
036 TransitionSystem transitionSystem = parserState.getTransitionSystem();
037 while (!parserState.isTerminalState()) {
038 GuideUserAction action = transitionSystem.getDeterministicAction(parserState.getHistory(), currentParserConfiguration);
039 if (action == null) {
040 action = oracleGuide.predict(goldDependencyGraph, currentParserConfiguration);
041 try {
042 classifierGuide.addInstance((GuideDecision)action);
043 } catch (NullPointerException e) {
044 throw new MaltChainedException("The guide cannot be found. ", e);
045 }
046 } else if (diagnostics == true) {
047 writeToDiaFile(" *");
048 }
049 if (diagnostics == true) {
050 writeToDiaFile(" " + transitionSystem.getActionString(action));
051 }
052 parserState.apply(action);
053 }
054 copyEdges(currentParserConfiguration.getDependencyGraph(), parseDependencyGraph);
055 parseDependencyGraph.linkAllTreesToRoot();
056 oracleGuide.finalizeSentence(parseDependencyGraph);
057 if (diagnostics == true) {
058 writeToDiaFile("\n");
059 }
060 return parseDependencyGraph;
061 }
062
063 public OracleGuide getOracleGuide() {
064 return oracleGuide;
065 }
066
067 public void train() throws MaltChainedException { }
068 public void terminate() throws MaltChainedException {
069 if (diagnostics == true) {
070 closeDiaWriter();
071 }
072 }
073 }