diff --git a/scripts/test_vistas_single_gpu.py b/scripts/test_vistas_single_gpu.py index 466ff64..2dc0019 100644 --- a/scripts/test_vistas_single_gpu.py +++ b/scripts/test_vistas_single_gpu.py @@ -14,9 +14,11 @@ from dataset.transform import SegmentationTransform from inplace_abn import InPlaceABN from modules import DeeplabV3 +import PIL from PIL import Image, ImagePalette from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from itertools import chain parser = argparse.ArgumentParser( description="Testing script for the Vistas segmentation model" @@ -317,11 +319,18 @@ def load_snapshot(snapshot_file): _PALETTE = np.concatenate( [_PALETTE, np.zeros((256 - _PALETTE.shape[0], 3), dtype=np.uint8)], axis=0 ) -_PALETTE = ImagePalette.ImagePalette( - palette=list(_PALETTE[:, 0]) + list(_PALETTE[:, 1]) + list(_PALETTE[:, 2]), - mode="RGB", -) +PIL_version = PIL.__version__.split('.') +if int(PIL_version[0])<8 or (int(PIL_version[0])==8 and int(PIL_version[1])<3): + _PALETTE = ImagePalette.ImagePalette( + palette=list(_PALETTE[:, 0]) + list(_PALETTE[:, 1]) + list(_PALETTE[:, 2]), + mode="RGB", + ) +else: + _PALETTE = ImagePalette.ImagePalette( + palette=list(chain.from_iterable(_PALETTE)), + mode="RGB", + ) def get_pred_image(tensor, out_size, with_palette): tensor = tensor.numpy()