How to freeze a pool

Originally this was a part of Python is easy article, but it quickly grew way to big for a tangent, so I've decided to give it its own place under the sun.

Let's imagine you are running a multiprocessing.Pool and want some clever stuff to happen when your worker terminates. You should set up a signal handler with signal module like this, and you'll be good to go

def _handle_signal(sig: int, frame: Any) -> None:
    logger.info(f'Received signal {sig}')
    _cleanup()
    sys.exit()


signal.signal(signal.SIGINT, _handle_signal)
signal.signal(signal.SIGTERM, _handle_signal)

As you probably know, sys.exit() just raises SystemExit, so what happens when you raise some other exceptions? First of all, all hell breaks loose you really should not do that. But if it happens to be a subclass of Exception it will deadlock the pool.

Almost all exceptions in python are subclasses of Exception class. It makes it very easy to just write

try:
    func()
except Exception:
    logger.exception('Oh-oh something went wrong:')
    raise

But there was one big caveat. And this caveat was CancelledError. In asyncio, python built-in async library, CancelledError signals cancellation of a coroutine or a task and can be raised on basically any await. So, if your particular workload needs to do some cleanup or logging before exiting you do it like this

try:
    await afunc()
except CancelledError:
    # do some cleanup 
    raise

But this popular pattern suddenly becomes a problem.

try:
    await afunc()
except Exception:
    pass

Why? Because until Python 3.8 CancelledError was a subclass of Exception. It would just straight up ignore a cancellation of a coroutine. And it becomes extremely dangerous when event loop is stopping, since it has to cancel all currently scheduled coroutines.

So async code was riddled with stuff like this:

try:
    await afunc()
except CancelledError:
    raise
except SomeException:
    logger.info('bad')

Now CancelledError is a subclass of BaseException, so there is no need to check for it every time you want to recover from an exception.

But why all this tangent about CancelledError? Well, you see, problem with deadlock in the multiprocessing.Pool revolves around similar issue. When you create a pool it starts several management threads to deal with queues and what not. And it also spawns worker processes, running a special worker function. It roughly looks like this, minus exception handling and setup part.

def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
    completed = 0
    while maxtasks is None or (maxtasks and completed < maxtasks):

        task = get()  # Get a job to execute
        
        if task is None:
            util.debug('worker got sentinel -- exiting')
            break

        job, i, func, args, kwds = task
        try:
            result = (True, func(*args, **kwds))  # Runs a function submitted with .apply()
        except Exception as e:
            result = (False, e)  # Stores the exception as the result to send back to the parent process
        
        put((job, i, result))  # Puts the result in finished queue
        
        # Remove references, so the data will not linger on the next iteration
        task = job = result = func = args = kwds = None  
        completed += 1
    util.debug('worker exiting after %d tasks' % completed)

When pool shuts down is sends SIGTERM to workers, and then joins them, so it won't leave any zombies. But when you register a signal handler inside a worker process that raises some subclass of Exception, worker, in fact, does not terminate. It just sends your exceptions as a result and waits for more work to come its way. But pool has already closed all management threads and now is waiting for worker to exit. Classic deadlock.

So yeah, be careful with setting signal handlers in worker processes. Actually, it's just one of many caveats you need to be wary about when using multiprocessing in Python. Some of them are described in this article.

And one more thing

As documentation for signal module states, if signal handler raises an exception it can be executed after any bytecode instruction on the main thread. So with this bit of information, where the memory leak occurs in _wait_response method?

class Client:
    def __init__(self):
        # snip
        self._responses: Dict[str, Any] = {}
        self._response_events: Dict[str, Event] = {}

    # snip
    def _wait_response(self, correlation_id: str, expiration: int) -> Dict:
        if not self._response_events[correlation_id].wait(expiration):
            self._response_events.pop(correlation_id)
            raise TimeoutError
        self._response_events.pop(correlation_id)
        return self._responses.pop(correlation_id)

It's a trick question. Because there is no memory leak.

Yet.

If you set up a signal handler for SIGALARM like this for the worker process running Client

def raise_timeout_error(signal, frame):
    raise RuntimeError("Time is up!")

signal.signal(signal.SIGALRM, raise_timeout_error)
signal.alarm(_timeout)

You will leak Event instances in the _response_events dict. If your expiration for wait method on Event happens to be a bit longer (or equal, for that matter) than _timeout on SIGALARM, RuntimeError will be raised from wait(expiration) call. Because RuntimeError is subclass of Exception, worker will not terminate on it, and you will skip the resource cleanup code.

So yeah, your process pool is full of sharks, and you are the one who released them.


862 Words

2022-06-26