注:在此之前确保环境中已经安装packaging模块
pip install packaging
若此模块安装过程中出现图下报错:这是由于环境创建时候的权限问题导致的,具体解决方案见我另一篇博文!
1.Triton模块安装
注意:必须先安装triton,否则后续mamba_ssm安装会出错!
(1)Triton模块在huggingface上的下载地址:https://hf-mirror.com/madbuda/triton-windows-builds,选择对应python版本下载到本地。
(2)将下载好的文件放到方便的路径下:我习惯放到用户目录下,这样在anaconda prompt中安装比较方便;当然其他路径也可以,安装的时候记得cd到文件对应路径下安装即可。
(3)cd到模块文件路径下,pip install即可(温馨提示:文件名过长,可输入前几个字符后按Tab键自动补全!)
2.casual-conv1d安装
(1)模块下载地址:casual-conv1d
下载解压放到本地路径下:
(2)对文件中的setup.py文件做部分修改:
FORCE_BUILD = os.getenv("CAUSAL_CONV1D_FORCE_BUILD", "FALSE") == "TRUE"
SKIP_CUDA_BUILD = os.getenv("CAUSAL_CONV1D_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
FORCE_CXX11_ABI = os.getenv("CAUSAL_CONV1D_FORCE_CXX11_ABI", "FALSE") == "TRUE"
修改为:
FORCE_BUILD = os.getenv("CAUSAL_CONV1D_FORCE_BUILD", "FALSE") == "FALSE"
SKIP_CUDA_BUILD = os.getenv("CAUSAL_CONV1D_SKIP_CUDA_BUILD", "FALSE") == "FALSE"
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
FORCE_CXX11_ABI = os.getenv("CAUSAL_CONV1D_FORCE_CXX11_ABI", "FALSE") == "FALSE"
(3)cd到文件路径下:输入pip install . 进行安装
3.mamba_ssm安装
务必保证triton模块成功安装后在安装mamba_ssm,否则会出现如下错误!
(1)下载地址:mamba_ssm
同样,下载好后解压放到本地路径
(2)修改setup.py部分代码:
FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE"
SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
修改为:
FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "FALSE"
SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "FALSE"
再修改mamba_ssm/ops/selective_scan_interface.py 两处代码:
第一处:
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,return_last_state=False):"""if return_last_state is True, returns (out, last_state)last_state has shape (batch, dim, dstate). Note that the gradient of the last state isnot considered in the backward pass."""return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
修改为:
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,return_last_state=False):"""if return_last_state is True, returns (out, last_state)last_state has shape (batch, dim, dstate). Note that the gradient of the last state isnot considered in the backward pass."""return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
第二处:
def mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,out_proj_weight, out_proj_bias,A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,C_proj_bias=None, delta_softplus=True
):return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,out_proj_weight, out_proj_bias,A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
修改为:
def mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,out_proj_weight, out_proj_bias,A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,C_proj_bias=None, delta_softplus=True
):return mamba_inner_ref(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,out_proj_weight, out_proj_bias,A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)
(3)文件准备就绪,cd到文件所在路径下,执行如下代码安装。
pip install .
(4)安装成功:
温馨提示:所有模块的安装,请先activate对应的pytorch环境!