基于Tensorflow的最基本GAN网络模型

这篇具有很好参考价值的文章主要介绍了基于Tensorflow的最基本GAN网络模型。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import glob
import os
#(1)创建输入管道
# 导入原始数据
(train_images, train_labels),(_, _) = tf.keras.datasets.mnist.load_data()
# 查看原始数据大小与数据格式
# 60000张图片,每一张图片都是28*28像素
# print(train_images.shape)
# dtype('uint8'),每一位的范围都是0-255的整数,由于图像的一个通道中rgb颜色值就是0-255不等,因此uint8就是图像的标准数字格式
# print(train_images.dtype)

#(1.1)数据预处理
# 转换数据类型
train_images = train_images.reshape(train_images.shape[0], 28,28,1)
train_images = train_images.astype('float32')

# 归一化0-255>>[-1,1]
train_images = (train_images - 127.5)/127.5

#(1.2)确定训练时的BATCH_SIZE与BUFFER_SIZE
BATCH_SIZE = 256 # 每一个batch指一次训练,batch_size表示一次训练所需的数据个数。这里一次训练需要256张图片
BUFFER_SIZE = 60000 # 目前不知道buffer是干什么的

#(1.3)将归一化后的图像转化为tf内置的一种数据形式
datasets = tf.data.Dataset.from_tensor_slices(train_images)

#(1.4)将训练模型的数据集进行打乱的操作:shuffle
datasets = datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
#(2)生成器模型
def Generator_Model():
    model = keras.Sequential() # 顺序模型
    # dense 全连接层
    # 输入:长度为100的随机数向量(自己定义)
    # 输出:长度为256的向量
    model.add(layers.Dense(256, input_shape = (100,), use_bias = False))
    model.add(layers.BatchNormalization()) # 归一化层
    model.add(layers.LeakyReLU()) # 激活层
    
    model.add(layers.Dense(512, use_bias = False))
    model.add(layers.BatchNormalization()) # 归一化层
    model.add(layers.LeakyReLU()) # 激活层
    
    model.add(layers.Dense(28*28*1, use_bias = False, activation = 'tanh'))
    model.add(layers.BatchNormalization()) # 归一化层
    
    model.add(layers.Reshape((28,28,1))) # 写为元组的形式
    
    return model
#(3)判别器模型
def Discriminator_Model():
    model = keras.Sequential()
    
    model.add(layers.Flatten()) # 将3维图像拉伸为一维图像
    
    model.add(layers.Dense(512, use_bias = False))
    model.add(layers.BatchNormalization()) # 归一化层
    model.add(layers.LeakyReLU()) # 激活层
    
    model.add(layers.Dense(256, use_bias = False))
    model.add(layers.BatchNormalization()) # 归一化层
    model.add(layers.LeakyReLU()) # 激活层
    
    model.add(layers.Dense(1)) # 输出1或者0
    
    return model
    
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits = True)

#(4)判别器的损失函数:对于真是图片,判定为1;对于生成图片,判定为0
def discriminator_loss(real_out, fake_out):
    real_loss = cross_entropy(tf.ones_like(real_out),real_out)
    fake_loss = cross_entropy(tf.zeros_like(fake_out),fake_out)
    return real_loss+fake_loss

#(5)生成器损失函数:对于生成图片,判定为1
def generator_loss(fake_out):
    fake_loss = cross_entropy(tf.ones_like(fake_out),fake_out)
    return fake_loss
#(6)创建判别器和生成器的优化器,定义参数的学习速率
generator_opt = tf.keras.optimizers.Adam(1e-4)
discriminator_opt = tf.keras.optimizers.Adam(1e-4)
EPOCHS = 100
noise_dim = 100
num_exp_to_generate = 16
seed = tf.random.normal([num_exp_to_generate, noise_dim])

# 实例化生成器与判别器
Generator = Generator_Model()
Discriminator = Discriminator_Model()
#(7)训练GAN网络
# 每一个batch
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        real_output = Discriminator(images, training = True)
        gen_image = Generator(noise, training = True)
        fake_output = Discriminator(gen_image, training = True)
        
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)
        
    #优化
    gradient_gen = gen_tape.gradient(gen_loss, Generator.trainable_variables)
    gradient_disc = disc_tape.gradient(disc_loss, Discriminator.trainable_variables)
    generator_opt.apply_gradients(zip(gradient_gen, Generator.trainable_variables))
    discriminator_opt.apply_gradients(zip(gradient_disc, Discriminator.trainable_variables))
# 可视化函数
def generator_plt_img(gen_model, test_noise):

    pre_images = gen_model(test_noise, training = False)
    fig = plt.figure(figsize=(4, 4))
    for i in range(pre_images.shape[0]):
        plt.subplot(4,4,i+1)
        plt.imshow((pre_images[i,:,:,0]+1)/2, cmap = 'gray')
        plt.axis('off')
    plt.show()
# 完整的训练模型的函数
def train(dataset, epochs):
    for epoch in range(epochs):
        for img_batch in dataset:
            train_step(img_batch)
            print('.',end='')
        generator_plt_img(Generator, seed)
# 训练模型
train(datasets, EPOCHS)

视频链接:https://www.bilibili.com/video/BV1f7411E7wU/?spm_id_from=333.999.0.0文章来源地址https://www.toymoban.com/news/detail-412721.html

到了这里,关于基于Tensorflow的最基本GAN网络模型的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • Stable Diffusion模型基于 TensorFlow 或 PyTorch 训练

    安装必要的软件和库 : 安装 Python(建议使用 Python 3.x 版本)。 安装 TensorFlow 或 PyTorch,具体版本取决于你的模型是基于哪个框架训练的。 安装其他可能需要的依赖,如 NumPy、Matplotlib 等。 获取模型代码和权重 : 下载或克隆 Stable Diffusion 的代码仓库(如果可用)。 下载预训

    2024年04月28日
    浏览(42)
  • TensorFlow学习之:了解和实践卷积神经网络和序列模型

    学习CNN的结构和原理,了解如何用TensorFlow实现CNN。 卷积神经网络(Convolutional Neural Networks,CNN)是深度学习中的一种强大的模型架构,特别适合于处理图像数据。CNN通过使用卷积层自动地从图像中学习空间层级的特征,这使得它们在图像分类、物体检测、图像分割等计算机视

    2024年04月17日
    浏览(32)
  • 鸟类识别系统python+TensorFlow+Django网页界面+卷积网络算法+深度学习模型

    鸟类识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Django框架,开发网页端操作平台,实现用户上传一张图片识别其名称。 视频+代码:https://www.yuque.com/ziwu/

    2024年02月16日
    浏览(39)
  • 基于包围框回归的目标检测网络原理及Tensorflow实现

    对象检测是对图像内的对象进行分类和定位。 换句话说,它是图像分类和对象定位的结合。 构建用于图像分类的机器学习模型更简单,我在我的一篇文章中对此进行了描述。 然而,图像分类器无法准确判断对象在图像内的位置。 为了实现这一目标,我们需要构建一个神经网

    2024年02月16日
    浏览(43)
  • 水果识别系统Python,基于TensorFlow卷积神经网络算法

    提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 提示 面对水果识别系统Python,基于TensorFlow卷积神经网络算---深度学习算法: 提示:以下是本篇文章正文内容,下面案例可供参考 果蔬识别系统,使用Python作为主要开发语言,使用深度学习 TensorFLOw框架

    2024年01月16日
    浏览(75)
  • CNN卷积神经网络实现手写数字识别(基于tensorflow)

    卷积网络的 核心思想 是将: 局部感受野 权值共享(或者权值复制) 时间或空间亚采样 卷积神经网络 (Convolutional Neural Networks,简称: CNN )是深度学习当中一个非常重要的神经网络结构。它主要用于用在 图像图片处理 , 视频处理 , 音频处理 以及 自然语言处理 等等。

    2024年02月11日
    浏览(41)
  • 鸟类识别Python,基于TensorFlow卷积神经网络【实战项目】

    鸟类识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Django框架,开发网页端操作平台,实现用户上传一张图片识别其名称。 数据集选自加州理工学院200种鸟类

    2024年02月10日
    浏览(48)
  • FPGA上利用Vitis AI部署resnet50 TensorFlow神经网络模型

    参考Xilinx官方教程快速入门 • Vitis AI 用户指南 (UG1414) 克隆 Vitis AI 存储库以获取示例、参考代码和脚本(连接github失败可能需要科学上网)。 安装Docker如何在 Ubuntu 20.04 上安装和使用 Docker 安装完docker后,下载最新Vitis AI Docker, 将官方的指令 docker pull xilinx/vitis-ai-pytorch/tensorfl

    2024年02月04日
    浏览(44)
  • 【深度学习】基于卷积神经网络(tensorflow)的人脸识别项目(一)

    ​ 活动地址:CSDN21天学习挑战赛 经过前段时间研究,从LeNet-5手写数字入门到最近研究的一篇天气识别。我想干一票大的,因为我本身从事的就是C++/Qt开发,对Qt还是比较熟悉,所以我想实现一个基于Qt的界面化的一个人脸识别。 对卷积神经网络的概念比较陌生的可以看一看

    2024年02月04日
    浏览(54)
  • 【深度学习】基于卷积神经网络(tensorflow)的人脸识别项目(四)

    经过前段时间研究,从LeNet-5手写数字入门到最近研究的一篇天气识别。我想干一票大的,因为我本身从事的就是C++/Qt开发,对Qt还是比较熟悉,所以我想实现一个界面化的一个人脸识别。 对卷积神经网络的概念比较陌生的可以看一看这篇文章:卷积实际上是干了什么 想了解

    2024年01月17日
    浏览(149)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包