TensorFlow-slim包进行图像数据集分类---具体流程

这篇具有很好参考价值的文章主要介绍了TensorFlow-slim包进行图像数据集分类---具体流程。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

本次使用的TensorFlow版本是1.13.0
地址:https://github.com/tensorflow/models/tree/r1.13.0
到tensorflow-models的GitHub下载research下面的slim这个包到本地
TensorFlow-slim包进行图像数据集分类---具体流程,tensorflow,neo4j,人工智能

TensorFlow中slim包的目录结构:

-- slim
    |-- BUILD
    |-- README.md
    |-- WORKSPACE
    |-- __init__.py
    |-- datasets
    |   |-- __init__.py
    |   |-- __pycache__
    |   |   |-- __init__.cpython-37.pyc
    |   |   |-- dataset_utils.cpython-37.pyc
    |   |   |-- download_and_convert_cifar10.cpython-37.pyc
    |   |   |-- download_and_convert_flowers.cpython-37.pyc
    |   |   `-- download_and_convert_mnist.cpython-37.pyc
    |   |-- build_imagenet_data.py
    |   |-- cifar10.py
    |   |-- dataset_factory.py
    |   |-- dataset_utils.py
    |   |-- download_and_convert_cifar10.py
    |   |-- download_and_convert_flowers.py
    |   |-- download_and_convert_imagenet.sh
    |   |-- download_and_convert_mnist.py
    |   |-- download_imagenet.sh
    |   |-- flowers.py
    |   |-- imagenet.py
    |   |-- imagenet_2012_validation_synset_labels.txt
    |   |-- imagenet_lsvrc_2015_synsets.txt
    |   |-- imagenet_metadata.txt
    |   |-- mnist.py
    |   |-- preprocess_imagenet_validation_data.py
    |   `-- process_bounding_boxes.py
    |-- deployment
    |   |-- __init__.py
    |   |-- model_deploy.py
    |   `-- model_deploy_test.py
    |-- download_and_convert_data.py    # 下载相应的数据集,并将数据打包成TF-record的格式
    |-- eval_image_classifier.py        # 测试模型分类效果
    |-- export_inference_graph.py
    |-- export_inference_graph_test.py
    |-- nets
    |   |-- __init__.py
    |   |-- alexnet.py
    |   |-- alexnet_test.py
    |   |-- cifarnet.py
    |   |-- cyclegan.py
    |   |-- cyclegan_test.py
    |   |-- dcgan.py
    |   |-- dcgan_test.py
    |   |-- i3d.py
    |   |-- i3d_test.py
    |   |-- i3d_utils.py
    |   |-- inception.py
    |   |-- inception_resnet_v2.py
    |   |-- inception_resnet_v2_test.py
    |   |-- inception_utils.py
    |   |-- inception_v1.py
    |   |-- inception_v1_test.py
    |   |-- inception_v2.py
    |   |-- inception_v2_test.py
    |   |-- inception_v3.py
    |   |-- inception_v3_test.py
    |   |-- inception_v4.py
    |   |-- inception_v4_test.py
    |   |-- lenet.py
    |   |-- mobilenet
    |   |   |-- README.md
    |   |   |-- __init__.py
    |   |   |-- conv_blocks.py
    |   |   |-- madds_top1_accuracy.png
    |   |   |-- mnet_v1_vs_v2_pixel1_latency.png
    |   |   |-- mobilenet.py
    |   |   |-- mobilenet_example.ipynb
    |   |   |-- mobilenet_v2.py
    |   |   `-- mobilenet_v2_test.py
    |   |-- mobilenet_v1.md
    |   |-- mobilenet_v1.png
    |   |-- mobilenet_v1.py
    |   |-- mobilenet_v1_eval.py
    |   |-- mobilenet_v1_test.py
    |   |-- mobilenet_v1_train.py
    |   |-- nasnet
    |   |   |-- README.md
    |   |   |-- __init__.py
    |   |   |-- nasnet.py
    |   |   |-- nasnet_test.py
    |   |   |-- nasnet_utils.py
    |   |   |-- nasnet_utils_test.py
    |   |   |-- pnasnet.py
    |   |   `-- pnasnet_test.py
    |   |-- nets_factory.py
    |   |-- nets_factory_test.py
    |   |-- overfeat.py
    |   |-- overfeat_test.py
    |   |-- pix2pix.py
    |   |-- pix2pix_test.py
    |   |-- resnet_utils.py
    |   |-- resnet_v1.py
    |   |-- resnet_v1_test.py
    |   |-- resnet_v2.py
    |   |-- resnet_v2_test.py
    |   |-- s3dg.py
    |   |-- s3dg_test.py
    |   |-- vgg.py
    |   `-- vgg_test.py
    |-- preprocessing
    |   |-- __init__.py
    |   |-- cifarnet_preprocessing.py
    |   |-- inception_preprocessing.py
    |   |-- lenet_preprocessing.py
    |   |-- preprocessing_factory.py
    |   `-- vgg_preprocessing.py
    |-- scripts                     # gqr:存储的是相关的模型训练脚本                
    |   |-- export_mobilenet.sh
    |   |-- finetune_inception_resnet_v2_on_flowers.sh
    |   |-- finetune_inception_v1_on_flowers.sh
    |   |-- finetune_inception_v3_on_flowers.sh
    |   |-- finetune_resnet_v1_50_on_flowers.sh
    |   |-- train_cifarnet_on_cifar10.sh
    |   `-- train_lenet_on_mnist.sh
    |-- setup.py
    |-- slim_walkthrough.ipynb
    `-- train_image_classifier.py    # 训练模型的脚本

1、训练脚本文件(该文件包含数据下载打包、模型训练,模型评估流程)

scripts/finetune_resnet_v1_50_on_flowers.sh

#!/bin/bash
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# This script performs the following operations:
# 1. Downloads the Flowers dataset
# 2. Fine-tunes a ResNetV1-50 model on the Flowers training set.
# 3. Evaluates the model on the Flowers validation set.
#
# Usage:
# cd slim
# ./slim/scripts/finetune_resnet_v1_50_on_flowers.sh
set -e

# Where the pre-trained ResNetV1-50 checkpoint is saved to.
PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints   # gqr:预训练模型存放路径

# Where the training (fine-tuned) checkpoint and logs will be saved to.
TRAIN_DIR=/tmp/flowers-models/resnet_v1_50

# Where the dataset is saved to.
DATASET_DIR=/tmp/flowers    # gqr:数据集存放路径

# Download the pre-trained checkpoint.
if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then
  mkdir ${PRETRAINED_CHECKPOINT_DIR}
fi
if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt ]; then
  wget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz
  tar -xvf resnet_v1_50_2016_08_28.tar.gz
  mv resnet_v1_50.ckpt ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt
  rm resnet_v1_50_2016_08_28.tar.gz
fi

# Download the dataset
python download_and_convert_data.py \
  --dataset_name=flowers \
  --dataset_dir=${DATASET_DIR}

# Fine-tune only the new layers for 3000 steps.
python train_image_classifier.py \
  --train_dir=${TRAIN_DIR} \
  --dataset_name=flowers \
  --dataset_split_name=train \
  --dataset_dir=${DATASET_DIR} \
  --model_name=resnet_v1_50 \
  --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt \
  --checkpoint_exclude_scopes=resnet_v1_50/logits \
  --trainable_scopes=resnet_v1_50/logits \
  --max_number_of_steps=3000 \
  --batch_size=32 \
  --learning_rate=0.01 \
  --save_interval_secs=60 \
  --save_summaries_secs=60 \
  --log_every_n_steps=100 \
  --optimizer=rmsprop \
  --weight_decay=0.00004

# Run evaluation.
python eval_image_classifier.py \
  --checkpoint_path=${TRAIN_DIR} \
  --eval_dir=${TRAIN_DIR} \
  --dataset_name=flowers \
  --dataset_split_name=validation \
  --dataset_dir=${DATASET_DIR} \
  --model_name=resnet_v1_50

# Fine-tune all the new layers for 1000 steps.
python train_image_classifier.py \
  --train_dir=${TRAIN_DIR}/all \
  --dataset_name=flowers \
  --dataset_split_name=train \
  --dataset_dir=${DATASET_DIR} \
  --checkpoint_path=${TRAIN_DIR} \
  --model_name=resnet_v1_50 \
  --max_number_of_steps=1000 \
  --batch_size=32 \
  --learning_rate=0.001 \
  --save_interval_secs=60 \
  --save_summaries_secs=60 \
  --log_every_n_steps=100 \
  --optimizer=rmsprop \
  --weight_decay=0.00004

# Run evaluation.
python eval_image_classifier.py \
  --checkpoint_path=${TRAIN_DIR}/all \
  --eval_dir=${TRAIN_DIR}/all \
  --dataset_name=flowers \
  --dataset_split_name=validation \
  --dataset_dir=${DATASET_DIR} \
  --model_name=resnet_v1_50

以上文件以下载并打包flowers数据集为例会调用slim/datasets下的****download_and_convert_flowers.py
TensorFlow-slim包进行图像数据集分类---具体流程,tensorflow,neo4j,人工智能
代码43行:_NUM_VALIDATION = 350值的意思的测试数据集的数量,我们一般2,8分数据集,这里只用填写测试集的数据代码会自动吧总数据集分成2部分
代码48行:_NUM_SHARDS = 1这个的意思是生成几个tfrecord文件,这个数量是根据你数据量来划分
TensorFlow-slim包进行图像数据集分类---具体流程,tensorflow,neo4j,人工智能
代码190行:dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir) 函数为下载数据集函数,如果本地已经存在数据集,可将将其注释掉
TensorFlow-slim包进行图像数据集分类---具体流程,tensorflow,neo4j,人工智能
代码210行:_clean_up_temporary_files(dataset_dir) 函数为打包完毕后删除下载的数据集文件,如果需要下载的数据集可以将其注释掉

上述文件执行完毕后,会得到以下文件
TensorFlow-slim包进行图像数据集分类---具体流程,tensorflow,neo4j,人工智能

3、模型训练

模型训练文件为
TensorFlow-slim包进行图像数据集分类---具体流程,tensorflow,neo4j,人工智能
以下是该文件中各个模块相关内容

1、数据集相关模块:

TensorFlow-slim包进行图像数据集分类---具体流程,tensorflow,neo4j,人工智能

2、设置网络模型模块

TensorFlow-slim包进行图像数据集分类---具体流程,tensorflow,neo4j,人工智能

3、数据预处理模块

TensorFlow-slim包进行图像数据集分类---具体流程,tensorflow,neo4j,人工智能

4、定义损失loss

TensorFlow-slim包进行图像数据集分类---具体流程,tensorflow,neo4j,人工智能

5、定义优化器模块

TensorFlow-slim包进行图像数据集分类---具体流程,tensorflow,neo4j,人工智能

运行训练指令:

python train_image_classifier.py \
  --train_dir=./data/flowers-models/resnet_v1_50\
  --dataset_name=flowers \
  --dataset_split_name=train \
  --dataset_dir=./data/flowers \
  --model_name=resnet_v1_50 \
  --checkpoint_path=./data/checkpoints/resnet_v1_50.ckpt \
  --checkpoint_exclude_scopes=resnet_v1_50/logits \
  --trainable_scopes=resnet_v1_50/logits \
  --max_number_of_steps=3000 \ 
  --batch_size=32 \
  --learning_rate=0.01 \
  --save_interval_secs=60 \
  --save_summaries_secs=60 \
  --log_every_n_steps=100 \
  --optimizer=rmsprop \
  --weight_decay=0.00004

–dataset_name=指定模板
–model_name=指定预训练模板
–dataset_dir=指定训练集目录
–checkpoint_exclude_scopes=指定忘记那几层的参数,不带进训练里面,记住提取特征的部分
–train_dir=训练参数存放地址
–trainable_scopes=设定只对那几层变量进行调整,其他层都不进行调整,不设定就会对所有层训练(所以是必须要给定的)
–learning_rate=学习率
–optimizer=优化器
–checkpoint_path:预训练模型存放地址
–max_number_of_steps=训练步数
–batch_size=一次训练所选取的样本数。 (Batch Size的大小影响模型的优化程度和速度。同时其直接影响到GPU内存的使用情况,假如你GPU内存不大,该数值最好设置小一点。)
–weight_decay=即模型中所有参数的二次正则化超参数(这个的加入就是为了防止过拟合加入正则项,weight_decay 是乘在正则项的前面,控制正则化项在损失函数中所占权重的)

注意:在模型训练前,需要下载预训练模型,
wget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz

解压后存放在相应目录文章来源地址https://www.toymoban.com/news/detail-681060.html

到了这里,关于TensorFlow-slim包进行图像数据集分类---具体流程的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 第61步 深度学习图像识别:多分类建模(TensorFlow)

    一、写在前面 截至上期,我们一直都在做二分类的任务,无论是之前的机器学习任务,还是最近更新的图像分类任务。然而,在实际工作中,我们大概率需要进行多分类任务。例如肺部胸片可不仅仅能诊断肺结核,还有COVID-19、细菌性(病毒性)肺炎等等,这就涉及到图像识

    2024年02月11日
    浏览(37)
  • RK3568笔记四:基于TensorFlow花卉图像分类部署

    若该文为原创文章,转载请注明原文出处。 基于正点原子的ATK-DLRK3568部署测试。 花卉图像分类任务,使用使用 tf.keras.Sequential 模型,简单构建模型,然后转换成 RKNN 模型部署到ATK-DLRK3568板子上。 在 PC 使用 Windows 系统安装 tensorflow,并创建虚拟环境进行训练,然后切换到VM下

    2024年02月07日
    浏览(51)
  • 第63步 深度学习图像识别:多分类建模误判病例分析(Tensorflow)

    一、写在前面 上两期我们基于TensorFlow和Pytorch环境做了图像识别的多分类任务建模。这一期我们做误判病例分析,分两节介绍,分别基于TensorFlow和Pytorch环境的建模和分析。 本期以健康组、肺结核组、COVID-19组、细菌性(病毒性)肺炎组为数据集,基于TensorFlow环境,构建mob

    2024年02月10日
    浏览(38)
  • 【计算机视觉 | Pytorch】timm 包的具体介绍和图像分类案例(含源代码)

    timm 是一个 PyTorch 原生实现的计算机视觉模型库。它提供了预训练模型和各种网络组件,可以用于各种计算机视觉任务,例如图像分类、物体检测、语义分割等等。 timm 的特点如下: PyTorch 原生实现: timm 的实现方式与 PyTorch 高度契合,开发者可以方便地使用 PyTorch 的 API 进行

    2024年02月15日
    浏览(40)
  • Azure 机器学习 - 使用 Visual Studio Code训练图像分类 TensorFlow 模型

    了解如何使用 TensorFlow 和 Azure 机器学习 Visual Studio Code 扩展训练图像分类模型来识别手写数字。 关注TechLead,分享AI全维度知识。作者拥有10+年互联网服务架构、AI产品研发经验、团队管理经验,同济本复旦硕,复旦机器人智能实验室成员,阿里云认证的资深架构师,项目管理

    2024年02月06日
    浏览(48)
  • uni-app:对数组对象进行以具体某一项的分类处理

    这里定义为五个数组,种类product有aaa,bbb两种 原始数据在data中进行定义 注:使用了数组的 reduce() 方法来对 devices 数组进行循环遍历,并根据每个设备对象中的 product 值将其分类到一个以 product 为名称的数组中。 reduce() 方法接收一个回调函数和一个初始值作为参数。在这里,

    2024年02月07日
    浏览(53)
  • 卷积神经网络教程 (CNN) – 使用 TensorFlow 在 Python 中开发图像分类器

    在这篇博客中,让我们讨论什么是卷积神经网络 (CNN) 以及 卷积神经网络背后的 架构 ——旨在 解决   图像识别 系统和 分类 问题。 卷积神经网络在图像和视频识别、推荐系统和 自然语言处理方面有着 广泛的应用 。 目录 计算机如何读取图像? 为什么不是全连接网络?

    2024年02月12日
    浏览(42)
  • 三、学习分类 - 基于图像大小进行分类

    天下一半剑仙是我友 谁家娘子不娇羞 我以醇酒洗我剑 谁人说我不风流 根据图片的尺寸,把图片分为纵向图像和横向图像。这种把图像分成两种类别的问题,就是 二分类问题 。 纵向图片示例:    横向图片示例:  这样就有了两个训练数据: 增加训练数据,并在图像中表

    2024年02月16日
    浏览(25)
  • 【Tensorflow object detection API + 微软NNI】图像分类问题完成自动调参,进一步提升模型准确率!

    1. 背景目标 利用Tensorflow object detection API开发并训练图像分类模型(例如,Mobilenetv2等),自己直接手动调参,对于模型的准确率提不到极致,利用微软NNI自动调参工具进行调参,进一步提升准确率。 2. 方法 关于 Tensorflow object detection API 开发并训练图像分类模型详见这篇博客

    2024年02月12日
    浏览(50)
  • YOLOV5 分类:利用yolov5进行图像分类

    之前介绍了yolov5的目标检测示例,这次将介绍yolov5的分类展示 目标检测:YOLOv5 项目:训练代码和参数详细介绍(train)_yolov5训练代码的详解-CSDN博客 yolov5和其他网络的性能对比 yolov5分类的代码部分在这 yolov5分类的数据集就是常规的摆放方式 相同数据放在同样的目录下,目

    2024年04月12日
    浏览(27)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包