Test if a numpy array is a member of a list of numpy arrays, and remove it from the list
当测试numpy数组
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | import numpy as np c = np.array([[[ 75, 763]], [[ 57, 763]], [[ 57, 749]], [[ 75, 749]]]) CNTS = [np.array([[[ 78, 1202]], [[ 63, 1202]], [[ 63, 1187]], [[ 78, 1187]]]), np.array([[[ 75, 763]], [[ 57, 763]], [[ 57, 749]], [[ 75, 749]]]), np.array([[[ 72, 742]], [[ 58, 742]], [[ 57, 741]], [[ 57, 727]], [[ 58, 726]], [[ 72, 726]]]), np.array([[[ 66, 194]], [[ 51, 194]], [[ 51, 179]], [[ 66, 179]]])] print(c in CNTS) |
我得到:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
然而,答案相当清楚:
如何正确测试numpy数组是否是numpy数组列表的成员?
移除时也会出现同样的问题:
1 | CNTS.remove(c) |
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
应用程序:测试
您得到错误是因为
1 2 3 4 5 6 7 8 | >>> c == CNTS[1] array([[[ True, True]], [[ True, True]], [[ True, True]], [[ True, True]]]) >>> bool(_) ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() |
这同样适用于删除,因为它测试每个元素的相等性。
遏制
解决方法是使用
1 | any(np.array_equal(c, x) for x in CNTS) |
或
1 | any((c == x).all() for x in CNTS) |
去除
要执行删除,您对元素的索引比对它的存在更感兴趣。我能想到的最快方法是迭代索引,使用
1 | index = next((i for i, x in enumerate(CNTS) if (c == x).all()), -1) |
这个选项很好地短路,返回
现在可以像往常一样删除:
1 | del CNTS[index] |
此解决方案可用于以下情况:
1 2 3 4 5 | def arrayisin(array, list_of_arrays): for a in list_of_arrays: if np.array_equal(array, a): return True return False |
此函数迭代数组列表,并针对其他数组测试相等性。所以用法是:
1 2 | >>> arrayisin(c, CNTS) True |
要从列表中删除数组,可以获取数组的索引,然后使用
1 2 3 4 5 6 7 8 | def get_index(array, list_of_arrays): for j, a in enumerate(list_of_arrays): if np.array_equal(array, a): return j return None idx = get_index(c, CNTS) # 1 CNTS.pop(idx) |
有关
使用
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 | del CNTS[int(np.where(list(np.array_equal(row, c) for row in CNTS))[0])] CNTS [array([[[ 78, 1202]], [[ 63, 1202]], [[ 63, 1187]], [[ 78, 1187]]]), array([[[ 72, 742]], [[ 58, 742]], [[ 57, 741]], [[ 57, 727]], [[ 58, 726]], [[ 72, 726]]]), array([[[ 66, 194]], [[ 51, 194]], [[ 51, 179]], [[ 66, 179]]])] |