{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a19ed815",
   "metadata": {},
   "source": [
    "# Threading and parallel processing\n",
    "\n",
    "\n",
    "The Python language has built-in support for multi-threading in the\n",
    "[`threading`](https://docs.python.org/3/library/threading.html) module, and\n",
    "true parallelism in the\n",
    "[`multiprocessing`](https://docs.python.org/3/library/multiprocessing.html)\n",
    "module.  If you want to be impressed, skip straight to the section on\n",
    "[`multiprocessing`](multiprocessing).\n",
    "\n",
    "\n",
    "> *Note*: This notebook covers features that are built-in to the Python\n",
    "> programming language. However, there are many other parallelisation options\n",
    "> available to you through third-party libraries - some of them are covered in `applications/parallel/parallel.ipynb`.\n",
    "\n",
    "\n",
    "> *Note*: If you are familiar with a \"real\" programming language such as C++\n",
    "> or Java, you might be disappointed with the native support for parallelism in\n",
    "> Python. Python threads do not run in parallel because of the Global\n",
    "> Interpreter Lock, and if you use `multiprocessing`, be prepared to either\n",
    "> bear the performance hit of copying data between processes, or jump through\n",
    "> hoops order to share data between processes.\n",
    ">\n",
    "> This limitation is being addressed in recent Python version, with\n",
    "> [_Free-threaded Python_ builds](https://docs.python.org/3/howto/free-threading-python.html),\n",
    "> which will hopefully soon be the default behaviour.\n",
    "\n",
    "\n",
    "* [Threading](#threading)\n",
    "  * [Subclassing `Thread`](#subclassing-thread)\n",
    "  * [Daemon threads](#daemon-threads)\n",
    "  * [Thread synchronisation](#thread-synchronisation)\n",
    "    * [`Lock`](#lock)\n",
    "    * [`Event`](#event)\n",
    "  * [The Global Interpreter Lock (GIL)](#the-global-interpreter-lock-gil)\n",
    "* [Multiprocessing](#multiprocessing)\n",
    "  * [`threading`-equivalent API](#threading-equivalent-api)\n",
    "  * [Higher-level API - the `multiprocessing.Pool`](#higher-level-api-the-multiprocessing-pool)\n",
    "    * [`Pool.map`](#pool-map)\n",
    "    * [`Pool.apply_async`](#pool-apply-async)\n",
    "* [Sharing data between processes](#sharing-data-between-processes)\n",
    "  * [Memory-mapping](#memory-mapping)\n",
    "  * [Read-only sharing](#read-only-sharing)\n",
    "  * [Read/write sharing](#read-write-sharing)\n",
    "\n",
    "\n",
    "<a class=\"anchor\" id=\"threading\"></a>\n",
    "## Threading\n",
    "\n",
    "\n",
    "The [`threading`](https://docs.python.org/3/library/threading.html) module\n",
    "provides a traditional multi-threading API that should be familiar to you if\n",
    "you have worked with threads in other languages.\n",
    "\n",
    "\n",
    "Running a task in a separate thread in Python is easy - simply create a\n",
    "`Thread` object, and pass it the function or method that you want it to\n",
    "run. Then call its `start` method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50172fe8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import threading\n",
    "\n",
    "def longRunningTask(niters):\n",
    "    for i in range(niters):\n",
    "        if i % 2 == 0: print('Tick')\n",
    "        else:          print('Tock')\n",
    "        time.sleep(0.5)\n",
    "\n",
    "t = threading.Thread(target=longRunningTask, args=(8,))\n",
    "\n",
    "t.start()\n",
    "\n",
    "while t.is_alive():\n",
    "    time.sleep(0.4)\n",
    "    print('Waiting for thread to finish...')\n",
    "print('Finished!')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9f0d52d",
   "metadata": {},
   "source": [
    "You can also `join` a thread, which will block execution in the current thread\n",
    "until the thread that has been `join`ed has finished:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5389e92d",
   "metadata": {},
   "outputs": [],
   "source": [
    "t = threading.Thread(target=longRunningTask, args=(6, ))\n",
    "t.start()\n",
    "\n",
    "print('Joining thread ...')\n",
    "t.join()\n",
    "print('Finished!')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ca8ed21",
   "metadata": {},
   "source": [
    "<a class=\"anchor\" id=\"subclassing-thread\"></a>\n",
    "### Subclassing `Thread`\n",
    "\n",
    "\n",
    "It is also possible to sub-class the `Thread` class, and override its `run`\n",
    "method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19bf75de",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LongRunningThread(threading.Thread):\n",
    "    def __init__(self, niters, *args, **kwargs):\n",
    "        super().__init__(*args, **kwargs)\n",
    "        self.niters = niters\n",
    "\n",
    "    def run(self):\n",
    "        for i in range(self.niters):\n",
    "            if i % 2 == 0: print('Tick')\n",
    "            else:          print('Tock')\n",
    "            time.sleep(0.5)\n",
    "\n",
    "t = LongRunningThread(6)\n",
    "t.start()\n",
    "t.join()\n",
    "print('Done')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61391158",
   "metadata": {},
   "source": [
    "<a class=\"anchor\" id=\"daemon-threads\"></a>\n",
    "### Daemon threads\n",
    "\n",
    "\n",
    "By default, a Python application will not exit until _all_ active threads have\n",
    "finished execution.  If you want to run a task in the background for the\n",
    "duration of your application, you can mark it as a `daemon` thread - when all\n",
    "non-daemon threads in a Python application have finished, all daemon threads\n",
    "will be halted, and the application will exit.\n",
    "\n",
    "\n",
    "You can mark a thread as being a daemon by setting an attribute on it after\n",
    "creation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fcb1f65",
   "metadata": {},
   "outputs": [],
   "source": [
    "t = threading.Thread(target=longRunningTask)\n",
    "t.daemon = True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69c1f604",
   "metadata": {},
   "source": [
    "See the [`Thread`\n",
    "documentation](https://docs.python.org/3/library/threading.html#thread-objects)\n",
    "for more details.\n",
    "\n",
    "\n",
    "<a class=\"anchor\" id=\"thread-synchronisation\"></a>\n",
    "### Thread synchronisation\n",
    "\n",
    "\n",
    "The `threading` module provides some useful thread-synchronisation primitives\n",
    "- the `Lock`, `RLock` (re-entrant `Lock`), and `Event` classes.  The\n",
    "`threading` module also provides `Condition` and `Semaphore` classes - refer\n",
    "to the [documentation](https://docs.python.org/3/library/threading.html) for\n",
    "more details.\n",
    "\n",
    "\n",
    "<a class=\"anchor\" id=\"lock\"></a>\n",
    "#### `Lock`\n",
    "\n",
    "\n",
    "The [`Lock`](https://docs.python.org/3/library/threading.html#lock-objects)\n",
    "class (and its re-entrant version, the\n",
    "[`RLock`](https://docs.python.org/3/library/threading.html#rlock-objects))\n",
    "prevents a block of code from being accessed by more than one thread at a\n",
    "time. For example, if we have multiple threads running this `task` function,\n",
    "their [outputs](https://www.youtube.com/watch?v=F5fUFnfPpYU) will inevitably\n",
    "become intertwined:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7811d789",
   "metadata": {},
   "outputs": [],
   "source": [
    "def task():\n",
    "    for i in range(5):\n",
    "        print(f'{i} Woozle ', end='')\n",
    "        time.sleep(0.1)\n",
    "        print('Wuzzle')\n",
    "\n",
    "threads = [threading.Thread(target=task) for i in range(5)]\n",
    "for t in threads:\n",
    "    t.start()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f3c088c",
   "metadata": {},
   "source": [
    "But if we protect the critical section with a `Lock` object, the output will\n",
    "look more sensible:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20981b0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "lock = threading.Lock()\n",
    "\n",
    "def task():\n",
    "\n",
    "    for i in range(5):\n",
    "        with lock:\n",
    "            print(f'{i} Woozle ', end='')\n",
    "            time.sleep(0.1)\n",
    "            print('Wuzzle')\n",
    "\n",
    "threads = [threading.Thread(target=task) for i in range(5)]\n",
    "for t in threads:\n",
    "    t.start()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f665943f",
   "metadata": {},
   "source": [
    "> Instead of using a `Lock` object in a `with` statement, it is also possible\n",
    "> to manually call its `acquire` and `release` methods:\n",
    ">\n",
    ">     def task():\n",
    ">         for i in range(5):\n",
    ">             lock.acquire()\n",
    ">             print(f'{i} Woozle ', end='')\n",
    ">             time.sleep(0.1)\n",
    ">             print('Wuzzle')\n",
    ">             lock.release()\n",
    "\n",
    "\n",
    "Python does not have any built-in constructs to implement `Lock`-based mutual\n",
    "exclusion across several functions or methods - each function/method must\n",
    "explicitly acquire/release a shared `Lock` instance. However, it is relatively\n",
    "straightforward to implement a decorator which does this for you:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac049359",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mutex(func, lock):\n",
    "    def wrapper(*args):\n",
    "        with lock:\n",
    "            func(*args)\n",
    "    return wrapper\n",
    "\n",
    "class MyClass(object):\n",
    "\n",
    "    def __init__(self):\n",
    "        lock = threading.Lock()\n",
    "        self.safeFunc1 = mutex(self.safeFunc1, lock)\n",
    "        self.safeFunc2 = mutex(self.safeFunc2, lock)\n",
    "\n",
    "    def safeFunc1(self):\n",
    "        time.sleep(0.1)\n",
    "        print('safeFunc1 start')\n",
    "        time.sleep(0.2)\n",
    "        print('safeFunc1 end')\n",
    "\n",
    "    def safeFunc2(self):\n",
    "        time.sleep(0.1)\n",
    "        print('safeFunc2 start')\n",
    "        time.sleep(0.2)\n",
    "        print('safeFunc2 end')\n",
    "\n",
    "mc = MyClass()\n",
    "\n",
    "f1threads = [threading.Thread(target=mc.safeFunc1) for i in range(4)]\n",
    "f2threads = [threading.Thread(target=mc.safeFunc2) for i in range(4)]\n",
    "\n",
    "for t in f1threads + f2threads:\n",
    "    t.start()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4eb2c3a8",
   "metadata": {},
   "source": [
    "Try removing the `mutex` lock from the two methods in the above code, and see\n",
    "what it does to the output.\n",
    "\n",
    "\n",
    "<a class=\"anchor\" id=\"event\"></a>\n",
    "#### `Event`\n",
    "\n",
    "\n",
    "The\n",
    "[`Event`](https://docs.python.org/3/library/threading.html#event-objects)\n",
    "class is essentially a boolean [semaphore][semaphore-wiki]. It can be used to\n",
    "signal events between threads. Threads can `wait` on the event, and be awoken\n",
    "when the event is `set` by another thread:\n",
    "\n",
    "\n",
    "[semaphore-wiki]: https://en.wikipedia.org/wiki/Semaphore_(programming)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30b6989c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "processingFinished = threading.Event()\n",
    "\n",
    "def processData(data):\n",
    "    print('Processing data ...')\n",
    "    time.sleep(2)\n",
    "    print('Result:', data.mean())\n",
    "    processingFinished.set()\n",
    "\n",
    "data = np.random.randint(1, 100, 100)\n",
    "\n",
    "t = threading.Thread(target=processData, args=(data,))\n",
    "t.start()\n",
    "\n",
    "processingFinished.wait()\n",
    "print('Processing finished!')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1bb6596",
   "metadata": {},
   "source": [
    "<a class=\"anchor\" id=\"the-global-interpreter-lock-gil\"></a>\n",
    "### The Global Interpreter Lock (GIL)\n",
    "\n",
    "\n",
    "The [*Global Interpreter\n",
    "Lock*](https://docs.python.org/3/c-api/init.html#thread-state-and-the-global-interpreter-lock)\n",
    "is an implementation detail of [CPython](https://github.com/python/cpython)\n",
    "(the official Python interpreter).  The GIL means that a multi-threaded\n",
    "program written in pure Python is not able to take advantage of multiple\n",
    "cores - this essentially means that only one thread may be executing at any\n",
    "point in time.\n",
    "\n",
    "\n",
    "> Note that this is likely to change in future Python releases, with the\n",
    "> development of [_Free-threaded Python_](https://docs.python.org/3/howto/free-threading-python.html).\n",
    "\n",
    "\n",
    "The `threading` module does still have its uses though, as this GIL problem\n",
    "does not affect tasks which involve calls to system or natively compiled\n",
    "libraries (e.g. file and network I/O, Numpy operations, etc.). So you can,\n",
    "for example, perform some expensive processing on a Numpy array in a thread\n",
    "running on one core, whilst having another thread (e.g. user interaction)\n",
    "running on another core.\n",
    "\n",
    "\n",
    "<a class=\"anchor\" id=\"multiprocessing\"></a>\n",
    "## Multiprocessing\n",
    "\n",
    "\n",
    "For true parallelism, you should check out the\n",
    "[`multiprocessing`](https://docs.python.org/3/library/multiprocessing.html)\n",
    "module.\n",
    "\n",
    "\n",
    "The `multiprocessing` module spawns sub-processes, rather than threads, and so\n",
    "is not subject to the GIL constraints that the `threading` module suffers\n",
    "from. It provides two APIs - a \"traditional\" equivalent to that provided by\n",
    "the `threading` module, and a powerful higher-level API.\n",
    "\n",
    "\n",
    "> Python also provides the\n",
    "> [`concurrent.futures`](https://docs.python.org/3/library/concurrent.futures.html)\n",
    "> module, which offers a simpler alternative API to `multiprocessing`. It\n",
    "> offers no functionality over `multiprocessing`, so is not covered here.\n",
    "\n",
    "\n",
    "<a class=\"anchor\" id=\"threading-equivalent-api\"></a>\n",
    "### `threading`-equivalent API\n",
    "\n",
    "\n",
    "The\n",
    "[`Process`](https://docs.python.org/3/library/multiprocessing.html#the-process-class)\n",
    "class is the `multiprocessing` equivalent of the\n",
    "[`threading.Thread`](https://docs.python.org/3/library/threading.html#thread-objects)\n",
    "class.  `multprocessing` also has equivalents of the [`Lock` and `Event`\n",
    "classes](https://docs.python.org/3/library/multiprocessing.html#synchronization-between-processes),\n",
    "and the other synchronisation primitives provided by `threading`.\n",
    "\n",
    "\n",
    "So you can simply replace `threading.Thread` with `multiprocessing.Process`,\n",
    "and you will have true parallelism.\n",
    "\n",
    "\n",
    "Because your \"threads\" are now independent processes, you need to be a little\n",
    "careful about how to share information across them. If you only need to share\n",
    "small amounts of data, you can use the [`Queue` and `Pipe`\n",
    "classes](https://docs.python.org/3/library/multiprocessing.html#exchanging-objects-between-processes),\n",
    "in the `multiprocessing` module. If you are working with large amounts of data\n",
    "where copying between processes is not feasible, things become more\n",
    "complicated, but read on...\n",
    "\n",
    "\n",
    "<a class=\"anchor\" id=\"higher-level-api-the-multiprocessing-pool\"></a>\n",
    "### Higher-level API - the `multiprocessing.Pool`\n",
    "\n",
    "\n",
    "The real advantages of `multiprocessing` lie in its higher level API, centered\n",
    "around the [`Pool`\n",
    "class](https://docs.python.org/3/library/multiprocessing.html#using-a-pool-of-workers).\n",
    "\n",
    "\n",
    "Essentially, you create a `Pool` of worker processes - you specify the number\n",
    "of processes when you create the pool. Once you have created a `Pool`, you can\n",
    "use its methods to automatically parallelise tasks. The most useful are the\n",
    "`map`, `starmap` and `apply_async` methods.\n",
    "\n",
    "\n",
    "The `Pool` class is a context manager, so can be used in a `with` statement,\n",
    "e.g.:\n",
    "\n",
    "> ```\n",
    "> with mp.Pool(processes=16) as pool:\n",
    ">     # do stuff with the pool\n",
    "> ```\n",
    "\n",
    "It is possible to create a `Pool` outside of a `with` statement, but in this\n",
    "case you must ensure that you call its `close` method when you are finished.\n",
    "Using a `Pool` in a `with` statement is therefore recommended, because you know\n",
    "that it will be shut down correctly, even in the event of an error.\n",
    "\n",
    "\n",
    "> The best number of processes to use for a `Pool` will depend on the system\n",
    "> you are running on (number of cores), and the tasks you are running (e.g.\n",
    "> I/O bound or CPU bound). If you do not specify the number of processes when\n",
    "> creating a `Pool`, it will default to the number of cores on your machine.\n",
    "\n",
    "\n",
    "<a class=\"anchor\" id=\"pool-map\"></a>\n",
    "#### `Pool.map`\n",
    "\n",
    "\n",
    "The\n",
    "[`Pool.map`](https://docs.python.org/3/library/multiprocessing.html#multiprocessing.pool.Pool.map)\n",
    "method is the multiprocessing equivalent of the built-in\n",
    "[`map`](https://docs.python.org/3/library/functions.html#map) function - it\n",
    "is given a function, and a sequence, and it applies the function to each\n",
    "element in the sequence."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "025a74c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import                    time\n",
    "import multiprocessing as mp\n",
    "import numpy           as np\n",
    "\n",
    "def crunchImage(imgfile):\n",
    "\n",
    "    # Load a nifti image and calculate some\n",
    "    # metric from the image. Use your\n",
    "    # imagination to fill in this function.\n",
    "    time.sleep(2)\n",
    "    np.random.seed()\n",
    "    result = np.random.randint(1, 100, 1)[0]\n",
    "\n",
    "    return result\n",
    "\n",
    "\n",
    "imgfiles = [f'{i:02d}.nii.gz' for i in range(20)]\n",
    "\n",
    "print(f'Crunching {len(imgfiles)} images...')\n",
    "\n",
    "start = time.time()\n",
    "\n",
    "with mp.Pool(processes=16) as p:\n",
    "     results = p.map(crunchImage, imgfiles)\n",
    "\n",
    "end = time.time()\n",
    "\n",
    "for imgfile, result in zip(imgfiles, results):\n",
    "   print(f'Result for {imgfile}: {result}')\n",
    "\n",
    "print('Total execution time: {:0.2f} seconds'.format(end - start))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f363f85a",
   "metadata": {},
   "source": [
    "The `Pool.map` method only works with functions that accept one argument, such\n",
    "as our `crunchImage` function above. If you have a function which accepts\n",
    "multiple arguments, use the\n",
    "[`Pool.starmap`](https://docs.python.org/3/library/multiprocessing.html#multiprocessing.pool.Pool.starmap)\n",
    "method instead:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c23af106",
   "metadata": {},
   "outputs": [],
   "source": [
    "def crunchImage(imgfile, modality):\n",
    "    time.sleep(2)\n",
    "\n",
    "    np.random.seed()\n",
    "\n",
    "    if modality == 't1':\n",
    "        result = np.random.randint(1, 100, 1)\n",
    "    elif modality == 't2':\n",
    "        result = np.random.randint(100, 200, 1)\n",
    "\n",
    "    return result[0]\n",
    "\n",
    "\n",
    "imgfiles   = [f't1_{i:02d}.nii.gz' for i in range(10)] + \\\n",
    "             [f't2_{i:02d}.nii.gz' for i in range(10)]\n",
    "modalities = ['t1'] * 10 + ['t2'] * 10\n",
    "\n",
    "args = [(f, m) for f, m in zip(imgfiles, modalities)]\n",
    "\n",
    "print('Crunching images...')\n",
    "\n",
    "start = time.time()\n",
    "\n",
    "with mp.Pool(processes=16) as pool:\n",
    "     results = pool.starmap(crunchImage, args)\n",
    "\n",
    "end = time.time()\n",
    "\n",
    "for imgfile, modality, result in zip(imgfiles, modalities, results):\n",
    "    print(f'{imgfile} [{modality}]: {result}')\n",
    "\n",
    "print('Total execution time: {:0.2f} seconds'.format(end - start))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e91c1ea5",
   "metadata": {},
   "source": [
    "The `map` and `starmap` methods also have asynchronous equivalents `map_async`\n",
    "and `starmap_async`, which return immediately. Refer to the\n",
    "[`Pool`](https://docs.python.org/3/library/multiprocessing.html#module-multiprocessing.pool)\n",
    "documentation for more details.\n",
    "\n",
    "\n",
    "<a class=\"anchor\" id=\"pool-apply-async\"></a>\n",
    "#### `Pool.apply_async`\n",
    "\n",
    "\n",
    "The\n",
    "[`Pool.apply`](https://docs.python.org/3/library/multiprocessing.html#multiprocessing.pool.Pool.apply)\n",
    "method will execute a function on one of the processes, and block until it has\n",
    "finished.  The\n",
    "[`Pool.apply_async`](https://docs.python.org/3/library/multiprocessing.html#multiprocessing.pool.Pool.apply_async)\n",
    "method returns immediately, and is thus more suited to asynchronously\n",
    "scheduling multiple jobs to run in parallel.\n",
    "\n",
    "\n",
    "`apply_async` returns an object of type\n",
    "[`AsyncResult`](https://docs.python.org/3/library/multiprocessing.html#multiprocessing.pool.AsyncResult).\n",
    "An `AsyncResult` object has `wait` and `get` methods which will block until\n",
    "the job has completed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55e69074",
   "metadata": {},
   "outputs": [],
   "source": [
    "import                    time\n",
    "import multiprocessing as mp\n",
    "import numpy           as np\n",
    "\n",
    "\n",
    "def linear_registration(src, ref):\n",
    "    time.sleep(1)\n",
    "    return np.eye(4)\n",
    "\n",
    "\n",
    "def nonlinear_registration(src, ref, affine):\n",
    "\n",
    "    time.sleep(3)\n",
    "\n",
    "    # this number represents a non-linear warp\n",
    "    # field - use your imagination people!\n",
    "    np.random.seed()\n",
    "    return np.random.randint(1, 100, 1)[0]\n",
    "\n",
    "\n",
    "t1s = [f'{i:02d}_t1.nii.gz' for i in range(20)]\n",
    "std = 'MNI152_T1_2mm.nii.gz'\n",
    "\n",
    "print('Running structural-to-standard registration '\n",
    "      f'on {len(t1s)} subjects...')\n",
    "\n",
    "\n",
    "# Run linear registration on all the T1s.\n",
    "start = time.time()\n",
    "with mp.Pool(processes=16) as pool:\n",
    "\n",
    "    # We build a list of AsyncResult objects\n",
    "    linresults = [pool.apply_async(linear_registration, (t1, std))\n",
    "                  for t1 in t1s]\n",
    "\n",
    "    # Then we wait for each job to finish,\n",
    "    # and replace its AsyncResult object\n",
    "    # with the actual result - an affine\n",
    "    # transformation matrix.\n",
    "    for i, r in enumerate(linresults):\n",
    "        linresults[i] = r.get()\n",
    "\n",
    "\n",
    "end = time.time()\n",
    "\n",
    "print('Linear registrations completed in '\n",
    "      f'{end - start:0.2f} seconds')\n",
    "\n",
    "\n",
    "# Run non-linear registration on all the T1s,\n",
    "# using the linear registrations to initialise.\n",
    "start = time.time()\n",
    "with mp.Pool(processes=16) as pool:\n",
    "    nlinresults = [pool.apply_async(nonlinear_registration, (t1, std, aff))\n",
    "                   for (t1, aff) in zip(t1s, linresults)]\n",
    "\n",
    "    # Wait for each non-linear reg to finish,\n",
    "    # and store the resulting warp field.\n",
    "    for i, r in enumerate(nlinresults):\n",
    "        nlinresults[i] = r.get()\n",
    "\n",
    "\n",
    "end = time.time()\n",
    "\n",
    "print('Non-linear registrations completed in '\n",
    "      '{:0.2f} seconds'.format(end - start))\n",
    "\n",
    "print('Non linear registrations:')\n",
    "for t1, result in zip(t1s, nlinresults):\n",
    "    print(f'{t1} : {result}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c855c8d",
   "metadata": {},
   "source": [
    "<a class=\"anchor\" id=\"sharing-data-between-processes\"></a>\n",
    "## Sharing data between processes\n",
    "\n",
    "\n",
    "When you use the `Pool.map` method (or any of the other methods we have shown)\n",
    "to run a function on a sequence of items, those items must be copied into the\n",
    "memory of each of the child processes. When the child processes are finished,\n",
    "the data that they return then has to be copied back to the parent process.\n",
    "\n",
    "\n",
    "Any items which you wish to pass to a function that is executed by a `Pool`\n",
    "must be *pickleable*<sup>1</sup> - the built-in\n",
    "[`pickle`](https://docs.python.org/3/library/pickle.html) module is used by\n",
    "`multiprocessing` to serialise and de-serialise the data passed to and\n",
    "returned from a child process. The majority of standard Python types (`list`,\n",
    "`dict`, `str` etc), and Numpy arrays can be pickled and unpickled, so you only\n",
    "need to worry about this detail if you are passing objects of a custom type\n",
    "(e.g. instances of classes that you have written, or that are defined in some\n",
    "third-party library).\n",
    "\n",
    "\n",
    "> <sup>1</sup>*Pickleable* is the term used in the Python world to refer to\n",
    "> something that is *serialisable* - basically, the process of converting an\n",
    "> in-memory object into a binary form that can be stored and/or transmitted,\n",
    "> and then loaded back into memory at some point in the future (in the same\n",
    "> process, or in another process).\n",
    "\n",
    "\n",
    "There is obviously some overhead in copying data back and forth between the\n",
    "main process and the worker processes; this may or may not be a problem.  For\n",
    "most computationally intensive tasks, this communication overhead is not\n",
    "important - the performance bottleneck is typically going to be the\n",
    "computation time, rather than I/O between the parent and child processes.\n",
    "\n",
    "\n",
    "However, if you are working with a large data set, where copying it between\n",
    "processes is not viable, you have a couple of options available to you.\n",
    "\n",
    "\n",
    "<a class=\"anchor\" id=\"memory-mapping\"></a>\n",
    "### Memory-mapping\n",
    "\n",
    "\n",
    "One method for sharing a large `numpy` array between multiple processes is to\n",
    "use a _memory-mapped_ array. This is a feature built into `numpy` which\n",
    "stores your data in a regular file, instead of in memory.  This allows your\n",
    "data to be simultaneously read and written by multiple processes, and is fairly\n",
    "straightforward to accomplish.\n",
    "\n",
    "For example, let's say you have some 4D fMRI data, and wish to fit a\n",
    "complicated model to the time series at each voxel.  First we will load our 4D\n",
    "data, and pre-allocate another array to store the fitted model parameters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46938bac",
   "metadata": {},
   "outputs": [],
   "source": [
    "import                    time\n",
    "import functools       as ft\n",
    "import multiprocessing as mp\n",
    "import numpy           as np\n",
    "\n",
    "# Store the parameters that are required\n",
    "# to create our memory-mapped arrays, as\n",
    "# we need to re-use them a couple of times.\n",
    "#\n",
    "# Note that in practice you would usually\n",
    "# want to store these files in a temporary\n",
    "# directory, and/or ensure that they are\n",
    "# deleted once you are finished.\n",
    "data_params  = dict(filename='data.mmap',  shape=(91, 109, 91, 50), dtype=np.float32)\n",
    "model_params = dict(filename='model.mmap', shape=(91, 109, 91),     dtype=np.float32)\n",
    "\n",
    "# Load our data as a memory-mapped array (we\n",
    "# are using random data for this example)\n",
    "data    = np.memmap(**data_params, mode='w+')\n",
    "data[:] = np.random.random((91, 109, 91, 50)).astype(np.float32)\n",
    "data.flush()\n",
    "\n",
    "# Pre-allocate space to store the fitted\n",
    "# model parameters\n",
    "model = np.memmap(**model_params, mode='w+')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89897973",
   "metadata": {},
   "source": [
    "> If your image files are uncompressed (i.e. `.nii` rather than `.nii.gz`),\n",
    "> you can instruct `nibabel` and `fslpy` to load them as a memory-map by\n",
    "> passing `mmap=True` to the `nibabel.load` function, and the\n",
    "> `fsl.data.image.Image` constructor.\n",
    "\n",
    "\n",
    "Now we will write our model fitting function so that it works on one slice at\n",
    "a time - this will allow us to process multiple slices in parallel. Note\n",
    "that, within this function, we have to _re-load_ the memory-mapped arrays. In\n",
    "this example we have written this function so as to expect the arguments\n",
    "required to create the two memory-maps to be passed in (the `data_params` and\n",
    "`model_params` dictionaries that we created above):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "037dbf6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_model(indata, outdata, sliceidx):\n",
    "\n",
    "    indata  = np.memmap(**indata,  mode='r')\n",
    "    outdata = np.memmap(**outdata, mode='r+')\n",
    "\n",
    "    # sleep to simulate expensive model fitting\n",
    "    print(f'Fitting model at slice {sliceidx}')\n",
    "    time.sleep(1)\n",
    "    outdata[:, :, sliceidx] = indata[:, :, sliceidx, :].mean() + sliceidx"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb4b7f8f",
   "metadata": {},
   "source": [
    "Now we can use `multiprocessing` to fit the model in parallel across all of the\n",
    "image slices:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba54b93c",
   "metadata": {},
   "outputs": [],
   "source": [
    "fit_function = ft.partial(fit_model, data_params, model_params)\n",
    "slice_idxs   = list(range(91))\n",
    "\n",
    "with mp.Pool(processes=16) as pool:\n",
    "    pool.map(fit_function, slice_idxs)\n",
    "\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "145a22d2",
   "metadata": {},
   "source": [
    "<a class=\"anchor\" id=\"read-only-sharing\"></a>\n",
    "### Read-only sharing\n",
    "\n",
    "\n",
    "If you are working with a large dataset, you have determined that copying data\n",
    "between processes is having a substantial impact on your performance, and have\n",
    "also decided that memory-mapping is not an option for you, and instead wish to\n",
    "*share* a single copy of the data between the processes, you will need to:\n",
    "\n",
    " 1. Structure your code so that the data you want to share is accessible at\n",
    "    the *module level*.\n",
    " 2. Define/create/load the data *before* creating the `Pool`.\n",
    "\n",
    "\n",
    "This is because, when you create a `Pool`, what actually happens is that the\n",
    "process your Python script is running in will [**fork**][wiki-fork] itself -\n",
    "the child processes that are created are used as the worker processes by the\n",
    "`Pool`. And if you create/load your data in your main process *before* this\n",
    "fork occurs, all of the child processes will inherit the memory space of the\n",
    "main process, and will therefore have (read-only) access to the data, without\n",
    "any copying required.\n",
    "\n",
    "\n",
    "[wiki-fork]: https://en.wikipedia.org/wiki/Fork_(system_call)\n",
    "\n",
    "\n",
    "Let's see this in action with a simple example. We'll start by defining a\n",
    "horrible little helper function which allows us to track the total memory\n",
    "usage:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd36f0b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import subprocess as sp\n",
    "def memusage(msg):\n",
    "    if sys.platform == 'darwin':\n",
    "        total = sp.run(['sysctl', 'hw.memsize'], capture_output=True).stdout.decode()\n",
    "        total = int(total.split()[1]) // 1048576\n",
    "        usage = sp.run('vm_stat', capture_output=True).stdout.decode()\n",
    "        usage = usage.strip().split('\\n')\n",
    "        usage = [l.split(':') for l in usage]\n",
    "        usage = {k.strip() : v.strip() for k, v in usage}\n",
    "        usage = int(usage['Pages free'][:-1]) / 256.0\n",
    "        usage = int(total - usage)\n",
    "    else:\n",
    "        stdout = sp.run(['free', '--mega'], capture_output=True).stdout.decode()\n",
    "        stdout = stdout.split('\\n')[1].split()\n",
    "        total  = int(stdout[1])\n",
    "        usage  = int(stdout[2])\n",
    "    print(f'Memory usage {msg}: {usage} / {total} MB')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f50cbe16",
   "metadata": {},
   "source": [
    "Now our task is simply to calculate the sum of a large array of numbers. We're\n",
    "going to create a big chunk of data, and process it in chunks, keeping track\n",
    "of memory usage as the task progresses:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa189b29",
   "metadata": {},
   "outputs": [],
   "source": [
    "import                    time\n",
    "import multiprocessing as mp\n",
    "import numpy           as np\n",
    "\n",
    "memusage('before creating data')\n",
    "\n",
    "# allocate 500MB of data\n",
    "data = np.random.random(500 * (1048576 // 8))\n",
    "\n",
    "# Assign nelems values to each worker\n",
    "# process (hard-coded so we need 12\n",
    "# jobs to complete the task)\n",
    "nelems =  len(data) // 12\n",
    "\n",
    "memusage('after creating data')\n",
    "\n",
    "# Each job process nelems values,\n",
    "# starting from the specified offset\n",
    "def process_chunk(offset):\n",
    "    time.sleep(1)\n",
    "    return data[offset:offset + nelems].sum()\n",
    "\n",
    "# Generate an offset into the data for each job -\n",
    "# we will call process_chunk for each offset\n",
    "offsets = range(0, len(data), nelems)\n",
    "\n",
    "# Create our worker process pool\n",
    "with mp.Pool(4) as pool:\n",
    "\n",
    "    results = pool.map_async(process_chunk, offsets)\n",
    "\n",
    "    # Wait for all of the jobs to finish\n",
    "    elapsed = 0\n",
    "    while not results.ready():\n",
    "        memusage(f'after {elapsed} seconds')\n",
    "        time.sleep(1)\n",
    "        elapsed += 1\n",
    "\n",
    "    results = results.get()\n",
    "\n",
    "print('Total sum:   ', sum(results))\n",
    "print('Sanity check:', data.sum())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23ab2138",
   "metadata": {},
   "source": [
    "You should be able to see that only one copy of `data` is created, and is\n",
    "shared by all of the worker processes without any copying taking place.\n",
    "\n",
    "So things are reasonably straightforward if you only need read-only acess to\n",
    "your data. But what if your worker processes need to be able to modify the\n",
    "data? Go back to the code block above and:\n",
    "\n",
    "1. Modify the `process_chunk` function so that it modifies every element of\n",
    "   its assigned portion of the data before the call to `time.sleep`.  For\n",
    "   example:\n",
    "\n",
    "   > ```\n",
    "   > data[offset:offset + nelems] += 1\n",
    "   > ```\n",
    "\n",
    "2. Restart the Jupyter notebook kernel (*Kernel -> Restart*) - this example is\n",
    "   somewhat dependent on the behaviour of the Python garbage collector, so it\n",
    "   helps to start afresh\n",
    "\n",
    "\n",
    "2. Re-run the two code blocks, and watch what happens to the memory usage.\n",
    "\n",
    "\n",
    "What happened? Well, you are seeing [copy-on-write][wiki-copy-on-write] in\n",
    "action. When the `process_chunk` function is invoked, it is given a reference\n",
    "to the original data array in the memory space of the parent process. But as\n",
    "soon as an attempt is made to modify it, a copy of the data, in the memory\n",
    "space of the child process, is created. The modifications are then applied to\n",
    "this child process copy, and not to the original copy. So the total memory\n",
    "usage has blown out to twice as much as before, and the changes made by each\n",
    "child process are being lost!\n",
    "\n",
    "\n",
    "[wiki-copy-on-write]: https://en.wikipedia.org/wiki/Copy-on-write\n",
    "\n",
    "\n",
    "<a class=\"anchor\" id=\"read-write-sharing\"></a>\n",
    "### Read/write sharing\n",
    "\n",
    "\n",
    "> If you have worked with a real programming language with true parallelism\n",
    "> and shared memory via within-process multi-threading, feel free to take a\n",
    "> break at this point. Breathe. Relax. Go punch a hole in a wall. I've been\n",
    "> coding in Python for years, and this still makes me angry. Sometimes\n",
    "> ... don't tell anyone I said this ... I even find myself wishing I were\n",
    "> coding in *Java* instead of Python. Ugh. I need to take a shower.\n",
    "\n",
    "\n",
    "In order to truly share memory between multiple processes, the\n",
    "`multiprocessing` module provides the [`Value`, `Array`, and `RawArray`\n",
    "classes](https://docs.python.org/3/library/multiprocessing.html#shared-ctypes-objects),\n",
    "which allow you to share individual values, or arrays of values, respectively.\n",
    "\n",
    "\n",
    "The `Array` and `RawArray` classes essentially wrap a typed pointer (from the\n",
    "built-in [`ctypes`](https://docs.python.org/3/library/ctypes.html) module) to\n",
    "a block of memory. We can use the `Array` or `RawArray` class to share a Numpy\n",
    "array between our worker processes. The difference between an `Array` and a\n",
    "`RawArray` is that the former offers low-level synchronised\n",
    "(i.e. process-safe) access to the shared memory. This is necessary if your\n",
    "child processes will be modifying the same parts of your data.\n",
    "\n",
    "\n",
    "> If you need fine-grained control over synchronising access to shared data by\n",
    "> multiple processes, all of the [synchronisation\n",
    "> primitives](https://docs.python.org/3/library/multiprocessing.html#synchronization-between-processes)\n",
    "> from the `multiprocessing` module are at your disposal.\n",
    "\n",
    "\n",
    "The requirements for sharing memory between processes still apply here - we\n",
    "need to make our data accessible at the *module level*, and we need to create\n",
    "our data before creating the `Pool`. And to achieve read and write capability,\n",
    "we also need to make sure that our input and output arrays are located in\n",
    "shared memory - we can do this via the `Array` or `RawArray`.\n",
    "\n",
    "\n",
    "As an example, let's say we want to parallelise processing of an image by\n",
    "having each worker process perform calculations on a chunk of the image.\n",
    "First, let's define a function which does the calculation on a specified set\n",
    "of image coordinates:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1e1d718",
   "metadata": {},
   "outputs": [],
   "source": [
    "import multiprocessing as mp\n",
    "import ctypes\n",
    "import numpy as np\n",
    "np.set_printoptions(suppress=True)\n",
    "\n",
    "\n",
    "def process_chunk(shape, idxs):\n",
    "\n",
    "    # Get references to our\n",
    "    # input/output data, and\n",
    "    # create Numpy array views\n",
    "    # into them.\n",
    "    sindata  = process_chunk.input_data\n",
    "    soutdata = process_chunk.output_data\n",
    "    indata   = np.ctypeslib.as_array(sindata) .reshape(shape)\n",
    "    outdata  = np.ctypeslib.as_array(soutdata).reshape(shape)\n",
    "\n",
    "    # Do the calculation on\n",
    "    # the specified voxels\n",
    "    outdata[idxs] = indata[idxs] ** 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7a57523",
   "metadata": {},
   "source": [
    "Rather than passing the input and output data arrays in as arguments to the\n",
    "`process_chunk` function, we set them as attributes of the `process_chunk`\n",
    "function. This makes the input/output data accessible at the module level,\n",
    "which is required in order to share the data between the main process and the\n",
    "child processes.\n",
    "\n",
    "\n",
    "Now let's define a second function which processes an entire image. It does\n",
    "the following:\n",
    "\n",
    "\n",
    "1. Initialises shared memory areas to store the input and output data.\n",
    "2. Copies the input data into shared memory.\n",
    "3. Sets the input and output data as attributes of the `process_chunk` function.\n",
    "4. Creates sets of indices into the input data which, for each worker process,\n",
    "   specifies the portion of the data that it is responsible for.\n",
    "5. Creates a worker pool, and runs the `process_chunk` function for each set\n",
    "   of indices."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c16c5587",
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_dataset(data):\n",
    "\n",
    "    nprocs   = 8\n",
    "    origData = data\n",
    "\n",
    "    # Create arrays to store the\n",
    "    # input and output data\n",
    "    sindata  = mp.RawArray(ctypes.c_double, data.size)\n",
    "    soutdata = mp.RawArray(ctypes.c_double, data.size)\n",
    "    data     = np.ctypeslib.as_array(sindata).reshape(data.shape)\n",
    "    outdata  = np.ctypeslib.as_array(soutdata).reshape(data.shape)\n",
    "\n",
    "    # Copy the input data\n",
    "    # into shared memory\n",
    "    data[:]  = origData\n",
    "\n",
    "    # Make the input/output data\n",
    "    # accessible to the process_chunk\n",
    "    # function. This must be done\n",
    "    # *before* the worker pool is\n",
    "    # created - even though we are\n",
    "    # doing things differently to the\n",
    "    # read-only example, we are still\n",
    "    # making the data arrays accessible\n",
    "    # at the *module* level, so the\n",
    "    # memory they are stored in can be\n",
    "    # shared with the child processes.\n",
    "    process_chunk.input_data  = sindata\n",
    "    process_chunk.output_data = soutdata\n",
    "\n",
    "    # number of voxels to be computed\n",
    "    # by each worker process.\n",
    "    nvox = int(data.size / nprocs)\n",
    "\n",
    "    # Generate coordinates for\n",
    "    # every voxel in the image\n",
    "    xlen, ylen, zlen = data.shape\n",
    "    xs, ys, zs = np.meshgrid(np.arange(xlen),\n",
    "                             np.arange(ylen),\n",
    "                             np.arange(zlen))\n",
    "\n",
    "    xs = xs.flatten()\n",
    "    ys = ys.flatten()\n",
    "    zs = zs.flatten()\n",
    "\n",
    "    # We're going to pass each worker\n",
    "    # process a list of indices, which\n",
    "    # specify the data items which that\n",
    "    # worker process needs to compute.\n",
    "    xs = [xs[nvox * i:nvox * i + nvox] for i in range(nprocs)] + [xs[nvox * nprocs:]]\n",
    "    ys = [ys[nvox * i:nvox * i + nvox] for i in range(nprocs)] + [ys[nvox * nprocs:]]\n",
    "    zs = [zs[nvox * i:nvox * i + nvox] for i in range(nprocs)] + [zs[nvox * nprocs:]]\n",
    "\n",
    "    # Build the argument lists for\n",
    "    # each worker process.\n",
    "    args = [(data.shape, (x, y, z)) for x, y, z in zip(xs, ys, zs)]\n",
    "\n",
    "    # Create a pool of worker\n",
    "    # processes and run the jobs.\n",
    "    with mp.Pool(processes=nprocs) as pool:\n",
    "        pool.starmap(process_chunk, args)\n",
    "\n",
    "    return outdata"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dedc20ea",
   "metadata": {},
   "source": [
    "Now we can call our `process_data` function just like any other function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3bfb596",
   "metadata": {},
   "outputs": [],
   "source": [
    "indata  = np.array(np.arange(64).reshape((4, 4, 4)), dtype=np.float64)\n",
    "outdata = process_dataset(indata)\n",
    "\n",
    "print('Input')\n",
    "print(indata)\n",
    "\n",
    "print('Output')\n",
    "print(outdata)"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 5
}