您好,欢迎访问三七文档
当前位置:首页 > 行业资料 > 冶金工业 > 如何用卷积神经网络CNN识别手写数字集?
如何用卷积神经网络CNN识别手写数字集?前几天用CNN识别手写数字集,后来看到kaggle上有一个比赛是识别手写数字集的,已经进行了一年多了,目前有1179个有效提交,最高的是100%,我做了一下,用keras做的,一开始用最简单的MLP,准确率只有98.19%,然后不断改进,现在是99.78%,然而我看到排名第一是100%,心碎==,于是又改进了一版,现在把最好的结果记录一下,如果提升了再来更新。手写数字集相信大家应该很熟悉了,这个程序相当于学一门新语言的“HelloWorld”,或者mapreduce的“WordCount”:)这里就不多做介绍了,简单给大家看一下:复制代码1#Author:Charlotte2#Plotmnistdataset3fromkeras.datasetsimportmnist4importmatplotlib.pyplotasplt5#loadtheMNISTdataset6(X_train,y_train),(X_test,y_test)=mnist.load_data()7#plot4imagesasgrayscale8plt.subplot(221)9plt.imshow(X_train[0],cmap=plt.get_cmap('PuBuGn_r'))10plt.subplot(222)11plt.imshow(X_train[1],cmap=plt.get_cmap('PuBuGn_r'))12plt.subplot(223)13plt.imshow(X_train[2],cmap=plt.get_cmap('PuBuGn_r'))14plt.subplot(224)15plt.imshow(X_train[3],cmap=plt.get_cmap('PuBuGn_r'))16#showtheplot17plt.show()复制代码图:1.BaseLine版本一开始我没有想过用CNN做,因为比较耗时,所以想看看直接用比较简单的算法看能不能得到很好的效果。之前用过机器学习算法跑过一遍,最好的效果是SVM,96.8%(默认参数,未调优),所以这次准备用神经网络做。BaseLine版本用的是MultiLayerPercepton(多层感知机)。这个网络结构比较简单,输入---隐含---输出。隐含层采用的rectifierlinearunit,输出直接选取的softmax进行多分类。网络结构:代码:复制代码1#coding:utf-82#BaselineMLPforMNISTdataset3importnumpy4fromkeras.datasetsimportmnist5fromkeras.modelsimportSequential6fromkeras.layersimportDense7fromkeras.layersimportDropout8fromkeras.utilsimportnp_utils910seed=711numpy.random.seed(seed)12#加载数据13(X_train,y_train),(X_test,y_test)=mnist.load_data()1415num_pixels=X_train.shape[1]*X_train.shape[2]16X_train=X_train.reshape(X_train.shape[0],num_pixels).astype('float32')17X_test=X_test.reshape(X_test.shape[0],num_pixels).astype('float32')1819X_train=X_train/25520X_test=X_test/2552122#对输出进行onehot编码23y_train=np_utils.to_categorical(y_train)24y_test=np_utils.to_categorical(y_test)25num_classes=y_test.shape[1]2627#MLP模型28defbaseline_model():29model=Sequential()30model.add(Dense(num_pixels,input_dim=num_pixels,init='normal',activation='relu'))31model.add(Dense(num_classes,init='normal',activation='softmax'))32model.summary()33model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])34returnmodel3536#建立模型37model=baseline_model()3839#Fit40model.fit(X_train,y_train,validation_data=(X_test,y_test),nb_epoch=10,batch_size=200,verbose=2)4142#Evaluation43scores=model.evaluate(X_test,y_test,verbose=0)44print(BaselineError:%.2f%%%(100-scores[1]*100))#输出错误率复制代码结果:复制代码1Layer(type)OutputShapeParam#Connectedto2====================================================================================================3dense_1(Dense)(None,784)615440dense_input_1[0][0]4____________________________________________________________________________________________________5dense_2(Dense)(None,10)7850dense_1[0][0]6====================================================================================================7Totalparams:6232908____________________________________________________________________________________________________9Trainon60000samples,validateon10000samples10Epoch1/10113s-loss:0.2791-acc:0.9203-val_loss:0.1420-val_acc:0.957912Epoch2/10133s-loss:0.1122-acc:0.9679-val_loss:0.0992-val_acc:0.969914Epoch3/10153s-loss:0.0724-acc:0.9790-val_loss:0.0784-val_acc:0.974516Epoch4/10173s-loss:0.0509-acc:0.9853-val_loss:0.0774-val_acc:0.977318Epoch5/10193s-loss:0.0366-acc:0.9898-val_loss:0.0626-val_acc:0.979420Epoch6/10213s-loss:0.0265-acc:0.9930-val_loss:0.0639-val_acc:0.979722Epoch7/10233s-loss:0.0185-acc:0.9956-val_loss:0.0611-val_acc:0.981124Epoch8/10253s-loss:0.0150-acc:0.9967-val_loss:0.0616-val_acc:0.981626Epoch9/10274s-loss:0.0107-acc:0.9980-val_loss:0.0604-val_acc:0.982128Epoch10/10294s-loss:0.0073-acc:0.9988-val_loss:0.0611-val_acc:0.981930BaselineError:1.81%复制代码可以看到结果还是不错的,正确率98.19%,错误率只有1.81%,而且只迭代十次效果也不错。这个时候我还是没想到去用CNN,而是想如果迭代100次,会不会效果好一点?于是我迭代了100次,结果如下:Epoch100/1008s-loss:4.6181e-07-acc:1.0000-val_loss:0.0982-val_acc:0.9854BaselineError:1.46%从结果中可以看出,迭代100次也只提高了0.35%,没有突破99%,所以就考虑用CNN来做。2.简单的CNN网络keras的CNN模块还是很全的,由于这里着重讲CNN的结果,对于CNN的基本知识就不展开讲了。网络结构:代码:复制代码1#coding:utf-82#SimpleCNN3importnumpy4fromkeras.datasetsimportmnist5fromkeras.modelsimportSequential6fromkeras.layersimportDense7fromkeras.layersimportDropout8fromkeras.layersimportFlatten9fromkeras.layers.convolutionalimportConvolution2D10fromkeras.layers.convolutionalimportMaxPooling2D11fromkeras.utilsimportnp_utils1213seed=714numpy.random.seed(seed)1516#加载数据17(X_train,y_train),(X_test,y_test)=mnist.load_data()18#reshapetobe[samples][channels][width][height]19X_train=X_train.reshape(X_train.shape[0],1,28,28).astype('float32')20X_test=X_test.reshape(X_test.shape[0],1,28,28).astype('float32')2122#normalizeinputsfrom0-255to0-123X_train=X_train/25524X_test=X_test/2552526#onehotencodeoutputs27y_train=np_utils.to_categorical(y_train)28y_test=np_utils.to_categorical(y_test)29num_classes=y_test.shape[1]3031#defineasimpleCNNmodel32defbaseline_model():33#createmodel34model=Sequential()35model.add(Convolution2D(32,5,5,border_mode='valid',input_shape=(1,28,28),activation='relu'))36model.add(MaxPooling2D(pool_size=(2,2)))37model.add(Dropout(0.2))38model.add(Flatten())39model.add(Dense(128,act
本文标题:如何用卷积神经网络CNN识别手写数字集?
链接地址:https://www.777doc.com/doc-5141636 .html