Este repositório mantém as concessionárias mantidas na NVIDIA para otimizar a precisão mista e o treinamento distribuído em Pytorch. Alguns do código aqui serão incluídos no Upstream Pytorch eventualmente. A intenção do Apex é disponibilizar utilitários atualizados para os usuários o mais rápido possível.
Descontinuado. Use Pytorch AMP
apex.amp
é uma ferramenta para ativar o treinamento de precisão mista, alterando apenas 3 linhas do seu script. Os usuários podem experimentar facilmente diferentes modos de treinamento de precisão pura e mista, fornecendo sinalizadores diferentes para amp.initialize
.
Webinar Apresentando o AMP (o flag cast_batchnorm
foi renomeado para keep_batchnorm_fp32
).
Documentação da API
Exemplo abrangente de imagenet
Exemplo de dcgan em breve ...
Mudando para a nova API do AMP (para usuários das APIs depreciadas "AMP" e "FP16_Optimizer")
apex.parallel.DistributedDataParallel
está descontinuado. Use torch.nn.parallel.DistributedDataParallel
apex.parallel.DistributedDataParallel
é um invólucro de módulo, semelhante a torch.nn.parallel.DistributedDataParallel
. Ele permite o conveniente treinamento distribuído multiprocess, otimizado para a Biblioteca de Comunicação NCCL da NVIDIA.
Documentação da API
Fonte Python
Exemplo/passo a passo
O exemplo do ImageNet mostra o uso de apex.parallel.DistributedDataParallel
junto com apex.amp
.
Descontinuado. Use torch.nn.SyncBatchNorm
apex.parallel.SyncBatchNorm
estende torch.nn.modules.batchnorm._BatchNorm
para suportar o BN sincronizado. ALLDUCE STATS ENTRE PROCESSOS DURANTE O TREINAMENTO MULTIPROCESS (DISTIBUDEDDATAPARALL). O BN síncrono tem sido usado nos casos em que apenas um pequeno minibatch local pode caber em cada GPU. As estatísticas alterações aumentam o tamanho efetivo do lote da camada BN para o tamanho do lote global em todos os processos (que, tecnicamente, é a formulação correta). O BN síncrono foi observado para melhorar a precisão convergente em alguns de nossos modelos de pesquisa.
Para salvar e carregar corretamente o seu treinamento amp
, apresentamos o amp.state_dict()
, que contém todos loss_scalers
e suas etapas não elaboradas correspondentes, bem como amp.load_state_dict()
para restaurar esses atributos.
Para obter uma precisão bittual, recomendamos o seguinte fluxo de trabalho:
# Initialization
opt_level = 'O1'
model , optimizer = amp . initialize ( model , optimizer , opt_level = opt_level )
# Train your model
...
with amp . scale_loss ( loss , optimizer ) as scaled_loss :
scaled_loss . backward ()
...
# Save checkpoint
checkpoint = {
'model' : model . state_dict (),
'optimizer' : optimizer . state_dict (),
'amp' : amp . state_dict ()
}
torch . save ( checkpoint , 'amp_checkpoint.pt' )
...
# Restore
model = ...
optimizer = ...
checkpoint = torch . load ( 'amp_checkpoint.pt' )
model , optimizer = amp . initialize ( model , optimizer , opt_level = opt_level )
model . load_state_dict ( checkpoint [ 'model' ])
optimizer . load_state_dict ( checkpoint [ 'optimizer' ])
amp . load_state_dict ( checkpoint [ 'amp' ])
# Continue training
...
Observe que recomendamos restaurar o modelo usando o mesmo opt_level
. Observe também que recomendamos chamar os métodos load_state_dict
após amp.initialize
.
Cada módulo apex.contrib
requer uma ou mais opções de instalação que não --cpp_ext
e --cuda_ext
. Observe que os módulos de contribuição não suportam necessariamente as liberações estáveis de Pytorch.
Os contêineres NVIDIA Pytorch estão disponíveis no NGC: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch. Os contêineres vêm com todas as extensões personalizadas disponíveis no momento.
Veja a documentação do NGC para obter detalhes como:
Para instalar o APEX da fonte, recomendamos o uso do Pytorch noturno obtido em https://github.com/pytorch/pytorch.
A versão estável mais recente obtida em https://pytorch.org também deve funcionar.
Recomendamos instalar Ninja
para tornar a compilação mais rapidamente.
Para desempenho e funcionalidade completa, recomendamos a instalação do APEX com extensões CUDA e C ++ via
git clone https://github.com/NVIDIA/apex
cd apex
# if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key...
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings " --build-option=--cpp_ext " --config-settings " --build-option=--cuda_ext " ./
# otherwise
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option= " --cpp_ext " --global-option= " --cuda_ext " ./
Apex também suporta uma construção somente para Python via
pip install -v --disable-pip-version-check --no-build-isolation --no-cache-dir ./
Uma construção somente para Python omite:
apex.optimizers.FusedAdam
.apex.normalization.FusedLayerNorm
e apex.normalization.FusedRMSNorm
.apex.parallel.SyncBatchNorm
.apex.parallel.DistributedDataParallel
e apex.amp
. DistributedDataParallel
, amp
e SyncBatchNorm
ainda serão utilizáveis, mas podem ser mais lentos. pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" .
Pode funcionar se você puder construir Pytorch a partir da fonte em seu sistema. Uma construção somente para Python via pip install -v --no-cache-dir .
é mais provável que funcione.
Se você instalou o Pytorch em um ambiente do CONDA, instale o APEX no mesmo ambiente.
Se um requisito de um módulo não for atendido, ele não será construído.
Nome do módulo | Opção de instalação | Misc |
---|---|---|
apex_C | --cpp_ext | |
amp_C | --cuda_ext | |
syncbn | --cuda_ext | |
fused_layer_norm_cuda | --cuda_ext | apex.normalization |
mlp_cuda | --cuda_ext | |
scaled_upper_triang_masked_softmax_cuda | --cuda_ext | |
generic_scaled_masked_softmax_cuda | --cuda_ext | |
scaled_masked_softmax_cuda | --cuda_ext | |
fused_weight_gradient_mlp_cuda | --cuda_ext | Requer cuda> = 11 |
permutation_search_cuda | --permutation_search | apex.contrib.sparsity |
bnp | --bnp | apex.contrib.groupbn |
xentropy | --xentropy | apex.contrib.xentropy |
focal_loss_cuda | --focal_loss | apex.contrib.focal_loss |
fused_index_mul_2d | --index_mul_2d | apex.contrib.index_mul_2d |
fused_adam_cuda | --deprecated_fused_adam | apex.contrib.optimizers |
fused_lamb_cuda | --deprecated_fused_lamb | apex.contrib.optimizers |
fast_layer_norm | --fast_layer_norm | apex.contrib.layer_norm . Diferente de fused_layer_norm |
fmhalib | --fmha | apex.contrib.fmha |
fast_multihead_attn | --fast_multihead_attn | apex.contrib.multihead_attn |
transducer_joint_cuda | --transducer | apex.contrib.transducer |
transducer_loss_cuda | --transducer | apex.contrib.transducer |
cudnn_gbn_lib | --cudnn_gbn | Requer cudnn> = 8.5, apex.contrib.cudnn_gbn |
peer_memory_cuda | --peer_memory | apex.contrib.peer_memory |
nccl_p2p_cuda | --nccl_p2p | Requer NCCL> = 2.10, apex.contrib.nccl_p2p |
fast_bottleneck | --fast_bottleneck | Requer peer_memory_cuda e nccl_p2p_cuda , apex.contrib.bottleneck |
fused_conv_bias_relu | --fused_conv_bias_relu | Requer cudnn> = 8.4, apex.contrib.conv_bias_relu |