//------------------------------------------------------------------------------
// GB_mex_AdotB: compute C=spones(Mask).*(A'*B)
//------------------------------------------------------------------------------

// SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2020, All Rights Reserved.
// http://suitesparse.com   See GraphBLAS/Doc/License.txt for license.

//------------------------------------------------------------------------------

// Returns a plain MATLAB sparse matrix, not a struct.  Only works in double
// and complex.  Input matrices must be MATLAB sparse matrices, or GraphBLAS
// structs in CSC format.

#include "GB_mex.h"

#define USAGE "C = GB_mex_AdotB (A,B,Mask,flipxy)"

#define FREE_ALL                        \
{                                       \
    GB_MATRIX_FREE (&A) ;               \
    GB_MATRIX_FREE (&Aconj) ;           \
    GB_MATRIX_FREE (&B) ;               \
    GB_MATRIX_FREE (&C) ;               \
    GB_MATRIX_FREE (&Mask) ;            \
    GrB_free (&add) ;                   \
    GrB_free (&semiring) ;              \
    GB_mx_put_global (true, GxB_AxB_DOT) ; \
}

GrB_Matrix A = NULL, B = NULL, C = NULL, Aconj = NULL, Mask = NULL ;
GrB_Monoid add = NULL ;
GrB_Semiring semiring = NULL ;
GrB_Info adotb_complex (GB_Context Context) ;
GrB_Info adotb (GB_Context Context) ;
GrB_Index anrows, ancols, bnrows, bncols, mnrows, mncols ;
bool flipxy = false ;

//------------------------------------------------------------------------------

GrB_Info adotb_complex (GB_Context Context)
{
    GrB_Info info = GrB_Matrix_new (&Aconj, Complex, anrows, ancols) ;
    if (info != GrB_SUCCESS) return (info) ;
    info = GrB_apply (Aconj, NULL, NULL, Complex_conj, A, NULL) ;
    if (info != GrB_SUCCESS)
    {
        GrB_free (&Aconj) ;
        return (info) ;
    }

    // force completion
    info = GrB_wait ( ) ;
    if (info != GrB_SUCCESS)
    {
        GrB_free (&Aconj) ;
        return (info) ;
    }

    bool mask_applied = false ;

    GrB_Semiring semiring = Complex_plus_times ;

    GrB_Matrix Aslice [1] ;
    Aslice [0] = Aconj ;

    if (Mask != NULL)
    {
        // C<M> = A'*B using dot product method
        info = GB_AxB_dot3 (&C, Mask, false, Aconj, B, semiring, flipxy, Context);
        mask_applied = true ;
    }
    else
    {
        // C = A'*B using dot product method
        info = GB_AxB_dot2 (&C, NULL, false, Aslice, B, semiring, flipxy,
            &mask_applied,
            /* single thread: */
            1, 1, 1, Context) ;
    }

    GrB_free (&Aconj) ;
    return (info) ;
}

//------------------------------------------------------------------------------

GrB_Info adotb (GB_Context Context) 
{
    // create the Semiring for regular z += x*y
    GrB_Info info = GrB_Monoid_new (&add, GrB_PLUS_FP64, (double) 0) ;
    if (info != GrB_SUCCESS) return (info) ;
    info = GrB_Semiring_new (&semiring, add, GrB_TIMES_FP64) ;
    if (info != GrB_SUCCESS)
    {
        GrB_free (&add) ;
        return (info) ;
    }
    // C = A'*B
    bool mask_applied = false ;
    GrB_Matrix Aslice [1] ;
    Aslice [0] = A ;

    if (Mask != NULL)
    {
        // C<M> = A'*B using dot product method
        info = GB_AxB_dot3 (&C, Mask, false, A, B,
            semiring /* GxB_PLUS_TIMES_FP64 */,
            flipxy, Context) ;
        mask_applied = true ;
    }
    else
    {
        info = GB_AxB_dot2 (&C, NULL, false, Aslice, B,
            semiring /* GxB_PLUS_TIMES_FP64 */,
            flipxy, &mask_applied,
            // single thread:
            1, 1, 1, Context) ;
    }

    GrB_free (&add) ;
    GrB_free (&semiring) ;
    return (info) ;
}

//------------------------------------------------------------------------------

void mexFunction
(
    int nargout,
    mxArray *pargout [ ],
    int nargin,
    const mxArray *pargin [ ]
)
{

    bool malloc_debug = GB_mx_get_global (true) ;

    GB_WHERE (USAGE) ;

    // check inputs
    if (nargout > 1 || nargin < 2 || nargin > 4)
    {
        mexErrMsgTxt ("Usage: " USAGE) ;
    }

    #define GET_DEEP_COPY ;
    #define FREE_DEEP_COPY ;

    GET_DEEP_COPY ;
    // get A and B (shallow copies)
    A = GB_mx_mxArray_to_Matrix (pargin [0], "A input", false, true) ;
    B = GB_mx_mxArray_to_Matrix (pargin [1], "B input", false, true) ;
    if (A == NULL)
    {
        FREE_ALL ;
        mexErrMsgTxt ("A failed") ;
    }
    if (B == NULL)
    {
        FREE_ALL ;
        mexErrMsgTxt ("B failed") ;
    }

    GrB_Matrix_nrows (&anrows, A) ;
    GrB_Matrix_ncols (&ancols, A) ;

    GrB_Matrix_nrows (&bnrows, B) ;
    GrB_Matrix_ncols (&bncols, B) ;

    if (!A->is_csc || !B->is_csc)
    {
        FREE_ALL ;
        mexErrMsgTxt ("matrices must be CSC only") ;
    }

    // get Mask (shallow copy)
    if (nargin > 2)
    {
        Mask = GB_mx_mxArray_to_Matrix (pargin [2], "Mask input", false, false);

        GrB_Matrix_nrows (&mnrows, Mask) ;
        GrB_Matrix_ncols (&mncols, Mask) ;

        if (!Mask->is_csc)
        {
            FREE_ALL ;
            mexErrMsgTxt ("matrices must be CSC only") ;
        }

        if (mnrows != ancols || mncols != bncols)
        {
            FREE_ALL ;
            mexErrMsgTxt ("mask wrong dimension") ;
        }
    }

    if (anrows != bnrows)
    {
        FREE_ALL ;
        mexErrMsgTxt ("inner dimensions of A'*B do not match") ;
    }

    // get flipxy
    GET_SCALAR (3, bool, flipxy, false) ;

    if (A->type == Complex)
    {
        // C = A'*B, complex case
        METHOD (adotb_complex (Context)) ;
    }
    else
    {
        METHOD (adotb (Context)) ;
    }

    // return C to MATLAB
    pargout [0] = GB_mx_Matrix_to_mxArray (&C, "C AdotB result", false) ;

    FREE_ALL ;
}

