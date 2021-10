# 省略



from torch.utils.data import Subset



# 省略



def get_kfold_datasets(dataset, k):

datasets = []



if not hasattr(dataset, '__len__'):

raise TypeError(f'{dataset} does not have a __len__ attr')

else:

ds_length = len(dataset) # ok since dataset has a __len__ attr



for trainidx, validx in make_kfold_range(ds_length, k):

datasets.append((Subset(dataset, trainidx), Subset(dataset, validx)))

return datasets