Notations:

  • : input vector
  • : training sample
  • : decision function

In classification problem, our goal is to estimate accurately. Sometimes we need to use nonliner function for . In such case, RBF (Gaussian) kernel model defined as,

where is reasonable and useful.

In this article, I introduce a wrapper class, rbfmodel_wrapper.py, for some scikit-learn’s classifiers to use Gaussian kernel model. Using this class, we can easily make classifiers nonlinear. For example, Logistic Regression is nonlinearized by,

clf=RbfModelWrapper(LogisticRegression(),gamma=1.)

We can also use GridSearch for hyperparameter selection as,

grid=GridSearchCV(RbfModelWrapper(LogisticRegression(),param_grid={"gamma":np.logspace(-2,0,9),"C":[1,10,100]}))

I show a demo code below.

import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn import datasets
from sklearn.grid_search import GridSearchCV
import matplotlib.pyplot as plt
np.random.seed(1)

n=500
X,y=datasets.make_moons(n_samples=n,noise=.05)
idx=np.random.permutation(n)
ntr=np.int32(n*0.7)
itr=idx[:ntr]
ite=idx[ntr:]

param_grid={"gamma":np.logspace(-2,0,9)}
grid=GridSearchCV(RbfModelWrapper(LogisticRegression()),param_grid=param_grid)
grid.fit(X[itr],y[itr])
clf=grid.best_estimator_
print "accuracy:",clf.score(X[ite],y[ite])

offset=.5
xx,yy=np.meshgrid(np.linspace(X[:,0].min()-offset,X[:,0].max()+offset,300),
                  np.linspace(X[:,1].min()-offset,X[:,1].max()+offset,300))

Z=clf.predict(np.c_[xx.ravel(),yy.ravel()])
Z=Z.reshape(xx.shape)

a=plt.contour(xx, yy, Z, levels=[0.5], linewidths=2, colors='green')
b1=plt.scatter(X[y==1][:,0],X[y==1][:,1],color="blue",s=40)
b2=plt.scatter(X[y==0][:,0],X[y==0][:,1],color="red",s=40)
plt.axis("tight")
plt.xlim((X[:,0].min()-offset,X[:,0].max()+offset))
plt.ylim((X[:,1].min()-offset,X[:,1].max()+offset))
plt.legend([a.collections[0],b1,b2],
           [r"decision boundary","positive","unlabeled"],
           prop={"size":10},loc="lower right")
plt.tight_layout()
plt.show()

The result is like this:

accuracy: 1.0

image name