Ceri Stagg avatar Ceri Stagg committed 51457dc

Refactoring of python code

Comments (0)

Files changed (1)

svmlight/svmlight.py

 RANKING        = 3
 OPTIMIZATION   = 4
 
+method_dict = { CLASSIFICATION : 'classification', 
+                REGRESSION : 'regression', 
+                RANKING : 'ranking', 
+                OPTIMIZATION : 'optimization' }
+
+enum_dict = {}
+for k, v in method_dict.items():
+    enum_dict[v] = k
+
 MAXSHRINK = 50000
+# change this as necessary to the correct path to your svmlight.so lib
 svm = CDLL("./svmlight.so")
 # ----------------------------------------------
 
 
     if "type" in kwds:
         typ = kwds["type"]
-        if typ == "classification":
-            learn_parm.type = CLASSIFICATION
-        elif typ == "regression":
-            learn_parm.type = REGRESSION
-        elif typ == "ranking":
-            learn_parm.type = RANKING
-        elif typ == "optimization":
-            learn_parm.type = OPTIMIZATION
-        else:
-            raise Exception("unknown learning type specified. Valid types are: 'classification', 'regression', 'ranking' and 'optimization'.")
+        if not typ in method_dict.values():
+            raise Exception, "unknown learning type specified. Valid types are: 'classification', 'regression', 'ranking' and 'optimization'."
 
+        learn_parm.type = enum_dict[ typ ]
+            
     print 'Type:'
     print learn_parm.type
 
 
     words = [WordTuple( int( feat0 ), feat1 ) for
              feat0, feat1 in words_list[:max_words_doc]]
-    '''
-    words = []
-
-    for (feat0, feat1) in words_list:
-        if len( words ) >= max_words_doc:
-            break 
-        wordtuple = WordTuple( int( feat0 ), feat1 )
-        words.append( wordtuple )
-    '''
 
     # sentinel entry required by C code
     words.append( WordTuple( 0, 0.0 ) )
                 ("totdoc",  c_int)]
 # ----------------------------------------------
 
+
 def unpack_doclist( doclist ):
     try:
         doc_iterator = iter(doclist)
     except TypeError, te:
-        raise Exception("Not iterable")
+        raise Exception, "Not iterable"
 
-    (max_docs, max_words) = count_doclist( doclist )
+    max_docs, max_words = count_doclist( doclist )
 
     tempdoclist = []
     templabellist = []
     for item in doc_iterator:
 
         unpackdata = unpack_document( item, max_words )
-        numwords = len( unpackdata.words )
-        if numwords > 0:
+        if unpackdata.words:
+            assert len(unpackdata.words) > 1
             candidatewords = unpackdata.words[-2].wnum
-            if candidatewords > totwords:
-                totwords = candidatewords
+            totwords = max(totwords, candidatewords)
 
         docnum = unpackdata.doc_label
 
     totdoc = len( doclist )
 
     carraydoc = ( POINTER( DOC ) * totdoc )()
+    carraylabel = ( c_double * totdoc )() 
           
-    counter = 0
-    for item in iter( tempdoclist ):
-
-        carraydoc[ counter ] = item
-        counter += 1
-
-    carraylabel = ( c_double * totdoc )() 
-
-    counter = 0
-    for item in iter( templabellist ):
-        carraylabel[ counter ] = item
-        counter += 1 
+    for i, item in enumerate( tempdoclist ):
+        carraydoc[ i ] = item
+        
+    for i, item in enumerate( templabellist ):
+        carraylabel[ i ] = item
 
     result = DOCLISTDATA()
 
 
     model = MODEL()
 
-    kernel_cache = None
+    # this is a bit of a hack because of some slight nastiness in the C code, comparing
+    # against the address of a null pointer
+    kernel_cache = c_int( 0 )
+    
     if client_data.kparm.kernel_type != LINEAR:
         kernel_cache = svm.kernel_cache_init( doclistdata.totdoc, 
                                               client_data.plearn.kernel_cache_size )
 
     svm_call_tuple = SVMCallTuple( doclistdata, client_data, kernel_cache, model )
 
-    if client_data.plearn.type == CLASSIFICATION:
-        call_svm_method_with_null( "svm_learn_classification", svm_call_tuple )
-    elif client_data.plearn.type == OPTIMIZATION:
-        call_svm_method_with_null( "svm_learn_optimization", svm_call_tuple )
-    elif client_data.plearn.type == REGRESSION:
-        call_svm_method_without_null( "svm_learn_regression", svm_call_tuple )
-    elif client_data.plearn.type == RANKING:
-        call_svm_method_without_null( "svm_learn_ranking", svm_call_tuple )
+    call_pattern = call_svm_method_with_null
+    if client_data.plearn.type in [ REGRESSION, RANKING ]:
+        call_pattern = call_svm_method_without_null
+
+    svm_method_name = "svm_learn_" + method_dict[ client_data.plearn.type ]
+    call_pattern( svm_method_name, svm_call_tuple )
 
     result = LearnResultsTuple( model, doclistdata.docs, doclistdata.totdoc )
     return result 
-    # return (learn_parm, kernal_parm)
 # ----------------------------------------------
 
 def call_svm_method_with_null( method_name, svm_call_tuple ):
     try:
         doc_iterator = iter(doclist)
     except TypeError, te:
-        raise Exception("Not iterable")
+        raise Exception, "Not iterable"
 
     docnum = 0
     dist = None
     if has_linear_kernel:
         svm.add_weight_vector_to_linear_model( pointer( model ) )
 
-    (max_docs, max_words) = count_doclist( doclist )
+    max_docs, max_words = count_doclist( doclist )
 
     result = []
     for item in doc_iterator:
 
 # -------------------- MAIN --------------------
 
+# example main function. Assumes the existence of the file 'localdata.py'
+# which contains train0 and test0 data list entries
+
+'''
 import localdata
 
 if __name__ == "__main__":
     training_data = localdata.train0
     test_data = localdata.test0
-    learn_results_tuple = svm_learn( training_data, type='classification' )
+    learn_results_tuple = svm_learn( training_data, type='optimization' )
 
-    print( "Begin write model" )
     write_model( learn_results_tuple.model, 'my_python_model.dat')
-    print( "End write model" )
 
     predictions = svm_classify( learn_results_tuple.model, test_data)
     for p in predictions:
          print '%.8f' % p
 
+    # As things stand, this will fail because pickle cannot work
+    # with Ctypes objects that use pointers
     with open("model.pickle", 'wb') as f:
         pickle.dump( learn_results_tuple.model, f)
+'''
 
 # ----------------------------------------------
 
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.