From 009aca6e4406632cab511c6907179abb1e5f1006 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Tue, 10 Dec 2024 11:15:20 +0800 Subject: [PATCH] [CI] Add llama2-7b-cinn test (#9578) * refine log * refine --- scripts/distribute/ci_case_auto.sh | 105 +++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) diff --git a/scripts/distribute/ci_case_auto.sh b/scripts/distribute/ci_case_auto.sh index 010e2ba6cd14..c8bc21a002fc 100755 --- a/scripts/distribute/ci_case_auto.sh +++ b/scripts/distribute/ci_case_auto.sh @@ -102,6 +102,7 @@ function llama_case_list_auto() { llama_align_dy2st_fthenb_and_vpp_auto_bs2_fp32_DP1-MP1-PP4 llama_align_dygraph_dy2st_pir_auto_pp_bs2_bf16_DP1-MP1-PP4 llama_baichuan_pir_auto_fuse_ffn_attention_qkv_DP2_MP2_PP2 + llama_dy2st_auto_bs2_bf16_DP2-MP1-PP1-CINN ) if [ $1 = "prepare_case" ]; then restore_func $fun_list @@ -1025,6 +1026,110 @@ function llama_align_dygraph_dy2st_auto_bs2_bf16_DP2-MP1-PP1() { echo "=========== $FUNCNAME run end ===========" } +function llama_dy2st_auto_bs2_bf16_DP2-MP1-PP1-CINN() { + echo "=========== $FUNCNAME run begin ===========" + export PYTHONPATH=$root_path/:$PYTHONPATH + export FLAGS_call_stack_level=3 + export FLAGS_cudnn_deterministic=1 + export NVIDIA_TF32_OVERRIDE=0 + export FLAGS_embedding_deterministic=1 + export FLAGS_flash_attn_version=v1 + export FLAGS_enable_pir_api=1 + export FLAGS_max_inplace_grad_add=4 + export PARALLEL_CROSS_ENTROPY=true + + export FLAGS_use_cinn=1 + export FLAGS_dist_prim_all=1 + export FLAGS_prim_forward_blacklist="pd_op.stack;pd_op.squeeze;pd_op.swiglu;pd_op.squared_l2_norm" + export FLAGS_prim_backward_blacklist="swiglu_grad" + + task_name="llama_dy2st_auto_bs2_bf16_DP2-MP1-PP1-CINN" + case_out_dir="output/$task_name" + case_log_dir="output/$task_name""_log" + rm -rf $case_out_dir + rm -rf $case_log_dir + + python -u -m paddle.distributed.launch \ + --gpus "0,1" \ + --log_dir $case_log_dir \ + run_pretrain_auto.py \ + --model_type "llama" \ + --model_name_or_path "facebook/llama-7b" \ + --tokenizer_name_or_path "facebook/llama-7b" \ + --input_dir "./data" \ + --output_dir $case_out_dir \ + --split 949,50,1 \ + --weight_decay 0.01 \ + --warmup_ratio 0.01 \ + --warmup_steps 30 \ + --max_grad_norm 1.0 \ + --learning_rate 3e-05 \ + --min_learning_rate 3e-06 \ + --max_steps 10 \ + --logging_steps 10 \ + --eval_steps 1000 \ + --save_steps 50000 \ + --continue_training 0 \ + --do_train true \ + --do_eval false \ + --do_predict false \ + --disable_tqdm true \ + --skip_profile_timer true \ + --save_total_limit 2 \ + --device gpu \ + --disable_tqdm true \ + --dataloader_num_workers 1 \ + --distributed_dataloader 0 \ + --enable_auto_parallel 1 \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --per_device_eval_batch_size 2 \ + --recompute false \ + --recompute_use_reentrant true \ + --recompute_granularity full \ + --pp_recompute_interval 0 \ + --bf16 1 \ + --fp16_opt_level "O2" \ + --amp_custom_black_list "reduce_sum" "c_softmax_with_cross_entropy" \ + --amp_custom_white_list "lookup_table" "lookup_table_v2" \ + --amp_master_grad 1 \ + --fuse_attention_ffn true \ + --fuse_attention_qkv true \ + --fuse_sequence_parallel_allreduce false \ + --use_flash_attention 0 \ + --use_fused_rope false \ + --use_fused_rms_norm false \ + --max_seq_length 4096 \ + --sep_parallel_degree 1 \ + --sequence_parallel false \ + --pipeline_parallel_degree 1 \ + --sharding_parallel_degree 1 \ + --tensor_parallel_degree 1 \ + --virtual_pp_degree 1 \ + --pipeline_schedule_mode "VPP" \ + --sharding "" \ + --to_static ${to_static} \ + --num_hidden_layers 2 \ + >>${log_path}/$FUNCNAME 2>&1 + loss=`cat $case_log_dir/workerlog.0 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'` + ips=-1 + mem=-1 + echo "result: to_static=$to_static loss=$loss ips=$ips mem=$mem" + loss_base=9.99302597 + if [ $IS_A100 -ne 0 ];then + loss_base=10.20988007 + fi + ips_base=-1 + mem_base=-1 + check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem} + + unset FLAGS_use_cinn + unset FLAGS_dist_prim_all + unset FLAGS_prim_forward_blacklist + unset FLAGS_prim_backward_blacklist + echo "=========== $FUNCNAME run end ===========" +} + function llama_align_dygraph_dy2st_pir_auto_grad_merge_bs2_fp32_DP1-MP1-PP1() { echo "=========== $FUNCNAME run begin ===========" export PYTHONPATH=$root_path/:$PYTHONPATH