Status atual do IC:
O Pytorch/XLA é um pacote Python que usa o compilador de aprendizado profundo do XLA para conectar a estrutura de aprendizado profundo do Pytorch e as TPUs em nuvem. Você pode experimentar agora, de graça, em uma única TPU VM com Kaggle!
Dê uma olhada em um dos nossos cadernos Kaggle para começar:
Para instalar a compilação estável pytorch/xla em uma nova TPU VM:
pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html
Para instalar o Pytorch/XLA Nightly Build em uma nova TPU VM:
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html
Pytorch/XLA agora fornece suporte à GPU através de um pacote de plug -in semelhante ao libtpu
:
pip install torch~=2.5.0 torch_xla~=2.5.0 https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla_cuda_plugin-2.5.0-py3-none-any.whl
Para atualizar seu loop de treinamento existente, faça as seguintes alterações:
- import torch.multiprocessing as mp
+ import torch_xla as xla
+ import torch_xla.core.xla_model as xm
def _mp_fn(index):
...
+ # Move the model paramters to your XLA device
+ model.to(xla.device())
for inputs, labels in train_loader:
+ with xla.step():
+ # Transfer data to the XLA device. This happens asynchronously.
+ inputs, labels = inputs.to(xla.device()), labels.to(xla.device())
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
- optimizer.step()
+ # `xm.optimizer_step` combines gradients across replicas
+ xm.optimizer_step(optimizer)
if __name__ == '__main__':
- mp.spawn(_mp_fn, args=(), nprocs=world_size)
+ # xla.launch automatically selects the correct world size
+ xla.launch(_mp_fn, args=())
Se você estiver usando DistributedDataParallel
, faça as seguintes alterações:
import torch.distributed as dist
- import torch.multiprocessing as mp
+ import torch_xla as xla
+ import torch_xla.distributed.xla_backend
def _mp_fn(rank):
...
- os.environ['MASTER_ADDR'] = 'localhost'
- os.environ['MASTER_PORT'] = '12355'
- dist.init_process_group("gloo", rank=rank, world_size=world_size)
+ # Rank and world size are inferred from the XLA device runtime
+ dist.init_process_group("xla", init_method='xla://')
+
+ model.to(xm.xla_device())
+ # `gradient_as_bucket_view=True` required for XLA
+ ddp_model = DDP(model, gradient_as_bucket_view=True)
- model = model.to(rank)
- ddp_model = DDP(model, device_ids=[rank])
for inputs, labels in train_loader:
+ with xla.step():
+ inputs, labels = inputs.to(xla.device()), labels.to(xla.device())
optimizer.zero_grad()
outputs = ddp_model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
if __name__ == '__main__':
- mp.spawn(_mp_fn, args=(), nprocs=world_size)
+ xla.launch(_mp_fn, args=())
Informações adicionais sobre Pytorch/XLA, incluindo uma descrição de sua semântica e funções, estão disponíveis em pytorch.org. Consulte o Guia da API para obter as melhores práticas ao escrever redes que são executadas em dispositivos XLA (TPU, CUDA, CPU e ...).
Nossos guias abrangentes de usuário estão disponíveis em:
Documentação para o último lançamento
Documentação para filial mestre
Os lançamentos Pytorch/XLA começando com a versão R2.1 estarão disponíveis no Pypi. Agora você pode instalar a compilação principal com pip install torch_xla
. Para também instalar o plug -in da Cloud TPU correspondente ao seu torch_xla
instalado, instale as dependências opcionais tpu
após a instalação da compilação principal com
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
A GPU e as construções noturnas estão disponíveis em nosso balde público GCS.
Versão | Rodas VM da GPU em nuvem |
---|---|
2.5 (CUDA 12.1 + Python 3.9) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.1 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.1 + Python 3.11) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.4 + Python 3.9) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.4 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.4 + Python 3.11) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl |
Nightly (Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp38-cp38-linux_x86_64.whl |
Nightly (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl |
Nightly (Cuda 12.1 + Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.6.0.dev-cp38-cp38-linux_x86_64.whl |
pip3 install torch==2.6.0.dev20240925+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly%2B20240925-cp310-cp310-linux_x86_64.whl
A versão da roda da tocha versão 2.6.0.dev20240925+cpu
pode ser encontrada em https://download.pytorch.org/whl/nightly/torch/.
Você também pode adicionar yyyymmdd
após torch_xla-2.6.0.dev
para obter a roda noturna de uma data especificada. Aqui está um exemplo:
pip3 install torch==2.5.0.dev20240820+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0.dev20240820-cp310-cp310-linux_x86_64.whl
A versão da roda da tocha versão 2.6.0.dev20240925+cpu
pode ser encontrada em https://download.pytorch.org/whl/nightly/torch/.
Versão | Roda de VMs TPU em nuvem |
---|---|
2.4 (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.4.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.3 (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.2 (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.1 (xrt + python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/xrt/tpuvm/torch_xla-2.1.0%2Bxrt-cp310-cp310-manylinux_2_28_x86_64.whl |
2.1 (Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.1.0-cp38-cp38-linux_x86_64.whl |
Versão | Roda de GPU |
---|---|
2.5 (CUDA 12.1 + Python 3.9) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.1 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.1 + Python 3.11) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.4 + Python 3.9) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.4 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.4 + Python 3.11) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl |
2.4 (CUDA 12.1 + Python 3.9) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp39-cp39-manylinux_2_28_x86_64.whl |
2.4 (CUDA 12.1 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.4 (CUDA 12.1 + Python 3.11) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp311-cp311-manylinux_2_28_x86_64.whl |
2.3 (CUDA 12.1 + Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp38-cp38-manylinux_2_28_x86_64.whl |
2.3 (CUDA 12.1 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.3 (CUDA 12.1 + Python 3.11) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp311-cp311-manylinux_2_28_x86_64.whl |
2.2 (CUDA 12.1 + Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl |
2.2 (CUDA 12.1 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.1 + CUDA 11.8 | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/11.8/torch_xla-2.1.0-cp38-cp38-manylinux_2_28_x86_64.whl |
Nightly + Cuda 12.0> = 2023/06/27 | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.0/torch_xla-nightly-cp38-cp38-linux_x86_64.whl |
Versão | Cloud TPU VMS Docker |
---|---|
2.5 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_tpuvm |
2.4 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm |
2.3 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_tpuvm |
2.2 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.2.0_3.10_tpuvm |
2.1 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_tpuvm |
Python noturno | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm |
Para usar os Dockers acima, passe --privileged --net host --shm-size=16G
junto. Aqui está um exemplo:
docker run --privileged --net host --shm-size=16G -it us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm /bin/bash
Versão | GPU CUDA 12.4 DOCKER |
---|---|
2.5 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_cuda_12.4 |
2.4 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_cuda_12.4 |
Versão | GPU CUDA 12.1 Docker |
---|---|
2.5 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_cuda_12.1 |
2.4 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_cuda_12.1 |
2.3 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_cuda_12.1 |
2.2 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.2.0_3.10_cuda_12.1 |
2.1 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_12.1 |
todas as noites | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1 |
todas as noites na data | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1_YYYYMMDD |
Versão | GPU CUDA 11.8 + Docker |
---|---|
2.1 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_11.8 |
2.0 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.0_3.8_cuda_11.8 |
Para executar em instâncias de computação com GPUs.
Se o Pytorch/XLA não estiver funcionando como esperado, consulte o Guia de Solução de Problemas, que possui sugestões para depuração e otimização de suas redes.
A equipe Pytorch/XLA está sempre feliz em ouvir os usuários e os colaboradores da OSS! A melhor maneira de alcançar é arquivando um problema neste github. Perguntas, relatórios de bugs, solicitações de recursos, problemas de construção etc. são todos bem -vindos!
Veja o guia de contribuição.
Este repositório é operado e mantido em conjunto pelo Google, Meta e vários colaboradores individuais listados no arquivo dos colaboradores. Para perguntas direcionadas à Meta, envie um e -mail para [email protected]. Para perguntas direcionadas ao Google, envie um e-mail para [email protected]. Para todas as outras perguntas, abra um problema neste repositório aqui.
Você pode encontrar materiais de leitura úteis adicionais em