function [post nlZ dnlZ] = infGrid(hyp, mean, cov, lik, x, y, opt)

% Inference for a GP with Gaussian likelihood and covGrid covariance.
% The (Kronecker) covariance matrix used is given by:
%   K = kron( kron(...,K{2}), K{1} ) = K_p x .. x K_2 x K_1.
%
% Compute a parametrization of the posterior, the negative log marginal
% likelihood and its derivatives w.r.t. the hyperparameters.
% The result is exact for complete grids, otherwise results are approximate.
% See also "help infMethods".
%
% The function takes a specified covariance function (see covFunctions.m) and
% likelihood function (see likFunctions.m), and is designed to be used with
% gp.m and in conjunction with covGrid* and likGauss.
%
% In case of equispaced data points, we use Toeplitz/BTTB algebra. We use a
% circulant embedding approach to approximate the log determinant of the
% covariance matrix. If any of the factors K_i, i=1..p has Toeplitz or more
% general BTTB structure (which is indicated by K.kron.factor(i).descr being
% equal to 'toep', 'bttb2', 'bttb3', etc.), we automatically use the circulant
% determinant approximation. The grid specification needs to reflect this. There 
% are some examples to illustrate the doubly nested curly bracket formalism.
% See also "help covGrid".
%
% There are a set of options available:
%   The conjugate gradient-based linear system solver has two adjustable
%   parameters, the relative residual threshold for convergence opt.cg_tol and
%   the maximum number of MVMs opt.cg_maxit until the process stops.
% opt.cg_tol,   default is 1e-6      as in Matlab's pcg function
% opt.cg_maxit, default is min(n,20) as in Matlab's pcg function
%   We can tell the inference engine to make functions post.fs2 and post.ys2
%   available in order to compute the latent and predictive variance of an
%   unknown test data point. Precomputations for Perturb-and-MAP sampling are
%   required for these functions.
% opt.pred_var, minimum value is 20 as suggested in the Papandreou paper
%   Instead of the data x, we can tell the engine to use x*hyp.P' to make grid
%   methods available to higher dimensional data. We offer two ways of
%   restricting the projection matrix hyp.P to either orthonormal matrices,
%   where hyp.P*hyp.P'=I or normalised projections diag(hyp.P*hyp.P')=1.
% opt.proj_ortho, enforce orthonormal projections by working with 
%   sqrtm(hyp.P*hyp.P')\hyp.P instead of hyp.P
% opt.proj_norm, enforce normal projections by working with 
%   diag(1./sqrt(diag(hyp.P*hyp.P')))*hyp.P instead of hyp.P
%   The first parameter has priority over the second.
% opt.stat = true returns a little bit of output summarising the exploited
%   structure of the covariance of the grid.
%
% DEPRECATED: Usually, cov = {@covGrid,covg,xg} but we also offer scaled
% additive covariance functions, where cov = {@covGrid,covg,xg,Q,w}. Here, the
% covariance matrix is given by: sum_j=1..Q diag(sj)*Kj*diag(sj), where the
% scalings sj are given by sj=softmax(w1,..,wQ)_j and wj(x) is given by the
% function w, which is a GP mean function. The overall hyperparameters are
%  hyp = [hyp_w1; hyp_cov1; ..; hyp_wQ; hyp_covQ];
%
% Copyright (c) by Hannes Nickisch and Andrew Wilson, 2016-02-03.
%
% See also INFMETHODS.M, COVGRID.M.

if iscell(lik), likstr = lik{1}; else likstr = lik; end
if ~ischar(likstr), likstr = func2str(likstr); end
if ~strcmp(likstr,'likGauss')               % NOTE: no explicit call to likGauss
  error('Inference with infGrid only possible with Gaussian likelihood.');
end
if ~isfield(hyp,'mean'), hyp.mean = []; end  % set an empty mean hyper parameter
if numel(cov)>3 && isnumeric(cov{4}) && numel(cov{4})==1 % DEPRECATED scaled add
  Q = cov{4}; if numel(cov)>4, w=cov{5}; else w={@meanOne}; end, cov = cov(1:3);
  covs = cell(Q,1); ws = cell(Q,1); for j=1:Q, covs{j} = cov{2}; ws{j} = w; end
  cov = {@covGridScaleAdd,covs,cov{3},ws};
end
cov1 = cov{1}; if isa(cov1, 'function_handle'), cov1 = func2str(cov1); end
if ~strncmp(cov1,'covGrid',7); error('Only covGrid* supported.'), end% check cov
scale_add = strcmp(cov1,'covGridScaleAdd');    % additive scaled covariance mode
if scale_add, Q = numel(cov{2}); end           % number of scaled additive terms

xg = cov{3}; p = numel(xg);                 % extract underlying grid parameters
if nargin<=6, opt = []; end                        % make opt variable available
proj_norm  = false; proj_ortho = false;                                   % init
if isfield(opt,'proj_ortho'), proj_ortho = opt.proj_ortho; end  % enforce P*P'=I
if isfield(opt,'proj_norm'),  proj_norm  = opt.proj_norm; end  % or diag(P*P')=1
hP = 1;                                                   % no projection at all
if isfield(hyp,'P')                  % apply transformation matrix P if provided
  hP = hyp.P;
  if proj_ortho                                                       % priority
    hP = sqrtm(hP*hP')\hP;                               % orthonormal projector
  elseif proj_norm
    hP = diag(1./sqrt(diag(hP*hP')))*hP;                      % normal projector
  end
end
if isfield(opt,'stat'), stat = opt.stat; else stat = false; end  % report status

[K,M] = feval(cov{:}, hyp.cov, x*hP');    % evaluate covariance mat constituents
[xx,ng,Dg] = covGrid('idx2dat',xg,x);             % turn into data vector if idx
kronmvm = K.kronmvm;
if stat      % show some information about the nature of the p Kronecker factors
  if scale_add, s = ['K = sum(',num2str(Q),')']; Ks = K.sum.K; Ks = Ks.kron;
  else          s = 'K ='; Ks = K.kron; end; s = [s,' kron[ '];
  for i=1:p
    if isnumeric(Ks(i).factor)
      si = sprintf('mat(%d)',size(Ks(i).factor,1));
    else
      sz = num2str(size(xg{i}{1},1));
      for j=2:numel(xg{i}), sz = [sz,'x',num2str(size(xg{i}{j},1))]; end
      si = sprintf('%s(%s)',Ks(i).factor.descr(1:4),sz);
    end
    if i<p, si = [si,' x ']; end
    s = sprintf('%s%s',s,si);
  end
  fprintf('%s ]\n',s)
end
n = size(x,1); D = sum(Dg); N = prod(ng);                                 % dims
mid = n==N && ~scale_add;                                            % M==eye(N)
if mid, q = randn(N,1); mid = max(abs(q-M*q))<1e-10; end       % random MVM test
m = feval(mean{:}, hyp.mean, xx*hP');                     % evaluate mean vector

if isfield(opt,'cg_tol'), cgtol = opt.cg_tol;         % stop conjugate gradients
else cgtol = 1e-6; end                                          % same as in pcg
if isfield(opt,'cg_maxit'), cgmit = opt.cg_maxit;      % number of cg iterations
else cgmit = min(n,20); end                                     % same as in pcg
% no of samples for perturb-and-MAP, see George Papandreou and Alan L. Yuille:
% "Efficient Variational Inference in Large-Scale Bayesian Compressed Sensing"
ns = 0;                       % do nothing per default, 20 is suggested in paper
if isfield(opt,'pred_var')
  ess = opt.pred_var<0;        % elliptical slice sampling if negative parameter
  ns = max(ceil(abs(opt.pred_var)),20);
end
% no of samples for covariance hyperparameter sampling approximation,
% see Po-Ru Loh et.al.: "Contrasting regional architectures of schizophrenia and
% other complex diseases using fast variance components analysis, biorxiv.org
if isfield(opt,'ndcovs'), ndcovs = max(opt.ndcovs,20);
else ndcovs = 0; end

sn2 = exp(2*hyp.lik);                               % noise variance of likGauss
circ_emb = -5;                                % Whittle embedding overlap factor
if scale_add
  Sq = K.sum.S; Vq=cell(Q,1); Eq=zeros(N,Q); Iq=false(N,Q); dEq=cell(Q,1);
  for j=1:Q
    [Vq{j},Eq(:,j),Iq(:,j),dEq{j}] = covGrid('eig',K.sum(j).K,xg,circ_emb);
  end
  % See also Prob.III.6.14 in Matrix Analysis, Bhatia 1997.
  % lam(A*B) < lam(A)*lam(B), where lam(.) returns sorted eigvals
  [junk,Si] = sort(Sq,1,'descend'); Si = Si + N*ones(N,1)*(0:Q-1);
  [junk,Ei] = sort(Eq,1,'descend'); Ei = Ei + N*ones(N,1)*(0:Q-1);
  SES = Eq(Ei).*Sq(Si).^2;
  [ldu,I] = logdet_weyl(SES); Ei = Ei(I); Si = Si(I);   % propagate Weyl indices
  e = sum(SES(I),2);                              % e = sum(Eq(Ei).*Sq(Si).^2,2)
  Ei = Ei(1:min(n,N),:); Si = Si(1:min(n,N),:);      % restrict to relevant part
  Eqi = Eq(Ei); Sqi = Sq(Si); eqi = sum(Eqi.*Sqi.^2,2);
  assert( norm(eqi-e(1:min(n,N)))<1e-12 )
  dWq = cell(Q,1); for j=1:Q, dWq{j} = K.sum(j).dW; end
  nw = K.nw; nc = K.nc;                   % number of hyp for weights/covariance
  nh = cumsum(nc(:)+nw(:));         % number of hypers after each of the Q terms
  vnum = false;
else
  [V,e,ir,def] = covGrid('eig',K,xg,circ_emb);        % eigenvalue decomposition
  vnum = true;            % flag indicating whether all entries in V are numeric
  for j=1:p, vnum = vnum && isnumeric(V{j}); end
end

if mid
  s = 1./(e+sn2); ord = 1:N;                  % V*diag(s)*V' = inv(K+sn2*eye(N))
else
  [eord,ord] = sort(e,'descend');              % sort long vector of eigenvalues
  if n>N, eord = [eord;zeros(n-N,1)]; ord = [ord;(N+1:n)']; end   % special case
  s = 1./((n/N)*eord(1:n) + sn2);               % approx using top n eigenvalues
end

if mid && vnum                    % decide between Toeplitz or Kronecker algebra
  L=@(k) real(kronmvm(V,repmat(1/sn2-s(ir),1,size(k,2)).*kronmvm(V,k,1)))-k/sn2;
else                                                % go for conjugate gradients
  mvm = @(t) mvmK(t,K,sn2,M);                             % single parameter mvm
  L = @(k) -solveMVM(k,mvm,cgtol,cgmit);
end
alpha = -L(y-m);

post.alpha = alpha;                            % return the posterior parameters
post.sW = ones(n,1)/sqrt(sn2);                         % sqrt of noise precision
post.L = L;                           % function to compute inv(K+sn2*eye(N))*Ks
Mtal = M'*alpha;                              % blow up alpha vector from n to N
if ns>0
  % explained variance on the grid vg=diag(Ku*M'* inv(M*Ku*M'+sn2*eye(n)) *M*Ku)
  % relative accuracy r = std(vg_est)/vg_exact = sqrt(2/ns)
  if ess, snp2 = 0;   nskip = 30; nburn = 25;         % params depending on mode
  else    snp2 = sn2; nskip =  0; nburn =  0;
  end
  if scale_add                                 % z~N(0,A), A=M*Ku*M'+snp2*eye(n)
    z = sample(Vq,Eq,Iq,M,Sq,snp2,ns*(1+nskip)+nburn,kronmvm);
  else
    z = sample(V,e,ir,M,[],snp2,ns*(1+nskip)+nburn,kronmvm);
  end
  if ess                        % draw posterior samples using elliptical slices
    f = zeros(n,1); l = sum(likGauss(hyp.lik,y,m));          % initialize values
    T = ns*(1+nskip)+nburn;
    for t=1:T
      r = z(:,t);                                 % z are samples from the prior
      [f,l] = sample_ess_step(f,l,r,m,y,{@likGauss},hyp.lik);
      z(:,t) = f;
    end
    z = z(:,nburn+1+(0:ns-1)*nskip);            % only keep the relevant samples
    xe = covGrid('expand',xg);
    vg = sn2 + feval(cov{:},hyp.cov,xe,'diag') - M'*var(z,[],2);
  else
    z = K.mvm(M'*L(z)); vg = sum(z.*z,2)/ns;          % z~N(0,Ku*M'*inv(A)*M*Ku)
  end
  if scale_add, error('Not yet implemented.'), end
else
  vg = zeros(N,1);                                       % no variance explained
end
post.predict = @(xs) predict(xs,xg,K.mvm(Mtal),vg,hyp,mean,cov,lik); % f|y,mu|s2

if nargout>1                               % do we want the marginal likelihood?
  lda = -sum(log(s));                                       % complexity penalty
  % exact: lda = 2*sum(log(diag(chol(M*kronmvm(K,eye(N))*M'+sn2*eye(n)))));
  nlZ = (y-m)'*alpha/2 + n*log(2*pi)/2 + lda/2;
  if nargout>2                                         % do we want derivatives?
    dnlZ = hyp;     % allocate space for derivatives, define Q=inv(M*K*M'+sn2*I)
    for i = 1:numel(hyp.cov)
      dK = feval(cov{:}, hyp.cov, x*hP', [], i);
      if scale_add                                   % deal with covGridScaleAdd
        ij = find(i<=nh,1);                % index of the additive term, ij=1..Q
        if ij==1, ii = i; else ii = i-nh(ij-1); end % in ij, ii=1..nw(ij)+nc(ij)
        if ii<=nw(ij)          % derivative w.r.t. squashing function parameters
          sj_dWij = Sq(:,ij).*dWq{ij}(ii); dSq = zeros(size(Sq)); dSq(:,ij) = 1;
          dSq = (dSq - Sq).*(sj_dWij*ones(1,Q));
          de = 2*n/N*sum(Eqi.*Sqi.*dSq(Si),2);
        else            % derivatives w.r.t. ordinary covariance hyperparameters
          % d lam(K) = diag(V'*dK*V), for psd matrix K = V*diag(lam)*V'.
          dKj = K.sum(ij).dK; dEqi = dEq{ij}(dKj(ii-nw(ij)));
          dEqi = dEqi(Ei(:,ij)-N*(ij-1));         % restrict to relevant indices
          de = n/N*dEqi.*Sqi(:,ij).^2;         % only one single term ij is left
        end
      else                                                   % deal with covGrid
        de = [def(dK); zeros(max(n-N,0),1)];             % from size N to size n
        de = (n/N)*de(ord(1:n));    % approximate if incomplete grid observation
      end
      % formally dp'*s lacks a (small) term trace(dK)/sn2-sum(dp)/sn2 with the
      % missing orthogonal complement of V to represent tr((K+sn2*eye(n))\dK)
      % but in practice the term is harmful for incomplete grids
      dlda = de'*s;
      dnlZ.cov(i) = (dlda-Mtal'*dK.mvm(Mtal))/2; % complexity penalty + data fit
    end
    dnlZ.lik = sn2*(sum(s) - alpha'*alpha);                  % sum(s) = trace(Q)
    for i = 1:numel(hyp.mean)
      dnlZ.mean(i) = -feval(mean{:}, hyp.mean, xx*hP', i)'*alpha;
    end
    if isfield(hyp,'P')
      dnlZ.P = deriv_P(alpha,hP,K,covGrid('flatten',xg),m,mean,hyp,x);
      if     proj_ortho, dnlZ.P = chain_ortho(hyp.P,dnlZ.P);
      elseif proj_norm,  dnlZ.P = chain_norm( hyp.P,dnlZ.P);
      end
    end
    if ndcovs>0
      if scale_add                                % a~N(0,C), C = M*Ku*M'+sn^2*I
        A = sample(Vq,Eq,Iq,M,Sq,sn2,ndcovs,kronmvm);     % sign does not matter
      else
        A = sample(V, e, ir,M,[],sn2,ndcovs,kronmvm);
      end
%       R = chol(M*K.mvm(eye(N))*M'+sn2*eye(n))'; % transform such that R*R' = C
%       A = R * randn(n,ndcovs);         % equivalent to sampling a~N(0,C) above
%       A = R * sqrt(ndcovs)*eye(n);       % obtain exact value by full sampling
      A = L(A); MtA = M'*A;                                      % a~N(0,inv(C))
      dcovs = zeros(size(hyp.cov));        % covariance function hyperparameters
      for i = 1:numel(hyp.cov)     
        dK = feval(cov{:}, hyp.cov, x*hP', [], i);        % E[a'*dK*a] - a'*dK*a
        dcovs(i) = (sum(sum(MtA.*dK.mvm(MtA)))/ndcovs-Mtal'*dK.mvm(Mtal))/2;
      end
      dnlZ.covs = dcovs;

      W = zeros(n,numel(hyp.cov));                       % Hessian approximation
      for i = 1:numel(hyp.cov)     
        dK = feval(cov{:}, hyp.cov, x*hP', [], i); W(:,i) = M*dK.mvm(Mtal);
      end
      dnlZ.covs2 = -(W'*L(W))/2;

      if isfield(hyp,'P')                               % data projection matrix
        KMtA = K.mvm(MtA); KMtal = K.mvm(Mtal);
        dPs = zeros(size(hyp.P));
        xP = x*hP'; [M,dM] = covGrid('interp',xg,xP);
        for i = 1:size(dPs,1)
          if equi(xg,i), wi = max(xg{i})-min(xg{i}); else wi = 1; end  % scaling
          for j = 1:size(dPs,2)
            dMtal = dM{i}'*(x(:,j).*alpha/wi);
            dMtA  = dM{i}'*(repmat(x(:,j),1,size(A,2)).*A/wi);
            dPs(i,j) = sum(sum(dMtA.*KMtA))/ndcovs - dMtal'*KMtal;
          end
        end
        dnlZ.Ps = dPs;
        if     proj_ortho, dnlZ.Ps = chain_ortho(hyp.P,dnlZ.Ps);
        elseif proj_norm,  dnlZ.Ps = chain_norm( hyp.P,dnlZ.Ps);
        end
      end
    end  
  end
end
% global mem, S = whos(); mem=0; for i=1:numel(S), mem=mem+S(i).bytes/1e6; end

% elliptical slice sampling: one step
function [f,l] = sample_ess_step(f,l,r,m,y,lik,hyp,tau)
  if nargin<8, tau=1; end, if tau>1, tau=1; end, if tau<0, tau=0; end
  h = log(rand) + tau*l;                                  % acceptance threshold
  a = rand*2*pi; amin = a-2*pi; amax = a;                % bracket whole ellipse
  k = 0;                                                       % emergency break
  while true % slice sampling loop; f for proposed angle diff; check if on slice
    fp = f*cos(a) + r*sin(a);                    % move on ellipsis defined by r
    l = sum(feval(lik{:},hyp,y,fp+m));
    if tau*l>h || k>20, break, end  % exit if new point is on slice or exhausted
    if a>0, amax=a; elseif a<0, amin=a; end     % shrink slice to rejected point
    a = rand*(amax-amin) + amin; k = k+1;  % propose new angle difference; break
  end
  f = fp;                                                               % accept

% sample a~N(0,C), C = M*Ku*M'+sn^2*I
function A = sample(V,E,I,M,S,sn2,ns,kronmvm)
  scale_add = numel(S)>0; [n,N] = size(M);
  if scale_add
    A = zeros(N,ns); Q = numel(V);
    for j=1:Q                              % aggregate additive components of Ku
      Aj = randn(N,ns);                                              % aj~N(0,I)
      Aj = repmat(sqrt(E(I(:,j),j)),1,ns).*kronmvm(V{j},Aj,1);
      Aj = kronmvm(V{j},Aj);                                       % aj~N(0,Kuj)
      A = A + repmat(S(:,j),1,ns).*Aj;
    end
  else
    A = randn(N,ns);                                                  % a~N(0,I)
    A = kronmvm(V,repmat(sqrt(E(I)),1,ns).*kronmvm(V,A,1));
  end                                                                % a~N(0,Ku)
  A = M*A + sqrt(sn2)*randn(n,ns);

% Upper bound the log determinant of a sum of p symmetric positive semi-definite
% matrices of size nxn represented in terms of their eigenvalues E using Weyl's
% inequalities:
% Let a(nx1) and b(nx1) be the orderered eigenvalues of the Hermitian matrices
% A(nxn) and B(nxn). Then, the eigenvalues c(nx1) of the matrix C = A+B are
% upper bounded by c(i+j-1) <= a(i)+b(j).
%
% Each of the p columns of the matrix E of size nxp contains the
% eigenvalues of the respective matrix.
% The string mod contains the method of upper bounding the log determinant:
%  'halve'  - select eigenvalues of roughly equal indices i,j
%  'greedy' - select the locally smallest bound; the integer s is the number of
%             search steps ahead
%
% The index matrix I can be used to reconstruct the contributions of the
% individual eigenvalues to the overall log determinant upper bound given by:
%   ldu = sum(log( sum(E(I),2) ));
function [ldu,I] = logdet_weyl(E,mod,s)
if nargin<2, mod = 'halve'; end
[n,p] = size(E); I = zeros(n,p);                         % get dims, init result
[e,I(:,1)] = sort(E(:,1),'descend');                    % get upper bound so far
i = ceil((1:n)'/2); j = (1:n)'-i+1;        % default: similar index size (halve)
for k=2:p
  [fs,js] = sort(E(:,k),'descend');                     % sort both k-th summand
  if strcmp(mod,'greedy')     % set up i and j adaptively using greedy heuristic
    if nargin<3, s = 1; end                                  % set default value
    i(1) = 1; j(1) = 1;                                                   % init
    for m=2:n
      ss = s;                           % reduce steps if it violates boundaries
      if ss>  i(m-1) || ss>  j(m-1), ss = min(  i(m-1),   j(m-1)); end
      if ss>n-i(m-1) || ss>n-j(m-1), ss = min(n-i(m-1), n-j(m-1)); end
      i_ss = i(m-1)+(1-ss:ss); j_ss = m-i_ss+1;                % left/right step
      [junk,idx] = min(e(i_ss) + fs(j_ss));   % keep min of steps left and right
      i(m) = i_ss(idx); j(m) = j_ss(idx);
    end
  end
  I(:,1:k-1) = I(i,1:k-1);               % sort previous summands according to i
  I(:,k) = js(j) + (k-1)*n;              % bound contributions from k-th summand
  e = e(i) + fs(j);                              % compute Weyl eigenvalue bound
end
ldu = sum(log(e));
assert( norm(e-sum(E(I),2))<1e-10 )
if ~isinf(ldu), assert( abs(ldu-sum(log(sum(E(I),2))))<1e-10 ), end

% compute derivative of neg log marginal likelihood w.r.t. projection matrix P
function dP = deriv_P(alpha,P,K,xg,m,mean,hyp,x)
  xP = x*P'; [M,dM] = covGrid('interp',xg,xP); % grid interp derivative matrices
  beta = K.mvm(M'*alpha);                          % dP(i,j) = -alpha'*dMij*beta
  dP = zeros(size(P)); h = 1e-4;               % allocate result, num deriv step
  for i=1:size(P,1)
    if equi(xg,i), wi = max(xg{i})-min(xg{i}); else wi = 1; end % scaling factor
    xP(:,i) = xP(:,i)+h;
    dmi = (feval(mean{:},hyp.mean,xP)-m)/h;         % numerically estimate dm/di
    xP(:,i) = xP(:,i)-h;
    betai = dmi + dM{i}*beta/wi;
    for j=1:size(P,2), dP(i,j) = -alpha'*(x(:,j).*betai); end
  end

function eq = equi(xg,i)                        % grid along dim i is equispaced
  ni = size(xg{i},1);
  if ni>1                              % diagnose if data is linearly increasing
    dev = abs(diff(xg{i})-ones(ni-1,1)*(xg{i}(2,:)-xg{i}(1,:)));
    eq = max(dev(:))<1e-9;
  else
    eq = true;
  end

% chain rule for the function Q = sqrtm(P*P')\P;  for d sqrtm(X) see the website
function dQ = chain_ortho(P,dP) % http://math.stackexchange.com/questions/540361
  [V,F] = eig(P*P'); sf = sqrt(diag(F)); S = V*diag(sf)*V';         % eig-decomp
  H = dP'/S; G = H'*(P'/S); o = ones(size(dP,1),1);                 % chain rule
  dQ = (H - P'*V*((V'*(G+G')*V)./(sf*o'+o*sf'))*V')';

% chain rule for the function Q = diag(1./sqrt(diag(P*P')))*P;
function dQ = chain_norm(P,dP)
  p = 1./sqrt(diag(P*P'));
  dQ = diag(p)*dP - diag(diag(dP*P').*p.^3)*P;

function q = mvmK(p,K,sn2,M) % mvm q = M*K*M'*p + sn2*p using the Kronecker repr
  q = M*K.mvm(M'*p) + sn2*p;

function q = solveMVM(p,mvm,varargin) % solve q = mvm(p) via conjugate gradients
  [q,flag,relres,iter] = conjgrad(mvm,p,varargin{:});                 % like pcg
  if ~flag,error('Not converged after %d iterations, r=%1.2e\n',iter,relres),end

% Compute latent and predictive means and variances by grid interpolation.
function [fmu,fs2,ymu,ys2] = predict(xs,xg,Kalpha,vg,hyp,mean,cov,lik)
  Ms = covGrid('interp',xg,xs);                    % obtain interpolation matrix
  xs = covGrid('idx2dat',xg,xs);                        % deal with index vector
  ms = feval(mean{:},hyp.mean,xs);                         % evaluate prior mean
  fmu = ms + Ms*Kalpha;                 % combine and perform grid interpolation
  if nargout>1
    if norm(vg,1)>1e-10, ve = Ms*vg; else ve = 0; end    % interp grid var expl.
    ks = feval(cov{:},hyp.cov,xs,'diag');              % evaluate prior variance
    fs2 = max(ks-ve,0);              % combine, perform grid interpolation, clip
    % if nargout>2, [lp, ymu, ys2] = feval(lik{:},hyp.lik,[],fmu,fs2); end
    ymu = fmu; ys2 = fs2 + exp(2*hyp.lik);               % explicit for likGauss
  end

% Solve x=A*b with symmetric A(n,n), b(n,m), x(n,m) using conjugate gradients.
% The method is along the lines of PCG but suited for matrix inputs b.
function [x,flag,relres,iter,r] = conjgrad(A,b,tol,maxit)
if nargin<3, tol = 1e-10; end
if nargin<4, maxit = min(size(b,1),20); end
x0 = zeros(size(b)); x = x0;
if isnumeric(A), r = b-A*x; else r = b-A(x); end, r2 = sum(r.*r,1); r2new = r2;
nb = sqrt(sum(b.*b,1)); flag = 0; iter = 1;
relres = sqrt(r2)./nb; todo = relres>=tol; if ~any(todo), flag = 1; return, end
on = ones(size(b,1),1); r = r(:,todo); d = r;
for iter = 2:maxit
  if isnumeric(A), z = A*d; else z = A(d); end
  a = r2(todo)./sum(d.*z,1);
  a = on*a;
  x(:,todo) = x(:,todo) + a.*d;
  r = r - a.*z;
  r2new(todo) = sum(r.*r,1);
  relres = sqrt(r2new)./nb; cnv = relres(todo)<tol; todo = relres>=tol;
  d = d(:,~cnv); r = r(:,~cnv);                           % get rid of converged
  if ~any(todo), flag = 1; return, end
  b = r2new./r2;                                               % Fletcher-Reeves
  d = r + (on*b(todo)).*d;
  r2 = r2new;
end