ToTensor() 是pytorch中的数据预处理函数,包含在 torchvision.transforms 模块下。一般用于处理图像数据,所以其处理对象是 PIL Image 和 numpy.ndarray 。
1、ToTensor() 函数的作用
必须要声明不能只看函数名,就以为 ToTensor() 只是将图像转为 tensor,其实它的功能不止于此
看一下 ToTensor() 函数的源码:
class ToTensor:
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
or if the numpy.ndarray has dtype = np.uint8
In the other cases, tensors are returned without scaling.
.. note::
Because the input image is scaled to [0.0, 1.0], this transformation should not be used when
transforming target image masks. See the `references`_ for implementing the transforms for image masks.
.. _references: https://github.com/pytorch/vision/tree/main/references/segmentation
"""
大意是:
(1)将 PIL Image 或 numpy.ndarray 转为 tensor
(2)如果 PIL Image 属于 (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) 中的一种图像类型,或者 numpy.ndarray 格式数据类型是 np.uint8 ,则将 [0, 255] 的数据转为 [0.0, 1.0] ,也就是说将所有数据除以 255 进行归一化。
(3)将 HWC 的图像格式转为 CHW 的 tensor 格式。CNN训练时需要的数据格式是[N,C,N,W],也就是说经过 ToTensor() 处理的图像可以直接输入到CNN网络中,不需要再进行reshape。
2、读取图像时 PIL 和 opencv 的选择
在自己建立 dataset 迭代器时,一般操作是检索数据集图像的路径,然后使用 PIL 库或 opencv库读取图片路径。
2.1 使用 PIL
import numpy as np
from PIL import Image
filePath="Dataset/FFHQ/00000.png"
img1=Image.open(filePath)
print(f"img1 = {img1}")
# img1 = <PIL.PngImagePlugin.PngImageFile image mode=RGB size=128x128 at 0x253DC205A88>
img2 = np.array(img1)
print(f"img2 = {img2}")
"""
img2 = [[[ 0 130 146]
[ 0 128 144]
[ 0 125 141]
...
[133 162 164]
[133 157 159]
[134 157 163]]]
"""
可以看到,使用 PIL.Image 读取的图像是一种 PIL 类,mode=RGB,要想获得图像的像素值还需要将其转为 np.array 格式。
而 opencv 可以直接将图像读取为 np.array 格式,因此首选 opencv 。
2.2 使用 opencv
import cv2
filePath="Dataset/FFHQ/00000.png"
img=cv2.imread(filePath)
print(f"img.shape = {img.shape}") # img.shape = (128, 128, 3)
print(f"img = {img}") # img.dtype = uint8
"""
img = [[[146 130 0]
[144 128 0]
[141 125 0]
...
[164 162 133]
[159 157 133]
[163 157 134]]]
"""
仔细对比PIL 和 opencv 的输出结果可以发现,PIL默认输出的图片格式为 RGB,而opencv输出的是BGR格式。
使用opencv读取的图像是[H,W,C]大小的,数据格式是 np.uint8 ,经过 ToTensor() 会进行归一化。而其他的数据类型(如 np.int8)经过 ToTensor() 数值不变,不进行归一化,后面会详细讲述。并且经过ToTensor()后图像格式变为 [C,H,W]。
3、ToTensor() 的使用
3.1 关键知识点
不管是使用 PLT还是opencv,最终得到都是 np.array类型。因此:
ToTensor() 是将 np.array 的数据 转为 tensor 格式
这里一定要明确几个点:
(1)np.array 整型的默认数据类型为 np.int32,经过 ToTensor() 后数值不变,不进行归一化。
(2)np.array 浮点型的默认数据类型为 np.float64,经过 ToTensor() 后数值不变,不进行归一化。
(3)opencv 读取的图像格式为 np.array,其数据类型为 np.uint8
经过 ToTensor() 后数值由 [0,255] 变为 [0,1],通过将每个数据除以255进行归一化。
(4)经过 ToTensor() 后,HWC 的图像格式变为 CHW 的 tensor 格式。
(5)np.uint8 和 np.int8 不一样,uint8是无符号整型,数值都是正数。
(6)ToTensor() 可以处理任意 shape 的 np.array,并不只是三通道的图像数据。
3.2 代码示例
下面通过代码熟悉 ToTensor() 的使用,分别看一下 np.uint8 和 非 np.uint8 类型的 np.array 经过 ToTensor() 之后的输出。
(1) np.uint8 类型文章来源:https://www.toymoban.com/news/detail-433792.html
import numpy as np
from torchvision import transforms
data = np.array([
[0, 5, 10, 20, 0],
[255, 125, 180, 255, 196]
], dtype=np.uint8)
tensor = transforms.ToTensor()(data)
print(tensor)
"""
tensor([[[0.0000, 0.0196, 0.0392, 0.0784, 0.0000],
[1.0000, 0.4902, 0.7059, 1.0000, 0.7686]]])
"""
(2)非 np.uint8 类型文章来源地址https://www.toymoban.com/news/detail-433792.html
import numpy as np
from torchvision import transforms
data = np.array([
[0, 5, 10, 20, 0],
[255, 125, 180, 255, 196]
]) # data.dtype = int32
tensor = transforms.ToTensor()(data)
print(tensor)
"""
tensor([[[ 0, 5, 10, 20, 0],
[255, 125, 180, 255, 196]]], dtype=torch.int32)
"""
到了这里,关于torchvision.transforms 数据预处理:ToTensor()的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!