How to check if a list of numpy arrays contains a given test array?
我有一个
1 | a = [np.random.rand(3, 3), np.random.rand(3, 3), np.random.rand(3, 3)] |
我有一个测试阵列,比如
1 | b = np.random.rand(3, 3) |
我想检查一下
1 | b in a |
引发以下错误:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
我想要什么,正确的方法是什么?
您只需在
1 | a = np.asarray(a) |
然后将它与
1 | np.all(np.isclose(a, b), axis=(1, 2)) |
例如:
1 2 3 4 5 6 | a = [np.random.rand(3,3),np.random.rand(3,3),np.random.rand(3,3)] a = np.asarray(a) b = a[1, ...] # set b to some value we know will yield True np.all(np.isclose(a, b), axis=(1, 2)) # array([False, True, False]) |
如@jotasi所强调的,由于数组中的元素比较,真值不明确。这个问题以前有个答案。总的来说,你的任务可以通过各种方式完成:
通过将列表转换为(3,3,3)形数组,可以使用"in"运算符,如下所示:
1 2 3 4 5 | >>> a = [np.random.rand(3, 3), np.random.rand(3, 3), np.random.rand(3, 3)] >>> a= np.asarray(a) >>> b= a[1].copy() >>> b in a True |
NP.ALL:
1 2 | >>> any(np.all((b==a),axis=(1,2))) True |
列表组成:这是通过迭代每个数组来完成的:
1 2 | >>> any([(b == a_s).all() for a_s in a]) True |
下面是上述三种方法的速度比较:
速度比较
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | import numpy as np import perfplot perfplot.show( setup=lambda n: np.asarray([np.random.rand(3*3).reshape(3,3) for i in range(n)]), kernels=[ lambda a: a[-1] in a, lambda a: any(np.all((a[-1]==a),axis=(1,2))), lambda a: any([(a[-1] == a_s).all() for a_s in a]) ], labels=[ 'in', 'np.all', 'list_comperhension' ], n_range=[2**k for k in range(1,20)], xlabel='Array size', logx=True, logy=True, ) |
使用数组"从numpy开始相等"
1 2 3 4 5 6 7 | import numpy as np a = [np.random.rand(3,3),np.random.rand(3,3),np.random.rand(3,3)] b = np.random.rand(3,3) for i in a: if np.array_equal(b,i): print("yes") |
如本回答中所指出的,文件规定:
For container types such as list, tuple, set, frozenset, dict, or collections.deque, the expression x in y is equivalent to any(x is e or x == e for e in y).
不过,
1 | any((b is e) or (b == e).all() for e in a) |
或者加入一个函数:
1 2 3 | def numpy_in(arrayToTest, listOfArrays): return any((arrayToTest is e) or (arrayToTest == e).all() for e in listOfArrays) |
这个错误是因为如果
您可以尝试如下操作:
1 | np.any([np.all(a_s == b) for a_s in a]) |
[np.all(a_s == b) for a_s in a] 这里创建boolean 值列表,遍历a 的元素,检查b 和a 的特定元素中的所有元素是否相同。使用
np.any 可以检查数组中是否有元素是True 元素。
好的,所以
1 2 3 4 5 | def in_(obj, iterable): for elem in iterable: if obj == elem: return True return False |
现在的问题是,对于两个星期的
1 2 3 4 5 6 7 8 9 10 | def array_in(arr, list_of_arr): for elem in list_of_arr: if (arr == elem).all(): return True return False a = [np.arange(5)] * 3 b = np.ones(5) array_in(b, a) # --> False |