师兄的代码封装成类,流畅精美,容易调试。我的代码是堆积成的,被师兄嘲笑说写脚本。好吧!我的代码只有我懂,哈哈! 希望以后代码能写得工整点。现在还是让我先懂。这里,我做了一个简单的任务:0,1,2三个数字的分类
(部分)代码分为:
1 train_net.py
1 #import some module 2 import time 3 import os 4 import numpy as np 5 import sys 6 import cv2 7 sys.path.append("/home/wang/Downloads/caffe-master/python") 8 import caffe 9 #from prepare_data import DataConfig 10 #from data_config import DataConfig 11 12 #configure GPU mode 13 ‘‘‘ uncommend below line to use gpu ‘‘‘ 14 caffe.set_mode_gpu() 15 16 # about dataset 17 ##dataset = Dataset(‘/home/wang/Downloads/object/extract/‘) 18 ##dataset = dataset.Split(‘train‘) 19 ##data_config = DataConfig(dataset) 20 ##data_config.SetBatchSize(256) 21 data_config=‘/home/wang/Downloads/caffe-master/examples/myFig_recognition/data/train/‘ 22 23 24 25 #configure solve.prototxt 26 solver = caffe.SGDSolver(‘models/solver.prototxt‘) 27 28 # load pretrain model 29 print(‘load pretrain model‘) 30 solver.net.copy_from(‘models/bvlc_reference_caffenet.caffemodel‘) 31 32 solver.net.layers[0].SetDataConfig(data_config) 33 34 for i in range(1, 10000): 35 # Make one SGD update 36 solver.step(5) 37 if i % 100 == 0: 38 solver.net.save(‘tmp.caffemodel‘) 39 ‘‘‘ TODO: test code ‘‘‘
2 test_net.py(还没写)
3 pre_data.py
1 import os 2 import numpy as np 3 from random import randint 4 import cv2 5 from utils import PrepareImage,CatImage 6 #class data: 7 #path should be /home/ 8 def prepare_data(path,batchsize): 9 #tmp_path=os.listdir(path) 10 img_list=[] 11 label = np.zeros(batchsize, dtype=np.float32) 12 for i in range(batchsize): 13 #randomly select one file 14 idf=randint(0,2) 15 idf_str=str(idf) 16 path1=path+idf_str 17 tmp_path=os.listdir(path1) 18 19 #randomly select one image 20 idi=randint(0,len(tmp_path)-1) 21 #img = cv2.imread(imgPaths[idx]) 22 img_path=path1+‘/‘+tmp_path[idi] 23 img=cv2.imread(img_path) 24 25 img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 26 flip = randint(0, 1)>0 27 if flip > 0: 28 img = img[:, ::-1, :] # flip left to right 29 30 img=PrepareImage(img, (227,227)) 31 img_list.append(img) 32 label[i]=idf 33 imgData = CatImage(img_list) 34 return (imgData,label)
4 utils.py
1 import os 2 import cv2 3 import numpy as np 4 5 def PrepareImage(im, size): 6 im = cv2.resize(im, (size[0], size[1])) 7 im = im.transpose(2, 0, 1) 8 im = im.astype(np.float32, copy=False) 9 return im 10 11 def CatImage(im_list): 12 max_shape = np.array([im.shape for im in im_list]).max(axis=0) 13 blob = np.zeros((len(im_list), 3, max_shape[1], max_shape[2]), dtype=np.float32) 14 # set to mean value 15 blob[:, 0, :, :] = 102.9801 16 blob[:, 1, :, :] = 115.9465 17 blob[:, 2, :, :] = 122.7717 18 for i, im in enumerate(im_list): 19 blob[i, :, 0:im.shape[1], 0:im.shape[2]] = im 20 return blob
4 layer/data_layer.py
1 import caffe 2 import numpy as np 3 4 #import data_config 5 #import prepare_data 6 from pre_data import prepare_data 7 8 class DataLayer(caffe.Layer): 9 10 def SetDataConfig(self, data_config): 11 self._data_config = data_config 12 13 def GetDataConfig(self): 14 return self._data_config 15 16 def setup(self, bottom, top): 17 # data blob 18 top[0].reshape(1, 3, 227, 227) 19 #top[0].reshape(1, 3, 34, 44) 20 # label type 21 top[1].reshape(1, 1) 22 23 def reshape(self, bootom, top): 24 pass 25 26 def forward(self, bottom, top): 27 #(imgs, label) = self._data_config.next() 28 path=self.GetDataConfig() 29 (imgs,label)=prepare_data(path,128) 30 (N, C, W, H) = imgs.shape 31 # image data 32 top[0].reshape(N, C, W, H) 33 top[0].data[...] = imgs 34 # object type label 35 top[1].reshape(N) 36 top[1].data[...] = label 37 38 def backward(self, top, propagate_down, bottom): 39 pass
5 layer/__init__.py
import data_layer
还有一些caffe中经典的东西没放进来。
数据:http://pan.baidu.com/s/1mgYQa6G(尚未分训练集和测试集)
python caffe 在师兄的代码上修改成自己风格的代码
原文:http://www.cnblogs.com/Wanggcong/p/5169737.html