1. libsleipnir
  2. sleipnir

Commits

Jian Zhou  committed 9944767

fixed SVMhierarchy/SVMmulticlass read model function

  • Participants
  • Parent commits 65ac4e5
  • Branches sleipnir

Comments (0)

Files changed (3)

File src/svmstruct.h

View file
  • Ignore whitespace
 		EFilterInclude = 0, EFilterExclude = EFilterInclude + 1,
 	};
 
-	class CSVMSTRUCT{
+	class CSVMSTRUCTBASE{
 		/* This base class is solely intended to serve as a common template for different SVM Struct implementations
 		A few required functions are not defined here because their parameter type or return type has to differ 
 		among different implementations, but I listed them in comments. */
 
 
 	//class for SVMStruct
-	class CSVMSTRUCTMC : CSVMSTRUCT{
+	class CSVMSTRUCTMC : CSVMSTRUCTBASE{
 
 	public:
 		LEARN_PARM learn_parm;
 
 
 		void ReadModel(char* model_file) {
-			FreeModel();
 			structmodel = read_struct_model(model_file, &struct_parm);
+			if(structmodel.svm_model->kernel_parm.kernel_type == LINEAR) { /* linear kernel */
+				/* compute weight vector */
+				add_weight_vector_to_linear_model(structmodel.svm_model);
+				structmodel.w=structmodel.svm_model->lin_weights;
+			}
 		}
 
 		void WriteModel(char* model_file) {

File src/svmstructtree.h

View file
  • Ignore whitespace
 #include <svm_hierarchy/svm_struct/svm_struct_learn.h>
 #undef class
 
-	}
+}
 #endif
 
 #include "svmstruct.h"
 	};
 
 	//class for SVMStruct
-	class CSVMSTRUCTTREE : CSVMSTRUCT {
+	class CSVMSTRUCTTREE : CSVMSTRUCTBASE {
 
 	public:
 		LEARN_PARM learn_parm;
 		map<string,int> onto_map;
 		map<int, string> onto_map_rev;
 
-		
+
 		int Alg;
 		CSVMSTRUCTTREE() {
 			initialize();
 		}
 
 		void ReadModel(char* model_file) {
-			FreeModel();
+
 			structmodel = read_struct_model(model_file, &struct_parm);
+			if(structmodel.svm_model->kernel_parm.kernel_type == LINEAR) { /* linear kernel */
+				/* compute weight vector */
+				add_weight_vector_to_linear_model(structmodel.svm_model);
+				structmodel.w=structmodel.svm_model->lin_weights;
+			}
 		}
 
 		void WriteModel(char* model_file) {
 			//		ofsm << structmodel.w[i+1] << endl;
 			//	}
 			//} else {
-				write_struct_model(model_file, &structmodel, &struct_parm);
+			write_struct_model(model_file, &structmodel, &struct_parm);
 			/*}*/
 		}
 

File tools/SVMhierarchy/SVMhierarchy.cpp

View file
  • Ignore whitespace
 		for (i = 0; i < vecLabels.size(); i++)
 			setLabeledGenes.insert(vecLabels[i].GeneName);
 		cerr << "Read labels from file" << endl;
+		SVM.InitializeLikAfterReadLabels();
 	}
 
 
-	SVM.InitializeLikAfterReadLabels();
+	
 
 
 	//  cout << "there are " << vecLabels.size() << " labels processed" << endl;
 	size_t iFile;
 	vector<string> PCLs;
 	if (sArgs.input_given) {
+		cerr << "Loading PCL file" << endl;
 		if (!PCL.Open(sArgs.input_arg, sArgs.skip_arg, sArgs.mmap_flag)) {
 			cerr << "Could not open input PCL" << endl;
 			return 1;
 		}
 	}
+	cerr << "PCL file Loaded" << endl;
 
 
 
 				vecAllLabels.push_back(SVMLabel(PCL.GetGene(i), 0));
 
 			SVM.ReadModel(sArgs.model_arg);
+			cerr << "Model Loaded" << endl;
+
 			AllResults = SVM.Classify(PCL, vecAllLabels);
 			ofstream ofsm;
 			ofsm.open(sArgs.output_arg);
 			ifsm.open(sArgs.test_labels_arg);
 			if (ifsm.is_open())
 				vecLabels = SVM.ReadLabels(ifsm);
+
 			else {
 				cerr << "Could not read label file" << endl;
 				exit(1);
 			}
 			for (i = 0; i < vecLabels.size(); i++)
 				setLabeledGenes.insert(vecLabels[i].GeneName);
-
+			cerr << "Loading Model" << endl;
+			SVM.ReadModel(sArgs.model_arg);
+			cerr << "Model Loaded" << endl;
 
 			pTestVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg);
 			for (j = 0; j < vecLabels.size(); j++) {
 				pTestVector[0].push_back(vecLabels[j]);		      
 			}
+
+
 			tmpAllResults = SVM.Classify(PCL,	pTestVector[0]);
 			cerr << "Classified " << tmpAllResults.size() << " examples"<< endl;
 			AllResults.insert(AllResults.end(), tmpAllResults.begin(), tmpAllResults.end());