clear all, close all

%% param config

use_gpu = 1 
root_path = [pwd '/../../'];
cnn_model_name = [root_path 'caffe/examples/mnist/lenet_iter_10000.caffemodel.example']; % pre-trained nn for initialization
num_iters = 10; % # training iterations
lr_scale = 0.0001; % learning rate scaling
nx = 70; % grid size

%% initialize env

% related paths
addpath([root_path 'caffe/matlab']);
addpath([root_path 'gpml/kissgp']);
% caffe mode
if use_gpu
  caffe.set_mode_gpu();
  %caffe.set_device(1);
else
  caffe.set_mode_cpu();
end

%% initialize CNN

example_dir = [root_path 'examples/mnist/'];
solver = caffe.Solver([example_dir 'lenet_solver.prototxt_matlab']);
train_net = solver.net;
test_net = solver.test_nets(1);
train_net.copy_from(cnn_model_name);

%% load data

fprintf('loading data ...\n');
data_path = [root_path 'caffe/examples/mnist/data/'];
im_size = 28*28;
numtrain = 60000;
numtest = 10000;

% load training data
input_data = zeros(28, 28, 1, numtrain, 'single');
ytrain = [];
train_fid = fopen([data_path '/mnist_train/mnist.dat']);
for m=1:numtrain
  % image
  C = textscan(train_fid, '%f', im_size);
  im = zeros(im_size, 'single');
  im = C{1};
  im = reshape(im,28,28);
  im = im * 0.00390625;
  input_data(:,:,1,m) = im; 
  % label
  C = textscan(train_fid, '%f', 1);
  ytrain = [ytrain C{1}];
end
% forward pass to get cnn features
train_net.blobs('data').reshape([28 28 1 numtrain]); % reshape cnn input blob
train_net.reshape();
train_net.forward({input_data});
% get cnn feature from the last layer
Xtrain_cnn = train_net.blobs('ip1.6').get_data();
Xtrain_cnn = double(Xtrain_cnn'); % size = numtrain x D
ytrain = ytrain';
D = size(Xtrain_cnn, 2); % feature dimension

% load test data
test_data = zeros(28, 28, 1, numtest, 'single');
ytest = [];
test_fid = fopen([data_path '/mnist_test/mnist.dat']);
for m=1:numtest
  % image
  C = textscan(test_fid, '%f', im_size);
  im = zeros(im_size, 'single');
  im = C{1};
  im = reshape(im,28,28);
  im = im * 0.00390625;
  test_data(:,:,1,m) = im; 
  % label
  C = textscan(test_fid, '%f', 1);
  ytest = [ytest C{1}];
end
% forward pass to get cnn features
train_net.blobs('data').reshape([28 28 1 numtest]); %reshape cnn input blob
train_net.reshape();
train_net.forward({test_data});
% get cnn feature from the last layer
Xtest_cnn = train_net.blobs('ip1.6').get_data(); %cnn features
Xtest_cnn = double(Xtest_cnn'); %size = numtrain x D
ytest = ytest';

% test cnn performance
y_reg = train_net.blobs('ip2').get_data();
rmse = sqrt(mean((ytest-y_reg').^2));
fprintf('CNN RMSE %f\n', rmse);

%% initialize GP

% optimization params
opt.toep1d = false;    % toggle [], true, false toggle to allow use of toep*
opt.ntoep_min = 1000;
opt.circ_emb = true;   % toggle to use the circular embedding
opt.cg_maxit = 500; opt.cg_tol = 1e-6;

% kernel hyps
cov = {@covSEiso};
cov = repmat({@covSEiso}, [1, D]);
sf = 1; ell = 0.5; hyp.cov = repmat(log([ell;sf]), [D, 1]);
meanfunc = {@meanSum, {@meanLinear, @meanConst}}; hyp.mean = [zeros(D,1); 0];
sn = 0.3;  hyp.lik = log(sn);
likfunc = @likGauss;
inf_method = @(varargin) infGrid(varargin{:},opt);
% grid
xg_min = min(min(Xtrain_cnn), min(Xtest_cnn)) - 5; % range of grid ...
xg_max = max(max(Xtrain_cnn), max(Xtest_cnn)) + 5; 
xg = cell([1, D]);
for i = 1:D
  xg{i} = linspace(xg_min(i), xg_max(i), nx)';
end
covg = {@covGrid,cov,xg};

% training 
for iter = 1:num_iters 
  % compute gradient w.r.t the base kernel input
  [~,~,~,dx] = infGrid_cnngp(hyp,meanfunc,covg,'likGauss',Xtrain_cnn,ytrain,opt);
  % update base kernel hyps using the GPML routine
  hyp = minimize(hyp,@gp,-2,inf_method,meanfunc,covg,likfunc,Xtrain_cnn,ytrain);
  % update CNN params through backpropagation
  diff = {dx' * lr_scale};
  train_net.blobs('data').reshape([28 28 1 numtrain]); %reshape cnn input blob
  train_net.reshape();
  train_net.forward({input_data}); % forward pass
  train_net.backward_from('ip1.6', {'ip1.6'}, diff); % backward pass
  solver.apply_update();

  % get the latest cnn features
  train_net.forward({input_data});
  Xtrain_cnn = train_net.blobs('ip1.6').get_data();
  Xtrain_cnn = double(Xtrain_cnn');
  train_net.blobs('data').reshape([28 28 1 numtest]);
  train_net.reshape();
  train_net.forward({test_data});
  Xtest_cnn = train_net.blobs('ip1.6').get_data();
  Xtest_cnn = double(Xtest_cnn');

  % test
  [postg,nlZg,dnlZg] = infGrid(hyp,meanfunc,covg,'likGauss',Xtrain_cnn,ytrain,opt);
  ymug = postg.fmu(Xtest_cnn); % quick interpolated prediction
  rmse = sqrt(mean((ytest-ymug).^2));
  fprintf('iter %d, RMSE %f\n', iter, rmse);
end

%% save
%train_net.save(['dkl_rbf_mnist.caffemodel']);
%save dkl_rbf_mnist_hyps.mat hyp
