Mamba2-hf
					Collection
				
HF compatible format of  state-spaces/mamba2
					• 
				5 items
				• 
				Updated
					
				
Use the code below to get started with the model.
import torch  
from transformers import AutoTokenizer  
from transformers import Mamba2ForCausalLM
if __name__ == "__main__":
    device = "cuda"
    model_id = "benchang1110/mamba2-370m-hf"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = Mamba2ForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map=device)
    model.eval()
    with torch.no_grad():
      text = input("Input: ")
      input_ids = tokenizer(text, return_tensors="pt").to(device)
      output = model.generate(**input_ids, max_new_tokens=1024, do_sample=False)
      print(tokenizer.decode(output[0], skip_special_tokens=True))
Conversion script: mamba2hf.py
Base model
state-spaces/mamba2-370m