| ## UNUSED BY KERNEL-BUILDER | |
| ## File is retained for reference, but is not currently used in the build process. | |
| # import os | |
| # from pathlib import Path | |
| # from datetime import datetime | |
| # import subprocess | |
| # from setuptools import setup, find_packages | |
| # from torch.utils.cpp_extension import ( | |
| # BuildExtension, | |
| # CUDAExtension, | |
| # IS_WINDOWS, | |
| # CUDA_HOME | |
| # ) | |
| # def is_flag_set(flag: str) -> bool: | |
| # return os.getenv(flag, "FALSE").lower() in ["true", "1", "y", "yes"] | |
| # def get_features_args(): | |
| # features_args = [] | |
| # if is_flag_set("FLASH_MLA_DISABLE_FP16"): | |
| # features_args.append("-DFLASH_MLA_DISABLE_FP16") | |
| # return features_args | |
| # def get_arch_flags(): | |
| # # Check NVCC Version | |
| # # NOTE The "CUDA_HOME" here is not necessarily from the `CUDA_HOME` environment variable. For more details, see `torch/utils/cpp_extension.py` | |
| # assert CUDA_HOME is not None, "PyTorch must be compiled with CUDA support" | |
| # nvcc_version = subprocess.check_output( | |
| # [os.path.join(CUDA_HOME, "bin", "nvcc"), '--version'], stderr=subprocess.STDOUT | |
| # ).decode('utf-8') | |
| # nvcc_version_number = nvcc_version.split('release ')[1].split(',')[0].strip() | |
| # major, minor = map(int, nvcc_version_number.split('.')) | |
| # print(f'Compiling using NVCC {major}.{minor}') | |
| # DISABLE_SM100 = is_flag_set("FLASH_MLA_DISABLE_SM100") | |
| # DISABLE_SM90 = is_flag_set("FLASH_MLA_DISABLE_SM90") | |
| # if major < 12 or (major == 12 and minor <= 8): | |
| # assert DISABLE_SM100, "sm100 compilation for Flash MLA requires NVCC 12.9 or higher. Please set FLASH_MLA_DISABLE_SM100=1 to disable sm100 compilation, or update your environment." # TODO Implement this | |
| # arch_flags = [] | |
| # if not DISABLE_SM100: | |
| # arch_flags.extend(["-gencode", "arch=compute_100f,code=sm_100f"]) | |
| # if not DISABLE_SM90: | |
| # arch_flags.extend(["-gencode", "arch=compute_90a,code=sm_90a"]) | |
| # return arch_flags | |
| # def get_nvcc_thread_args(): | |
| # nvcc_threads = os.getenv("NVCC_THREADS") or "32" | |
| # return ["--threads", nvcc_threads] | |
| # subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) | |
| # this_dir = os.path.dirname(os.path.abspath(__file__)) | |
| # if IS_WINDOWS: | |
| # cxx_args = ["/O2", "/std:c++20", "/DNDEBUG", "/W0"] | |
| # else: | |
| # cxx_args = ["-O3", "-std=c++20", "-DNDEBUG", "-Wno-deprecated-declarations"] | |
| # ext_modules = [] | |
| # ext_modules.append( | |
| # CUDAExtension( | |
| # name="flash_mla.cuda", | |
| # sources=[ | |
| # # API | |
| # "csrc/api/api.cpp", | |
| # # Misc kernels for decoding | |
| # "csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu", | |
| # "csrc/smxx/decode/combine/combine.cu", | |
| # # sm90 dense decode | |
| # "csrc/sm90/decode/dense/instantiations/fp16.cu", | |
| # "csrc/sm90/decode/dense/instantiations/bf16.cu", | |
| # # sm90 sparse decode | |
| # "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu", | |
| # "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu", | |
| # "csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu", | |
| # "csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu", | |
| # # sm90 sparse prefill | |
| # "csrc/sm90/prefill/sparse/fwd.cu", | |
| # "csrc/sm90/prefill/sparse/instantiations/phase1_k512.cu", | |
| # "csrc/sm90/prefill/sparse/instantiations/phase1_k512_topklen.cu", | |
| # "csrc/sm90/prefill/sparse/instantiations/phase1_k576.cu", | |
| # "csrc/sm90/prefill/sparse/instantiations/phase1_k576_topklen.cu", | |
| # # sm100 dense prefill & backward | |
| # "csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu", | |
| # "csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu", | |
| # # sm100 sparse prefill | |
| # "csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu", | |
| # "csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k576.cu", | |
| # "csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k512.cu", | |
| # "csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k576.cu", | |
| # "csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_prefill_k512.cu", | |
| # # sm100 sparse decode | |
| # "csrc/sm100/decode/head64/instantiations/v32.cu", | |
| # "csrc/sm100/decode/head64/instantiations/model1.cu", | |
| # "csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu", | |
| # ], | |
| # extra_compile_args={ | |
| # "cxx": cxx_args + get_features_args(), | |
| # "nvcc": [ | |
| # "-O3", | |
| # "-std=c++20", | |
| # "-DNDEBUG", | |
| # "-D_USE_MATH_DEFINES", | |
| # "-Wno-deprecated-declarations", | |
| # "-U__CUDA_NO_HALF_OPERATORS__", | |
| # "-U__CUDA_NO_HALF_CONVERSIONS__", | |
| # "-U__CUDA_NO_HALF2_OPERATORS__", | |
| # "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", | |
| # "--expt-relaxed-constexpr", | |
| # "--expt-extended-lambda", | |
| # "--use_fast_math", | |
| # "--ptxas-options=-v,--register-usage-level=10,--warn-on-spills,--warn-on-local-memory-usage,--warn-on-double-precision-use", | |
| # "-lineinfo", | |
| # "--source-in-ptx", | |
| # ] + get_features_args() + get_arch_flags() + get_nvcc_thread_args(), | |
| # }, | |
| # include_dirs=[ | |
| # Path(this_dir) / "csrc", | |
| # Path(this_dir) / "csrc" / "kerutils" / "include", # TODO Remove me | |
| # Path(this_dir) / "csrc" / "sm90", | |
| # Path(this_dir) / "csrc" / "cutlass" / "include", | |
| # Path(this_dir) / "csrc" / "cutlass" / "tools" / "util" / "include", | |
| # ], | |
| # ) | |
| # ) | |
| # try: | |
| # cmd = ['git', 'rev-parse', '--short', 'HEAD'] | |
| # rev = '+' + subprocess.check_output(cmd).decode('ascii').rstrip() | |
| # except Exception as _: | |
| # now = datetime.now() | |
| # date_time_str = now.strftime("%Y-%m-%d-%H-%M-%S") | |
| # rev = '+' + date_time_str | |
| # setup( | |
| # name="flash_mla", | |
| # version="1.0.0" + rev, | |
| # packages=find_packages(include=['flash_mla']), | |
| # ext_modules=ext_modules, | |
| # cmdclass={"build_ext": BuildExtension}, | |
| # ) | |