TIL: How a Harmless Refactor Exposed a Hidden CUDA Bug in Vision-Language Models
Debugging GPU errors is like detective work: you think you've found the culprit, but it turns out the real suspect has been hiding in the shadows all along.
This week, while maintaining TRL, I chased down one such bug that began as a CI failure and ended as a deep dive into CUDA index mechanics.
Here's what happened and what I learned along the way. ๐
๐ฅ The Symptom: A Mysterious CI Failure
Our continuous integration (CI) started failing only when using the latest dev dependencies, suddenly flooded with cryptic error messages like:
torch.AcceleratorError: CUDA error: device-side assert triggered
You know the kind of error that doesn't tell you where or why, only that something went catastrophically wrong on the GPU. ๐
After a preliminary investigation, I discovered that the first occurrences were happening with Vision-Language Models (VLMs) such as Gemma3ForConditionalGeneration, during tests like:
tests/test_grpo_trainer.py::TestGRPOTrainer::test_training_vlm_0_trl_internal_testing_tiny_Gemma3ForConditionalGeneration
I opened a tracking issue pointing to the specific causing tests: ๐ CI fails with dev dependencies: torch.AcceleratorError: CUDA error: device-side assert triggered
๐ต๏ธ The False Trail
My first hypothesis was that something broke in a recent Transformers PR, since the error only appeared after updating to the latest dev version.
So, I started bisecting and eventually found the exact PR after which the error appeared: ๐ #41505: Refactor cache initialization in generate()
At first, it looked like the cause. CI failures began right after it merged. But appearances can be deceiving: the real issue was lurking much deeper.
๐ The Deep Investigation
Determined to find the real culprit, I dove deep into the logs. Line after line of CUDA traces later, I finally spotted a clue: an "index out of bounds" error quietly revealing what was really going on beneath the surface.
The stack trace hinted at a low-level CUDA kernel crash:
Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed.
I realized something fascinating:
- The PR wasn't breaking anything: it was revealing a latent bug that had always been there.
- The refactor changed cache behavior just enough to expose a silent indexing error that had previously gone unnoticed.
That's the kind of debugging moment where you both curse and cheer at the same time! ๐
Time to find the real culprit.
โ๏ธ The Root Cause: A CUDA Index Out of Bounds
After tracing through the VLM masking code, I found it.
In the function handling token-type masking for bidirectional image attention, there were two index tensors:
kv_idxq_idx
A previous PR had correctly added bounds-checking for kv_idx, but overlooked that q_idx also needed the same protection.
๐ #39396: Fix bidirectional image mask for Gemma3
That omission was harmless... until static caches entered the picture.
During generation, the cache dimension can be larger than the actual input sequence (e.g. a static cache for 2048 tokens with an input of only 512). This means some q_idx values can exceed token_type_ids.shape[1].
When the masking code accessed:
token_type_ids[batch_idx, q_idx]
with q_idx beyond the sequence length, CUDA threw a device-side assert, and the GPU crashed.
So the real issue wasn't the refactor in #41505, it was a missing safety guard for q_idx.
๐งฉ The Fix: Complete the Bounds Check
I opened a PR to Transformers to finish what #39396 started: ๐ #41757 โ Fix CUDA index out of bounds for q_idx in VLM token type masking
The fix was straightforward but crucial:
safe_q_idx = torch.where(q_idx < token_type_ids.shape[1], q_idx, 0)
token_type_ids_at_q_idx = torch.where(
q_idx < token_type_ids.shape[1],
token_type_ids[batch_idx, safe_q_idx],
0,
)
Now both kv_idx and q_idx are safely bounded within valid dimensions: even when cache shapes exceed the input length.
๐ง Lessons Learned
A refactor can reveal, not cause, a bug. The cache logic change in #41505 wasn't the villain: it was the light that made the real bug visible.
Static cache โ input sequence. Always be mindful that caching layers may hold more positions than the current sequence length.
Device-side asserts are detective clues. CUDA doesn't tell you much, but the message "index out of bounds" almost always means you're indexing a tensor beyond its valid range on the GPU.
Document your debug journey. Writing detailed issue and PR notes not only helps reviewers but also your future self (and the next maintainer who hits a similar bug at 2 AM).
โ The Outcome
After merging the fix:
TRL's CI is getting green again ๐
Transformers' masking logic is becoming safer and more consistent across:
Gemma3ForConditionalGenerationPaliGemmaForConditionalGeneration- Example modular transformer templates
๐ The Takeaway
Sometimes, the hardest bugs aren't the ones introduced yesterday: they're the ones that have always been there, silently waiting for the right change to expose them.
In this case, one refactor illuminated a hidden CUDA indexing issue: and one small bounds check made Transformers (and TRL) a little more robust ๐ช
๐ References
- ๐ TRL issue: #4281: CI fails with dev dependencies: torch.AcceleratorError: CUDA error: device-side assert triggered
- ๐ง Transformers fixing PR #41757: Fix CUDA index out of bounds for q_idx in VLM token type masking for Gemma3, PaliGemma, and example modular
Original related PRs:
๐ฌ Special thanks to @joaogante (@gante @GH) and @RaushanTurganbay (@zucchini-nlp @GH) for their earlier work in these related PRs.