在 Tensorflow无人车使用移动端的SSD(单发多框检测)来识别物体及Graph的认识 中我们对Graph这个计算图有了一定的了解,也知道了它具备的优点:性能做了提升,可以并行处理以及由于它是一种数据结构,可以在非Python环境中进行交互。
我们先来看下自己的tensorflow的版本:
print(tf.__version__) # 2.11.0
目前基本上都是2.0以上,不过这个Session的用法在tensorflow2.0版本之后就没有了,所以大家在上一篇文章看到的是我使用的兼容1.0版本的用法:tf.compat.v1.Session(graph=g1)
如果是直接去调用的话:tf.compat.v1.Session(graph=g1) 就会报下面这样的错误:
AttributeError: module 'tensorflow' has no attribute 'Session'
于是到了2.0版本之后,我们使用function来代替Graph!
1、tf.function声明
我们具体来看下,2.0及以上版本中的tf.function是如何使用的。
import tensorflow as tf
# 常规函数
def f(x, w, b):
y = tf.matmul(x, w) # 矩阵乘法就是dot点积运算
return y + b
tf_f = tf.function(f)
print(type(f),type(tf_f))
# <class 'function'>
# <class 'tensorflow.python.eager.polymorphic_function.polymorphic_function.Function'>
可以看到这里的tf.function和平时定义的def的这个函数类型是不一样,def的类型就是function,而tf.function(函数名)得到的类型是
tensorflow.python.eager.polymorphic_function.polymorphic_function.Function
eager:渴望的,急切的,这里就是一种即时模型的意思,polymorphic:意思来看是,多形态的,多态的。
函数定义好了之后,我们来看下是如何调用并计算的
c1 = tf.constant([[1.0, 2.0]])
c2 = tf.constant([[2.0], [3.0]])
c3 = tf.constant(4.0)
f_value = f(c1, c2, c3).numpy()
tf_value = tf_f(c1, c2, c3).numpy()
print(f_value,tf_value)#[[12.]] [[12.]]
得到的结果是一样的,那我们引入这个tf.function的作用是什么呢?接着往下看
2、@tf.function装饰器
上面我们可以看到 tf.function 的类型,虽然也是函数,但跟常规函数还是有很大区别,因为我们的目的是能够代替Graph,而使用计算图的目的又是为了性能的提升,所以应该知道这个函数所要表达的意思了吧,在这里我们可以使用 @tf.function 装饰器,就可以让这种即时执行模式的控制流转换成计算图的方式了。
其中matmul是MatrixMultiple的缩写,矩阵乘法的意思,也就是在numpy中的dot点积运算的用法(行乘以列的和)
实际上,这个tf.function可能封装多个tf.graph,所以这两种不同的函数表达,在性能和部署上存在很大的不同。
import tensorflow as tf
def inner_function(x, w, b):
x = tf.matmul(x, w)
return x + b
# 使用装饰器来定义函数
@tf.function
def outer_function(x):
w = tf.constant([[2.0], [3.0]])
b = tf.constant(4.0)
return inner_function(x, w, b)
# 创建一个Graph计算图,里面包含inner_function和outer_function
r1 = outer_function(tf.constant([[1.0, 2.0]])).numpy()
r2 = outer_function(tf.constant([[1.0, 2.0],[3.0, 4.0],[5.0, 6.0]])).numpy()
print(r1)
print(r2)
'''
[[12.]]
[[12.]
[22.]
[32.]]
'''
这里使用了一个@tf.function装饰器来声明这个函数为多态函数,我们来打印看下它的具体特征:
print(outer_function.pretty_printed_concrete_signatures())
'''
outer_function(x)
Args:
x: float32 Tensor, shape=(1, 2)
Returns:
float32 Tensor, shape=(1, 1)
outer_function(x)
Args:
x: float32 Tensor, shape=(3, 2)
Returns:
float32 Tensor, shape=(3, 1)
'''
我使用了两种形状的输入,这里也对应出现两种形式的计算图。这种多态的作用是可以用来提升性能,因为可以判断输入的类型(以及形状),如果是一样的形状,同类型的就不需要新建计算图,我们接着来看下
r3 = outer_function(tf.constant([[11.0, 22.0],[3.0, 4.0],[5.0, 6.0]])).numpy()
print(outer_function.pretty_printed_concrete_signatures())
'''
outer_function(x)
Args:
x: float32 Tensor, shape=(1, 2)
Returns:
float32 Tensor, shape=(1, 1)
outer_function(x)
Args:
x: float32 Tensor, shape=(3, 2)
Returns:
float32 Tensor, shape=(3, 1)
'''
可以看到结果是一样的,对于这样的输入,因为r3跟r2的类型形状是一样的,所以r3可以使用r2的,那么从另一角度可以理解为缓存,当数据类型或形状不一致的时候才会创建新的计算图。
r4 = outer_function([[11.0, 22.0],[3.0, 4.0],[5.0, 6.0]]).numpy()
print(outer_function.pretty_printed_concrete_signatures())
'''
outer_function(x)
Args:
x: float32 Tensor, shape=(1, 2)
Returns:
float32 Tensor, shape=(1, 1)
outer_function(x)
Args:
x: float32 Tensor, shape=(3, 2)
Returns:
float32 Tensor, shape=(3, 1)
outer_function(x=[[11.0, 22.0], [3.0, 4.0], [5.0, 6.0]])
Returns:
float32 Tensor, shape=(3, 1)
'''
这里的r4虽然输出是跟r3一样,不过这里的输入类型不一致,所以还是会新建一个。
3、tf.autograph
现在我们又回到最开始的tf.function,它的本质其实是对原函数做了转换,函数体做了新的变化处理。依然是上面的示例,我们查看下它的本质:
import tensorflow as tf
# 常规函数
def f(x, w, b):
y = tf.matmul(x, w) # 矩阵乘法就是dot点积运算
return y + b
tf_f = tf.function(f)
w = tf.constant([[2.0], [3.0]])
b = tf.constant(4.0)
print(tf_f(tf.constant([[1.0, 2.0]]),w,b).numpy()) # [[12.]]
print(tf.autograph.to_code(f))
'''
def tf__f(x, w, b):
with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
do_return = False
retval_ = ag__.UndefinedReturnValue()
y = ag__.converted_call(ag__.ld(tf).matmul, (ag__.ld(x), ag__.ld(w)), None, fscope)
try:
do_return = True
retval_ = (ag__.ld(y) + ag__.ld(b))
except:
do_return = False
raise
return fscope.ret(retval_, do_return)
'''
可以看到这里是将原函数 f 转成 tf__f 函数,其函数体是做了另外的处理,里面结构也是很类似的。我们打印 tf__f 这个graph计算图的详情看下:
print(tf_f.get_concrete_function(tf.constant([[1.0, 2.0]]),w,b).graph.as_graph_def())
'''
node {
name: "x"
op: "Placeholder"
attr {
key: "_user_specified_name"
value {
s: "x"
}
}
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 1
}
dim {
size: 2
}
}
}
}
}
node {
name: "w"
op: "Placeholder"
attr {
key: "_user_specified_name"
value {
s: "w"
}
}
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 2
}
dim {
size: 1
}
}
}
}
}
node {
name: "b"
op: "Placeholder"
attr {
key: "_user_specified_name"
value {
s: "b"
}
}
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
}
}
}
}
node {
name: "MatMul"
op: "MatMul"
input: "x"
input: "w"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "transpose_a"
value {
b: false
}
}
attr {
key: "transpose_b"
value {
b: false
}
}
}
node {
name: "add"
op: "AddV2"
input: "MatMul"
input: "b"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
node {
name: "Identity"
op: "Identity"
input: "add"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
versions {
producer: 1286
}
'''
可以看到每个节点里面是名称、属性、类型、操作等详细的说明,这也印证了最开始说的这个Graph计算图是一种数据结构,数据类型是
<class 'tensorflow.core.framework.graph_pb2.GraphDef'>
4、追踪
graph里的Tracing也是其一个特性,这里的print不在追踪范围内,所以虽然调用了三次,结果只输出一次!
import tensorflow as tf
#tf.config.run_functions_eagerly(False)
@tf.function
def mse(y_true, y_pred):
print("计算均方误差")
tf.print("均方误差")
sq_diff = tf.pow(y_true - y_pred, 2)
return tf.reduce_mean(sq_diff)
tf.config.run_functions_eagerly(False)
y_true = tf.random.uniform([5], maxval=10, dtype=tf.int32)
y_pred = tf.random.uniform([5], maxval=10, dtype=tf.int32)
print(y_true,y_pred)
mse_r = mse(y_true, y_pred)
mse_r = mse(y_true, y_pred)
mse_r = mse(y_true, y_pred)
print(mse_r.numpy())
'''
tf.Tensor([8 0 6 3 9], shape=(5,), dtype=int32) tf.Tensor([9 5 9 3 9], shape=(5,), dtype=int32)
计算均方误差
均方误差
均方误差
均方误差
7
'''
其中的tf.print是可以追踪的,所以每次的调用都会输出。
5、非严格执行
tf.graph计算图只关心需要的操作,其余的不会关心,就算是错误的情况也不处理。
def t(x):
tf.gather(x, [3])
return x
try:
print(t(tf.constant([0.0])))
except tf.errors.InvalidArgumentError as e:
print(f'{type(e).__name__}: {e}')
gather用法:
help(tf.gather)
gather_v2(params, indices, validate_indices=None, axis=None, batch_dims=0, name=None)
params = tf.constant(['p0', 'p1', 'p2', 'p3', 'p4', 'p5'])
params[3].numpy() # b'p3'
indices = [2, 0, 2, 5]
tf.gather(params, indices).numpy() #array([b'p2', b'p0', b'p2', b'p5'], dtype=object)
这里的tf.gather调用,我们可以很明显知道,在输入是[0]的情况,[3]索引的数据是不存在的,所以会报错:
InvalidArgumentError: {{function_node __wrapped__GatherV2_device_/job:localhost/replica:0/task:0/device:CPU:0}} indices[0] = 3 is not in [0, 1) [Op:GatherV2]
而当我们使用@tf.function装饰器来装饰这个函数的时候,会发现即便有错误也不会执行。 文章来源:https://www.toymoban.com/news/detail-610224.html
@tf.function
def t(x):
tf.gather(x, [3])
return x
print(t(tf.constant([0.0])))#tf.Tensor([0.], shape=(1,), dtype=float32)
这也再次说明了计算图只关心流程图,里面的具体计算不会去验证。其中tf.gather的用法就是在指定维度抽取数据,这个用法在很多情况使用起来特别有用,我们再来看一个示例:文章来源地址https://www.toymoban.com/news/detail-610224.html
import tensorflow as tf
a = tf.range(8)
a = tf.reshape(a, [4,2])
print(a)
print(tf.gather(a, [3,1,0], axis=0))
'''
tf.Tensor(
[[0 1]
[2 3]
[4 5]
[6 7]], shape=(4, 2), dtype=int32)
tf.Tensor(
[[6 7]
[2 3]
[0 1]], shape=(3, 2), dtype=int32)
'''
到了这里,关于Tensorflow2.0中function(是1.0版本的Graph的推荐替代)的相关知识介绍的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!