環境:WSL的Ubuntu24.04
1.創建conda環境,其中python版本為3.10.13
2.當前conda環境依次執行下面命令:
conda install cudatoolkit==11.8 -c nvidia
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 -f https://mirrors.aliyun.com/pytorch-wheels/cu118
conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
conda install packaging
3.安裝causal_conv1d和mamba_ssm
3.1嘗試使用pip線上安裝causal_conv1d:
pip install causal-conv1d==1.5.2
失敗。
3.2嘗試本地安裝:
在
https://github.com/Dao-AILab/causal-conv1d/releases/tag/v1.5.0
下載
causal_conv1d-1.5.0+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
在
https://github.com/state-spaces/mamba/releases/tag/v2.0.3
下載
mamba_ssm-2.0.3+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
注意這兩個whl文件名中的“cu”后的數字(即cuda版本)要與上一步安裝的cuda版本相同,“torch”后面的數字要與上一步安裝的torch版本相同,“cp”后面的數字要與你當前環境的python版本相同。
在這兩個whl文件所在目錄下分別執行:
pip install 文件名
根據打印信息確定是否安裝成功。
python導入:
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
但打印如下信息:
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.2.6 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.
If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.
...
/miniconda3/envs/mapy31013/lib/python3.10/site-packages/torch/nn/modules/transformer.py:20: UserWarning: Failed to initialize NumPy: _ARRAY_API not found (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:84.)
device: torch.device = torch.device(torch._C._get_default_device()), ?# torch.device('cpu'),
就是說當前虛擬環境的NumPy版本為numpy2+,需要使用numpy<2版本才能導入causal_conv1d。
于是先
pip uninstall numpy
在
https://pypi.org/project/numpy/
網頁的左側點擊“Release history”找到合適的numpy版本(我選的是1.26.4),再執行
pip install numpy==1.26.4
順利安裝numpy1.26.4后再次執行
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
不報錯則安裝成功。
3.3python導入:
from mamba_ssm import Mamba, Mamba2
不報錯則mamba安裝成功。