assert flatten_list_of_lists([[1],[2],[3]]) == [1,2,3]
assert flatten_recursive([[1],[2],[3, [4,5, [6,7,8]]]]) == [1,2,3,4,5,6,7,8]
Parallel Processing
MRL tries to build in parallel processing at every level. This can make a huge difference when you're processing millions of molecules
new_pool_parallel
and maybe_parallel
are convenient wrappers for parallel processing. The given func
is wrapped with **kwargs
and used to process the iterable
. If iterable
is a list
or np.ndarray
, the elements in iterable
are run in parallel by func
.
Parallel processing tradeoffs
Parallel processing can significantly speed up a process. There are however some trade-offs.
In Python, parallel processing is creating using a Pool
. A pool maps instances of a function over an iterable.
# uses 5 processes to map `my_iterable` to `my_func`
with Pool(processes=5) as p:
outputs = p.map(my_func, my_iterable)
Using the above code creates a new Pool
with 5 processes, and uses those 5 processes to process the function arguments. The code incurs some i/o overhead creating the Pool
. This means that if the time required to process the function calls is less than the Pool overhead, using parallel processing will actually be slower than serial processing. There are two ways around this:
- Use process pools for bulk processing (ie function time much greater than i/o time)
- Maintain an open process pool to avoid repeated pool creation overhead
MRL uses different functions to control the method of parallel processing used.
The new_pool_parallel
function implements parallel processing using a new pool for every function call, similar to the above code. This function is best used to process large numbers of inputs infrequently. Parallel processing is controlled by the cpus
argument. If cpus=None
, the ncpus
environment variable is used (ie os.environ['ncpus'] = '8'
)
The maybe_parallel
function allows for repeated use of a stateful process Pool, defined by the GLOBAL_POOL
variable. By default, GLOBAL_POOL=None
. To create a global pool, use the set_global_pool
function.
set_global_pool(cpus=8)
If the cpus=None
, maybe_parallel
will run processes using GLOBAL_POOL
if it exists, or serial processing if it does not. If cpus
is not None, maybe_parallel
defaults back to using new_pool_parallel
If you need to frequently use parallel processing on small batches of inputs (ie batches from a model), set a global pool and use maybe_parallel
Global Pool Gotchas
Using a global pool allows us to take advantage of parallel processing on small batches without having the overhead of creating process pools over and over again. However, process pools left open accumulate memory. If memory usage builds up, use refresh_global_pool
to release the memory and create a new global pool, or use close_global_pool
to delete the global pool and reset it to None
def test_func(x):
time.sleep(0.5)
return x
start = time.time()
_ = [test_func(i) for i in range(10)]
t1 = time.time()
_ = new_pool_parallel(test_func, list(range(10))) # serial processing
t2 = time.time()
_ = new_pool_parallel(test_func, list(range(10)), cpus=4) # 4 cpus manually defined
t3 = time.time()
os.environ['ncpus'] = '4'
_ = new_pool_parallel(test_func, list(range(10))) # 4 cpus defined by environ variable
t4 = time.time()
print_str = f'''
Serial time: {t1-start:.2f}\n
new_pool_parallel, 0 cpus time: {t2-t1:.2f}\n
new_pool_parallel, 4 cpus (arg defined) time: {t3-t2:.2f}\n
new_pool_parallel, 4 cpus (environ defined) time: {t4-t3:.2f}\n
'''
print(print_str)
print(type(GLOBAL_POOL))
set_global_pool(5)
print(type(GLOBAL_POOL))
start = time.time()
_ = maybe_parallel(test_func, list(range(10)))
t1 = time.time()
_ = maybe_parallel(test_func, list(range(10)), cpus=2)
t2 = time.time()
print_str = f'''
maybe_parallel Global Pool (5 cpus) time: {t1-start:.2f}\n
maybe_parallel arg override 2 cpus time: {t2-t1:.2f}\n
'''
print(print_str)
start = time.time()
_ = maybe_parallel(test_func, list(range(10)))
t1 = time.time()
_ = maybe_parallel(test_func, list(range(10)), cpus=2)
t2 = time.time()
print_str = f'''
maybe_parallel Global Pool (5 cpus) time: {t1-start:.2f}\n
maybe_parallel arg override 2 cpus time: {t2-t1:.2f}\n
'''
print(print_str)
start = time.time()
for i in range(10):
_ = new_pool_parallel(test_func, list(range(10)))
end = time.time() - start
print(f'{end:.2f} elapsed')
start = time.time()
for i in range(10):
_ = maybe_parallel(test_func, list(range(10)))
end = time.time() - start
print(f'{end:.2f} elapsed')
In the above example, new_pool_parallel
takes 5 seconds longer to execute compared to maybe_parallel
. The time difference is driven by the overhead generated by creating new pools
Debugging Parallel Processing
Errors in parallel processing can be difficult to debug because the true error and stack trace are obscured by the parallel processing stack trace. If you have errors in parallel processing, first try setting os.environ['ncpus'] = '0'
and running close_global_pool
to disable python multiprocessing. This should reveal the true error.
If everything works fine when multiprocessing is disabled, it is likely one of your functions is failing to pickle.