001 package org.maltparser.parser.guide.decision;
002
003 import java.lang.reflect.Constructor;
004 import java.lang.reflect.InvocationTargetException;
005
006
007 import org.maltparser.core.exception.MaltChainedException;
008 import org.maltparser.core.feature.FeatureModel;
009 import org.maltparser.core.feature.FeatureVector;
010 import org.maltparser.core.helper.HashMap;
011 import org.maltparser.core.syntaxgraph.DependencyStructure;
012 import org.maltparser.parser.DependencyParserConfig;
013 import org.maltparser.parser.guide.ClassifierGuide;
014 import org.maltparser.parser.guide.GuideException;
015 import org.maltparser.parser.guide.instance.AtomicModel;
016 import org.maltparser.parser.guide.instance.FeatureDivideModel;
017 import org.maltparser.parser.guide.instance.InstanceModel;
018 import org.maltparser.parser.history.action.GuideDecision;
019 import org.maltparser.parser.history.action.MultipleDecision;
020 import org.maltparser.parser.history.action.SingleDecision;
021 import org.maltparser.parser.history.container.TableContainer.RelationToNextDecision;
022 /**
023 *
024 * @author Johan Hall
025 * @since 1.1
026 **/
027 public class BranchedDecisionModel implements DecisionModel {
028 private ClassifierGuide guide;
029 private String modelName;
030 private FeatureModel featureModel;
031 private InstanceModel instanceModel;
032 private int decisionIndex;
033 private DecisionModel parentDecisionModel;
034 private HashMap<Integer,DecisionModel> children;
035 private String branchedDecisionSymbols;
036
037 public BranchedDecisionModel(ClassifierGuide guide, FeatureModel featureModel) throws MaltChainedException {
038 this.branchedDecisionSymbols = "";
039 setGuide(guide);
040 setFeatureModel(featureModel);
041 setDecisionIndex(0);
042 setModelName("bdm"+decisionIndex);
043 setParentDecisionModel(null);
044 }
045
046 public BranchedDecisionModel(ClassifierGuide guide, DecisionModel parentDecisionModel, String branchedDecisionSymbol) throws MaltChainedException {
047 if (branchedDecisionSymbol != null && branchedDecisionSymbol.length() > 0) {
048 this.branchedDecisionSymbols = branchedDecisionSymbol;
049 } else {
050 this.branchedDecisionSymbols = "";
051 }
052 setGuide(guide);
053 setParentDecisionModel(parentDecisionModel);
054 setDecisionIndex(parentDecisionModel.getDecisionIndex() + 1);
055 setFeatureModel(parentDecisionModel.getFeatureModel());
056 if (branchedDecisionSymbols != null && branchedDecisionSymbols.length() > 0) {
057 setModelName("bdm"+decisionIndex+branchedDecisionSymbols);
058 } else {
059 setModelName("bdm"+decisionIndex);
060 }
061 this.parentDecisionModel = parentDecisionModel;
062 }
063
064 public void updateFeatureModel() throws MaltChainedException {
065 featureModel.update();
066 }
067
068 // public void updateCardinality() throws MaltChainedException {
069 // featureModel.updateCardinality();
070 // }
071
072
073 public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException {
074 if (instanceModel != null) {
075 instanceModel.finalizeSentence(dependencyGraph);
076 }
077 if (children != null) {
078 for (DecisionModel child : children.values()) {
079 child.finalizeSentence(dependencyGraph);
080 }
081 }
082 }
083
084 public void noMoreInstances() throws MaltChainedException {
085 if (guide.getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) {
086 throw new GuideException("The decision model could not create it's model. ");
087 }
088 if (instanceModel != null) {
089 instanceModel.noMoreInstances();
090 instanceModel.train();
091 }
092 if (children != null) {
093 for (DecisionModel child : children.values()) {
094 child.noMoreInstances();
095 }
096 }
097 }
098
099 public void terminate() throws MaltChainedException {
100 if (instanceModel != null) {
101 instanceModel.terminate();
102 instanceModel = null;
103 }
104 if (children != null) {
105 for (DecisionModel child : children.values()) {
106 child.terminate();
107 }
108 }
109 }
110
111 public void addInstance(GuideDecision decision) throws MaltChainedException {
112 if (decision instanceof SingleDecision) {
113 throw new GuideException("A branched decision model expect more than one decisions. ");
114 }
115 featureModel.update();
116 final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex);
117 if (instanceModel == null) {
118 initInstanceModel(singleDecision.getTableContainer().getTableContainerName());
119 }
120
121 instanceModel.addInstance(singleDecision);
122 if (decisionIndex+1 < decision.numberOfDecisions()) {
123 if (singleDecision.continueWithNextDecision()) {
124 if (children == null) {
125 children = new HashMap<Integer,DecisionModel>();
126 }
127 DecisionModel child = children.get(singleDecision.getDecisionCode());
128 if (child == null) {
129 child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1),
130 branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol());
131 children.put(singleDecision.getDecisionCode(), child);
132 }
133 child.addInstance(decision);
134 }
135 }
136 }
137
138 public boolean predict(GuideDecision decision) throws MaltChainedException {
139 // if (decision instanceof SingleDecision) {
140 // throw new GuideException("A branched decision model expect more than one decisions. ");
141 // }
142 featureModel.update();
143 final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex);
144 if (instanceModel == null) {
145 initInstanceModel(singleDecision.getTableContainer().getTableContainerName());
146 }
147 instanceModel.predict(singleDecision);
148 if (decisionIndex+1 < decision.numberOfDecisions()) {
149 if (singleDecision.continueWithNextDecision()) {
150 if (children == null) {
151 children = new HashMap<Integer,DecisionModel>();
152 }
153 DecisionModel child = children.get(singleDecision.getDecisionCode());
154 if (child == null) {
155 child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1),
156 branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol());
157 children.put(singleDecision.getDecisionCode(), child);
158 }
159 child.predict(decision);
160 }
161 }
162
163 return true;
164 }
165
166 public FeatureVector predictExtract(GuideDecision decision) throws MaltChainedException {
167 if (decision instanceof SingleDecision) {
168 throw new GuideException("A branched decision model expect more than one decisions. ");
169 }
170 featureModel.update();
171 final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex);
172 if (instanceModel == null) {
173 initInstanceModel(singleDecision.getTableContainer().getTableContainerName());
174 }
175 FeatureVector fv = instanceModel.predictExtract(singleDecision);
176 if (decisionIndex+1 < decision.numberOfDecisions()) {
177 if (singleDecision.continueWithNextDecision()) {
178 if (children == null) {
179 children = new HashMap<Integer,DecisionModel>();
180 }
181 DecisionModel child = children.get(singleDecision.getDecisionCode());
182 if (child == null) {
183 child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1),
184 branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol());
185 children.put(singleDecision.getDecisionCode(), child);
186 }
187 child.predictExtract(decision);
188 }
189 }
190
191 return fv;
192 }
193
194 public FeatureVector extract() throws MaltChainedException {
195 featureModel.update();
196 return instanceModel.extract(); // TODO handle many feature vectors
197 }
198
199 public boolean predictFromKBestList(GuideDecision decision) throws MaltChainedException {
200 if (decision instanceof SingleDecision) {
201 throw new GuideException("A branched decision model expect more than one decisions. ");
202 }
203
204 boolean success = false;
205 final SingleDecision singleDecision = ((MultipleDecision)decision).getSingleDecision(decisionIndex);
206 if (decisionIndex+1 < decision.numberOfDecisions()) {
207 if (singleDecision.continueWithNextDecision()) {
208 if (children == null) {
209 children = new HashMap<Integer,DecisionModel>();
210 }
211 DecisionModel child = children.get(singleDecision.getDecisionCode());
212 if (child != null) {
213 success = child.predictFromKBestList(decision);
214 }
215
216 }
217 }
218 if (!success) {
219 success = singleDecision.updateFromKBestList();
220 if (decisionIndex+1 < decision.numberOfDecisions()) {
221 if (singleDecision.continueWithNextDecision()) {
222 if (children == null) {
223 children = new HashMap<Integer,DecisionModel>();
224 }
225 DecisionModel child = children.get(singleDecision.getDecisionCode());
226 if (child == null) {
227 child = initChildDecisionModel(((MultipleDecision)decision).getSingleDecision(decisionIndex+1),
228 branchedDecisionSymbols+(branchedDecisionSymbols.length() == 0?"":"_")+singleDecision.getDecisionSymbol());
229 children.put(singleDecision.getDecisionCode(), child);
230 }
231 child.predict(decision);
232 }
233 }
234 }
235 return success;
236 }
237
238
239 public ClassifierGuide getGuide() {
240 return guide;
241 }
242
243 public String getModelName() {
244 return modelName;
245 }
246
247 public FeatureModel getFeatureModel() {
248 return featureModel;
249 }
250
251 public int getDecisionIndex() {
252 return decisionIndex;
253 }
254
255 public DecisionModel getParentDecisionModel() {
256 return parentDecisionModel;
257 }
258
259 private void setFeatureModel(FeatureModel featureModel) {
260 this.featureModel = featureModel;
261 }
262
263 private void setDecisionIndex(int decisionIndex) {
264 this.decisionIndex = decisionIndex;
265 }
266
267 private void setParentDecisionModel(DecisionModel parentDecisionModel) {
268 this.parentDecisionModel = parentDecisionModel;
269 }
270
271 private void setModelName(String modelName) {
272 this.modelName = modelName;
273 }
274
275 private void setGuide(ClassifierGuide guide) {
276 this.guide = guide;
277 }
278
279
280 private DecisionModel initChildDecisionModel(SingleDecision decision, String branchedDecisionSymbol) throws MaltChainedException {
281 Class<?> decisionModelClass = null;
282 if (decision.getRelationToNextDecision() == RelationToNextDecision.SEQUANTIAL) {
283 decisionModelClass = org.maltparser.parser.guide.decision.SeqDecisionModel.class;
284 } else if (decision.getRelationToNextDecision() == RelationToNextDecision.BRANCHED) {
285 decisionModelClass = org.maltparser.parser.guide.decision.BranchedDecisionModel.class;
286 } else if (decision.getRelationToNextDecision() == RelationToNextDecision.NONE) {
287 decisionModelClass = org.maltparser.parser.guide.decision.OneDecisionModel.class;
288 }
289
290 if (decisionModelClass == null) {
291 throw new GuideException("Could not find an appropriate decision model for the relation to the next decision");
292 }
293
294 try {
295 Class<?>[] argTypes = { org.maltparser.parser.guide.ClassifierGuide.class, org.maltparser.parser.guide.decision.DecisionModel.class,
296 java.lang.String.class };
297 Object[] arguments = new Object[3];
298 arguments[0] = getGuide();
299 arguments[1] = this;
300 arguments[2] = branchedDecisionSymbol;
301 Constructor<?> constructor = decisionModelClass.getConstructor(argTypes);
302 return (DecisionModel)constructor.newInstance(arguments);
303 } catch (NoSuchMethodException e) {
304 throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e);
305 } catch (InstantiationException e) {
306 throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e);
307 } catch (IllegalAccessException e) {
308 throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e);
309 } catch (InvocationTargetException e) {
310 throw new GuideException("The decision model class '"+decisionModelClass.getName()+"' cannot be initialized. ", e);
311 }
312 }
313
314 private void initInstanceModel(String subModelName) throws MaltChainedException {
315 FeatureVector fv = featureModel.getFeatureVector(branchedDecisionSymbols+"."+subModelName);
316 if (fv == null) {
317 fv = featureModel.getFeatureVector(subModelName);
318 }
319 if (fv == null) {
320 fv = featureModel.getMainFeatureVector();
321 }
322
323 DependencyParserConfig c = guide.getConfiguration();
324
325 // if (c.getOptionValue("guide", "tree_automatic_split_order").toString().equals("yes") ||
326 // (c.getOptionValue("guide", "tree_split_columns")!=null &&
327 // c.getOptionValue("guide", "tree_split_columns").toString().length() > 0) ||
328 // (c.getOptionValue("guide", "tree_split_structures")!=null &&
329 // c.getOptionValue("guide", "tree_split_structures").toString().length() > 0)) {
330 // instanceModel = new DecisionTreeModel(fv, this);
331 // }else
332 if (c.getOptionValue("guide", "data_split_column").toString().length() == 0) {
333 instanceModel = new AtomicModel(-1, fv, this);
334 } else {
335 instanceModel = new FeatureDivideModel(fv, this);
336 }
337 }
338
339 public String toString() {
340 final StringBuilder sb = new StringBuilder();
341 sb.append(modelName + ", ");
342 for (DecisionModel model : children.values()) {
343 sb.append(model.toString() + ", ");
344 }
345 return sb.toString();
346 }
347 }