pytorch和tensorflow计算Flops和params的详细过程 目录 pytorch和tensorflow计算Flops和params 1.只计算params 2.计算flops和params 3.tensorflow计算params和flops pytorch和tensorflow计算Flops和params 1.只计算params net = model() # 定义好的网络模型 total = sum([param.nelement() for param in
目录
- pytorch和tensorflow计算Flops和params
- 1.只计算params
- 2.计算flops和params
- 3.tensorflow计算params和flops
pytorch和tensorflow计算Flops和params
1.只计算params
net = model() # 定义好的网络模型
total = sum([param.nelement() for param in net.parameters()])
print("Number of parameter: %.2fM" % total)
这是网上很常见的直接用自带方法计算params,基本不会出错。胜在简洁。
2.计算flops和params
要计算flops,目前没见到用自带方法计算的,基本都是要安装别的库。
这边我们安装thop库。
pip install thop # 安装thop库
import torch
from thop import profile
net = model() # 定义好的网络模型
img1 = torch.randn(1, 3, 512, 512)
img2 = torch.randn(1, 3, 512, 512)
img3 = torch.randn(1, 3, 512, 512)
macs, params = profile(net, (img1,img2,img3))
print('flops: ', 2*macs, 'params: ', params)
这边和其他网上教程的区别便是,他们macs和flops不分。因为macs表示乘加累积操作数,一个乘法加上一个加法才算一个macs。而flops表示浮点运算次数,每一个加、减、乘、除操作都算1FLOPs操作。所以很明显,在数值上,1flops=2macs。此外,(img1,img2,img3)就表示你如果有三个输入要输入模型,就这样写。
另外,要注意,params只和模型参数量相关,而和输入tensor大小无关。但flops和输入图片大小是相关的.
3.tensorflow计算params和flops
此处是我找到的一些用于tensorflow计算params和flops的方法,仅供参考,不保证效果。
def get_flops_params():
sess = tf.compat.v1.Session()
graph = sess.graph
flops = tf.compat.v1.profiler.profile(graph, options=tf.compat.v1.profiler.ProfileOptionBuilder.float_operation())
params = tf.compat.v1.profiler.profile(graph,
options=tf.compat.v1.profiler.ProfileOptionBuilder.trainable_variables_parameter())
print('FLOPs: {}; Trainable params: {}'.format(flops.total_float_ops, params.total_parameters))
def count2():
print(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]))
def get_nb_params_shape(shape):
'''
Computes the total number of params for a given shap.
Works for any number of shapes etc [D,F] or [W,H,C] computes D*F and W*H*C.
'''
nb_params = 1
for dim in shape:
nb_params = nb_params * int(dim)
return nb_params
def count3():
tot_nb_params = 0
for trainable_variable in tf.trainable_variables():
shape = trainable_variable.get_shape() # e.g [D,F] or [W,H,C]
current_nb_params = get_nb_params_shape(shape)
tot_nb_params = tot_nb_params + current_nb_params
print(tot_nb_params)
import tensorflow.compat.v1 as tf
tf.compat.v1.disable_eager_execution()
from model import Model
import keras.backend as K
def get_flops(model):
run_meta = tf.RunMetadata()
opts = tf.profiler.ProfileOptionBuilder.float_operation()
# We use the Keras session graph in the call to the profiler.
flops = tf.profiler.profile(graph=K.get_session().graph,
run_meta=run_meta, cmd='op', options=opts)
return flops.total_float_ops # Prints the "flops" of the model.
# .... Define your model here ....
M = Model(BATCH_SIZE=1, INPUT_H=268, INPUT_W=360, is_training=False)
print(get_flops(M))
到此这篇关于pytorch和tensorflow计算Flops和params的文章就介绍到这了,更多相关pytorch和tensorflow计算内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!
本文标题为:pytorch和tensorflow计算Flops和params的详细过程


- python中列表添加元素的几种方式(+、append()、ext 2022-09-02
- Python之路-Python中的线程与进程 2023-09-04
- CentOS7 安装 Python3.6 2023-09-04
- Python Pandas如何获取和修改任意位置的值(at,iat,loc,iloc) 2023-08-04
- python中defaultdict用法实例详解 2022-10-20
- windows安装python2.7.12和pycharm2018教程 2023-09-03
- python线程池ThreadPoolExecutor与进程池ProcessPoolExecutor 2023-09-04
- Python实现将DNA序列存储为tfr文件并读取流程介绍 2022-10-20
- 在centos6.4下安装python3.5 2023-09-04
- Python 保存数据的方法(4种方法) 2023-09-04