function [post nlZ dnlZc dxc] = infGrid_svi_dkl(hypc, mean, covc, lik, xxc, yc, vi, opt)

% Inference for a GP with Gaussian likelihood and covGrid covariance.
% The (Kronecker) covariance matrix used is given by:
%   K = kron( kron(...,K{2}), K{1} ) = K_p x .. x K_2 x K_1.
%
% Compute a parametrization of the posterior, the negative log marginal
% likelihood and its derivatives w.r.t. the hyperparameters.
% The result is exact for complete grids, otherwise results are approximate.
% See also "help infMethods".
%
% The function takes a specified covariance function (see covFunctions.m) and
% likelihood function (see likFunctions.m), and is designed to be used with
% gp.m and in conjunction with covFITC and likGauss.
%
% In case of equispaced data points, we use Toeplitz algebra. We offer to use a
% circulant embedding of the Toeplitz matrix and Fourier mechanics.
%
% Copyright (c) by Hannes Nickisch and Andrew Wilson, 2015-09-26.
%
% See also INFMETHODS.M, COVGRID.M.

if iscell(lik), likstr = lik{1}; else likstr = lik; end
if ~ischar(likstr), likstr = func2str(likstr); end
if ~strcmp(likstr,'likGauss')               % NOTE: no explicit call to likGauss
  error('Inference with infGrid only possible with Gaussian likelihood.');
end
if ~isfield(hypc{1},'mean'), hypc{1}.mean = []; end  % set an empty mean hyper parameter
cov = covc{1};
cov1 = cov{1}; if isa(cov1, 'function_handle'), cov1 = func2str(cov1); end
if ~strcmp(cov1,'covGrid'); error('Only covGrid supported.'), end    % check cov

xg = cov{3}; p = numel(xg);                 % extract underlying grid parameters
if nargin<=6, opt = []; end                        % make opt variable available
if isfield(opt,'proj_ortho'), proj_ortho = opt.proj_ortho;    % enforce P*P' = I
else proj_ortho = false; end,

% START
additive_mode = numel(cov)>3;
if additive_mode
  disp '(additive mode) Not supported';
  exit;
  Q = cov{4}; if numel(cov)>4, w=cov{5}; else w={@meanZero}; end,cov = cov(1:3);
  fprintf('%d additive scaling terms\n',Q)
end
% END

pi = vi.pi;
post.fmu  = @(xsc) pred_muc(xsc,vi.m,covc,pi,hypc,mean,opt.C);  % interpolate latent mean
post.ymu  = @(xsc) pred_y(post.fmu(xsc));
post.liks = @(xsc,ys) pred_lik(post.fmu(xsc), ys);
if nargout==1, return; end

if isfield(opt,'stat'), stat = opt.stat; else stat = false; end  % report status
n = size(xxc{1},1); 
if isfield(opt,'cg_tol'), cgtol = opt.cg_tol;         % stop conjugate gradients
else cgtol = 1e-5; end
if isfield(opt,'cg_maxit'), cgmit = opt.cg_maxit;      % number of cg iterations
else cgmit = min(n,20); end
totn = opt.n;
C = opt.C;
scale = totn / n;

nlZ = 0;
if opt.update_vi==1  % allocate space for derivatives
  dnlZc = vi;
else 
  dnlZc = hypc;
end
dxc = cell(size(xxc));

Kc = cell(C,1);
Mc = cell(C,1);
mc = cell(C,1);
for c=1:C
  cov = covc{c}; xg = cov{3};
  x = xxc{c};
  [xx,ng,Dg] = covGrid('idx2dat',xg,x);             % turn into data vector if idx
  hyp = hypc{c};
  hP = 1;                              % no projection at all
  if isfield(hyp,'P')                  % apply transformation matrix P if provided
    hP = hyp.P; if proj_ortho, hP = sqrtm(hP*hP')\hP; end  % orthonormal projector
  end
  [K,M] = feval(cov{:}, hyp.cov, x*hP');    % evaluate covariance mat constituents
  for j=1:p
    if any(isinf(K.kron(j).factor)), warning('K inf'); end %TODO
    if any(isnan(K.kron(j).factor)), warning('K nan'); end %TODO
  end
  Kc{c} = K;
  Mc{c} = M;
  mc{c} = feval(mean{:}, hyp.mean, xx*hP');  % evaluate mean vector
end
kronmvm = Kc{1}.kronmvm;

if nargout>1

  for c=1:C
    for si=1:opt.vi.mccnt
      opt.vi.u(:,si,c) = vi.m(:,c) + kronmvm(vi.L{c}, opt.vi.eps(:,si,c));
    end
  end

  mCProbc = zeros(n,C,opt.vi.mccnt);
  CProbc = zeros(n,C,opt.vi.mccnt);
  for c=1:C % Compute distributions over classes
    CProbc(:,c,:) = Mc{c}*opt.vi.u(:,:,c)+repmat(mc{c},[1,opt.vi.mccnt]);
  end
  for si=1:opt.vi.mccnt % mix
    mCProbc(:,:,si) = CProbc(:,:,si)*pi';   
  end
  mCProbc = mCProbc-repmat(max(mCProbc,[],2),[1,C,1]); % to avoid overflow
  mCProbc = exp(mCProbc);
  mCProbcsum = sum(mCProbc, 2);
  if any(isinf(mCProbc)) % check infinity
    warning('mCProbc inf !');
  end
  if any(isinf(mCProbcsum)) % check infinity
    warning('mCProbcsum inf !');
  end
  mCProbc = mCProbc./repmat(mCProbcsum,[1,C,1]);
  opt.vi.mCProbc = permute(mCProbc, [1 3 2]);
  
  ycext = zeros(n,C,1);
  ycext(:,:,1) = yc;
  ycext = repmat(ycext, [1,1,opt.vi.mccnt]);
  yCProbc = ycext - mCProbc;
  opt.vi.yCProbc(1:n,:,:) = permute(yCProbc, [1 3 2]); % n x mccnt x C
  for si=1:opt.vi.mccnt
    opt.vi.myCProbc(1:n,si,:) = yCProbc(:,:,si)*pi; 
  end

  %if opt.update_vi == 1
  %  % do nothing
  %else  % marginal likelihood
    ycext = zeros(n,1,C);
    ycext(:,1,:) = yc;
    ycext = repmat(ycext, [1,opt.vi.mccnt,1]);
    nlZext = opt.vi.mCProbc.*ycext;
    if any(isinf(nlZext)), error('inf: opt.vi.mCProbc.*ycext'); end %TODO
    nlZext = -log(nlZext(nlZext~=0)); % negative
    if any(isinf(nlZext)), error('inf: log(opt.vi.mCProbc.*ycext)'); end %TODO
    nlZ = sum(nlZext(:))/opt.vi.mccnt;  % data fit term
    nlZ = nlZ*scale;
    if isinf(nlZ), error('inf: nlZ data-fit'); end %TODO
    if isnan(nlZ), error('nan: nlZ data-fit'); end %TODO

  %end
  
  for c=1:C
    %fprintf('c=%d\n', c);

    cov = covc{c}; xg = cov{3};
    x = xxc{c};
    [xx,ng,Dg] = covGrid('idx2dat',xg,x);             % turn into data vector if idx
    D = sum(Dg); N = prod(ng);                        % dims
    hyp = hypc{c};
    hP = 1;                              % no projection at all
    if isfield(hyp,'P')                  % apply transformation matrix P if provided
      hP = hyp.P; if proj_ortho, hP = sqrtm(hP*hP')\hP; end  % orthonormal projector
    end
    y = yc(:,c);
    %[K,M] = feval(cov{:}, hyp.cov, x*hP');    % evaluate covariance mat constituents
    K = Kc{c};
    M = Mc{c};
    m = mc{c};                     % evaluate mean vector
    kronmvm = K.kronmvm;
  
    if stat      % show some information about the nature of the p Kronecker factors
      s = 'K = ';
      for i=1:p
        if isnumeric(K.kron(i).factor)
          si = sprintf('mat(%d)',size(K.kron(i).factor,1));
        else
          sz = num2str(size(xg{i}{1},1));
          for j=2:numel(xg{i}), sz = [sz,'x',num2str(size(xg{i}{j},1))]; end
          si = sprintf('%s(%s)',K.kron(i).factor.descr(1:4),sz);
        end
        if i<p, si = [si,' x ']; end
        s = sprintf('%s%s',s,si);
      end
      fprintf('%s\n',s)
    end
    mid = n==N; if mid, q = randn(N,1); mid = max(abs(q-M*q))<1e-10; end % M==eye(N)
  
    circ_emb = -5;                                % Whittle embedding overlap factor
    [V,Vfull,E,e,ir] = eigkron(K,xg,ng,circ_emb);                 % eigenvalue decomposition
    vnum = true;              % flag indicating whether all entries in V are numeric
    for j=1:p, vnum = vnum && isnumeric(V{j}); end
    % START
    vnum = vnum && ~additive_mode;
    % END
  
    vim = vi.m(:,c);
    L = vi.L{c};
  
    %% Compute neg log likelihood
    % Data fit term has been done
    %fprintf('nlZ data-fit\t%f\n', nlZ);
    % KL term
    logdetS = 0;
    trprod = 1;
    rankD = N;
    for di=1:p
      VSVdiag = sum((Vfull{di}'*L{di}*L{di}').*Vfull{di}.',2); % VSVdiag = diag(V'*S*V)
      if any(isnan(VSVdiag))
        if any(isnan(L{di})), fprintf('nan L{%d}\n', di); end
        if any(isnan(Vfull{di})), fprintf('nan Vfull{%d}\n', di); end
        fprintf('nan VSVdiag d=%d\n', di); 
      end
      trprod = trprod * (1./(E{di}+opt.eps)'*VSVdiag);
      if isnan(trprod), fprintf('nan trprod d=%d\n', di); end
      %fprintf('L max %f\tmin %f\n', max(L{di}(:)), min(L{di}(:))); 
      %fprintf('Vfull max %f\tmin %f\tave %f\n', max(Vfull{di}(:)), min(Vfull{di}(:)), sum(Vfull{di}(:))/numel(Vfull{di}(:))); 
      %fprintf('VSVdiag max %f\tave %f\n', max(VSVdiag), sum(VSVdiag)/numel(VSVdiag)); 
      logdetS = logdetS + (N/ng(di))*sum(log(diag(L{di})));
      if any(isinf(log(diag(L{di})))), warning('logdetS inf!'); end %TODO
    end
    logdetS = 2*logdetS; 
    mKinvm = vim'*kronmvm(Vfull,1./(e+opt.eps).*kronmvm(Vfull,vim,1));
    nlZKL = (sum(log(e+opt.eps)) - logdetS - p + trprod + mKinvm)/2;
    if isinf(nlZKL), error('inf: nlZ KL'); end %TODO
    if isnan(nlZKL)
      if isnan(sum(log(e+opt.eps))), warning('nan: log|inv(K)|'); end
      if isnan(logdetS), warning('nan: log|S|'); end
      if isnan(trprod), warning('nan: trprod'); end
      error('nan: nlZ KL'); 
    end
    %fprintf('nlZ\t%f\t%f\t%f\t%f\n', sum(log(e+opt.eps)), -logdetS, trprod, mKinvm);
    %fprintf('nlZ KL\t%f\n', nlZKL);
    
    nlZ = nlZ + nlZKL;

    if nargout>2                                         % do we want derivatives?
      if opt.update_vi==1  % E-step update varitional params

        %% Derivatives
        % Deriv wrt vi.m
        %fprintf('compute deriv vi.m\n');
        vim = vi.m(:,c);
        myCProb = opt.vi.myCProbc(1:n,:,c);
        vieps = opt.vi.eps(:,:,c);
        %dvim = M'*y - sum(M'*CProb, 2)/opt.vi.mccnt; % (1) softmax term
        dvim = scale*M'*sum(myCProb, 2)/opt.vi.mccnt; % (1) softmax term
        %dvimsm = dvim(1:10)' %TODO
        %tmpdvimkl = kronmvm(Vfull,1./(e+opt.eps).*kronmvm(Vfull,vim,1));
        %dvimkl = tmpdvimkl(1:10)' %TODO
        dvim = dvim - kronmvm(Vfull,1./(e+opt.eps).*kronmvm(Vfull,vim,1)); % (2) KL term
        dvim = -dvim; % to minimize neg lower-bound

        % Deriv wrt vi.L
        %fprintf('compute deriv vi.L\n');
        if p==1 

          di=1;
          L = vi.L{c};
          kronKinv = cell(p,1); % { inv(Kd) }
          trKinvSext = zeros(p,1); % [ tr(inv(Kd)*Sd) ]
          kronKinv{di} = Vfull{di} .* repmat(1./(E{di}+opt.eps)',[ng(di),1]) * Vfull{di}';
          KinvSdiag = (L{di}*L{di}').*kronKinv{di}.'; % sum(KinvSdiag, 2) = diag(KinvS)
          trKinvSext(di) = sum(KinvSdiag(:));
          trprod = prod(trKinvSext);

          dL = L; 

          kronKinvdi = kronKinv{di};
          Ldi = L{di};
          Ldiinv = Ldi'\eye(ng(di)); Ldiinv = Ldiinv'; % inv(Ldi)

          trdKSdi = kronKinvdi*Ldi*(trprod/trKinvSext(di));

          dLdi = zeros(size(Ldi));
          for si=1:opt.vi.mccnt
            dLdi = dLdi + (repmat(myCProb(:,si),[1,N]).*M)'*repmat(vieps(:,si)',[n,1]); % (1) (unscaled) softmax term
          end
          dLdi = scale*dLdi/opt.vi.mccnt;

          dLdi = dLdi + Ldiinv - trdKSdi;  % (2) KL term
          dL{di} = -dLdi; % to minimize neg lower-bound
 
        else
          L = vi.L{c};
          kronKinv = cell(p,1); % { inv(Kd) }
          trKinvSext = zeros(p,1); % [ tr(inv(Kd)*Sd) ]
          for di=1:p
            kronKinv{di} = Vfull{di} .* repmat(1./(E{di}+opt.eps)',[ng(di),1]) * Vfull{di}';
            KinvSdiag = (L{di}*L{di}').*kronKinv{di}.'; % sum(KinvSdiag, 2) = diag(KinvS)
            trKinvSext(di) = sum(KinvSdiag(:));
          end
          trprod = prod(trKinvSext);
          dL = L; 
          tmpdL = L; % temp variables to speedup
          for di=1:p % every dimension
            kronKinvdi = kronKinv{di};
            Ldi = L{di};
            Ldiinv = Ldi'\eye(ng(di)); Ldiinv = Ldiinv'; % inv(Ldi)

            %if c==3  % TODO
              %Ldi
              %Ldiinv
            %end
            %maxL = max(Ldi(:)); minL = min(Ldi(:));
            %maxLinv = max(Ldiinv(:)); minLinv = min(Ldiinv(:));
            %maxminLdi = [maxL minL maxLinv minLinv]

            tmpdL{di} = zeros(size(tmpdL{di}));        
            trdKSdi = kronKinvdi*Ldi*(trprod/trKinvSext(di));

            %trprod   %TODO
            %trKinvSext'
            %maxkronKinv = max(kronKinv{di}(:));
            %minkronKinv = min(kronKinv{di}(:));
            %maxminkronKinv = [minkronKinv maxkronKinv]
            %tempSdi = Ldi*Ldi';
            %maxtempSdi = max(tempSdi(:));
            %mintempSdi = min(tempSdi(:));
            %maxminSdi = [maxtempSdi mintempSdi]
            %temptrdKSdi = kronKinvdi*Ldi;
            %maxtrKinvS = max(temptrdKSdi(:));
            %mintrKinvS = min(temptrdKSdi(:));
            %maxmintrKinvS = [mintrKinvS maxtrKinvS]

            dLdi = zeros(size(Ldi));
            for j=1:ng(di)
              %tic
              for jj=1:j % every non-zero element of the lower triangular matrix
                dl = 0;
                %tic
                %dLL = deriv_L(L,di,[j,jj]);
                tmpdL{di}(j,jj) = 1;
                %toc
                %fprintf('sampling\n');
                %tic       
                for si=1:opt.vi.mccnt
                  dl = dl + myCProb(:,si)'*M*kronmvm(tmpdL,vieps(:,si));
                end
                dLdi(j,jj) = dl;     % (1) (unscaled) softmax term
                %toc
       
                tmpdL{di}(j,jj) = 0;

                %% TODO: test
                %if c==3
                  %fprintf('sm term %f\t\t%d\n', dl/opt.vi.mccnt, j);
                  %fprintf('kl %f \t Linv %f \t trprod %f \t dtr %f\n', ...
                  %    Ldiinv(j,jj) - trprod/trKinvSext(di)*(kronKinvdi(j,:)*Ldi(:,jj)), ...
                  %    Ldiinv(j,jj), trprod/trKinvSext(di), kronKinvdi(j,:)*Ldi(:,jj));
                %end

              end % end of col jj
              %toc
            end % end of row j
            dLdi = scale*dLdi/opt.vi.mccnt;

            %maxdLdism = max(dLdi(:)); %TODO
            %mindLdism = min(dLdi(:));
            %maxmindLsm = [mindLdism maxdLdism]
            %dLdi(:,1)'

            dLdi = dLdi + Ldiinv - trdKSdi;  % (2) KL term
            dL{di} = -dLdi; % to minimize neg lower-bound
            tmpdL{di} = L{di};

            %KLterm = Ldiinv - trdKSdi; %TODO
            %maxdLdiKL = max(KLterm(:));
            %mindLdiKL = min(KLterm(:));
            %maxmindLKL = [mindLdiKL maxdLdiKL]
            %KLterm(:,1)'

          end % end of dimension 
        end

        %% Update
        % Update on vi.m
        %vimdiff = dvim * opt.vi.mlp_diff_scale;
        %vim_net = vi.msolvers(c).net;
        %vim_net.backward_from('ip3', {'ip3'}, {vimdiff'}); % backward
        %opt.vi.msolvers(c).apply_update();
        %vim_net.forward({vi.xge_mlpinput}); % forward
        %mpred = vim_net.blobs('ip3').get_data();
        %vi.m(:,c) = double(mpred');
        %
        %vi.m(:,c) = vi.m(:,c) - opt.vi.mlp_diff_scale*opt.vi.lr*dvim;
        %
        %vi.dm(:,c) = opt.vi.momentum*vi.dm(:,c) - opt.vi.lr*opt.vi.mlp_diff_scale*dvim;
        %
        dnlZc.m(:,c) = dvim;

        % Update on vi.L
        dnlZc.L{c} = dL;
        %for di=1:p
        %  dnlZc.L{c}{di} = zeros(size(dnlZc.L{c}{di}));
        %end

        %for di=1:p
          %vi.dL{c}{di} = opt.vi.momentum*vi.dL{c}{di} - opt.vi.lr*dL{di};
          %
          %vi.L{c}{di} = vi.L{c}{di} - opt.vi.lr*dL{di};
          %tmpL = vi.L{c}{di}; % TODO
          %tmpL(tmpL>0) = opt.eps2; tmpL(tmpL<0) = -opt.eps2;
          %vi.L{c}{di} = vi.L{c}{di} + tmpL;
          %
          %maxL = max(vi.L{c}{di}(:)); % TODO
          %minL = min(vi.L{c}{di}(:));
          %maxminL = [maxL minL]
          %index = [c di]
        %end 

        % deriv wrt mixing weights 
        dpipred = zeros(n,opt.vi.mccnt,1);
        dpipred(:,:,1) = M*opt.vi.u(:,:,c)+repmat(m,[1,opt.vi.mccnt]);
        dpi = opt.vi.yCProbc(1:n,:,:).*repmat(dpipred, [1,1,C]); % n x mccnt x C
        dnlZc.pi(:,c) = -sum(sum(dpi,1),2)/opt.vi.mccnt*scale;
        dnlZc.pi(:,c) = dnlZc.pi(:,c) + 2*opt.pi_decay*pi(:,c); % L2-regularizor
 
      else
   
        dnlZ = hyp;     % allocate space for derivatives
  
        Kinvm = kronmvm(Vfull,1./(e+opt.eps).*kronmvm(Vfull,vim,1));
        for i = 1:numel(hyp.cov)
          dK = feval(cov{:}, hyp.cov, x*hP', [], i);
          dcmplx = 1; dKS = -1;
          for j=1:p
            VdKVEj = Vfull{j}'*dK.kron(j).factor*Vfull{j}.*repmat(1./(E{j}+opt.eps)',[ng(j),1]); % V{j}'*dK{j}*V{j}*inv(E{j})
            VLj = Vfull{j}'*L{j};
            trKdKKSj = 1./(E{j}+opt.eps)'*sum((VdKVEj*VLj).*VLj.',2); % tr(inv(K)*dK*inv(K)*S) {j}
            dKS = dKS*trKdKKSj; 
            dcmplx = dcmplx*trace(VdKVEj);
          end
          dfit = -Kinvm'*dK.mvm(Kinvm);
          dnlZcovi = (dcmplx+dfit+dKS)/2;
          dnlZ.cov(i) = dnlZcovi;
        end
  
        myCProb = opt.vi.myCProbc(1:n,:,c);
        for i = 1:numel(hyp.mean)
          dnlZ.mean(i) = -feval(mean{:}, hyp.mean, xx*hP', i)'* ...
              sum(myCProb, 2)/opt.vi.mccnt*scale + 2*opt.mean_decay*hyp.mean(i);
        end 

        if isfield(hyp,'P')
          dnlZ.P = deriv_P(hP,covGrid('flatten',xg),m,mean,hyp,x,vi,myCProb,scale,c);
          if proj_ortho, dnlZ.P = chain_ortho(hyp.P,dnlZ.P); end
        end
  
        dnlZc{c} = dnlZ;

        if nargout>3                                % do we want derivatives w.r.t. x?
          dx = deriv_x(hP,covGrid('flatten',xg),m,mean,hyp,x,vi,myCProb,scale,c,opt);
          dxc{c}=dx;
        end

      end % end of deriv wrt hypers
    end % end of computing derivs
  end % end of classes

  tmpdnlZc = unwrap(dnlZc); % normalize
  tmpdnlZc = tmpdnlZc./norm(tmpdnlZc);
  dnlZc = rewrap(dnlZc, tmpdnlZc);
end

if 0 && opt.update_vi==1 % temp TODO
  idx_temp = randi(n,10,1); % temp TODO
  [a_temp, b_temp] = find(yc(idx_temp,:) == 1);
  ab_temp = [a_temp, b_temp];
  ab_temp = sortrows(ab_temp);
  mc_temp = ab_temp(:,2)'; % ground truth
  for c=1:C
    mc_temp = [mc_temp; mc{c}(idx_temp)'];
  end
  mc_temp
  mix_mc_temp = mc_temp;
  mix_mc_temp(2:end,:) = pi*mc_temp(2:end,:);
  mix_mc_temp

  vimc_temp = zeros(C, 10);
  for c=1:C
    vimc_temp(c,:) = M(idx_temp,:)*vi.m(:,c);
  end
  vimc_temp

  mix_vimc_temp = pi*vimc_temp
end

% ------- methods of classification svi --------

function dL = deriv_L(L, d, idx)
  dLd = zeros(size(L{d}));
  dLd(idx) = 1;
  dL = L;
  L{d} = dLd;

function v = unwrap(s)
% Extract the numerical values from "s" into the column vector "v". The
% variable "s" can be of any type, including struct and cell array.
% Non-numerical elements are ignored. See also the reverse rewrap.m. 
v = [];   
if isnumeric(s)
  v = s(:);                        % numeric values are recast to column vector
elseif isstruct(s)
  v = unwrap(struct2cell(orderfields(s))); % alphabetize, conv to cell, recurse
elseif iscell(s)
  for i = 1:numel(s)             % cell array elements are handled sequentially
    v = [v; unwrap(s{i})];
  end
end                                                   % other types are ignored

function [s v] = rewrap(s, v)
% Map the numerical elements in the vector "v" onto the variables "s" which can
% be of any type. The number of numerical elements must match; on exit "v"
% should be empty. Non-numerical entries are just copied. See also unwrap.m.
if isnumeric(s)
  if numel(v) < numel(s)
    error('The vector for conversion contains too few elements')
  end
  s = reshape(v(1:numel(s)), size(s));            % numeric values are reshaped
  v = v(numel(s)+1:end);                        % remaining arguments passed on
elseif isstruct(s) 
  [s p] = orderfields(s); p(p) = 1:numel(p);      % alphabetize, store ordering
  [t v] = rewrap(struct2cell(s), v);                 % convert to cell, recurse
  s = orderfields(cell2struct(t,fieldnames(s),1),p);  % conv to struct, reorder
elseif iscell(s)
  for i = 1:numel(s)             % cell array elements are handled sequentially 
    [s{i} v] = rewrap(s{i}, v);
  end
end                                             % other types are not processed


% ----------------------------------------------

% START
% precompute terms for mvm with additive scaled covariance
function [Kq,Sq,dKq,dWq] = additive_scaled_mvm_prec(N,Q,D,w,cov,hyp,x,xg,hP)
  nw = eval(feval(w{:})); nc = eval(feval(cov{:}));           % number of hypers
  Kq = cell(Q,1); dKq = cell(Q,1); Wq = zeros(N,Q); dWq = cell(Q,1);  % allocate
  for j=1:Q
    hypwq = hyp((j-1)*(nw+nc)+   (1:nw));      % scaling weight hyper parameters
    hypcq = hyp((j-1)*(nw+nc)+nw+(1:nc)); % covariance function hyper parameters
    Wq(:,j) = feval(w{:}, hypwq, covGrid('expand',xg));        % scaling weights
    dWq{j} = @(i) feval(w{:}, hypwq, covGrid('expand',xg), i); % derivative sc-w
    Kq{j} = feval(cov{:}, hypcq, x*hP');                     % covariance matrix
    dKq{j} = @(i) feval(cov{:}, hypcq, x*hP', [], i);    % derivative cov matrix
  end
  Sq = softmax(Wq);

% mvm with additive scaled covariance
function r = additive_scaled_mvm(t,Kq,Sq,dKq,dWq,nc,nw,i)
  r = zeros(size(t)); o = ones(1,size(r,2)); Q = numel(Kq);
  if nargin>7                                                      % derivatives
    ii = mod(i-1,nc+nw)+1;           % hyperparameter index within additive term
    ij = 1+(i-ii)/(nc+nw);                          % index of the additive term
    sj_dWij = Sq(:,ij).*dWq{ij}(ii);                            % factor for dsj
    if ii<=nw                  % derivative w.r.t. squashing function parameters
      for j=1:Q
        dsj = ((double(ij==j)-Sq(:,j)).*sj_dWij)*o; sj = Sq(:,j)*o;  % dsj/dh_ii
        r = r + dsj.*Kq{j}.mvm(sj.*t) + sj.*Kq{j}.mvm(dsj.*t);
      end
    else                % derivatives w.r.t. ordinary covariance hyperparameters
      dKij = dKq{ij}(ii-nw);                   % only term ij is left in the sum
      sj = Sq(:,ij)*o; r = r + sj.*dKij.mvm(sj.*t);
    end
  else                                                        % plain evaluation
    for j=1:Q, sj = Sq(:,j)*o; r = r + sj.*Kq{j}.mvm(sj.*t); end
  end

% softmax along the second dimension
function s = softmax(w)
  Q = size(w,2); oq = ones(1,Q);
  s = exp(w-max(w,[],2)*oq);
  s = s./(sum(s,2)*oq);
  assert(norm(sum(s,2)-1)<1e-12)

% Upper bound the log determinant of a sum of p symmetric positive semi-definite
% matrices of size nxn represented in terms of their eigenvalues E using Weyl's
% inequalities:
% Let a(nx1) and b(nx1) be the orderered eigenvalues of the Hermitian matrices
% A(nxn) and B(nxn). Then, the eigenvalues c(nx1) of the matrix C = A+B are
% upper bounded by c(i+j-1) <= a(i)+b(j).
%
% Each of the p columns of the matrix E of size nxp contains the
% eigenvalues of the respective matrix.
% The string mod contains the method of upper bounding the log determinant:
%  'halve'  - select eigenvalues of roughly equal indices i,j
%  'greedy' - select the locally smallest bound; the integer s is the number of
%             search steps ahead
%
% The index matrix I can be used to reconstruct the contributions of the
% individual eigenvalues to the overall log determinant upper bound given by:
%   ldu = sum(log( sum(E(I),2) ));
function [ldu,I] = logdet_weyl(E,mod,s)
if nargin<2, mod = 'halve'; end
[n,p] = size(E); I = zeros(n,p);                         % get dims, init result
[e,I(:,1)] = sort(E(:,1),'descend');                    % get upper bound so far
i = ceil((1:n)'/2); j = (1:n)'-i+1;        % default: similar index size (halve)
for k=2:p
  [fs,js] = sort(E(:,k),'descend');                     % sort both k-th summand
  if strcmp(mod,'greedy')     % set up i and j adaptively using greedy heuristic
    if nargin<3, s = 1; end                                  % set default value
    i(1) = 1; j(1) = 1;                                                   % init
    for m=2:n
      ss = s;                           % reduce steps if it violates boundaries
      if ss>  i(m-1) || ss>  j(m-1), ss = min(  i(m-1),   j(m-1)); end
      if ss>n-i(m-1) || ss>n-j(m-1), ss = min(n-i(m-1), n-j(m-1)); end
      i_ss = i(m-1)+(1-ss:ss); j_ss = m-i_ss+1;                % left/right step
      [junk,idx] = min(e(i_ss) + fs(j_ss));   % keep min of steps left and right
      i(m) = i_ss(idx); j(m) = j_ss(idx);
    end
  end
  I(:,1:k-1) = I(i,1:k-1);               % sort previous summands according to i
  I(:,k) = js(j) + (k-1)*n;              % bound contributions from k-th summand
  e = e(i) + fs(j);                              % compute Weyl eigenvalue bound
end
ldu = sum(log(e));
assert( norm(e-sum(E(I),2))<1e-10 )
if ~isinf(ldu), assert( abs(ldu-sum(log(sum(E(I),2))))<1e-10 ), end
% END

function dx = deriv_x(P,xg,mu,mean,hyp,x,vi,myCProb,scale,c,opt)
  xP = x*P'; [M,dM] = covGrid('interp',xg,xP); % grid interp derivative matrices
  aveyCProb = sum(myCProb, 2)/opt.vi.mccnt;
  dxP = zeros(size(xP)); h = 1e-4;               % allocate result, num deriv step
  for i=1:size(xP,2)
    if equi(xg,i), wi = max(xg{i})-min(xg{i}); else wi = 1; end % scaling factor
    xP(:,i) = xP(:,i)+h;
    dmi = (feval(mean{:},hyp.mean,xP)-mu)/h;         % numerically estimate dm/di
    xP(:,i) = xP(:,i)-h;
    for si=1:opt.vi.mccnt
      %temp_u = opt.vi.u(1:10,si,c)
      %temp_dM = dM{i}(1:10,:)
      %sum_dM = sum(dM{i}(:))
      %temp_dMu = dM{i}*opt.vi.u(:,si,c);
      %temp_dMu = temp_dMu(1:10,:)
      %temp_myCProb = myCProb(1:10,:)
      dxP(:,i) = dxP(:,i) + myCProb(:,si).*(dM{i}*opt.vi.u(:,si,c)); 
    end
    dxP(:,i) = dxP(:,i)/opt.vi.mccnt + aveyCProb.*dmi;
    %dxP(:,i) = dxP(:,i)/opt.vi.mccnt; %TODO
  end
  dx = -dxP*P/size(x,1); % normalize by #data

% compute derivative of neg log marginal likelihood w.r.t. projection matrix P
function dP = deriv_P(P,xg,mu,mean,hyp,x,vi,myCProb,scale,c)
  xP = x*P'; [M,dM] = covGrid('interp',xg,xP); % grid interp derivative matrices
  aveyCProb = sum(myCProb, 2)/opt.vi.mccnt;
  dP = zeros(size(P)); h = 1e-4;               % allocate result, num deriv step
  for i=1:size(P,1)
    if equi(xg,i), wi = max(xg{i})-min(xg{i}); else wi = 1; end % scaling factor
    xP(:,i) = xP(:,i)+h;
    dmi = (feval(mean{:},hyp.mean,xP)-mu)/h;         % numerically estimate dm/di
    xP(:,i) = xP(:,i)-h;
    for j=1:size(P,2)
      dMPj = dM{i}.*repmat(x(:,j),[1,size(dM{i},2)]);
      for si=1:opt.vi.mccnt
        dP(i,j) = dP(i,j) + myCProb(:,si)'*(x(:,j).*(dM{i}*opt.vi.u(:,si,c)));
      end
      dP(i,j) = dP(i,j)/opt.vi.mccnt + aveyCProb'*(x(:,j).*dmi);
    end 
  end 
  dP = -scale*dP;

function y = zfill(z,n)                                      % fill z with zeros
  y = accumarray((1:numel(z))',z,[n,1]);

function eq = equi(xg,i)                        % grid along dim i is equispaced
  ni = size(xg{i},1);
  if ni>1                              % diagnose if data is linearly increasing
    dev = abs(diff(xg{i})-ones(ni-1,1)*(xg{i}(2,:)-xg{i}(1,:)));
    eq = max(dev(:))<1e-9;
  else
    eq = true;
  end

% chain rule for the function Q = sqrtm(P*P')\P;  for d sqrtm(X) see the website
function dQ = chain_ortho(P,dP) % http://math.stackexchange.com/questions/540361
  [V,F] = eig(P*P'); sf = sqrt(diag(F)); S = V*diag(sf)*V';         % eig-decomp
  H = dP'/S; G = H'*(P'/S); o = ones(size(dP,1),1);                 % chain rule
  dQ = (H - P'*V*((V'*(G+G')*V)./(sf*o'+o*sf'))*V')';

function q = solveMVM(p,mvm,varargin) % solve q = mvm(p) via conjugate gradients
  [q,flag,relres,iter] = conjgrad(mvm,p,varargin{:});                 % like pcg
  if ~flag
    error('Not converged after %d iterations, r=%1.2e\n',iter,relres);
    exit;
  end

function q = mvmK(p,K,sn2,M) % mvm q = M*K*M'*p + sn2*p using the Kronecker repr
  q = M*K.mvm(M'*p) + sn2*p;

function fmuc = pred_muc(xsc,vim,covc,pi,hyp,mean,C) % predict interpolated latent mean
  fmuc = [];
  for c=1:C
    xs = xsc{c};
    xg = covc{c}{3};
    Ms = covGrid('interp',xg,xs);                    % obtain interpolation matrix
    %xs = covGrid('idx2dat',xg,xs);                        % deal with index vector
    ms = feval(mean{:},hyp{c}.mean,xs);                       % evaluate prior mean
    fmuc = [fmuc ms+Ms*vim(:,c)];                % combine and perform grid interpolation
  end
  fmuc = fmuc*pi'; % mixing layer

function y = pred_y(fmuc) % predict interpolated latent mean
  [~,y] = max(fmuc,[],2); %indexes of the max elements

function liks = pred_lik(fmuc, ys)
  C = size(fmuc,2);
  fmuc1 = exp(fmuc-repmat(max(fmuc,[],2),[1,C]));
  fmuc1 = fmuc1./repmat(sum(fmuc1,2),[1,C]);
  ys_ext = zeros(size(ys,1),size(fmuc,2));
  for c=1:C, ys_ext(ys==c,c)=1; end
  fy = fmuc1.*ys_ext;
  liks = sum(sum(log(fy(fy~=0))))/size(ys,1);

function fs2 = pred_s2(xs,vg,xg,hyp,cov)  % predict interpolated latent variance
  if norm(vg,1)>1e10
    ve = covGrid('interp',xg,xs)*vg;       % interpolate grid variance explained
  else
    ve = 0;
  end
  xs = covGrid('idx2dat',xg,xs);                        % deal with index vector
  ks = feval(cov{:},hyp.cov,xs,'diag');                % evaluate prior variance
  fs2 = max(ks-ve,0);                % combine, perform grid interpolation, clip

% Solve x=A*b with symmetric A(n,n), b(n,m), x(n,m) using conjugate gradients.
% The method is along the lines of PCG but suited for matrix inputs b.
function [x,flag,relres,iter,r] = conjgrad(A,b,tol,maxit)
if nargin<3, tol = 1e-10; end
if nargin<4, maxit = min(size(b,1),20); end
x0 = zeros(size(b)); x = x0;
if isnumeric(A), r = b-A*x; else r = b-A(x); end, r2 = sum(r.*r,1); r2new = r2;
nb = sqrt(sum(b.*b,1)); flag = 0; iter = 1;
relres = sqrt(r2)./nb; todo = relres>=tol; if ~any(todo), flag = 1; return, end
on = ones(size(b,1),1); r = r(:,todo); d = r;
for iter = 2:maxit
  if isnumeric(A), z = A*d; else z = A(d); end
  a = r2(todo)./sum(d.*z,1);
  a = on*a;
  x(:,todo) = x(:,todo) + a.*d;
  r = r - a.*z;
  r2new(todo) = sum(r.*r,1);
  relres = sqrt(r2new)./nb; cnv = relres(todo)<tol; todo = relres>=tol;
  d = d(:,~cnv); r = r(:,~cnv);                           % get rid of converged
  if ~any(todo), flag = 1; return, end
  b = r2new./r2;                                               % Fletcher-Reeves
  d = r + (on*b(todo)).*d;
  r2 = r2new;
end

% Eigendecomposition of a Kronecker matrix with dense, Toeplitz or BTTB factors.
% K.mvm(z) == kronmvm(V,e.*kronmvm(V,z,1)), approximate for Toeplitz/BTTB
% Vfull: the full eigenvectors
function [V,Vfull,E,e,ir] = eigkron(K,xg,ng,circ_emb)
isbttb = @(Ki) isstruct(Ki) && (strcmp (Ki.descr,'toep') ...
                             || strncmp(Ki.descr,'bttb',4));   % BTTB covariance
p = numel(K.kron); V = cell(p,1); Vfull = cell(p,1); E = cell(p,1);     % sizes and allocate memory
e = 1; ir = 1;      % compute eigenvalue diagonal matrix full/reduced rank index
for j=1:p
  if isbttb(K.kron(j).factor)
    disp 'bttb Vfull not supported';
    exit

    [xj,nj] = covGrid('expand',xg{j}); Nj = prod(nj);     % extract subgrid size
    ej = max(real(fftn( circ(K.kron(j).factor.kii,nj,circ_emb) )),0);
    sNj = sqrt(Nj);                            % eigenvalues of circulant matrix
    ep = ej>0; Ej = ej(ep); sub = @(s,ep) s(ep);    % non-negative eigenvalues
    V{j}.mvmt = @(v) sub(fftn(v)/sNj,ep);       % V{j} is partial Fourier matrix
    V{j}.mvm = @(v) ifftn(reshape(accumarray(find(ep),v,[Nj,1]),[nj,1]))*sNj;
    V{j}.size = [numel(ep),sum(ep)];
  else 
    [V{j},Ej,Vfull{j}] = eigr(K.kron(j).factor,0); % eigenvalues of non-Toeplitz mat
  end
  E{j} = Ej;
end
for j=1:p
  de = zfill(E{j},ng(j)); e = kron(de,e); ir = kron(de>0,ir);
  E{j} = de;
end
ir = ir~=0;


% Real eigenvalues and eigenvectors up to the rank of a real symmetric matrix.
% Decompose A into V*D*V' with orthonormal matrix V and diagonal matrix
% D = diag(d)represented by a column vector.
% Entries of d with index larger tha the rank r of the matrix A as returned by
% the call rank(A,tol) are zero.
% Vfull: the full eigenvectors
function [V,d,Vfull] = eigr(A,econ,tol)
if nargin<2, econ = false; end                          % assign a default value
if isnumeric(econ), econ = econ==0; end                      % turn into boolean
if ischar(econ), econ = strcmp(econ,'econ'); end             % turn into boolean
if nargout==0, return, end                     % avoid computation in limit case
if nargout==1
  d = sort(eig((A+A')/2),'descend');       % eigenvalues of strictly symmetric A
else
  [V,D] = eig((A+A')/2);                 % decomposition of strictly symmetric A
  d = max(real(diag(D)),0); [d,ord] = sort(d,'descend');      % tidy up and sort
end
n = size(A,1);                                                  % dimensionality
if nargin<3, tol = n*eps(max(d)); end, r = sum(d>tol);             % get rank(A)
d(r+1:n) = 0;                             % set junk eigenvalues to strict zeros
if econ, d = d(1:r); end                                   % truncate if desired
if nargout==1                                        % only eigenvalues required
  V = d;
else                                             % entire decomposition required
  V(:,1:r) = real(V(:,ord(1:r)));                    % eigenvectors up to rank r
  if econ
    if nargout>2
      Vfull = V;
    end
    V = V(:,1:r);
  else                                            % ortho completion if required
    V(:,r+1:n) = null(V(:,1:r)');
    if nargout>2
      Vfull = V;
    end
  end
end

% Construct a circular embedding c(nx1) from a covariance function k.
%  - k is a function and the call k(1:n) returns a vector of length n
%  - s is the setting for the embedding with values s={..,-2,-1,0,1,2,3,4,5}.
%
% s<0  Whittle embedding [2] as described by Guinness & Fuentes [3] in
%      equation (5) with N = |s|.
% s=0  No embedding c = k(n).
% s=1  Generic embedding c = [k(1:n/2), k(n/2:-1:2)] without smoothing
%      as in the Strang preconditioner.
% s=2  Variant of smoothing inspired by [1] respecting the conditions
%      c(1) = k(1), c(i)=c(n-i+2) for i=2..n and idempotence. We use the generic
%      scheme c(2:n) = w.*k(2:end) + flipud(w.*k(2:end)), where w are sigmoidal
%      interpolation weights.
% s=3  Linear smoothing as in T. Chan's preconditioner.
% s=4  Tyrtyshnikov preconditioner.
% s=5  R. Chan's preconditioner.
%
% [1] Helgason, Pipiras & Abry, Smoothing windows for the synthesis of
%     Gaussian stationary random fields using circulant matrix embedding,
%     Journal of Computational and Graphical Statistics, 2014, 23(3).
% [2] Whittle, On stationary processes in the plane, Biometrika, 1954, 41(3/4).
% [3] Guinness & Fuentes, Circulant embedding of approximate covariances for
%     inference from Gaussian data on large lattices, 2014, preprint,
%     http://www4.stat.ncsu.edu/~guinness/circembed.html.
function c = circ(k,n,s)
p = numel(n); n = n(:)';                             % dimensions and row vector
if nargin<3, s = -2; end                                         % default value
if s==0                                                    % no embedding at all
  xg = cell(p,1); for i=1:p, xg{i} = (1:n(i))'; end              % standard grid
  xc = covGrid('expand',xg);
  c = reshape(k(xc),[n,1]);
elseif s<0                                           % Whittle/Guinness aliasing
  N = abs(s);
  xg = cell(p,1); for i=1:p, xg{i} = (1-N*n(i):N*n(i))'; end
  sz = [n; 2*N*ones(1,p)]; sz = sz(:)';
  c = reshape(k(covGrid('expand',xg)),sz);
  for i=1:p, c = sum(c,2*i); end, c = squeeze(c);
else
  if     s==1                                                 % Strang embedding
    xg = cell(p,1);
    for i=1:p, n2 = floor(n(i)/2)+1; xg{i} = [1:n2, n2-n(i)+1:0]'; end
    xc = covGrid('expand',xg);
    c = reshape(k(xc),[n,1]);
  elseif s==2                                               % Helgason smoothing
    if numel(n)>1, error('Only 1d allowed'), end
    k = k(1:n); k = k(:);
    k0 = k(2:n); % smooth by sigmoid interpolation between the two flipped parts
    wk = k0./(1+exp( 30*(((1:n-1)-n/2)'/n) ));           % sigmoid-weighted part
    c = [k(1); wk + flipud(wk)];
  elseif s==3                                       % T. Chan's linear smoothing
    if numel(n)>1, error('Only 1d allowed'), end
    k = k(1:n); k = k(:);
    w = (1:n-1)'/n;
    c = k; c(2:n) = (1-w).*c(2:n) + w.*flipud(c(2:n));
  elseif s==4                                                     % Tyrtyshnikov
    if numel(n)>1, error('Only 1d allowed'), end
    k = k(1:n); k = k(:);
    w = (1:n-1)'/n;
    c = k; c2 = c(2:n).*c(2:n);
    c(2:n) = ((1-w).*c2 + w.*flipud(c2)) ./ ((1-w).*c(2:n) + w.*flipud(c(2:n)));
  elseif s==5                                                          % R. Chan
    if numel(n)>1, error('Only 1d allowed'), end
    k = k(1:n); k = k(:);
    c = k; c(2:n) = c(2:n) + flipud(c(2:n));                   % mirrored add-up
  end
end
assert(numel(c)==prod(n))                                               % length
c2 = c;
for i=1:p
  c2 = reshape(c2,[prod(n(1:i-1)-1),n(i),prod(n(i+1:end))]); c2 = c2(:,2:end,:);
end
c2f = c2; for i=1:p, c2f = flipdim(c2f,i); end
if s~=0, assert(max(reshape(abs(c2-c2f),[],1))<1e-10), end         % circularity
