• <xmp id="om0om">
  • <table id="om0om"><noscript id="om0om"></noscript></table>
  • Conversational AI / NLP

    NVIDIA cuDNN 9? ????? ???

    Reading Time: 7 minutes

    NVIDIA CUDA ? ?? ???? ?????(cuDNN)? ??? ???? ? ?? ?? ??? ????? ?? GPU ?? ????????.

    cuDNN? PyTorch, TensorFlow ? XLA(?? ?? ??)? ?? ?? ?? ? ?? ?????? ?????. ??? ?????? ?? GPU ?????? ???? ?????? ?? ????? ??? ?? ?? ??? ???? ?????? ? ??? ? ????. cuDNN? ???? ?? ?? ??? ?? ??? ??????? ??? ??? ????? ??? ? ?? ????.

    ???? SDPA(Scaled Dot Product Attention)? ?? ?? ??(LLM)? ?? ?? ?????? ??? ???? ??? ?? ??? ??????. cuDNN? ? ?? ??? ?? ??? ????? ??? ??? ? ? ? ???? ???? ????? ??? ???? ??? ?? ?? ??? ???? ??? ??? ?? ??? ???? ????. 

    NVIDIA H200 Tensor ?? GPU?? cuDNN? FP8?? ?? 1.2PFLOPS? ??? ? ????. ?? ? ?? ???, ?? ?? Llama2 70B LoRA ?? ??? ?? cuDNN FP8 SDPA? ???? ? 1.15?? ?? ??? ??????. ? ????? 8GPU H200 ???? NVIDIA NeMo? NVIDIA TE(Transformer Engine)? ??????.

    ????? ??e2e ??
    cuDNN? ????? NeMo + TE1x(??)
    NeMo + TE(BF16? cuDNN SDPA ??)1.11x
    NeMo + TE(FP8? cuDNN SDPA ??)1.15x
    ? 1. 8-GPU H200 ???? ?? ? ?? ???? ??(Llama2 70B LoRA ?? ??)? ??? SDPA? cuDNN? ?? ? ?? ?? ??

    ? ????? cuDNN SDPA? ?? ??? ??? ?? ??? ??? ?? ??? ????, cuDNN 9?? ????? ? ?? ??? ??? ??? ???? ????.

    SDPA(Scaled Dot Product Attention) ??

    NVIDIA? ??? ????? ?? ??? ???? APEX ?????? ?? ???? ???(fMHA) ??? ?? ??? ???? ???? ???? ??????. Tri Dao? ???? ??? ? ??? ????? ???? ??? ???? ??? ??? ?? ??? ???? ??????. ??? ??? GitHub? /Dao-AILab/flash-attention ? FlashAttention-2: ? ?? ?? ?? ? ?? ??? ? ??? ??? ??? ?????.

    ?? ??, NVIDIA? ? ??? ??? ???? ??? ???? ??? ??? ???????. ? ??? ?? NVIDIA Hopper ???? GPU? ?? NVIDIA TE? ?? ?? ??? ??????. 

    XLA? ?? cuDNN SDPA?? ??? ???? JAX SDPA API? ?? ?????? XLA ????? ???? JAX/PyTorch? ??? ???? cuDNN SDPA? ?? ? ????. 

    PyTorch ?? ?? SDPA? ?? cuDNN? ???? ?? ????, cuDNN ?? ??? ?? ????. ??? ??? Fprop ? Bprop? PyTorch PR? ?????.  

    cuDNN SDPA ????? ??? ??????.

    • ?? NVIDIA ???? ????? ?? ???? ??
    • ? ??????? ??? ??? v2 ? ? ??? ?? ?? ????? ?? ??? SDPA ????? ??
    • ?? ?? ? ?? GPU ????? ?? ?? ??(?: ?? ??)? ???? ???? ????

    ?? NVIDIA GPU?? ??? ?? ??? ?????. ?? 1? 2? ??? ?? ???? cuDNN 9 BF16? BF16??? ?? ??? ??? PyTorch eager ???? ?? 2? ? ??? ?????. cuDNN FP8 ??? ?? 3? ? ?? ??? ?????.

    ??? ???? ??? ??? ? ???? ??? ?? ???? ? ?? ?? ??? ?????. ?? ?? ????(?: ??? ??? ?? ????)? ?? ? ???? GPU ??? ???? ??? ????? ???? ????.

    ?? 1. ?? ???, ?? ?? 128? ?? SDPA(??? ??)
    ?? 2. ?? ???, ?? ?? 128? ?? SDPA(??? + ???)
    ?? 3. ?? ???, ?? ?? 256? ?? SDPA(??? ??)

    cuDNN ?????? SDPA

    cuDNN? SDPA? Tensor ??? cuDNN ???? ??? ? ????. ??? ????? cuDNN ??????? ?? ??? ? ?? ?? ?? ??? ????. 

    ?? ????? ??? ??? ?? ?? ???? ? ?? ??? GPU?? ????? ???? ???? cuDNN ???? ?? ??? ??? ??? ???? ? ? ??? ????. ??? ????? ??? ??? ????? ??? ?? ?? ?? ???? ?????. ?? ?? ????? ???? “?? ???”? ? ? ????.

    SDPA(??? ??)? ??? ?? ???? ???? ?? ????. ??? ??? ???? ??? ??? ???? ?? ?? ??? ??? ????? ???? ??? ?? ????? ??? ??? ???? ???????. ?, ??? ??? ?????? ?? GPU?? ????? ??? ? ????. ? ??? ?? ??? ???? ??? ???? ? ? ??? ????.

    ?? 4? ??? ??(fprop) ?? ??? ?? cuDNN ???? ?????. ?? ?? ???? ??? ??? ????? SDPA? ??? ??? ??? ??? ??? ?????.

    ?? 4. SDPA? ?? ???? fprop ?? cuDNN ????

    ?? ??? ??? ?? gen ???, ?? ?? ? ?? ???? ?????. softmax ??? 6?? cuDNN ???? ?????(?? 4).? cuDNN ??? ??? ???? ??? ??? ??? ? ?? ???? ?????. ?? 5? ??? ??? ?????.?

    ?? 5. cuDNN ??? ??

    ?? 5? ?? ??? ? ??(256) ?? ???? ?? ? cuDNN FP8 ??? ??? ???? ?? 1.2PFLOP? ??? ? ??? ?????. 

    BMM1? Softmax ?? ??? ?? ???? ??? ???? ??? ?? ????. ??? ??? ??? ?? ???? ?? ??? ??? ??? ? ????. 

    ?????? ???? ?? ??? ??? ???? ? ???? ????? ???? ???? ?? ??? ??? ???? ?????. ?? 2? ?? ???? ?? ??? ?????. ??? ??? ????? ?? ??? ???? ??? ? ????.

    SDPA ?? ??

    cuDNN ???? ???? ???? ? ??? ? ?? ?? API ???? ???? ??? ????.

    • ????? API(C++ ? Python ?? ?? ??)
    • ??? API(C? ??)

    ?? ???? ??? ?? cuDNN ??? ??? ? API ??? ?? ?????. ??? C ?????? ???? ??? cuDNN ?? Python ?? C++?? cuDNN ????? API? ?????? ??? ?? ?????. ????? API? ?? ? ???? ? ?? ???? ?????.

    ?? ??, ????? API? ?? ??? ??? ???? ?? ??? ?? ?? ??? ??? ????? ??? ?????. ?, SDPA? ?? ???? ??? ??? ?? ??? ????? ??? ?????. ??? ? ??? ??? ??? ? ??? ???? ?? ?????. ? SDPA ?? ??? ?? ??? ??? Python?? ??? SDPA ??? ?????.

    ????? ?????? SDPA Python ??? ?? ??? ?? ?? ??? ?????.

    1. ??? ??? ???? cudnn.pygraph ??? ??????.?
    2. Tensor? ??, ???? ? ??? ??? ???? Tensor ??? ????.?
    3. ??? ?? ??? ??? ??? ???? ??? ??? ?????.?
    4. ???? ???? ???? ???? ?????.?
    5. ???? ?????.??
    # 1. cuDNN graph
    graph = cudnn.pygraph(
              io_data_type=cudnn.data_type.BFLOAT16,
              intermediate_data_type=cudnn.data_type.FLOAT,
              compute_data_type=cudnn.data_type.FLOAT,
            )
     
    # 2. tensor attributes
    q = make_tensor_attr(graph, q_gpu, "q")
    k = make_tensor_attr(graph, k_gpu, "k")
    v = make_tensor_attr(graph, v_gpu, "v")
    attn_scale = make_tensor_attr(graph, attn_scale_cpu, "attn_scale", is_pass_by_value=True)
     
    # 3. sdpa node with the configurations
    o, stats = graph.scaled_dot_product_flash_attention(
                        name="scaled_dot_product_flash_attention",
                        q=q,
                        k=k,
                        v=v,
                        is_inference=false,
                        attn_scale=attn_scal,
                        bias=None,
                        use_alibi_mask=false,
                        use_padding_mask=false,
                        seq_len_q=None,
                        seq_len_kv=None,
                        use_causal_mask=true,
                        dropout=None,
                    )
     
    # 4. build the graph and provide the device pointers
    graph.build()
    variant_pack = {
                     q: q_gpu,
                     k: k_gpu,
                     v: v_gpu,
                     o: o_gpu
                   }
     
    # 5. execute the graph
    graph.execute()

    SDPA ??? ????? ???? ??? ? ?? ???? ???? ?? SDPA ???? ???? ??? ?? ????? ???? ? ????. 

    ?? ??, ?? ?? ????? SDPA ?? ?? ?? ??? ??????. ??? ??? scaled_dot_product_flash_attention.h? ?????.?

    // Optional scale
    if (options.inputs.Attn_scale) {
          // Lower options to scale options
          auto attn_scale_output = std::make_shared<Tensor_attributes>();
          attn_scale_output->set_is_virtual(true);
     
          Pointwise_attributes scale_attributes;
          scale_attributes.set_name("attn_scale");
          scale_attributes.set_mode(PointwiseMode_t::MUL);
          scale_attributes.inputs.IN_0 = last_output;
          scale_attributes.inputs.IN_1 = options.inputs.Attn_scale;
          last_output = scale_attributes.outputs.OUT_0 = attn_scale_output;
          auto scale_node = std::make_unique<PointwiseNode>(std::move(scale_attributes), context);
           
          sub_nodes.emplace_back(std::move(scale_node));
    }

    ??? ?????? ????? ??? ?? ??? ??? NVIDIA cuDNN ???? ?????.

    ? ? ????? cuDNN 9 ??

    SDPA ?? ?? ??? cuDNN 9?? ??? ?? ? ?? ??? ?? ??? ???????.

    • ??? ? ???? ?? ?? ?? ??? ??
    • ??? ?? ??
    • ???? ?? ???? ???
    • ???? ??

    ??? ? ???? ?? ?? ?? ??? ?? 

    ?? ????? ??? ??? ??? ??(FP16, FP32)??? ?? ???(matmul) ? ??? API? ???? FP16? ?? ???? INT8? ?? ? ?? AWQ(Activation-aware Weight Quantization)? ?? ??? ???? ????. ? ?? ???? ????? ???? ????? ???. ???? ?? ??, ?? ??? ?? ?? ???? ?????. 

    cuDNN? ?? ?? ? ??? ???? ?? ??? ?? ?? ??? ?? A ? B ????? ?? ?? ??? ??? ? ?? ?? ?? ??? ???? ???? ?????. ???? ? ??? ?? ???? ???? ? ????. cuDNN? ???? ???? ??? ??? ?????.

    ?? 6? cuDNN ?? ?? ??? ???? ???? ?? ????? ?? ?? ??? ?????. ?? ??? ?? A? B? ?? FP16 ? INT8 ????? A? INT8? ??? ? INT32 ??? ?? INT8xINT8 ???? ???? ??? ?????. ?? ??? A? INT8?? FP16?? ????? ? FP32 ??? ?? FP16xFP16 ???? ???? ??? ?????.

    ?? 6. cuDNN?? ?? ?? ??? ??? ? ??? ??

    ??? ?? ??

    ??? ????? ??, ?? cuDNN? ???? ? ?? ?????? ?? ??? ????? ??????. ??? cuDNN?? ?? ??? ???? ?? ? ?? ???? ????????. NVIDIA? ? ??? ???? ?? ?? ??? ????? ??? ????. 

    cuDNN 9? ??? ?? ? ??? ?????.

    • ?? ???? ?? ??
    • ???? ?? ?? ?? ???? ? ??? ?? ??
    • ?? ??? ???? ??? ?? ??
    • cuDNNGetLastErrorstring(????? ???? ??? ?? ???? ???? ? ??)

    ??? ??? ??? ???? ?? ? API ?? ??? ?????.

    ???? ?? ???? ???

    ?? 9.0.0 ??? cuDNN ?????? ????? ???? ????? ?? ??? ?? GPU ?????? ??????. ?? ??, cuDNN 8.9.x? NVIDIA Hopper?? ??????(?, ??? ?? 9.0). ?? GPU ?????? 8.9.x? ???? ?? ???? ????. 

    ??? cuDNN 9? API? ??? ?? ??? ?? ???? ?? ???? ???? ?????. ?, cuDNN v9? ?? ? API ?? ??? ???? ????? ?? ??? GPU??? ???? ??? ????? ???? ?? ??? GPU? ???? ?? cuDNN ??? ??? ??????? ??? ???. 

    9.0 ??? ??? ??? ?? ?? GPU?? ??? ? ??? ???? ?? ?? ???? ???? ?????? ????? ??? ??? ?? PTX JIT? ???? ? ????? ???? ??? ?? ?????.

    ?? ?? ? ?? ??? ?? ??? ??? ??? ???? ???? ?? ???? ??? ??? ?????.

    ???? ??

    Python ????? ?? ????? ??? pip? ???? ? Python ?????? ??? ? ????.

    # cuDNN library for CUDA 12
    pip install nvidia-cudnn-cu12 
     
    # cuDNN frontend
    pip install git+https://github.com/NVIDIA/cudnn-frontend.git

    ?? cuDNN 9? RPM ? Debian ?? ???? ?? ????? ???????. ?? ??, ?? ??? ???? Ubuntu 22? cuDNN? ?????.

    wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
    sudo dpkg -i cuda-keyring_1.1-1_all.deb
    sudo apt-get update
    sudo apt-get install -y cudnn

    ? ? ??? Debian ??? ??? ?? ??? ????? cudnn ?? ???? ???? ??? cuDNN ????? ??? ??????.

    ??? ?? ? ?? ??? cuDNN ?? ???? ??????. 

    ?? ??

    ???, ?? ?? ??? ?? ?? cuDNN NVIDIA ??? ??(cuDNN ????? ?? ??) ?? GitHub? NVIDIA/cudnn-frontend(????? ?? ??)? ??? ? ????. cuDNN ???? ????.

    ??? ? ?? ? cuDNN ??? ??? ???. ? ????? ???? ??? ?? ??? ??? ???? ???? ??? ? ?? ??? ??? ????. 

    AI? ???? ???? ??? ????? ????? ??? ??? ???? ?? NVIDIA? ? ?? ?????? ??? ???? ???? cuDNN? ?? ????? ????? ??? ? ??? ??? ????? ??? ??? ????? ???? ????.

    ?? ?? ??

    NVIDIA cuDNN ?? ?? ??? ?? ?? ?? ???? ? ???? ?? ???? ??????.

    ?? ???

    Discuss (0)
    0

    Tags

    人人超碰97caoporen国产