Skip to content

michaelfeil/candle-flash-attn-v3

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

28 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Candle Flash Attention v3 Layer

Flash Attention v3 Layer for Hopper (compatible nvidia sm90a arch) and the candle framework.

Usage

use candle_flash_attn_v3;
use anyhow::Result;
use candle::{DType, Device, IndexOp, Tensor, D};

fn flash_attn_acausal() -> Result<()> {
    let device = Device::new_cuda(0)?;
    let q = Tensor::arange(0u32, 3 * 2 * 64, &device)?
        .to_dtype(DType::F16)?
        .reshape((1, 3, 2, 64))?; // batch, head, seqlen, hidden_dim
    let k = (&q / 400.)?;
    let v = (&q / 500.)?;
    let q = (&q / 300.)?;

    let att = {
        let q = q.transpose(1, 2)?;
        let k = k.transpose(1, 2)?;
        let v = v.transpose(1, 2)?;
        candle_flash_attn_v3::flash_attn(&q, &k, &v, 0.5, false, false)?.transpose(1, 2)?
    };

Integration instructions

[dependencies]
candle = { version = "*", package = "candle-core", default-features = false }
candle-nn = { version = "*" }
candle-transformers = { version = "*" }
candle-flash-attn-3 = { git = "https://github.com/michaelfeil/candle-flash-attn-3", rev = "main", optional = true }

Install and test locally:

git submodule update --init --recursive
cargo test --release

About

No description, website, or topics provided.

Resources

License

Apache-2.0, MIT licenses found

Licenses found

Apache-2.0
LICENSE-APACHE
MIT
LICENSE-MIT

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •