001 package org.maltparser.parser.guide.instance;
002
003 import java.io.IOException;
004 import java.lang.reflect.Constructor;
005 import java.lang.reflect.InvocationTargetException;
006 import java.util.ArrayList;
007 import java.util.Formatter;
008
009 import org.maltparser.core.exception.MaltChainedException;
010 import org.maltparser.core.feature.FeatureVector;
011 import org.maltparser.core.feature.function.FeatureFunction;
012 import org.maltparser.core.feature.function.Modifiable;
013 import org.maltparser.core.syntaxgraph.DependencyStructure;
014 import org.maltparser.ml.LearningMethod;
015 import org.maltparser.parser.guide.ClassifierGuide;
016 import org.maltparser.parser.guide.GuideException;
017 import org.maltparser.parser.guide.Model;
018 import org.maltparser.parser.history.action.SingleDecision;
019
020
021 /**
022
023 @author Johan Hall
024 @since 1.0
025 */
026 public class AtomicModel implements InstanceModel {
027 private Model parent;
028 private String modelName;
029 private FeatureVector featureVector;
030 private int index;
031 private int frequency = 0;
032 private LearningMethod method;
033
034
035 /**
036 * Constructs an atomic model.
037 *
038 * @param index the index of the atomic model (-1..n), where -1 is special value (used by a single model
039 * or the master divide model) and n is number of divide models.
040 * @param features the feature vector used by the atomic model.
041 * @param parent the parent guide model.
042 * @throws MaltChainedException
043 */
044 public AtomicModel(int index, FeatureVector features, Model parent) throws MaltChainedException {
045 setParent(parent);
046 setIndex(index);
047 if (index == -1) {
048 setModelName(parent.getModelName()+".");
049 } else {
050 setModelName(parent.getModelName()+"."+new Formatter().format("%03d", index)+".");
051 }
052 setFeatures(features);
053 setFrequency(0);
054 initMethod();
055 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH && index == -1 && getGuide().getConfiguration().getConfigurationDir().getInfoFileWriter() != null) {
056 try {
057 getGuide().getConfiguration().getConfigurationDir().getInfoFileWriter().write(method.toString());
058 getGuide().getConfiguration().getConfigurationDir().getInfoFileWriter().flush();
059 } catch (IOException e) {
060 throw new GuideException("Could not write learner settings to the information file. ", e);
061 }
062 }
063 }
064
065 public void addInstance(SingleDecision decision) throws MaltChainedException {
066 try {
067 method.addInstance(decision, featureVector);
068 } catch (NullPointerException e) {
069 throw new GuideException("The learner cannot be found. ", e);
070 }
071 }
072
073
074 public void noMoreInstances() throws MaltChainedException {
075 try {
076 method.noMoreInstances();
077 } catch (NullPointerException e) {
078 throw new GuideException("The learner cannot be found. ", e);
079 }
080 }
081
082 public void finalizeSentence(DependencyStructure dependencyGraph) throws MaltChainedException {
083 try {
084 method.finalizeSentence(dependencyGraph);
085 } catch (NullPointerException e) {
086 throw new GuideException("The learner cannot be found. ", e);
087 }
088 }
089
090 public boolean predict(SingleDecision decision) throws MaltChainedException {
091 try {
092 // if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
093 // throw new GuideException("Cannot predict during batch training. ");
094 // }
095 return method.predict(featureVector, decision);
096 } catch (NullPointerException e) {
097 throw new GuideException("The learner cannot be found. ", e);
098 }
099 }
100
101 public FeatureVector predictExtract(SingleDecision decision) throws MaltChainedException {
102 try {
103 // if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
104 // throw new GuideException("Cannot predict during batch training. ");
105 // }
106 if (method.predict(featureVector, decision)) {
107 return featureVector;
108 }
109 return null;
110 } catch (NullPointerException e) {
111 throw new GuideException("The learner cannot be found. ", e);
112 }
113 }
114
115 public FeatureVector extract() throws MaltChainedException {
116 return featureVector;
117 }
118
119 public void terminate() throws MaltChainedException {
120 if (method != null) {
121 method.terminate();
122 method = null;
123 }
124 featureVector = null;
125 parent = null;
126 }
127
128 /**
129 * Moves all instance from this atomic model into the destination atomic model and add the divide feature.
130 * This method is used by the feature divide model to sum up all model below a certain threshold.
131 *
132 * @param model the destination atomic model
133 * @param divideFeature the divide feature
134 * @param divideFeatureIndexVector the divide feature index vector
135 * @throws MaltChainedException
136 */
137 public void moveAllInstances(AtomicModel model, FeatureFunction divideFeature, ArrayList<Integer> divideFeatureIndexVector) throws MaltChainedException {
138 if (method == null) {
139 throw new GuideException("The learner cannot be found. ");
140 } else if (model == null) {
141 throw new GuideException("The guide model cannot be found. ");
142 } else if (divideFeature == null) {
143 throw new GuideException("The divide feature cannot be found. ");
144 } else if (divideFeatureIndexVector == null) {
145 throw new GuideException("The divide feature index vector cannot be found. ");
146 }
147 ((Modifiable)divideFeature).setFeatureValue(index);
148 method.moveAllInstances(model.getMethod(), divideFeature, divideFeatureIndexVector);
149 method.terminate();
150 method = null;
151 }
152
153 /**
154 * Invokes the train() of the learning method
155 *
156 * @throws MaltChainedException
157 */
158 public void train() throws MaltChainedException {
159 try {
160 method.train(featureVector);
161 method.terminate();
162 method = null;
163
164 } catch (NullPointerException e) {
165 throw new GuideException("The learner cannot be found. ", e);
166 }
167
168
169 }
170
171 /**
172 * Initialize the learning method according to the option --learner-method.
173 *
174 * @throws MaltChainedException
175 */
176 public void initMethod() throws MaltChainedException {
177 Class<?> clazz = (Class<?>)getGuide().getConfiguration().getOptionValue("guide", "learner");
178 Class<?>[] argTypes = { org.maltparser.parser.guide.instance.InstanceModel.class, java.lang.Integer.class };
179 Object[] arguments = new Object[2];
180 arguments[0] = this;
181 if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.CLASSIFY) {
182 arguments[1] = LearningMethod.CLASSIFY;
183 } else if (getGuide().getGuideMode() == ClassifierGuide.GuideMode.BATCH) {
184 arguments[1] = LearningMethod.BATCH;
185 }
186
187 try {
188 Constructor<?> constructor = clazz.getConstructor(argTypes);
189 this.method = (LearningMethod)constructor.newInstance(arguments);
190 } catch (NoSuchMethodException e) {
191 throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
192 } catch (InstantiationException e) {
193 throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
194 } catch (IllegalAccessException e) {
195 throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
196 } catch (InvocationTargetException e) {
197 throw new GuideException("The learner class '"+clazz.getName()+"' cannot be initialized. ", e);
198 }
199 }
200
201
202
203 /**
204 * Returns the parent guide model
205 *
206 * @return the parent guide model
207 */
208 public Model getParent() throws MaltChainedException {
209 if (parent == null) {
210 throw new GuideException("The atomic model can only be used by a parent model. ");
211 }
212 return parent;
213 }
214
215 /**
216 * Sets the parent guide model
217 *
218 * @param parent the parent guide model
219 */
220 protected void setParent(Model parent) {
221 this.parent = parent;
222 }
223
224 public String getModelName() {
225 return modelName;
226 }
227
228 /**
229 * Sets the name of the atomic model
230 *
231 * @param modelName the name of the atomic model
232 */
233 protected void setModelName(String modelName) {
234 this.modelName = modelName;
235 }
236
237 /**
238 * Returns the feature vector used by this atomic model
239 *
240 * @return a feature vector object
241 */
242 public FeatureVector getFeatures() {
243 return featureVector;
244 }
245
246 /**
247 * Sets the feature vector used by the atomic model.
248 *
249 * @param features a feature vector object
250 */
251 protected void setFeatures(FeatureVector features) {
252 this.featureVector = features;
253 }
254
255 public ClassifierGuide getGuide() {
256 return parent.getGuide();
257 }
258
259 /**
260 * Returns the index of the atomic model
261 *
262 * @return the index of the atomic model
263 */
264 public int getIndex() {
265 return index;
266 }
267
268 /**
269 * Sets the index of the model (-1..n), where -1 is a special value.
270 *
271 * @param index index value (-1..n) of the atomic model
272 */
273 protected void setIndex(int index) {
274 this.index = index;
275 }
276
277 /**
278 * Returns the frequency (number of instances)
279 *
280 * @return the frequency (number of instances)
281 */
282 public int getFrequency() {
283 return frequency;
284 }
285
286 /**
287 * Increase the frequency by 1
288 */
289 public void increaseFrequency() {
290 if (parent instanceof InstanceModel) {
291 ((InstanceModel)parent).increaseFrequency();
292 }
293 frequency++;
294 }
295
296 public void decreaseFrequency() {
297 if (parent instanceof InstanceModel) {
298 ((InstanceModel)parent).decreaseFrequency();
299 }
300 frequency--;
301 }
302 /**
303 * Sets the frequency (number of instances)
304 *
305 * @param frequency (number of instances)
306 */
307 protected void setFrequency(int frequency) {
308 this.frequency = frequency;
309 }
310
311 /**
312 * Returns a learner object
313 *
314 * @return a learner object
315 */
316 public LearningMethod getMethod() {
317 return method;
318 }
319
320
321 /* (non-Javadoc)
322 * @see java.lang.Object#toString()
323 */
324 public String toString() {
325 final StringBuilder sb = new StringBuilder();
326 sb.append(method.toString());
327 return sb.toString();
328 }
329 }