A Wrapper Class of RBF (Gaussian) Kernel Model for scikit-learn's Classifiers
Notations:
- : input vector
- : training sample
- : decision function
In classification problem, our goal is to estimate accurately. Sometimes we need to use nonliner function for . In such case, RBF (Gaussian) kernel model defined as,
where is reasonable and useful.
In this article, I introduce a wrapper class, rbfmodel_wrapper.py, for some scikit-learn’s classifiers to use Gaussian kernel model. Using this class, we can easily make classifiers nonlinear. For example, Logistic Regression is nonlinearized by,
clf=RbfModelWrapper(LogisticRegression(),gamma=1.)
We can also use GridSearch for hyperparameter selection as,
grid=GridSearchCV(RbfModelWrapper(LogisticRegression(),param_grid={"gamma":np.logspace(-2,0,9),"C":[1,10,100]}))
I show a demo code below.
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn import datasets
from sklearn.grid_search import GridSearchCV
import matplotlib.pyplot as plt
np.random.seed(1)
n=500
X,y=datasets.make_moons(n_samples=n,noise=.05)
idx=np.random.permutation(n)
ntr=np.int32(n*0.7)
itr=idx[:ntr]
ite=idx[ntr:]
param_grid={"gamma":np.logspace(-2,0,9)}
grid=GridSearchCV(RbfModelWrapper(LogisticRegression()),param_grid=param_grid)
grid.fit(X[itr],y[itr])
clf=grid.best_estimator_
print "accuracy:",clf.score(X[ite],y[ite])
offset=.5
xx,yy=np.meshgrid(np.linspace(X[:,0].min()-offset,X[:,0].max()+offset,300),
np.linspace(X[:,1].min()-offset,X[:,1].max()+offset,300))
Z=clf.predict(np.c_[xx.ravel(),yy.ravel()])
Z=Z.reshape(xx.shape)
a=plt.contour(xx, yy, Z, levels=[0.5], linewidths=2, colors='green')
b1=plt.scatter(X[y==1][:,0],X[y==1][:,1],color="blue",s=40)
b2=plt.scatter(X[y==0][:,0],X[y==0][:,1],color="red",s=40)
plt.axis("tight")
plt.xlim((X[:,0].min()-offset,X[:,0].max()+offset))
plt.ylim((X[:,1].min()-offset,X[:,1].max()+offset))
plt.legend([a.collections[0],b1,b2],
[r"decision boundary","positive","unlabeled"],
prop={"size":10},loc="lower right")
plt.tight_layout()
plt.show()
The result is like this:
accuracy: 1.0