001 package org.maltparser.ml.lib;
002
003 import java.io.BufferedReader;
004 import java.io.EOFException;
005 import java.io.File;
006 import java.io.FileInputStream;
007 import java.io.IOException;
008 import java.io.InputStreamReader;
009 import java.io.ObjectInputStream;
010 import java.io.ObjectOutputStream;
011 import java.io.Reader;
012 import java.io.Serializable;
013 import java.nio.charset.Charset;
014 import java.util.Arrays;
015 import java.util.regex.Pattern;
016
017 import org.maltparser.core.helper.Util;
018
019 import de.bwaldvogel.liblinear.SolverType;
020
021 /**
022 * <p>This class borrows code from liblinear.Model.java of the Java implementation of the liblinear package.
023 * MaltLiblinearModel stores the model obtained from the training procedure. In addition to the original code the model is more integrated to
024 * MaltParser. Instead of moving features from MaltParser's internal data structures to liblinear's data structure it uses MaltParser's data
025 * structure directly on the model. </p>
026 *
027 * @author Johan Hall
028 *
029 */
030 public class MaltLiblinearModel implements Serializable, MaltLibModel {
031 private static final long serialVersionUID = 7526471155622776147L;
032 private static final Charset FILE_CHARSET = Charset.forName("ISO-8859-1");
033 private double bias;
034 /** label of each class */
035 private int[] labels;
036 private int nr_class;
037 private int nr_feature;
038 private SolverType solverType;
039 /** feature weight array */
040 private double[][] w;
041
042 public MaltLiblinearModel(int[] labels, int nr_class, int nr_feature, double[][] w, SolverType solverType) {
043 this.labels = labels;
044 this.nr_class = nr_class;
045 this.nr_feature = nr_feature;
046 this.w = w;
047 this.solverType = solverType;
048 }
049
050 public MaltLiblinearModel(Reader inputReader) throws IOException {
051 loadModel(inputReader);
052 }
053
054 public MaltLiblinearModel(File modelFile) throws IOException {
055 BufferedReader inputReader = new BufferedReader(new InputStreamReader(new FileInputStream(modelFile), FILE_CHARSET));
056 loadModel(inputReader);
057 }
058
059 /**
060 * @return number of classes
061 */
062 public int getNrClass() {
063 return nr_class;
064 }
065
066 /**
067 * @return number of features
068 */
069 public int getNrFeature() {
070 return nr_feature;
071 }
072
073 public int[] getLabels() {
074 return Util.copyOf(labels, nr_class);
075 }
076
077 /**
078 * The nr_feature*nr_class array w gives feature weights. We use one
079 * against the rest for multi-class classification, so each feature
080 * index corresponds to nr_class weight values. Weights are
081 * organized in the following way
082 *
083 * <pre>
084 * +------------------+------------------+------------+
085 * | nr_class weights | nr_class weights | ...
086 * | for 1st feature | for 2nd feature |
087 * +------------------+------------------+------------+
088 * </pre>
089 *
090 * If bias >= 0, x becomes [x; bias]. The number of features is
091 * increased by one, so w is a (nr_feature+1)*nr_class array. The
092 * value of bias is stored in the variable bias.
093 * @see #getBias()
094 * @return a <b>copy of</b> the feature weight array as described
095 */
096 // public double[] getFeatureWeights() {
097 // return Util.copyOf(w, w.length);
098 // }
099
100 /**
101 * @return true for logistic regression solvers
102 */
103 public boolean isProbabilityModel() {
104 return (solverType == SolverType.L2R_LR || solverType == SolverType.L2R_LR_DUAL || solverType == SolverType.L1R_LR);
105 }
106
107 public double getBias() {
108 return bias;
109 }
110
111 public int[] predict(MaltFeatureNode[] x) {
112 final double[] dec_values = new double[nr_class];
113 final int[] predictionList = Util.copyOf(labels, nr_class);
114 final int n = (bias >= 0)?nr_feature + 1:nr_feature;
115 // final int nr_w = (nr_class == 2 && solverType != SolverType.MCSVM_CS)?1:nr_class;
116 final int xlen = x.length;
117 // int i;
118 // for (i = 0; i < nr_w; i++) {
119 // dec_values[i] = 0;
120 // }
121
122 for (int i=0; i < xlen; i++) {
123 if (x[i].index <= n) {
124 final int t = (x[i].index - 1);
125 if (w[t] != null) {
126 for (int j = 0; j < w[t].length; j++) {
127 dec_values[j] += w[t][j] * x[i].value;
128 }
129 }
130 }
131 }
132
133
134 double tmpDec;
135 int tmpObj;
136 int lagest;
137 final int nc = nr_class-1;
138 for (int i=0; i < nc; i++) {
139 lagest = i;
140 for (int j=i; j < nr_class; j++) {
141 if (dec_values[j] > dec_values[lagest]) {
142 lagest = j;
143 }
144 }
145 tmpDec = dec_values[lagest];
146 dec_values[lagest] = dec_values[i];
147 dec_values[i] = tmpDec;
148 tmpObj = predictionList[lagest];
149 predictionList[lagest] = predictionList[i];
150 predictionList[i] = tmpObj;
151 }
152 return predictionList;
153 }
154
155 private void readObject(ObjectInputStream is) throws ClassNotFoundException, IOException {
156 is.defaultReadObject();
157 }
158
159 private void writeObject(ObjectOutputStream os) throws IOException {
160 os.defaultWriteObject();
161 }
162
163 private void loadModel(Reader inputReader) throws IOException {
164 labels = null;
165 Pattern whitespace = Pattern.compile("\\s+");
166 BufferedReader reader = null;
167 if (inputReader instanceof BufferedReader) {
168 reader = (BufferedReader)inputReader;
169 } else {
170 reader = new BufferedReader(inputReader);
171 }
172
173 try {
174 String line = null;
175 while ((line = reader.readLine()) != null) {
176 String[] split = whitespace.split(line);
177 if (split[0].equals("solver_type")) {
178 SolverType solver = SolverType.valueOf(split[1]);
179 if (solver == null) {
180 throw new RuntimeException("unknown solver type");
181 }
182 solverType = solver;
183 } else if (split[0].equals("nr_class")) {
184 nr_class = Util.atoi(split[1]);
185 Integer.parseInt(split[1]);
186 } else if (split[0].equals("nr_feature")) {
187 nr_feature = Util.atoi(split[1]);
188 } else if (split[0].equals("bias")) {
189 bias = Util.atof(split[1]);
190 } else if (split[0].equals("w")) {
191 break;
192 } else if (split[0].equals("label")) {
193 labels = new int[nr_class];
194 for (int i = 0; i < nr_class; i++) {
195 labels[i] = Util.atoi(split[i + 1]);
196 }
197 } else {
198 throw new RuntimeException("unknown text in model file: [" + line + "]");
199 }
200 }
201
202 int w_size = nr_feature;
203 if (bias >= 0) w_size++;
204
205 int nr_w = nr_class;
206 if (nr_class == 2 && solverType != SolverType.MCSVM_CS) nr_w = 1;
207 w = new double[w_size][nr_w];
208 int[] buffer = new int[128];
209
210 for (int i = 0; i < w_size; i++) {
211 for (int j = 0; j < nr_w; j++) {
212 int b = 0;
213 while (true) {
214 int ch = reader.read();
215 if (ch == -1) {
216 throw new EOFException("unexpected EOF");
217 }
218 if (ch == ' ') {
219 w[i][j] = Util.atof(new String(buffer, 0, b));
220 break;
221 } else {
222 buffer[b++] = ch;
223 }
224 }
225 }
226 }
227 }
228 finally {
229 Util.closeQuietly(reader);
230 }
231 }
232
233 public int hashCode() {
234 final int prime = 31;
235 long temp = Double.doubleToLongBits(bias);
236 int result = prime * 1 + (int)(temp ^ (temp >>> 32));
237 result = prime * result + Arrays.hashCode(labels);
238 result = prime * result + nr_class;
239 result = prime * result + nr_feature;
240 result = prime * result + ((solverType == null) ? 0 : solverType.hashCode());
241 for (int i = 0; i < w.length; i++) {
242 result = prime * result + Arrays.hashCode(w[i]);
243 }
244 return result;
245 }
246
247 public boolean equals(Object obj) {
248 if (this == obj) return true;
249 if (obj == null) return false;
250 if (getClass() != obj.getClass()) return false;
251 MaltLiblinearModel other = (MaltLiblinearModel)obj;
252 if (Double.doubleToLongBits(bias) != Double.doubleToLongBits(other.bias)) return false;
253 if (!Arrays.equals(labels, other.labels)) return false;
254 if (nr_class != other.nr_class) return false;
255 if (nr_feature != other.nr_feature) return false;
256 if (solverType == null) {
257 if (other.solverType != null) return false;
258 } else if (!solverType.equals(other.solverType)) return false;
259 for (int i = 0; i < w.length; i++) {
260 if (other.w.length <= i) return false;
261 if (!Util.equals(w[i], other.w[i])) return false;
262 }
263 return true;
264 }
265
266 public String toString() {
267 final StringBuilder sb = new StringBuilder("Model");
268 sb.append(" bias=").append(bias);
269 sb.append(" nr_class=").append(nr_class);
270 sb.append(" nr_feature=").append(nr_feature);
271 sb.append(" solverType=").append(solverType);
272 return sb.toString();
273 }
274 }