关于python:如何检查numpy数组列表是否包含给定的测试数组?

How to check if a list of numpy arrays contains a given test array?

我有一个numpy数组的列表,比如,

1
a = [np.random.rand(3, 3), np.random.rand(3, 3), np.random.rand(3, 3)]

我有一个测试阵列,比如

1
b = np.random.rand(3, 3)

我想检查一下a是否含有b。然而

1
b in a

引发以下错误:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

我想要什么,正确的方法是什么?


您只需在a中制作一个形状为(3, 3, 3)的数组:

1
a = np.asarray(a)

然后将它与b进行比较(我们在这里比较浮点数,所以应该使用isclose())

1
np.all(np.isclose(a, b), axis=(1, 2))

例如:

1
2
3
4
5
6
a = [np.random.rand(3,3),np.random.rand(3,3),np.random.rand(3,3)]
a = np.asarray(a)
b = a[1, ...]       # set b to some value we know will yield True

np.all(np.isclose(a, b), axis=(1, 2))
# array([False,  True, False])

如@jotasi所强调的,由于数组中的元素比较,真值不明确。这个问题以前有个答案。总的来说,你的任务可以通过各种方式完成:

  • 列表到数组:
  • 通过将列表转换为(3,3,3)形数组,可以使用"in"运算符,如下所示:

    1
    2
    3
    4
    5
        >>> a = [np.random.rand(3, 3), np.random.rand(3, 3), np.random.rand(3, 3)]
        >>> a= np.asarray(a)
        >>> b= a[1].copy()
        >>> b in a
        True
  • NP.ALL:

    1
    2
    >>> any(np.all((b==a),axis=(1,2)))
    True
  • 列表组成:这是通过迭代每个数组来完成的:

    1
    2
    >>> any([(b == a_s).all() for a_s in a])
    True
  • 下面是上述三种方法的速度比较:

    速度比较

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    import numpy as np
    import perfplot

    perfplot.show(
        setup=lambda n: np.asarray([np.random.rand(3*3).reshape(3,3) for i in range(n)]),
        kernels=[
            lambda a: a[-1] in a,
            lambda a: any(np.all((a[-1]==a),axis=(1,2))),
            lambda a: any([(a[-1] == a_s).all() for a_s in a])
            ],
        labels=[
            'in', 'np.all', 'list_comperhension'
            ],
        n_range=[2**k for k in range(1,20)],
        xlabel='Array size',
        logx=True,
        logy=True,
        )


    使用数组"从numpy开始相等"

    1
    2
    3
    4
    5
    6
    7
        import numpy as np
        a = [np.random.rand(3,3),np.random.rand(3,3),np.random.rand(3,3)]
        b = np.random.rand(3,3)

        for i in a:
            if np.array_equal(b,i):
                print("yes")


    如本回答中所指出的,文件规定:

    For container types such as list, tuple, set, frozenset, dict, or collections.deque, the expression x in y is equivalent to any(x is e or x == e for e in y).

    不过,a[0]==b是一个数组,包含了a[0]b的元素比较。此数组的整体真值明显不明确。如果所有元素都匹配,或者如果大多数元素至少匹配一个元素,它们是相同的吗?因此,numpy迫使你明确你的意思。你想知道的是,测试所有元素是否相同。您可以使用numpyall方法:

    1
    any((b is e) or (b == e).all() for e in a)

    或者加入一个函数:

    1
    2
    3
    def numpy_in(arrayToTest, listOfArrays):
        return any((arrayToTest is e) or (arrayToTest == e).all()
                   for e in listOfArrays)

    这个错误是因为如果abnumpy arrays,那么a == b不会返回TrueFalse,但是在比较ab元素之后,boolean值的array

    您可以尝试如下操作:

    1
    np.any([np.all(a_s == b) for a_s in a])
    • [np.all(a_s == b) for a_s in a]这里创建boolean值列表,遍历a的元素,检查ba的特定元素中的所有元素是否相同。

    • 使用np.any可以检查数组中是否有元素是True元素。


    好的,所以in不起作用,因为它有效地起作用了。

    1
    2
    3
    4
    5
    def in_(obj, iterable):
        for elem in iterable:
            if obj == elem:
                return True
        return False

    现在的问题是,对于两个星期的aba == b是一个数组(尝试它),而不是布尔值,所以if a == b失败。解决方法是定义一个新函数

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    def array_in(arr, list_of_arr):
         for elem in list_of_arr:
            if (arr == elem).all():
                return True
         return False

    a = [np.arange(5)] * 3
    b = np.ones(5)

    array_in(b, a) # --> False