diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index e074d702b6c87ab9dae882cdd6a42e9eac12c669..2790ec5d8549a8ae2352c5aef6cf6c8d862d0842 100644
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -2,6 +2,17 @@ This document contains the ``fslpy`` release history in reverse chronological
 order.
 
 
+2.3.1 (Friday July 5th 2019)
+----------------------------
+
+
+Fixed
+^^^^^
+
+
+* The :class:`.Bitmap` class now supports greyscale images and palette images.
+
+
 2.3.0 (Tuesday June 25th 2019)
 ------------------------------
 
diff --git a/fsl/data/bitmap.py b/fsl/data/bitmap.py
index ec7b6b1fe1c9e12c5fa8e52c6e52589f1be3281c..51b354f7ddfcef93cb44a2d2a83231b518d89fe2 100644
--- a/fsl/data/bitmap.py
+++ b/fsl/data/bitmap.py
@@ -51,30 +51,48 @@ class Bitmap(object):
                   data.
         """
 
-        try:
-            import PIL.Image as Image
-        except ImportError:
-            raise RuntimeError('Install Pillow to use the Bitmap class')
-
         if isinstance(bmp, six.string_types):
-            source = bmp
-            data   = np.array(Image.open(source))
+
+            try:
+                # Allow big images
+                import PIL.Image as Image
+                Image.MAX_IMAGE_PIXELS = 1e9
+
+            except ImportError:
+                raise RuntimeError('Install Pillow to use the Bitmap class')
+
+            src = bmp
+            img = Image.open(src)
+
+            # If this is a palette/LUT
+            # image, convert it into a
+            # regular rgb(a) image.
+            if img.mode == 'P':
+                img = img.convert()
+
+            data = np.array(img)
 
         elif isinstance(bmp, np.ndarray):
-            source = 'array'
-            data   = np.copy(bmp)
+            src  = 'array'
+            data = np.copy(bmp)
 
         else:
             raise ValueError('unknown bitmap: {}'.format(bmp))
 
-        # Make the array (w, h, c)
+        # Make the array (w, h, c). Single channel
+        # (e.g. greyscale) images are returned as
+        # 2D arrays, whereas multi-channel images
+        # are returned as 3D. In either case, the
+        # first two dimensions are (height, width),
+        # but we watn them the other way aruond.
+        data = np.atleast_3d(data)
         data = np.fliplr(data.transpose((1, 0, 2)))
         data = np.array(data, dtype=np.uint8, order='C')
         w, h = data.shape[:2]
 
         self.__data       = data
-        self.__dataSource = source
-        self.__name       = op.basename(source)
+        self.__dataSource = src
+        self.__name       = op.basename(src)
 
 
     def __hash__(self):
@@ -132,7 +150,6 @@ class Bitmap(object):
         if nchannels == 1:
             dtype = np.uint8
 
-
         elif nchannels == 3:
             dtype = np.dtype([('R', 'uint8'),
                               ('G', 'uint8'),
@@ -158,4 +175,6 @@ class Bitmap(object):
 
         data = np.array(data, order='F', copy=False)
 
-        return fslimage.Image(data, name=self.name)
+        return fslimage.Image(data,
+                              name=self.name,
+                              dataSource=self.dataSource)
diff --git a/fsl/data/image.py b/fsl/data/image.py
index 417a0445ceccd64f6bcc28d7abc1d2214c2de32d..1bf901c287d5c2ba3ea033e11f08fdb79cc55e3d 100644
--- a/fsl/data/image.py
+++ b/fsl/data/image.py
@@ -809,6 +809,7 @@ class Image(Nifti):
         """
 
         nibImage = None
+        saved    = False
 
         if indexed is not False:
             warnings.warn('The indexed argument is deprecated '
@@ -841,10 +842,10 @@ class Image(Nifti):
 
         # The image parameter may be the name of an image file
         if isinstance(image, six.string_types):
-
             image      = op.abspath(addExt(image))
             nibImage   = nib.load(image, **kwargs)
             dataSource = image
+            saved      = True
 
         # Or a numpy array - we wrap it in a nibabel image,
         # with an identity transformation (each voxel maps
@@ -906,7 +907,7 @@ class Image(Nifti):
         self.__dataSource   = dataSource
         self.__threaded     = threaded
         self.__nibImage     = nibImage
-        self.__saveState    = dataSource is not None
+        self.__saveState    = saved
         self.__imageWrapper = imagewrapper.ImageWrapper(self.nibImage,
                                                         self.name,
                                                         loadData=loadData,
diff --git a/tests/test_bitmap.py b/tests/test_bitmap.py
index b59c4083e2e822400f19ccc6895d500179fb4149..20b369e75775e52b2acd5e83397aa980bb437f93 100644
--- a/tests/test_bitmap.py
+++ b/tests/test_bitmap.py
@@ -19,22 +19,30 @@ def test_bitmap():
 
     from PIL import Image
 
+    nchannels = (1, 3, 4)
+
     with tempdir.tempdir():
-        data = np.random.randint(0, 255, (100, 200, 4), dtype=np.uint8)
-        img  = Image.fromarray(data, mode='RGBA')
 
-        img.save('image.png')
+        for nch in nchannels:
+            data = np.random.randint(0, 255, (100, 200, nch), dtype=np.uint8)
+            img  = Image.fromarray(data.squeeze())
+
+            fname = 'image.png'
+            img.save(fname)
 
-        bmp = fslbmp.Bitmap('image.png')
+            bmp1 = fslbmp.Bitmap(fname)
+            bmp2 = fslbmp.Bitmap(data)
 
-        assert bmp.name       == 'image.png'
-        assert bmp.dataSource == 'image.png'
-        assert bmp.shape      == (200, 100, 4)
+            assert bmp1.name       == fname
+            assert bmp1.dataSource == fname
+            assert bmp1.shape      == (200, 100, nch)
+            assert bmp2.shape      == (200, 100, nch)
 
-        repr(bmp)
-        hash(bmp)
+            repr(bmp1)
+            hash(bmp1)
 
-        assert np.all(bmp.data == np.fliplr(data.transpose(1, 0, 2)))
+            assert np.all(bmp1.data == np.fliplr(data.transpose(1, 0, 2)))
+            assert np.all(bmp2.data == np.fliplr(data.transpose(1, 0, 2)))
 
 
 @pytest.mark.piltest
@@ -47,17 +55,23 @@ def test_bitmap_asImage():
 
         img3 = Image.fromarray(d3, mode='RGB')
         img4 = Image.fromarray(d4, mode='RGBA')
+        img1 = img3.convert(mode='P')
 
         img3.save('rgb.png')
         img4.save('rgba.png')
+        img1.save('p.png')
 
-        bmp3  = fslbmp.Bitmap('rgb.png')
-        bmp4  = fslbmp.Bitmap('rgba.png')
+        bmp3 = fslbmp.Bitmap('rgb.png')
+        bmp4 = fslbmp.Bitmap('rgba.png')
+        bmp1 = fslbmp.Bitmap('p.png')
 
-        i3 = bmp3.asImage()
-        i4 = bmp4.asImage()
+        i3   = bmp3.asImage()
+        i4   = bmp4.asImage()
+        i1   = bmp1.asImage()
 
         assert i3.shape == (200, 100, 1)
         assert i4.shape == (200, 100, 1)
+        assert i1.shape == (200, 100, 1)
         assert i3.nvals == 3
         assert i4.nvals == 4
+        assert i1.nvals == 3