博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Keras手写识别例子(1)----softmax
阅读量:4962 次
发布时间:2019-06-12

本文共 2691 字,大约阅读时间需要 8 分钟。

转自:

下载数据

# download the mnist to the path '~/.keras/datasets/' if it is the first time to be called

# X shape (60,000 28x28), y shape (10,000, )
(X_train, y_train), (X_test, y_test) = mnist.load_data()

data预处理:

X_train = X_train.reshape(X_train.shape[0], -1) / 255.   # normalize

X_test = X_test.reshape(X_test.shape[0], -1) / 255.      # normalize
y_train = np_utils.to_categorical(y_train, num_classes=10)
y_test = np_utils.to_categorical(y_test, num_classes=10)

 

导入包:

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("./", one_hot=True)
X_train=mnist.train.images
Y_train=mnist.train.labels
X_test=mnist.test.images
Y_test=mnist.test.labels

因为(X_train, y_train), (X_test, y_test) = mnist.load_data()需从网上下载数据,由于网络限制,下载失败。

可以先在官网上下载四个数据(

在当前目录,不要解压!

#input_data.py该模块在tensorflow.examples.tutorials.mnist下,直接加载来读取上面四个压缩包。

#四个压缩包形式为特殊形式。非图片和标签,要解析。

from tensorflow.examples.tutorials.mnist import input_data

#加载数据路径为"./",为当前路径,自动加载数据,用one-hot方式处理好数据。

#read_data_sets是input_data.py里面的一个函数,主要是将数据解压之后,放到对应的位置。 第一个参数为路径,写"./"表示当前路径,其会判断该路径下有没有数据,没有的话会自动下载数据。

mnist = input_data.read_data_sets("./", one_hot=True)  

 

相关的包:

model.Sequential():用来一层一层的去建立神经层。

layers.Dense,表示这个神经层是全连接层。

layers.Activation,激励函数

optimizers.RMSprop,优化器采用RMSprop,加速神经网络训练方法。

Keras工作流程:

  1. 定义训练数据:输入张量和目标张量
  2. 定义层组成的网络(或模型),将输入映射到目标
  3. 配置学习过程:选择损失函数、优化器和需要监控的指标
  4. 调用模型的fit方法在训练数据上进行迭代

代码:

import numpy as npnp.random.seed(1337)  # for reproducibilityfrom keras.datasets import mnist from keras.models import Sequentialfrom keras.layers import Dense, Activationfrom keras.optimizers import RMSprop#读取数据,其中,X_train为55000*784,Y_train为55000*10,X_test为10000*784,Y_test大小为10000*10.from tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("./", one_hot=True)X_train=mnist.train.imagesY_train=mnist.train.labelsX_test=mnist.test.imagesY_test=mnist.test.labels #建立神经网络模型,一共两层,第一层输入784个变量,输出为32,激活函数为relu,第二层输入是上层的输出32,输出为10,激活函数为softmax。model = Sequential([    Dense(32, input_dim=784),    Activation('relu'),    Dense(10),    Activation('softmax'),])#采用RMSprop来求解模型,设学习率lr为0.001,以及别的参数。rmsprop = RMSprop(lr=0.001, rho=0.9, epsilon=1e-08, decay=0.0)#激活模型,优化器为rmsprop,损失函数为交叉熵,metric,里面可以放入需要计算的,比如cost、accuracy、score等model.compile(optimizer=rmsprop,              loss='categorical_crossentropy',              metrics=['accuracy'])#训练网络,用fit函数,导入数据,训练次数为20,每批处理32个model.fit(X_train, Y_train, nb_epoch=20, batch_size=32)#测试模型print('\nTesting ------------')# Evaluate the model with the metrics we defined earlierloss, accuracy = model.evaluate(X_test, Y_test)print('test loss: ', loss)print('test accuracy: ', accuracy)

 结果:

 

 

 

转载于:https://www.cnblogs.com/Lee-yl/p/8572443.html

你可能感兴趣的文章
Java中的return this
查看>>
Java调用淘宝API demo源代码
查看>>
跟我学Android之十一 列表和适配器
查看>>
JS学习
查看>>
Dreamweaver安装须知
查看>>
SQL Server 2008 Database Mirror - DB 镜像 - Some Key Learnings
查看>>
iPhone开机键坏了如何开机
查看>>
从C 到 OC----从面向过程到面向对象的转变
查看>>
Code analysis 笔记
查看>>
SSMS 远程连接SERVER 设置 - Unable to connect to SQL Server instance remotely
查看>>
jqGrid 自定义搜索
查看>>
结对开发地铁
查看>>
linux(centos)设置tomcat开机启动
查看>>
操作使用的常见的问题集合 http://bbs.ecshop.com/thread-95341-1-1.html
查看>>
BZOJ 2467 生成树(组合数学)
查看>>
dedecms关键词维护里面字数多的词优先字数少的词的解决办法 相关案例演示
查看>>
eclipse和android studio的目录结构分析
查看>>
我的第一个canvas的作品:漫画对白编辑器
查看>>
NYOJ题目100 1的个数
查看>>
用字符串连接SQL语句并用EXEC执行时,出现名称 '‘不是有效的标识符
查看>>