AssertionError: Format for classes is `<label> file` - python-3.x

This is a python script for detecting features in a set of images for a SVM.
import os
import sys
import argparse
import _pickle as cPickle
import json
import cv2
import numpy as np
from sklearn.cluster import KMeans
def build_arg_parser():
parser = argparse.ArgumentParser(description='Creates features for given images')
parser.add_argument("--samples", dest="cls", nargs="+", action="append",
required=True, help="Folders containing the training images. \
The first element needs to be the class label.")
parser.add_argument("--codebook-file", dest='codebook_file', required=True,
help="Base file name to store the codebook")
parser.add_argument("--feature-map-file", dest='feature_map_file', required=True,
help="Base file name to store the feature map")
parser.add_argument("--scale-image", dest="scale", type=int, default=150,
help="Scales the longer dimension of the image down to this size.")
return parser
def load_input_map(label, input_folder):
combined_data = []
if not os.path.isdir(input_folder):
print ("The folder " + input_folder + " doesn't exist")
raise IOError
for root, dirs, files in os.walk(input_folder):
for filename in (x for x in files if x.endswith('.jpg')):
combined_data.append({'label': label, 'image': os.path.join(root, filename)})
return combined_data
class FeatureExtractor(object):
def extract_image_features(self, img):
kps = DenseDetector().detect(img)
kps, fvs = SIFTExtractor().compute(img, kps)
return fvs
def get_centroids(self, input_map, num_samples_to_fit=10):
kps_all = []
count = 0
cur_label = ''
for item in input_map:
if count >= num_samples_to_fit:
if cur_label != item['label']:
count = 0
count += 1
if count == num_samples_to_fit:
print ("Built centroids for", item['label'])
cur_label = item['label']
img = cv2.imread(item['image'])
img = resize_to_size(img, 150)
num_dims = 128
fvs = self.extract_image_features(img)
kmeans, centroids = Quantizer().quantize(kps_all)
return kmeans, centroids
def get_feature_vector(self, img, kmeans, centroids):
return Quantizer().get_feature_vector(img, kmeans, centroids)
def extract_feature_map(input_map, kmeans, centroids):
feature_map = []
for item in input_map:
temp_dict = {}
temp_dict['label'] = item['label']
print ("Extracting features for", item['image'])
img = cv2.imread(item['image'])
img = resize_to_size(img, 150)
temp_dict['feature_vector'] = FeatureExtractor().get_feature_vector(
img, kmeans, centroids)
if temp_dict['feature_vector'] is not None:
return feature_map
class Quantizer(object):
def __init__(self, num_clusters=32):
self.num_dims = 128
self.extractor = SIFTExtractor()
self.num_clusters = num_clusters
self.num_retries = 10
def quantize(self, datapoints):
kmeans = KMeans(self.num_clusters,
n_init=max(self.num_retries, 1),
max_iter=10, tol=1.0)
res =
centroids = res.cluster_centers_
return kmeans, centroids
def normalize(self, input_data):
sum_input = np.sum(input_data)
if sum_input > 0:
return input_data / sum_input
return input_data
def get_feature_vector(self, img, kmeans, centroids):
kps = DenseDetector().detect(img)
kps, fvs = self.extractor.compute(img, kps)
labels = kmeans.predict(fvs)
fv = np.zeros(self.num_clusters)
for i, item in enumerate(fvs):
fv[labels[i]] += 1
fv_image = np.reshape(fv, ((1, fv.shape[0])))
return self.normalize(fv_image)
class DenseDetector(object):
def __init__(self, step_size=20, feature_scale=40, img_bound=20):
self.detector = cv2.xfeatures2d.SIFT_create("Dense")
self.detector.setInt("initXyStep", step_size)
self.detector.setInt("initFeatureScale", feature_scale)
self.detector.setInt("initImgBound", img_bound)
def detect(self, img):
return self.detector.detect(img)
class SIFTExtractor(object):
def compute(self, image, kps):
if image is None:
print ("Not a valid image")
raise TypeError
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
kps, des = cv2.SIFT().compute(gray_image, kps)
return kps, des
# Resize the shorter dimension to 'new_size'
# while maintaining the aspect ratio
def resize_to_size(input_image, new_size=150):
h, w = input_image.shape[0], input_image.shape[1]
ds_factor = new_size / float(h)
if w < h:
ds_factor = new_size / float(w)
new_size = (int(w * ds_factor), int(h * ds_factor))
return cv2.resize(input_image, new_size)
if __name__=='__main__':
args = build_arg_parser().parse_args()
input_map = []
for cls in args.cls:
assert len(cls) >= 2, "Format for classes is `<label> file`"
label = cls[0]
input_map += load_input_map(label, cls[1])
downsample_length = args.scale
# Building the codebook
print ("===== Building codebook =====")
kmeans, centroids = FeatureExtractor().get_centroids(input_map)
if args.codebook_file:
with open(args.codebook_file, 'w') as f:
pickle.dump((kmeans, centroids), f)
# Input data and labels
print ("===== Building feature map =====")
feature_map = extract_feature_map(input_map, kmeans, centroids)
if args.feature_map_file:
with open(args.feature_map_file, 'w') as f:
pickle.dump(feature_map, f)
I receive the following error:
Traceback (most recent call last):
File "", line 164, in <module>
assert len(cls) >= 2, ("Format for classes is `<label> file`")
AssertionError: Format for classes is `<label> file`
Any idea of what could be wrong? I'm just following the instructions of 'OpenCV with Python by Example' of Prateek Joshi. Pages 494-526

Assertion are used to check a condition. If the condition isn't satisfied, it throes AssertionError. In your case, len(cls) >= 2 isn't satisfied. It means that len(cls) is smaller than 2. Apparently, cls is a list of arguments passed to the programm. And the first element of this list must be a label. And when you add argument (a file), you should specify a label for this file.
For example, if you choose a label name my_label, you must add file with my_label my_file.


How can I speed up using Pytorch DataLoader?

I had a dataset including about a million of rows. Before, I read the rows, preprocessed data and created a list of rows to be trained. Then I defined a Dataloader over this data like:
train_dataloader =['train'],
Preprocessing could be time consuming, so I thought to define an IterableDataSet with __iter__ function. Then I could define my Dataloader like:
train_dataloader =['train'],
However, still to begin training it seems that it calls my preprocessing function and creates an Iteration over it. So, it seems I didn't gain much speed up.
Please guide me how could I use speed up in this case?
Here is my part of my class:
def __iter__(self):
iter_start = self.start
iter_end = self.num_samples
worker_info =
if worker_info is None: # single-process data loading, return the full iterator
iter_start = self.start
iter_end = self.num_samples
else: # in a worker process
# split workload
per_worker = int(math.ceil((self.num_samples - self.start) / float(worker_info.num_workers)))
worker_id =
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.num_samples)
if self.flat_data:
return iter(self.flat_data)
return iter(self.fill_data(iter_start, iter_end))
def fill_data(self, iter_start, iter_end, show_progress=False):
flat_data = []
if iter_end < 0:
iter_end = self.num_samples
kk = 0"========================== SPLIT: %s", self.split_name)"get data from %s to %s", iter_start, iter_end)"total rows: %s", len(self.split_df))
if show_progress:
pbar = tqdm(total = self.num_samples)
for index, d in self.split_df.iterrows():
if kk < iter_start:"!!!!!!!!! before start %s", iter_start)
kk += 1
rel = d["prefix"]
# preprocessing and adding to returned list
I did preprosessing in the fill_data or __iter__ body. However, I can use a map for preprocessing. Then the preprocessing is called during training and for every batch and not before training.
import pandas as pd
import torch
class MyDataset(
def __init__(self, fname, until=10):
self.df = pd.read_table("atomic/" + fname)
self.until = until
def preproc(self, t):
prefix, data = t
text = "Preproc: " + prefix + "|" + data
print(text) # to check when it is called
return text
def __iter__(self):
_iter = self.df_iter()
return map(self.preproc, _iter)
def df_iter(self):
ret = []
for idx, row in self.df.iterrows():
return iter(ret)

keras BatchGenerator(keras.utils.Sequence) is too slow

I'm using a custom batch generator with large dataframe. but the Generator takes too much time to generate a batch, it takes 127s to generate a batch of 1024. I've tried Dask but still, the processing is slow. is there any way to integrate multiprocessing with inside the generator. knowing that I've tried use_multiprocessing=True with workers=12
import keras
from random import randint
import glob
import warnings
import numpy as np
import math
import pandas as pd
import dask.dataframe as dd
class BatchGenerator(keras.utils.Sequence):
'Generates data for Keras'
def __init__(self, labels=None, batch_size=8, n_classes=4, shuffle=True,
seq_len=6, data_path=None, meta_path=None,list_IDs=None):
self.batch_size = batch_size
self.labels = labels
self.n_classes = n_classes
self.shuffle = shuffle
self.seq_len = seq_len
self.meta_df = meta_path
self.data_df = data_path
self.data_df = self.data_df.astype({"mjd": int})
self.list_IDs = list_IDs
if self.list_IDs==None:
self.list_IDs = list(self.meta_df['object_id'].unique())
def __len__(self):
'Denotes the number of batches per epoch'
return int(np.floor(len(self.list_IDs) / self.batch_size))
def __getitem__(self, index):
'Generate one batch of data'
# Generate indexes of the batch
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
# Find list of IDs
list_IDs_temp = [self.list_IDs[k] for k in indexes]
# Generate data
X, y = self.__data_generation(list_IDs_temp)
return X, y
def on_epoch_end(self):
'Updates indexes after each epoch'
self.indexes = np.arange(len(self.list_IDs))
if self.shuffle == True:
def __data_generation(self, list_IDs_temp):
X_dat = np.zeros((self.batch_size, self.seq_len,6,1))
Y_mask = np.zeros((self.batch_size, self.seq_len,6,1))
# Y_dat = np.empty((self.batch_size,1), dtype=int)
X_length= np.empty((self.batch_size,1), dtype=int)
for i, trans_id in enumerate(list_IDs_temp):
curve = self.data_df[self.data_df.object_id==trans_id]
mjdlist = list(curve['mjd'].unique())
ts_length = len(mjdlist)
if ts_length <= self.seq_len :
start_ind = 0
else :
start_ind = randint(0, ts_length - self.seq_len)
ts_length = self.seq_len
for j in range(ts_length):
if j+start_ind < len(mjdlist):
step = curve[curve.mjd==mjdlist[j+start_ind]]
for k in range(len(step.mjd)):
obs = step[step.passband==k]
if len(obs) == 0 :
# print('here is one')
if k == 0:
X_dat[i,j,0,0] =obs.flux.iloc[0]
Y_mask[i,j,0,0] = 1
if k == 1:
X_dat[i,j,1,0] = obs.flux.iloc[0]
Y_mask[i,j,1,0] = 1
if k == 2:
X_dat[i,j,2,0] = obs.flux.iloc[0]
Y_mask[i,j,2,0] = 1
if k == 3:
X_dat[i,j,3,0] = obs.flux.iloc[0]
Y_mask[i,j,3,0] = 1
if k == 4:
X_dat[i,j,4,0] = obs.flux.iloc[0]
Y_mask[i,j,4,0] = 1
if k == 5:
X_dat[i,j,5,0] = obs.flux.iloc[0]
Y_mask[i,j,5,0] = 1
# meta = self.meta_df[self.meta_df['object_id'] == trans_id]
# Y_dat[i] = self.labels[int(meta['target'])]
X_length[i,0] = ts_length
flux_max = np.max(X_dat[i])
flux_min = np.min(X_dat[i])
flux_pow = math.log2(flux_max - flux_min)
X_dat[i] /= flux_pow
X_noised = X_dat + np.random.uniform(low=0, high=0.5, size=X_dat.shape)
return [X_noised, X_length, np.reshape(Y_mask,(self.batch_size, self.seq_len*6))], np.reshape(X_dat,(self.batch_size, self.seq_len*6))
To make it faster, the for loop in the function __data_generation should be parallelized. Using the joblib package may help.

using openCV to open the webcam and take picture with it every five seconds

I tried to use the webcam and take pictures every 5 seconds via openCV but the cam itself didn't work and kept causing the error...
I also tried changing the integer in the cv2.VideoCapture() to -1 and 1 but still that didn't work.
This is the form of the error: "[ WARN:0] global /io/opencv/modules/videoio/src/cap_v4l.cpp (802)
open VIDEOIO ERROR: V4L: can't open camera by index 0
Traceback (most recent call last):
File "", line 176, in
raise IOError("Cannot open webcam")
OSError: Cannot open webcam"
import colorsys
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import cv2
import time
import numpy as np
from keras import backend as K
from keras.models import load_model
from keras.layers import Input
from yolo3.model import yolo_eval, yolo_body, tiny_yolo_body
from yolo3.utils import image_preporcess
class YOLO(object):
_defaults = {
#"model_path": 'logs/trained_weights_final.h5',
"model_path": 'model_data/yolo_weights.h5',
"anchors_path": 'model_data/yolo_anchors.txt',
"classes_path": 'model_data/coco_classes.txt',
"score" : 0.3,
"iou" : 0.45,
"model_image_size" : (416, 416),
"text_size" : 1,
def get_defaults(cls, n):
if n in cls._defaults:
return cls._defaults[n]
return "Unrecognized attribute name '" + n + "'"
def __init__(self, **kwargs):
self.__dict__.update(self._defaults) # set up default values
self.__dict__.update(kwargs) # and update with user overrides
self.class_names = self._get_class()
self.anchors = self._get_anchors()
self.sess = K.get_session()
self.boxes, self.scores, self.classes = self.generate()
def _get_class(self):
classes_path = os.path.expanduser(self.classes_path)
with open(classes_path) as f:
class_names = f.readlines()
class_names = [c.strip() for c in class_names]
return class_names
def _get_anchors(self):
anchors_path = os.path.expanduser(self.anchors_path)
with open(anchors_path) as f:
anchors = f.readline()
anchors = [float(x) for x in anchors.split(',')]
return np.array(anchors).reshape(-1, 2)
def generate(self):
model_path = os.path.expanduser(self.model_path)
assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.'
# Load model, or construct model and load weights.
num_anchors = len(self.anchors)
num_classes = len(self.class_names)
is_tiny_version = num_anchors==6 # default setting
self.yolo_model = load_model(model_path, compile=False)
self.yolo_model = tiny_yolo_body(Input(shape=(None,None,3)), num_anchors//2, num_classes) \
if is_tiny_version else yolo_body(Input(shape=(None,None,3)), num_anchors//3, num_classes)
self.yolo_model.load_weights(self.model_path) # make sure model, anchors and classes match
assert self.yolo_model.layers[-1].output_shape[-1] == \
num_anchors/len(self.yolo_model.output) * (num_classes + 5), \
'Mismatch between model and given anchor and class sizes'
print('{} model, anchors, and classes loaded.'.format(model_path))
# Generate colors for drawing bounding boxes.
hsv_tuples = [(x / len(self.class_names), 1., 1.)
for x in range(len(self.class_names))]
self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
self.colors = list(
map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),
np.random.shuffle(self.colors) # Shuffle colors to decorrelate adjacent classes.
# Generate output tensor targets for filtered bounding boxes.
self.input_image_shape = K.placeholder(shape=(2, ))
boxes, scores, classes = yolo_eval(self.yolo_model.output, self.anchors,
len(self.class_names), self.input_image_shape,
score_threshold=self.score, iou_threshold=self.iou)
return boxes, scores, classes
def detect_image(self, image):
if self.model_image_size != (None, None):
assert self.model_image_size[0]%32 == 0, 'Multiples of 32 required'
assert self.model_image_size[1]%32 == 0, 'Multiples of 32 required'
boxed_image = image_preporcess(np.copy(image), tuple(reversed(self.model_image_size)))
image_data = boxed_image
out_boxes, out_scores, out_classes =
[self.boxes, self.scores, self.classes],
self.yolo_model.input: image_data,
self.input_image_shape: [image.shape[0], image.shape[1]],#[image.size[1], image.size[0]],
K.learning_phase(): 0
#print('Found {} boxes for {}'.format(len(out_boxes), 'img'))
thickness = (image.shape[0] + image.shape[1]) // 600
ObjectsList = []
for i, c in reversed(list(enumerate(out_classes))):
predicted_class = self.class_names[c]
box = out_boxes[i]
score = out_scores[i]
label = '{} {:.2f}'.format(predicted_class, score)
#label = '{}'.format(predicted_class)
scores = '{:.2f}'.format(score)
top, left, bottom, right = box
top = max(0, np.floor(top + 0.5).astype('int32'))
left = max(0, np.floor(left + 0.5).astype('int32'))
bottom = min(image.shape[0], np.floor(bottom + 0.5).astype('int32'))
right = min(image.shape[1], np.floor(right + 0.5).astype('int32'))
mid_h = (bottom-top)/2+top
mid_v = (right-left)/2+left
# put object rectangle
cv2.rectangle(image, (left, top), (right, bottom), self.colors[c], thickness)
# get text size
(test_width, text_height), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, thickness/self.text_size, 1)
# put text rectangle
cv2.rectangle(image, (left, top), (left + test_width, top - text_height - baseline), self.colors[c], thickness=cv2.FILLED)
# put text above rectangle
cv2.putText(image, label, (left, top-2), cv2.FONT_HERSHEY_SIMPLEX, thickness/self.text_size, (0, 0, 0), 1)
# add everything to list
ObjectsList.append([top, left, bottom, right, mid_v, mid_h, label, scores])
return image, ObjectsList
def close_session(self):
def detect_img(self, image):
#image = cv2.imread(image, cv2.IMREAD_COLOR)
original_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
original_image_color = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
r_image, ObjectsList = self.detect_image(original_image_color)
return r_image, ObjectsList
if __name__=="__main__":
yolo = YOLO()
# set start time to current time
#start_time = time.time()
# displays the frame rate every 2 second
#display_time = 2
# Set primarry FPS to 0
#fps = 0
# we create the video capture object cap
cap = cv2.VideoCapture(0)
if not cap.isOpened():
raise IOError("Cannot open webcam")
cap.set(3, 640)
#if not cap.isOpened():
#raise IOError("We cannot open webcam")
#while True:
#ret, frame =
# resize our captured frame if we need
#frame = cv2.resize(frame, None, fx=1.0, fy=1.0,
# detect object on our frame
#r_image, ObjectsList = yolo.detect_img(frame)
# show us frame with detection
#cv2.imshow("Web cam input", r_image)
#if cv2.waitKey(25) & 0xFF == ord("q"):
# calculate FPS
#fps += 1
#TIME = time.time() - start_time
#if TIME > display_time:
#print("FPS:", fps / TIME)
#fps = 0
#start_time = time.time()
while True:
ret, frame =
frame = cv2.resize(frame, None, fx=1.0, fy=1.0,
r_image, ObjectsList=yolo.detect_img(frame)
cv2.imshow('Web cam input', r_image)
if cv2.waitKey(1) & 0xFF == ord('q'):
if time.time() - start_time >= 5: #<---- Check if 5 sec passed
img_name = "opencv_frame_{}.png".format(img_counter)
cv2.imwrite(img_name, frame)
print("{} written!".format(img_counter))
start_time = time.time()
img_counter += 1

Building a dataset with dataloader pytorch getting error cannot import name 'read_data_sets'

Loading data into dataset using pytorch dataloader.
Getting error cannot import name 'read_data_sets'
Tried searaching for results from similar issues.
If there is confusion about file instead of module and it can't find read_data_sets in your file How do i change to fix?
class MRDataset(data.Dataset):
def __init__(self, root_dir, task, plane, train=True, transform=None, weights=None):
self.task = task
self.plane = plane
self.root_dir = root_dir
self.train = train
if self.train:
self.folder_path = self.root_dir + 'train/{0}/'.format(plane)
self.records = pd.read_csv(
self.root_dir + 'train-{0}.csv'.format(task), header=None, names=['id', 'label'])
transform = None
self.folder_path = self.root_dir + 'valid/{0}/'.format(plane)
self.records = pd.read_csv(
self.root_dir + 'valid-{0}.csv'.format(task), header=None, names=['id', 'label'])
self.records['id'] = self.records['id'].map(
lambda i: '0' * (4 - len(str(i))) + str(i))
self.paths = [self.folder_path + filename +
'.npy' for filename in self.records['id'].tolist()]
self.labels = self.records['label'].tolist()
self.transform = transform
if weights is None:
pos = np.sum(self.labels)
neg = len(self.labels) - pos
self.weights = torch.FloatTensor([1, neg / pos])
self.weights = torch.FloatTensor(weights)
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
array = np.load(self.paths[index])
label = self.labels[index]
if label == 1:
label = torch.FloatTensor([[0, 1]])
elif label == 0:
label = torch.FloatTensor([[1, 0]])
if self.transform:
array = self.transform(array)
array = np.stack((array,)*3, axis=1)
array = torch.FloatTensor(array)
# if label.item() == 1:
# weight = np.array([self.weights[1]])
# weight = torch.FloatTensor(weight)
# else:
# weight = np.array([self.weights[0]])
# weight = torch.FloatTensor(weight)
return array, label, self.weights
There is a model and train class to run this. Arguments specified in train.
Running the train should load data and run through model

TypeError: __init__() takes from 1 to 4 positional arguments but 9 were given

when l run the following program l got this error :
originDataset = dataset.lmdbDataset(originPath, 'abc', *args)
TypeError: __init__() takes from 1 to 4 positional arguments but 9 were given
This error is relate to the second code source l presented below. it's strange because l don't have 9 argument. what's wrong with my code ?
import sys
origin_path = sys.path
import dataset
sys.path = origin_path
import lmdb
def writeCache(env, cache):
with env.begin(write=True) as txn:
for k, v in cache.iteritems():
txn.put(k, v)
def convert(originPath, outputPath):
args = [0] * 6
originDataset = dataset.lmdbDataset(originPath, 'abc', *args)
print('Origin dataset has %d samples' % len(originDataset))
labelStrList = []
for i in range(len(originDataset)):
label = originDataset.getLabel(i + 1)
if i % 10000 == 0:
lengthList = [len(s) for s in labelStrList]
items = zip(lengthList, range(len(labelStrList)))
items.sort(key=lambda item: item[0])
env =, map_size=1099511627776)
cnt = 1
cache = {}
nSamples = len(items)
for i in range(nSamples):
imageKey = 'image-%09d' % cnt
labelKey = 'label-%09d' % cnt
origin_i = items[i][1]
img, label = originDataset[origin_i + 1]
cache[labelKey] = label
cache[imageKey] = img
if cnt % 1000 == 0 or cnt == nSamples:
writeCache(env, cache)
cache = {}
print('Written %d / %d' % (cnt, nSamples))
cnt += 1
nSamples = cnt - 1
cache['num-samples'] = str(nSamples)
writeCache(env, cache)
print('Convert dataset with %d samples' % nSamples)
if __name__ == "__main__":
convert('/share/datasets/scene_text/Synth90k/synth90k-val-lmdb', '/share/datasets/scene_text/Synth90k/synth90k-val-ordered-lmdb')
convert('/share/datasets/scene_text/Synth90k/synth90k-train-lmdb', '/share/datasets/scene_text/Synth90k/synth90k-train-ordered-lmdb')
which calls the following program :
# encoding: utf-8
import random
import torch
from import Dataset
from import sampler
import torchvision.transforms as transforms
import lmdb
import six
import sys
from PIL import Image
import numpy as np
class lmdbDataset(Dataset):
def __init__(self, root=None, transform=None, target_transform=None):
self.env =
if not self.env:
print('cannot creat lmdb from %s' % (root))
with self.env.begin(write=False) as txn:
nSamples = int(txn.get('num-samples'))
self.nSamples = nSamples
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return self.nSamples
def __getitem__(self, index):
assert index <= len(self), 'index range error'
index += 1
with self.env.begin(write=False) as txn:
img_key = 'image-%09d' % index
imgbuf = txn.get(img_key)
buf = six.BytesIO()
img ='L')
except IOError:
print('Corrupted image for %d' % index)
return self[index + 1]
if self.transform is not None:
img = self.transform(img)
label_key = 'label-%09d' % index
label = str(txn.get(label_key))
if self.target_transform is not None:
label = self.target_transform(label)
return (img, label)
class resizeNormalize(object):
def __init__(self, size, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation
self.toTensor = transforms.ToTensor()
def __call__(self, img):
img = img.resize(self.size, self.interpolation)
img = self.toTensor(img)
return img
class randomSequentialSampler(sampler.Sampler):
def __init__(self, data_source, batch_size):
self.num_samples = len(data_source)
self.batch_size = batch_size
def __iter__(self):
n_batch = len(self) // self.batch_size
tail = len(self) % self.batch_size
index = torch.LongTensor(len(self)).fill_(0)
for i in range(n_batch):
random_start = random.randint(0, len(self) - self.batch_size)
batch_index = random_start + torch.range(0, self.batch_size - 1)
index[i * self.batch_size:(i + 1) * self.batch_size] = batch_index
# deal with tail
if tail:
random_start = random.randint(0, len(self) - self.batch_size)
tail_index = random_start + torch.range(0, tail - 1)
index[(i + 1) * self.batch_size:] = tail_index
return iter(index)
def __len__(self):
return self.num_samples
class alignCollate(object):
def __init__(self, imgH=32, imgW=128, keep_ratio=False, min_ratio=1):
self.imgH = imgH
self.imgW = imgW
self.keep_ratio = keep_ratio
self.min_ratio = min_ratio
def __call__(self, batch):
images, labels = zip(*batch)
imgH = self.imgH
imgW = self.imgW
if self.keep_ratio:
ratios = []
for image in images:
w, h = image.size
ratios.append(w / float(h))
max_ratio = ratios[-1]
imgW = int(np.floor(max_ratio * imgH))
imgW = max(imgH * self.min_ratio, imgW) # assure imgH >= imgW
transform = resizeNormalize((imgW, imgH))
images = [transform(image) for image in images]
images =[t.unsqueeze(0) for t in images], 0)
return images, labels
