Python实现生成对抗网络:生成逼真数据
        
        作者:weixin02 · 2025-01-09 · 阅读时间:5分钟
    
    
    
     
    
        
    
    
        
        
    
     Python实现GAN (生成对抗网络) – 从0到1的深度学习之旅
嘿,小伙伴们!今天咱们要玩一个有趣的项目 – 用Python实现GAN网络。这个项目会帮你理解如何训练AI来生成超逼真的数据。我们会用MNIST手写数字数据集来演示,让AI学会画数字!
第一步:导入必要的包
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
# 设置随机种子,确保结果可复现
torch.manual_seed(42)第二步:准备数据集
# 数据预处理和加载
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 加载MNIST数据集
train_dataset = torchvision.datasets.MNIST(
root='./data',
train=True,
transform=transform,
download=True
)
# 创建数据加载器
train_loader = DataLoader(
dataset=train_dataset,
batch_size=64,
shuffle=True
)第三步:构建生成器网络
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
# 生成器结构:用简单的全连接层
self.gen = nn.Sequential(
# 输入是随机噪声(latent_dim)
nn.Linear(100, 256),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(256),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(512),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.BatchNorm1d(1024),
# 输出层,生成28*28=784维的图像
nn.Linear(1024, 784),
nn.Tanh()
)
def forward(self, x):
return self.gen(x)第四步:构建判别器网络
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
# 判别器结构
self.disc = nn.Sequential(
# 输入是展平的图像(784维)
nn.Linear(784, 1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
# 输出一个概率值
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.disc(x)第五步:训练模型
# 初始化模型
generator = Generator()
discriminator = Discriminator()
# 损失函数和优化器
criterion = nn.BCELoss()
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# 训练参数
num_epochs = 100
latent_dim = 100
fixed_noise = torch.randn(16, latent_dim)  # 用于可视化
# 训练循环
for epoch in range(num_epochs):
for i, (real_images, _) in enumerate(train_loader):
batch_size = real_images.size(0)
# 准备真实和虚假的标签
real_label = torch.ones(batch_size, 1)
fake_label = torch.zeros(batch_size, 1)
# 展平图像
real_images = real_images.view(-1, 784)
# 训练判别器
d_optimizer.zero_grad()
output_real = discriminator(real_images)
d_loss_real = criterion(output_real, real_label)
# 生成假图像
noise = torch.randn(batch_size, latent_dim)
fake_images = generator(noise)
output_fake = discriminator(fake_images.detach())
d_loss_fake = criterion(output_fake, fake_label)
# 计算判别器总损失
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
d_optimizer.step()
# 训练生成器
g_optimizer.zero_grad()
output_fake = discriminator(fake_images)
g_loss = criterion(output_fake, real_label)
g_loss.backward()
g_optimizer.step()
if i % 100 == 0:
print(f'Epoch [{epoch}/{num_epochs}], Step [{i}/{len(train_loader)}], '
f'd_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')第六步:生成和显示结果
def show_images(images):
plt.figure(figsize=(4, 4))
plt.axis(“off”)
plt.imshow(np.transpose(torchvision.utils.make_grid(
images.reshape(-1, 1, 28, 28), nrow=4, padding=2, normalize=True
).cpu(), (1, 2, 0)))
plt.show()
# 生成示例图像
with torch.no_grad():
fake_images = generator(fixed_noise)
show_images(fake_images)小贴士:
- 增加训练稳定性:
* 使用标签平滑化技术
* 添加梯度裁剪
* 尝试不同的激活函数
- 提高生成质量:
* 增加网络层数和通道数
* 使用更复杂的网络结构(如卷积层)
* 调整学习率和批次大小
- 避免模式崩溃:
* 使用Wasserstein GAN
* 实现批次归一化
* 适当调整判别器和生成器的训练比例
这个GAN实现虽然简单,但已经能产生不错的效果了。记得调参是提升效果的关键,慢慢调整,你一定能生成出更逼真的图像!
文章转自微信公众号@寒江映孤月
热门推荐
        一个账号试用1000+ API
            助力AI无缝链接物理世界 · 无需多次注册
            
        3000+提示词助力AI大模型
            和专业工程师共享工作效率翻倍的秘密
            
        热门API
- 1. AI文本生成
- 2. AI图片生成_文生图
- 3. AI图片生成_图生图
- 4. AI图像编辑
- 5. AI视频生成_文生视频
- 6. AI视频生成_图生视频
- 7. AI语音合成_文生语音
- 8. AI文本生成(中国)
最新文章
- 9个最佳Text2Sql开源项目:自然语言到SQL的高效转换工具
- 深入解析API网关策略:认证、授权、安全、流量处理与可观测性
- GraphQL API手册:如何构建、测试、使用和记录
- 自助式入职培训服务API:如何让企业管理更上一层楼?
- Python如何调用Jenkins API自动化发布
- 模型压缩四剑客:量化、剪枝、蒸馏、二值化
- 火山引擎如何接入API:从入门到实践的技术指南
- 为什么每个使用 API 的大型企业都需要一个 API 市场来增强其合作伙伴生态系统
- 构建更优质的API:2025年顶级API开发工具推荐 – Strapi
- 外部函数与内存API – Java 22 – 未记录
- FAPI 2.0 深度解析:下一代金融级 API 安全标准与实践指南
- .NET Core 下的 API 网关