diff --git a/tests/test_bitmap.py b/tests/test_bitmap.py index b59c4083e2e822400f19ccc6895d500179fb4149..81c1cd1815abf652e07876421851284778c2cbbd 100644 --- a/tests/test_bitmap.py +++ b/tests/test_bitmap.py @@ -19,22 +19,27 @@ def test_bitmap(): from PIL import Image + nchannels = (1, 3, 4) + with tempdir.tempdir(): - data = np.random.randint(0, 255, (100, 200, 4), dtype=np.uint8) - img = Image.fromarray(data, mode='RGBA') - img.save('image.png') + for nch in nchannels: + data = np.random.randint(0, 255, (100, 200, nch), dtype=np.uint8) + img = Image.fromarray(data.squeeze()) + + fname = 'image.png' + img.save(fname) - bmp = fslbmp.Bitmap('image.png') + bmp = fslbmp.Bitmap(fname) - assert bmp.name == 'image.png' - assert bmp.dataSource == 'image.png' - assert bmp.shape == (200, 100, 4) + assert bmp.name == fname + assert bmp.dataSource == fname + assert bmp.shape == (200, 100, nch) - repr(bmp) - hash(bmp) + repr(bmp) + hash(bmp) - assert np.all(bmp.data == np.fliplr(data.transpose(1, 0, 2))) + assert np.all(bmp.data == np.fliplr(data.transpose(1, 0, 2))) @pytest.mark.piltest