Skip to content

vllm.v1.attention.backends.rocm_aiter_fa

Attention layer with AiterFlashAttention.

_CP_TOKENS_PER_ITER_ROCM module-attribute

_CP_TOKENS_PER_ITER_ROCM = 32 * 1024

_PARTITION_SIZE_ROCM module-attribute

_PARTITION_SIZE_ROCM = 256

logger module-attribute

logger = init_logger(__name__)

AiterChunkContextMetadata dataclass

Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
@dataclass
class AiterChunkContextMetadata:
    workspace: torch.Tensor
    cu_seq_lens_chunk: torch.Tensor
    chunk_starts: torch.Tensor
    token_to_batch: torch.Tensor
    seq_tot: list[int]
    max_seq_lens: list[int]
    seq_lens: torch.Tensor
    num_chunks: int
    total_token_per_batch: list[int]

chunk_starts instance-attribute

chunk_starts: Tensor

cu_seq_lens_chunk instance-attribute

cu_seq_lens_chunk: Tensor

max_seq_lens instance-attribute

max_seq_lens: list[int]

num_chunks instance-attribute

num_chunks: int

seq_lens instance-attribute

seq_lens: Tensor

seq_tot instance-attribute

seq_tot: list[int]

token_to_batch instance-attribute

token_to_batch: Tensor

total_token_per_batch instance-attribute

total_token_per_batch: list[int]

workspace instance-attribute

workspace: Tensor

__init__

__init__(
    workspace: Tensor,
    cu_seq_lens_chunk: Tensor,
    chunk_starts: Tensor,
    token_to_batch: Tensor,
    seq_tot: list[int],
    max_seq_lens: list[int],
    seq_lens: Tensor,
    num_chunks: int,
    total_token_per_batch: list[int],
) -> None

AiterFlashAttentionBackend

Bases: AttentionBackend

Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
class AiterFlashAttentionBackend(AttentionBackend):
    accept_output_buffer: bool = True

    @classmethod
    def get_supported_dtypes(cls) -> list[torch.dtype]:
        return [torch.float16, torch.bfloat16]

    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
        return [64, 128, 256]

    @staticmethod
    def get_supported_kernel_block_size() -> list[int | MultipleOf]:
        return [MultipleOf(16)]

    @classmethod
    def validate_head_size(cls, head_size: int) -> None:
        supported_head_sizes = cls.get_supported_head_sizes()
        if head_size not in supported_head_sizes:
            attn_type = cls.__name__.removesuffix("Backend")
            raise ValueError(
                f"Head size {head_size} is not supported by {attn_type}. "
                f"Supported head sizes are: {supported_head_sizes}. "
                "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
                "FlexAttention backend which supports all head sizes."
            )

    @staticmethod
    def get_name() -> str:
        return "FLASH_ATTN"

    @staticmethod
    def get_impl_cls() -> type["AiterFlashAttentionImpl"]:
        return AiterFlashAttentionImpl

    @staticmethod
    def get_metadata_cls() -> type["AttentionMetadata"]:
        return AiterFlashAttentionMetadata

    @staticmethod
    def get_builder_cls() -> type["AiterFlashAttentionMetadataBuilder"]:
        return AiterFlashAttentionMetadataBuilder

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
        cache_dtype_str: str = "auto",
    ) -> tuple[int, ...]:
        if block_size % 16 != 0:
            raise ValueError("Block size must be a multiple of 16.")

        return (2, num_blocks, block_size, num_kv_heads, head_size)

accept_output_buffer class-attribute instance-attribute

accept_output_buffer: bool = True

get_builder_cls staticmethod

get_builder_cls() -> type[
    AiterFlashAttentionMetadataBuilder
]
Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
@staticmethod
def get_builder_cls() -> type["AiterFlashAttentionMetadataBuilder"]:
    return AiterFlashAttentionMetadataBuilder

get_impl_cls staticmethod

get_impl_cls() -> type[AiterFlashAttentionImpl]
Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
@staticmethod
def get_impl_cls() -> type["AiterFlashAttentionImpl"]:
    return AiterFlashAttentionImpl

get_kv_cache_shape staticmethod

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]
Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
    if block_size % 16 != 0:
        raise ValueError("Block size must be a multiple of 16.")

    return (2, num_blocks, block_size, num_kv_heads, head_size)

get_metadata_cls staticmethod

get_metadata_cls() -> type[AttentionMetadata]
Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
    return AiterFlashAttentionMetadata

get_name staticmethod

get_name() -> str
Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
@staticmethod
def get_name() -> str:
    return "FLASH_ATTN"

get_supported_dtypes classmethod

get_supported_dtypes() -> list[dtype]
Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
    return [torch.float16, torch.bfloat16]

get_supported_head_sizes classmethod

get_supported_head_sizes() -> list[int]
Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
@classmethod
def get_supported_head_sizes(cls) -> list[int]:
    return [64, 128, 256]

get_supported_kernel_block_size staticmethod

get_supported_kernel_block_size() -> list[int | MultipleOf]
Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
    return [MultipleOf(16)]

validate_head_size classmethod

validate_head_size(head_size: int) -> None
Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
@classmethod
def validate_head_size(cls, head_size: int) -> None:
    supported_head_sizes = cls.get_supported_head_sizes()
    if head_size not in supported_head_sizes:
        attn_type = cls.__name__.removesuffix("Backend")
        raise ValueError(
            f"Head size {head_size} is not supported by {attn_type}. "
            f"Supported head sizes are: {supported_head_sizes}. "
            "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
            "FlexAttention backend which supports all head sizes."
        )

AiterFlashAttentionChunkPrefillMetadata dataclass

Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
@dataclass
class AiterFlashAttentionChunkPrefillMetadata:
    max_query_len: int
    min_query_len: int
    max_seq_len: int
    query_start_loc: torch.Tensor
    chunk_context_metadata: AiterChunkContextMetadata

chunk_context_metadata instance-attribute

chunk_context_metadata: AiterChunkContextMetadata

max_query_len instance-attribute

max_query_len: int

max_seq_len instance-attribute

max_seq_len: int

min_query_len instance-attribute

min_query_len: int

query_start_loc instance-attribute

query_start_loc: Tensor

__init__

__init__(
    max_query_len: int,
    min_query_len: int,
    max_seq_len: int,
    query_start_loc: Tensor,
    chunk_context_metadata: AiterChunkContextMetadata,
) -> None

AiterFlashAttentionDecodeMetadata dataclass

Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
@dataclass
class AiterFlashAttentionDecodeMetadata:
    max_query_len: int
    min_query_len: int
    max_seq_len: int
    query_start_loc: torch.Tensor

max_query_len instance-attribute

max_query_len: int

max_seq_len instance-attribute

max_seq_len: int

min_query_len instance-attribute

min_query_len: int

query_start_loc instance-attribute

query_start_loc: Tensor

__init__

__init__(
    max_query_len: int,
    min_query_len: int,
    max_seq_len: int,
    query_start_loc: Tensor,
) -> None

AiterFlashAttentionImpl

Bases: AttentionImpl

Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
class AiterFlashAttentionImpl(AttentionImpl):
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
        kv_cache_dtype: str,
        logits_soft_cap: float | None = None,
        attn_type: AttentionType = AttentionType.DECODER,
        kv_sharing_target_layer_name: int | None = None,
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
        if sliding_window is None:
            self.sliding_window = [-1, -1]
        else:
            self.sliding_window = [sliding_window - 1, 0]
        self.kv_cache_dtype = kv_cache_dtype
        if logits_soft_cap is None:
            # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
            logits_soft_cap = 0.0
        self.logits_soft_cap = logits_soft_cap
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        AiterFlashAttentionBackend.validate_head_size(head_size)

        if attn_type != AttentionType.DECODER:
            raise NotImplementedError(
                "Encoder self-attention and "
                "encoder/decoder cross-attention "
                "are not implemented for "
                "FlashAttentionImpl"
            )

    def extend_forward(
        self,
        attn_metadata: AiterFlashAttentionMetadata,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        output: torch.Tensor,
        cu_seqlens_q: torch.Tensor,
        max_seqlen_q: int,
        max_seqlen_k: int,
        min_seqlen_q: int,
        block_table: torch.Tensor,
        slot_mapping: torch.Tensor,
        k_scale: float,
        v_scale: float,
    ):
        out, lse = aiter.flash_attn_varlen_func(
            q=query,
            k=key,
            v=value,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_k=cu_seqlens_q,
            max_seqlen_q=max_seqlen_q,
            max_seqlen_k=max_seqlen_q,
            min_seqlen_q=min_seqlen_q,
            dropout_p=0.0,
            softmax_scale=self.scale,
            causal=True,
            window_size=self.sliding_window,
            alibi_slopes=self.alibi_slopes,
            return_lse=True,
        )
        assert attn_metadata.extend_metadata is not None
        chunk_context_metadata = attn_metadata.extend_metadata.chunk_context_metadata
        num_chunks = chunk_context_metadata.num_chunks
        workspace = chunk_context_metadata.workspace
        cu_seqlens_kv = chunk_context_metadata.cu_seq_lens_chunk
        max_seqlens = chunk_context_metadata.max_seq_lens
        chunk_starts = chunk_context_metadata.chunk_starts
        token_to_batch = chunk_context_metadata.token_to_batch
        total_token_per_batch = chunk_context_metadata.total_token_per_batch
        key_fetched, value_fetched = workspace[0], workspace[1]
        chunked_output = None
        chunked_lse = None
        for chunk_idx in range(num_chunks):
            cp_mha_gather_cache(
                key_cache=key_cache,
                value_cache=value_cache,
                key=key_fetched,
                value=value_fetched,
                block_tables=block_table,
                k_scales=k_scale,
                v_scales=v_scale,
                cu_seqlens_kv=cu_seqlens_kv[chunk_idx],
                token_to_batch=token_to_batch[chunk_idx],
                seq_starts=chunk_starts[chunk_idx],
                dequant=False,
                kv_cache_layout="NHD",
                total_tokens=total_token_per_batch[chunk_idx],
            )

            suf_out, suf_lse = aiter.flash_attn_varlen_func(
                q=query,
                k=key_fetched,
                v=value_fetched,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_kv[chunk_idx],
                max_seqlen_q=max_seqlen_q,
                max_seqlen_k=max_seqlens[chunk_idx],
                min_seqlen_q=min_seqlen_q,
                dropout_p=0.0,
                softmax_scale=self.scale,
                causal=False,
                window_size=self.sliding_window,
                alibi_slopes=self.alibi_slopes,
                return_lse=True,
            )
            if chunked_output is None:
                chunked_output = suf_out
                chunked_lse = suf_lse
            else:
                tmp_output = torch.empty_like(out)
                tmp_lse = torch.empty_like(lse)
                merge_attn_states(
                    output=tmp_output,
                    output_lse=tmp_lse,
                    prefix_output=chunked_output,
                    prefix_lse=chunked_lse,
                    suffix_output=suf_out,
                    suffix_lse=suf_lse,
                )
                chunked_output = tmp_output
                chunked_lse = tmp_lse

        merge_attn_states(
            output=output,
            prefix_output=chunked_output,
            prefix_lse=chunked_lse,
            suffix_output=out,
            suffix_lse=lse,
        )

    def forward(
        self,
        layer: torch.nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AiterFlashAttentionMetadata,
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Forward pass with AiterFlashAttention.

        Args:
            query: shape = [num_tokens, num_heads, head_size]
            key: shape = [num_tokens, num_kv_heads, head_size]
            value: shape = [num_tokens, num_kv_heads, head_size]
            kv_cache: shape =
                [2, num_blocks, block_size, num_kv_heads, head_size]
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        NOTE: FP8 quantization, flash-attn expect the size of
              {q,k,v}_descale to be (num_sequences, num_kv_heads).
              We use torch's .expand() to avoid duplicating values
        """
        assert output is not None, "Output tensor must be provided."

        if output_scale is not None or output_block_scale is not None:
            raise NotImplementedError(
                "fused output quantization is not yet supported for FlashAttentionImpl"
            )

        if attn_metadata is None:
            # Profiling run.
            return output.fill_(0)

        # IMPORTANT!
        # NOTE(woosuk): With piece-wise CUDA graphs, this method is
        # executed in eager-mode PyTorch. Thus, we need to be careful
        # about any CPU overhead in this method. For example, `view`
        # and `slice` (or `[:n]`) operations are surprisingly slow even
        # in the case they do not invoke any GPU ops.
        # Minimize the PyTorch ops in this method as much as possible.
        # Whenever making a change in this method, please benchmark the
        # performance to make sure it does not introduce any overhead.
        num_actual_tokens = attn_metadata.num_actual_tokens
        key_cache, value_cache = kv_cache.unbind(0)
        if self.kv_sharing_target_layer_name is None:
            # Reshape the input keys and values and store them in the cache.
            # Skip this if sharing KV cache with an earlier attention layer.
            # NOTE(woosuk): Here, key and value are padded while slot_mapping
            # is not padded. However, we don't need to do
            # key[:num_actual_tokens] and value[:num_actual_tokens] because
            # the reshape_and_cache_flash op uses the slot_mapping's shape
            # to determine the number of actual tokens.

            torch.ops._C_cache_ops.reshape_and_cache_flash(
                key,
                value,
                key_cache,
                value_cache,
                attn_metadata.slot_mapping,
                self.kv_cache_dtype,
                layer._k_scale,
                layer._v_scale,
            )

        if self.kv_cache_dtype.startswith("fp8"):
            key_cache = key_cache.view(current_platform.fp8_dtype())
            value_cache = value_cache.view(current_platform.fp8_dtype())

        # decode:extend:prefill
        query = query[:num_actual_tokens]
        key = key[:num_actual_tokens]
        value = value[:num_actual_tokens]

        output_actual_tokens = output[:num_actual_tokens]

        num_decodes = attn_metadata.num_decodes
        num_prefills = attn_metadata.num_prefills
        num_extends = attn_metadata.num_extends

        num_decode_tokens = attn_metadata.num_decode_tokens
        num_extend_tokens = attn_metadata.num_extend_tokens
        if not attn_metadata.use_cascade:
            # calculate for pure prefills
            if num_prefills > 0:
                assert attn_metadata.prefill_metadata is not None

                prefill_query = query[num_decode_tokens + num_extend_tokens :]
                prefill_key = key[num_decode_tokens + num_extend_tokens :]
                prefill_value = value[num_decode_tokens + num_extend_tokens :]

                aiter.flash_attn_varlen_func(
                    q=prefill_query,
                    k=prefill_key,
                    v=prefill_value,
                    cu_seqlens_q=attn_metadata.prefill_metadata.query_start_loc,
                    cu_seqlens_k=attn_metadata.prefill_metadata.query_start_loc,
                    max_seqlen_q=attn_metadata.prefill_metadata.max_query_len,
                    max_seqlen_k=attn_metadata.prefill_metadata.max_seq_len,
                    min_seqlen_q=attn_metadata.prefill_metadata.min_query_len,
                    dropout_p=0.0,
                    softmax_scale=self.scale,
                    causal=True,
                    window_size=self.sliding_window,
                    alibi_slopes=self.alibi_slopes,
                    out=output_actual_tokens[num_decode_tokens + num_extend_tokens :],
                )

            # calculate for extends
            if num_extends > 0:
                assert attn_metadata.extend_metadata is not None
                extend_tokens_slice = slice(
                    num_decode_tokens, num_decode_tokens + num_extend_tokens
                )
                extend_querys = query[extend_tokens_slice]
                extend_keys = key[extend_tokens_slice]
                extend_values = value[extend_tokens_slice]
                extend_outputs = output[extend_tokens_slice]
                self.extend_forward(
                    attn_metadata=attn_metadata,
                    query=extend_querys,
                    key=extend_keys,
                    value=extend_values,
                    key_cache=key_cache,
                    value_cache=value_cache,
                    output=extend_outputs,
                    cu_seqlens_q=attn_metadata.extend_metadata.query_start_loc,
                    max_seqlen_q=attn_metadata.extend_metadata.max_query_len,
                    max_seqlen_k=attn_metadata.extend_metadata.max_seq_len,
                    min_seqlen_q=attn_metadata.extend_metadata.min_query_len,
                    block_table=attn_metadata.block_table[
                        num_decodes : num_decodes + num_extends
                    ],
                    slot_mapping=attn_metadata.slot_mapping[
                        num_decodes : num_decodes + num_extends
                    ],
                    k_scale=layer._k_scale,
                    v_scale=layer._v_scale,
                )

            # calculate for decodes
            if num_decodes > 0:
                assert attn_metadata.decode_metadata is not None
                _, num_heads, head_size = query.shape
                nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8
                num_seqs = attn_metadata.seq_lens.shape[0]
                max_num_partitions = (
                    attn_metadata.max_seq_len + _PARTITION_SIZE_ROCM - 1
                ) // _PARTITION_SIZE_ROCM

                workspace_buffer = torch.empty(
                    (num_seqs * num_heads * max_num_partitions * head_size)
                    * nbytes_per_qo_elem
                    + 2 * (num_seqs * num_heads * max_num_partitions) * 4,
                    dtype=torch.uint8,
                    device=output.device,
                )

                torch.ops.aiter.paged_attention_v1(
                    output[:num_decode_tokens],
                    workspace_buffer,
                    query[:num_decode_tokens],
                    key_cache,
                    value_cache,
                    self.scale,
                    attn_metadata.block_table[:num_decodes],
                    attn_metadata.query_start_loc[:num_decodes],
                    attn_metadata.seq_lens[:num_decodes],
                    attn_metadata.max_seq_len,
                    self.alibi_slopes,
                    self.kv_cache_dtype,
                    "NHD",
                    self.logits_soft_cap,
                    layer._k_scale,
                    layer._v_scale,
                    None,
                    _PARTITION_SIZE_ROCM,
                )
        else:
            raise NotImplementedError(
                "Cascade attention is not implemented for ROCM AITER"
            )

        return output

alibi_slopes instance-attribute

alibi_slopes = alibi_slopes

head_size instance-attribute

head_size = head_size

kv_cache_dtype instance-attribute

kv_cache_dtype = kv_cache_dtype

kv_sharing_target_layer_name instance-attribute

kv_sharing_target_layer_name = kv_sharing_target_layer_name

logits_soft_cap instance-attribute

logits_soft_cap = logits_soft_cap

num_heads instance-attribute

num_heads = num_heads

num_kv_heads instance-attribute

num_kv_heads = num_kv_heads

num_queries_per_kv instance-attribute

num_queries_per_kv = num_heads // num_kv_heads

scale instance-attribute

scale = float(scale)

sliding_window instance-attribute

sliding_window = [-1, -1]

__init__

__init__(
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: list[float] | None,
    sliding_window: int | None,
    kv_cache_dtype: str,
    logits_soft_cap: float | None = None,
    attn_type: AttentionType = DECODER,
    kv_sharing_target_layer_name: int | None = None,
) -> None
Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
def __init__(
    self,
    num_heads: int,
    head_size: int,
    scale: float,
    num_kv_heads: int,
    alibi_slopes: list[float] | None,
    sliding_window: int | None,
    kv_cache_dtype: str,
    logits_soft_cap: float | None = None,
    attn_type: AttentionType = AttentionType.DECODER,
    kv_sharing_target_layer_name: int | None = None,
) -> None:
    self.num_heads = num_heads
    self.head_size = head_size
    self.scale = float(scale)
    self.num_kv_heads = num_kv_heads
    if alibi_slopes is not None:
        alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
    self.alibi_slopes = alibi_slopes
    if sliding_window is None:
        self.sliding_window = [-1, -1]
    else:
        self.sliding_window = [sliding_window - 1, 0]
    self.kv_cache_dtype = kv_cache_dtype
    if logits_soft_cap is None:
        # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
        logits_soft_cap = 0.0
    self.logits_soft_cap = logits_soft_cap
    self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

    assert self.num_heads % self.num_kv_heads == 0
    self.num_queries_per_kv = self.num_heads // self.num_kv_heads

    AiterFlashAttentionBackend.validate_head_size(head_size)

    if attn_type != AttentionType.DECODER:
        raise NotImplementedError(
            "Encoder self-attention and "
            "encoder/decoder cross-attention "
            "are not implemented for "
            "FlashAttentionImpl"
        )

extend_forward

extend_forward(
    attn_metadata: AiterFlashAttentionMetadata,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    key_cache: Tensor,
    value_cache: Tensor,
    output: Tensor,
    cu_seqlens_q: Tensor,
    max_seqlen_q: int,
    max_seqlen_k: int,
    min_seqlen_q: int,
    block_table: Tensor,
    slot_mapping: Tensor,
    k_scale: float,
    v_scale: float,
)
Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
def extend_forward(
    self,
    attn_metadata: AiterFlashAttentionMetadata,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    output: torch.Tensor,
    cu_seqlens_q: torch.Tensor,
    max_seqlen_q: int,
    max_seqlen_k: int,
    min_seqlen_q: int,
    block_table: torch.Tensor,
    slot_mapping: torch.Tensor,
    k_scale: float,
    v_scale: float,
):
    out, lse = aiter.flash_attn_varlen_func(
        q=query,
        k=key,
        v=value,
        cu_seqlens_q=cu_seqlens_q,
        cu_seqlens_k=cu_seqlens_q,
        max_seqlen_q=max_seqlen_q,
        max_seqlen_k=max_seqlen_q,
        min_seqlen_q=min_seqlen_q,
        dropout_p=0.0,
        softmax_scale=self.scale,
        causal=True,
        window_size=self.sliding_window,
        alibi_slopes=self.alibi_slopes,
        return_lse=True,
    )
    assert attn_metadata.extend_metadata is not None
    chunk_context_metadata = attn_metadata.extend_metadata.chunk_context_metadata
    num_chunks = chunk_context_metadata.num_chunks
    workspace = chunk_context_metadata.workspace
    cu_seqlens_kv = chunk_context_metadata.cu_seq_lens_chunk
    max_seqlens = chunk_context_metadata.max_seq_lens
    chunk_starts = chunk_context_metadata.chunk_starts
    token_to_batch = chunk_context_metadata.token_to_batch
    total_token_per_batch = chunk_context_metadata.total_token_per_batch
    key_fetched, value_fetched = workspace[0], workspace[1]
    chunked_output = None
    chunked_lse = None
    for chunk_idx in range(num_chunks):
        cp_mha_gather_cache(
            key_cache=key_cache,
            value_cache=value_cache,
            key=key_fetched,
            value=value_fetched,
            block_tables=block_table,
            k_scales=k_scale,
            v_scales=v_scale,
            cu_seqlens_kv=cu_seqlens_kv[chunk_idx],
            token_to_batch=token_to_batch[chunk_idx],
            seq_starts=chunk_starts[chunk_idx],
            dequant=False,
            kv_cache_layout="NHD",
            total_tokens=total_token_per_batch[chunk_idx],
        )

        suf_out, suf_lse = aiter.flash_attn_varlen_func(
            q=query,
            k=key_fetched,
            v=value_fetched,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_k=cu_seqlens_kv[chunk_idx],
            max_seqlen_q=max_seqlen_q,
            max_seqlen_k=max_seqlens[chunk_idx],
            min_seqlen_q=min_seqlen_q,
            dropout_p=0.0,
            softmax_scale=self.scale,
            causal=False,
            window_size=self.sliding_window,
            alibi_slopes=self.alibi_slopes,
            return_lse=True,
        )
        if chunked_output is None:
            chunked_output = suf_out
            chunked_lse = suf_lse
        else:
            tmp_output = torch.empty_like(out)
            tmp_lse = torch.empty_like(lse)
            merge_attn_states(
                output=tmp_output,
                output_lse=tmp_lse,
                prefix_output=chunked_output,
                prefix_lse=chunked_lse,
                suffix_output=suf_out,
                suffix_lse=suf_lse,
            )
            chunked_output = tmp_output
            chunked_lse = tmp_lse

    merge_attn_states(
        output=output,
        prefix_output=chunked_output,
        prefix_lse=chunked_lse,
        suffix_output=out,
        suffix_lse=lse,
    )

forward

forward(
    layer: Module,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    kv_cache: Tensor,
    attn_metadata: AiterFlashAttentionMetadata,
    output: Tensor | None = None,
    output_scale: Tensor | None = None,
    output_block_scale: Tensor | None = None,
) -> Tensor

Forward pass with AiterFlashAttention.

Parameters:

Name Type Description Default
query Tensor

shape = [num_tokens, num_heads, head_size]

required
key Tensor

shape = [num_tokens, num_kv_heads, head_size]

required
value Tensor

shape = [num_tokens, num_kv_heads, head_size]

required
kv_cache Tensor

shape = [2, num_blocks, block_size, num_kv_heads, head_size]

required
attn_metadata AiterFlashAttentionMetadata

Metadata for attention.

required

Returns: shape = [num_tokens, num_heads * head_size] NOTE: FP8 quantization, flash-attn expect the size of {q,k,v}_descale to be (num_sequences, num_kv_heads). We use torch's .expand() to avoid duplicating values

Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
def forward(
    self,
    layer: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_metadata: AiterFlashAttentionMetadata,
    output: torch.Tensor | None = None,
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
    """Forward pass with AiterFlashAttention.

    Args:
        query: shape = [num_tokens, num_heads, head_size]
        key: shape = [num_tokens, num_kv_heads, head_size]
        value: shape = [num_tokens, num_kv_heads, head_size]
        kv_cache: shape =
            [2, num_blocks, block_size, num_kv_heads, head_size]
        attn_metadata: Metadata for attention.
    Returns:
        shape = [num_tokens, num_heads * head_size]
    NOTE: FP8 quantization, flash-attn expect the size of
          {q,k,v}_descale to be (num_sequences, num_kv_heads).
          We use torch's .expand() to avoid duplicating values
    """
    assert output is not None, "Output tensor must be provided."

    if output_scale is not None or output_block_scale is not None:
        raise NotImplementedError(
            "fused output quantization is not yet supported for FlashAttentionImpl"
        )

    if attn_metadata is None:
        # Profiling run.
        return output.fill_(0)

    # IMPORTANT!
    # NOTE(woosuk): With piece-wise CUDA graphs, this method is
    # executed in eager-mode PyTorch. Thus, we need to be careful
    # about any CPU overhead in this method. For example, `view`
    # and `slice` (or `[:n]`) operations are surprisingly slow even
    # in the case they do not invoke any GPU ops.
    # Minimize the PyTorch ops in this method as much as possible.
    # Whenever making a change in this method, please benchmark the
    # performance to make sure it does not introduce any overhead.
    num_actual_tokens = attn_metadata.num_actual_tokens
    key_cache, value_cache = kv_cache.unbind(0)
    if self.kv_sharing_target_layer_name is None:
        # Reshape the input keys and values and store them in the cache.
        # Skip this if sharing KV cache with an earlier attention layer.
        # NOTE(woosuk): Here, key and value are padded while slot_mapping
        # is not padded. However, we don't need to do
        # key[:num_actual_tokens] and value[:num_actual_tokens] because
        # the reshape_and_cache_flash op uses the slot_mapping's shape
        # to determine the number of actual tokens.

        torch.ops._C_cache_ops.reshape_and_cache_flash(
            key,
            value,
            key_cache,
            value_cache,
            attn_metadata.slot_mapping,
            self.kv_cache_dtype,
            layer._k_scale,
            layer._v_scale,
        )

    if self.kv_cache_dtype.startswith("fp8"):
        key_cache = key_cache.view(current_platform.fp8_dtype())
        value_cache = value_cache.view(current_platform.fp8_dtype())

    # decode:extend:prefill
    query = query[:num_actual_tokens]
    key = key[:num_actual_tokens]
    value = value[:num_actual_tokens]

    output_actual_tokens = output[:num_actual_tokens]

    num_decodes = attn_metadata.num_decodes
    num_prefills = attn_metadata.num_prefills
    num_extends = attn_metadata.num_extends

    num_decode_tokens = attn_metadata.num_decode_tokens
    num_extend_tokens = attn_metadata.num_extend_tokens
    if not attn_metadata.use_cascade:
        # calculate for pure prefills
        if num_prefills > 0:
            assert attn_metadata.prefill_metadata is not None

            prefill_query = query[num_decode_tokens + num_extend_tokens :]
            prefill_key = key[num_decode_tokens + num_extend_tokens :]
            prefill_value = value[num_decode_tokens + num_extend_tokens :]

            aiter.flash_attn_varlen_func(
                q=prefill_query,
                k=prefill_key,
                v=prefill_value,
                cu_seqlens_q=attn_metadata.prefill_metadata.query_start_loc,
                cu_seqlens_k=attn_metadata.prefill_metadata.query_start_loc,
                max_seqlen_q=attn_metadata.prefill_metadata.max_query_len,
                max_seqlen_k=attn_metadata.prefill_metadata.max_seq_len,
                min_seqlen_q=attn_metadata.prefill_metadata.min_query_len,
                dropout_p=0.0,
                softmax_scale=self.scale,
                causal=True,
                window_size=self.sliding_window,
                alibi_slopes=self.alibi_slopes,
                out=output_actual_tokens[num_decode_tokens + num_extend_tokens :],
            )

        # calculate for extends
        if num_extends > 0:
            assert attn_metadata.extend_metadata is not None
            extend_tokens_slice = slice(
                num_decode_tokens, num_decode_tokens + num_extend_tokens
            )
            extend_querys = query[extend_tokens_slice]
            extend_keys = key[extend_tokens_slice]
            extend_values = value[extend_tokens_slice]
            extend_outputs = output[extend_tokens_slice]
            self.extend_forward(
                attn_metadata=attn_metadata,
                query=extend_querys,
                key=extend_keys,
                value=extend_values,
                key_cache=key_cache,
                value_cache=value_cache,
                output=extend_outputs,
                cu_seqlens_q=attn_metadata.extend_metadata.query_start_loc,
                max_seqlen_q=attn_metadata.extend_metadata.max_query_len,
                max_seqlen_k=attn_metadata.extend_metadata.max_seq_len,
                min_seqlen_q=attn_metadata.extend_metadata.min_query_len,
                block_table=attn_metadata.block_table[
                    num_decodes : num_decodes + num_extends
                ],
                slot_mapping=attn_metadata.slot_mapping[
                    num_decodes : num_decodes + num_extends
                ],
                k_scale=layer._k_scale,
                v_scale=layer._v_scale,
            )

        # calculate for decodes
        if num_decodes > 0:
            assert attn_metadata.decode_metadata is not None
            _, num_heads, head_size = query.shape
            nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8
            num_seqs = attn_metadata.seq_lens.shape[0]
            max_num_partitions = (
                attn_metadata.max_seq_len + _PARTITION_SIZE_ROCM - 1
            ) // _PARTITION_SIZE_ROCM

            workspace_buffer = torch.empty(
                (num_seqs * num_heads * max_num_partitions * head_size)
                * nbytes_per_qo_elem
                + 2 * (num_seqs * num_heads * max_num_partitions) * 4,
                dtype=torch.uint8,
                device=output.device,
            )

            torch.ops.aiter.paged_attention_v1(
                output[:num_decode_tokens],
                workspace_buffer,
                query[:num_decode_tokens],
                key_cache,
                value_cache,
                self.scale,
                attn_metadata.block_table[:num_decodes],
                attn_metadata.query_start_loc[:num_decodes],
                attn_metadata.seq_lens[:num_decodes],
                attn_metadata.max_seq_len,
                self.alibi_slopes,
                self.kv_cache_dtype,
                "NHD",
                self.logits_soft_cap,
                layer._k_scale,
                layer._v_scale,
                None,
                _PARTITION_SIZE_ROCM,
            )
    else:
        raise NotImplementedError(
            "Cascade attention is not implemented for ROCM AITER"
        )

    return output

AiterFlashAttentionMetadata dataclass

Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
@dataclass
class AiterFlashAttentionMetadata:
    # NOTE(sang): Definition of context_len, query_len, and seq_len.
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
    # |-------------------- seq_len ---------------------|
    #                                   |-- query_len ---|

    num_actual_tokens: int  # Number of tokens excluding padding.
    num_actual_kv_tokens: int
    max_query_len: int
    query_start_loc: torch.Tensor
    max_seq_len: int
    seq_lens: torch.Tensor
    slot_mapping: torch.Tensor
    block_table: torch.Tensor

    # prefill and deocde split
    num_decodes: int
    num_decode_tokens: int
    num_prefills: int
    num_prefill_tokens: int
    num_extends: int
    num_extend_tokens: int

    decode_metadata: AiterFlashAttentionDecodeMetadata | None
    prefill_metadata: AiterFlashAttentionPrefillMetadata | None
    extend_metadata: AiterFlashAttentionChunkPrefillMetadata | None

    # For cascade attention.
    use_cascade: bool
    common_prefix_len: int
    total_tokens: int

block_table instance-attribute

block_table: Tensor

common_prefix_len instance-attribute

common_prefix_len: int

decode_metadata instance-attribute

decode_metadata: AiterFlashAttentionDecodeMetadata | None

extend_metadata instance-attribute

extend_metadata: (
    AiterFlashAttentionChunkPrefillMetadata | None
)

max_query_len instance-attribute

max_query_len: int

max_seq_len instance-attribute

max_seq_len: int

num_actual_kv_tokens instance-attribute

num_actual_kv_tokens: int

num_actual_tokens instance-attribute

num_actual_tokens: int

num_decode_tokens instance-attribute

num_decode_tokens: int

num_decodes instance-attribute

num_decodes: int

num_extend_tokens instance-attribute

num_extend_tokens: int

num_extends instance-attribute

num_extends: int

num_prefill_tokens instance-attribute

num_prefill_tokens: int

num_prefills instance-attribute

num_prefills: int

prefill_metadata instance-attribute

prefill_metadata: AiterFlashAttentionPrefillMetadata | None

query_start_loc instance-attribute

query_start_loc: Tensor

seq_lens instance-attribute

seq_lens: Tensor

slot_mapping instance-attribute

slot_mapping: Tensor

total_tokens instance-attribute

total_tokens: int

use_cascade instance-attribute

use_cascade: bool

__init__

__init__(
    num_actual_tokens: int,
    num_actual_kv_tokens: int,
    max_query_len: int,
    query_start_loc: Tensor,
    max_seq_len: int,
    seq_lens: Tensor,
    slot_mapping: Tensor,
    block_table: Tensor,
    num_decodes: int,
    num_decode_tokens: int,
    num_prefills: int,
    num_prefill_tokens: int,
    num_extends: int,
    num_extend_tokens: int,
    decode_metadata: AiterFlashAttentionDecodeMetadata
    | None,
    prefill_metadata: AiterFlashAttentionPrefillMetadata
    | None,
    extend_metadata: AiterFlashAttentionChunkPrefillMetadata
    | None,
    use_cascade: bool,
    common_prefix_len: int,
    total_tokens: int,
) -> None

AiterFlashAttentionMetadataBuilder

Bases: AttentionMetadataBuilder[AiterFlashAttentionMetadata]

Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
class AiterFlashAttentionMetadataBuilder(
    AttentionMetadataBuilder[AiterFlashAttentionMetadata]
):
    cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
    reorder_batch_threshold: int = 1

    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        super().__init__(kv_cache_spec, layer_names, vllm_config, device)

        self.model_config = vllm_config.model_config
        self.parallel_config = vllm_config.parallel_config
        self.cache_config = vllm_config.cache_config

        self.num_heads_q = self.model_config.get_num_attention_heads(
            self.parallel_config
        )
        self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config)
        self.headdim = self.model_config.get_head_size()
        self.block_size = kv_cache_spec.block_size
        # Sliding window size to be used with the AOT scheduler will be
        # populated on first build() call.
        self.aot_sliding_window: tuple[int, int] | None = None
        self.total_tokens: int = 0

        self.extend_workspace = torch.empty(
            [2, _CP_TOKENS_PER_ITER_ROCM, self.num_heads_kv, self.headdim],
            dtype=self.model_config.dtype,
            device=device,
        )

    def build_for_cudagraph_capture(
        self, common_attn_metadata: CommonAttentionMetadata
    ):
        self.total_tokens = (
            self.model_config.max_model_len
            * self.vllm_config.scheduler_config.max_num_partial_prefills
        )
        res = self.build(common_prefix_len=0, common_attn_metadata=common_attn_metadata)
        self.total_tokens = 0
        return res

    def build(
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        fast_build: bool = False,
    ) -> "AiterFlashAttentionMetadata":
        split_ret = split_decodes_prefills_and_extends(
            common_attn_metadata,
            decode_threshold=self.reorder_batch_threshold,
        )

        (
            num_decodes,
            num_extends,
            num_prefills,
            num_decode_tokens,
            num_extend_tokens,
            num_prefill_tokens,
        ) = split_ret

        query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu

        seq_lens = common_attn_metadata.seq_lens_cpu

        query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]

        decode_metadata = None
        if num_decodes > 0:
            decode_metadata = AiterFlashAttentionDecodeMetadata(
                max_query_len=query_lens_cpu[:num_decodes].max().item(),
                min_query_len=query_lens_cpu[:num_decodes].min().item(),
                max_seq_len=seq_lens[:num_decodes].max().item(),
                query_start_loc=common_attn_metadata.query_start_loc[: num_decodes + 1],
            )

        prefill_metadata = None
        if num_prefills > 0:
            query_lens_for_prefill = query_lens_cpu[num_decodes + num_extends :]
            query_start_loc_device = common_attn_metadata.query_start_loc[
                num_decodes + num_extends :
            ]
            prefill_metadata = AiterFlashAttentionPrefillMetadata(
                max_query_len=query_lens_for_prefill.max().item(),
                min_query_len=query_lens_for_prefill.min().item(),
                max_seq_len=seq_lens[num_decodes + num_extends :].max().item(),
                query_start_loc=query_start_loc_device - query_start_loc_device[0],
            )

        extend_metadata = None
        if num_extends > 0:
            num_extends_slice = slice(num_decodes, num_decodes + num_extends)
            query_lens_for_extend = query_lens_cpu[num_extends_slice]
            seq_lens_for_extend = common_attn_metadata.seq_lens_cpu[num_extends_slice]
            computed_kv_lens = seq_lens_for_extend - query_lens_for_extend

            # allocate the equal amount of workspace for
            # each chunk prefill request
            max_context_chunk = _CP_TOKENS_PER_ITER_ROCM // num_extends
            num_chunks = cdiv(computed_kv_lens.max().item(), max_context_chunk)

            chunk_starts = (
                torch.arange(num_chunks, dtype=torch.int32)
                .unsqueeze(1)
                .expand(-1, num_extends)
                * max_context_chunk
            )
            chunk_ends = torch.min(
                computed_kv_lens.unsqueeze(0), chunk_starts + max_context_chunk
            )
            chunk_seq_lens = (chunk_ends - chunk_starts).clamp(
                min=0
            )  # [num_chunks, num_extends]
            cu_seq_lens_cpu = torch.zeros(
                [num_chunks, num_extends + 1], dtype=torch.int32, pin_memory=True
            )
            torch.cumsum(
                chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32
            )
            max_cum_tokens = cu_seq_lens_cpu[:, -1].max().item()

            range_idx = torch.arange(max_cum_tokens, dtype=torch.int32)[None, None, :]
            idx_to_batch_tensor = range_idx == cu_seq_lens_cpu[:, 1:][:, :, None]
            idx_to_batch_tensor = idx_to_batch_tensor.sum(
                dim=1
            )  # [num_chunks, max_cum_tokens]
            token_to_batch_tensor = torch.cumsum(idx_to_batch_tensor, dim=1)

            chunk_context_metadata = AiterChunkContextMetadata(
                workspace=self.extend_workspace,
                cu_seq_lens_chunk=cu_seq_lens_cpu.to(self.device, non_blocking=True),
                chunk_starts=chunk_starts.to(self.device, non_blocking=True),
                seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
                max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
                seq_lens=chunk_seq_lens,
                token_to_batch=token_to_batch_tensor.to(self.device, non_blocking=True),
                num_chunks=num_chunks,
                total_token_per_batch=cu_seq_lens_cpu[:, -1].tolist(),
            )

            query_start_loc_device = common_attn_metadata.query_start_loc[
                num_decodes : num_decodes + num_extends + 1
            ]
            seq_lens_device = common_attn_metadata.seq_lens[num_extends_slice]
            cu_seq_lens = torch.zeros(
                num_extends + 1, dtype=torch.int32, device=seq_lens_device.device
            )
            torch.cumsum(
                seq_lens_device, dim=0, dtype=cu_seq_lens.dtype, out=cu_seq_lens[1:]
            )
            extend_metadata = AiterFlashAttentionChunkPrefillMetadata(
                max_query_len=query_lens_for_extend.max().item(),
                min_query_len=query_lens_for_extend.min().item(),
                max_seq_len=seq_lens[num_extends_slice].max().item(),
                query_start_loc=query_start_loc_device - query_start_loc_device[0],
                chunk_context_metadata=chunk_context_metadata,
            )

        num_actual_kv_tokens = torch.sum(seq_lens).item()

        use_cascade = common_prefix_len > 0

        attn_metadata = AiterFlashAttentionMetadata(
            num_actual_tokens=common_attn_metadata.num_actual_tokens,
            num_actual_kv_tokens=num_actual_kv_tokens,
            max_query_len=common_attn_metadata.max_query_len,
            query_start_loc=common_attn_metadata.query_start_loc,
            max_seq_len=common_attn_metadata.max_seq_len,
            seq_lens=common_attn_metadata.seq_lens,
            block_table=common_attn_metadata.block_table_tensor,
            slot_mapping=common_attn_metadata.slot_mapping,
            num_decodes=num_decodes,
            num_decode_tokens=num_decode_tokens,
            num_prefills=num_prefills,
            num_prefill_tokens=num_prefill_tokens,
            num_extends=num_extends,
            num_extend_tokens=num_extend_tokens,
            decode_metadata=decode_metadata,
            prefill_metadata=prefill_metadata,
            extend_metadata=extend_metadata,
            use_cascade=use_cascade,
            common_prefix_len=common_prefix_len,
            total_tokens=self.total_tokens,
        )
        return attn_metadata

    def use_cascade_attention(self, *args, **kwargs) -> bool:
        return False

aot_sliding_window instance-attribute

aot_sliding_window: tuple[int, int] | None = None

block_size instance-attribute

block_size = block_size

cache_config instance-attribute

cache_config = cache_config

cudagraph_support class-attribute instance-attribute

cudagraph_support = UNIFORM_SINGLE_TOKEN_DECODE

extend_workspace instance-attribute

extend_workspace = empty(
    [2, _CP_TOKENS_PER_ITER_ROCM, num_heads_kv, headdim],
    dtype=dtype,
    device=device,
)

headdim instance-attribute

headdim = get_head_size()

model_config instance-attribute

model_config = model_config

num_heads_kv instance-attribute

num_heads_kv = get_num_kv_heads(parallel_config)

num_heads_q instance-attribute

num_heads_q = get_num_attention_heads(parallel_config)

parallel_config instance-attribute

parallel_config = parallel_config

reorder_batch_threshold class-attribute instance-attribute

reorder_batch_threshold: int = 1

total_tokens instance-attribute

total_tokens: int = 0

__init__

__init__(
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: device,
)
Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
def __init__(
    self,
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: torch.device,
):
    super().__init__(kv_cache_spec, layer_names, vllm_config, device)

    self.model_config = vllm_config.model_config
    self.parallel_config = vllm_config.parallel_config
    self.cache_config = vllm_config.cache_config

    self.num_heads_q = self.model_config.get_num_attention_heads(
        self.parallel_config
    )
    self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config)
    self.headdim = self.model_config.get_head_size()
    self.block_size = kv_cache_spec.block_size
    # Sliding window size to be used with the AOT scheduler will be
    # populated on first build() call.
    self.aot_sliding_window: tuple[int, int] | None = None
    self.total_tokens: int = 0

    self.extend_workspace = torch.empty(
        [2, _CP_TOKENS_PER_ITER_ROCM, self.num_heads_kv, self.headdim],
        dtype=self.model_config.dtype,
        device=device,
    )

build

build(
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    fast_build: bool = False,
) -> AiterFlashAttentionMetadata
Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
def build(
    self,
    common_prefix_len: int,
    common_attn_metadata: CommonAttentionMetadata,
    fast_build: bool = False,
) -> "AiterFlashAttentionMetadata":
    split_ret = split_decodes_prefills_and_extends(
        common_attn_metadata,
        decode_threshold=self.reorder_batch_threshold,
    )

    (
        num_decodes,
        num_extends,
        num_prefills,
        num_decode_tokens,
        num_extend_tokens,
        num_prefill_tokens,
    ) = split_ret

    query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu

    seq_lens = common_attn_metadata.seq_lens_cpu

    query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]

    decode_metadata = None
    if num_decodes > 0:
        decode_metadata = AiterFlashAttentionDecodeMetadata(
            max_query_len=query_lens_cpu[:num_decodes].max().item(),
            min_query_len=query_lens_cpu[:num_decodes].min().item(),
            max_seq_len=seq_lens[:num_decodes].max().item(),
            query_start_loc=common_attn_metadata.query_start_loc[: num_decodes + 1],
        )

    prefill_metadata = None
    if num_prefills > 0:
        query_lens_for_prefill = query_lens_cpu[num_decodes + num_extends :]
        query_start_loc_device = common_attn_metadata.query_start_loc[
            num_decodes + num_extends :
        ]
        prefill_metadata = AiterFlashAttentionPrefillMetadata(
            max_query_len=query_lens_for_prefill.max().item(),
            min_query_len=query_lens_for_prefill.min().item(),
            max_seq_len=seq_lens[num_decodes + num_extends :].max().item(),
            query_start_loc=query_start_loc_device - query_start_loc_device[0],
        )

    extend_metadata = None
    if num_extends > 0:
        num_extends_slice = slice(num_decodes, num_decodes + num_extends)
        query_lens_for_extend = query_lens_cpu[num_extends_slice]
        seq_lens_for_extend = common_attn_metadata.seq_lens_cpu[num_extends_slice]
        computed_kv_lens = seq_lens_for_extend - query_lens_for_extend

        # allocate the equal amount of workspace for
        # each chunk prefill request
        max_context_chunk = _CP_TOKENS_PER_ITER_ROCM // num_extends
        num_chunks = cdiv(computed_kv_lens.max().item(), max_context_chunk)

        chunk_starts = (
            torch.arange(num_chunks, dtype=torch.int32)
            .unsqueeze(1)
            .expand(-1, num_extends)
            * max_context_chunk
        )
        chunk_ends = torch.min(
            computed_kv_lens.unsqueeze(0), chunk_starts + max_context_chunk
        )
        chunk_seq_lens = (chunk_ends - chunk_starts).clamp(
            min=0
        )  # [num_chunks, num_extends]
        cu_seq_lens_cpu = torch.zeros(
            [num_chunks, num_extends + 1], dtype=torch.int32, pin_memory=True
        )
        torch.cumsum(
            chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32
        )
        max_cum_tokens = cu_seq_lens_cpu[:, -1].max().item()

        range_idx = torch.arange(max_cum_tokens, dtype=torch.int32)[None, None, :]
        idx_to_batch_tensor = range_idx == cu_seq_lens_cpu[:, 1:][:, :, None]
        idx_to_batch_tensor = idx_to_batch_tensor.sum(
            dim=1
        )  # [num_chunks, max_cum_tokens]
        token_to_batch_tensor = torch.cumsum(idx_to_batch_tensor, dim=1)

        chunk_context_metadata = AiterChunkContextMetadata(
            workspace=self.extend_workspace,
            cu_seq_lens_chunk=cu_seq_lens_cpu.to(self.device, non_blocking=True),
            chunk_starts=chunk_starts.to(self.device, non_blocking=True),
            seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
            max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
            seq_lens=chunk_seq_lens,
            token_to_batch=token_to_batch_tensor.to(self.device, non_blocking=True),
            num_chunks=num_chunks,
            total_token_per_batch=cu_seq_lens_cpu[:, -1].tolist(),
        )

        query_start_loc_device = common_attn_metadata.query_start_loc[
            num_decodes : num_decodes + num_extends + 1
        ]
        seq_lens_device = common_attn_metadata.seq_lens[num_extends_slice]
        cu_seq_lens = torch.zeros(
            num_extends + 1, dtype=torch.int32, device=seq_lens_device.device
        )
        torch.cumsum(
            seq_lens_device, dim=0, dtype=cu_seq_lens.dtype, out=cu_seq_lens[1:]
        )
        extend_metadata = AiterFlashAttentionChunkPrefillMetadata(
            max_query_len=query_lens_for_extend.max().item(),
            min_query_len=query_lens_for_extend.min().item(),
            max_seq_len=seq_lens[num_extends_slice].max().item(),
            query_start_loc=query_start_loc_device - query_start_loc_device[0],
            chunk_context_metadata=chunk_context_metadata,
        )

    num_actual_kv_tokens = torch.sum(seq_lens).item()

    use_cascade = common_prefix_len > 0

    attn_metadata = AiterFlashAttentionMetadata(
        num_actual_tokens=common_attn_metadata.num_actual_tokens,
        num_actual_kv_tokens=num_actual_kv_tokens,
        max_query_len=common_attn_metadata.max_query_len,
        query_start_loc=common_attn_metadata.query_start_loc,
        max_seq_len=common_attn_metadata.max_seq_len,
        seq_lens=common_attn_metadata.seq_lens,
        block_table=common_attn_metadata.block_table_tensor,
        slot_mapping=common_attn_metadata.slot_mapping,
        num_decodes=num_decodes,
        num_decode_tokens=num_decode_tokens,
        num_prefills=num_prefills,
        num_prefill_tokens=num_prefill_tokens,
        num_extends=num_extends,
        num_extend_tokens=num_extend_tokens,
        decode_metadata=decode_metadata,
        prefill_metadata=prefill_metadata,
        extend_metadata=extend_metadata,
        use_cascade=use_cascade,
        common_prefix_len=common_prefix_len,
        total_tokens=self.total_tokens,
    )
    return attn_metadata

build_for_cudagraph_capture

build_for_cudagraph_capture(
    common_attn_metadata: CommonAttentionMetadata,
)
Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
def build_for_cudagraph_capture(
    self, common_attn_metadata: CommonAttentionMetadata
):
    self.total_tokens = (
        self.model_config.max_model_len
        * self.vllm_config.scheduler_config.max_num_partial_prefills
    )
    res = self.build(common_prefix_len=0, common_attn_metadata=common_attn_metadata)
    self.total_tokens = 0
    return res

use_cascade_attention

use_cascade_attention(*args, **kwargs) -> bool
Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
def use_cascade_attention(self, *args, **kwargs) -> bool:
    return False

AiterFlashAttentionPrefillMetadata dataclass

Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
@dataclass
class AiterFlashAttentionPrefillMetadata:
    max_query_len: int
    min_query_len: int
    max_seq_len: int
    query_start_loc: torch.Tensor

max_query_len instance-attribute

max_query_len: int

max_seq_len instance-attribute

max_seq_len: int

min_query_len instance-attribute

min_query_len: int

query_start_loc instance-attribute

query_start_loc: Tensor

__init__

__init__(
    max_query_len: int,
    min_query_len: int,
    max_seq_len: int,
    query_start_loc: Tensor,
) -> None

block_size

block_size(x, head_dim)
Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
def block_size(x, head_dim):
    return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))

cp_mha_gather_cache

cp_mha_gather_cache(
    key_cache: Tensor,
    value_cache: Tensor,
    key: Tensor,
    value: Tensor,
    block_tables: Tensor,
    k_scales: Tensor,
    v_scales: Tensor,
    cu_seqlens_kv: Tensor,
    token_to_batch: Tensor,
    seq_starts: Tensor,
    dequant: bool,
    kv_cache_layout: str,
    total_tokens: int,
)
Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
def cp_mha_gather_cache(
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    block_tables: torch.Tensor,
    k_scales: torch.Tensor,
    v_scales: torch.Tensor,
    cu_seqlens_kv: torch.Tensor,
    token_to_batch: torch.Tensor,
    seq_starts: torch.Tensor,
    dequant: bool,
    kv_cache_layout: str,
    total_tokens: int,
):
    assert kv_cache_layout in ["v0", "NHD", "HND"], (
        "kv_cache_layout only support v0, NHD, HND"
    )
    head_dim = key.shape[2]
    x = 0
    # assert dequant is True, "Currently, we only support "\
    # "gather cache with dequant"
    # For k cache layout: [num_blocks, num_heads, page_size, head_dim]
    assert kv_cache_layout == "NHD", (
        "ROCM_AITER_FA_BACKEND Only support NHD kv cache layout for now"
    )
    assert head_dim == key_cache.shape[3], (
        "We assume your kv cache layout is [num_blocks, "
        "page_size, num_heads, head_dim], but got otherwise"
    )
    page_size = key_cache.shape[1]
    num_heads = key_cache.shape[2]

    NUM_PRGMS = num_programs(total_tokens)
    BLOCK_SIZE = block_size(key_cache, head_dim)
    grid = lambda meta: (NUM_PRGMS,)
    cp_mha_gather_cache_kernel[grid](
        key_cache,
        value_cache,
        key,
        value,
        block_tables,
        cu_seqlens_kv,
        token_to_batch,
        seq_starts,
        k_scales,
        v_scales,
        num_heads,
        head_dim,
        x,
        block_tables.size(1),
        total_tokens,
        DEQUANT=dequant,
        PAGE_SIZE=page_size,
        CACHE_FORMAT=kv_cache_layout,
        BLOCK_SIZE=BLOCK_SIZE,
        NUM_PRGMS=NUM_PRGMS,
    )

cp_mha_gather_cache_kernel

cp_mha_gather_cache_kernel(
    key_cache_ptr,
    value_cache_ptr,
    key_ptr,
    value_ptr,
    block_table_ptr,
    cu_seqlens_kv_ptr,
    token_to_batch_ptr,
    seq_start_ptr,
    k_scale_ptr,
    v_scale_ptr,
    num_heads,
    head_size,
    x,
    max_block_num,
    num_tokens,
    DEQUANT: constexpr,
    PAGE_SIZE: constexpr,
    CACHE_FORMAT: constexpr,
    BLOCK_SIZE: constexpr,
    NUM_PRGMS: constexpr,
)
Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
@triton.jit
def cp_mha_gather_cache_kernel(
    key_cache_ptr,  # [num_blocks, page_size, num_head, head_size]
    value_cache_ptr,  # [num_blocks, page_size, num_head, head_size]
    key_ptr,  # [num_tokens, num_heads, head_size]
    value_ptr,  # [num_tokens, num_heads, head_size]
    block_table_ptr,  # [num_batches, max_block_num]
    cu_seqlens_kv_ptr,  # [num_batches + 1]
    token_to_batch_ptr,  # [max_cum_tokens]
    seq_start_ptr,  # [num_batches]
    k_scale_ptr,
    v_scale_ptr,
    num_heads,
    head_size,
    x,
    max_block_num,
    num_tokens,
    DEQUANT: tl.constexpr,
    PAGE_SIZE: tl.constexpr,
    CACHE_FORMAT: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
    NUM_PRGMS: tl.constexpr,
):
    bid = tl.program_id(0)
    col_offsets = tl.arange(0, BLOCK_SIZE)
    if DEQUANT:
        k_scale = tl.load(k_scale_ptr)
        v_scale = tl.load(v_scale_ptr)

    for token_id in tl.range(bid, num_tokens, NUM_PRGMS):
        key_ptr_offset = key_ptr + token_id * head_size * num_heads
        value_ptr_offset = value_ptr + token_id * head_size * num_heads
        batch_idx = tl.load(token_to_batch_ptr + token_id)
        batch_start = tl.load(seq_start_ptr + batch_idx)
        token_start = tl.load(cu_seqlens_kv_ptr + batch_idx)
        batch_offset = token_id - token_start + batch_start
        block_offset = batch_offset // PAGE_SIZE
        block_id = tl.load(
            block_table_ptr + max_block_num * batch_idx + block_offset
        )
        slot_id = batch_offset % PAGE_SIZE

        if CACHE_FORMAT == "NHD":
            # for kv cache layout as
            # K: [num_blocks, page_size, num_head, head_dim]
            # V: [num_blocks, page_size, num_head, head_dim]
            key_cache_ptr_offset = (
                key_cache_ptr
                + block_id * num_heads * head_size * PAGE_SIZE
                + slot_id * num_heads * head_size
            )
            value_cache_ptr_offset = (
                value_cache_ptr
                + block_id * num_heads * head_size * PAGE_SIZE
                + slot_id * num_heads * head_size
            )

            for i in tl.range(0, head_size * num_heads, BLOCK_SIZE):
                mask = (col_offsets + i) < head_size * num_heads
                k_reg = tl.load(key_cache_ptr_offset + col_offsets + i, mask=mask)
                v_reg = tl.load(value_cache_ptr_offset + col_offsets + i, mask=mask)
                if DEQUANT:
                    k_dtype = k_reg.dtype
                    v_dtype = v_reg.dtype
                    k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype)
                    v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype)
                tl.store(key_ptr_offset + col_offsets + i, k_reg, mask=mask)
                tl.store(value_ptr_offset + col_offsets + i, v_reg, mask=mask)

num_programs

num_programs(head_dim)
Source code in vllm/v1/attention/backends/rocm_aiter_fa.py
def num_programs(head_dim):
    return min(head_dim, get_num_sms())