Ceri Stagg avatar Ceri Stagg committed d9a975f

Add svm_classify method. Conversion complete.

Comments (0)

Files changed (1)

svmlight/svmlight.py

 
 # specify return types for key methods in the svm C library
 svm.sprod_ss.restype = c_double
+svm.classify_example_linear.restype = c_double
 svm.read_model.restype = POINTER( MODEL )
+
 # ----------------------------------------------
 
 ''' This auxiliary function to svm_learn reads some parameters from the keywords to
     for feature_pair in words_list:
         if len( words ) >= max_words_doc:
             break 
-        wordtuple = WordTuple( feature_pair[0], feature_pair[1] )
+        wordtuple = WordTuple( int( feature_pair[0] ), feature_pair[1] )
         words.append( wordtuple )
 
     # sentinel entry required by C code
     templabellist = []
     totwords = 0
     for item in doc_iterator:
+
         unpackdata = unpack_document( item, max_words )
         numwords = len( unpackdata.words )
         if numwords > 0:
 # ----------------------------------------------
 
 LearnResultsTuple = namedtuple( 'LearnResultsTuple', 'model docs totdoc' )
+# ----------------------------------------------
 
 def svm_learn( doclist, **kwds):
     
     return pmodel.contents
 # ----------------------------------------------
 
+def svm_classify( model, doclist ):
+
+    try:
+        doc_iterator = iter(doclist)
+    except TypeError, te:
+        raise Exception("Not iterable")
+
+    docnum = 0
+    dist = None
+
+    has_linear_kernel = ( model.kernel_parm.kernel_type == 0 )
+
+    if has_linear_kernel:
+        svm.add_weight_vector_to_linear_model( pointer( model ) )
+
+    (max_docs, max_words) = count_doclist( doclist )
+
+    result = []
+    for item in doc_iterator:
+        unpackdata = unpack_document( item, max_words )
+
+        if has_linear_kernel:
+
+            for doc_item in unpackdata.words:
+  
+                if doc_item.wnum == 0:
+                    #sentinel entry
+                    break
+                if doc_item.wnum > model.totwords:
+                    doc_item.wnum = 0
+
+            svector = create_svector( unpackdata.words, "", 1.0 )
+            doc = create_example( -1, 0, 0, 0.0, svector )
+            dist = svm.classify_example_linear( pointer(model), pointer(doc) )
+        else:
+            svector = create_svector( unpackdata.words, "", 1.0 )
+            doc = create_example( -1, 0, 0, 0.0, svector )
+            dist = svm.classify_example( pointer(model), pointer(doc) )
+
+        result.append( dist )
+
+    return result
+
+        
+
 # -------------------- MAIN --------------------
 if __name__ == "__main__":
     training_data = localdata.train0
     test_data = localdata.test0
     learn_results_tuple = svm_learn( training_data, type='classification' )
 
-    print( "Begin read model" )
-    temp_model = read_model('my_python_model.dat')
-    print( "End read model" )
+    print( "Begin write model" )
+    write_model( learn_results_tuple.model, 'my_python_model.dat')
+    print( "End write model" )
 
-    print( "Begin write model" )
-    write_model( temp_model, 'my_python_model_duplicate.dat')
-    print( "End write model" )
+    predictions = svm_classify( learn_results_tuple.model, test_data)
+    for p in predictions:
+         print '%.8f' % p
 
     with open("model.pickle", 'wb') as f:
         pickle.dump( learn_results_tuple.model, f)
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.