diff --git a/fsl/utils/filetree/filetree.py b/fsl/utils/filetree/filetree.py index 3159fefa842f986115921ae4166df7411ac93893..b1dad5f1b14c467e207ace2a518e977856880fe2 100644 --- a/fsl/utils/filetree/filetree.py +++ b/fsl/utils/filetree/filetree.py @@ -3,9 +3,9 @@ from typing import Tuple, Optional, Dict, Any, Set from copy import deepcopy from . import parse import pickle +import json import os.path as op from . import utils -from fsl.utils.deprecated import deprecated class MissingVariable(KeyError): @@ -255,12 +255,28 @@ class FileTree(object): with open(filename, 'wb') as f: pickle.dump(self, f) + def save_json(self, filename): + """ + Saves the Filetree to a JSON file + + :param filename: filename to store the file tree in + """ + def default(obj): + if isinstance(obj, FileTree): + res = dict(obj.__dict__) + del res['_parent'] + return res + return obj + + with open(filename, 'w') as f: + json.dump(self, f, default=default) + @classmethod def load_pickle(cls, filename): """ Loads the Filetree from a pickle file - :param filename: filename produced from Filetree.save + :param filename: filename produced from Filetree.save_pickle :return: stored Filetree """ with open(filename, 'rb') as f: @@ -269,6 +285,30 @@ class FileTree(object): raise IOError("Pickle file did not contain %s object" % cls) return res + @classmethod + def load_json(cls, filename): + """ + Loads the FileTree from a JSON file + + :param filename: filename produced by FileTree.save_json + :return: stored FileTree + """ + def from_dict(input_dict): + res_tree = FileTree( + templates=input_dict['templates'], + variables=input_dict['variables'], + sub_trees={name: from_dict(value) for name, value in input_dict['sub_trees'].items()}, + name=input_dict['_name'], + ) + for sub_tree in res_tree.sub_trees.values(): + sub_tree._parent = res_tree + return res_tree + + with open(filename, 'r') as f: + as_dict = json.load(f) + return from_dict(as_dict) + + def defines(self, short_names, error=False): """ Checks whether templates are defined for all the `short_names` diff --git a/tests/test_filetree/test_read.py b/tests/test_filetree/test_read.py index 6627969e343094ca8a8cc8797c09872195ab9774..8382cc6d2f48a3965a7e71ca52c4708e8f06ecb9 100644 --- a/tests/test_filetree/test_read.py +++ b/tests/test_filetree/test_read.py @@ -1,5 +1,6 @@ # Sample Test passing with nose and pytest from fsl.utils import filetree +from fsl.utils.tempdir import tempdir from pathlib import PurePath import os.path as op import pytest @@ -180,3 +181,26 @@ def test_read_local_sub_children(): # ensure current directory is not the test directory, which would cause the test to be too easy os.chdir('..') filetree.FileTree.read(op.join(directory, 'local_parent.tree')) + + +def same_tree(t1, t2): + assert t1.all_variables == t2.all_variables + assert t1.templates == t2.templates + assert len(t1.sub_trees) == len(t2.sub_trees) + for name in t1.sub_trees: + same_tree(t1.sub_trees[name], t2.sub_trees[name]) + assert t1.sub_trees[name].parent is t1 + assert t2.sub_trees[name].parent is t2 + + +def test_io(): + directory = op.split(__file__)[0] + tree = filetree.FileTree.read(op.join(directory, 'parent.tree'), partial_fill=True) + with tempdir(): + tree.save_pickle('test.pck') + new_tree = filetree.FileTree.load_pickle('test.pck') + same_tree(tree, new_tree) + + tree.save_json('test.json') + new_tree = filetree.FileTree.load_json('test.json') + same_tree(tree, new_tree)