首页 后端开发 Python教程 Tensorflow分类器项目自定义数据读入的方法介绍(代码示例)

Tensorflow分类器项目自定义数据读入的方法介绍(代码示例)

Feb 11, 2019 am 10:28 AM
input python tensorflow

本篇文章给大家带来的内容是关于Tensorflow分类器项目自定义数据读入的方法介绍(代码示例),有一定的参考价值,有需要的朋友可以参考一下,希望对你有所帮助。

Tensorflow分类器项目自定义数据读入

在照着Tensorflow官网的demo敲了一遍分类器项目的代码后,运行倒是成功了,结果也不错。但是最终还是要训练自己的数据,所以尝试准备加载自定义的数据,然而demo中只是出现了fashion_mnist.load_data()并没有详细的读取过程,随后我又找了些资料,把读取的过程记录在这里。

首先提一下需要用到的模块:

import os

import keras
import matplotlib.pyplot as plt
from PIL import Image
from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split
登录后复制

图片分类器项目,首先确定你要处理的图片分辨率将是多少,这里的例子为30像素:

IMG_SIZE_X = 30
IMG_SIZE_Y = 30
登录后复制

其次确定你图片的方式目录:

image_path = r'D:\Projects\ImageClassifier\data\set'
path = ".\data"
# 你也可以使用相对路径的方式
# image_path =os.path.join(path, "set")
登录后复制

目录下的结构如下:

210901301-5c57a5e09df2b_articlex.png

相应的label.txt如下:

动漫
风景
美女
物语
樱花
登录后复制

接下来是接在labels.txt,如下:

label_name = "labels.txt"
label_path = os.path.join(path, label_name)
class_names = np.loadtxt(label_path, type(""))
登录后复制

这里简便起见,直接利用了numpy的loadtxt函数直接加载。

之后便是正式处理图片数据了,注释就写在里面了:

re_load = False
re_build = False
# re_load = True
re_build = True

data_name = "data.npz"
data_path = os.path.join(path, data_name)
model_name = "model.h5"
model_path = os.path.join(path, model_name)

count = 0

# 这里判断是否存在序列化之后的数据,re_load是一个开关,是否强制重新处理,测试用,可以去除。
if not os.path.exists(data_path) or re_load:
    labels = []
    images = []
    print('Handle images')
    # 由于label.txt是和图片防止目录的分类目录一一对应的,即每个子目录的目录名就是labels.txt里的一个label,所以这里可以通过读取class_names的每一项去拼接path后读取
    for index, name in enumerate(class_names):
        # 这里是拼接后的子目录path
        classpath = os.path.join(image_path, name)
        # 先判断一下是否是目录
        if not os.path.isdir(classpath):
            continue
        # limit是测试时候用的这里可以去除
        limit = 0
        for image_name in os.listdir(classpath):
            if limit >= max_size:
                break
            # 这里是拼接后的待处理的图片path
            imagepath = os.path.join(classpath, image_name)
            count = count + 1
            limit = limit + 1
            # 利用Image打开图片
            img = Image.open(imagepath)
            # 缩放到你最初确定要处理的图片分辨率大小
            img = img.resize((IMG_SIZE_X, IMG_SIZE_Y))
            # 转为灰度图片,这里彩色通道会干扰结果,并且会加大计算量
            img = img.convert("L")
            # 转为numpy数组
            img = np.array(img)
            # 由(30,30)转为(1,30,30)(即`channels_first`),当然你也可以转换为(30,30,1)(即`channels_last`)但为了之后预览处理后的图片方便这里采用了(1,30,30)的格式存放
            img = np.reshape(img, (1, IMG_SIZE_X, IMG_SIZE_Y))
            # 这里利用循环生成labels数据,其中存放的实际是class_names中对应元素的索引
            labels.append([index])
            # 添加到images中,最后统一处理
            images.append(img)
            # 循环中一些状态的输出,可以去除
            print("{} class: {} {} limit: {} {}"
                  .format(count, index + 1, class_names[index], limit, imagepath))
    # 最后一次性将images和labels都转换成numpy数组
    npy_data = np.array(images)
    npy_labels = np.array(labels)
    # 处理数据只需要一次,所以我们选择在这里利用numpy自带的方法将处理之后的数据序列化存储
    np.savez(data_path, x=npy_data, y=npy_labels)
    print("Save images by npz")
else:
    # 如果存在序列化号的数据,便直接读取,提高速度
    npy_data = np.load(data_path)["x"]
    npy_labels = np.load(data_path)["y"]
    print("Load images by npz")
image_data = npy_data
labels_data = npy_labels
登录后复制

到了这里原始数据的加工预处理便已经完成,只需要最后一步,就和demo中fashion_mnist.load_data()返回的结果一样了。代码如下:

# 最后一步就是将原始数据分成训练数据和测试数据
train_images, test_images, train_labels, test_labels = \
    train_test_split(image_data, labels_data, test_size=0.2, random_state=6)
登录后复制

这里将相关信息打印的方法也附上:

print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Image Data", image_data.shape))
print("%-28s %-s" % ("Labels Data", labels_data.shape))
print("=================================================================")

print('Split train and test data,p=%')
print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Train Images", train_images.shape))
print("%-28s %-s" % ("Test Images", test_images.shape))
print("%-28s %-s" % ("Train Labels", train_labels.shape))
print("%-28s %-s" % ("Test Labels", test_labels.shape))
print("=================================================================")
登录后复制

之后别忘了归一化哟:

print("Normalize images")
train_images = train_images / 255.0
test_images = test_images / 255.0
登录后复制

最后附上读取自定义数据的完整代码:

import os

import keras
import matplotlib.pyplot as plt
from PIL import Image
from keras.layers import *
from keras.models import *
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
re_load = False
re_build = False
# re_load = True
re_build = True
epochs = 50
batch_size = 5
count = 0
max_size = 2000000000
IMG_SIZE_X = 30
IMG_SIZE_Y = 30
np.random.seed(9277)
image_path = r'D:\Projects\ImageClassifier\data\set'
path = ".\data"
data_name = "data.npz"
data_path = os.path.join(path, data_name)
model_name = "model.h5"
model_path = os.path.join(path, model_name)
label_name = "labels.txt"
label_path = os.path.join(path, label_name)
class_names = np.loadtxt(label_path, type(""))
print('Load class names')
if not os.path.exists(data_path) or re_load:
    labels = []
    images = []
    print('Handle images')
    for index, name in enumerate(class_names):
        classpath = os.path.join(image_path, name)
        if not os.path.isdir(classpath):
            continue
        limit = 0
        for image_name in os.listdir(classpath):
            if limit >= max_size:
                break
            imagepath = os.path.join(classpath, image_name)
            count = count + 1
            limit = limit + 1
            img = Image.open(imagepath)
            img = img.resize((30, 30))
            img = img.convert("L")
            img = np.array(img)
            img = np.reshape(img, (1, 30, 30))
            # img = skimage.io.imread(imagepath, as_grey=True)
            # if img.shape[2] != 3:
            #     print("{} shape is {}".format(image_name, img.shape))
            #     continue
            # data = transform.resize(img, (IMG_SIZE_X, IMG_SIZE_Y))
            labels.append([index])
            images.append(img)
            print("{} class: {} {} limit: {} {}"
                  .format(count, index + 1, class_names[index], limit, imagepath))
    npy_data = np.array(images)
    npy_labels = np.array(labels)
    np.savez(data_path, x=npy_data, y=npy_labels)
    print("Save images by npz")
else:
    npy_data = np.load(data_path)["x"]
    npy_labels = np.load(data_path)["y"]
    print("Load images by npz")
image_data = npy_data
labels_data = npy_labels
print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Image Data", image_data.shape))
print("%-28s %-s" % ("Labels Data", labels_data.shape))
print("=================================================================")
train_images, test_images, train_labels, test_labels = \
    train_test_split(image_data, labels_data, test_size=0.2, random_state=6)
print('Split train and test data,p=%')
print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Train Images", train_images.shape))
print("%-28s %-s" % ("Test Images", test_images.shape))
print("%-28s %-s" % ("Train Labels", train_labels.shape))
print("%-28s %-s" % ("Test Labels", test_labels.shape))
print("=================================================================")

# 归一化
# 我们将这些值缩小到 0 到 1 之间,然后将其馈送到神经网络模型。为此,将图像组件的数据类型从整数转换为浮点数,然后除以 255。以下是预处理图像的函数:
# 务必要以相同的方式对训练集和测试集进行预处理:
print("Normalize images")
train_images = train_images / 255.0
test_images = test_images / 255.0
登录后复制

以上是Tensorflow分类器项目自定义数据读入的方法介绍(代码示例)的详细内容。更多信息请关注PHP中文网其他相关文章!

本站声明
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn

热AI工具

Undresser.AI Undress

Undresser.AI Undress

人工智能驱动的应用程序,用于创建逼真的裸体照片

AI Clothes Remover

AI Clothes Remover

用于从照片中去除衣服的在线人工智能工具。

Undress AI Tool

Undress AI Tool

免费脱衣服图片

Clothoff.io

Clothoff.io

AI脱衣机

Video Face Swap

Video Face Swap

使用我们完全免费的人工智能换脸工具轻松在任何视频中换脸!

热门文章

<🎜>:泡泡胶模拟器无穷大 - 如何获取和使用皇家钥匙
3 周前 By 尊渡假赌尊渡假赌尊渡假赌
北端:融合系统,解释
3 周前 By 尊渡假赌尊渡假赌尊渡假赌
Mandragora:巫婆树的耳语 - 如何解锁抓钩
3 周前 By 尊渡假赌尊渡假赌尊渡假赌

热工具

记事本++7.3.1

记事本++7.3.1

好用且免费的代码编辑器

SublimeText3汉化版

SublimeText3汉化版

中文版,非常好用

禅工作室 13.0.1

禅工作室 13.0.1

功能强大的PHP集成开发环境

Dreamweaver CS6

Dreamweaver CS6

视觉化网页开发工具

SublimeText3 Mac版

SublimeText3 Mac版

神级代码编辑软件(SublimeText3)

热门话题

Java教程
1666
14
CakePHP 教程
1425
52
Laravel 教程
1325
25
PHP教程
1273
29
C# 教程
1252
24
PHP和Python:解释了不同的范例 PHP和Python:解释了不同的范例 Apr 18, 2025 am 12:26 AM

PHP主要是过程式编程,但也支持面向对象编程(OOP);Python支持多种范式,包括OOP、函数式和过程式编程。PHP适合web开发,Python适用于多种应用,如数据分析和机器学习。

在PHP和Python之间进行选择:指南 在PHP和Python之间进行选择:指南 Apr 18, 2025 am 12:24 AM

PHP适合网页开发和快速原型开发,Python适用于数据科学和机器学习。1.PHP用于动态网页开发,语法简单,适合快速开发。2.Python语法简洁,适用于多领域,库生态系统强大。

sublime怎么运行代码python sublime怎么运行代码python Apr 16, 2025 am 08:48 AM

在 Sublime Text 中运行 Python 代码,需先安装 Python 插件,再创建 .py 文件并编写代码,最后按 Ctrl B 运行代码,输出会在控制台中显示。

PHP和Python:深入了解他们的历史 PHP和Python:深入了解他们的历史 Apr 18, 2025 am 12:25 AM

PHP起源于1994年,由RasmusLerdorf开发,最初用于跟踪网站访问者,逐渐演变为服务器端脚本语言,广泛应用于网页开发。Python由GuidovanRossum于1980年代末开发,1991年首次发布,强调代码可读性和简洁性,适用于科学计算、数据分析等领域。

Python vs. JavaScript:学习曲线和易用性 Python vs. JavaScript:学习曲线和易用性 Apr 16, 2025 am 12:12 AM

Python更适合初学者,学习曲线平缓,语法简洁;JavaScript适合前端开发,学习曲线较陡,语法灵活。1.Python语法直观,适用于数据科学和后端开发。2.JavaScript灵活,广泛用于前端和服务器端编程。

Golang vs. Python:性能和可伸缩性 Golang vs. Python:性能和可伸缩性 Apr 19, 2025 am 12:18 AM

Golang在性能和可扩展性方面优于Python。1)Golang的编译型特性和高效并发模型使其在高并发场景下表现出色。2)Python作为解释型语言,执行速度较慢,但通过工具如Cython可优化性能。

vscode在哪写代码 vscode在哪写代码 Apr 15, 2025 pm 09:54 PM

在 Visual Studio Code(VSCode)中编写代码简单易行,只需安装 VSCode、创建项目、选择语言、创建文件、编写代码、保存并运行即可。VSCode 的优点包括跨平台、免费开源、强大功能、扩展丰富,以及轻量快速。

notepad 怎么运行python notepad 怎么运行python Apr 16, 2025 pm 07:33 PM

在 Notepad 中运行 Python 代码需要安装 Python 可执行文件和 NppExec 插件。安装 Python 并为其添加 PATH 后,在 NppExec 插件中配置命令为“python”、参数为“{CURRENT_DIRECTORY}{FILE_NAME}”,即可在 Notepad 中通过快捷键“F6”运行 Python 代码。

See all articles