vcpkg/ports/libtorch/fix-aten-cutlass.patch

33 lines
1.3 KiB
Diff
Raw Normal View History

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 80bbfc6b..eead313d 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -560,6 +560,8 @@ if(MSVC)
string(APPEND CMAKE_CXX_FLAGS " /FS")
string(APPEND CMAKE_CUDA_FLAGS " -Xcompiler /FS")
+ # CUTLASS_CONSTEXPR_IF_CXX17 must be constexpr. Correct the __cplusplus value
+ string(APPEND CMAKE_CUDA_FLAGS " -Xcompiler /Zc:__cplusplus")
endif(MSVC)
string(APPEND CMAKE_CUDA_FLAGS " -Xfatbin -compress-all")
diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt
index 93021a74..0755c890 100644
--- a/aten/src/ATen/CMakeLists.txt
+++ b/aten/src/ATen/CMakeLists.txt
@@ -437,7 +437,13 @@ if(NOT MSVC AND NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
endif()
if(USE_CUDA AND NOT USE_ROCM)
- list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
+ find_package(NvidiaCutlass CONFIG REQUIRED)
+ get_target_property(CUTLASS_INCLUDE_DIRS nvidia::cutlass::cutlass INTERFACE_INCLUDE_DIRECTORIES)
+ list(APPEND ATen_CUDA_INCLUDE ${CUTLASS_INCLUDE_DIRS})
+ if(MSVC)
+ # CUTLASS_CONSTEXPR_IF_CXX17 must be constexpr. Correct the __cplusplus value with MSVC
+ add_compile_options($<$<COMPILE_LANGUAGE:CUDA>:"-Xcompiler /Zc:__cplusplus">)
+ endif()
if($ENV{ATEN_STATIC_CUDA})
list(APPEND ATen_CUDA_DEPENDENCY_LIBS
${CUDA_LIBRARIES}