Calculate histograms along axis(沿轴计算直方图)
本文介绍了沿轴计算直方图的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
是否有办法沿ND数组的轴计算多个直方图?我当前使用方法使用for
循环迭代所有其他轴,并为每个结果一维数组计算numpy.histogram()
:
import numpy
import itertools
data = numpy.random.rand(4, 5, 6)
# axis=-1, place `200001` and `[slice(None)]` on any other position to process along other axes
out = numpy.zeros((4, 5, 200001), dtype="int64")
indices = [
numpy.arange(4), numpy.arange(5), [slice(None)]
]
# Iterate over all axes, calculate histogram for each cell
for idx in itertools.product(*indices):
out[idx] = numpy.histogram(
data[idx],
bins=2 * 100000 + 1,
range=(-100000 - 0.5, 100000 + 0.5),
)[0]
out.shape # (4, 5, 200001)
不用说这非常慢,但是我找不到使用numpy.histogram
、numpy.histogram2d
或numpy.histogramdd
解决此问题的方法。
推荐答案
这里是一种利用高效工具np.searchsorted
和np.bincount
的矢量化方法。searchsorted
为我们提供了根据垃圾箱放置每个元素的位置,并且bincount
为我们进行了计数。
实施-
def hist_laxis(data, n_bins, range_limits):
# Setup bins and determine the bin location for each element for the bins
R = range_limits
N = data.shape[-1]
bins = np.linspace(R[0],R[1],n_bins+1)
data2D = data.reshape(-1,N)
idx = np.searchsorted(bins, data2D,'right')-1
# Some elements would be off limits, so get a mask for those
bad_mask = (idx==-1) | (idx==n_bins)
# We need to use bincount to get bin based counts. To have unique IDs for
# each row and not get confused by the ones from other rows, we need to
# offset each row by a scale (using row length for this).
scaled_idx = n_bins*np.arange(data2D.shape[0])[:,None] + idx
# Set the bad ones to be last possible index+1 : n_bins*data2D.shape[0]
limit = n_bins*data2D.shape[0]
scaled_idx[bad_mask] = limit
# Get the counts and reshape to multi-dim
counts = np.bincount(scaled_idx.ravel(),minlength=limit+1)[:-1]
counts.shape = data.shape[:-1] + (n_bins,)
return counts
运行时测试
原始方法-
def org_app(data, n_bins, range_limits):
R = range_limits
m,n = data.shape[:2]
out = np.zeros((m, n, n_bins), dtype="int64")
indices = [
np.arange(m), np.arange(n), [slice(None)]
]
# Iterate over all axes, calculate histogram for each cell
for idx in itertools.product(*indices):
out[idx] = np.histogram(
data[idx],
bins=n_bins,
range=(R[0], R[1]),
)[0]
return out
计时和验证-
In [2]: data = np.random.randn(4, 5, 6)
...: out1 = org_app(data, n_bins=200001, range_limits=(- 2.5, 2.5))
...: out2 = hist_laxis(data, n_bins=200001, range_limits=(- 2.5, 2.5))
...: print np.allclose(out1, out2)
...:
True
In [3]: %timeit org_app(data, n_bins=200001, range_limits=(- 2.5, 2.5))
10 loops, best of 3: 39.3 ms per loop
In [4]: %timeit hist_laxis(data, n_bins=200001, range_limits=(- 2.5, 2.5))
100 loops, best of 3: 3.17 ms per loop
因为,在循环解中,我们循环通过前两个轴。因此,让我们增加它们的长度,因为这将向我们展示矢量化的有多好-
In [59]: data = np.random.randn(400, 500, 6)
In [60]: %timeit org_app(data, n_bins=21, range_limits=(- 2.5, 2.5))
1 loops, best of 3: 9.59 s per loop
In [61]: %timeit hist_laxis(data, n_bins=21, range_limits=(- 2.5, 2.5))
10 loops, best of 3: 44.2 ms per loop
In [62]: 9590/44.2 # Speedup number
Out[62]: 216.9683257918552
这篇关于沿轴计算直方图的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持编程学习网!
沃梦达教程
本文标题为:沿轴计算直方图
猜你喜欢
- pytorch 中的自适应池是如何工作的? 2022-07-12
- python-m http.server 443--使用SSL? 2022-01-01
- 如何在 Python 的元组列表中对每个元组中的第一个值求和? 2022-01-01
- padding='same' 转换为 PyTorch padding=# 2022-01-01
- python check_output 失败,退出状态为 1,但 Popen 适用于相同的命令 2022-01-01
- 使用Heroku上托管的Selenium登录Instagram时,找不到元素';用户名'; 2022-01-01
- 如何将一个类的函数分成多个文件? 2022-01-01
- 如何在 python3 中将 OrderedDict 转换为常规字典 2022-01-01
- 分析异常:路径不存在:dbfs:/databricks/python/lib/python3.7/site-packages/sampleFolder/data; 2022-01-01
- 沿轴计算直方图 2022-01-01