multiprocessing: How do I share a dict among multiple processes?
一个程序,它创建几个可在可连接队列
如果我在子进程中打印字典D,我会看到已对其进行的修改(即在D上)。 但是在主进程加入Q之后,如果我打印D,那就是空的dict!
我知道这是一个同步/锁定问题。 有人能告诉我这里发生了什么,以及如何同步访问D?
一般答案涉及使用
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | from multiprocessing import Process, Manager def f(d): d[1] += '1' d['2'] += 2 if __name__ == '__main__': manager = Manager() d = manager.dict() d[1] = '1' d['2'] = 2 p1 = Process(target=f, args=(d,)) p2 = Process(target=f, args=(d,)) p1.start() p2.start() p1.join() p2.join() print d |
输出:
1 2 | $ python mul.py {1: '111', '2': 6} |
多处理与线程不同。每个子进程都将获得主进程内存的副本。通常,状态通过通信(管道/插座),信号或共享存储器共享。
多处理使一些抽象可用于您的用例 - 通过使用代理或共享内存将其视为本地的共享状态:http://docs.python.org/library/multiprocessing.html#sharing-state-between-processes
相关部分:
- http://docs.python.org/library/multiprocessing.html#shared-ctypes-objects
- http://docs.python.org/library/multiprocessing.html#module-multiprocessing.managers
我想分享我自己的工作,这比管理器的字典更快,比使用大量内存并且不适用于Mac OS的pyshmht库更简单,更稳定。虽然我的dict只适用于普通字符串,但目前是不可变的。
我使用线性探测实现,并在表后面的单独内存块中存储键和值对。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 | from mmap import mmap import struct from timeit import default_timer from multiprocessing import Manager from pyshmht import HashTable class shared_immutable_dict: def __init__(self, a): self.hs = 1 << (len(a) * 3).bit_length() kvp = self.hs * 4 ht = [0xffffffff] * self.hs kvl = [] for k, v in a.iteritems(): h = self.hash(k) while ht[h] != 0xffffffff: h = (h + 1) & (self.hs - 1) ht[h] = kvp kvp += self.kvlen(k) + self.kvlen(v) kvl.append(k) kvl.append(v) self.m = mmap(-1, kvp) for p in ht: self.m.write(uint_format.pack(p)) for x in kvl: if len(x) <= 0x7f: self.m.write_byte(chr(len(x))) else: self.m.write(uint_format.pack(0x80000000 + len(x))) self.m.write(x) def hash(self, k): h = hash(k) h = (h + (h >> 3) + (h >> 13) + (h >> 23)) * 1749375391 & (self.hs - 1) return h def get(self, k, d=None): h = self.hash(k) while True: x = uint_format.unpack(self.m[h * 4:h * 4 + 4])[0] if x == 0xffffffff: return d self.m.seek(x) if k == self.read_kv(): return self.read_kv() h = (h + 1) & (self.hs - 1) def read_kv(self): sz = ord(self.m.read_byte()) if sz & 0x80: sz = uint_format.unpack(chr(sz) + self.m.read(3))[0] - 0x80000000 return self.m.read(sz) def kvlen(self, k): return len(k) + (1 if len(k) <= 0x7f else 4) def __contains__(self, k): return self.get(k, None) is not None def close(self): self.m.close() uint_format = struct.Struct('>I') def uget(a, k, d=None): return to_unicode(a.get(to_str(k), d)) def uin(a, k): return to_str(k) in a def to_unicode(s): return s.decode('utf-8') if isinstance(s, str) else s def to_str(s): return s.encode('utf-8') if isinstance(s, unicode) else s def mmap_test(): n = 1000000 d = shared_immutable_dict({str(i * 2): '1' for i in xrange(n)}) start_time = default_timer() for i in xrange(n): if bool(d.get(str(i))) != (i % 2 == 0): raise Exception(i) print 'mmap speed: %d gets per sec' % (n / (default_timer() - start_time)) def manager_test(): n = 100000 d = Manager().dict({str(i * 2): '1' for i in xrange(n)}) start_time = default_timer() for i in xrange(n): if bool(d.get(str(i))) != (i % 2 == 0): raise Exception(i) print 'manager speed: %d gets per sec' % (n / (default_timer() - start_time)) def shm_test(): n = 1000000 d = HashTable('tmp', n) d.update({str(i * 2): '1' for i in xrange(n)}) start_time = default_timer() for i in xrange(n): if bool(d.get(str(i))) != (i % 2 == 0): raise Exception(i) print 'shm speed: %d gets per sec' % (n / (default_timer() - start_time)) if __name__ == '__main__': mmap_test() manager_test() shm_test() |
我的笔记本电脑的性能结果是:
1 2 3 | mmap speed: 247288 gets per sec manager speed: 33792 gets per sec shm speed: 691332 gets per sec |
简单的用法示例:
1 2 | ht = shared_immutable_dict({'a': '1', 'b': '2'}) print ht.get('a') |
除了@ senderle之外,有些人可能也想知道如何在这里使用
好处是
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | from itertools import repeat import multiprocessing as mp import os import pprint def f(d): pid = os.getpid() d[pid] ="Hi, I was written by process %d" % pid if __name__ == '__main__': with mp.Manager() as manager: d = manager.dict() with manager.Pool() as pool: pool.map(f, repeat(d, 10)) # `d` is a DictProxy object that can be converted to dict pprint.pprint(dict(d)) |
输出:
1 2 3 4 5 6 7 8 9 10 11 | $ python3 mul.py {22562: 'Hi, I was written by process 22562', 22563: 'Hi, I was written by process 22563', 22564: 'Hi, I was written by process 22564', 22565: 'Hi, I was written by process 22565', 22566: 'Hi, I was written by process 22566', 22567: 'Hi, I was written by process 22567', 22568: 'Hi, I was written by process 22568', 22569: 'Hi, I was written by process 22569', 22570: 'Hi, I was written by process 22570', 22571: 'Hi, I was written by process 22571'} |
这是一个稍微不同的示例,其中每个进程只将其进程ID记录到全局
也许你可以试试pyshmht,为Python共享基于内存的哈希表扩展。
注意
它没有经过全面测试,仅供参考。
它目前缺乏用于多处理的锁/ sem机制。