Commits

Ceri Stagg committed 3407603

Refactor svm_learn2 method

Comments (0)

Files changed (1)

svmlight/svmlight.py

                 ("totwords",  c_int),
                 ("totdoc",  c_int)]'''
 
+SVMCallTuple = namedtuple('SVMCallTuple', 'doclistdata client_data kernel_cache model' )
+
 def svm_learn2( doclist, **kwds):
     
     client_data = read_learning_parameters( **kwds )
 
     print( "Learn type: %r" % ( client_data.plearn.type ))
 
+    svm_call_tuple = SVMCallTuple( doclistdata, client_data, kernel_cache, model )
+
     if client_data.plearn.type == CLASSIFICATION:
-        svm.svm_learn_classification( doclistdata.docs, 
-                                      doclistdata.labels,
-                                      doclistdata.totdoc,
-                                      doclistdata.totwords,
-                                      pointer( client_data.plearn ),
-                                      pointer( client_data.kparm ),
-                                      kernel_cache,
-                                      pointer( model ),
-                                      None )
-                                  
-    #if(!unpack_doclist(doclist, &docs, &target, &totwords, &totdoc))
-    #    return NULL;
+        call_svm_method_with_null( "svm_learn_classification", svm_call_tuple )
+    elif client_data.plearn.type == OPTIMIZATION:
+        call_svm_method_with_null( "svm_learn_optimization", svm_call_tuple )
+    elif client_data.plearn.type == REGRESSION:
+        call_svm_method_without_null( "svm_learn_regression", svm_call_tuple )
+    elif client_data.plearn.type == RANKING:
+        call_svm_method_without_null( "svm_learn_ranking", svm_call_tuple )
 
     # return (learn_parm, kernal_parm)
 # ----------------------------------------------
 
+def call_svm_method_with_null( method_name, svm_call_tuple ):
+
+    method = getattr(svm, method_name)
+    method( svm_call_tuple.doclistdata.docs, 
+            svm_call_tuple.doclistdata.labels,
+            svm_call_tuple.doclistdata.totdoc,
+            svm_call_tuple.doclistdata.totwords,
+            pointer( svm_call_tuple.client_data.plearn ),
+            pointer( svm_call_tuple.client_data.kparm ),
+            svm_call_tuple.kernel_cache,
+            pointer( svm_call_tuple.model ),
+            None )
+
+def call_svm_method_without_null( method_name, svm_call_tuple ):
+
+    method = getattr(svm, method_name)
+    method( svm_call_tuple.doclistdata.docs, 
+            svm_call_tuple.doclistdata.labels,
+            svm_call_tuple.doclistdata.totdoc,
+            svm_call_tuple.doclistdata.totwords,
+            pointer( svm_call_tuple.client_data.plearn ),
+            pointer( svm_call_tuple.client_data.kparm ),
+            pointer( svm_call_tuple.kernel_cache ),
+            pointer( svm_call_tuple.model ) )
+
 
 # -------------------- MAIN --------------------
 if __name__ == "__main__":