From f6541def80f0cd105ddacbcb917ba27c44765dec Mon Sep 17 00:00:00 2001 From: Tobias Horsmann Date: Sat, 14 May 2016 19:23:51 +0200 Subject: [PATCH] Issue #352: --- dkpro-tc-examples/pom.xml | 4 + .../single/sequence/MalletBrownPosDemo.java | 140 ++++++ dkpro-tc-ml-mallet/.activate_rat-check | 1 + dkpro-tc-ml-mallet/LICENSE.txt | 201 ++++++++ dkpro-tc-ml-mallet/pom.xml | 72 +++ .../org/dkpro/tc/mallet/MalletAdapter.java | 92 ++++ .../org/dkpro/tc/mallet/package-info.java | 24 + .../mallet/report/MalletOutcomeIDReport.java | 31 ++ .../mallet/report/MalletReportConstants.java | 46 ++ .../dkpro/tc/mallet/task/MalletTestTask.java | 107 +++++ .../ConversionToFeatureVectorSequence.java | 199 ++++++++ .../util/MalletFoldDimensionBundle.java | 143 ++++++ .../org/dkpro/tc/mallet/util/MalletUtils.java | 454 ++++++++++++++++++ .../tc/mallet/util/PerClassEvaluator.java | 212 ++++++++ .../tc/mallet/writer/MalletDataWriter.java | 210 ++++++++ .../mallet/writer/MalletFeatureEncoder.java | 51 ++ pom.xml | 13 +- 17 files changed, 1999 insertions(+), 1 deletion(-) create mode 100644 dkpro-tc-examples/src/main/java/org/dkpro/tc/examples/single/sequence/MalletBrownPosDemo.java create mode 100644 dkpro-tc-ml-mallet/.activate_rat-check create mode 100644 dkpro-tc-ml-mallet/LICENSE.txt create mode 100644 dkpro-tc-ml-mallet/pom.xml create mode 100644 dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/MalletAdapter.java create mode 100644 dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/package-info.java create mode 100644 dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/report/MalletOutcomeIDReport.java create mode 100644 dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/report/MalletReportConstants.java create mode 100644 dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/task/MalletTestTask.java create mode 100644 dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/util/ConversionToFeatureVectorSequence.java create mode 100644 dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/util/MalletFoldDimensionBundle.java create mode 100644 dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/util/MalletUtils.java create mode 100644 dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/util/PerClassEvaluator.java create mode 100644 dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/writer/MalletDataWriter.java create mode 100644 dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/writer/MalletFeatureEncoder.java diff --git a/dkpro-tc-examples/pom.xml b/dkpro-tc-examples/pom.xml index 213099e41..95bf06305 100644 --- a/dkpro-tc-examples/pom.xml +++ b/dkpro-tc-examples/pom.xml @@ -189,6 +189,10 @@ org.dkpro.tc dkpro-tc-fstore-simple + + org.dkpro.tc + dkpro-tc-ml-mallet + diff --git a/dkpro-tc-examples/src/main/java/org/dkpro/tc/examples/single/sequence/MalletBrownPosDemo.java b/dkpro-tc-examples/src/main/java/org/dkpro/tc/examples/single/sequence/MalletBrownPosDemo.java new file mode 100644 index 000000000..deaeb7b96 --- /dev/null +++ b/dkpro-tc-examples/src/main/java/org/dkpro/tc/examples/single/sequence/MalletBrownPosDemo.java @@ -0,0 +1,140 @@ +/** + * Copyright 2016 + * Ubiquitous Knowledge Processing (UKP) Lab + * Technische Universität Darmstadt + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see http://www.gnu.org/licenses/. + */ +package org.dkpro.tc.examples.single.sequence; + +import static de.tudarmstadt.ukp.dkpro.core.api.io.ResourceCollectionReaderBase.INCLUDE_PREFIX; +import static java.util.Arrays.asList; +import static org.apache.uima.fit.factory.AnalysisEngineFactory.createEngineDescription; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.uima.analysis_engine.AnalysisEngineDescription; +import org.apache.uima.collection.CollectionReaderDescription; +import org.apache.uima.fit.component.NoOpAnnotator; +import org.apache.uima.fit.factory.CollectionReaderFactory; +import org.apache.uima.resource.ResourceInitializationException; +import org.dkpro.lab.Lab; +import org.dkpro.lab.task.BatchTask.ExecutionPolicy; +import org.dkpro.lab.task.Dimension; +import org.dkpro.lab.task.ParameterSpace; +import org.dkpro.tc.core.Constants; +import org.dkpro.tc.crfsuite.CRFSuiteAdapter; +import org.dkpro.tc.examples.io.BrownCorpusReader; +import org.dkpro.tc.examples.util.DemoUtils; +import org.dkpro.tc.features.length.NrOfTokensUFE; +import org.dkpro.tc.features.ngram.LuceneCharacterNGramUFE; +import org.dkpro.tc.features.style.InitialCharacterUpperCaseUFE; +import org.dkpro.tc.fstore.simple.DenseFeatureStore; +import org.dkpro.tc.mallet.MalletAdapter; +import org.dkpro.tc.ml.ExperimentCrossValidation; +import org.dkpro.tc.ml.ExperimentTrainTest; +import org.dkpro.tc.ml.report.BatchCrossValidationReport; +import org.dkpro.tc.ml.report.BatchTrainTestReport; + +/** + * This a pure Java-based experiment setup of POS tagging as sequence tagging. + */ +public class MalletBrownPosDemo + implements Constants +{ + public static final String LANGUAGE_CODE = "en"; + + public static final int NUM_FOLDS = 2; + + public static final String corpusFilePathTrain = "src/main/resources/data/brown_tei/"; + + public static void main(String[] args) + throws Exception + { + + // This is used to ensure that the required DKPRO_HOME environment variable is set. + // Ensures that people can run the experiments even if they haven't read the setup + // instructions first :) +// DemoUtils.setDkproHome(MalletBrownPosDemo.class.getSimpleName()); + System.setProperty("DKPRO_HOME", System.getProperty("user.home")+"/Desktop/"); + + ParameterSpace pSpace = getParameterSpace(Constants.FM_SEQUENCE, Constants.LM_SINGLE_LABEL); + + MalletBrownPosDemo experiment = new MalletBrownPosDemo(); + experiment.runCrossValidation(pSpace); + } + + public static ParameterSpace getParameterSpace(String featureMode, String learningMode) + throws ResourceInitializationException + { + // configure training and test data reader dimension + Map dimReaders = new HashMap(); + + CollectionReaderDescription train = CollectionReaderFactory.createReaderDescription(BrownCorpusReader.class, BrownCorpusReader.PARAM_LANGUAGE, "en", + BrownCorpusReader.PARAM_SOURCE_LOCATION, corpusFilePathTrain, + BrownCorpusReader.PARAM_PATTERNS, + asList(INCLUDE_PREFIX + "a01.xml")); + dimReaders.put(DIM_READER_TRAIN, train); + + CollectionReaderDescription test = CollectionReaderFactory.createReaderDescription(BrownCorpusReader.class, BrownCorpusReader.PARAM_LANGUAGE, "en", + BrownCorpusReader.PARAM_SOURCE_LOCATION, corpusFilePathTrain, + BrownCorpusReader.PARAM_PATTERNS, + asList(INCLUDE_PREFIX + "a02.xml")); + dimReaders.put(DIM_READER_TEST, test); + + @SuppressWarnings("unchecked") + Dimension> dimPipelineParameters = Dimension.create(DIM_PIPELINE_PARAMS, + asList(new Object[] { LuceneCharacterNGramUFE.PARAM_CHAR_NGRAM_MIN_N, 2, + LuceneCharacterNGramUFE.PARAM_CHAR_NGRAM_MAX_N, 4, + LuceneCharacterNGramUFE.PARAM_CHAR_NGRAM_USE_TOP_K, 50 })); + + + @SuppressWarnings("unchecked") + Dimension> dimFeatureSets = Dimension.create(DIM_FEATURE_SET, + asList(new String[] { NrOfTokensUFE.class.getName(), + InitialCharacterUpperCaseUFE.class.getName() })); + + ParameterSpace pSpace = new ParameterSpace(Dimension.createBundle("readers", dimReaders), + Dimension.create(DIM_LEARNING_MODE, learningMode), + Dimension.create(DIM_FEATURE_MODE, featureMode), + Dimension.create(Constants.DIM_FEATURE_STORE, DenseFeatureStore.class.getName()), + dimPipelineParameters, dimFeatureSets); + + return pSpace; + } + + // ##### CV ##### + protected void runCrossValidation(ParameterSpace pSpace) + throws Exception + { + + ExperimentTrainTest batch = new ExperimentTrainTest("BrownPosDemoCV_Mallet", + MalletAdapter.class); + batch.setPreprocessing(getPreprocessing()); + batch.setParameterSpace(pSpace); + batch.setExecutionPolicy(ExecutionPolicy.RUN_AGAIN); + batch.addReport(BatchTrainTestReport.class); + + // Run + Lab.getInstance().run(batch); + } + + protected AnalysisEngineDescription getPreprocessing() + throws ResourceInitializationException + { + return createEngineDescription(NoOpAnnotator.class); + } +} diff --git a/dkpro-tc-ml-mallet/.activate_rat-check b/dkpro-tc-ml-mallet/.activate_rat-check new file mode 100644 index 000000000..5c4334301 --- /dev/null +++ b/dkpro-tc-ml-mallet/.activate_rat-check @@ -0,0 +1 @@ +Marker file to activate rat license checker profile \ No newline at end of file diff --git a/dkpro-tc-ml-mallet/LICENSE.txt b/dkpro-tc-ml-mallet/LICENSE.txt new file mode 100644 index 000000000..e6e77b089 --- /dev/null +++ b/dkpro-tc-ml-mallet/LICENSE.txt @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/dkpro-tc-ml-mallet/pom.xml b/dkpro-tc-ml-mallet/pom.xml new file mode 100644 index 000000000..c510a32f2 --- /dev/null +++ b/dkpro-tc-ml-mallet/pom.xml @@ -0,0 +1,72 @@ + + + 4.0.0 + + org.dkpro.tc + dkpro-tc + 0.9.0-SNAPSHOT + + dkpro-tc-ml-mallet + Interface to the Mallet Machine Learning Toolkit + + + Apache License, Version 2.0 + http://www.apache.org/licenses/LICENSE-2.0 + repo + + + + + cc.mallet + mallet + + + org.dkpro.tc + dkpro-tc-core + + + org.dkpro.tc + dkpro-tc-api-features + + + org.dkpro.tc + dkpro-tc-api + + + org.apache.commons + commons-math + + + commons-io + commons-io + + + org.dkpro.lab + dkpro-lab-core + + + commons-lang + commons-lang + + + org.dkpro.tc + dkpro-tc-ml + + + \ No newline at end of file diff --git a/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/MalletAdapter.java b/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/MalletAdapter.java new file mode 100644 index 000000000..1f580c1cd --- /dev/null +++ b/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/MalletAdapter.java @@ -0,0 +1,92 @@ +/******************************************************************************* + * Copyright 2015 + * Ubiquitous Knowledge Processing (UKP) Lab + * Technische Universität Darmstadt + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ +package org.dkpro.tc.mallet; + +import java.util.Collection; + +import org.dkpro.lab.reporting.ReportBase; +import org.dkpro.lab.task.Dimension; +import org.dkpro.lab.task.impl.DimensionBundle; +import org.dkpro.lab.task.impl.ExecutableTaskBase; +import org.dkpro.tc.core.io.DataWriter; +import org.dkpro.tc.core.ml.ModelSerialization_ImplBase; +import org.dkpro.tc.core.ml.TCMachineLearningAdapter; +import org.dkpro.tc.core.task.ModelSerializationTask; +import org.dkpro.tc.mallet.report.MalletOutcomeIDReport; +import org.dkpro.tc.mallet.task.MalletTestTask; +import org.dkpro.tc.mallet.util.MalletFoldDimensionBundle; +import org.dkpro.tc.mallet.writer.MalletDataWriter; +import org.dkpro.tc.ml.report.InnerBatchUsingTCEvaluationReport; + +public class MalletAdapter + implements TCMachineLearningAdapter +{ + + public static TCMachineLearningAdapter getInstance() { + return new MalletAdapter(); + } + + @Override + public ExecutableTaskBase getTestTask() { + return new MalletTestTask(); + } + + @Override + public Class getOutcomeIdReportClass() { + return MalletOutcomeIDReport.class; + } + + @Override + public Class getBatchTrainTestReportClass() { + return InnerBatchUsingTCEvaluationReport.class; + } + + @SuppressWarnings("unchecked") + @Override + public DimensionBundle> getFoldDimensionBundle( + String[] files, int folds) { + return new MalletFoldDimensionBundle("files", Dimension.create("", files), folds); + } + + @Override + public String getFrameworkFilename(AdapterNameEntries name) { + + switch (name) { + case featureVectorsFile: return "training-data.txt"; + case predictionsFile : return "predictions.txt"; + case featureSelectionFile : return "attributeEvaluationResults.txt"; + } + + return null; + } + + @Override + public Class getDataWriterClass() { + return MalletDataWriter.class; + } + + @Override + public Class getLoadModelConnectorClass() { + throw new UnsupportedOperationException(); + } + + @Override + public Class getSaveModelTask() { + throw new UnsupportedOperationException(); + } +} diff --git a/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/package-info.java b/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/package-info.java new file mode 100644 index 000000000..9be768ed8 --- /dev/null +++ b/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/package-info.java @@ -0,0 +1,24 @@ +/******************************************************************************* + * Copyright 2015 + * Ubiquitous Knowledge Processing (UKP) Lab + * Technische Universität Darmstadt + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +/** + * Support for MALLET machine learning framework. + * + * @since 0.1.0 + */ +package org.dkpro.tc.mallet; diff --git a/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/report/MalletOutcomeIDReport.java b/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/report/MalletOutcomeIDReport.java new file mode 100644 index 000000000..f3d00bd70 --- /dev/null +++ b/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/report/MalletOutcomeIDReport.java @@ -0,0 +1,31 @@ +/******************************************************************************* + * Copyright 2015 + * Ubiquitous Knowledge Processing (UKP) Lab + * Technische Universität Darmstadt + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ +package org.dkpro.tc.mallet.report; + +import org.dkpro.lab.reporting.ReportBase; + +public class MalletOutcomeIDReport + extends ReportBase +{ + + @Override + public void execute() + throws Exception + { + } +} \ No newline at end of file diff --git a/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/report/MalletReportConstants.java b/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/report/MalletReportConstants.java new file mode 100644 index 000000000..b364b57b2 --- /dev/null +++ b/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/report/MalletReportConstants.java @@ -0,0 +1,46 @@ +/******************************************************************************* + * Copyright 2015 + * Ubiquitous Knowledge Processing (UKP) Lab + * Technische Universität Darmstadt + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ +package org.dkpro.tc.mallet.report; + +/** + * Constants that are used in reports + * + * @deprecated As of release 0.7.0, only dkpro-tc-ml-crfsuite is supported + */ +public interface MalletReportConstants +{ + // accuracy + public static final String CORRECT = "Correctly Classified Examples"; + public static final String INCORRECT = "Incorrectly Classified Examples"; + public static final String PCT_CORRECT = "Percentage Correct"; + public static final String PCT_INCORRECT = "Percentage Incorrect"; + + // P/R/F/Accuracy + public static final String PRECISION = "Precision"; + public static final String RECALL = "Recall"; + public static final String FMEASURE = "F-Measure"; + + public static final String MACRO_AVERAGE_FMEASURE = "Macro-averaged F-Measure"; + +// public static final String WGT_PRECISION = "Weighted Precision"; +// public static final String WGT_RECALL = "Weighted Recall"; +// public static final String WGT_FMEASURE = "Weighted F-Measure"; + + public static final String NUMBER_EXAMPLES = "Absolute Number of Examples"; + public static final String NUMBER_LABELS = "Absolute Number of Labels"; +} diff --git a/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/task/MalletTestTask.java b/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/task/MalletTestTask.java new file mode 100644 index 000000000..382e770a4 --- /dev/null +++ b/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/task/MalletTestTask.java @@ -0,0 +1,107 @@ +/******************************************************************************* + * Copyright 2015 + * Ubiquitous Knowledge Processing (UKP) Lab + * Technische Universität Darmstadt + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ +package org.dkpro.tc.mallet.task; + +import java.io.File; +import java.util.ArrayList; + +import org.dkpro.lab.engine.TaskContext; +import org.dkpro.lab.storage.StorageService.AccessMode; +import org.dkpro.lab.task.Discriminator; +import org.dkpro.lab.task.impl.ExecutableTaskBase; +import org.dkpro.tc.core.Constants; +import org.dkpro.tc.mallet.util.MalletUtils; + +import cc.mallet.fst.TransducerEvaluator; + +public class MalletTestTask + extends ExecutableTaskBase +{ + @Discriminator + private String tagger = "CRF"; //added to configure other taggers like HMM, although these are not supported + + @Discriminator + private double gaussianPriorVariance = 10.0; //Gaussian Prior Variance + + @Discriminator + private int iterations = 1000; //Number of iterations + + @Discriminator + private String defaultLabel = "O"; + + @Discriminator + private int[] orders = new int[]{0, 1, 2, 3, 4}; + + @Discriminator + private boolean denseFeatureValues = true; + + public static ArrayList precisionValues; + + public static ArrayList recallValues; + + public static ArrayList f1Values; + + public static ArrayList labels; + + // TODO - most of that should be in Constants + public static final String PREDICTIONS_KEY = "predictions.txt"; + public static final String TRAINING_DATA_KEY = "training-data.txt"; //TODO Issue 127: add from Constants + public static final String EVALUATION_DATA_KEY = "evaluation.csv"; + public static final String CONFUSION_MATRIX_KEY = "confusionMatrix.csv"; +// public static final String FEATURE_SELECTION_DATA_KEY = "attributeEvaluationResults.txt"; + public static final String PREDICTION_CLASS_LABEL_NAME = "PredictedOutcome"; + public static final String OUTCOME_CLASS_LABEL_NAME = "Outcome"; + public static final String MALLET_MODEL_KEY = "mallet-model"; + + public static boolean MULTILABEL; + + @Override + public void execute(TaskContext aContext) + throws Exception + { + + File fileTrain = new File(aContext.getStorageLocation(Constants.TEST_TASK_INPUT_KEY_TRAINING_DATA, + AccessMode.READONLY).getPath() + + "/" + TRAINING_DATA_KEY); + File fileTest = new File(aContext.getStorageLocation(Constants.TEST_TASK_INPUT_KEY_TEST_DATA, + AccessMode.READONLY).getPath() + + "/" + TRAINING_DATA_KEY); + + File fileModel = new File(aContext.getStorageLocation(Constants.TEST_TASK_OUTPUT_KEY, AccessMode.READWRITE) + .getPath() + "/" + MALLET_MODEL_KEY); + + TransducerEvaluator eval = MalletUtils.runTrainTest(fileTrain, fileTest, fileModel, gaussianPriorVariance, iterations, defaultLabel, + false, orders, tagger, denseFeatureValues); + + + File filePredictions = new File(aContext.getStorageLocation(Constants.TEST_TASK_OUTPUT_KEY, AccessMode.READWRITE) + .getPath() + "/" + PREDICTIONS_KEY); + + MalletUtils.outputPredictions(eval, fileTest, filePredictions, PREDICTION_CLASS_LABEL_NAME); + + File fileEvaluation = new File(aContext.getStorageLocation(Constants.TEST_TASK_OUTPUT_KEY, AccessMode.READWRITE) + .getPath() + "/" + EVALUATION_DATA_KEY); + + MalletUtils.outputEvaluation(eval, fileEvaluation); + + File fileConfusionMatrix = new File(aContext.getStorageLocation(Constants.TEST_TASK_OUTPUT_KEY, AccessMode.READWRITE) + .getPath() + "/" + CONFUSION_MATRIX_KEY); + + MalletUtils.outputConfusionMatrix(eval, fileConfusionMatrix); + } +} \ No newline at end of file diff --git a/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/util/ConversionToFeatureVectorSequence.java b/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/util/ConversionToFeatureVectorSequence.java new file mode 100644 index 000000000..5bc240f7b --- /dev/null +++ b/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/util/ConversionToFeatureVectorSequence.java @@ -0,0 +1,199 @@ +/******************************************************************************* + * Copyright 2015 + * Ubiquitous Knowledge Processing (UKP) Lab + * Technische Universität Darmstadt + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ +package org.dkpro.tc.mallet.util; + +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.util.ArrayList; +import java.util.HashMap; + +import org.dkpro.tc.api.exception.TextClassificationException; +import org.dkpro.tc.api.features.Feature; +import org.dkpro.tc.api.features.FeatureStore; +import org.dkpro.tc.core.Constants; +import org.dkpro.tc.mallet.task.MalletTestTask; + +import cc.mallet.pipe.Pipe; +import cc.mallet.types.Alphabet; +import cc.mallet.types.FeatureVector; +import cc.mallet.types.FeatureVectorSequence; +import cc.mallet.types.Instance; +import cc.mallet.types.LabelAlphabet; +import cc.mallet.types.LabelSequence; + +/** + * Modification of SimpleTagger2FeatureVectorSequence from Mallet + */ + +public class ConversionToFeatureVectorSequence extends Pipe +{ + + private int idFeatureIndex; + private boolean denseFeatureValues; + + // Previously, there was no serialVersionUID. This is ID that would + // have been automatically generated by the compiler. Therefore, + // other changes should not break serialization. + private static final long serialVersionUID = -2059308802200728625L; + + public ConversionToFeatureVectorSequence (boolean denseFeatureValues) + { + super (new Alphabet(), new LabelAlphabet()); + idFeatureIndex = -1; + } + + /** + * Parses a string representing a sequence of rows of tokens into an + * array of arrays of tokens (Ignore first sentence containing feature names) + * + * @param sentence a String + * @return the corresponding array of arrays of tokens. + */ + private String[][] parseSentence(String sentence) + { + String[] lines = sentence.split("\n"); + String[][] tokens = null; + if (lines[0].matches("^[A-Za-z]+.*$")) { //parsing first line group containing feature names + String[] featureNames = lines[0].split(" "); + for (int i = 0; i < featureNames.length; i++) { + if (featureNames[i].equals(Constants.ID_FEATURE_NAME)) { + idFeatureIndex = i; + } + } + if (idFeatureIndex != -1) { // if file contained the DKPro Instance ID feature + tokens = new String[lines.length - 1][]; + String[][] tempTokens = new String[lines.length - 1][]; + for (int i = 1; i < lines.length; i++) { + tempTokens[i - 1] = lines[i].split(" "); + } + for (int i = 0; i < tempTokens.length; i++) { + tokens[i] = new String[tempTokens[i].length - 1]; + int tokenLineIndex = 0; + for (int j = 0; j < tempTokens[i].length; j++) { + if (j != idFeatureIndex) { + tokens[i][tokenLineIndex++] = tempTokens[i][j]; + } + } + } + } + else { + tokens = new String[lines.length - 1][]; + for (int i = 1; i < lines.length; i++) { + tokens[i - 1] = lines[i].split(" "); + } + } + } + else { + if (idFeatureIndex != -1) { // if file contained the DKPro Instance ID feature + tokens = new String[lines.length][]; + String[][] tempTokens = new String[lines.length][]; + for (int i = 0; i < lines.length; i++) { + tempTokens[i] = lines[i].split(" "); + } + for (int i = 0; i < tempTokens.length; i++) { + tokens[i] = new String[tempTokens[i].length - 1]; + int tokenLineIndex = 0; + for (int j = 0; j < tempTokens[i].length; j++) { + if (j != idFeatureIndex) { + tokens[i][tokenLineIndex++] = tempTokens[i][j]; + } + } + } + } + else { + tokens = new String[lines.length][]; + for (int i = 0; i < lines.length; i++) { + tokens[i] = lines[i].split(" "); + } + } + } + return tokens; + } + + @Override + public Instance pipe (Instance carrier) + { + Object inputData = carrier.getData(); + Alphabet features = getDataAlphabet(); + LabelAlphabet labels; + LabelSequence target = null; + String [][] tokens; + if (inputData instanceof String) + tokens = parseSentence((String)inputData); + else if (inputData instanceof String[][]) + tokens = (String[][])inputData; + else + throw new IllegalArgumentException("Not a String or String[][]; got "+inputData); + FeatureVector[] fvs = new FeatureVector[tokens.length]; + if (isTargetProcessing()) + { + labels = (LabelAlphabet)getTargetAlphabet(); + target = new LabelSequence (labels, tokens.length); + } + + for (int l = 0; l < tokens.length; l++) { + int nFeatures; + if (isTargetProcessing()) + { + if (tokens[l].length < 1) + throw new IllegalStateException ("Missing label at line " + l + " instance "+carrier.getName ()); + nFeatures = tokens[l].length - 1; + target.add(tokens[l][nFeatures]); + } + else nFeatures = tokens[l].length; + ArrayList featureIndices = new ArrayList(); + ArrayList featureValues = new ArrayList(); + for (int f = 0; f < nFeatures; f++) { + int featureIndex = features.lookupIndex(tokens[l][f]); + // gdruck + // If the data alphabet's growth is stopped, featureIndex + // will be -1. Ignore these features. + if (featureIndex >= 0) { + featureIndices.add(featureIndex); + } + featureValues.add(Double.parseDouble(tokens[l][f])); + } + int[] featureIndicesArr = new int[featureIndices.size()]; + for (int index = 0; index < featureIndices.size(); index++) { + featureIndicesArr[index] = featureIndices.get(index); + } + double[] featureValuesArr = new double[featureValues.size()]; + for (int index = 0; index < featureValues.size(); index++) { + featureValuesArr[index] = featureValues.get(index); + } + if (denseFeatureValues) + fvs[l] = new FeatureVector(features, featureValuesArr); + else + fvs[l] = new FeatureVector(features, featureIndicesArr); + //fvs[l] = featureInductionOption.value ? new AugmentableFeatureVector(features, featureIndicesArr, null, featureIndicesArr.length) : + // fvs[l] = featureInductionOption.value ? new AugmentableFeatureVector(features, featureIndicesArr, null, featureIndicesArr.length) : + // new FeatureVector(features, featureValues); + } + carrier.setData(new FeatureVectorSequence(fvs)); + if (isTargetProcessing()) + carrier.setTarget(target); + else + carrier.setTarget(new LabelSequence(getTargetAlphabet())); + return carrier; + } + + +} \ No newline at end of file diff --git a/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/util/MalletFoldDimensionBundle.java b/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/util/MalletFoldDimensionBundle.java new file mode 100644 index 000000000..709191f5d --- /dev/null +++ b/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/util/MalletFoldDimensionBundle.java @@ -0,0 +1,143 @@ +/******************************************************************************* + * Copyright 2015 + * Ubiquitous Knowledge Processing (UKP) Lab + * Technische Universität Darmstadt + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ +package org.dkpro.tc.mallet.util; + +import java.io.File; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.dkpro.lab.task.Dimension; +import org.dkpro.lab.task.impl.DimensionBundle; +import org.dkpro.lab.task.impl.DynamicDimension; + +/** + * Modification to FoldDimensionBundle in order to add instances belonging to the same sequence in + * the same fold + * + * @deprecated As of release 0.7.0, only dkpro-tc-ml-crfsuite is supported + */ +public class MalletFoldDimensionBundle extends DimensionBundle> implements DynamicDimension +{ + private Dimension foldedDimension; + private List[] buckets; + private int validationBucket = -1; + private int folds; + + public MalletFoldDimensionBundle(String aName, Dimension aFoldedDimension, int aFolds) + { + super(aName, new Object[0] ); + foldedDimension = aFoldedDimension; + folds = aFolds; + } + + private void init() + { + buckets = new List[folds]; + + // Capture all data from the dimension into buckets, one per fold + foldedDimension.rewind(); + int i = 0; + + String remainingFile = null; + while (foldedDimension.hasNext()) { + int bucket = i % folds; + + if (buckets[bucket] == null) { + buckets[bucket] = new ArrayList(); + } + + if (remainingFile != null) { + buckets[bucket].add(remainingFile); + } + + String firstFile = foldedDimension.next(); + String firstEssayName = getEssayName(firstFile); + buckets[bucket].add(firstFile); + + // Ensure that all instances belonging to same sequence are put in the same bucket + while (foldedDimension.hasNext()) { + String currentFile = foldedDimension.next(); + String currentEssayName = getEssayName(currentFile); + if (!firstEssayName.equals(currentEssayName)) { + remainingFile = currentFile; + break; + } + buckets[bucket].add(currentFile); + } + i++; + } + + if (i < folds) { + throw new IllegalStateException("Requested [" + folds + "] folds, but only got [" + i + + "] values. There must be at least as many values as folds."); + } + } + + private String getEssayName(String file) { + String simpleFileName = new File(file).getName(); + return simpleFileName.substring(0, simpleFileName.indexOf('_')); + } + + @Override + public boolean hasNext() + { + return validationBucket < buckets.length-1; + } + + @Override + public void rewind() + { + init(); + validationBucket = -1; + } + + @Override + public Map> next() + { + validationBucket++; + return current(); + } + + @Override + public Map> current() + { + List trainingData = new ArrayList(); + for (int i = 0; i < buckets.length; i++) { + if (i != validationBucket) { + trainingData.addAll(buckets[i]); + } + } + + Map> data = new HashMap>(); + data.put(getName()+"_training", trainingData); + data.put(getName()+"_validation", buckets[validationBucket]); + + return data; + } + + @Override + public void setConfiguration(Map aConfig) + { + if (foldedDimension instanceof DynamicDimension) { + ((DynamicDimension) foldedDimension).setConfiguration(aConfig); + } + } +} diff --git a/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/util/MalletUtils.java b/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/util/MalletUtils.java new file mode 100644 index 000000000..d71d932ea --- /dev/null +++ b/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/util/MalletUtils.java @@ -0,0 +1,454 @@ +/******************************************************************************* + * Copyright 2015 + * Ubiquitous Knowledge Processing (UKP) Lab + * Technische Universität Darmstadt + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ +package org.dkpro.tc.mallet.util; + +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.OutputStreamWriter; +import java.io.Reader; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; + +import org.apache.commons.lang.StringUtils; + +import cc.mallet.fst.CRF; +import cc.mallet.fst.CRFTrainerByLabelLikelihood; +import cc.mallet.fst.NoopTransducerTrainer; +import cc.mallet.fst.Transducer; +import cc.mallet.fst.TransducerEvaluator; +import cc.mallet.fst.TransducerTrainer; +import cc.mallet.pipe.Pipe; +import cc.mallet.pipe.iterator.LineGroupIterator; +import cc.mallet.types.Alphabet; +import cc.mallet.types.InstanceList; +import org.dkpro.tc.api.exception.TextClassificationException; +import org.dkpro.tc.api.features.Feature; +import org.dkpro.tc.api.features.FeatureStore; +import org.dkpro.tc.api.features.Instance; +import org.dkpro.tc.api.features.MissingValue; +import org.dkpro.tc.mallet.report.MalletReportConstants; +import org.dkpro.tc.mallet.task.MalletTestTask; +import org.dkpro.tc.mallet.writer.MalletFeatureEncoder; + +/** + * Utility class for the Mallet machine learning toolkit + */ +public class MalletUtils +{ + + // TODO yet to decide when to call this method + public static void writeFeatureNamesToFile(FeatureStore instanceList, File outputFile) + throws IOException, TextClassificationException + { + BufferedWriter bw = new BufferedWriter(new OutputStreamWriter( + new FileOutputStream(outputFile), "UTF-8")); + HashMap featureOffsetIndex = new HashMap(); + for (int i = 0; i < instanceList.getNumberOfInstances(); i++) { + Instance instance = instanceList.getInstance(i); + for (Feature feature : instance.getFeatures()) { + String featureName = feature.getName(); + if (!featureOffsetIndex.containsKey(featureName)) { + featureOffsetIndex.put(featureName, featureOffsetIndex.size()); + bw.write(featureName + " "); + } + } + } + bw.write(MalletTestTask.OUTCOME_CLASS_LABEL_NAME); + bw.close(); + } + + public static HashMap getFeatureOffsetIndex(FeatureStore instanceList) + { + HashMap featureOffsetIndex = new HashMap(); + for (int i = 0; i < instanceList.getNumberOfInstances(); i++) { + Instance instance = instanceList.getInstance(i); + for (Feature feature : instance.getFeatures()) { + String featureName = feature.getName(); + if (!featureOffsetIndex.containsKey(featureName)) { + featureOffsetIndex.put(featureName, featureOffsetIndex.size()); + } + } + + } + return featureOffsetIndex; + } + + + public static CRF trainCRF(InstanceList training, CRF crf, double gaussianPriorVariance, int iterations, String defaultLabel, + boolean fullyConnected, int[] orders) { + + if (crf == null) { + crf = new CRF(training.getPipe(), (Pipe)null); + String startName = + crf.addOrderNStates(training, orders, null, + defaultLabel, null, null, + fullyConnected); + for (int i = 0; i < crf.numStates(); i++) { + crf.getState(i).setInitialWeight (Transducer.IMPOSSIBLE_WEIGHT); + } + crf.getState(startName).setInitialWeight(0.0); + } + // logger.info("Training on " + training.size() + " instances"); + + CRFTrainerByLabelLikelihood crft = new CRFTrainerByLabelLikelihood (crf); + crft.setGaussianPriorVariance(gaussianPriorVariance); + + boolean converged; + for (int i = 1; i <= iterations; i++) { + converged = crft.train (training, 1); + if (converged) { + break; + } + } + return crf; + } + + public static void runTrainCRF(File trainingFile, File modelFile, double var, int iterations, String defaultLabel, + boolean fullyConnected, int[] orders, boolean denseFeatureValues) throws FileNotFoundException, IOException, ClassNotFoundException { + Reader trainingFileReader = null; + InstanceList trainingData = null; + //trainingFileReader = new FileReader(trainingFile); + trainingFileReader = new InputStreamReader(new FileInputStream(trainingFile), "UTF-8"); + Pipe p = null; + CRF crf = null; + p = new ConversionToFeatureVectorSequence(denseFeatureValues); //uses first line of file to identify DKProInstanceID feature and discard + p.getTargetAlphabet().lookupIndex(defaultLabel); + p.setTargetProcessing(true); + trainingData = new InstanceList(p); + trainingData.addThruPipe(new LineGroupIterator(trainingFileReader, + Pattern.compile("^\\s*$"), true)); //if you want to skip the line containing feature names, add "|^[A-Za-z]+.*$" + // logger.info + // ("Number of features in training data: "+p.getDataAlphabet().size()); + + // logger.info ("Number of predicates: "+p.getDataAlphabet().size()); + + + if (p.isTargetProcessing()) + { + Alphabet targets = p.getTargetAlphabet(); + StringBuffer buf = new StringBuffer("Labels:"); + for (int i = 0; i < targets.size(); i++) + { + buf.append(" ").append(targets.lookupObject(i).toString()); + // logger.info(buf.toString()); + } + } + + crf = trainCRF(trainingData, crf, var, iterations, defaultLabel, fullyConnected, orders); + + ObjectOutputStream s = + new ObjectOutputStream(new FileOutputStream(modelFile)); + s.writeObject(crf); + s.close(); + } + + public static void test(TransducerTrainer tt, TransducerEvaluator eval, + InstanceList testing) + { + eval.evaluateInstanceList(tt, testing, "Testing"); + } + + public static TransducerEvaluator runTestCRF(File testFile, File modelFile) throws FileNotFoundException, IOException, ClassNotFoundException { + Reader testFileReader = null; + InstanceList testData = null; + //testFileReader = new FileReader(testFile); + testFileReader = new InputStreamReader(new FileInputStream(testFile), "UTF-8"); + Pipe p = null; + CRF crf = null; + TransducerEvaluator eval = null; + ObjectInputStream s = + new ObjectInputStream(new FileInputStream(modelFile)); + crf = (CRF) s.readObject(); + s.close(); + p = crf.getInputPipe(); + p.setTargetProcessing(true); + testData = new InstanceList(p); + testData.addThruPipe( + new LineGroupIterator(testFileReader, + Pattern.compile("^\\s*$"), true)); + // logger.info ("Number of predicates: "+p.getDataAlphabet().size()); + + eval = new PerClassEvaluator(new InstanceList[] {testData}, new String[] {"Testing"}); + + if (p.isTargetProcessing()) + { + Alphabet targets = p.getTargetAlphabet(); + StringBuffer buf = new StringBuffer("Labels:"); + for (int i = 0; i < targets.size(); i++) + { + buf.append(" ").append(targets.lookupObject(i).toString()); + // logger.info(buf.toString()); + } + } + + test(new NoopTransducerTrainer(crf), eval, testData); + + List labels = ((PerClassEvaluator) eval).getLabelNames(); + List precisionValues = ((PerClassEvaluator) eval).getPrecisionValues(); + List recallValues = ((PerClassEvaluator) eval).getRecallValues(); + List f1Values = ((PerClassEvaluator) eval).getF1Values(); + + printEvaluationMeasures(labels, precisionValues, recallValues, f1Values); + + return eval; + } + + public static TransducerEvaluator runTrainTest(File trainFile, File testFile, File modelFile, + double var, int iterations, String defaultLabel, + boolean fullyConnected, int[] orders, String tagger, boolean denseFeatureValues) throws FileNotFoundException, ClassNotFoundException, IOException, TextClassificationException { + TransducerEvaluator eval = null; + if (tagger.equals("CRF")) { + runTrainCRF(trainFile,modelFile, var, iterations, defaultLabel, fullyConnected, orders, denseFeatureValues); + eval = runTestCRF(testFile, modelFile); + } + else if (tagger.equals("HMM")){ + throw new TextClassificationException("'HMM' is not currently supported."); + //runTrainHMM(trainFile,modelFile, defaultLabel, iterations, denseFeatureValues); + //eval = runTestHMM(testFile, modelFile); + } + else { + throw new TextClassificationException("Unsupported tagger name for sequence tagging. Supported taggers are 'CRF' and 'HMM'."); + } + return eval; + } + + //FIXME HMM is not currently supported (uncomment and use a different vector sequence compatible to HMM utilities + //in Mallet) + +// public static void runTrainHMM(File trainingFile, File modelFile, String defaultLabel, int iterations, boolean denseFeatureValues) throws FileNotFoundException, IOException { +// Reader trainingFileReader = null; +// InstanceList trainingData = null; +// //trainingFileReader = new FileReader(trainingFile); +// trainingFileReader = new InputStreamReader(new GZIPInputStream(new FileInputStream(trainingFile))); +// Pipe p = null; +// p = new ConversionToFeatureVectorSequence(denseFeatureValues); //uses first line of file to identify DKProInstanceID feature and discard +// p.getTargetAlphabet().lookupIndex(defaultLabel); +// p.setTargetProcessing(true); +// trainingData = new InstanceList(p); +// trainingData.addThruPipe(new LineGroupIterator(trainingFileReader, +// Pattern.compile("^\\s*$"), true)); //if you want to skip the line containing feature names, add "|^[A-Za-z]+.*$" +// // logger.info +// // ("Number of features in training data: "+p.getDataAlphabet().size()); +// +// // logger.info ("Number of predicates: "+p.getDataAlphabet().size()); +// +// if (p.isTargetProcessing()) +// { +// Alphabet targets = p.getTargetAlphabet(); +// StringBuffer buf = new StringBuffer("Labels:"); +// for (int i = 0; i < targets.size(); i++) +// buf.append(" ").append(targets.lookupObject(i).toString()); +// // logger.info(buf.toString()); +// } +// +// HMM hmm = null; +// hmm = trainHMM(trainingData, hmm, iterations); +// ObjectOutputStream s = +// new ObjectOutputStream(new FileOutputStream(modelFile)); +// s.writeObject(hmm); +// s.close(); +// } +// +// public static HMM trainHMM(InstanceList training, HMM hmm, int numIterations) throws IOException { +// if (hmm == null) { +// hmm = new HMM(training.getPipe(), null); +// hmm.addStatesForLabelsConnectedAsIn(training); +// //hmm.addStatesForBiLabelsConnectedAsIn(trainingInstances); +// +// HMMTrainerByLikelihood trainer = +// new HMMTrainerByLikelihood(hmm); +// +// trainer.train(training, numIterations); +// +// //trainingEvaluator.evaluate(trainer); +// } +// return hmm; +// } +// +// public static TransducerEvaluator runTestHMM(File testFile, File modelFile) throws FileNotFoundException, IOException, ClassNotFoundException { +// ArrayList pipes = new ArrayList(); +// +// pipes.add(new SimpleTaggerSentence2TokenSequence()); +// pipes.add(new TokenSequence2FeatureSequence()); +// +// Pipe pipe = new SerialPipes(pipes); +// +// InstanceList testData = new InstanceList(pipe); +// +// testData.addThruPipe(new LineGroupIterator(new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(testFile)))), Pattern.compile("^\\s*$"), true)); +// +// TransducerEvaluator eval = +// new PerClassEvaluator(testData, "testing"); +// +// ObjectInputStream s = +// new ObjectInputStream(new FileInputStream(modelFile)); +// HMM hmm = (HMM) s.readObject(); +// +// test(new NoopTransducerTrainer(hmm), eval, testData); +// labels = ((PerClassEvaluator) eval).getLabelNames(); +// precisionValues = ((PerClassEvaluator) eval).getPrecisionValues(); +// recallValues = ((PerClassEvaluator) eval).getRecallValues(); +// f1Values = ((PerClassEvaluator) eval).getF1Values(); +// return eval; +// } + + + public static void printEvaluationMeasures(List labels, List precisionValues, List recallValues, List f1Values) { + double values[][] = new double[labels.size()][3]; + Iterator itPrecision = precisionValues.iterator(); + Iterator itRecall = recallValues.iterator(); + Iterator itF1 = f1Values.iterator(); + int i = 0; + while(itPrecision.hasNext()) { + values[i++][0] = itPrecision.next(); + } + i = 0; + while(itRecall.hasNext()) { + values[i++][1] = itRecall.next(); + } + i = 0; + while(itF1.hasNext()) { + values[i++][2] = itF1.next(); + } + Iterator itLabels = labels.iterator(); + for(i=0; i predictedLabels = ((PerClassEvaluator) eval).getPredictedLabels(); + BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(fileTest), "UTF-8")); + BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(filePredictions), "UTF-8")); + String line; + boolean header = false; + int i = 0; + while ((line = br.readLine()) != null) { + if (!header) { + bw.write(line + " " + predictionClassLabelName); + bw.flush(); + header = true; + continue; + } + if (!line.isEmpty()) { + bw.write("\n" + line + " " + predictedLabels.get(i++)); + bw.flush(); + } + else { + bw.write("\n"); + bw.flush(); + } + } + br.close(); + bw.close(); + } + + public static void outputEvaluation(TransducerEvaluator eval, File fileEvaluation) throws IOException { + BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(fileEvaluation), "UTF-8")); + + ArrayList labelNames = ((PerClassEvaluator) eval).getLabelNames(); + + ArrayList precisionValues = ((PerClassEvaluator) eval).getPrecisionValues(); + ArrayList recallValues = ((PerClassEvaluator) eval).getRecallValues(); + ArrayList f1Values = ((PerClassEvaluator) eval).getF1Values(); + + int numLabels = labelNames.size(); + bw.write("Measure,Value"); + bw.write("\n" + MalletReportConstants.CORRECT + "," + ((PerClassEvaluator) eval).getNumberOfCorrectPredictions()); + bw.write("\n" + MalletReportConstants.INCORRECT + "," + ((PerClassEvaluator) eval).getNumberOfIncorrectPredictions()); + bw.write("\n" + MalletReportConstants.NUMBER_EXAMPLES + "," + ((PerClassEvaluator) eval).getNumberOfExamples()); + bw.write("\n" + MalletReportConstants.PCT_CORRECT + "," + ((PerClassEvaluator) eval).getPercentageOfCorrectPredictions()); + bw.write("\n" + MalletReportConstants.PCT_INCORRECT + "," + ((PerClassEvaluator) eval).getPercentageOfIncorrectPredictions()); + + for (int i = 0; i < numLabels; i++) { + String label = labelNames.get(i); + bw.write("\n" + MalletReportConstants.PRECISION + "_" + label + "," + precisionValues.get(i)); + bw.write("\n" + MalletReportConstants.RECALL + "_" + label + "," + recallValues.get(i)); + bw.write("\n" + MalletReportConstants.FMEASURE + "_" + label + "," + f1Values.get(i)); + bw.flush(); + } + bw.write("\n" + MalletReportConstants.MACRO_AVERAGE_FMEASURE + "," + ((PerClassEvaluator) eval).getMacroAverage()); + bw.flush(); + bw.close(); + } + + public static void outputConfusionMatrix(TransducerEvaluator eval, File fileConfusionMatrix) throws IOException { + BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(fileConfusionMatrix), "UTF-8")); + + ArrayList labelNames = ((PerClassEvaluator) eval).getLabelNames(); + int numLabels = labelNames.size(); + + HashMap labelNameToIndexMap = new HashMap(); + for (int i = 0; i < numLabels; i++) { + labelNameToIndexMap.put(labelNames.get(i), i); + } + + ArrayList goldLabels = ((PerClassEvaluator) eval).getGoldLabels(); + ArrayList predictedLabels = ((PerClassEvaluator) eval).getPredictedLabels(); + + Integer[][] confusionMatrix = new Integer[numLabels][numLabels]; + + //initialize to 0 + for (int i = 0; i < confusionMatrix.length; i++) { + for (int j = 0; j < confusionMatrix.length; j++) { + confusionMatrix[i][j] = 0; + } + } + + for (int i = 0; i < goldLabels.size(); i++) { + confusionMatrix[labelNameToIndexMap.get(goldLabels.get(i))][labelNameToIndexMap.get(predictedLabels.get(i))]++; + } + + String[][] confusionMatrixString = new String[numLabels + 1][numLabels + 1]; + confusionMatrixString[0][0] = " "; + for (int i = 1; i < numLabels + 1; i++) { + confusionMatrixString[i][0] = labelNames.get(i-1) + "_actual"; + confusionMatrixString[0][i] = labelNames.get(i-1) + "_predicted"; + } + for (int i = 1; i < numLabels + 1; i++) { + for (int j = 1; j < numLabels + 1; j++) { + confusionMatrixString[i][j] = confusionMatrix[i-1][j-1].toString(); + } + } + + for (String[] element : confusionMatrixString) { + bw.write(StringUtils.join(element, ",")); + bw.write("\n"); + bw.flush(); + } + bw.close(); + } +} diff --git a/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/util/PerClassEvaluator.java b/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/util/PerClassEvaluator.java new file mode 100644 index 000000000..1fd5deb91 --- /dev/null +++ b/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/util/PerClassEvaluator.java @@ -0,0 +1,212 @@ +/******************************************************************************* + * Copyright 2015 + * Ubiquitous Knowledge Processing (UKP) Lab + * Technische Universität Darmstadt + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ +package org.dkpro.tc.mallet.util; + +import java.text.DecimalFormat; +import java.util.ArrayList; +import java.util.logging.Logger; + +import cc.mallet.fst.TokenAccuracyEvaluator; +import cc.mallet.fst.Transducer; +import cc.mallet.fst.TransducerEvaluator; +import cc.mallet.fst.TransducerTrainer; +import cc.mallet.types.Alphabet; +import cc.mallet.types.Instance; +import cc.mallet.types.InstanceList; +import cc.mallet.types.MatrixOps; +import cc.mallet.types.Sequence; +import cc.mallet.util.MalletLogger; + +/** + * Determines the precision, recall and F1 on a per-class basis. + * + * @deprecated As of release 0.7.0, only dkpro-tc-ml-crfsuite is supported + */ + +public class PerClassEvaluator extends TransducerEvaluator { + + private static Logger logger = MalletLogger.getLogger(TokenAccuracyEvaluator.class.getName()); + + private Integer numberOfCorrectPredictions = 0; + private Integer numberOfIncorrectPredictions = 0; + private Integer numberOfExamples = 0; + + private Double percentageOfCorrectPredictions = 0.0; + private Double percentageOfIncorrectPredictions = 0.0; + + private ArrayList precisionValues = new ArrayList(); + private ArrayList recallValues = new ArrayList(); + private ArrayList f1Values = new ArrayList(); + + private Double macroAverage = 0.0; + + public Integer getNumberOfCorrectPredictions() { + return numberOfCorrectPredictions; + } + + public void setNumberOfCorrectPredictions(Integer numberOfCorrectPredictions) { + this.numberOfCorrectPredictions = numberOfCorrectPredictions; + } + + public Integer getNumberOfIncorrectPredictions() { + return numberOfIncorrectPredictions; + } + + public void setNumberOfIncorrectPredictions(Integer numberOfIncorrectPredictions) { + this.numberOfIncorrectPredictions = numberOfIncorrectPredictions; + } + + public Double getPercentageOfCorrectPredictions() { + return percentageOfCorrectPredictions; + } + + public void setPercentageOfCorrectPredictions( + Double percentageOfCorrectPredictions) { + this.percentageOfCorrectPredictions = percentageOfCorrectPredictions; + } + + public Integer getNumberOfExamples() { + return numberOfExamples; + } + + public void setNumberOfExamples(Integer numberOfExamples) { + this.numberOfExamples = numberOfExamples; + } + + public Double getPercentageOfIncorrectPredictions() { + return percentageOfIncorrectPredictions; + } + + public void setPercentageOfIncorrectPredictions( + Double percentageOfIncorrectPredictions) { + this.percentageOfIncorrectPredictions = percentageOfIncorrectPredictions; + } + + public Double getMacroAverage() { + return macroAverage; + } + + public void setMacroAverage(Double macroAverage) { + this.macroAverage = macroAverage; + } + + private static ArrayList labelNames = new ArrayList(); + private static ArrayList predictedLabels = new ArrayList(); + private static ArrayList goldLabels = new ArrayList(); + + public PerClassEvaluator (InstanceList[] instanceLists, String[] descriptions) { + super (instanceLists, descriptions); + } + + public PerClassEvaluator (InstanceList i1, String d1) { + this (new InstanceList[] {i1}, new String[] {d1}); + } + + public PerClassEvaluator (InstanceList i1, String d1, InstanceList i2, String d2) { + this (new InstanceList[] {i1, i2}, new String[] {d1, d2}); + } + + @Override + public void evaluateInstanceList (TransducerTrainer tt, InstanceList data, String description) + { + Transducer model = tt.getTransducer(); + Alphabet dict = model.getInputPipe().getTargetAlphabet(); + int numLabels = dict.size(); + int[] numCorrectTokens = new int [numLabels]; + int[] numPredTokens = new int [numLabels]; + int[] numTrueTokens = new int [numLabels]; + + logger.info("Per-token results for " + description); + for (int i = 0; i < data.size(); i++) { + Instance instance = data.get(i); + Sequence input = (Sequence) instance.getData(); + Sequence trueOutput = (Sequence) instance.getTarget(); + assert (input.size() == trueOutput.size()); + Sequence predOutput = model.transduce (input); + assert (predOutput.size() == trueOutput.size()); + for (int j = 0; j < trueOutput.size(); j++) { + int idx = dict.lookupIndex(trueOutput.get(j)); + numTrueTokens[idx]++; + numPredTokens[dict.lookupIndex(predOutput.get(j))]++; + if (trueOutput.get(j).equals(predOutput.get(j))) { + numCorrectTokens[idx]++; + numberOfCorrectPredictions++; + } + else { + numberOfIncorrectPredictions++; + } + goldLabels.add(trueOutput.get(j).toString()); + predictedLabels.add(predOutput.get(j).toString()); + } + } + + setNumberOfExamples(numberOfCorrectPredictions + numberOfIncorrectPredictions); + setPercentageOfCorrectPredictions((numberOfCorrectPredictions*100)/(double) numberOfExamples); + setPercentageOfIncorrectPredictions((numberOfIncorrectPredictions*100)/(double) numberOfExamples); + + precisionValues = new ArrayList(); + recallValues = new ArrayList(); + f1Values = new ArrayList(); + labelNames = new ArrayList(); + + DecimalFormat f = new DecimalFormat ("0.####"); + double[] allf = new double [numLabels]; + for (int i = 0; i < numLabels; i++) { + Object label = dict.lookupObject(i); + double precision = ((double) numCorrectTokens[i]) / numPredTokens[i]; + double recall = ((double) numCorrectTokens[i]) / numTrueTokens[i]; + double f1 = (2 * precision * recall) / (precision + recall); + if (!Double.isNaN (f1)) allf [i] = f1; + logger.info(description +" label " + label + " P " + f.format (precision) + + " R " + f.format(recall) + " F1 "+ f.format (f1)); + precisionValues.add(precision); + recallValues.add(recall); + f1Values.add(f1); + labelNames.add(label.toString()); + } + + logger.info ("Macro-average F1 "+f.format (MatrixOps.mean (allf))); + setMacroAverage(MatrixOps.mean (allf)); + + } + + public ArrayList getPrecisionValues() { + return precisionValues; + } + + public ArrayList getRecallValues() { + return recallValues; + } + + public ArrayList getF1Values() { + return f1Values; + } + + public ArrayList getLabelNames() { + return labelNames; + } + + public ArrayList getGoldLabels() { + return goldLabels; + } + + public ArrayList getPredictedLabels() { + return predictedLabels; + } + +} diff --git a/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/writer/MalletDataWriter.java b/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/writer/MalletDataWriter.java new file mode 100644 index 000000000..564beffa9 --- /dev/null +++ b/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/writer/MalletDataWriter.java @@ -0,0 +1,210 @@ +/******************************************************************************* + * Copyright 2015 + * Ubiquitous Knowledge Processing (UKP) Lab + * Technische Universität Darmstadt + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ +package org.dkpro.tc.mallet.writer; + +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.dkpro.tc.api.exception.TextClassificationException; +import org.dkpro.tc.api.features.Feature; +import org.dkpro.tc.api.features.FeatureStore; +import org.dkpro.tc.api.features.Instance; +import org.dkpro.tc.api.features.MissingValue; +import org.dkpro.tc.core.Constants; +import org.dkpro.tc.core.io.DataWriter; +import org.dkpro.tc.core.ml.TCMachineLearningAdapter.AdapterNameEntries; +import org.dkpro.tc.mallet.MalletAdapter; +import org.dkpro.tc.mallet.task.MalletTestTask; + +/** + * {@link DataWriter} for the Mallet machine learning tool. + */ +public class MalletDataWriter + implements DataWriter, Constants +{ + + @Override + public void write(File outputFolder, FeatureStore featureStore, boolean useDenseInstances, + String learningMode, boolean applyWeighting) + throws Exception + { + String frameworkFilename = MalletAdapter.getInstance().getFrameworkFilename(AdapterNameEntries.featureVectorsFile); + File outputFile = new File(outputFolder, frameworkFilename); + + // check for error conditions + if (featureStore.getUniqueOutcomes().isEmpty()) { + throw new IllegalArgumentException("List of instance outcomes is empty."); + } + + Map featureOffsetIndex = getFeatureOffsetIndex(featureStore); + +// writeFeatureNamesToFile(featureStore, outputFile); + + List instanceArrayList = new ArrayList(); + + for (int i = 0; i < featureStore.getNumberOfInstances(); i++) { + instanceArrayList.add(featureStore.getInstance(i)); + } + + // group based on instance sequence and sort based on instance position in file + Collections.sort(instanceArrayList, new Comparator() + { + @Override + public int compare(Instance o1, Instance o2) + { + int instanceSequenceId1 = o1.getSequenceId(); + int instanceSequenceId2 = o2.getSequenceId(); + int instancePosition1 = o1.getSequencePosition(); + int instancePosition2 = o2.getSequencePosition(); + + if (instanceSequenceId1 == instanceSequenceId2) { + if (instancePosition1 == instancePosition2) { + return 0; + } + return instancePosition1 < instancePosition2 ? -1 : 1; + } + + return 0; + // order of sequences doesn't matter + // order of instances within a sequence does + } + }); + + // List normalizedInstanceArrayList = instanceArrayList; + // ArrayList normalizedInstanceArrayList = + // normalizeNumericFeatureValues(instanceArrayList); + + int currentSequenceId = 1; + for (int i = 0; i < instanceArrayList.size(); i++) { + Instance instance = instanceArrayList.get(i); + if (currentSequenceId != instance.getSequenceId()) { + writeNewLineToFile(outputFile); + currentSequenceId = instance.getSequenceId(); + } + + String outcome = instance.getOutcome(); + String featureValues[] = new String[featureOffsetIndex.size()]; + for (Feature feature : instance.getFeatures()) { + String featureName = feature.getName(); + Object value = feature.getValue(); + double doubleFeatureValue = 0.0; + String featureValue; + if (value instanceof Number) { + doubleFeatureValue = ((Number) value).doubleValue(); + featureValue = String.valueOf(doubleFeatureValue); + } + else if (value instanceof Boolean) { + doubleFeatureValue = (Boolean) value ? 1.0d : 0.0d; + featureValue = String.valueOf(doubleFeatureValue); + } + else if (value instanceof MissingValue) { + // missing value + featureValue = MalletFeatureEncoder.getMissingValueConversionMap() + .get(((MissingValue) value).getType()); + } + else if (value == null) { + // null + throw new IllegalArgumentException( + "You have an instance which doesn't specify a value for the feature " + + feature.getName()); + } + else { + featureValue = value.toString(); + } + if (featureOffsetIndex.containsKey(featureName)) { + featureValues[featureOffsetIndex.get(featureName)] = featureValue; + } + } + writeFeatureValuesToFile(featureValues, outcome, outputFile); + } + + // MalletUtils.instanceListToMalletFormatFile(new File(outputDirectory + "/" + // + + // MalletAdapter.getInstance().getFrameworkFilename(AdapterNameEntries.featureVectorsFile)), + // featureStore, useDenseInstances); + } + + public HashMap getFeatureOffsetIndex(FeatureStore instanceList) + { + HashMap featureOffsetIndex = new HashMap(); + for (int i = 0; i < instanceList.getNumberOfInstances(); i++) { + Instance instance = instanceList.getInstance(i); + for (Feature feature : instance.getFeatures()) { + String featureName = feature.getName(); + if (!featureOffsetIndex.containsKey(featureName)) { + featureOffsetIndex.put(featureName, featureOffsetIndex.size()); + } + } + + } + return featureOffsetIndex; + } + + public void writeFeatureNamesToFile(FeatureStore instanceList, File outputFile) + throws IOException, TextClassificationException + { + BufferedWriter bw = new BufferedWriter( + new OutputStreamWriter(new FileOutputStream(outputFile), "UTF-8")); + HashMap featureOffsetIndex = new HashMap(); + for (int i = 0; i < instanceList.getNumberOfInstances(); i++) { + Instance instance = instanceList.getInstance(i); + for (Feature feature : instance.getFeatures()) { + String featureName = feature.getName(); + if (!featureOffsetIndex.containsKey(featureName)) { + featureOffsetIndex.put(featureName, featureOffsetIndex.size()); + bw.write(featureName + " "); + } + } + } + bw.write(MalletTestTask.OUTCOME_CLASS_LABEL_NAME); + bw.close(); + } + + public void writeFeatureValuesToFile(String featureValues[], String outcome, File outputFile) + throws IOException + { + BufferedWriter bw = new BufferedWriter( + new OutputStreamWriter(new FileOutputStream(outputFile, true), "UTF-8")); + bw.write("\n"); + for (String featureValue : featureValues) { + bw.write(featureValue + " "); + } + bw.write(outcome); + bw.flush(); + bw.close(); + } + + public void writeNewLineToFile(File outputFile) + throws IOException + { + BufferedWriter bw = new BufferedWriter( + new OutputStreamWriter(new FileOutputStream(outputFile, true), "UTF-8")); + bw.write("\n"); + bw.flush(); + bw.close(); + } +} diff --git a/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/writer/MalletFeatureEncoder.java b/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/writer/MalletFeatureEncoder.java new file mode 100644 index 000000000..93e3b73aa --- /dev/null +++ b/dkpro-tc-ml-mallet/src/main/java/org/dkpro/tc/mallet/writer/MalletFeatureEncoder.java @@ -0,0 +1,51 @@ +/******************************************************************************* + * Copyright 2015 + * Ubiquitous Knowledge Processing (UKP) Lab + * Technische Universität Darmstadt + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ +package org.dkpro.tc.mallet.writer; + +import java.util.HashMap; +import java.util.Map; + +import org.dkpro.tc.api.features.MissingValue.MissingValueType; + +/** + * @deprecated As of release 0.7.0, only dkpro-tc-ml-crfsuite is supported + */ +public class MalletFeatureEncoder +{ + + /** + * A map returning a String value for each valid {@link MissingValueType} + * + * @return a map with {@link MissingValueType} keys, and strings as values + */ + public static Map getMissingValueConversionMap() + { + Map map = new HashMap(); + // for booelan attributes: false + map.put(MissingValueType.BOOLEAN, "0"); + // for numeric attributes: zero + map.put(MissingValueType.NUMERIC, "0"); + // TODO is this really what we want? + // for nominal attributes: the first + map.put(MissingValueType.NOMINAL, "0"); + // TODO is this really what we want? + // for string attributes: the first + map.put(MissingValueType.STRING, "0"); + return map; + } +} diff --git a/pom.xml b/pom.xml index f217fc40a..67f349c0c 100644 --- a/pom.xml +++ b/pom.xml @@ -79,6 +79,7 @@ dkpro-tc-integrationtest dkpro-tc-ml dkpro-tc-ml-crfsuite + dkpro-tc-ml-mallet dkpro-tc-ml-liblinear dkpro-tc-ml-svmhmm dkpro-tc-ml-weka @@ -172,6 +173,11 @@ dkpro-tc-ml-crfsuite 0.9.0-SNAPSHOT + + org.dkpro.tc + dkpro-tc-ml-mallet + 0.9.0-SNAPSHOT + org.dkpro.tc dkpro-tc-ml-liblinear @@ -381,7 +387,12 @@ spring-web ${spring.version} - + + cc.mallet + mallet + 2.0.8 + +