Tensorflow两类版本获取模型参数情况的方法

来源:人工智能技术干货 - AITechStudy

Tensorflow存在不同版本架构的差异,以Tensorflow 2.0为分界线,Tensorflow 2.0以下的版本都是以静态图构建图结构,且一般是调用TF内建接口来构造神经网络模型的,而Tensorflow 2.0及以上版本则默认以动态图构建图结构,且默认是以Keras API接口来构建神经网络模型的,因此,获取模型参数情况的方法也因不同版本的差异而有所不同。本篇文章,将提供Tensorflow两类版本获取模型参数情况的不同方法,从而让我们可以轻松获取所训练模型的总参数量、可训练参数量及非可训练参数量情况。

01

Tensorflow V0.x及V1.x版本提供了tf.global_variables()这个可获知全局变量的接口,我们将以该接口来获得在Tensorflow静态图模式下的模型参数情况。该接口返回的是'tensorflow.python.ops.variables.RefVariable'类型,我们需要遍历其返回变量表并获取每一个变量的shape,之后我们可依据Tensorflow的Tensor是否trainable来区分可训练参数和非可训练参数。值得注意的是,在这所获得的数据类型均是Tensorflow的某种内建类型,没有经过session.run()我们拿到的还不是实际数据,要计算参数量,我们还需要将其转为numpy数组,以便更好的计算参数量,具体方法如下:

# 定义总参数量、可训练参数量及非可训练参数量变量
Total_params = 0 
Trainable_params = 0
NonTrainable_params = 0

# 遍历tf.global_variables()返回的全局变量列表
for var in tf.global_variables():
    shape = var.shape # 获取每个变量的shape,其类型为'tensorflow.python.framework.tensor_shape.TensorShape'
    array = np.asarray([dim.value for dim in shape]) # 转换为numpy数组,方便后续计算
    mulValue = np.prod(array) # 使用numpy prod接口计算数组所有元素之积

    Total_params += mulValue # 总参数量
    if var.trainable:
        Trainable_params += mulValue # 可训练参数量
    else:
        NonTrainable_params += mulValue # 非可训练参数量

print(f'Total params: {Total_params}')
print(f'Trainable params: {Trainable_params}')
print(f'Non-trainable params: {NonTrainable_params}')

通过该方法,我们即可得到Tensorflow在静态图模式下的模型参数情况,如果要获得粗略的模型大小,则可以将参数量除以1e6并乘以4即可得大概的模型大小(单位MB)。

02

Tensorflow V2.x及其之后的版本,如要获取模型参数情况,则因为其默认使用Keras API构建神经网络模型而变得非常简单,不需要我们自己设计代码获取模型参数情况,仅需要调用其提供的模型类型接口summary()即可得到我们想要的信息,方法如下:

# 构建神经网络模型, ...表示神经网络各个Op连接构造
model = tf.keras.Sequential(...)
model.summary() # 调用模型summary()接口直接获取模型参数情况

我们以MobileNetV3_large模型为例,通过该方法获取到的模型参数情况信息一般格式如下图所示:

Tensorflow两类版本获取模型参数情况的方法

从上图我们可以看到,Tensorflow V2.x动态图版本的模型summary()接口不仅为我们输出了各个类别的参数量情况,还为我们输出了每层网络的参数量,并提供了额外的层名称及其输出shape大小,这极大地方便了我们查看模型构造的情况。

以上即为Tensorflow两类版本获取模型参数情况的方法,仅供参考!

来源:人工智能技术干货 - AITechStudy ,转载此文目的在于传递更多信息,版权归原作者所有。

最新文章