{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Threading and parallel processing\n",
    "\n",
    "\n",
    "The Python language has built-in support for multi-threading in the\n",
    "[`threading`](https://docs.python.org/3.5/library/threading.html) module, and\n",
    "true parallelism in the\n",
    "[`multiprocessing`](https://docs.python.org/3.5/library/multiprocessing.html)\n",
    "module.  If you want to be impressed, skip straight to the section on\n",
    "[`multiprocessing`](todo).\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "## Threading\n",
    "\n",
    "\n",
    "The [`threading`](https://docs.python.org/3.5/library/threading.html) module\n",
    "provides a traditional multi-threading API that should be familiar to you if\n",
    "you have worked with threads in other languages.\n",
    "\n",
    "\n",
    "Running a task in a separate thread in Python is easy - simply create a\n",
    "`Thread` object, and pass it the function or method that you want it to\n",
    "run. Then call its `start` method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import threading\n",
    "\n",
    "def longRunningTask(niters):\n",
    "    for i in range(niters):\n",
    "        if i % 2 == 0: print('Tick')\n",
    "        else:          print('Tock')\n",
    "        time.sleep(0.5)\n",
    "\n",
    "t = threading.Thread(target=longRunningTask, args=(8,))\n",
    "\n",
    "t.start()\n",
    "\n",
    "while t.is_alive():\n",
    "    time.sleep(0.4)\n",
    "    print('Waiting for thread to finish...')\n",
    "print('Finished!')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also `join` a thread, which will block execution in the current thread\n",
    "until the thread that has been `join`ed has finished:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = threading.Thread(target=longRunningTask, args=(6, ))\n",
    "t.start()\n",
    "\n",
    "print('Joining thread ...')\n",
    "t.join()\n",
    "print('Finished!')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Subclassing `Thread`\n",
    "\n",
    "\n",
    "It is also possible to sub-class the `Thread` class, and override its `run`\n",
    "method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LongRunningThread(threading.Thread):\n",
    "    def __init__(self, niters, *args, **kwargs):\n",
    "        super().__init__(*args, **kwargs)\n",
    "        self.niters = niters\n",
    "\n",
    "    def run(self):\n",
    "        for i in range(self.niters):\n",
    "            if i % 2 == 0: print('Tick')\n",
    "            else:          print('Tock')\n",
    "            time.sleep(0.5)\n",
    "\n",
    "t = LongRunningThread(6)\n",
    "t.start()\n",
    "t.join()\n",
    "print('Done')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Daemon threads\n",
    "\n",
    "\n",
    "By default, a Python application will not exit until _all_ active threads have\n",
    "finished execution.  If you want to run a task in the background for the\n",
    "duration of your application, you can mark it as a `daemon` thread - when all\n",
    "non-daemon threads in a Python application have finished, all daemon threads\n",
    "will be halted, and the application will exit.\n",
    "\n",
    "\n",
    "You can mark a thread as being a daemon by setting an attribute on it after\n",
    "creation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = threading.Thread(target=longRunningTask)\n",
    "t.daemon = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [`Thread`\n",
    "documentation](https://docs.python.org/3.5/library/threading.html#thread-objects)\n",
    "for more details.\n",
    "\n",
    "\n",
    "### Thread synchronisation\n",
    "\n",
    "\n",
    "The `threading` module provides some useful thread-synchronisation primitives\n",
    "- the `Lock`, `RLock` (re-entrant `Lock`), and `Event` classes.  The\n",
    "`threading` module also provides `Condition` and `Semaphore` classes - refer\n",
    "to the [documentation](https://docs.python.org/3.5/library/threading.html) for\n",
    "more details.\n",
    "\n",
    "\n",
    "#### `Lock`\n",
    "\n",
    "\n",
    "The [`Lock`](https://docs.python.org/3.5/library/threading.html#lock-objects)\n",
    "class (and its re-entrant version, the\n",
    "[`RLock`](https://docs.python.org/3.5/library/threading.html#rlock-objects))\n",
    "prevents a block of code from being accessed by more than one thread at a\n",
    "time. For example, if we have multiple threads running this `task` function,\n",
    "their [outputs](https://www.youtube.com/watch?v=F5fUFnfPpYU) will inevitably\n",
    "become intertwined:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def task():\n",
    "    for i in range(5):\n",
    "        print('{} Woozle '.format(i), end='')\n",
    "        time.sleep(0.1)\n",
    "        print('Wuzzle')\n",
    "\n",
    "threads = [threading.Thread(target=task) for i in range(5)]\n",
    "for t in threads:\n",
    "    t.start()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "But if we protect the critical section with a `Lock` object, the output will\n",
    "look more sensible:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lock = threading.Lock()\n",
    "\n",
    "def task():\n",
    "\n",
    "    for i in range(5):\n",
    "        with lock:\n",
    "            print('{} Woozle '.format(i), end='')\n",
    "            time.sleep(0.1)\n",
    "            print('Wuzzle')\n",
    "\n",
    "threads = [threading.Thread(target=task) for i in range(5)]\n",
    "for t in threads:\n",
    "    t.start()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> Instead of using a `Lock` object in a `with` statement, it is also possible\n",
    "> to manually call its `acquire` and `release` methods:\n",
    ">\n",
    ">     def task():\n",
    ">         for i in range(5):\n",
    ">             lock.acquire()\n",
    ">             print('{} Woozle '.format(i), end='')\n",
    ">             time.sleep(0.1)\n",
    ">             print('Wuzzle')\n",
    ">             lock.release()\n",
    "\n",
    "\n",
    "Python does not have any built-in constructs to implement `Lock`-based mutual\n",
    "exclusion across several functions or methods - each function/method must\n",
    "explicitly acquire/release a shared `Lock` instance. However, it is relatively\n",
    "straightforward to implement a decorator which does this for you:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mutex(func, lock):\n",
    "    def wrapper(*args):\n",
    "        with lock:\n",
    "            func(*args)\n",
    "    return wrapper\n",
    "\n",
    "class MyClass(object):\n",
    "\n",
    "    def __init__(self):\n",
    "        lock = threading.Lock()\n",
    "        self.safeFunc1 = mutex(self.safeFunc1, lock)\n",
    "        self.safeFunc2 = mutex(self.safeFunc2, lock)\n",
    "\n",
    "    def safeFunc1(self):\n",
    "        time.sleep(0.1)\n",
    "        print('safeFunc1 start')\n",
    "        time.sleep(0.2)\n",
    "        print('safeFunc1 end')\n",
    "\n",
    "    def safeFunc2(self):\n",
    "        time.sleep(0.1)\n",
    "        print('safeFunc2 start')\n",
    "        time.sleep(0.2)\n",
    "        print('safeFunc2 end')\n",
    "\n",
    "mc = MyClass()\n",
    "\n",
    "f1threads = [threading.Thread(target=mc.safeFunc1) for i in range(4)]\n",
    "f2threads = [threading.Thread(target=mc.safeFunc2) for i in range(4)]\n",
    "\n",
    "for t in f1threads + f2threads:\n",
    "    t.start()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Try removing the `mutex` lock from the two methods in the above code, and see\n",
    "what it does to the output.\n",
    "\n",
    "\n",
    "#### `Event`\n",
    "\n",
    "\n",
    "The\n",
    "[`Event`](https://docs.python.org/3.5/library/threading.html#event-objects)\n",
    "class is essentially a boolean [semaphore][semaphore-wiki]. It can be used to\n",
    "signal events between threads. Threads can `wait` on the event, and be awoken\n",
    "when the event is `set` by another thread:\n",
    "\n",
    "\n",
    "[semaphore-wiki]: https://en.wikipedia.org/wiki/Semaphore_(programming)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "processingFinished = threading.Event()\n",
    "\n",
    "def processData(data):\n",
    "    print('Processing data ...')\n",
    "    time.sleep(2)\n",
    "    print('Result: {}'.format(data.mean()))\n",
    "    processingFinished.set()\n",
    "\n",
    "data = np.random.randint(1, 100, 100)\n",
    "\n",
    "t = threading.Thread(target=processData, args=(data,))\n",
    "t.start()\n",
    "\n",
    "processingFinished.wait()\n",
    "print('Processing finished!')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The Global Interpreter Lock (GIL)\n",
    "\n",
    "\n",
    "The [_Global Interpreter\n",
    "Lock_](https://docs.python.org/3/c-api/init.html#thread-state-and-the-global-interpreter-lock)\n",
    "is an implementation detail of [CPython](https://github.com/python/cpython)\n",
    "(the official Python interpreter).  The GIL means that a multi-threaded\n",
    "program written in pure Python is not able to take advantage of multiple\n",
    "cores - this essentially means that only one thread may be executing at any\n",
    "point in time.\n",
    "\n",
    "\n",
    "The `threading` module does still have its uses though, as this GIL problem\n",
    "does not affect tasks which involve calls to system or natively compiled\n",
    "libraries (e.g. file and network I/O, Numpy operations, etc.). So you can,\n",
    "for example, perform some expensive processing on a Numpy array in a thread\n",
    "running on one core, whilst having another thread (e.g. user interaction)\n",
    "running on another core.\n",
    "\n",
    "\n",
    "## Multiprocessing\n",
    "\n",
    "\n",
    "For true parallelism, you should check out the\n",
    "[`multiprocessing`](https://docs.python.org/3.5/library/multiprocessing.html)\n",
    "module.\n",
    "\n",
    "\n",
    "The `multiprocessing` module spawns sub-processes, rather than threads, and so\n",
    "is not subject to the GIL constraints that the `threading` module suffers\n",
    "from. It provides two APIs - a \"traditional\" equivalent to that provided by\n",
    "the `threading` module, and a powerful higher-level API.\n",
    "\n",
    "\n",
    "### `threading`-equivalent API\n",
    "\n",
    "\n",
    "The\n",
    "[`Process`](https://docs.python.org/3.5/library/multiprocessing.html#the-process-class)\n",
    "class is the `multiprocessing` equivalent of the\n",
    "[`threading.Thread`](https://docs.python.org/3.5/library/threading.html#thread-objects)\n",
    "class.  `multprocessing` also has equivalents of the [`Lock` and `Event`\n",
    "classes](https://docs.python.org/3.5/library/multiprocessing.html#synchronization-between-processes),\n",
    "and the other synchronisation primitives provided by `threading`.\n",
    "\n",
    "\n",
    "So you can simply replace `threading.Thread` with `multiprocessing.Process`,\n",
    "and you will have true parallelism.\n",
    "\n",
    "\n",
    "Because your \"threads\" are now independent processes, you need to be a little\n",
    "careful about how to share information across them. Fortunately, the\n",
    "`multiprocessing` module provides [`Queue` and `Pipe`\n",
    "classes](https://docs.python.org/3.5/library/multiprocessing.html#exchanging-objects-between-processes)\n",
    "which make it easy to share data across processes.\n",
    "\n",
    "\n",
    "### Higher-level API - the `multiprocessing.Pool`\n",
    "\n",
    "\n",
    "The real advantages of `multiprocessing` lie in its higher level API, centered\n",
    "around the [`Pool`\n",
    "class](https://docs.python.org/3.5/library/multiprocessing.html#using-a-pool-of-workers).\n",
    "\n",
    "\n",
    "Essentially, you create a `Pool` of worker processes - you specify the number\n",
    "of processes when you create the pool.\n",
    "\n",
    "\n",
    "> The best number of processes to use for a `Pool` will depend on the system\n",
    "> you are running on (number of cores), and the tasks you are running (e.g.\n",
    "> I/O bound or CPU bound).\n",
    "\n",
    "\n",
    "Once you have created a `Pool`, you can use its methods to automatically\n",
    "parallelise tasks. The most useful are the `map`, `starmap` and\n",
    "`apply_async` methods.\n",
    "\n",
    "\n",
    "#### `Pool.map`\n",
    "\n",
    "\n",
    "The\n",
    "[`Pool.map`](https://docs.python.org/3.5/library/multiprocessing.html#multiprocessing.pool.Pool.map)\n",
    "method is the multiprocessing equivalent of the built-in\n",
    "[`map`](https://docs.python.org/3.5/library/functions.html#map) function - it\n",
    "is given a function, and a sequence, and it applies the function to each\n",
    "element in the sequence."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import                    time\n",
    "import multiprocessing as mp\n",
    "import numpy           as np\n",
    "\n",
    "def crunchImage(imgfile):\n",
    "\n",
    "    # Load a nifti image, do stuff\n",
    "    # to it. Use your imagination\n",
    "    # to fill in this function.\n",
    "    time.sleep(2)\n",
    "\n",
    "    # numpy's random number generator\n",
    "    # will be initialised in the same\n",
    "    # way in each process, so let's\n",
    "    # re-seed it.\n",
    "    np.random.seed()\n",
    "    result = np.random.randint(1, 100, 1)\n",
    "\n",
    "    print(imgfile, ':', result)\n",
    "\n",
    "    return result\n",
    "\n",
    "imgfiles = ['{:02d}.nii.gz'.format(i) for i in range(20)]\n",
    "\n",
    "p = mp.Pool(processes=16)\n",
    "\n",
    "print('Crunching images...')\n",
    "\n",
    "start   = time.time()\n",
    "results = p.map(crunchImage, imgfiles)\n",
    "end     = time.time()\n",
    "\n",
    "print('Total execution time: {:0.2f} seconds'.format(end - start))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `Pool.map` method only works with functions that accept one argument, such\n",
    "as our `crunchImage` function above. If you have a function which accepts\n",
    "multiple arguments, use the\n",
    "[`Pool.starmap`](https://docs.python.org/3.5/library/multiprocessing.html#multiprocessing.pool.Pool.starmap)\n",
    "method instead:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def crunchImage(imgfile, modality):\n",
    "    time.sleep(2)\n",
    "\n",
    "    np.random.seed()\n",
    "\n",
    "    if modality == 't1':\n",
    "        result = np.random.randint(1, 100, 1)\n",
    "    elif modality == 't2':\n",
    "        result = np.random.randint(100, 200, 1)\n",
    "\n",
    "    print(imgfile, ': ', result)\n",
    "\n",
    "    return result\n",
    "\n",
    "imgfiles   = ['t1_{:02d}.nii.gz'.format(i) for i in range(10)] + \\\n",
    "             ['t2_{:02d}.nii.gz'.format(i) for i in range(10)]\n",
    "modalities = ['t1'] * 10 + ['t2'] * 10\n",
    "\n",
    "pool = mp.Pool(processes=16)\n",
    "\n",
    "args = [(f, m) for f, m in zip(imgfiles, modalities)]\n",
    "\n",
    "print('Crunching images...')\n",
    "\n",
    "start   = time.time()\n",
    "results = pool.starmap(crunchImage, args)\n",
    "end     = time.time()\n",
    "\n",
    "print('Total execution time: {:0.2f} seconds'.format(end - start))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `map` and `starmap` methods also have asynchronous equivalents `map_async`\n",
    "and `starmap_async`, which return immediately. Refer to the\n",
    "[`Pool`](https://docs.python.org/3.5/library/multiprocessing.html#module-multiprocessing.pool)\n",
    "documentation for more details.\n",
    "\n",
    "\n",
    "#### `Pool.apply_async`\n",
    "\n",
    "\n",
    "The\n",
    "[`Pool.apply`](https://docs.python.org/3.5/library/multiprocessing.html#multiprocessing.pool.Pool.apply)\n",
    "method will execute a function on one of the processes, and block until it has\n",
    "finished.  The\n",
    "[`Pool.apply_async`](https://docs.python.org/3.5/library/multiprocessing.html#multiprocessing.pool.Pool.apply_async)\n",
    "method returns immediately, and is thus more suited to asynchronously\n",
    "scheduling multiple jobs to run in parallel.\n",
    "\n",
    "\n",
    "`apply_async` returns an object of type\n",
    "[`AsyncResult`](https://docs.python.org/3.5/library/multiprocessing.html#multiprocessing.pool.AsyncResult).\n",
    "An `AsyncResult` object has `wait` and `get` methods which will block until\n",
    "the job has completed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import                    time\n",
    "import multiprocessing as mp\n",
    "import numpy           as np\n",
    "\n",
    "\n",
    "def linear_registration(src, ref):\n",
    "    time.sleep(1)\n",
    "\n",
    "    return np.eye(4)\n",
    "\n",
    "def nonlinear_registration(src, ref, affine):\n",
    "\n",
    "    time.sleep(3)\n",
    "\n",
    "    # this number represents a non-linear warp\n",
    "    # field - use your imagination people!\n",
    "    np.random.seed()\n",
    "    return np.random.randint(1, 100, 1)\n",
    "\n",
    "t1s = ['{:02d}_t1.nii.gz'.format(i) for i in range(20)]\n",
    "std = 'MNI152_T1_2mm.nii.gz'\n",
    "\n",
    "pool = mp.Pool(processes=16)\n",
    "\n",
    "print('Running structural-to-standard registration '\n",
    "      'on {} subjects...'.format(len(t1s)))\n",
    "\n",
    "# Run linear registration on all the T1s.\n",
    "#\n",
    "# We build a list of AsyncResult objects\n",
    "linresults = [pool.apply_async(linear_registration, (t1, std))\n",
    "              for t1 in t1s]\n",
    "\n",
    "# Then we wait for each job to finish,\n",
    "# and replace its AsyncResult object\n",
    "# with the actual result - an affine\n",
    "# transformation matrix.\n",
    "start = time.time()\n",
    "for i, r in enumerate(linresults):\n",
    "    linresults[i] = r.get()\n",
    "end = time.time()\n",
    "\n",
    "print('Linear registrations completed in '\n",
    "      '{:0.2f} seconds'.format(end - start))\n",
    "\n",
    "# Run non-linear registration on all the T1s,\n",
    "# using the linear registrations to initialise.\n",
    "nlinresults = [pool.apply_async(nonlinear_registration, (t1, std, aff))\n",
    "               for (t1, aff) in zip(t1s, linresults)]\n",
    "\n",
    "# Wait for each non-linear reg to finish,\n",
    "# and store the resulting warp field.\n",
    "start = time.time()\n",
    "for i, r in enumerate(nlinresults):\n",
    "    nlinresults[i] = r.get()\n",
    "end = time.time()\n",
    "\n",
    "print('Non-linear registrations completed in '\n",
    "      '{:0.2f} seconds'.format(end - start))\n",
    "\n",
    "print('Non linear registrations:')\n",
    "for t1, result in zip(t1s, nlinresults):\n",
    "    print(t1, ':', result)"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 2
}