【Caffe實踐】基於CNN的性別、年齡識別的代碼實現


主要參考:http://www.openu.ac.il/home/hassner/projects/cnn_agegender/

具體實現:http://nbviewer.jupyter.org/url/www.openu.ac.il/home/hassner/projects/cnn_agegender/cnn_age_gender_demo.ipynb




首先按照鏈接:http://www.openu.ac.il/home/hassner/projects/cnn_agegender/下載所需的cnn_age_gender_models_and_data.0.0.2.zip到models。

先要保證python版本2.7以上http://yijiebuyi.com/blog/108ae6186bb00cc708bc54f02adec277.html

升級python后,可能導致pip不可用,執行以下命令即可

wget https://bootstrap.pypa.io/get-pip.py

python get-pip.py

numpy路徑:/usr/lib64/python2.7/site-packages/numpy/core/include/numpy/arrayobject.h

之后需要編譯caffe,使其支持python,具體編譯方法見:http://caffe.berkeleyvision.org/installation.html

預測代碼

[python] view plain copy 在CODE上查看代碼片派生到我的代碼片
  1. import os  
  2. import numpy as np  
  3. import shutil  
  4.   
  5. caffe_root = '/opt/caffe/'   
  6. import sys  
  7. sys.path.insert(0, caffe_root + 'python')  
  8. import caffe  
  9.   
  10. def predict(src_folder):  
  11.     mean_filename='/opt/caffe/models/AgeGenderCNN/mean.binaryproto'  
  12.     proto_data = open(mean_filename, "rb").read()  
  13.     a = caffe.io.caffe_pb2.BlobProto.FromString(proto_data)  
  14.     mean  = caffe.io.blobproto_to_array(a)[0]  
  15.   
  16.     age_net_pretrained='/opt/caffe/models/AgeGenderCNN/age_net.caffemodel'  
  17.     age_net_model_file='/opt/caffe/models/AgeGenderCNN/deploy_age.prototxt'  
  18.     age_net = caffe.Classifier(age_net_model_file, age_net_pretrained,  
  19.                            mean=mean,  
  20.                            channel_swap=(2,1,0),  
  21.                            raw_scale=255,  
  22.                            image_dims=(256256))  
  23.   
  24.     gender_net_pretrained='/opt/caffe/models/AgeGenderCNN/gender_net.caffemodel'  
  25.     gender_net_model_file='/opt/caffe/models/AgeGenderCNN/deploy_gender.prototxt'  
  26.     gender_net = caffe.Classifier(gender_net_model_file, gender_net_pretrained,  
  27.                            mean=mean,  
  28.                            channel_swap=(2,1,0),  
  29.                            raw_scale=255,  
  30.                            image_dims=(256256))  
  31.   
  32.     age_list=['(0, 2)','(4, 6)','(8, 12)','(15, 20)','(25, 32)','(38, 43)','(48, 53)','(60, 100)']  
  33.     gender_list=['Male','Female']  
  34.     gender_folder = 'male'  
  35.     for people_folder in os.listdir(src_folder):  
  36.         people_path = src_folder + people_folder + '/'  
  37.         for img_file in os.listdir(people_path):  
  38.             img_path = people_path+img_file  
  39.             input_image = caffe.io.load_image(img_path)  
  40. #            prediction = age_net.predict([input_image])   
  41. #            print 'predicted age:', age_list[prediction[0].argmax()]  
  42.             prediction = gender_net.predict([input_image])  
  43. #            print 'predicted gender:', gender_list[prediction[0].argmax()]  
  44.             if gender_list[prediction[0].argmax()] != gender_folder:  
  45.                 print 'processing img:',img_path ,'gender:',gender_list[prediction[0].argmax()],' prediction:',prediction  
  46.                 if gender_folder == 'Male':  
  47.                     shutil.copy(img_path,src_folder+'../maleout')  
  48.                 elif gender_folder == 'Female':  
  49.                     shutil.copy(img_path,src_folder+'../femaleout')  
  50.   
  51. if __name__ == '__main__':  
  52.     if len(sys.argv) != 2:  
  53.         print 'Usage: python %s src_folder' % (sys.argv[0])  
  54.         sys.exit()  
  55.     src_folder = sys.argv[1]  
  56.     if not src_folder.endswith('/'):  
  57.         src_folder += '/'  
  58.     predict(src_folder)  



注意!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系我们删除。



 
粤ICP备14056181号  © 2014-2021 ITdaan.com