Ceri Stagg avatar Ceri Stagg committed 6a77cdd

Add write_model method

Comments (0)

Files changed (1)

svmlight/svmlight.py

 
 
 UnpackData = namedtuple('Unpackdata', 'words doc_label queryid slackid costfactor')
+# ----------------------------------------------
 
 WordTuple = namedtuple('WordTuple', 'wnum weight' )
 # ----------------------------------------------
 
     # We initialize these parameters with their default values, since we won't
     # be reading them from the feature pairs (don't really care).
-    queryid = 0
-    slackid = 0
-    costfactor = 1
+    queryid, slackid, costfactor = 0, 0, 1
 
     if type(docobj) != tuple:
         raise Exception("document should be a tuple")
 
-    label     = docobj[0]
-    words_list = docobj[1]
+    label, words_list = docobj[0], docobj[1]
     if len( docobj ) > 2:
         queryid = docobj[2]
 
 # ----------------------------------------------
 
 SVMCallTuple = namedtuple('SVMCallTuple', 'doclistdata client_data kernel_cache model' )
+# ----------------------------------------------
 
 LearnResultsTuple = namedtuple( 'LearnResultsTuple', 'model docs totdoc' )
 
             svm_call_tuple.kernel_cache,
             pointer( svm_call_tuple.model ),
             None )
+# ----------------------------------------------
 
 def call_svm_method_without_null( method_name, svm_call_tuple ):
 
             pointer( svm_call_tuple.client_data.kparm ),
             pointer( svm_call_tuple.kernel_cache ),
             pointer( svm_call_tuple.model ) )
+# ----------------------------------------------
 
+def write_model( model, filename ):
+    filename_as_c_string = generate_C_string_from_python( filename )
+    svm.write_model( filename, pointer( model ) )
+# ----------------------------------------------
 
 # -------------------- MAIN --------------------
 if __name__ == "__main__":
     training_data = localdata.train0
     test_data = localdata.test0
-    model = svm_learn( training_data, type='classification' )
+    learn_results_tuple = svm_learn( training_data, type='classification' )
     
-    # write_model(model, 'my_model.dat')
+    print( "Begin write model" )
+    write_model( learn_results_tuple.model, 'my_python_model.dat')
+    print( "End write model" )
 
     with open("model.pickle", 'wb') as f:
-        pickle.dump(model, f)
+        pickle.dump( learn_results_tuple.model, f)
 # ----------------------------------------------
 
 
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.