Sad-Comedian-711

Sad-Comedian-711 t1_jcqgv1x wrote

This approach has been shown to work. Longformer even provided a script that did this for you: https://github.com/allenai/longformer/blob/master/scripts/convert_model_to_long.ipynb

I think for flash attention you do not want to use Longformer's attention though, you want to use Big Bird's with specific block sizes or something like that.

1

Sad-Comedian-711 t1_jcpvahm wrote

So there is flash attention and then there is block sparse flash attention.

Flash attention by itself only got them to 16k on an A100 for their model, to go further they needed to use windowed attention... You could have already gone to 16k with windowed attention before this paper without much issue.

The special thing about this windowed attention is that it is in blocks that can fit into SRAM. From what I can tell Python's implementation of Flash Attention doesn't look like it supports block sparse flash attention.

https://github.com/pytorch/pytorch/blob/eb32bb2ca6811ea21002699f4be884d3012dc362/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_fprop_kernel_1xN.h

While Triton's looks like it does:https://github.com/openai/triton/blob/c9740f0870f6ae2480acd2a76a5fb4c920bc5ce5/python/triton/ops/flash_attention.py

I think windowing must be done in blocks that align with the SRAM grid so it kinda has to be part of the Flash Attention implementation. You might be able to throw normal Big Bird block sparse attention on top...

You also may be able to call out to triton's implementation:
https://github.com/violethaze74/pumpkin-py/blob/d9250933bec045e6add61b3930ff3dbbe08f6501/aten/src/ATen/native/transformers/attention.cpp#L726

3