Ceri Stagg avatar Ceri Stagg committed 3407603

Refactor svm_learn2 method

Comments (0)

Files changed (1)

svmlight/svmlight.py

                 ("totwords",  c_int),
                 ("totdoc",  c_int)]'''
 
+SVMCallTuple = namedtuple('SVMCallTuple', 'doclistdata client_data kernel_cache model' )
+
 def svm_learn2( doclist, **kwds):
     
     client_data = read_learning_parameters( **kwds )
 
     print( "Learn type: %r" % ( client_data.plearn.type ))
 
+    svm_call_tuple = SVMCallTuple( doclistdata, client_data, kernel_cache, model )
+
     if client_data.plearn.type == CLASSIFICATION:
-        svm.svm_learn_classification( doclistdata.docs, 
-                                      doclistdata.labels,
-                                      doclistdata.totdoc,
-                                      doclistdata.totwords,
-                                      pointer( client_data.plearn ),
-                                      pointer( client_data.kparm ),
-                                      kernel_cache,
-                                      pointer( model ),
-                                      None )
-                                  
-    #if(!unpack_doclist(doclist, &docs, &target, &totwords, &totdoc))
-    #    return NULL;
+        call_svm_method_with_null( "svm_learn_classification", svm_call_tuple )
+    elif client_data.plearn.type == OPTIMIZATION:
+        call_svm_method_with_null( "svm_learn_optimization", svm_call_tuple )
+    elif client_data.plearn.type == REGRESSION:
+        call_svm_method_without_null( "svm_learn_regression", svm_call_tuple )
+    elif client_data.plearn.type == RANKING:
+        call_svm_method_without_null( "svm_learn_ranking", svm_call_tuple )
 
     # return (learn_parm, kernal_parm)
 # ----------------------------------------------
 
+def call_svm_method_with_null( method_name, svm_call_tuple ):
+
+    method = getattr(svm, method_name)
+    method( svm_call_tuple.doclistdata.docs, 
+            svm_call_tuple.doclistdata.labels,
+            svm_call_tuple.doclistdata.totdoc,
+            svm_call_tuple.doclistdata.totwords,
+            pointer( svm_call_tuple.client_data.plearn ),
+            pointer( svm_call_tuple.client_data.kparm ),
+            svm_call_tuple.kernel_cache,
+            pointer( svm_call_tuple.model ),
+            None )
+
+def call_svm_method_without_null( method_name, svm_call_tuple ):
+
+    method = getattr(svm, method_name)
+    method( svm_call_tuple.doclistdata.docs, 
+            svm_call_tuple.doclistdata.labels,
+            svm_call_tuple.doclistdata.totdoc,
+            svm_call_tuple.doclistdata.totwords,
+            pointer( svm_call_tuple.client_data.plearn ),
+            pointer( svm_call_tuple.client_data.kparm ),
+            pointer( svm_call_tuple.kernel_cache ),
+            pointer( svm_call_tuple.model ) )
+
 
 # -------------------- MAIN --------------------
 if __name__ == "__main__":
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.