Keras 3 是一个多后端深度学习框架,支持 JAX、TensorFlow 和 PyTorch。轻松构建和训练计算机视觉、自然语言处理、音频处理、时间序列预测、推荐系统等模型。
与近 300 万开发人员(从新兴初创公司到跨国企业)一起利用 Keras 3 的强大功能。
Keras 3 在 PyPI 上以keras
形式提供。请注意,Keras 2 仍可作为tf-keras
软件包使用。
keras
: pip install keras --upgrade
要使用keras
,您还应该安装选择的后端: tensorflow
、 jax
或torch
。请注意,使用某些 Keras 3 功能需要tensorflow
:某些预处理层以及tf.data
管道。
Keras 3 兼容 Linux 和 MacOS 系统。对于 Windows 用户,我们建议使用 WSL2 来运行 Keras。要安装本地开发版本:
pip install -r requirements.txt
python pip_build.py --install
keras_export
公共 API 的 PR 时运行 API 生成脚本: ./shell/api_gen.sh
requirements.txt
文件将安装仅 CPU 版本的 TensorFlow、JAX 和 PyTorch。对于GPU支持,我们还为TensorFlow、JAX和PyTorch提供了单独的requirements-{backend}-cuda.txt
。它们通过pip
安装所有 CUDA 依赖项,并期望预安装 NVIDIA 驱动程序。我们建议每个后端使用干净的 python 环境,以避免 CUDA 版本不匹配。作为示例,以下是如何使用conda
创建 Jax GPU 环境:
conda create -y -n keras-jax python=3.10
conda activate keras-jax
pip install -r requirements-jax-cuda.txt
python pip_build.py --install
您可以导出环境变量KERAS_BACKEND
,也可以在~/.keras/keras.json
编辑本地配置文件来配置后端。可用的后端选项有: "tensorflow"
、 "jax"
、 "torch"
。例子:
export KERAS_BACKEND="jax"
在 Colab 中,您可以执行以下操作:
import os
os . environ [ "KERAS_BACKEND" ] = "jax"
import keras
注意:导入keras
之前必须配置后端,导入包后不能更改后端。
Keras 3 旨在作为tf.keras
的直接替代品(当使用 TensorFlow 后端时)。只需使用现有的tf.keras
代码,确保对model.save()
调用使用最新的.keras
格式,然后就完成了。
如果您的tf.keras
模型不包含自定义组件,您可以立即开始在 JAX 或 PyTorch 上运行它。
如果它确实包含自定义组件(例如自定义层或自定义train_step()
),通常可以在短短几分钟内将其转换为与后端无关的实现。
此外,Keras 模型可以使用任何格式的数据集,无论您使用的后端是什么:您可以使用现有的tf.data.Dataset
管道或 PyTorch DataLoaders
来训练模型。
Module
的一部分或 JAX 原生模型函数的一部分。请阅读 Keras 3 发布公告了解更多信息。