diff --git a/CMakeLists.txt b/CMakeLists.txt index 80bbfc6b..eead313d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -560,6 +560,8 @@ if(MSVC) string(APPEND CMAKE_CXX_FLAGS " /FS") string(APPEND CMAKE_CUDA_FLAGS " -Xcompiler /FS") + # CUTLASS_CONSTEXPR_IF_CXX17 must be constexpr. Correct the __cplusplus value + string(APPEND CMAKE_CUDA_FLAGS " -Xcompiler /Zc:__cplusplus") endif(MSVC) string(APPEND CMAKE_CUDA_FLAGS " -Xfatbin -compress-all") diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 93021a74..0755c890 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -437,7 +437,13 @@ if(NOT MSVC AND NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE) endif() if(USE_CUDA AND NOT USE_ROCM) - list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include) + find_package(NvidiaCutlass CONFIG REQUIRED) + get_target_property(CUTLASS_INCLUDE_DIRS nvidia::cutlass::cutlass INTERFACE_INCLUDE_DIRECTORIES) + list(APPEND ATen_CUDA_INCLUDE ${CUTLASS_INCLUDE_DIRS}) + if(MSVC) + # CUTLASS_CONSTEXPR_IF_CXX17 must be constexpr. Correct the __cplusplus value with MSVC + add_compile_options($<$:"-Xcompiler /Zc:__cplusplus">) + endif() if($ENV{ATEN_STATIC_CUDA}) list(APPEND ATen_CUDA_DEPENDENCY_LIBS ${CUDA_LIBRARIES}