文档|示例
Sonnet 是一个构建在 TensorFlow 2 之上的库,旨在为机器学习研究提供简单、可组合的抽象。
Sonnet 是由 DeepMind 的研究人员设计和构建的。它可用于构建用于许多不同目的的神经网络(无监督学习、强化学习……)。我们发现它对我们的组织来说是一个成功的抽象,您也可能如此!
更具体地说,Sonnet 提供了一个简单但功能强大的编程模型,该模型围绕一个概念: snt.Module
。模块可以保存对参数、其他模块和对用户输入应用某些功能的方法的引用。 Sonnet 附带了许多预定义的模块(例如snt.Linear
、 snt.Conv2D
、 snt.BatchNorm
)和一些预定义的模块网络(例如snt.nets.MLP
),但也鼓励用户构建自己的模块。
与许多框架不同,Sonnet 对于如何使用模块没有任何意见。模块被设计为独立的并且彼此完全解耦。 Sonnet 不附带培训框架,鼓励用户构建自己的框架或采用其他人构建的框架。
Sonnet 的设计也易于理解,我们的代码(希望如此!)清晰且重点突出。在我们选择默认值(例如初始参数值的默认值)的地方,我们尝试指出原因。
尝试 Sonnet 的最简单方法是使用 Google Colab,它提供了一个连接到 GPU 或 TPU 的免费 Python 笔记本。
snt.distribute
进行分布式训练要开始安装 TensorFlow 2.0 和 Sonnet 2:
$ pip install tensorflow tensorflow-probability
$ pip install dm-sonnet
您可以运行以下命令来验证安装是否正确:
import tensorflow as tf
import sonnet as snt
print ( "TensorFlow version {}" . format ( tf . __version__ ))
print ( "Sonnet version {}" . format ( snt . __version__ ))
Sonnet 附带了许多您可以轻松使用的内置模块。例如,要定义 MLP,我们可以使用snt.Sequential
模块来调用一系列模块,将给定模块的输出作为下一个模块的输入。我们可以使用snt.Linear
和tf.nn.relu
来实际定义我们的计算:
mlp = snt . Sequential ([
snt . Linear ( 1024 ),
tf . nn . relu ,
snt . Linear ( 10 ),
])
要使用我们的模块,我们需要“调用”它。 Sequential
模块(以及大多数模块)定义了一个__call__
方法,这意味着您可以通过名称调用它们:
logits = mlp ( tf . random . normal ([ batch_size , input_size ]))
请求模块的所有参数也很常见。 Sonnet 中的大多数模块在第一次使用某些输入调用时都会创建其参数(因为在大多数情况下,参数的形状是输入的函数)。 Sonnet 模块提供了两个用于访问参数的属性。
variables
属性返回给定模块引用的所有tf.Variable
:
all_variables = mlp . variables
值得注意的是tf.Variable
不仅仅用于模型的参数。例如,它们用于保存snt.BatchNorm
中使用的指标的状态。在大多数情况下,用户检索模块变量以将它们传递给优化器进行更新。在这种情况下,不可训练的变量通常不应出现在该列表中,因为它们是通过不同的机制进行更新的。 TensorFlow 有一个内置机制,可将变量标记为“可训练”(模型的参数)与不可训练(其他变量)。 Sonnet 提供了一种从模块中收集所有可训练变量的机制,这可能是您想要传递给优化器的内容:
model_parameters = mlp . trainable_variables
Sonnet 强烈鼓励用户子类化snt.Module
来定义自己的模块。让我们首先创建一个名为MyLinear
的简单Linear
层:
class MyLinear ( snt . Module ):
def __init__ ( self , output_size , name = None ):
super ( MyLinear , self ). __init__ ( name = name )
self . output_size = output_size
@ snt . once
def _initialize ( self , x ):
initial_w = tf . random . normal ([ x . shape [ 1 ], self . output_size ])
self . w = tf . Variable ( initial_w , name = "w" )
self . b = tf . Variable ( tf . zeros ([ self . output_size ]), name = "b" )
def __call__ ( self , x ):
self . _initialize ( x )
return tf . matmul ( x , self . w ) + self . b
使用这个模块很简单:
mod = MyLinear ( 32 )
mod ( tf . ones ([ batch_size , input_size ]))
通过子类化snt.Module
您可以免费获得许多不错的属性。例如__repr__
的默认实现,它显示构造函数参数(对于调试和内省非常有用):
>> > print ( repr ( mod ))
MyLinear ( output_size = 10 )
您还可以获得variables
和trainable_variables
属性:
>> > mod . variables
( < tf . Variable 'my_linear/b:0' shape = ( 10 ,) ...) > ,
< tf . Variable 'my_linear/w:0' shape = ( 1 , 10 ) ...) > )
您可能会注意到上面变量上的my_linear
前缀。这是因为每当调用方法时,Sonnet 模块也会进入模块名称范围。通过输入模块名称范围,我们为 TensorBoard 之类的工具提供了一个更有用的图表来使用(例如,my_linear 内发生的所有操作都将位于名为 my_linear 的组中)。
此外,您的模块现在将支持 TensorFlow 检查点和保存模型,这些是稍后介绍的高级功能。
Sonnet 支持多种序列化格式。我们支持的最简单的格式是 Python 的pickle
,并且所有内置模块都经过测试,以确保它们可以在同一 Python 进程中通过 pickle 保存/加载。一般来说,我们不鼓励使用 pickle,它没有得到 TensorFlow 许多部分的良好支持,而且根据我们的经验,它可能非常脆弱。
参考: https://www.tensorflow.org/alpha/guide/checkpoints
TensorFlow 检查点可用于在训练期间定期保存参数值。这对于保存训练进度很有用,以防程序崩溃或停止。 Sonnet 旨在与 TensorFlow 检查点完美配合:
checkpoint_root = "/tmp/checkpoints"
checkpoint_name = "example"
save_prefix = os . path . join ( checkpoint_root , checkpoint_name )
my_module = create_my_sonnet_module () # Can be anything extending snt.Module.
# A `Checkpoint` object manages checkpointing of the TensorFlow state associated
# with the objects passed to it's constructor. Note that Checkpoint supports
# restore on create, meaning that the variables of `my_module` do **not** need
# to be created before you restore from a checkpoint (their value will be
# restored when they are created).
checkpoint = tf . train . Checkpoint ( module = my_module )
# Most training scripts will want to restore from a checkpoint if one exists. This
# would be the case if you interrupted your training (e.g. to use your GPU for
# something else, or in a cloud environment if your instance is preempted).
latest = tf . train . latest_checkpoint ( checkpoint_root )
if latest is not None :
checkpoint . restore ( latest )
for step_num in range ( num_steps ):
train ( my_module )
# During training we will occasionally save the values of weights. Note that
# this is a blocking call and can be slow (typically we are writing to the
# slowest storage on the machine). If you have a more reliable setup it might be
# appropriate to save less frequently.
if step_num and not step_num % 1000 :
checkpoint . save ( save_prefix )
# Make sure to save your final values!!
checkpoint . save ( save_prefix )
参考: https://www.tensorflow.org/alpha/guide/saved_model
TensorFlow 保存的模型可用于保存与 Python 源分离的网络副本。这是通过保存描述计算的 TensorFlow 图和包含权重值的检查点来实现的。
为了创建保存的模型,要做的第一件事是创建一个要保存的snt.Module
:
my_module = snt . nets . MLP ([ 1024 , 1024 , 10 ])
my_module ( tf . ones ([ 1 , input_size ]))
接下来,我们需要创建另一个模块来描述我们想要导出的模型的特定部分。我们建议这样做(而不是就地修改原始模型),以便您可以对实际导出的内容进行细粒度控制。这通常很重要,可以避免创建非常大的保存模型,这样您就可以只共享模型中您想要的部分(例如,您只想共享 GAN 的生成器,但保持鉴别器私有)。
@ tf . function ( input_signature = [ tf . TensorSpec ([ None , input_size ])])
def inference ( x ):
return my_module ( x )
to_save = snt . Module ()
to_save . inference = inference
to_save . all_variables = list ( my_module . variables )
tf . saved_model . save ( to_save , "/tmp/example_saved_model" )
现在,我们在/tmp/example_saved_model
文件夹中有一个保存的模型:
$ ls -lh /tmp/example_saved_model
total 24K
drwxrwsr-t 2 tomhennigan 154432098 4.0K Apr 28 00:14 assets
-rw-rw-r-- 1 tomhennigan 154432098 14K Apr 28 00:15 saved_model.pb
drwxrwsr-t 2 tomhennigan 154432098 4.0K Apr 28 00:15 variables
加载此模型非常简单,可以在不同的计算机上完成,无需任何构建已保存模型的 Python 代码:
loaded = tf . saved_model . load ( "/tmp/example_saved_model" )
# Use the inference method. Note this doesn't run the Python code from `to_save`
# but instead uses the TensorFlow Graph that is part of the saved model.
loaded . inference ( tf . ones ([ 1 , input_size ]))
# The all_variables property can be used to retrieve the restored variables.
assert len ( loaded . all_variables ) > 0
请注意,加载的对象不是 Sonnet 模块,而是一个容器对象,具有我们在上一个块中添加的特定方法(例如inference
)和属性(例如all_variables
)。
示例: https://github.com/deepmind/sonnet/blob/v2/examples/distributed_cifar10.ipynb
Sonnet 支持使用自定义 TensorFlow 分发策略进行分布式训练。
Sonnet 和使用tf.keras
的分布式训练之间的一个关键区别是,Sonnet 模块和优化器在分布策略下运行时的行为没有不同(例如,我们不会平均您的梯度或同步您的批量标准统计数据)。我们认为用户应该完全控制他们的培训的这些方面,并且他们不应该被纳入库中。这里的权衡是,您需要在训练脚本中实现这些功能(通常这只是 2 行代码,以便在应用优化器之前减少梯度)或交换明确分布感知的模块(例如snt.distribute.CrossReplicaBatchNorm
)。
我们的分布式 Cifar-10 示例演示了如何使用 Sonnet 进行多 GPU 训练。