![](https://bohrium.oss-cn-zhangjiakou.aliyuncs.com/article/14076/3ede223eb868435588437d18e31af318/bfa82193-a864-45cb-a26c-9c49f44e25a7.png?x-oss-process=image/resize,w_100,m_lfit)
![](https://cdn1.deepmd.net/bohrium/web/static/images/level-v2-4.png?x-oss-process=image/resize,w_50,m_lfit)
快速开始 TensorFlow2|面向初学者的快速入门
©️ Copyright 2023 @ Authors
作者:阙浩辉 📨
日期:2023-05-09
共享协议:本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可。
🎯 本教程旨在快速掌握使用 TensorFlow 建立神经网络机器学习模型的范式周期。
一键运行,你可以快速在实践中检验你的想法。
丰富完善的注释,对于入门者友好。
在 Bohrium Notebook 界面,你可以点击界面上方蓝色按钮 开始连接
,选择 bohrium-notebook
镜像及任何一款节点配置,稍等片刻即可运行。
目标
使用 TensorFlow 快速建立一个神经网络机器学习模型。
在学习本教程后,你将能够:
- 加载一个预构建的数据集。
- 构建对图像进行分类的神经网络机器学习模型。
- 训练此神经网络。
- 评估模型的准确率。
阅读该教程【最多】约需 5 分钟,让我们开始吧!
1.2 安装 TensorFlow
本教程是一个 Bohrium Notebook。Python 程序可直接在浏览器中运行,Bohrium 已安装 TensorFlow。这是学习和使用 TensorFlow 的好方法。
要按照本教程进行操作,请点击本页顶部的按钮,在 Bohrium Notebook 中运行本笔记本。
你可以点击界面上方蓝色按钮
开始连接
,选择bohrium-notebook
镜像及任何一款计算机型,稍等片刻即可运行。若要运行笔记本中的所有代码,请点击左上角“ 运行全部单元格 ”。若要一次运行一个代码单元,请选择需要运行的单元格,然后点击左上角 “运行选中的单元格” 图标。
如果你的 Bohrium 镜像尚未安装 TensorFlow, 最方便的方法是通过 pip 安装:
【注意】TensorFlow 2 仅支持 Ubuntu 和 Windows !
Requirement already satisfied: tensorflow in /opt/conda/lib/python3.8/site-packages (2.11.0) Requirement already satisfied: absl-py>=1.0.0 in /opt/conda/lib/python3.8/site-packages (from tensorflow) (1.4.0) Requirement already satisfied: libclang>=13.0.0 in /opt/conda/lib/python3.8/site-packages (from tensorflow) (15.0.6.1) Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /opt/conda/lib/python3.8/site-packages (from tensorflow) (0.30.0) Requirement already satisfied: google-pasta>=0.1.1 in /opt/conda/lib/python3.8/site-packages (from tensorflow) (0.2.0) Requirement already satisfied: grpcio<2.0,>=1.24.3 in /opt/conda/lib/python3.8/site-packages (from tensorflow) (1.51.3) Requirement already satisfied: numpy>=1.20 in /opt/conda/lib/python3.8/site-packages (from tensorflow) (1.22.4) Requirement already satisfied: termcolor>=1.1.0 in /opt/conda/lib/python3.8/site-packages (from tensorflow) (2.2.0) Requirement already satisfied: six>=1.12.0 in /opt/conda/lib/python3.8/site-packages (from tensorflow) (1.15.0) Requirement already satisfied: opt-einsum>=2.3.2 in /opt/conda/lib/python3.8/site-packages (from tensorflow) (3.3.0) Requirement already satisfied: keras<2.12,>=2.11.0 in /opt/conda/lib/python3.8/site-packages (from tensorflow) (2.11.0) Requirement already satisfied: flatbuffers>=2.0 in /opt/conda/lib/python3.8/site-packages (from tensorflow) (23.1.21) Requirement already satisfied: h5py>=2.9.0 in /opt/conda/lib/python3.8/site-packages (from tensorflow) (3.1.0) Requirement already satisfied: gast<=0.4.0,>=0.2.1 in /opt/conda/lib/python3.8/site-packages (from tensorflow) (0.4.0) Requirement already satisfied: packaging in /opt/conda/lib/python3.8/site-packages (from tensorflow) (23.0) Requirement already satisfied: protobuf<3.20,>=3.9.2 in /opt/conda/lib/python3.8/site-packages (from tensorflow) (3.19.6) Requirement already satisfied: tensorflow-estimator<2.12,>=2.11.0 in /opt/conda/lib/python3.8/site-packages (from tensorflow) (2.11.0) Requirement already satisfied: typing-extensions>=3.6.6 in /opt/conda/lib/python3.8/site-packages (from tensorflow) (4.5.0) Requirement already satisfied: tensorboard<2.12,>=2.11 in /opt/conda/lib/python3.8/site-packages (from tensorflow) (2.11.2) Requirement already satisfied: astunparse>=1.6.0 in /opt/conda/lib/python3.8/site-packages (from tensorflow) (1.6.3) Requirement already satisfied: setuptools in /opt/conda/lib/python3.8/site-packages (from tensorflow) (64.0.2) Requirement already satisfied: wrapt>=1.11.0 in /opt/conda/lib/python3.8/site-packages (from tensorflow) (1.14.1) Requirement already satisfied: wheel<1.0,>=0.23.0 in /opt/conda/lib/python3.8/site-packages (from astunparse>=1.6.0->tensorflow) (0.37.1) Requirement already satisfied: requests<3,>=2.21.0 in /opt/conda/lib/python3.8/site-packages (from tensorboard<2.12,>=2.11->tensorflow) (2.28.2) Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /opt/conda/lib/python3.8/site-packages (from tensorboard<2.12,>=2.11->tensorflow) (0.6.1) Requirement already satisfied: werkzeug>=1.0.1 in /opt/conda/lib/python3.8/site-packages (from tensorboard<2.12,>=2.11->tensorflow) (1.0.1) Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /opt/conda/lib/python3.8/site-packages (from tensorboard<2.12,>=2.11->tensorflow) (0.4.6) Requirement already satisfied: google-auth<3,>=1.6.3 in /opt/conda/lib/python3.8/site-packages (from tensorboard<2.12,>=2.11->tensorflow) (2.16.1) Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /opt/conda/lib/python3.8/site-packages (from tensorboard<2.12,>=2.11->tensorflow) (1.8.1) Requirement already satisfied: markdown>=2.6.8 in /opt/conda/lib/python3.8/site-packages (from tensorboard<2.12,>=2.11->tensorflow) (3.4.1) Requirement already satisfied: pyasn1-modules>=0.2.1 in /opt/conda/lib/python3.8/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.12,>=2.11->tensorflow) (0.2.8) Requirement already satisfied: rsa<5,>=3.1.4 in /opt/conda/lib/python3.8/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.12,>=2.11->tensorflow) (4.9) Requirement already satisfied: cachetools<6.0,>=2.0.0 in /opt/conda/lib/python3.8/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.12,>=2.11->tensorflow) (5.3.0) Requirement already satisfied: requests-oauthlib>=0.7.0 in /opt/conda/lib/python3.8/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.12,>=2.11->tensorflow) (1.3.1) Requirement already satisfied: importlib-metadata>=4.4 in /opt/conda/lib/python3.8/site-packages (from markdown>=2.6.8->tensorboard<2.12,>=2.11->tensorflow) (6.0.0) Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.8/site-packages (from requests<3,>=2.21.0->tensorboard<2.12,>=2.11->tensorflow) (2022.12.7) Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.8/site-packages (from requests<3,>=2.21.0->tensorboard<2.12,>=2.11->tensorflow) (2.10) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.8/site-packages (from requests<3,>=2.21.0->tensorboard<2.12,>=2.11->tensorflow) (1.26.14) Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.8/site-packages (from requests<3,>=2.21.0->tensorboard<2.12,>=2.11->tensorflow) (3.0.1) Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.8/site-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard<2.12,>=2.11->tensorflow) (3.15.0) Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /opt/conda/lib/python3.8/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.12,>=2.11->tensorflow) (0.4.8) Requirement already satisfied: oauthlib>=3.0.0 in /opt/conda/lib/python3.8/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.12,>=2.11->tensorflow) (3.2.2) WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
如果你需要使用更特定于你的平台或包管理器的安装方法,你可以在这里查看更完整的安装说明。
TensorFlow 版本: 2.11.0
TensorFlow 版本: 2.11.0
2.1 加载数据集
加载并准备 MNIST 数据集。将样本数据从整数转换为浮点数:
array([[ 0.28472176, -0.03656976, -0.5568978 , 0.0447401 , -0.04145458, -0.14942075, 0.26909587, 0.25949594, -0.04086691, 0.34572643]], dtype=float32)
tf.nn.softmax
函数将这些 logits 转换为每个类的概率:
array([[0.12411637, 0.09001065, 0.05349563, 0.09763518, 0.08957204, 0.08040505, 0.122192 , 0.12102458, 0.0896247 , 0.13192376]], dtype=float32)
注:可以将 tf.nn.softmax
烘焙到网络最后一层的激活函数中。虽然这可以使模型输出更易解释,但不建议使用这种方式,因为在使用 softmax 输出时不可能为所有模型提供精确且数值稳定的损失计算。
使用 losses.SparseCategoricalCrossentropy
为训练定义损失函数,它会接受 logits 向量和 True
索引,并为每个样本返回一个标量损失。
此损失等于 true 类的负对数概率:如果模型确定类正确,则损失为零。
这个未经训练的模型给出的概率接近随机(每个类为 1/10),因此初始损失应该接近 -tf.math.log(1/10) ~= 2.3
。
2.5206783
在开始训练之前,使用 Keras Model.compile
配置和编译模型。将 optimizer
类设置为 adam
,将 loss
设置为您之前定义的 loss_fn
函数,并通过将 metrics
参数设置为 accuracy
来指定要为模型评估的指标。
Epoch 1/5 1875/1875 [==============================] - 6s 3ms/step - loss: 0.3014 - accuracy: 0.9123 Epoch 2/5 1875/1875 [==============================] - 6s 3ms/step - loss: 0.1438 - accuracy: 0.9579 Epoch 3/5 1875/1875 [==============================] - 6s 3ms/step - loss: 0.1062 - accuracy: 0.9681 Epoch 4/5 1875/1875 [==============================] - 6s 3ms/step - loss: 0.0848 - accuracy: 0.9739 Epoch 5/5 1875/1875 [==============================] - 6s 3ms/step - loss: 0.0745 - accuracy: 0.9767
<keras.callbacks.History at 0x7f70f2a6c2e0>
Model.evaluate
方法通常在 "Validation-set" 或 "Test-set" 上检查模型性能。
313/313 - 0s - loss: 0.0734 - accuracy: 0.9786 - 458ms/epoch - 1ms/step
[0.07335715740919113, 0.978600025177002]
现在,这个照片分类器的准确度已经达到 98%。想要了解更多,请阅读 TensorFlow 教程。
如果您想让模型返回概率,可以封装经过训练的模型,并将 softmax 附加到该模型:
<tf.Tensor: shape=(5, 10), dtype=float32, numpy= array([[1.08303681e-07, 6.54607124e-09, 5.52915935e-06, 3.82668914e-05, 2.15424262e-10, 1.91471074e-07, 1.79550073e-13, 9.99944329e-01, 7.39323582e-07, 1.08608301e-05], [1.18504238e-06, 2.17750960e-04, 9.98743713e-01, 8.70145741e-04, 5.40504800e-16, 1.63173332e-04, 1.18685925e-07, 4.76410404e-11, 3.88029775e-06, 7.46752920e-12], [1.45594271e-07, 9.97109830e-01, 6.52094022e-04, 3.42531675e-05, 2.32106504e-05, 2.75247676e-05, 3.74190495e-05, 1.88533065e-03, 2.27929631e-04, 2.34026993e-06], [9.99918818e-01, 3.94812805e-09, 2.14232896e-05, 1.83060642e-08, 3.61103304e-07, 5.01507458e-08, 8.47206138e-06, 5.02170624e-05, 7.43030359e-11, 5.72064039e-07], [7.77674018e-07, 1.13904397e-09, 8.17028558e-06, 5.48501355e-10, 9.99079227e-01, 4.27116431e-09, 2.78238957e-07, 2.03034971e-04, 5.55814040e-07, 7.07846368e-04]], dtype=float32)>
![](https://bohrium.oss-cn-zhangjiakou.aliyuncs.com/article/14076/3ede223eb868435588437d18e31af318/bfa82193-a864-45cb-a26c-9c49f44e25a7.png?x-oss-process=image/resize,w_100,m_lfit)
![](https://cdn1.deepmd.net/bohrium/web/static/images/level-v2-4.png?x-oss-process=image/resize,w_50,m_lfit)
![](https://cdn1.deepmd.net/static/img/d7d9741bda38a158-957c-4877-942f-4bf6f81fcc63.png?x-oss-process=image/resize,w_100,m_lfit)
![](https://cdn1.deepmd.net/bohrium/web/static/images/level-v2-1.png?x-oss-process=image/resize,w_50,m_lfit)
![](https://bohrium.oss-cn-zhangjiakou.aliyuncs.com/article/14076/3ede223eb868435588437d18e31af318/bfa82193-a864-45cb-a26c-9c49f44e25a7.png?x-oss-process=image/resize,w_100,m_lfit)
![](https://cdn1.deepmd.net/bohrium/web/static/images/level-v2-4.png?x-oss-process=image/resize,w_50,m_lfit)
![](https://cdn1.deepmd.net/static/img/d7d9741bda38a158-957c-4877-942f-4bf6f81fcc63.png?x-oss-process=image/resize,w_100,m_lfit)
![](https://cdn1.deepmd.net/bohrium/web/static/images/level-v2-1.png?x-oss-process=image/resize,w_50,m_lfit)