Skip to content

Commit

Permalink
Issue #352: instable runnable version
Browse files Browse the repository at this point in the history
  • Loading branch information
Horsmann committed May 14, 2016
1 parent 8a1da19 commit 59c0b0b
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,71 +129,45 @@ private String[][] parseSentence(String sentence)
}

@Override
public Instance pipe (Instance carrier)
{
Object inputData = carrier.getData();
Alphabet features = getDataAlphabet();
LabelAlphabet labels;
LabelSequence target = null;
String [][] tokens;
if (inputData instanceof String)
tokens = parseSentence((String)inputData);
else if (inputData instanceof String[][])
tokens = (String[][])inputData;
else
throw new IllegalArgumentException("Not a String or String[][]; got "+inputData);
FeatureVector[] fvs = new FeatureVector[tokens.length];
if (isTargetProcessing())
{
labels = (LabelAlphabet)getTargetAlphabet();
target = new LabelSequence (labels, tokens.length);
}

for (int l = 0; l < tokens.length; l++) {
int nFeatures;
if (isTargetProcessing())
{
if (tokens[l].length < 1)
throw new IllegalStateException ("Missing label at line " + l + " instance "+carrier.getName ());
nFeatures = tokens[l].length - 1;
target.add(tokens[l][nFeatures]);
}
else nFeatures = tokens[l].length;
ArrayList<Integer> featureIndices = new ArrayList<Integer>();
ArrayList<Double> featureValues = new ArrayList<Double>();
for (int f = 0; f < nFeatures; f++) {
int featureIndex = features.lookupIndex(tokens[l][f]);
// gdruck
// If the data alphabet's growth is stopped, featureIndex
// will be -1. Ignore these features.
if (featureIndex >= 0) {
featureIndices.add(featureIndex);
}
featureValues.add(Double.parseDouble(tokens[l][f]));
}
int[] featureIndicesArr = new int[featureIndices.size()];
for (int index = 0; index < featureIndices.size(); index++) {
featureIndicesArr[index] = featureIndices.get(index);
}
double[] featureValuesArr = new double[featureValues.size()];
for (int index = 0; index < featureValues.size(); index++) {
featureValuesArr[index] = featureValues.get(index);
}
if (denseFeatureValues)
fvs[l] = new FeatureVector(features, featureValuesArr);
else
fvs[l] = new FeatureVector(features, featureIndicesArr);
//fvs[l] = featureInductionOption.value ? new AugmentableFeatureVector(features, featureIndicesArr, null, featureIndicesArr.length) :
// fvs[l] = featureInductionOption.value ? new AugmentableFeatureVector(features, featureIndicesArr, null, featureIndicesArr.length) :
// new FeatureVector(features, featureValues);
}
carrier.setData(new FeatureVectorSequence(fvs));
if (isTargetProcessing())
carrier.setTarget(target);
else
carrier.setTarget(new LabelSequence(getTargetAlphabet()));
return carrier;
}
public Instance pipe (Instance carrier)
{
Object inputData = carrier.getData();
Alphabet features = getDataAlphabet();
LabelAlphabet labels;
LabelSequence target = null;
String [][] tokens;
if (inputData instanceof String)
tokens = parseSentence((String)inputData);
else if (inputData instanceof String[][])
tokens = (String[][])inputData;
else
throw new IllegalArgumentException("Not a String or String[][]; got "+inputData);
FeatureVector[] fvs = new FeatureVector[tokens.length];
if (isTargetProcessing())
{
labels = (LabelAlphabet)getTargetAlphabet();
target = new LabelSequence (labels, tokens.length);
}
for (int l = 0; l < tokens.length; l++) {
int nFeatures;
if (isTargetProcessing())
{
if (tokens[l].length < 1)
throw new IllegalStateException ("Missing label at line " + l + " instance "+carrier.getName ());
nFeatures = tokens[l].length - 1;
target.add(tokens[l][nFeatures]);
}
else nFeatures = tokens[l].length;
int featureIndices[] = new int[nFeatures];
for (int f = 0; f < nFeatures; f++)
featureIndices[f] = features.lookupIndex(tokens[l][f]);
fvs[l] = new FeatureVector(features, featureIndices);
}
carrier.setData(new FeatureVectorSequence(fvs));
if (isTargetProcessing())
carrier.setTarget(target);
return carrier;
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -106,30 +106,50 @@ public static HashMap<String, Integer> getFeatureOffsetIndex(FeatureStore instan
public static CRF trainCRF(InstanceList training, CRF crf, double gaussianPriorVariance, int iterations, String defaultLabel,
boolean fullyConnected, int[] orders) {

if (crf == null) {
crf = new CRF(training.getPipe(), (Pipe)null);
String startName =
crf.addOrderNStates(training, orders, null,
defaultLabel, null, null,
fullyConnected);
for (int i = 0; i < crf.numStates(); i++) {
crf.getState(i).setInitialWeight (Transducer.IMPOSSIBLE_WEIGHT);
}
crf.getState(startName).setInitialWeight(0.0);
}
// logger.info("Training on " + training.size() + " instances");

CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood (crf);
crft.setGaussianPriorVariance(gaussianPriorVariance);

boolean converged;
for (int i = 1; i <= iterations; i++) {
converged = crft.train (training, 1);
if (converged) {

crf = new CRF(training.getPipe(), null);
//crf.addStatesForLabelsConnectedAsIn(trainingInstances);
crf.addStatesForThreeQuarterLabelsConnectedAsIn(training);
crf.addStartState();

CRFTrainerByLabelLikelihood trainer =
new CRFTrainerByLabelLikelihood(crf);
trainer.setGaussianPriorVariance(10.0);

boolean converged;
for (int i = 1; i <= iterations; i++) {
converged = trainer.train (training, 1);
if (converged) {
break;
}
}
return crf;
}
return crf;


// if (crf == null) {
// crf = new CRF(training.getPipe(), (Pipe)null);
// String startName =
// crf.addOrderNStates(training, orders, null,
// defaultLabel, null, null,
// fullyConnected);
// for (int i = 0; i < crf.numStates(); i++) {
// crf.getState(i).setInitialWeight (Transducer.IMPOSSIBLE_WEIGHT);
// }
// crf.getState(startName).setInitialWeight(0.0);
// }
// // logger.info("Training on " + training.size() + " instances");
//
// CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood (crf);
// crft.setGaussianPriorVariance(gaussianPriorVariance);
//
// boolean converged;
// for (int i = 1; i <= iterations; i++) {
// converged = crft.train (training, 1);
// if (converged) {
// break;
// }
// }
// return crf;
}

public static void runTrainCRF(File trainingFile, File modelFile, double var, int iterations, String defaultLabel,
Expand Down

0 comments on commit 59c0b0b

Please sign in to comment.