Repositori ini memegang utilitas yang diperiksa NVIDIA untuk merampingkan presisi campuran dan pelatihan terdistribusi di Pytorch. Beberapa kode di sini akan dimasukkan dalam pytorch hulu pada akhirnya. Maksud Apex adalah untuk membuat utilitas terkini yang tersedia untuk pengguna secepat mungkin.
Tercerahkan. Gunakan amp Pytorch
apex.amp
adalah alat untuk mengaktifkan pelatihan presisi campuran dengan hanya mengubah 3 baris skrip Anda. Pengguna dapat dengan mudah bereksperimen dengan mode pelatihan presisi murni dan campuran dengan memasok bendera yang berbeda ke amp.initialize
.
Webinar Memperkenalkan AMP (bendera cast_batchnorm
telah diganti namanya menjadi keep_batchnorm_fp32
).
Dokumentasi API
Contoh Imagenet Komprehensif
Contoh dcgan segera hadir ...
Pindah ke API AMP baru (untuk pengguna "amp" yang sudah usang dan "fp16_optimizer" API)
apex.parallel.DistributedDataParallel
sudah usang. Gunakan torch.nn.parallel.DistributedDataParallel
apex.parallel.DistributedDataParallel
adalah pembungkus modul, mirip dengan torch.nn.parallel.DistributedDataParallel
. Ini memungkinkan pelatihan terdistribusi multiproses yang nyaman, dioptimalkan untuk Perpustakaan Komunikasi NCCL NVIDIA.
Dokumentasi API
Sumber Python
Contoh/walkthrough
Contoh imagenet menunjukkan penggunaan apex.parallel.DistributedDataParallel
bersama dengan apex.amp
.
Tercerahkan. Gunakan torch.nn.SyncBatchNorm
apex.parallel.SyncBatchNorm
memperluas torch.nn.modules.batchnorm._BatchNorm
untuk mendukung Bn yang disinkronkan. Ini allreduksi statistik di seluruh proses selama pelatihan multiproses (distributedDataParallel). BN sinkron telah digunakan dalam kasus di mana hanya minibatch lokal kecil yang dapat muat pada setiap GPU. Allreduced Stats meningkatkan ukuran batch efektif untuk lapisan BN ke ukuran batch global di semua proses (yang, secara teknis, adalah formulasi yang benar). BN sinkron telah diamati untuk meningkatkan akurasi konvergen dalam beberapa model penelitian kami.
Untuk menyimpan dan memuat pelatihan amp
Anda dengan benar, kami memperkenalkan amp.state_dict()
, yang berisi semua loss_scalers
dan langkah -langkah yang tidak dikeluarkan, serta amp.load_state_dict()
untuk mengembalikan atribut ini.
Untuk mendapatkan akurasi bitwise, kami merekomendasikan alur kerja berikut:
# Initialization
opt_level = 'O1'
model , optimizer = amp . initialize ( model , optimizer , opt_level = opt_level )
# Train your model
...
with amp . scale_loss ( loss , optimizer ) as scaled_loss :
scaled_loss . backward ()
...
# Save checkpoint
checkpoint = {
'model' : model . state_dict (),
'optimizer' : optimizer . state_dict (),
'amp' : amp . state_dict ()
}
torch . save ( checkpoint , 'amp_checkpoint.pt' )
...
# Restore
model = ...
optimizer = ...
checkpoint = torch . load ( 'amp_checkpoint.pt' )
model , optimizer = amp . initialize ( model , optimizer , opt_level = opt_level )
model . load_state_dict ( checkpoint [ 'model' ])
optimizer . load_state_dict ( checkpoint [ 'optimizer' ])
amp . load_state_dict ( checkpoint [ 'amp' ])
# Continue training
...
Perhatikan bahwa kami sarankan mengembalikan model menggunakan opt_level
yang sama. Perhatikan juga bahwa kami merekomendasikan untuk memanggil metode load_state_dict
setelah amp.initialize
.
Setiap modul apex.contrib
memerlukan satu atau lebih opsi instal selain --cpp_ext
dan --cuda_ext
. Perhatikan bahwa modul yang berkontribusi tidak selalu mendukung pelepasan pytorch yang stabil.
Kontainer Nvidia Pytorch tersedia di NGC: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch. Wadah datang dengan semua ekstensi khusus yang tersedia saat ini.
Lihat dokumentasi NGC untuk detail seperti:
Untuk menginstal apex dari sumber, kami sarankan menggunakan pytorch malam yang dapat diperoleh dari https://github.com/pytorch/pytorch.
Rilis stabil terbaru yang dapat diperoleh dari https://pytorch.org juga harus berfungsi.
Kami sarankan menginstal Ninja
untuk membuat kompilasi lebih cepat.
Untuk kinerja dan fungsionalitas penuh, kami sarankan menginstal APEX dengan ekstensi CUDA dan C ++ melalui
git clone https://github.com/NVIDIA/apex
cd apex
# if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key...
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings " --build-option=--cpp_ext " --config-settings " --build-option=--cuda_ext " ./
# otherwise
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option= " --cpp_ext " --global-option= " --cuda_ext " ./
Apex juga mendukung build khusus Python melalui
pip install -v --disable-pip-version-check --no-build-isolation --no-cache-dir ./
Bangunan khusus Python menghilangkan:
apex.optimizers.FusedAdam
.apex.normalization.FusedLayerNorm
dan apex.normalization.FusedRMSNorm
.apex.parallel.SyncBatchNorm
.apex.parallel.DistributedDataParallel
dan apex.amp
. DistributedDataParallel
, amp
, dan SyncBatchNorm
masih dapat digunakan, tetapi mereka mungkin lebih lambat. pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" .
Mungkin bekerja jika Anda dapat membangun Pytorch dari sumber di sistem Anda. Build khusus Python via pip install -v --no-cache-dir .
lebih cenderung bekerja.
Jika Anda memasang pytorch di lingkungan Conda, pastikan untuk memasang apex di lingkungan yang sama.
Jika persyaratan modul tidak terpenuhi, maka itu tidak akan dibangun.
Nama Modul | Instal Opsi | Misc |
---|---|---|
apex_C | --cpp_ext | |
amp_C | --cuda_ext | |
syncbn | --cuda_ext | |
fused_layer_norm_cuda | --cuda_ext | apex.normalization |
mlp_cuda | --cuda_ext | |
scaled_upper_triang_masked_softmax_cuda | --cuda_ext | |
generic_scaled_masked_softmax_cuda | --cuda_ext | |
scaled_masked_softmax_cuda | --cuda_ext | |
fused_weight_gradient_mlp_cuda | --cuda_ext | Membutuhkan cuda> = 11 |
permutation_search_cuda | --permutation_search | apex.contrib.sparsity |
bnp | --bnp | apex.contrib.groupbn |
xentropy | --xentropy | apex.contrib.xentropy |
focal_loss_cuda | --focal_loss | apex.contrib.focal_loss |
fused_index_mul_2d | --index_mul_2d | apex.contrib.index_mul_2d |
fused_adam_cuda | --deprecated_fused_adam | apex.contrib.optimizers |
fused_lamb_cuda | --deprecated_fused_lamb | apex.contrib.optimizers |
fast_layer_norm | --fast_layer_norm | apex.contrib.layer_norm . berbeda dari fused_layer_norm |
fmhalib | --fmha | apex.contrib.fmha |
fast_multihead_attn | --fast_multihead_attn | apex.contrib.multihead_attn |
transducer_joint_cuda | --transducer | apex.contrib.transducer |
transducer_loss_cuda | --transducer | apex.contrib.transducer |
cudnn_gbn_lib | --cudnn_gbn | Membutuhkan cudnn> = 8.5, apex.contrib.cudnn_gbn |
peer_memory_cuda | --peer_memory | apex.contrib.peer_memory |
nccl_p2p_cuda | --nccl_p2p | Membutuhkan nccl> = 2.10, apex.contrib.nccl_p2p |
fast_bottleneck | --fast_bottleneck | Membutuhkan peer_memory_cuda dan nccl_p2p_cuda , apex.contrib.bottleneck |
fused_conv_bias_relu | --fused_conv_bias_relu | Membutuhkan cudnn> = 8.4, apex.contrib.conv_bias_relu |