Commits

Jian Zhou committed 1cd104c Merge

merge

  • Participants
  • Parent commits 57f7d09, 6f01eea

Comments (0)

Files changed (1)

File tools/SVMperfer/SVMperfer.cpp

 using namespace SVMLight;
 //#include "../../extlib/svm_light/svm_light/kernel.h"
 
-inline bool file_exists (const std::string& name) {
-    struct stat buffer;   
-    return (stat (name.c_str(), &buffer) == 0); 
-}
-
-vector< pair< string, string > > ReadLabelList(ifstream & ifsm, string output_prefix) {
-  static const size_t c_iBuffer = 1024;
-  char acBuffer[c_iBuffer];
-  vector<string> vecstrTokens;
-  vector< pair < string, string > > inout;
-  while (!ifsm.eof()) {
-    ifsm.getline(acBuffer, c_iBuffer - 1);
-    acBuffer[c_iBuffer - 1] = 0;
-    vecstrTokens.clear();
-    CMeta::Tokenize(acBuffer, vecstrTokens);
-    if (vecstrTokens.empty())
-      continue;
-    if (vecstrTokens.size() != 2) {
-      cerr << "Illegal input line (" << vecstrTokens.size() << "): "
-        << acBuffer << endl;
-      continue;
-    }
-    
-    if( file_exists( output_prefix + "/" + vecstrTokens[1] ) ){
-      continue;
-    }
-    
-
-    //cout << file_exists( vecstrTokens[1] ) << endl;
-
-    inout.push_back( make_pair( vecstrTokens[0], vecstrTokens[1] ) );
-  }
-  cout << inout.size() << " number of label files." << endl;
-  return inout;
-
-}
-
 vector<SVMLight::SVMLabel> ReadLabels(ifstream & ifsm) {
 
-  static const size_t c_iBuffer = 1024;
-  char acBuffer[c_iBuffer];
-  vector<string> vecstrTokens;
-  vector<SVMLight::SVMLabel> vecLabels;
-  size_t numPositives, numNegatives;
-  numPositives = numNegatives = 0;
-  while (!ifsm.eof()) {
-    ifsm.getline(acBuffer, c_iBuffer - 1);
-    acBuffer[c_iBuffer - 1] = 0;
-    vecstrTokens.clear();
-    CMeta::Tokenize(acBuffer, vecstrTokens);
-    if (vecstrTokens.empty())
-      continue;
-    if (vecstrTokens.size() != 2) {
-      cerr << "Illegal label line (" << vecstrTokens.size() << "): "
-        << acBuffer << endl;
-      continue;
-    }
-    //cout << vecstrTokens[0] << endl;
-    //cout << vecstrTokens[1] << endl;
-
-
-    vecLabels.push_back(SVMLight::SVMLabel(vecstrTokens[0], atof(
-            vecstrTokens[1].c_str())));
-    if (vecLabels.back().Target > 0)
-      numPositives++;
-    else
-      numNegatives++;
-  }
-
-
-
-  return vecLabels;
+	static const size_t c_iBuffer = 1024;
+	char acBuffer[c_iBuffer];
+	vector<string> vecstrTokens;
+	vector<SVMLight::SVMLabel> vecLabels;
+	size_t numPositives, numNegatives;
+	numPositives = numNegatives = 0;
+	while (!ifsm.eof()) {
+		ifsm.getline(acBuffer, c_iBuffer - 1);
+		acBuffer[c_iBuffer - 1] = 0;
+		vecstrTokens.clear();
+		CMeta::Tokenize(acBuffer, vecstrTokens);
+		if (vecstrTokens.empty())
+			continue;
+		if (vecstrTokens.size() != 2) {
+			cerr << "Illegal label line (" << vecstrTokens.size() << "): "
+					<< acBuffer << endl;
+			continue;
+		}
+		vecLabels.push_back(SVMLight::SVMLabel(vecstrTokens[0], atof(
+				vecstrTokens[1].c_str())));
+		if (vecLabels.back().Target > 0)
+			numPositives++;
+		else
+			numNegatives++;
+	}
+	return vecLabels;
 }
 
 struct SortResults {
 
-  bool operator()(const SVMLight::Result& rOne, const SVMLight::Result & rTwo) const {
-    return (rOne.Value > rTwo.Value);
-  }
+	bool operator()(const SVMLight::Result& rOne, const SVMLight::Result & rTwo) const {
+		return (rOne.Value > rTwo.Value);
+	}
 };
 
 size_t PrintResults(vector<SVMLight::Result> vecResults, ofstream & ofsm) {
-  sort(vecResults.begin(), vecResults.end(), SortResults());
-  int LabelVal;
-  for (size_t i = 0; i < vecResults.size(); i++) {
-    ofsm << vecResults[i].GeneName << '\t' << vecResults[i].Target << '\t'
-      << vecResults[i].Value << endl;
-  }
+	sort(vecResults.begin(), vecResults.end(), SortResults());
+	int LabelVal;
+	for (size_t i = 0; i < vecResults.size(); i++) {
+		ofsm << vecResults[i].GeneName << '\t' << vecResults[i].Target << '\t'
+				<< vecResults[i].Value << endl;
+	}
 }
 ;
 
 struct ParamStruct {
-  vector<float> vecK, vecTradeoff;
-  vector<size_t> vecLoss;
-  vector<char*> vecNames;
+	vector<float> vecK, vecTradeoff;
+	vector<size_t> vecLoss;
+	vector<char*> vecNames;
 };
 
 ParamStruct ReadParamsFromFile(ifstream& ifsm, string outFile) {
-  static const size_t c_iBuffer = 1024;
-  char acBuffer[c_iBuffer];
-  char* nameBuffer;
-  vector<string> vecstrTokens;
-  size_t extPlace;
-  string Ext, FileName;
-  if ((extPlace = outFile.find_first_of(".")) != string::npos) {
-    FileName = outFile.substr(0, extPlace);
-    Ext = outFile.substr(extPlace, outFile.size());
-  } else {
-    FileName = outFile;
-    Ext = "";
-  }
-  ParamStruct PStruct;
-  size_t index = 0;
-  while (!ifsm.eof()) {
-    ifsm.getline(acBuffer, c_iBuffer - 1);
-    acBuffer[c_iBuffer - 1] = 0;
-    vecstrTokens.clear();
-    CMeta::Tokenize(acBuffer, vecstrTokens);
-    if (vecstrTokens.empty())
-      continue;
-    if (vecstrTokens.size() != 3) {
-      cerr << "Illegal params line (" << vecstrTokens.size() << "): "
-        << acBuffer << endl;
-      continue;
-    }
-    if (acBuffer[0] == '#') {
-      cerr << "skipping " << acBuffer << endl;
-    } else {
-      PStruct.vecLoss.push_back(atoi(vecstrTokens[0].c_str()));
-      PStruct.vecTradeoff.push_back(atof(vecstrTokens[1].c_str()));
-      PStruct.vecK.push_back(atof(vecstrTokens[2].c_str()));
-      PStruct.vecNames.push_back(new char[c_iBuffer]);
-      if (PStruct.vecLoss[index] == 4 || PStruct.vecLoss[index] == 5)
-        sprintf(PStruct.vecNames[index], "%s_l%d_c%4.6f_k%4.3f%s",
-            FileName.c_str(), PStruct.vecLoss[index],
-            PStruct.vecTradeoff[index], PStruct.vecK[index],
-            Ext.c_str());
-      else
-        sprintf(PStruct.vecNames[index], "%s_l%d_c%4.6f%s",
-            FileName.c_str(), PStruct.vecLoss[index],
-            PStruct.vecTradeoff[index], Ext.c_str());
-      index++;
-    }
+	static const size_t c_iBuffer = 1024;
+	char acBuffer[c_iBuffer];
+	char* nameBuffer;
+	vector<string> vecstrTokens;
+	size_t extPlace;
+	string Ext, FileName;
+	if ((extPlace = outFile.find_first_of(".")) != string::npos) {
+		FileName = outFile.substr(0, extPlace);
+		Ext = outFile.substr(extPlace, outFile.size());
+	} else {
+		FileName = outFile;
+		Ext = "";
+	}
+	ParamStruct PStruct;
+	size_t index = 0;
+	while (!ifsm.eof()) {
+		ifsm.getline(acBuffer, c_iBuffer - 1);
+		acBuffer[c_iBuffer - 1] = 0;
+		vecstrTokens.clear();
+		CMeta::Tokenize(acBuffer, vecstrTokens);
+		if (vecstrTokens.empty())
+			continue;
+		if (vecstrTokens.size() != 3) {
+			cerr << "Illegal params line (" << vecstrTokens.size() << "): "
+					<< acBuffer << endl;
+			continue;
+		}
+		if (acBuffer[0] == '#') {
+			cerr << "skipping " << acBuffer << endl;
+		} else {
+			PStruct.vecLoss.push_back(atoi(vecstrTokens[0].c_str()));
+			PStruct.vecTradeoff.push_back(atof(vecstrTokens[1].c_str()));
+			PStruct.vecK.push_back(atof(vecstrTokens[2].c_str()));
+			PStruct.vecNames.push_back(new char[c_iBuffer]);
+			if (PStruct.vecLoss[index] == 4 || PStruct.vecLoss[index] == 5)
+				sprintf(PStruct.vecNames[index], "%s_l%d_c%4.6f_k%4.3f%s",
+						FileName.c_str(), PStruct.vecLoss[index],
+						PStruct.vecTradeoff[index], PStruct.vecK[index],
+						Ext.c_str());
+			else
+				sprintf(PStruct.vecNames[index], "%s_l%d_c%4.6f%s",
+						FileName.c_str(), PStruct.vecLoss[index],
+						PStruct.vecTradeoff[index], Ext.c_str());
+			index++;
+		}
 
-  }
-  return PStruct;
+	}
+	return PStruct;
 }
 
 int main(int iArgs, char** aszArgs) {
-  gengetopt_args_info sArgs;
+	gengetopt_args_info sArgs;
 
-  CPCL PCL;
-  SVMLight::CSVMPERF SVM;
+	CPCL PCL;
+	SVMLight::CSVMPERF SVM;
 
-  size_t i, j, iGene, jGene;
-  ifstream ifsm, iifsm;
+	size_t i, j, iGene, jGene;
+	ifstream ifsm;
+	if (cmdline_parser(iArgs, aszArgs, &sArgs)) {
+		cmdline_parser_print_help();
+		return 1;
+	}
+	SVM.SetVerbosity(sArgs.verbosity_arg);
+	SVM.SetLossFunction(sArgs.error_function_arg);
+	if (sArgs.k_value_arg > 1) {
+		cerr << "k_value is >1. Setting default 0.5" << endl;
+		SVM.SetPrecisionFraction(0.5);
+	} else if (sArgs.k_value_arg <= 0) {
+		cerr << "k_value is <=0. Setting default 0.5" << endl;
+		SVM.SetPrecisionFraction(0.5);
+	} else {
+		SVM.SetPrecisionFraction(sArgs.k_value_arg);
+	}
 
-  if (cmdline_parser(iArgs, aszArgs, &sArgs)) {
-    cmdline_parser_print_help();
-    return 1;
-  }
-  SVM.SetVerbosity(sArgs.verbosity_arg);
-  SVM.SetLossFunction(sArgs.error_function_arg);
-  if (sArgs.k_value_arg > 1) {
-    cerr << "k_value is >1. Setting default 0.5" << endl;
-    SVM.SetPrecisionFraction(0.5);
-  } else if (sArgs.k_value_arg <= 0) {
-    cerr << "k_value is <=0. Setting default 0.5" << endl;
-    SVM.SetPrecisionFraction(0.5);
-  } else {
-    SVM.SetPrecisionFraction(sArgs.k_value_arg);
-  }
+	
+	if (sArgs.cross_validation_arg < 1){
+	  cerr << "cross_valid is <1. Must be set at least 1" << endl;
+	  return 1;
+	}
+	else if(sArgs.cross_validation_arg < 2){
+	  cerr << "cross_valid is set to 1. No cross validation holdouts will be run." << endl;
+	}
+	
+	SVM.SetTradeoff(sArgs.tradeoff_arg);
+	if (sArgs.slack_flag)
+		SVM.UseSlackRescaling();
+	else
+		SVM.UseMarginRescaling();
 
 
-  if (sArgs.cross_validation_arg < 1){
-    cerr << "cross_valid is <1. Must be set at least 1" << endl;
-    return 1;
-  }
-  else if(sArgs.cross_validation_arg < 2){
-    cerr << "cross_valid is set to 1. No cross validation holdouts will be run." << endl;
-  }
+	if (!SVM.parms_check()) {
+		cerr << "Sanity check failed, see above errors" << endl;
+		return 1;
+	}
 
-  SVM.SetTradeoff(sArgs.tradeoff_arg);
-  if (sArgs.slack_flag)
-    SVM.UseSlackRescaling();
-  else
-    SVM.UseMarginRescaling();
+	//  cout << "there are " << vecLabels.size() << " labels processed" << endl;
+	size_t iFile;
+	vector<string> PCLs;
+	if (sArgs.input_given) {
+		if (!PCL.Open(sArgs.input_arg, sArgs.skip_arg, sArgs.mmap_flag)) {
+			cerr << "Could not open input PCL" << endl;
+			return 1;
+		}
+	}
 
+	vector<SVMLight::SVMLabel> vecLabels;
+	set<string> setLabeledGenes;
+	if (sArgs.labels_given) {
+		ifsm.clear();
+		ifsm.open(sArgs.labels_arg);
+		if (ifsm.is_open())
+			vecLabels = ReadLabels(ifsm);
+		else {
+			cerr << "Could not read label file" << endl;
+			return 1;
+		}
+		for (i = 0; i < vecLabels.size(); i++)
+			setLabeledGenes.insert(vecLabels[i].GeneName);
+	}
 
-  if (!SVM.parms_check()) {
-    cerr << "Sanity check failed, see above errors" << endl;
-    return 1;
-  }
+	SVMLight::SAMPLE* pTrainSample;
+	vector<SVMLight::SVMLabel> pTrainVector[sArgs.cross_validation_arg];
+	vector<SVMLight::SVMLabel> pTestVector[sArgs.cross_validation_arg];
+	vector<SVMLight::Result> AllResults;
+	vector<SVMLight::Result> tmpAllResults;
 
-  if (!sArgs.output_given){
-    cerr << "output prefix not provided" << endl;
-    return 1;
-  }
-  
-  string output_prefix(sArgs.output_arg);
+	if (sArgs.model_given && sArgs.labels_given) { //learn once and write to file
+		pTrainSample = CSVMPERF::CreateSample(PCL, vecLabels);
+		SVM.Learn(*pTrainSample);
+		SVM.WriteModel(sArgs.model_arg,sArgs.simple_model_flag);
+	} else if (sArgs.model_given && sArgs.output_given) { //read model and classify all
 
-  //  cout << "there are " << vecLabels.size() << " labels processed" << endl;
-  size_t iFile;
-  vector<string> PCLs;
-  if (sArgs.input_given) {
-    if (!PCL.Open(sArgs.input_arg, sArgs.skip_arg, sArgs.mmap_flag)) {
-      cerr << "Could not open input PCL" << endl;
-      return 1;
-    }
-  }
+		if(sArgs.test_labels_given && !sArgs.all_flag){
+		vector<SVMLight::SVMLabel> vecTestLabels;
+			ifsm.clear();
+			ifsm.open(sArgs.test_labels_arg);
+			if (ifsm.is_open())
+				vecTestLabels = ReadLabels(ifsm);
 
+			else {
+				cerr << "Could not read label file" << endl;
+				exit(1);
+			}
 
-  vector< pair < string, string > > vecLabelLists;
-  if (sArgs.labels_given) {
-    ifsm.clear();
-    ifsm.open(sArgs.labels_arg);
-    if (ifsm.is_open())
-      vecLabelLists = ReadLabelList(ifsm, output_prefix);
-    else {
-      cerr << "Could not read label list" << endl;
-      return 1;
-    }
-    ifsm.close();
-  }else{
-    cerr << "list of labels not given" << endl;
-    return 1;
-    //  if (sArgs.labels_given) {
-    //    vecLabelLists.push_back(pair(sArgs.labels_arg,sArgs.output_arg))
-    //  }
-  }
-  size_t k;
-  string labels_fn;
-  string output_fn;
 
-  
-    SVMLight::SAMPLE* pTrainSample;
-    vector<SVMLight::Result> AllResults;
-    vector<SVMLight::Result> tmpAllResults;
-    vector<SVMLight::SVMLabel> pTrainVector[sArgs.cross_validation_arg];
-    vector<SVMLight::SVMLabel> pTestVector[sArgs.cross_validation_arg];
-    vector<SVMLight::SVMLabel> vecLabels;
- 
-    string out_fn;
+			cerr << "Loading Model" << endl;
+			SVM.ReadModel(sArgs.model_arg);
+			cerr << "Model Loaded" << endl;
 
-  for(k = 0; k < vecLabelLists.size(); k ++){
-    labels_fn = vecLabelLists[k].first;
-    output_fn = vecLabelLists[k].second;
+			pTestVector[0].reserve((size_t) vecTestLabels.size()+1 );
+			for (j = 0; j < vecTestLabels.size(); j++) {
+				pTestVector[0].push_back(vecTestLabels[j]);		      
+			}
 
-    cout << labels_fn << endl;
-    cout << output_fn << endl;
-    
-    vecLabels.clear();
 
-    ifsm.clear();
-    ifsm.open(labels_fn.c_str());
-    if (ifsm.is_open())
-      vecLabels = ReadLabels(ifsm);
-    else {
-      cerr << "Could not read label file" << endl;
-      return 1;
-    }
-    ifsm.close();
+			tmpAllResults = SVM.Classify(PCL,	pTestVector[0]);
+			cerr << "Classified " << tmpAllResults.size() << " examples"<< endl;
+			AllResults.insert(AllResults.end(), tmpAllResults.begin(), tmpAllResults.end());
+			tmpAllResults.resize(0);
+			ofstream ofsm;
+			ofsm.clear();
+			ofsm.open(sArgs.output_arg);
+			PrintResults(AllResults, ofsm);
+			return 0;
+		}else{
+			vector<SVMLabel> vecAllLabels;
 
-    cout << "finished reading labels." << endl;
+			for (size_t i = 0; i < PCL.GetGenes(); i++)
+				vecAllLabels.push_back(SVMLabel(PCL.GetGene(i), 0));
 
+			SVM.ReadModel(sArgs.model_arg);
+			AllResults = SVM.Classify(PCL, vecAllLabels);
+			ofstream ofsm;
+			ofsm.open(sArgs.output_arg);
+			if (ofsm.is_open())
+				PrintResults(AllResults, ofsm);
+			else {
+				cerr << "Could not open output file" << endl;
+			}
+		}
+	} else if (sArgs.output_given && sArgs.labels_given) {
+		//do learning and classifying with cross validation
+	        if( sArgs.cross_validation_arg > 1){	    
+		  for (i = 0; i < sArgs.cross_validation_arg; i++) {
+		    pTestVector[i].reserve((size_t) vecLabels.size()
+					   / sArgs.cross_validation_arg + sArgs.cross_validation_arg);
+		    pTrainVector[i].reserve((size_t) vecLabels.size()
+					    / (sArgs.cross_validation_arg)
+					    * (sArgs.cross_validation_arg - 1)
+					    + sArgs.cross_validation_arg);
+		    for (j = 0; j < vecLabels.size(); j++) {
+		      if (j % sArgs.cross_validation_arg == i) {
+			pTestVector[i].push_back(vecLabels[j]);
+		      } else {
+			pTrainVector[i].push_back((vecLabels[j]));
+		      }
+		    }
+		  }
+		}
+		else{ // if you have less than 2 fold cross, no cross validation is done, all train genes are used. If test_labels are predicted if given, otherwise all genes are predicted.
+		  
+			if(sArgs.test_labels_given){
+					  pTrainVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg);
+					  for (j = 0; j < vecLabels.size(); j++) {
+						pTrainVector[0].push_back(vecLabels[j]);		    
+					  }
 
-    //do learning and classifying with cross validation
-    if( sArgs.cross_validation_arg > 1){	    
-      for (i = 0; i < sArgs.cross_validation_arg; i++) {
+						ifstream ifsm2;
+						vector<SVMLight::SVMLabel> vecTestLabels;
+						ifsm2.clear();
+						ifsm2.open(sArgs.test_labels_arg);
+						if (ifsm2.is_open())
+							vecTestLabels = ReadLabels(ifsm2);
+						else {
+							cerr << "Could not read label file" << endl;
+							exit(1);
+						}
 
-        pTestVector[i].clear();
-        pTrainVector[i].clear();
+						pTestVector[0].reserve((size_t) vecTestLabels.size()+1 );
+						for (j = 0; j < vecTestLabels.size(); j++) {
+							pTestVector[0].push_back(vecTestLabels[j]);		      
+						}
+						
+			}
+			else{// no holdout so train is the same as test gene set
+					  pTestVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg);
+					  pTrainVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg);
+		  
+					  for (j = 0; j < vecLabels.size(); j++) {
+						pTestVector[0].push_back(vecLabels[j]);		      
+						pTrainVector[0].push_back(vecLabels[j]);		    
+					  }
+			}
+		}
+		
+		
+		vector<SVMLabel> vec_allUnlabeledLabels;
+		vector<Result> vec_allUnlabeledResults;
+		vector<Result> vec_tmpUnlabeledResults;
+		if (sArgs.all_flag) {
+			vec_allUnlabeledLabels.reserve(PCL.GetGenes());
+			vec_allUnlabeledResults.reserve(PCL.GetGenes());
+			for (i = 0; i < PCL.GetGenes(); i++) {
+				if (setLabeledGenes.find(PCL.GetGene(i))
+						== setLabeledGenes.end()) {
+					vec_allUnlabeledLabels.push_back(
+							SVMLabel(PCL.GetGene(i), 0));
+					vec_allUnlabeledResults.push_back(Result(PCL.GetGene(i)));
+				}
+			}
+		}
+		if (sArgs.params_given) { //reading paramters from file
+			ifsm.close();
+			ifsm.clear();
+			ifsm.open(sArgs.params_arg);
+			if (!ifsm.is_open()) {
+				cerr << "Could not open: " << sArgs.params_arg << endl;
+				return 1;
+			}
+			ParamStruct PStruct;
+			string outFile(sArgs.output_arg);
+			PStruct = ReadParamsFromFile(ifsm, outFile);
 
-        pTestVector[i].reserve((size_t) vecLabels.size()
-            / sArgs.cross_validation_arg + sArgs.cross_validation_arg);
-        pTrainVector[i].reserve((size_t) vecLabels.size()
-            / (sArgs.cross_validation_arg)
-            * (sArgs.cross_validation_arg - 1)
-            + sArgs.cross_validation_arg);
-        for (j = 0; j < vecLabels.size(); j++) {
-          if (j % sArgs.cross_validation_arg == i) {
-            pTestVector[i].push_back(vecLabels[j]);
-          } else {
-            pTrainVector[i].push_back((vecLabels[j]));
-          }
-        }
-      }
-    }
-    else{ // if you have less than 2 fold cross, no cross validation is done, all train genes are used and predicted
+			size_t iParams;
+			ofstream ofsm;
+			SVMLight::SAMPLE * ppTrainSample[sArgs.cross_validation_arg];
+			
+			//build all the samples since they are being reused
+			for (i = 0; i < sArgs.cross_validation_arg; i++)
+				ppTrainSample[i] = SVMLight::CSVMPERF::CreateSample(PCL,
+						pTrainVector[i]);
+			
+			for (iParams = 0; iParams < PStruct.vecTradeoff.size(); iParams++) {
+				SVM.SetLossFunction(PStruct.vecLoss[iParams]);
+				SVM.SetTradeoff(PStruct.vecTradeoff[iParams]);
+				SVM.SetPrecisionFraction(PStruct.vecK[iParams]);
+				for (j = 0; j < vec_allUnlabeledResults.size(); j++)
+					vec_allUnlabeledResults[j].Value = 0;
+				for (i = 0; i < sArgs.cross_validation_arg; i++) {
+					cerr << "Cross Validation Trial " << i << endl;
+					SVM.Learn(*ppTrainSample[i]);
+					
+					cerr << "Learned" << endl;					
+					
+					tmpAllResults = SVM.Classify(PCL, pTestVector[i]);
+					cerr << "Classified " << tmpAllResults.size()
+							<< " examples" << endl;
+					AllResults.insert(AllResults.end(), tmpAllResults.begin(),
+							tmpAllResults.end());
+					tmpAllResults.resize(0);
+					if (sArgs.all_flag && vec_allUnlabeledLabels.size() > 0) {
+						vec_tmpUnlabeledResults = SVM.Classify(PCL,
+								vec_allUnlabeledLabels);
+						for (j = 0; j < vec_tmpUnlabeledResults.size(); j++)
+							vec_allUnlabeledResults[j].Value
+									+= vec_tmpUnlabeledResults[j].Value;
+					}
 
-      // no holdout so train is the same as test gene set
-      pTestVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg);
-      pTrainVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg);
+				}
 
-      for (j = 0; j < vecLabels.size(); j++) {
-        pTestVector[0].push_back(vecLabels[j]);		      
-        pTrainVector[0].push_back(vecLabels[j]);		    
-      }
-    }
 
-    for (i = 0; i < sArgs.cross_validation_arg; i++) {
-      pTrainSample = SVMLight::CSVMPERF::CreateSample(PCL,
-          pTrainVector[i]);
+				ofsm.open(PStruct.vecNames[iParams]);
+				if (sArgs.all_flag) { //add the unlabeled results
+					for (j = 0; j < vec_tmpUnlabeledResults.size(); j++)
+						vec_allUnlabeledResults[j].Value
+								/= sArgs.cross_validation_arg;
+					AllResults.insert(AllResults.end(),
+							vec_allUnlabeledResults.begin(),
+							vec_allUnlabeledResults.end());
+				}
 
-      cerr << "Cross Validation Trial " << i << endl;
+				PrintResults(AllResults, ofsm);
+				ofsm.close();
+				ofsm.clear();
+				if (i > 0 || iParams > 0)
+					SVM.FreeModel();
+				AllResults.resize(0);
+			}
+		} else { //run once
 
-      SVM.Learn(*pTrainSample);
-      cerr << "Learned" << endl;
-      tmpAllResults = SVM.Classify(PCL,
-          pTestVector[i]);
-      cerr << "Classified " << tmpAllResults.size() << " examples"
-        << endl;
-      AllResults.insert(AllResults.end(), tmpAllResults.begin(),
-          tmpAllResults.end());
-      tmpAllResults.resize(0);
+			for (i = 0; i < sArgs.cross_validation_arg; i++) {
+				pTrainSample = SVMLight::CSVMPERF::CreateSample(PCL,
+						pTrainVector[i]);
 
-      if (i > 0) {
-        SVMLight::CSVMPERF::FreeSample(*pTrainSample);
-      }
-    }
+				cerr << "Cross Validation Trial " << i << endl;
 
-    ofstream ofsm;
-    ofsm.clear();
-    out_fn = output_prefix + "/" + output_fn;
-    ofsm.open(out_fn.c_str());
-    PrintResults(AllResults, ofsm);
-    cout << "printed: " << output_fn << endl;
+				SVM.Learn(*pTrainSample);
+				cerr << "Learned" << endl;
+				tmpAllResults = SVM.Classify(PCL,
+						pTestVector[i]);
 
- 
-    delete[] pTrainSample;
-    AllResults.clear();
-    tmpAllResults.clear();
-    vecLabels.clear();
+				cerr << "Classified " << tmpAllResults.size() << " examples"
+						<< endl;
+				AllResults.insert(AllResults.end(), tmpAllResults.begin(),
+						tmpAllResults.end());
+				tmpAllResults.resize(0);
+				if (sArgs.all_flag) {
+					vec_tmpUnlabeledResults = SVM.Classify(
+							PCL, vec_allUnlabeledLabels);
+					for (j = 0; j < vec_tmpUnlabeledResults.size(); j++)
+						vec_allUnlabeledResults[j].Value
+								+= vec_tmpUnlabeledResults[j].Value;
 
+				}
+				if (i > 0) {
+					SVMLight::CSVMPERF::FreeSample(*pTrainSample);
+				}
+			}
 
+			if (sArgs.all_flag) { //add the unlabeled results
+				for (j = 0; j < vec_allUnlabeledResults.size(); j++)
+					vec_allUnlabeledResults[j].Value
+							/= sArgs.cross_validation_arg;
+				AllResults.insert(AllResults.end(),
+						vec_allUnlabeledResults.begin(),
+						vec_allUnlabeledResults.end());
+			}
 
-  } 
+			ofstream ofsm;
+			ofsm.clear();
+			ofsm.open(sArgs.output_arg);
+			PrintResults(AllResults, ofsm);
+			return 0;
+		}
+	} else {
+		cerr << "More options are needed" << endl;
+	}
+
 }