关于python:从numpy ndarray中有效地采样矢量

Efficiently sample vectors from numpy ndarray

我有一个shape的多维numpy数组X(B, dim, H, W)我想从X中随机抽取kdim维向量。我可以从形状为(B, 1, H, W)msk中得到样本指数:

1
sIdx = random.sample((msk.flat>=0).nonzero()[0], k)

使用numpy的等效采样代码为:

1
sIdx = np.random.choice((msk.flat>=0).nonzero()[0], replace=False, size=(k,))

但如何根据"平面"抽样指标sIdx有效地分割X?也就是说,有没有一种有效的方法将随机抽取的msk与切片的X相结合?


从展平的指数中得到这三个轴中其余带有np.unravel_index的各指数,然后沿着这些轴简单地索引到X中,以获得最终输出,就像这样。-

1
2
I,J,K = np.unravel_index(sIdx, (B, H, W))
out = X[I,:,J,K]