Jacob Perkins avatar Jacob Perkins committed b2d97d8

sklearn classification algorithms with scikitlearn.SklearnClassifier training

Comments (0)

Files changed (3)

analyze_classifier_coverage.py

 #!/usr/bin/python
 import argparse, collections, itertools, operator, re, string
+import cPickle as pickle
 import nltk.data
 from nltk.classify.util import accuracy
 from nltk.corpus import stopwords
 ## text extraction ##
 #####################
 
-classifier = nltk.data.load(args.classifier)
+try:
+	classifier = nltk.data.load(args.classifier)
+except LookupError:
+	classifier = pickle.load(open(args.classifier))
 
 if args.metrics:
 	label_instance_function = {

nltk_trainer/classification/args.py

 classifier_choices = ['NaiveBayes', 'DecisionTree', 'Maxent'] + MaxentClassifier.ALGORITHMS
 
 try:
-	from .sci import ScikitsClassifier
-	classifier_choices.append('Scikits')
-except ImportError:
-	pass
+	from nltk.classify import scikitlearn
+	from sklearn.pipeline import Pipeline
+	from sklearn import linear_model, naive_bayes, neighbors, svm, tree
+	
+	classifiers = [
+		linear_model.LogisticRegression,
+		#linear_model.SGDClassifier, # NOTE: this seems terrible, but could just be the options
+		naive_bayes.BernoulliNB,
+		#naive_bayes.GaussianNB, # TODO: requires a dense matrix
+		naive_bayes.MultinomialNB,
+		neighbors.KNeighborsClassifier, # TODO: options for nearest neighbors
+		svm.LinearSVC,
+		svm.NuSVC,
+		svm.SVC,
+		#tree.DecisionTreeClassifier, # TODO: requires a dense matrix
+	]
+	sklearn_classifiers = {}
+	
+	for classifier in classifiers:
+		sklearn_classifiers[classifier.__name__] = classifier
+	
+	classifier_choices.extend(sorted(['sklearn.%s' % c.__name__ for c in classifiers]))
+except ImportError as exc:
+	sklearn_classifiers = {}
 
 def add_maxent_args(parser):
 	maxent_group = parser.add_argument_group('Maxent Classifier',
 	decisiontree_group.add_argument('--support_cutoff', default=10, type=int,
 		help='default is 10')
 
+sklearn_kwargs = {}
+
+def add_sklearn_args(parser):
+	if not sklearn_classifiers: return
+	
+	sklearn_group = parser.add_argument_group('sklearn Classifiers',
+		'These options are common to many of the sklearn classification algorithms.')
+	sklearn_group.add_argument('--alpha', type=float, default=1.0,
+		help='smoothing parameter for naive bayes classifiers, default is %(default)s')
+	sklearn_group.add_argument('--C', type=float, default=1.0,
+		help='penalty parameter, default is %(default)s')
+	sklearn_group.add_argument('--penalty', choices=['l1', 'l2'],
+		default='l2', help='norm for penalization, default is %(default)s')
+	sklearn_group.add_argument('--kernel', default='rbf',
+		choices=['linear', 'poly', 'rbf', 'sigmoid', 'precomputed'],
+		help='kernel type for support vector machine classifiers, default is %(default)s')
+	
+	sklearn_kwargs['LogisticRegression'] = ['C','penalty']
+	sklearn_kwargs['BernoulliNB'] = ['alpha']
+	sklearn_kwargs['MultinomialNB'] = ['alpha']
+	sklearn_kwargs['SVC'] = ['C', 'kernel']
+	
+	linear_svc_group = parser.add_argument_group('sklearn Linear Support Vector Machine Classifier',
+		'These options only apply when a sklearn.LinearSVC classifier is chosen.')
+	linear_svc_group.add_argument('--loss', choices=['l1', 'l2'],
+		default='l2', help='loss function, default is %(default)s')
+	sklearn_kwargs['LinearSVC'] = ['C', 'loss', 'penalty']
+	
+	nu_svc_group = parser.add_argument_group('sklearn Nu Support Vector Machine Classifier',
+		'These options only apply when a sklearn.NuSVC classifier is chosen.')
+	nu_svc_group.add_argument('--nu', type=float, default=0.5,
+		help='upper bound on fraction of training errors & lower bound on fraction of support vectors, default is %(default)s')
+	sklearn_kwargs['NuSVC'] = ['nu', 'kernel']
+
+def make_sklearn_classifier(algo, args):
+	name = algo.split('.', 1)[1]
+	kwargs = {}
+	
+	for key in sklearn_kwargs.get(name, []):
+		val = getattr(args, key)
+		if val is not None: kwargs[key] = val
+	
+	if args.trace and kwargs:
+		print 'training %s with %s' % (algo, kwargs)
+	
+	return sklearn_classifiers[name](**kwargs)
+
 def make_classifier_builder(args):
 	if isinstance(args.classifier, basestring):
 		algos = [args.classifier]
 			classifier_train_kwargs['verbose'] = args.trace
 		elif algo == 'NaiveBayes':
 			classifier_train = NaiveBayesClassifier.train
-		elif algo == 'Scikits':
-			classifier_train = ScikitsClassifier.train
+		elif algo.startswith('sklearn.'):
+			# TODO: support many options for building an estimator pipeline
+			estimator = Pipeline([('classifier', make_sklearn_classifier(algo, args))])
+			# TODO: option for dtype
+			classifier_train = scikitlearn.SklearnClassifier(estimator, dtype=bool).train
 		else:
 			if algo != 'Maxent':
 				classifier_train_kwargs['algorithm'] = algo

train_classifier.py

 
 nltk_trainer.classification.args.add_maxent_args(parser)
 nltk_trainer.classification.args.add_decision_tree_args(parser)
+nltk_trainer.classification.args.add_sklearn_args(parser)
 
 args = parser.parse_args()
 
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.