Use Apple Accelerate framework for training and best models

Signed-off-by: Stefan Weil <sw@weilnetz.de>
This commit is contained in:
Stefan Weil 2021-02-28 12:04:17 +01:00
parent 3be11f12a9
commit 3ab8dcbf72
2 changed files with 43 additions and 5 deletions

View File

@ -297,13 +297,14 @@ OPENCL_CPPFLAGS=''
OPENCL_LDFLAGS=''
case "${host_os}" in
*darwin* | *-macos10*)
echo "checking for OpenCL framework"
MY_CHECK_FRAMEWORK([OpenCL])
if test $my_cv_framework_OpenCL = yes; then
have_opencl_lib=true
MY_CHECK_FRAMEWORK([Accelerate])
if test $my_cv_framework_Accelerate = yes; then
AM_CPPFLAGS="-DHAVE_FRAMEWORK_ACCELERATE $AM_CPPFLAGS"
LDFLAGS="$LDFLAGS -framework Accelerate"
fi
MY_CHECK_FRAMEWORK([OpenCL])
if test "$enable_opencl" = "yes"; then
if !($have_opencl_lib); then
if test $my_cv_framework_OpenCL = no; then
AC_MSG_ERROR([Required OpenCL library not found!])
fi
AM_CPPFLAGS="-DUSE_OPENCL $AM_CPPFLAGS"

View File

@ -25,6 +25,23 @@
#include "simddetect.h"
#include "tprintf.h" // for tprintf
#if defined(HAVE_FRAMEWORK_ACCELERATE)
// Use Apple Accelerate framework.
// https://developer.apple.com/documentation/accelerate/simd
// Comparison of execution time with different dot product implementations.
// time DOTPRODUCT=accelerate lstm_squashed_test
// Results for Apple M1:
// DotProductGeneric 64 s
// DotProduct 60 s
// DotProductAccelerate 33 s
// DotProductNative 30 s
#include <Accelerate/Accelerate.h>
#endif
#if defined(HAVE_AVX) || defined(HAVE_AVX2) || defined(HAVE_FMA) || defined(HAVE_SSE4_1)
# define HAS_CPUID
#endif
@ -83,6 +100,15 @@ bool SIMDDetect::fma_available_;
bool SIMDDetect::sse_available_;
#endif
#if defined(HAVE_FRAMEWORK_ACCELERATE)
static double DotProductAccelerate(const double* u, const double* v, int n) {
double total = 0.0;
const int stride = 1;
vDSP_dotprD(u, stride, v, stride, &total, n);
return total;
}
#endif
// Computes and returns the dot product of the two n-vectors u and v.
static double DotProductGeneric(const double *u, const double *v, int n) {
double total = 0.0;
@ -110,6 +136,17 @@ static void SetDotProduct(DotProductFunction f, const IntSimdMatrix *m = nullptr
SIMDDetect::SIMDDetect() {
// The fallback is a generic dot product calculation.
SetDotProduct(DotProductGeneric);
const char* dotproduct_env = getenv("DOTPRODUCT");
if (dotproduct_env != nullptr) {
if (strcmp(dotproduct_env, "native") == 0) {
SetDotProduct(DotProductNative);
#if defined(HAVE_FRAMEWORK_ACCELERATE)
} else if (strcmp(dotproduct_env, "accelerate") == 0) {
SetDotProduct(DotProductAccelerate);
}
#endif
return;
}
#if defined(HAS_CPUID)
# if defined(__GNUC__)