default_loader 在 torch 中的作用是什么?
what is the function of default_loader in torch?
import os
import pandas as pd
import numpy as np
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import download_url
from torch.utils.data import Dataset
class Sample_Class(Dataset):
def __init__(self,root,train=True,transform=None,loader=default_loader):
self.root = os.path.expanduser(root)
self.transform = transform
self.loader = default_loader
上面的代码片段中,loader=default_loader
的意义是什么,具体是做什么的?
这个 Sample_Class
很可能是在模仿 ImageFolder, DatasetFolder, and ImageNet 的行为。该函数应将文件名作为输入,return 是 PIL.Image
或 accimage.Image
,具体取决于所选的图像后端。
default_loader
函数定义在torchvision/datasets/folder.py
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def accimage_loader(path):
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
def default_loader(path):
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
注意:default_loader
默认为 PIL
reader
import os
import pandas as pd
import numpy as np
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import download_url
from torch.utils.data import Dataset
class Sample_Class(Dataset):
def __init__(self,root,train=True,transform=None,loader=default_loader):
self.root = os.path.expanduser(root)
self.transform = transform
self.loader = default_loader
上面的代码片段中,loader=default_loader
的意义是什么,具体是做什么的?
这个 Sample_Class
很可能是在模仿 ImageFolder, DatasetFolder, and ImageNet 的行为。该函数应将文件名作为输入,return 是 PIL.Image
或 accimage.Image
,具体取决于所选的图像后端。
default_loader
函数定义在torchvision/datasets/folder.py
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def accimage_loader(path):
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
def default_loader(path):
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
注意:default_loader
默认为 PIL
reader