1. libsleipnir
  2. sleipnir

Commits

Jian Zhou  committed e09676d

added hierarchical multilabel SVM

  • Participants
  • Parent commits 36982d5
  • Branches sleipnir

Comments (0)

Files changed (7)

File configure.ac

View file
 #  #include "svm/svm_struct_api_types.h"],
 
 
+## SVM HIERARCHY
+# If the user specifies a path, use it.
+# If the user says not to use it, then don't.
+# If the user doesn't specify anything, check for it.
+# While checking, look for a source distribution in extlib/.
+# If none, set NO_SVM_HIERARCHY=1.
+AC_ARG_WITH([svm_hierarchy],
+  [AS_HELP_STRING([--with-svm-hierarchy=PATH], [prefix of SVM Hierarchy installation])],
+  [
+         if test "x$with_svm_hierarchy" = "xyes"; then
+     svm_hierarchy_state=check
+         elif test "x$with_svm_hierarchy" != "xno"; then
+           svm_hierarchy_state=try
+     qualify_path with_svm_hierarchy
+           LOCAL_CHECK_APPEND_PATHS([$with_svm_hierarchy], [. include], [svm_hierarchy/svm_struct_api.h], [SVM_HIERARCHY_INCLUDE_DIR], [svm_hierarchy_state=warn])
+           LOCAL_CHECK_APPEND_PATHS([$with_svm_hierarchy], [. lib], [svm_hierarchy/libsvmhierarchy.a], [SVM_HIERARCHY_LIB_DIR], [svm_hierarchy_state=warn])
+         else
+           svm_hierarchy_state=no
+         fi
+        ],
+  [svm_hierarchy_state=check])
+AC_LANG_PUSH([C])
+LOCAL_CHECK_LIB([svmhierarchy], [svm_hierarchy], [optimize_svm],
+        [svm_hierarchy_state=ok],                                      dnl found
+        [SVM_HIERARCHY_LIBS="-lsvmhierarchy"],                              dnl and found installed
+        [
+         SVM_HIERARCHY_CFLAGS="-I $SVM_HIERARCHY_INCLUDE_DIR"
+         SVM_HIERARCHY_LIBS="-L $SVM_HIERARCHY_LIB_DIR/svm_hierarchy -lsvmhierarchy"
+        ],                                                        dnl and found in specified path
+  [],                                                       dnl not found
+        [svm_hierarchy_state=no],                                      dnl and not found installed
+        [svm_hierarchy_state=warn],                                    dnl and not found in specified path
+        [],                                                       dnl requested without
+  [#include <sys/types.h>
+   #include "svm_light/svm_common.h",
+   #include "svm_light/svm_learn.h"],
+        [],
+        [-L $SVM_HIERARCHY_LIB_DIR/svm_hierarchy -I $SVM_HIERARCHY_INCLUDE_DIR/svm_hierarchy])
+AC_LANG_POP
+if test "x$svm_hierarchy_state" != "xok"; then
+  AC_DEFINE([NO_SVM_HIERARCHY], [1])
+fi
+#AC_SUBST(SVM_HIERARCHY_PREFIX)
+AC_SUBST(SVM_HIERARCHY_CFLAGS)
+AC_SUBST(SVM_HIERARCHY_LIBS)
+#  #include "svm/svm_struct_api.h"
+#  #include "svm/svm_struct_api_types.h"],
+
 
 ## BOOST
 # If the user specifies a path, use it.
 AM_CONDITIONAL([WITH_SMILE_TOOLS], [test "x$smile_state" = "xok"])
 AM_CONDITIONAL([WITH_SVM_TOOLS], [test "x$svm_perf_state" = "xok"])
 AM_CONDITIONAL([WITH_SVM_MULTICLASS_TOOLS], [test "x$svm_multiclass_state" = "xok"])
+AM_CONDITIONAL([WITH_SVM_HIERARCHY_TOOLS], [test "x$svm_hierarchy_state" = "xok"])
 AM_CONDITIONAL([WITH_SVM_LIBSRC], [test "x$svm_perf_state" = "xok"])
 AM_CONDITIONAL([WITH_VW_TOOLS], [test "x$vowpal_wabbit_state" = "xok"])
 AM_CONDITIONAL([WITH_READLINE_TOOLS], [test "x$readline_state" = "xok"])
                  tools/Data2Svm/Makefile \
                  tools/SVMer/Makefile \
                  tools/SVMperfer/Makefile \
-                tools/SVMmulticlass/Makefile \
+                 tools/SVMmulticlass/Makefile \
+                 tools/SVMhierarchy/Makefile \
                  tools/SVMperfing/Makefile \
                  tools/LibSVMer/Makefile \
                  tools/VWer/Makefile \
 echo "  log4cpp                 = $log4cpp_info"
 echo "  SMILE                   = $smile_info"
 echo "  SVM Perf                = $svm_perf_info"
-echo "  SVM Multiclass                = $svm_multiclass_info"
+echo "  SVM Multiclass          = $svm_multiclass_info"
+echo "  SVM Hierarchy          = $svm_hierarchy_info"
 echo "  Vowpal Wabbit           = $vowpal_wabbit_info"
 echo "  pthread                 = $pthread_info"
 echo "  gsl                     = $gsl_info"
 fi
 
 
+if test "x$svm_hierarchy_state" != "xok"; then
+  cat << EOF
+
+** BUILDING WITHOUT SVM HIERARCHY
+SVM Hierarchy is strongly recommended.
+It is available from http://svmlight.joachims.org/.
+EOF
+
+fi
+
+if test "x$svm_hierarchy_state" = "xwarn"; then
+  cat << EOF
+** WARNING: The path to SVM Hierarchy may be incorrect.
+I looked for svm_struct_api.h and libsvmhierarchy.a and did not
+find both. NOTE: You may need to make libsvmhierarchy.a. See the README.
+EOF
+
+fi
+
+
+
 if test "x$smile_state" = "xwarn"; then
   cat << EOF
 ** WARNING: The path to SMILE may be incorrect.

File gen_tools_am

View file
                             SVMer => ['SVM_PERF'],
                             SVMperfer => ['SVM_PERF'],
                             SVMmulticlass => ['SVM_MULTICLASS'],
+                            SVMhierarchy => ['SVM_HIERARCHY'],
                             SVMperfing => ['SVM_PERF'],
                             LibSVMer => ['LIBSVM'],
                             VWer => ['VOWPAL_WABBIT'],

File src/Makefile.am

View file
 lib_LIBRARIES			= libSleipnir.a
 
-AM_CPPFLAGS = $(GSL_CFLAGS) $(LOG4CPP_CFLAGS) $(SMILE_CFLAGS) $(SVM_PERF_CFLAGS) $(SVM_MULTICLASS_CFLAGS) $(PTHREAD_CFLAGS) $(VOWPAL_WABBIT_CFLAGS) $(LIBSVM_CFLAGS)
+AM_CPPFLAGS = $(GSL_CFLAGS) $(LOG4CPP_CFLAGS) $(SMILE_CFLAGS) $(SVM_PERF_CFLAGS) $(SVM_MULTICLASS_CFLAGS) $(SVM_HIERARCHY_CFLAGS) $(PTHREAD_CFLAGS) $(VOWPAL_WABBIT_CFLAGS) $(LIBSVM_CFLAGS)
 
-#LDADD = $(LOG4CPP_LIBS) $(SMILE_LIBS) $(SVM_PERF_LIBS) $(SVM_MULTICLASS_LIBS) $(PTHREAD_LIBS)
+#LDADD = $(LOG4CPP_LIBS) $(SMILE_LIBS) $(SVM_PERF_LIBS) $(SVM_MULTICLASS_LIBS) $(SVM_HIERARCHY_LIBS) $(PTHREAD_LIBS)
 
 if WITH_SVM_TOOLS
 libSleipnir_SVM_SOURCES = svm.cpp svmperf.cpp
 libSleipnir_SVM_MULTICLASS_INCLUDES = svmstruct.h
 endif
 
+if WITH_SVM_HIERARCHY_TOOLS
+libSleipnir_SVM_HIERARCHY_SOURCES = svmstructtree.cpp
+libSleipnir_SVM_HIERARCHY_INCLUDES = svmstructtree.h
+endif
+
 if WITH_LIBSVM_TOOLS
 libSleipnir_LIBSVM_SOURCES = libsvm.cpp
 libSleipnir_LIBSVM_INCLUDES = libsvm.h
 	$(libSleipnir_GSL_SOURCES) \
 	$(libSleipnir_LIBSVM_SOURCES) \
 	$(libSleipnir_SVM_MULTICLASS_SOURCES)\
+	$(libSleipnir_SVM_HIERARCHY_SOURCES)\
 	vwb.cpp
 include_HEADERS			=	\
 	annotation.h			\
 	$(libSleipnir_GSL_SOURCES) \
 	$(libSleipnir_LIBSVM_INCLUDES) \
 	$(libSleipnir_SVM_MULTICLASS_INCLUDES)\
+	$(libSleipnir_SVM_HIERARCHY_INCLUDES)\
 	trie.h					\
 	triei.h					\
 	typesi.h				\

File src/svmstruct.cpp

View file
 		/* set default */
 		Alg = DEFAULT_ALG_TYPE;
 		//Learn_parms
-		struct_parm.C=0.01;
-		struct_parm.slack_norm=1;
-		struct_parm.epsilon=DEFAULT_EPS;
-		struct_parm.custom_argc=0;
-		struct_parm.loss_function=DEFAULT_LOSS_FCT;
-		struct_parm.loss_type=DEFAULT_RESCALING;
-		struct_parm.newconstretrain=100;
-		struct_parm.ccache_size=5;
+		struct_parm.C=0.01;
+		struct_parm.slack_norm=1;
+		struct_parm.epsilon=DEFAULT_EPS;
+		struct_parm.custom_argc=0;
+		struct_parm.loss_function=DEFAULT_LOSS_FCT;
+		struct_parm.loss_type=DEFAULT_RESCALING;
+		struct_parm.newconstretrain=100;
+		struct_parm.ccache_size=5;
 		struct_parm.batch_size=100;
 		//Learn_parms
 		//strcpy (learn_parm.predfile, "trans_predictions");
 		strcpy(learn_parm.alphafile, "");
 		//verbosity=0;/*verbosity for svm_light*/
 		//struct_verbosity = 1; /*verbosity for struct learning portion*/
-		learn_parm.biased_hyperplane=1;
-		learn_parm.remove_inconsistent=0;
-		learn_parm.skip_final_opt_check=0;
-		learn_parm.svm_maxqpsize=10;
-		learn_parm.svm_newvarsinqp=0;
-		learn_parm.svm_iter_to_shrink=-9999;
-		learn_parm.maxiter=100000;
-		learn_parm.kernel_cache_size=40;
-		learn_parm.svm_c=99999999;  /* overridden by struct_parm.C */
-		learn_parm.eps=0.001;       /* overridden by struct_parm.epsilon */
-		learn_parm.transduction_posratio=-1.0;
-		learn_parm.svm_costratio=1.0;
-		learn_parm.svm_costratio_unlab=1.0;
-		learn_parm.svm_unlabbound=1E-5;
-		learn_parm.epsilon_crit=0.001;
-		learn_parm.epsilon_a=1E-10;  /* changed from 1e-15 */
-		learn_parm.compute_loo=0;
-		learn_parm.rho=1.0;
-		learn_parm.xa_depth=0;
-		kernel_parm.kernel_type=0;
-		kernel_parm.poly_degree=3;
-		kernel_parm.rbf_gamma=1.0;
-		kernel_parm.coef_lin=1;
+		learn_parm.biased_hyperplane=1;
+		learn_parm.remove_inconsistent=0;
+		learn_parm.skip_final_opt_check=0;
+		learn_parm.svm_maxqpsize=10;
+		learn_parm.svm_newvarsinqp=0;
+		learn_parm.svm_iter_to_shrink=-9999;
+		learn_parm.maxiter=100000;
+		learn_parm.kernel_cache_size=40;
+		learn_parm.svm_c=99999999;  /* overridden by struct_parm.C */
+		learn_parm.eps=0.001;       /* overridden by struct_parm.epsilon */
+		learn_parm.transduction_posratio=-1.0;
+		learn_parm.svm_costratio=1.0;
+		learn_parm.svm_costratio_unlab=1.0;
+		learn_parm.svm_unlabbound=1E-5;
+		learn_parm.epsilon_crit=0.001;
+		learn_parm.epsilon_a=1E-10;  /* changed from 1e-15 */
+		learn_parm.compute_loo=0;
+		learn_parm.rho=1.0;
+		learn_parm.xa_depth=0;
+		kernel_parm.kernel_type=0;
+		kernel_parm.poly_degree=3;
+		kernel_parm.rbf_gamma=1.0;
+		kernel_parm.coef_lin=1;
 		kernel_parm.coef_const=1;
 		strcpy(kernel_parm.custom, "empty");
 
 	}
 
 
+	vector<SVMLabel> CSVMSTRUCTMC::ReadLabels(ifstream & ifsm) {
 
+		static const size_t c_iBuffer = 1024;
+		char acBuffer[c_iBuffer];
+		vector<string> vecstrTokens;
+		vector<SVMLabel> vecLabels;
+		size_t numPositives, numNegatives;
+		numPositives = numNegatives = 0;
+		while (!ifsm.eof()) {
+			ifsm.getline(acBuffer, c_iBuffer - 1);
+			acBuffer[c_iBuffer - 1] = 0;
+			vecstrTokens.clear();
+			CMeta::Tokenize(acBuffer, vecstrTokens);
+			if (vecstrTokens.empty())
+				continue;
+			if (vecstrTokens.size() != 2) {
+				cerr << "Illegal label line (" << vecstrTokens.size() << "): "
+					<< acBuffer << endl;
+				continue;
+			}
+			vecLabels.push_back(SVMArc::SVMLabel(vecstrTokens[0], atoi(
+				vecstrTokens[1].c_str())));
+			if (vecLabels.back().Target > 0)
+				numPositives++;
+			else
+				numNegatives++;
+		}
+		return vecLabels;
+	}
 
 
 	SAMPLE* CSVMSTRUCTMC::CreateSample(Sleipnir::CPCL &PCL, vector<SVMLabel> SVMLabels) {
 					SVMLabels[i].SetIndex(PCL.GetGene(SVMLabels[i].GeneName));
 				}
 				iGene = SVMLabels[i].index;
-				   //cout << "CLASS gene=" << iGene << endl;
+				//cout << "CLASS gene=" << iGene << endl;
 				if (iGene != -1) {
 					iDoc++;
 
 					vecResult[iDoc - 1].Value = label.Class;
 					vecResult[iDoc - 1].num_class=struct_parm.num_classes;
 					//vecResult[iDoc - 1].Scores.reserve(label.num_classes);
-					for (k = 1; k <= struct_parm.num_classes; k++)
-								vecResult[iDoc - 1].Scores.push_back(label.scores[k]);
+					for (k = 1; k <= struct_parm.num_classes; k++)
+						vecResult[iDoc - 1].Scores.push_back(label.scores[k]);
 					//cerr<<"CLASSIFY Called FreeDoc"<<endl;
 					FreeDoc(pattern.doc);
 					//cerr<<"CLASSIFY End FreeDoc"<<endl;

File src/svmstruct.h

View file
 #include "pclset.h"
 #include "meta.h"
 #include "dat.h"
-
-#include <stdio.h>
-
-/* removed to support cygwin */
-//#include <execinfo.h>
-
-namespace SVMArc {
-	extern "C" {
+#ifndef NO_SVM_STRUCT
+#define SVMSTRUCT_H
+extern "C" {
 
 #define class Class
 
 #include <svm_multiclass/svm_struct_api.h>
 #include <svm_multiclass/svm_struct/svm_struct_learn.h>
 #undef class
-		//#include "svm_struct_api.h"
+	//#include "svm_struct_api.h"
 
-	}
+}
+#endif
 
+#include <stdio.h>
+using namespace Sleipnir;
+using namespace std;
+
+/* removed to support cygwin */
+//#include <execinfo.h>
+
+namespace SVMArc {
 	class SVMLabel {
 	public:
 		string GeneName;
-		size_t Target;
+		size_t Target; //Save single integer label; used for single label classification (0-1, or multiclass)
+		vector<char> TargetM; //Save multiple labels; used for hierarchical multi-label classification;
+
 		size_t index;
 		bool hasIndex;
 		SVMLabel(std::string name, size_t target) {
 			index = -1;
 		}
 
+		SVMLabel(std::string name, vector<char> cl) {
+			GeneName = name;
+			TargetM = cl;
+			hasIndex = false;
+			index = -1;
+		}
 		SVMLabel() {
 			GeneName = "";
 			Target = 0;
 	class Result {
 	public:
 		std::string GeneName;
-		int Target;
-		int Value;
+		int Target; //for single label prediction
+		int Value; //for single label prediction
+		vector<char> TargetM;//for multi label prediction
+		vector<char> ValueM; //for multi label prediction
 		vector<double> Scores;
 		int num_class;
 		int CVround;
 		Result() {
 			GeneName = "";
 			Target = 0;
-			Value = Sleipnir::CMeta::GetNaN();
+			Value = -1;
 		}
 
 		Result(std::string name, int cv = -1) {
 			}
 			return ss.str();
 		}
+		string toStringMC() {
+			stringstream ss;
+			ss << GeneName << '\t' << Target << '\t' << Value << '\t';
+			for(size_t j=1;j<=num_class;j++)
+				ss << Scores[j]<<'\t';
+			return ss.str();
+		}
+		string toStringTREE(map<int, string>* ponto_map_rev, int returnindex) {
+			stringstream ss;
+			int mark=1;
+			ss << GeneName << '\t';
+			for(size_t j=0;j<=num_class;j++){
+				if(TargetM[j])
+					if(mark){
+						if(returnindex)
+							ss<<j;
+						else
+							ss <<(*ponto_map_rev)[j];
+						mark = 0;
+					}
+					else
+						ss <<','<<(*ponto_map_rev)[j];
+			}
+			if(mark==1)
+				ss<<"??"<<'\t';
+			else
+				ss<<'\t';
 
+			mark=1;
+			for(size_t j=0;j<=num_class;j++){
+				if(ValueM[j])
+					if(mark){
+						if(returnindex)
+							ss<<j;
+						else
+							ss <<(*ponto_map_rev)[j];
+						mark = 0;
+					}
+					else
+						ss <<','<<(*ponto_map_rev)[j];
+			}
+			if(mark)
+				ss<<"??";
+			ss <<'\t';
+			for(size_t j=1;j<=num_class;j++)
+				ss << Scores[j]<<'\t';
+			return ss.str();
+		}
 	};
 
 	enum EFilter {
 		EFilterInclude = 0, EFilterExclude = EFilterInclude + 1,
 	};
 
+	class CSVMSTRUCT{
+		/* This base class is solely intended to serve as a common template for different SVM Struct implementations
+		A few required functions are not defined here because their parameter type or return type has to differ 
+		among different implementations, but I listed them in comments. */
+	public:
+		virtual vector<Result> Classify(Sleipnir::CPCL& PCL, vector<SVMLabel> SVMLabels) = 0;
+		virtual void SetTradeoff(double tradeoff)=0;
+		virtual void SetLossFunction(size_t loss_f)=0;
+		virtual void SetLearningAlgorithm(int alg)=0;
+		virtual void UseSlackRescaling()=0;
+		virtual void UseMarginRescaling()=0;
+		virtual void ReadModel(char* model_file)=0;
+		virtual void WriteModel(char* model_file)=0;
+		virtual vector<SVMLabel> ReadLabels(ifstream & ifsm)=0;
+		virtual void SetVerbosity(size_t V)=0;
+		virtual bool parms_check() = 0;
+		virtual bool initialize() = 0;
+
+		/*The following functions should also be implemented
+		SAMPLE* CreateSample(Sleipnir::CPCL &PCL, vector<SVMLabel> SVMLabels);
+		static void FreeSample(sample s)
+		void Learn(SAMPLE &sample)
+		*/
+	};
+
+
+
+
 	//this class encapsulates the model and parameters and has no associated data
 
 
 	//class for SVMStruct
-	class CSVMSTRUCTMC {
+	class CSVMSTRUCTMC : CSVMSTRUCT{
 
 	public:
 		LEARN_PARM learn_parm;
 		static DOC* CreateDoc(Sleipnir::CPCL &PCL, size_t iGene, size_t iDoc);
 
 
+		//read in labels
+		vector<SVMLabel> ReadLabels(ifstream & ifsm);
+
 		//Creates a sample using a single PCL and SVMlabels Looks up genes by name.
-		static SAMPLE
+		SAMPLE
 			* CreateSample(Sleipnir::CPCL &PCL, vector<SVMLabel> SVMLabels);
 
 		//Classify single genes
 
 			cerr << "ALG=" << Alg << endl;
 
-			if(Alg == 0)
-				svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_ALG);
-			else if(Alg == 1)
-				svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_SHRINK_ALG);
-			else if(Alg == 2)
-				svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_PRIMAL_ALG);
-			else if(Alg == 3)
-				svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_ALG);
-			else if(Alg == 4)
-				svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_CACHE_ALG);
-			else if(Alg == 9)
-				svm_learn_struct_joint_custom(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel);
-			else
+			if(Alg == 0)
+				svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_ALG);
+			else if(Alg == 1)
+				svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_SHRINK_ALG);
+			else if(Alg == 2)
+				svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_PRIMAL_ALG);
+			else if(Alg == 3)
+				svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_ALG);
+			else if(Alg == 4)
+				svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_CACHE_ALG);
+			else if(Alg == 9)
+				svm_learn_struct_joint_custom(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel);
+			else
 				exit(1);
 			//
 		}
 
+		struct SortResults {
+
+			bool operator()(const Result& rOne, const Result & rTwo) const {
+				return (rOne.Value < rTwo.Value);
+			}
+		};
+
+		size_t PrintResults(vector<Result> vecResults, ofstream & ofsm) {
+			sort(vecResults.begin(), vecResults.end(), SortResults());
+			int LabelVal;
+			for (size_t i = 0; i < vecResults.size(); i++) {
+				ofsm << vecResults[i].GeneName << '\t' << vecResults[i].Target << '\t'
+					<< vecResults[i].Value<<'\t';
+				for(size_t j=1;j<=vecResults[i].num_class;j++)
+					ofsm << vecResults[i].Scores[j]<<'\t';
+				ofsm<< endl;
+
+			}
+		};
 
 		bool parms_check();
 		bool initialize();

File tools/Makefile.am

View file
 		SVMmulticlass
 endif
 
+if WITH_SVM_HIERARCHY_TOOLS
+    SVM_HIERARCHY_SUBDIRS = \
+		SVMhierarchy
+endif
+
 if WITH_LIBSVM_TOOLS
     LIBSVM_TOOLS_SUBDIRS = \
 	  LibSVMer
 endif
 endif
 
-SUBDIRS = $(TOOLS_SUBDIRS) $(SMILE_TOOLS_SUBDIRS) $(BOOST_TOOLS_SUBDIRS) $(SVM_TOOLS_SUBDIRS) $(READLINE_TOOLS_SUBDIRS) $(VW_TOOLS_SUBDIRS) $(GSL_TOOLS_SUBDIRS) $(LIBSVM_TOOLS_SUBDIRS) $(SVM_MULTICLASS_TOOLS_SUBDIRS)
+SUBDIRS = $(TOOLS_SUBDIRS) $(SMILE_TOOLS_SUBDIRS) $(BOOST_TOOLS_SUBDIRS) $(SVM_TOOLS_SUBDIRS) $(READLINE_TOOLS_SUBDIRS) $(VW_TOOLS_SUBDIRS) $(GSL_TOOLS_SUBDIRS) $(LIBSVM_TOOLS_SUBDIRS) $(SVM_MULTICLASS_TOOLS_SUBDIRS) $(SVM_HIERARCHY_TOOLS_SUBDIRS)

File tools/SVMmulticlass/SVMmulti.cpp

View file
 #include <fstream>
-#include <iostream>
+#include <iostream>
 #include <iterator>
 #include <vector>
 #include <queue>
 using namespace SVMArc;
 //#include "../../extlib/svm_light/svm_light/kernel.h"
 
-vector<SVMArc::SVMLabel> ReadLabels(ifstream & ifsm) {
 
-	static const size_t c_iBuffer = 1024;
-	char acBuffer[c_iBuffer];
-	vector<string> vecstrTokens;
-	vector<SVMArc::SVMLabel> vecLabels;
-	size_t numPositives, numNegatives;
-	numPositives = numNegatives = 0;
-	while (!ifsm.eof()) {
-		ifsm.getline(acBuffer, c_iBuffer - 1);
-		acBuffer[c_iBuffer - 1] = 0;
-		vecstrTokens.clear();
-		CMeta::Tokenize(acBuffer, vecstrTokens);
-		if (vecstrTokens.empty())
-			continue;
-		if (vecstrTokens.size() != 2) {
-			cerr << "Illegal label line (" << vecstrTokens.size() << "): "
-				<< acBuffer << endl;
-			continue;
-		}
-		vecLabels.push_back(SVMArc::SVMLabel(vecstrTokens[0], atoi(
-			vecstrTokens[1].c_str())));
-		if (vecLabels.back().Target > 0)
-			numPositives++;
-		else
-			numNegatives++;
-	}
-	return vecLabels;
-}
 
-struct SortResults {
-
-	bool operator()(const SVMArc::Result& rOne, const SVMArc::Result & rTwo) const {
-		return (rOne.Value < rTwo.Value);
-	}
-};
-
-size_t PrintResults(vector<SVMArc::Result> vecResults, ofstream & ofsm) {
-	sort(vecResults.begin(), vecResults.end(), SortResults());
-	int LabelVal;
-	for (size_t i = 0; i < vecResults.size(); i++) {
-		ofsm << vecResults[i].GeneName << '\t' << vecResults[i].Target << '\t'
-			<< vecResults[i].Value<<'\t';
-		for(size_t j=1;j<=vecResults[i].num_class;j++)
-			ofsm << vecResults[i].Scores[j]<<'\t';
-		ofsm<< endl;
-
-	}
-};
 
 
 int main(int iArgs, char** aszArgs) {
 		ifsm.clear();
 		ifsm.open(sArgs.labels_arg);
 		if (ifsm.is_open())
-			vecLabels = ReadLabels(ifsm);
+			vecLabels = SVM.ReadLabels(ifsm);
 		else {
 			cerr << "Could not read label file" << endl;
 			return 1;
 
 
 	//Training
-	SVMArc::SAMPLE* pTrainSample;
+	SAMPLE* pTrainSample;
 	vector<SVMArc::SVMLabel> pTrainVector[sArgs.cross_validation_arg];
 	vector<SVMArc::SVMLabel> pTestVector[sArgs.cross_validation_arg];
 	vector<SVMArc::Result> AllResults;
 	vector<SVMArc::Result> tmpAllResults;
 
 	if (sArgs.model_given && sArgs.labels_given) { //learn once and write to file
-		pTrainSample = CSVMSTRUCTMC::CreateSample(PCL, vecLabels);
+		pTrainSample = SVM.CreateSample(PCL, vecLabels);
 		SVM.Learn(*pTrainSample);
 		SVM.WriteModel(sArgs.model_arg);
 	} else if (sArgs.model_given && sArgs.output_given) { //read model and classify all
 		ofstream ofsm;
 		ofsm.open(sArgs.output_arg);
 		if (ofsm.is_open())
-			PrintResults(AllResults, ofsm);
+			SVM.PrintResults(AllResults, ofsm);
 		else {
 			cerr << "Could not open output file" << endl;
 		}
 		}
 		//run once
 		for (i = 0; i < sArgs.cross_validation_arg; i++) {
-			pTrainSample = SVMArc::CSVMSTRUCTMC::CreateSample(PCL,
+			pTrainSample = SVM.CreateSample(PCL,
 				pTrainVector[i]);
 
 			cerr << "Cross Validation Trial " << i << endl;
 		ofstream ofsm;
 		ofsm.clear();
 		ofsm.open(sArgs.output_arg);
-		PrintResults(AllResults, ofsm);
+		SVM.PrintResults(AllResults, ofsm);
 		return 0;
 
 	} else {