clear all, close all
rng(2016);

%% param config

use_gpu = 1 
root_path = [pwd '/../../'];
cnn_model_name = [root_path 'caffe/examples/airline/airline_iter_100000.caffemodel.example']; % pre-trained nn for initialization
num_iters = 1000;
lr_scale = 2e-5;
bsize = 50000;

vi_skip_niters = 0;
ft_skip_niters = 20;
hl_skip_niters = 10;

opt.vi.mccnt = 1; % # of MC samples
nx = 70; % grid size

sf = 1; ell = 5;

test_freq = 10;

% ---------------------

%% initialize env

% related paths
addpath([root_path 'caffe/matlab']);
addpath([root_path 'gpml/kissgp']);
% caffe mode
if use_gpu
  caffe.set_mode_gpu();
  caffe.set_device(1);
else
  caffe.set_mode_cpu();
end

%% initialize CNN

example_dir = [root_path 'examples/airline/'];
solver = caffe.Solver([example_dir 'airline_solver.prototxt_matlab']);
train_net = solver.net;
test_net = solver.test_nets(1);
train_net.copy_from(cnn_model_name);

%% load data

fprintf('loading data ...\n');
load([root_path 'caffe/examples/airline/data/airline.mat']);
train_idx = cvo.training(1);
test_idx = cvo.test(1);
rawdim = size(data, 2) - 1
C = 2
blob_shape = [rawdim 1 1];
 
% load train data
numtrain = sum(train_idx)
train_data = data(train_idx,[1:end]);
Xtrain_cnn=cell(C,1);
for c=1:C
  Xtrain_cnn{c} = zeros(numtrain,1);
end
y_concat = [];
npar = 65; % partition to avoid out-of-memory
psz = numtrain / npar;
for pi=1:npar
  idxst = (pi-1)*psz+1;
  if pi<npar, idxed = pi*psz; else, idxed = numtrain; end
  input_data = zeros(rawdim, 1, 1, idxed-idxst+1, 'single');
  input_data(:,1, 1, :) = train_data(idxst:idxed,1:rawdim)';
  % forward to get cnn features
  train_net.blobs('data').reshape([blob_shape idxed-idxst+1]); %reshape cnn input blob
  train_net.reshape();
  train_net.forward({input_data});
  % get nn feature from the last layer
  for c=1:C
    fea = train_net.blobs(['ip1.6-' num2str(c)]).get_data(); %cnn features
    Xtrain_cnn{c}(idxst:idxed,:) = double(fea'); %size = num x D
  end 
end 
ytrain = data(train_idx,end);
ytrain = ytrain + 1; % matlab index starts with 1
input_data = zeros(rawdim, 1, 1, numtrain, 'single');
input_data(:, 1, 1, :) = data(train_idx,1:rawdim)';
D = size(Xtrain_cnn{1}, 2);
ytrain_ext = zeros(numtrain,C);
for c=1:C
  ytrain_ext(ytrain==c,c)=1;
end

% load test data
numtest = sum(test_idx);
test_data = zeros(rawdim, 1, 1, numtest, 'single');
test_data(:,1, 1, :) = data(test_idx,1:rawdim)';
ytest = data(test_idx,end);
ytest = ytest + 1; % matlab index starts from 1 !!
% forward pass to get cnn features
train_net.blobs('data').reshape([blob_shape numtest]); %reshape nn input blob
train_net.reshape();
train_net.forward({test_data});
% get nn feature from the last layer
Xtest_cnn=cell(C,1);
for c=1:C
  fea = train_net.blobs(['ip1.6-' num2str(c)]).get_data(); %nn features
  Xtest_cnn{c} = double(fea'); %size = numtest x D
end

% test nn performance
y_concat = train_net.blobs('concat1').get_data();
y_concat = y_concat';
[~,y_pred] = max(y_concat,[],2);
accu_cnn_test = sum(y_pred==ytest) / numtest;
fprintf('NN accu %f\n', accu_cnn_test);

%% initialize GP

% optimization params
opt.toep1d = false;    % toggle [], true, false toggle to allow use of toep*
opt.ntoep_min = 1000;
opt.circ_emb = true;   % toggle to use the circular embedding
opt.cg_maxit = 20000; opt.cg_tol = 1e-3;
opt.eps = 1e-1; opt.eps2= 1e-4;
opt.C = C;
opt.n = numtrain;
opt.pi_decay = 0;
opt.mean_decay = 0;

% kernel hyps
cov = {@covSEiso};
cov = repmat({@covSEiso}, [1, D]);
hyp.cov = repmat(log([ell;sf]), [D, 1]);
meanfunc = {@meanIdentity}; hyp.mean = [];
hyp = repmat({hyp}, [1, C]);
likfunc = [];
% grid
xg=cell(C,1);
xg_min=cell(C,1); xg_max=cell(C,1);
covg=cell(C,1);
for c=1:C
  X_cnn = [Xtrain_cnn{c}; Xtest_cnn{c}];
  X_cnn_lb = 2*min(X_cnn,[],1)-quantile(X_cnn,0.1,1);
  X_cnn_ub = 2*max(X_cnn,[],1)-quantile(X_cnn,0.9,1);
  xg{c} = covGrid('create', [X_cnn; X_cnn_lb; X_cnn_ub], 1, nx);
  xg_min{c} = zeros(1, D); 
  xg_max{c} = zeros(1, D); 
  for i=1:D
    xg_min{c}(i) = min(xg{c}{i});
    xg_max{c}(i) = max(xg{c}{i});
  end
  covg{c} = {@covGrid,cov,xg{c}};
end
[~,ng,Dg] = covGrid('expand', xg{1}); % grid of every dim has the same ng/Dg
Ng=prod(ng); Dg=sum(Dg);

% variational params
vi.m = 0.5*randn(ng, C); % posterior mean
vi.L = cell(1,C); % posterior covariance
%vi.pi = (eye(C)+1e-1).*(diag(randn(1,C)*0.1+1)+randn(C)*0.1);, TODO
vi.pi = eye(C);
for c=1:C
  L = cell(D,1);
  for d=1:D
    L{d} = eye(ng(d));
  end
  vi.L{c} = L;
end
opt.vi.eps = zeros(Ng, opt.vi.mccnt, C); %mc samples
opt.vi.u = zeros(size(opt.vi.eps));
opt.vi.mCProbc = zeros(bsize,opt.vi.mccnt,C); %
opt.vi.yCProbc = zeros(size(opt.vi.mCProbc)); %
opt.vi.myCProbc = zeros(size(opt.vi.mCProbc)); %
opt.update_vi = 0;
opt.vi.sample = 0;
opt.vi.cnt = 0;
opt.vi.iter = 0;

disp 'init test';
opt.vi.update = 0;
try
  postg = infGrid_svi_dkl(hyp,meanfunc,covg,'likGauss',[],[],vi,opt);
  ymug = postg.ymu(Xtest_cnn);                            % quick interpolated prediction
  accu = sum(ymug==ytest) / numtest
  %
  fmugn = exp(postg.fmu(Xtest_cnn));
  fmugn = fmugn ./ repmat(sum(fmugn,2), [1,size(fmugn,2)]);
  %[ymug(1:50) ytest(1:50)]
catch
  msgstr = lasterr;
  warning('Inference method failed [%s]', msgstr);
end
disp 'init test done';

best_accu = 0;
best_iter = 0;
for iter = 1:num_iters 
  fprintf('iter: %d\n', iter);

  % subsample a minibatch
  randidx = randsample(1:numtrain,bsize);
  curbsize = bsize;
  yb = ytrain_ext(randidx,:);
  xb = cell(size(Xtrain_cnn));
  % forward pass to get cnn features
  train_net.blobs('data').reshape([blob_shape curbsize]); %reshape cnn input blob
  train_net.reshape();
  train_net.forward({input_data(:,:,:,randidx)});
  for c=1:C 
    fea = train_net.blobs(['ip1.6-' num2str(c)]).get_data(); %cnn features
    xb{c} = double(fea'); %size = curbsize x D
  end

  % draw MC samples
  opt.vi.sample = 1;
  opt = sample(opt, vi, C);
  opt.vi.sample = 0;
  % update variational params
  fprintf('update vi.\t mccnt=%d\t nx=%d\n', opt.vi.mccnt, nx);
  opt.update_vi = 1;
  inf_method = @(varargin) infGrid_svi_dkl(varargin{:},opt);
  vi = minimize_svi_dkl_vi(vi,@gp_svi_dkl_vi,-100,hyp,inf_method,meanfunc,covg,likfunc,xb,yb);
  opt.update_vi = 0;
  % test, TODO
  if mod(iter,test_freq)==0
    try
      postg = infGrid_svi_dkl(hyp,meanfunc,covg,'likGauss',[],[],vi,opt);
    catch
      msgstr = lasterr;
      warning('Inference method failed [%s]', msgstr);
    end
    try % test error
      ymug = postg.ymu(Xtest_cnn);                            % quick interpolated prediction
      accu = sum(ymug==ytest) / numtest
      %
      fmugn = exp(postg.fmu(Xtest_cnn));
      fmugn = fmugn ./ repmat(sum(fmugn,2), [1,size(fmugn,2)]);
      if accu > best_accu
        best_accu = accu;
        best_iter = iter;
      end
      [iter accu best_accu best_iter]
    catch
      msgstr = lasterr;
      warning('eval test-error failed [%s]', msgstr);
    end
  end

  % update kernel params
  if iter>=hl_skip_niters
    yb = ytrain_ext(randidx,:);
    inf_method = @(varargin) infGrid_svi_dkl(varargin{:},opt);
    hyp = minimize_svi_dkl_hyp(hyp,@gp_svi_dkl_hyp,-100,vi,inf_method,meanfunc,covg,likfunc,xb,yb);
    opt.vi.sample = 0;
    % TODO
    for c=1:C
      fprintf(['hyp{%d}.cov ' repmat('\t%f\t', [size(hyp{c}.cov,2), size(hyp{c}.cov,1)]) '\n'], c, hyp{c}.cov);
    end
    for c=1:C
      fprintf(['hyp{%d}.mean ' repmat('\t%f\t', [size(hyp{c}.mean,2), size(hyp{c}.mean,1)]) '\n'], c, hyp{c}.mean);
    end
    % test, TODO 
    if mod(iter,test_freq)==0
      try % test error
        postg = infGrid_svi_dkl(hyp,meanfunc,covg,'likGauss',[],[],vi,opt);
        ymug = postg.ymu(Xtest_cnn);                            % quick interpolated prediction
        accu = sum(ymug==ytest) / numtest;
        %
        fmugn = exp(postg.fmu(Xtest_cnn));
        fmugn = fmugn ./ repmat(sum(fmugn,2), [1,size(fmugn,2)]);
        if accu > best_accu
          best_accu = accu;
          best_iter = iter;
        end
        [iter accu best_accu best_iter]
      catch
        msgstr = lasterr;
        warning('Inference method failed [%s]', msgstr);
      end
    end
  end
  % update nn params
  if iter>=ft_skip_niters
    fprintf('update-nn\t lr_scale=%0.9f\n', lr_scale);
    for ft_iters=1:3
      try
        % forward to get cnn features
        train_net.blobs('data').reshape([blob_shape curbsize]); %reshape cnn input blob
        train_net.reshape();
        train_net.forward({input_data(:,:,:,randidx)});
        for c=1:C 
          fea = train_net.blobs(['ip1.6-' num2str(c)]).get_data(); %cnn features
          xb{c} = double(fea'); %size = curbsize x D
        end
        % compute gradient
        [~,~,~,dxc] = infGrid_svi_dkl(hyp,meanfunc,covg,'likGauss',xb,yb,vi,opt);
        dxcdisp = [];
        dpc = size(dxc{1},2); % dim per c
        diff = zeros(dpc*C, curbsize);
        for c=1:C
          dxcdisp = [dxcdisp dxc{c}(1:2,:)];
          dst = dpc*(c-1)+1;
          ded = dpc*c;
          diff(dst:ded,:)=lr_scale*dxc{c}';
        end
        train_net.backward_from('concat1', {'concat1'}, {diff});
        solver.apply_update();
      catch
        msgstr = lasterr;
        warning('Inference method failed [%s]', msgstr);
      end

      if mod(iter,test_freq)==0
        % test cnn accuracy
        % forward to get cnn features
        train_net.blobs('data').reshape([blob_shape numtest]); %reshape cnn input blob
        train_net.reshape();
        train_net.forward({test_data});
        y_concat = train_net.blobs('concat1').get_data(); %test
        y_concat = y_concat';
        [~,y_pred] = max(y_concat,[],2);
        accu_cnn = sum(y_pred==ytest) / numtest
        for c=1:C
          fea = train_net.blobs(['ip1.6-' num2str(c)]).get_data(); %cnn features
          Xtest_cnn{c} = double(fea'); %size = numtest x D
        end
        % test dkl accuracy
        try
          postg = infGrid_svi_dkl(hyp,meanfunc,covg,'likGauss',[],[],vi,opt);
          ymug = postg.ymu(Xtest_cnn);                            % quick interpolated prediction
          accu = sum(ymug==ytest) / numtest
          %
          fmugn = exp(postg.fmu(Xtest_cnn));
          fmugn = fmugn ./ repmat(sum(fmugn,2), [1,size(fmugn,2)]);
        catch
          msgstr = lasterr;
          warning('Inference method failed [%s]', msgstr);
        end
      end % end of test
    end % end of cnn fine-tuning
  end

  if mod(iter,test_freq)==0 
    fprintf('iter %d, accu %f\n', iter, accu);
  end
end

exit
