from torch.utils.data import Dataset import cv2 from PIL import Image import os class MyDataset(Dataset): def __init__(self, root_dir, label_dir): self.root_dir = root_dir self.label_dir = label_dir self.path = os.path.join(self.root_dir, self.label_dir) self.img_path = os.listdir(self.path) # __getitem__ 是一个魔法方法,当使用dataset[i]获取第i个样本时,就会调用该方法。 def __getitem__(self, idx): img_name = self.img_path[idx] img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) img = Image.open(img_item_path) label = self.label_dir return img, label def __len__(self): return len(self.img_path) root_dir = 'dataset/train' ants_label_dir = 'ants' bees_label_dir = 'bees' ants_dataset = MyDataset(root_dir, ants_label_dir) bees_dataset = MyDataset(root_dir, bees_label_dir) print(ants_dataset[0])