Flash Attention v2 doesnt work
No matter how I try, running with flash attention 2 does not work for me. I get this error:
raise ValueError(
ValueError: GraniteSpeechForConditionalGeneration does not support Flash Attention 2.0 yet. Please request to add support where the model is hosted, on its model hub page: https://huggingface.co/ibm-granite/granite-speech-3.3-8b/discussions/new or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new
Steps to reproduce:
Install flash attention v2
Load model from pretrained with AutoModelForSpeechSeq2Seq with the parameter attn_implementation="flash_attention_2"
Here is my basic code snipet:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
MODEL_ID = "ibm-granite/granite-speech-3.3-8b"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
# Load processor (feature extractor + tokenizer)
processor = AutoProcessor.from_pretrained(MODEL_ID)
# Load model with fast attention enabled (FlashAttention 2)
# Note: requires a recent transformers version that supports this flag
model = AutoModelForSpeechSeq2Seq.from_pretrained(
MODEL_ID,
torch_dtype=dtype,
device_map=device,
attn_implementation="flash_attention_2", # key part for fast attention
)
# Build ASR pipeline
asr_pipeline = pipeline(
task="automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=dtype,
device=0 if device == "cuda" else -1,
)
I have the following versions installed:
transformers 4.57.3
torch 2.9.1
flash_attn 2.8.3
The post https://huggingface.co/ibm-granite/granite-speech-3.3-8b/discussions/11 implies that flash attention should be supported and working. If anyone could give me some pointers if its simply not supported or if there are ways to make it work please let me know.
Hi @abcdefghijklmnop52 - Thank you for opening this!
Internally, this model uses Conformer and Blip2QFormer to process the audio (i.e., prefill), and a granite LLM to process the embedded result. The main issue is that the parts for the audio encoding do not support flash attention 2 at the moment, which I think is why it has been disabled on the model class.
With that being said, the underlying LLM that processes the multimodal embedding + decoding can still use flash attention 2 / other attention implementations! To try this, instead of passing attn_implementation="flash_attention_2", you should pass attn_implementation={"text_config": "flash_attention_2"}. These docs in transformers may also help