Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made critical changes to small_gemm #568

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions frame/2/gemv/bli_gemv_unf_var2.c
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ void PASTEMAC(ch,varname) \
) \
{ \
const num_t dt = PASTEMAC(ch,type); \
\
bli_init_once(); \
\
if(cntx == NULL) cntx = bli_gks_query_cntx(); \
\
ctype* zero = PASTEMAC(ch,0); \
ctype* A1; \
Expand Down
350 changes: 177 additions & 173 deletions frame/2/trsv/bli_trsv_unf_var2.c
Original file line number Diff line number Diff line change
Expand Up @@ -49,179 +49,183 @@ void PASTEMAC(ch,varname) \
cntx_t* cntx \
) \
{ \
const num_t dt = PASTEMAC(ch,type); \
\
ctype* minus_one = PASTEMAC(ch,m1); \
ctype* A01; \
ctype* A11; \
ctype* A21; \
ctype* a01; \
ctype* alpha11; \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please avoid whitespace changes.

ctype* a21; \
ctype* x0; \
ctype* x1; \
ctype* x2; \
ctype* x01; \
ctype* chi11; \
ctype* x21; \
ctype alpha11_conj; \
ctype minus_chi11; \
dim_t iter, i, k, j, l; \
dim_t b_fuse, f; \
dim_t n_ahead, f_ahead; \
inc_t rs_at, cs_at; \
uplo_t uploa_trans; \
conj_t conja; \
\
/* x = alpha * x; */ \
PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \
( \
BLIS_NO_CONJUGATE, \
m, \
alpha, \
x, incx, \
cntx, \
NULL \
); \
\
if ( bli_does_notrans( transa ) ) \
{ \
rs_at = rs_a; \
cs_at = cs_a; \
uploa_trans = uploa; \
} \
else /* if ( bli_does_trans( transa ) ) */ \
{ \
rs_at = cs_a; \
cs_at = rs_a; \
uploa_trans = bli_uplo_toggled( uploa ); \
} \
\
conja = bli_extract_conj( transa ); \
\
PASTECH(ch,axpyf_ker_ft) kfp_af; \
\
/* Query the context for the kernel function pointer and fusing factor. */ \
kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); \
b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); \
\
/* We reduce all of the possible cases down to just lower/upper. */ \
if ( bli_is_upper( uploa_trans ) ) \
{ \
for ( iter = 0; iter < m; iter += f ) \
{ \
f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); \
i = m - iter - f; \
n_ahead = i; \
A11 = a + (i )*rs_at + (i )*cs_at; \
A01 = a + (0 )*rs_at + (i )*cs_at; \
x1 = x + (i )*incx; \
x0 = x + (0 )*incx; \
\
/* x1 = x1 / triu( A11 ); */ \
for ( k = 0; k < f; ++k ) \
{ \
l = f - k - 1; \
f_ahead = l; \
alpha11 = A11 + (l )*rs_at + (l )*cs_at; \
a01 = A11 + (0 )*rs_at + (l )*cs_at; \
chi11 = x1 + (l )*incx; \
x01 = x1 + (0 )*incx; \
\
/* chi11 = chi11 / alpha11; */ \
if ( bli_is_nonunit_diag( diaga ) ) \
{ \
PASTEMAC(ch,copycjs)( conja, *alpha11, alpha11_conj ); \
PASTEMAC(ch,invscals)( alpha11_conj, *chi11 ); \
} \
\
/* x01 = x01 - chi11 * a01; */ \
PASTEMAC(ch,neg2s)( *chi11, minus_chi11 ); \
if ( bli_is_conj( conja ) ) \
{ \
for ( j = 0; j < f_ahead; ++j ) \
PASTEMAC(ch,axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); \
} \
else \
{ \
for ( j = 0; j < f_ahead; ++j ) \
PASTEMAC(ch,axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); \
} \
} \
\
/* x0 = x0 - A01 * x1; */ \
kfp_af \
( \
conja, \
BLIS_NO_CONJUGATE, \
n_ahead, \
f, \
minus_one, \
A01, rs_at, cs_at, \
x1, incx, \
x0, incx, \
cntx \
); \
} \
} \
else /* if ( bli_is_lower( uploa_trans ) ) */ \
{ \
for ( iter = 0; iter < m; iter += f ) \
{ \
f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); \
i = iter; \
n_ahead = m - iter - f; \
A11 = a + (i )*rs_at + (i )*cs_at; \
A21 = a + (i+f)*rs_at + (i )*cs_at; \
x1 = x + (i )*incx; \
x2 = x + (i+f)*incx; \
\
/* x1 = x1 / tril( A11 ); */ \
for ( k = 0; k < f; ++k ) \
{ \
l = k; \
f_ahead = f - k - 1; \
alpha11 = A11 + (l )*rs_at + (l )*cs_at; \
a21 = A11 + (l+1)*rs_at + (l )*cs_at; \
chi11 = x1 + (l )*incx; \
x21 = x1 + (l+1)*incx; \
\
/* chi11 = chi11 / alpha11; */ \
if ( bli_is_nonunit_diag( diaga ) ) \
{ \
PASTEMAC(ch,copycjs)( conja, *alpha11, alpha11_conj ); \
PASTEMAC(ch,invscals)( alpha11_conj, *chi11 ); \
} \
\
/* x21 = x21 - chi11 * a21; */ \
PASTEMAC(ch,neg2s)( *chi11, minus_chi11 ); \
if ( bli_is_conj( conja ) ) \
{ \
for ( j = 0; j < f_ahead; ++j ) \
PASTEMAC(ch,axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); \
} \
else \
{ \
for ( j = 0; j < f_ahead; ++j ) \
PASTEMAC(ch,axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); \
} \
} \
\
/* x2 = x2 - A21 * x1; */ \
kfp_af \
( \
conja, \
BLIS_NO_CONJUGATE, \
n_ahead, \
f, \
minus_one, \
A21, rs_at, cs_at, \
x1, incx, \
x2, incx, \
cntx \
); \
} \
} \
const num_t dt = PASTEMAC(ch,type); \
\
bli_init_once(); \
\
if( cntx == NULL ) cntx = bli_gks_query_cntx(); \
\
ctype* minus_one = PASTEMAC(ch,m1); \
ctype* A01; \
ctype* A11; \
ctype* A21; \
ctype* a01; \
ctype* alpha11; \
ctype* a21; \
ctype* x0; \
ctype* x1; \
ctype* x2; \
ctype* x01; \
ctype* chi11; \
ctype* x21; \
ctype alpha11_conj; \
ctype minus_chi11; \
dim_t iter, i, k, j, l; \
dim_t b_fuse, f; \
dim_t n_ahead, f_ahead; \
inc_t rs_at, cs_at; \
uplo_t uploa_trans; \
conj_t conja; \
\
/* x = alpha * x; */ \
PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \
( \
BLIS_NO_CONJUGATE, \
m, \
alpha, \
x, incx, \
cntx, \
NULL \
); \
\
if ( bli_does_notrans( transa ) ) \
{ \
rs_at = rs_a; \
cs_at = cs_a; \
uploa_trans = uploa; \
} \
else /* if ( bli_does_trans( transa ) ) */ \
{ \
rs_at = cs_a; \
cs_at = rs_a; \
uploa_trans = bli_uplo_toggled( uploa ); \
} \
\
conja = bli_extract_conj( transa ); \
\
PASTECH(ch,axpyf_ker_ft) kfp_af; \
\
/* Query the context for the kernel function pointer and fusing factor. */ \
kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); \
b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); \
\
/* We reduce all of the possible cases down to just lower/upper. */ \
if ( bli_is_upper( uploa_trans ) ) \
{ \
for ( iter = 0; iter < m; iter += f ) \
{ \
f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); \
i = m - iter - f; \
n_ahead = i; \
A11 = a + (i )*rs_at + (i )*cs_at; \
A01 = a + (0 )*rs_at + (i )*cs_at; \
x1 = x + (i )*incx; \
x0 = x + (0 )*incx; \
\
/* x1 = x1 / triu( A11 ); */ \
for ( k = 0; k < f; ++k ) \
{ \
l = f - k - 1; \
f_ahead = l; \
alpha11 = A11 + (l )*rs_at + (l )*cs_at; \
a01 = A11 + (0 )*rs_at + (l )*cs_at; \
chi11 = x1 + (l )*incx; \
x01 = x1 + (0 )*incx; \
\
/* chi11 = chi11 / alpha11; */ \
if ( bli_is_nonunit_diag( diaga ) ) \
{ \
PASTEMAC(ch,copycjs)( conja, *alpha11, alpha11_conj ); \
PASTEMAC(ch,invscals)( alpha11_conj, *chi11 ); \
} \
\
/* x01 = x01 - chi11 * a01; */ \
PASTEMAC(ch,neg2s)( *chi11, minus_chi11 ); \
if ( bli_is_conj( conja ) ) \
{ \
for ( j = 0; j < f_ahead; ++j ) \
PASTEMAC(ch,axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); \
} \
else \
{ \
for ( j = 0; j < f_ahead; ++j ) \
PASTEMAC(ch,axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); \
} \
} \
\
/* x0 = x0 - A01 * x1; */ \
kfp_af \
( \
conja, \
BLIS_NO_CONJUGATE, \
n_ahead, \
f, \
minus_one, \
A01, rs_at, cs_at, \
x1, incx, \
x0, incx, \
cntx \
); \
} \
} \
else /* if ( bli_is_lower( uploa_trans ) ) */ \
{ \
for ( iter = 0; iter < m; iter += f ) \
{ \
f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); \
i = iter; \
n_ahead = m - iter - f; \
A11 = a + (i )*rs_at + (i )*cs_at; \
A21 = a + (i+f)*rs_at + (i )*cs_at; \
x1 = x + (i )*incx; \
x2 = x + (i+f)*incx; \
\
/* x1 = x1 / tril( A11 ); */ \
for ( k = 0; k < f; ++k ) \
{ \
l = k; \
f_ahead = f - k - 1; \
alpha11 = A11 + (l )*rs_at + (l )*cs_at; \
a21 = A11 + (l+1)*rs_at + (l )*cs_at; \
chi11 = x1 + (l )*incx; \
x21 = x1 + (l+1)*incx; \
\
/* chi11 = chi11 / alpha11; */ \
if ( bli_is_nonunit_diag( diaga ) ) \
{ \
PASTEMAC(ch,copycjs)( conja, *alpha11, alpha11_conj ); \
PASTEMAC(ch,invscals)( alpha11_conj, *chi11 ); \
} \
\
/* x21 = x21 - chi11 * a21; */ \
PASTEMAC(ch,neg2s)( *chi11, minus_chi11 ); \
if ( bli_is_conj( conja ) ) \
{ \
for ( j = 0; j < f_ahead; ++j ) \
PASTEMAC(ch,axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); \
} \
else \
{ \
for ( j = 0; j < f_ahead; ++j ) \
PASTEMAC(ch,axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); \
} \
} \
\
/* x2 = x2 - A21 * x1; */ \
kfp_af \
( \
conja, \
BLIS_NO_CONJUGATE, \
n_ahead, \
f, \
minus_one, \
A21, rs_at, cs_at, \
x1, incx, \
x2, incx, \
cntx \
); \
} \
} \
}

INSERT_GENTFUNC_BASIC0( trsv_unf_var2 )
Expand Down
3 changes: 1 addition & 2 deletions frame/3/gemm/bli_gemm_front.c
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void bli_gemm_front
return;
}

#if 0

#ifdef BLIS_ENABLE_SMALL_MATRIX
// Only handle small problems separately for homogeneous datatypes.
if ( bli_obj_dt( a ) == bli_obj_dt( b ) &&
Expand All @@ -83,7 +83,6 @@ void bli_gemm_front
err_t status = bli_gemm_small( alpha, a, b, beta, c, cntx, cntl );
if ( status == BLIS_SUCCESS ) return;
}
#endif
#endif

// Alias A, B, and C in case we need to apply transformations.
Expand Down
1 change: 0 additions & 1 deletion frame/compat/bla_gemm.c
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,6 @@ void PASTEF77(ch,blasname) \
/* Finalize BLIS. */ \
bli_finalize_auto(); \
}

#endif

#ifdef BLIS_ENABLE_BLAS
Expand Down
Loading