使用Keras API入门TensorFlow | 作者:Anurag Dhadse

作者:API传播员 · 2025-11-27 · 阅读时间:5分钟
本文介绍如何使用Keras API入门TensorFlow,涵盖Sequential API构建简单模型、Functional API处理复杂模型如Wide & Deep,以及子类化API用于动态结构,帮助开发者根据任务需求高效构建和训练神经网络。

使用 Keras API 入门 TensorFlow

Keras 是一个高级深度学习 API(应用编程接口),它能够帮助开发者轻松地构建、训练、评估和部署各种神经网络。Keras 的核心功能是对底层深度学习框架(如 TensorFlow、Microsoft Cognitive Toolkit(CNTK)和 Theano)的实现进行抽象,提供统一的接口。

TensorFlow 是一个功能强大的深度学习库,除了支持数值计算和大规模机器学习外,还提供了许多实用工具,例如用于模型可视化的 TensorBoard,以及用于生产环境的 TensorFlow 扩展(TFX)等。


使用 TensorFlow 的 Keras 实现

TensorFlow 提供了自己的 Keras 实现,允许我们在其框架内构建神经网络。以下是使用 Keras 的基本步骤:

  1. 导入 Keras 模块。
  2. 根据模型的复杂程度选择合适的 API。

Keras 提供了三种不同的 API,用于创建神经网络:

  • Sequential API:适用于简单的线性堆叠模型。
  • Functional API:适用于复杂的非线性模型。
  • 子类化 API:适用于动态模型结构。

使用 Sequential API 构建简单模型

Sequential API 是最简单的 Keras API,适用于线性堆叠的神经网络。以下是一个使用 MNIST 数据集构建分类模型的示例:

模型结构

  1. Flatten 层:将输入的二维图像(28×28 像素)展平为一维数组(784 个特征)。
  2. Dense 层:堆叠两个全连接层,激活函数为 ReLU。
  3. 输出层:使用带有 Softmax 激活的 Dense 层,用于多分类任务。

以下是代码示例:

from tensorflow.keras import Sequential
from tensorflow.keras.layers import Flatten, Dense

model = Sequential([
    Flatten(input_shape=(28, 28)),
    Dense(300, activation='relu'),
    Dense(100, activation='relu'),
    Dense(10, activation='softmax')
])

模型摘要

调用 model.summary() 方法可以查看模型的层次结构和参数数量。例如:

  • Flatten 层将输入展平为 784 个神经元。
  • 第一层 Dense 层有 300 个神经元,总参数数为 300x784 + 300 = 235500
  • 其他层的参数计算方式类似。

编译和训练

在编译模型时,我们需要指定损失函数、优化器和评估指标:

model.compile(
    loss='sparse_categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

然后,使用 fit() 方法进行训练:

history = model.fit(X_train, y_train, epochs=10, validation_data=(X_val, y_val))

使用 Functional API 构建复杂模型

Functional API 适用于具有多个输入、输出或复杂连接的模型。例如,Wide & Deep 模型同时学习深层模式和简单规则。

示例:加州房价预测

以下是使用 Functional API 构建 Wide & Deep 模型的步骤:

  1. 创建输入层:

    from tensorflow.keras.layers import Input, Dense, concatenate
    input_ = Input(shape=(8,))
  2. 构建隐藏层:

    hidden1 = Dense(30, activation='relu')(input_)
    hidden2 = Dense(30, activation='relu')(hidden1)
  3. 合并输入和隐藏层:

    concat = concatenate([input_, hidden2])
  4. 添加输出层并创建模型:

    output = Dense(1)(concat)
    from tensorflow.keras.models import Model
    model = Model(inputs=[input_], outputs=[output])

多输入和多输出

Functional API 还支持多输入、多输出模型。例如:

  • 使用字典传递多个输入:

    model.fit({'input1': X1, 'input2': X2}, y, epochs=10)
  • 为每个输出指定损失函数和权重:

    model.compile(
      loss=['mse', 'mae'],
      loss_weights=[0.8, 0.2],
      optimizer='adam'
    )

使用子类化 API 构建动态模型

对于需要循环、条件分支或动态结构的模型,可以使用子类化 API。这种方法需要继承 tf.keras.Model 类,并在构造函数中定义层,在 call() 方法中定义前向传播逻辑。

以下是 Wide & Deep 模型的子类化实现:

import tensorflow as tf

class WideAndDeepModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.hidden1 = Dense(30, activation='relu')
        self.hidden2 = Dense(30, activation='relu')
        self.output_layer = Dense(1)

    def call(self, inputs):
        input_, hidden = inputs
        hidden1 = self.hidden1(input_)
        hidden2 = self.hidden2(hidden1)
        concat = tf.concat([hidden, hidden2], axis=-1)
        return self.output_layer(concat)

尽管子类化 API 提供了更大的灵活性,但它也有一些缺点,例如无法使用 model.summary() 检查模型结构,且容易出错。因此,除非需要动态模型,否则建议优先使用 Sequential 或 Functional API。


总结

Keras 提供了多种 API,适用于不同复杂度的深度学习任务:

  • Sequential API:适合简单模型。
  • Functional API:适合复杂模型,例如多输入、多输出。
  • 子类化 API:适合动态模型,但使用时需谨慎。

根据任务需求选择合适的 API,可以更高效地构建和训练神经网络。

原文链接: https://medium.com/analytics-vidhya/introduction-to-tensorflow-with-keras-api-36cbeeb562d5