clear all, close all

%% param config

use_gpu = 1 
root_path = [pwd '/../../'];
cnn_model_name = [root_path 'caffe/examples/kin40k/kin40k_iter_40000.caffemodel.example']; % pre-trained nn for initialization
num_iters = 1000; % # training iterations
lr_scale = 2e-5; % learning rate scaling
nx = 70; % grid size

%% initialize env

% related paths
addpath([root_path 'caffe/matlab']);
addpath([root_path 'gpml/kissgp']);
% caffe mode
if use_gpu
  caffe.set_mode_gpu();
  %caffe.set_device(1);
else
  caffe.set_mode_cpu();
end

%% initialize CNN

example_dir = [root_path 'examples/kin40k/'];
solver = caffe.Solver([example_dir 'kin40k_solver.prototxt_matlab']);
train_net = solver.net;
test_net = solver.test_nets(1);
train_net.copy_from(cnn_model_name);

%% load data

fprintf('loading data ...\n');
load([root_path 'caffe/examples/kin40k/data/kin40k.mat']);
numtotal = size(data, 1);
rawdim = size(data, 2) - 1;

% read shuffle mapping
train_fid = fopen([root_path 'caffe/examples/kin40k/data/kin40k.tab_mapping']);
C = textscan(train_fid, '%d', numtotal);
data = data(C{1},:);

numtrain = numtotal - int32(numtotal * 0.1);
input_data = zeros(rawdim, 1, 1, numtrain, 'single');
input_data(:,1, 1, :) = data(1:numtrain,1:rawdim)';
ytrain = data(1:numtrain,end);
% forward pass to get nn features
train_net.blobs('data').reshape([rawdim 1 1 numtrain]); %reshape nn input blob
train_net.reshape();
train_net.forward({input_data});
% get nn feature from the last layer
Xtrain_cnn = train_net.blobs('ip5').get_data(); %nn features
Xtrain_cnn = double(Xtrain_cnn'); %size = numtrain x D
D = size(Xtrain_cnn, 2); % feature dimension

% load test data
numtest = numtotal - numtrain;
test_data = zeros(rawdim, 1, 1, numtest, 'single');
test_data(:, 1, 1, :) = data(numtrain+1:end,1:rawdim)';
ytest = data(numtrain+1:end,end);
% forward pass to get nn features
train_net.blobs('data').reshape([rawdim 1 1 numtest]); %reshape nn input blob
train_net.reshape();
train_net.forward({test_data});
% get nn feature from the last layer
Xtest_cnn = train_net.blobs('ip5').get_data(); %nn features
Xtest_cnn = double(Xtest_cnn'); %size = numtrain x D

% test cnn performance
y_reg = train_net.blobs('ip6').get_data();
rmse = sqrt(mean((ytest-y_reg').^2));
fprintf('CNN RMSE %f\n', rmse);

%% initialize GP

% optimization params
opt.toep1d = false;    % toggle [], true, false toggle to allow use of toep*
opt.ntoep_min = 1000;
opt.circ_emb = true;   % toggle to use the circular embedding
opt.cg_maxit = 2000; opt.cg_tol = 1e-6;

% kernel hyps
cov = {@covSEiso};
cov = repmat({@covSEiso}, [1, D]);
sf = 1; ell = 0.5; hyp.cov = repmat(log([ell;sf]), [D, 1]);
meanfunc = {@meanSum, {@meanLinear, @meanConst}}; hyp.mean = [zeros(D,1); 0];
sn = 0.3;  hyp.lik = log(sn);
likfunc = @likGauss;
inf_method = @(varargin) infGrid(varargin{:},opt);
% grid
xg_min = min(min(Xtrain_cnn), min(Xtest_cnn)) - 5; % range of grid ...
xg_max = max(max(Xtrain_cnn), max(Xtest_cnn)) + 5;
xg = cell([1, D]);
for i = 1:D
  xg{i} = linspace(xg_min(i), xg_max(i), nx)';
end
covg = {@covGrid,cov,xg};

%% training 
hyp = minimize(hyp,@gp,-100,inf_method,meanfunc,covg,likfunc,Xtrain_cnn,ytrain);

for iter=1:num_iters 
  % compute gradient w.r.t the base kernel input
  [~,~,~,dx] = infGrid_cnngp(hyp,meanfunc,covg,'likGauss',Xtrain_cnn,ytrain,opt);
  % update base kernel hyps using the GPML routine
  hyp = minimize(hyp,@gp,-2,inf_method,meanfunc,covg,likfunc,Xtrain_cnn,ytrain);

  % update NN params through backpropagation
  diff = {dx' * lr_scale};
  train_net.blobs('data').reshape([rawdim 1 1 numtrain]); %reshape nn input blob
  train_net.reshape();
  train_net.forward({input_data});
  train_net.backward_from('ip5', {'ip5'}, diff);
  solver.apply_update();

  % get the latest nn features
  train_net.forward({input_data});
  Xtrain_cnn = train_net.blobs('ip5').get_data(); %nn features
  Xtrain_cnn = double(Xtrain_cnn'); %size = numtrain x D
  train_net.blobs('data').reshape([rawdim 1 1 numtest]); %reshape nn input blob
  train_net.reshape();
  train_net.forward({test_data});
  Xtest_cnn = train_net.blobs('ip5').get_data(); %nn features
  Xtest_cnn = double(Xtest_cnn'); %size = numtest x D
  
  % test
  [postg,~,~] = infGrid(hyp,meanfunc,covg,'likGauss',Xtrain_cnn,ytrain,opt);
  ymug = postg.fmu(Xtest_cnn);                            % quick interpolated prediction
  rmse = sqrt(mean((ytest-ymug).^2));
  if mod(iter,10)==0 
    fprintf('iter %d, RMSE %f\n', iter, rmse);
  end
end

%% save
%train_net.save(['dkl_rbf_kin40k.caffemodel']);
%save dkl_rbf_kin40k_hyps.mat hyp
