使用Keras API入门TensorFlow | 作者:Anurag Dhadse
文章目录
使用 Keras API 入门 TensorFlow
Keras 是一个高级深度学习 API(应用编程接口),它能够帮助开发者轻松地构建、训练、评估和部署各种神经网络。Keras 的核心功能是对底层深度学习框架(如 TensorFlow、Microsoft Cognitive Toolkit(CNTK)和 Theano)的实现进行抽象,提供统一的接口。
TensorFlow 是一个功能强大的深度学习库,除了支持数值计算和大规模机器学习外,还提供了许多实用工具,例如用于模型可视化的 TensorBoard,以及用于生产环境的 TensorFlow 扩展(TFX)等。
使用 TensorFlow 的 Keras 实现
TensorFlow 提供了自己的 Keras 实现,允许我们在其框架内构建神经网络。以下是使用 Keras 的基本步骤:
- 导入 Keras 模块。
- 根据模型的复杂程度选择合适的 API。
Keras 提供了三种不同的 API,用于创建神经网络:
- Sequential API:适用于简单的线性堆叠模型。
- Functional API:适用于复杂的非线性模型。
- 子类化 API:适用于动态模型结构。
使用 Sequential API 构建简单模型
Sequential API 是最简单的 Keras API,适用于线性堆叠的神经网络。以下是一个使用 MNIST 数据集构建分类模型的示例:
模型结构
- Flatten 层:将输入的二维图像(28×28 像素)展平为一维数组(784 个特征)。
- Dense 层:堆叠两个全连接层,激活函数为 ReLU。
- 输出层:使用带有 Softmax 激活的 Dense 层,用于多分类任务。
以下是代码示例:
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Flatten, Dense
model = Sequential([
Flatten(input_shape=(28, 28)),
Dense(300, activation='relu'),
Dense(100, activation='relu'),
Dense(10, activation='softmax')
])
模型摘要
调用 model.summary() 方法可以查看模型的层次结构和参数数量。例如:
- Flatten 层将输入展平为 784 个神经元。
- 第一层 Dense 层有 300 个神经元,总参数数为
300x784 + 300 = 235500。 - 其他层的参数计算方式类似。
编译和训练
在编译模型时,我们需要指定损失函数、优化器和评估指标:
model.compile(
loss='sparse_categorical_crossentropy',
optimizer='adam',
metrics=['accuracy']
)
然后,使用 fit() 方法进行训练:
history = model.fit(X_train, y_train, epochs=10, validation_data=(X_val, y_val))
使用 Functional API 构建复杂模型
Functional API 适用于具有多个输入、输出或复杂连接的模型。例如,Wide & Deep 模型同时学习深层模式和简单规则。
示例:加州房价预测
以下是使用 Functional API 构建 Wide & Deep 模型的步骤:
-
创建输入层:
from tensorflow.keras.layers import Input, Dense, concatenate input_ = Input(shape=(8,)) -
构建隐藏层:
hidden1 = Dense(30, activation='relu')(input_) hidden2 = Dense(30, activation='relu')(hidden1) -
合并输入和隐藏层:
concat = concatenate([input_, hidden2]) -
添加输出层并创建模型:
output = Dense(1)(concat) from tensorflow.keras.models import Model model = Model(inputs=[input_], outputs=[output])
多输入和多输出
Functional API 还支持多输入、多输出模型。例如:
-
使用字典传递多个输入:
model.fit({'input1': X1, 'input2': X2}, y, epochs=10) -
为每个输出指定损失函数和权重:
model.compile( loss=['mse', 'mae'], loss_weights=[0.8, 0.2], optimizer='adam' )
使用子类化 API 构建动态模型
对于需要循环、条件分支或动态结构的模型,可以使用子类化 API。这种方法需要继承 tf.keras.Model 类,并在构造函数中定义层,在 call() 方法中定义前向传播逻辑。
以下是 Wide & Deep 模型的子类化实现:
import tensorflow as tf
class WideAndDeepModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.hidden1 = Dense(30, activation='relu')
self.hidden2 = Dense(30, activation='relu')
self.output_layer = Dense(1)
def call(self, inputs):
input_, hidden = inputs
hidden1 = self.hidden1(input_)
hidden2 = self.hidden2(hidden1)
concat = tf.concat([hidden, hidden2], axis=-1)
return self.output_layer(concat)
尽管子类化 API 提供了更大的灵活性,但它也有一些缺点,例如无法使用 model.summary() 检查模型结构,且容易出错。因此,除非需要动态模型,否则建议优先使用 Sequential 或 Functional API。
总结
Keras 提供了多种 API,适用于不同复杂度的深度学习任务:
- Sequential API:适合简单模型。
- Functional API:适合复杂模型,例如多输入、多输出。
- 子类化 API:适合动态模型,但使用时需谨慎。
根据任务需求选择合适的 API,可以更高效地构建和训练神经网络。
原文链接: https://medium.com/analytics-vidhya/introduction-to-tensorflow-with-keras-api-36cbeeb562d5
最新文章
- 用 Poe-API-wrapper 连接 DALLE、ChatGPT,批量完成AI绘图或文字创作
- 2025年20大自动化API测试工具 – HeadSpin
- RESTful Web API 设计中要避免的 6 个常见错误
- LangGraph 工具详解:构建 AI 多步骤流程的关键利器
- GitHubAPI调用频率限制的增加方法
- 如何使用Route Optimization API优化配送路线
- 什么是聚类分析?
- 安全好用的OpenApi
- 医疗数据管理与fhir api的未来发展趋势
- 为什么要使用Google My Business Reviews API
- 2025年7月第2周GitHub热门API推荐:rustfs/rustfs、pocketbase/pocketbase、smallcloudai/refact
- API设计的首要原则