pytorch调用caffe的lmdb

时间:2019-11-21
本文章向大家介绍pytorch调用caffe的lmdb,主要包括pytorch调用caffe的lmdb使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

一. 处理好训练集和验证集后,通过caffe的convert_imageset生成lmdb:

1 /usr/softwares/caffe/build/tools/convert_imageset --resize_width=224 --resize_height=224 --gray=true --shuffle=true --encoded=true /usr/data_path/ /usr/data/data_lt.txt /usr/lmdb

(1) /usr/softwares/caffe/build/tools/convert_imageset:caffe中convert_imagese脚本的路径;

(2) --resize_width, --resize_height: 网络输入图像的宽和高;

(3) --gray: 是否灰度化,如果没设这个参数,则默认lmdb中像素值为三通道;

(4) --shuffle: 是否随机打乱输入图像列表的顺序;

(5) --encoded: 是否对像素做编码;

(6) /usr/data_path/: 存放图像的主目录路径;

(7) /usr/data/data_lt.txt: 图像的列表,存放的是图像的相对路径和label;

(8) /usr/lmdb: 生成的lmdb的绝对路径。

二. 提取lmdb中的key,生成key列表:

1 fl = open(saveKey_path, "w")
2 lmdb_env = lmdb.open(lmdb_path)
3 lmdb_txn = lmdb_env.begin()
4 lmdb_cursor = lmdb_txn.cursor()
5 for key, value in lmdb_cursor:
6     fl.write("%s\n" %(key))
7     count += 1
8 fl.close()

  根据caffe中读取lmdb的方式读取lmdb中每个元素的key值,详见代码get_lmdbKey.py。

三. 根据上述key值列表读取batch(详见文件readLmdb.py):

(1) 得到所有图像的索引,并shuffle;

1 self.indices_total = np.arange(self.dataset_size)
2 np.random.seed(321)
3 np.random.shuffle(self.indices_total)

(2) 定义图像batch和label batch,从lmdb中取数据放入这两个batch中:

 //定义image batch和label batch blob(N C H W),其中label_batch后的dtype是为了方便将其转化为pytorch接受的tensor形式的变量
1
image_batch = np.zeros((self.batch_size, 1, 224, 224), dtype=np.float32) 2 label_batch = np.zeros((self.batch_size), dtype=getattr(np, 'long'))
//按照batch值依次从shuffle后的key list中读取图像
3 for i in range(self.batch_size): 4 ind = self.indices_total[self.data_idx] 5 self.data_idx += 1
//若变量data_idx取到了list的最后一个值,则重新shuffle key list 6 if self.data_idx == self.dataset_size: 7 self.data_idx = 0 8 np.random.shuffle(self.indices_total)
//注意以下代码可能python2和python3的有差异,此处代码基于python3
9 temp = (self.key_lt[ind]).encode() 10 value = self.txn.get(temp) 11 datum = caffe_pb2.Datum() 12 datum.ParseFromString(value)
//读取label,并放入label batch中
13 label = float(datum.label)
//读取图像数据,对其预处理后放入image batch中
14 encoded = datum.encoded 15 if encoded: 16 stream = BytesIO(datum.data) 17 img = np.uint8(Image.open(stream)) 18 img = img[...,::-1] 19 else: 20 data = caffe.io.datum_to_array(datum) 21 img = np.transpose(data, (1, 2, 0)) 22 img_tmp = img.copy() 23 img_tmp = np.float64(img_tmp) 24 img_tmp -= 127.5 25 img_tmp *= 0.0078125 26 img_tmp = img_tmp[np.newaxis, :] 27 image_batch[i, :] = img_tmp 28 label_batch[i] = label

 (3) 在训练程序中,送入网络满足pytorch的数据形式:

1 train_inputs, train_label = batchClass.GetBatch()
2 pytorch_inputs = (torch.tensor(train_inputs)).cuda()
3 pytorch_labels = (torch.tensor(train_label)).cuda()

 将自己定义的blob通过torch.tensor转化为pytorch网络接受的数据形式。

原文地址:https://www.cnblogs.com/liangx-img/p/11905374.html