#include <mex.h>

/*
 * function [x] = block_trs(A, b, trans_flag);
 *
 * Solve Ax = b where the A is stored in block packed format
 * from packtri.  trans_flag is 0 (no transpose) or 1 (transpose).
 * Default is no transpose.
 */


/*
 * SUBROUTINE TRTRS( UPLO, TRANS, DIAG, N, NRHS, A, LDA,  B,
 *                   LDB, INFO )
 */
void dtrtrs_(char* uplo, char* trans, char* diag,
             int* n, int* nrhs,
             double* a, int* lda,
             double* b, int* ldb,
             int* info);

/*
 * SUBROUTINE GEMM ( TRANSA, TRANSB, M, N, K, ALPHA, A, LDA,
 *                   B, LDB, BETA, C, LDC )
 */
void dgemm_(char* transA, char* transB, 
            int* m, int* n, int* k, double* alpha,
            double* A, int* ldA,
            double* B, int* ldB, double* beta, 
            double* C, int* ldC);



/* 
 * Blocked backsolve:
 */
void block_trs(double* R, int k, int kb, double* b, int m, char trans)
{
    int n, J, info;
    double one     = 1;
    double neg_one = -1;

    if (trans == 'N') {

        /* Ordinary blocked backsolve */
        n = k;
        for (J = (k-1)/kb; J >= 0; --J) {
            
            int Jkb    = J*kb;
            int Jkbp   = Jkb + kb;
            int blockn = n-Jkb;
        
            dtrtrs_("U", "N", "N", &blockn, &m,
                    R    + Jkb + (Jkb*Jkbp)/2,  &Jkbp,
                    b    + Jkb,                 &k,
                    &info);
        
            dgemm_("N", "N", &Jkb, &m, &blockn,
                   &neg_one,
                   R    + (Jkb*Jkbp)/2,  &Jkbp,
                   b    + Jkb,           &k,
                   &one,
                   b,   &k);

            n = Jkb;
            
        }

    } else {

        /* Transposed blocked backsolve */
        for (J = 0; J <= (k-1)/kb; ++J) {
            
            int Jkb    = J*kb;
            int Jkbp   = Jkb + kb;
            int blockn = (Jkbp < k) ? kb : (k - Jkb);

            dgemm_("T", "N", &blockn, &m, &Jkb,
                   &neg_one,
                   R    + (Jkb*Jkbp)/2,  &Jkbp,
                   b,                    &k,
                   &one,
                   b    + Jkb,           &k);
        
            dtrtrs_("U", "T", "N", &blockn, &m,
                    R    + Jkb + (Jkb*Jkbp)/2,  &Jkbp,
                    b    + Jkb,                 &k,
                    &info);

        }

    }
}

void mexFunction(int nlhs, mxArray** plhs,
                 int nrhs, const mxArray** prhs)
{
    double* A;
    double* b;
    double* x;
    int k, m, kb;
    char trans;

    A  = mxGetPr(prhs[0]);
    b  = mxGetPr(prhs[1]);
    kb = mxGetM(prhs[0]);
    k  = mxGetM(prhs[1]);
    m  = mxGetN(prhs[1]);

    if (nrhs < 3) {
        trans = 'N';
    } else {
        if (*mxGetPr(prhs[3])) {
            trans = 'T';
        } else {
            trans = 'N';
        }
    }

    plhs[0] = mxCreateDoubleMatrix(k,m, mxREAL);
    x       = mxGetPr(plhs[0]);
    memcpy(x, b, k*m * sizeof(double));
    block_trs(A, k, kb, x, m, trans);
}

