From d2603183fb96948bdea2aef75d851c6423770e59 Mon Sep 17 00:00:00 2001 From: Chunk Date: Mon, 20 Apr 2015 15:36:26 +0800 Subject: [PATCH] staged. --- test/test_model.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/test/test_model.py b/test/test_model.py index b4e471f..afb7bea 100755 --- a/test/test_model.py +++ b/test/test_model.py @@ -5,7 +5,7 @@ from sklearn import cross_validation from ..common import * from ..mdata import CV, ILSVRC, ILSVRC_S from ..mmodel.svm import SVM -from ..mmodel.theano import THEANO +from ..mmodel.theano import THEANO import gzip import cPickle @@ -14,6 +14,7 @@ import cPickle timer = Timer() package_dir = os.path.dirname(os.path.abspath(__file__)) + def test_SVM_CV(): timer.mark() dcv = CV.DataCV() @@ -89,7 +90,7 @@ def test_SVM_ILSVRC_HBASE(): X1, Y1 = dil.load_data(mode='local') X_train, X_test, Y_train, Y_test = cross_validation.train_test_split(X, Y, test_size=0.4, random_state=0) - print Y,np.sum(np.array(Y)==0),np.sum(np.array(Y)==1) + print Y, np.sum(np.array(Y) == 0), np.sum(np.array(Y) == 1) print np.array(Y).shape, np.array(X).shape print np.array(X_train).shape, np.array(Y_train).shape print np.array(X_test).shape, np.array(Y_test).shape @@ -148,9 +149,7 @@ def test_SVM_ILSVRC_S(): # test_SVM_ILSVRC_SPARK() - def test_THEANO_crop(): - timer.mark() dilc = ILSVRC.DataILSVRC(base_dir='/data/hadoop/ImageNet/ILSVRC/ILSVRC2013_DET_val', category='Test_crop_pil') X, Y = dilc.load_data(mode='local', feattype='coef') @@ -158,11 +157,11 @@ def test_THEANO_crop(): # X_train, X_test, Y_train, Y_test = cross_validation.train_test_split(X, Y, test_size=0.2, random_state=0) # with open(os.path.join(package_dir,'../res/','ils_crop.pkl'),'wb') as f: - # cPickle.dump([(X_train,Y_train),(X_test,Y_test)], f) + # cPickle.dump([(X_train,Y_train),(X_test,Y_test)], f) timer.mark() mtheano = THEANO.ModelTHEANO(toolset='cnn') - mtheano._train_cnn(dataset='/data/hadoop/ImageNet/ILSVRC/ILSVRC2013_DET_val/ils_crop.pkl') + mtheano._train_cnn(X, Y) timer.report() -- libgit2 0.21.2