clear all, close all, seed = 4; randn('seed',seed); rand('seed',seed)
dev = @(x,y) max(abs(x(:)-y(:)))/max([max(abs(x(:))),max(abs(y(:))),1]);

addpath toepgrid                                              % infGrid, covGrid
opt.cg_maxit = 500; opt.cg_tol = 1e-4; opt.stat = true;       % keep these fixed

% n = 6500; nu1 = 350; nu2 = 250; ns1 = 900; ns2 = 1200; % FULL
n = 500; nu1 = 150; nu2 = 120; ns1 = 40; ns2 = 50; % REGRESSION TEST (res@end)

x1 = 20*rand(n,1)-5; x2 = 10*rand(n,1)-3; x = [x1,x2];
y = sin(0.8*x1+0.6*x2).*exp(-((x1-5).^2+(x2-3).^2)/(2*3.5^2)) + randn(n,1)/20;
mean = {@meanLinear}; hyp.mean = [0.0; 0.0];
lik = @likGauss; sn = 0.3;  hyp.lik = log(sn);
sf = 1.5; ell = 0.5; hyp.cov = log([ell;sf]);
cov = {@covSEiso};
xg1 = linspace(-6,16,nu1)'; xg2 = linspace(-4,8,nu2)';
xs1 = linspace(-7,17,ns1);  xs2 = linspace(-3,9,ns2);
[Xs2,Xs1] = meshgrid(xs2,xs1);
xs = [Xs1(:),Xs2(:)];

% BTTB
tic
covg = {@covGrid,cov,{{xg1,xg2}}};
[post nlZ dnlZ] = infGrid(hyp,mean,covg,lik,x,y,opt);
ymu = post.predict(xs);
t = toc;
s = sprintf('n=%1.1g, nu=%1.1g, ns=%1.1g: t=%1.1fs\n',n,nu1*nu2,ns1*ns2,t);
p = strfind(s,'+0');
for i = p(end:-1:1), s = [s(1:i-1),s(i+2:end)]; end
fprintf(s)

% twice Toeplitz
tic
hypt = hyp; hypt.cov = log([ell;sqrt(sf); ell;sqrt(sf)]);
covt = {@covGrid,[cov,cov],{{xg1},{xg2}}};
[postt nlZt dnlZt] = infGrid(hypt,mean,covt,lik,x,y,opt);
ymut = postt.predict(xs);
tt = toc;
s = sprintf('n=%1.1g, nu=%1.1g, ns=%1.1g: t=%1.1fs\n',n,nu1*nu2,ns1*ns2,tt);
p = strfind(s,'+0');
for i = p(end:-1:1), s = [s(1:i-1),s(i+2:end)]; end
fprintf(s)
if dev(postt.alpha,post.alpha)+dev(ymut,ymu)>1e-3, error('Bug.'), end

% pure Kronecker
tic
hypk = hyp; hypk.cov = log([ell;sqrt(sf); ell;sqrt(sf)]);
covk = {@covGrid,[cov,cov],{xg1,xg2}};
[postk nlZk dnlZk] = infGrid(hypk,mean,covk,lik,x,y,opt);
ymuk = postk.predict(xs);
tk = toc;
s = sprintf('n=%1.1g, nu=%1.1g, ns=%1.1g: t=%1.1fs\n',n,nu1*nu2,ns1*ns2,tk);
p = strfind(s,'+0');
for i = p(end:-1:1), s = [s(1:i-1),s(i+2:end)]; end
fprintf(s)
if dev(postk.alpha,post.alpha)+dev(ymuk,ymu)>1e-3, error('Bug.'), end

[nlZ,nlZt,nlZk]

if n<=1000 && ns1*ns2<=2000
  % non-factorising BTTB
  covh = {@covGaborard}; hyph = hyp; hyph.cov = log([ell;1.2*ell;0.9;1.3]);
  covg = {@covGrid,covh,{{xg1,xg2}}};
  [posta nlZa dnlZa] = infGrid(hyph,mean,covg,lik,x,y,opt);
  dnlZn = 0*unwrap(hyph); h = 1e-4;
  opt.stat = false;
  for i=1:numel(dnlZn)
    hyp_h = unwrap(hyph); hyp_h(i) = hyp_h(i)+h; hyp_h = rewrap(hyph,hyp_h);
    [post_h nlZ_h] = infGrid(hyp_h,mean,covg,lik,x,y,opt);
    dnlZn(i) = (nlZ_h-nlZa)/h;
  end
  fprintf('  derivative err=%1.1e\n',dev(dnlZn,unwrap(dnlZa)))
end

subplot(121)
  scatter(x(:,1),x(:,2),[],y), grid on
  xlim([-5,15]), ylim([-3,7]), title('training data')
subplot(122)
  imagesc(xs1,xs2,reshape(ymu,[ns1,ns2])'), axis xy
  xlim([-5,15]), ylim([-3,7]), title('prediction')
rmpath toepgrid

% regression test
% test_infGrid_bttb
% K = bttb(150x120)
% n=5e2, nu=2e4, ns=2e3: t=9.2s
% K = toep(150) x toep(120)
% n=5e2, nu=2e4, ns=2e3: t=5.9s
% K = mat(150) x mat(120)
% n=5e2, nu=2e4, ns=2e3: t=1.0s
% 
% ans =
% 
%   588.7575  588.7575  591.3453
% 
% K = bttb(150x120)
%   derivative err=1.5e-02