Sandwood 是一种基于 JVM 的概率模型的语言、编译器和运行时。它旨在允许使用 Java 开发人员熟悉的语言编写模型。生成的模型采用 Java 对象的形式,使它们成为周围系统的良好抽象组件。
使用传统的贝叶斯模型,用户必须设计模型,然后为他们希望在模型上执行的任何操作实现推理代码。这会产生许多问题:
构建推理代码在技术上具有挑战性,而且非常耗时。此步骤提供了引入细微错误的机会。
如果模型被修改,那么推理代码也必须更新。这也很耗时且技术上具有挑战性,导致以下问题:
它可以阻止修改模型。
不同的推理操作可能会不一致,因此一些工作针对旧模型,一些工作针对新模型。
当用户尝试对现有代码进行细微调整时,它为错误进入推理算法提供了另一个机会。
概率编程通过允许使用 API 或领域特定语言 (DSL) 来描述模型(如 Sandwood 的情况)来克服这些问题。 Sandwood DSL 被编译为生成代表模型并实现所有所需推理操作的 Java 类。这有很多优点:
Sandwood 由 3 个组件组成,每个组件都在其对应的目录中:
每个部分都依赖于前面的部分。每个组件目录都包含一个用于构建组件的 Maven POM 文件。对于编译器和插件,需要使用install
来调用它们,以使它们可用于后续阶段,即mvn clean install
。这些示例只能构建为mvn clean package
。
安装 Sandwood 后,目前有 3 种编译模型的方法:
构建编译器和运行时后,要从命令行使用 Sandwood,可以在commandline/SandwoodC/bin
中找到与javac
具有类似功能的命令行脚本。要使用此功能,用户通常将 bin 目录添加到路径中,然后调用 sandwoodc.sh HMM.sandwood 来编译 HMM 模型。 sandwoodc.sh -h
或sandwoodc.bat -h
将打印出用法和可用选项的描述。
SandwoodC 的所有功能都可以通过调用org.sandwood.compilation.SandwoodC
中的compile
方法并传递一个包含已传递到命令行的参数的数组来实现。
Maven插件可用于在依赖项目构建时自动触发sandwood文件的编译。要使用该插件,您需要添加 sandwood 运行时作为依赖项,并将插件添加到构建中。这是通过在 POM 文件中添加以下内容来实现的:
<dependencies>
<dependency>
<groupId>org.sandwood</groupId>
<artifactId>sandwood-runtime</artifactId>
<version>0.3.0</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.sandwood</groupId>
<artifactId>sandwoodc-maven-plugin</artifactId>
<version>0.3-SNAPSHOT</version>
<executions>
<execution>
<configuration>
<partialInferenceWarning>true</partialInferenceWarning>
<sourceDirectory>${basedir}/src/main/java</sourceDirectory>
</configuration>
<goals>
<goal>sandwoodc</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>`
包含元素<sourceDirectory>${basedir}/src/main/java</sourceDirectory>
指示插件在哪个目录中查找模型。其他有用的标志包括:
debug
该选项用于从SandwoodC 获取调试信息。将此选项设置为true
会导致 Sandwood 生成其操作的痕迹。默认值为false
。请注意,此标志用于调试编译器配置/编译器的错误,而不是用于调试正在编译的模型的错误。编译器将始终返回 Sandwood 模型文件中的错误和警告。
partialInferenceWarning
该选项用于在无法构建某些推理步骤时阻止 SandwoodC 失败。将此选项设置为true
会导致 Sandwood 仅针对缺少的步骤生成警告。默认值为false
。
sourceDirectory
此参数设置在哪个目录中查找模型文件。在此目录中,模型可以位于不同的包中。
outputDirectory
此参数设置模型的 Java 源代码应放置在哪个目录中。默认值为${project.build.directory}/generated-sources/sandwood
。
calculateIndividualProbabilities
此参数指定是否应计算循环中构造的每个随机变量的概率,而不是计算所有实例的单个值。默认值为false
。
javadoc
此参数指示编译器生成 JavaDoc 以补充模型。默认值为false
。
javadocDirectory
该参数指定生成的文档应放置的位置。
executable
此参数允许指定替代 JVM 来运行 Sandwood 编译器。
接下来介绍如何编写 Sandwood 模型以及如何使用实现模型的结果类。
在此图中可以看到模型所经历的步骤的概述。模型以.sandwood
文件开始,该文件被编译为一组类文件。这些可以被实例化多次,以生成具有不同配置的模型的多个实例。
作为一个运行示例,我们将使用隐马尔可夫模型(HMM)。该模型是用 Sandwood 编写的。该模型应保存在包目录org/sandwood/examples/hmm
中名为HMM.sandwood
的文件中。可以在此处找到该语言的更完整描述。
package org . sandwood . examples . hmm ;
model HMM ( int [] eventsMeasured , int numStates , int numEvents ) {
//Construct a transition matrix m.
double [] v = new double [ numStates ] <~ 0.1 ;
double [][] m = dirichlet ( v ). sample ( numStates );
//Construct weighting for which state to start in.
double [] initialState = new Dirichlet ( v ). sample ();
//Construct weighting for each event in each state.
double [] w = new double [ numEvents ] <~ 0.1 ;
double [][] bias = dirichlet ( w ). sample ( numStates );
//Allocate space to record the sequence of states.
int sequenceLength = eventsMeasured . length ;
int [] st = new int [ sequenceLength ];
//Calculate the movements between states.
st [ 0 ] = categorical ( initialState ). sampleDistribution ();
for ( int i : [ 1. . sequenceLength ) )
st [ i ] = categorical ( m [ st [ i - 1 ]]). sampleDistribution ();
//Emit the events for each state.
int [] events = new int [ sequenceLength ];
for ( int j = 0 ; j < sequenceLength ; j ++)
events [ j ] = new Categorical ( bias [ st [ j ]]). sample ();
//Assert that the events match the eventsMeasured data.
events . observe ( eventsMeasured );
}
除了 Sandwood 语言的文档和可为模型生成的 JavaDoc 注释之外,Sandwood Examples 目录中还有许多示例,我们建议新用户首先检查和修改这些示例。
用于描述 Sandwood 模型的语言的描述可以在此处找到。该语言的构建目的是让 Java 开发人员熟悉,但不包含构建对象的能力。我们计划将来添加对记录类型的支持,以使模型中的数据导入和导出更加简单。
编译模型时,会在定义模型的同一包中生成许多类文件。其中一个类的名称将与提供给模型的名称相同,因此在本例中为HMM.class ,并且此是用户应该实例化以获得模型实例的类。模型中的每个公开可见的变量对应于生成的类中的一个字段。示例 HMM 如下所示。
通过运行带有javadoc
标志集的编译器,将为生成的模型文件中的每个公共方法和类创建 JavaDoc。
模型编译完成后,我们需要实例化它的实例。这些实例是独立的,用户可以根据需要创建任意多个不同的模型副本。
模型对象的实例是通过类构造函数构造的。如前所述,模型通常有 3 个构造函数。唯一会减少的情况是构造函数的不同变体映射到相同的签名,在这种情况下,一个构造函数将适用于多个场景。
完整构造函数 - 此构造函数获取模型签名中出现的所有参数并设置它们。此构造函数用于推断值和推断概率操作。
空构造函数 - 此构造函数不带任何参数,将参数留给用户稍后设置。
执行构造函数 - 此构造函数删除仅观察到的参数,并且对于其维度用作代码输入的观察到的参数,采用这些维度而不是完整参数。因此,在 HMM 示例中,eventsMeasured 参数将变为描述序列长度的整数。
这些代码示例演示了如何调用已编译的模型。
通过模型对象与模型的交互有两种形式:
调用模型对象方法进行全局操作,例如设置默认保留策略、检查模型是否准备好进行推理以及启动推理步骤等。
调用模型参数对象。模型中的每个命名公共变量都由模型对象中的相应字段表示。如果变量在模型的最外层作用域中声明且未标记为private
,或者在内部作用域中声明且未标记为public
,则变量是公共的。如果某个字段在内部迭代作用域中声明为公共,例如 for 循环体,则将存储每次迭代的值。
对象的类型取决于变量。这些可以分为 3 类:
这些字段中的每一个都引用一个具有一组方法的对象,这些方法允许用户从参数中设置和读取值和属性。可以设置和读取的属性包括参数的概率、参数的保留策略以及参数是否应固定为其当前值。
参数对象在进行模型推理时比较重要的一些方法有:
getSamples 返回采样值。
getMAP 返回最大后验值。
setValue 允许将值设置为特定值。
setFixed 接受一个boolean
来将值标记为固定,因此在推理过程中不会更新。在修复参数之前设置参数值非常重要。
getLogProbability 在推断概率后获取变量的对数概率。
还有更多方法,我们建议您查阅 JavaDoc 来熟悉它们。
可以在模型上执行 3 种基本类型的操作:
setRentionPolicy
方法来为整个模型设置保留策略。然后,可以选择通过调用每个变量对象中相应的setRetentionPolicy
方法来设置各个变量的保留策略。抽样政策有3种:
NONE不记录值。如果其中一个变量很大,那么花费时间和空间来存储它会很浪费,这一点特别有用。
SAMPLE记录推理算法每次迭代的值,因此,如果执行 1000 次迭代,将从为此保留策略设置的每个变量中采样 1000 个值。这对于计算方差和平均值很有用。但这有一个弱点,如果模型内值的位置在推理过程中可以移动,则无法对这些值进行平均。例如,对于主题模型,主题 2 和 3 可能在推理过程中交换位置,因此对主题 2 的所有值进行平均,生成主题 2 和主题 3 的混合。为了克服这个问题,还提供了最大后验概率 (MAP)保留政策。
MAP或最大后验概率 (MAP) 记录模型处于最可能状态时的变量值。这克服了瞬态值位置的问题,这意味着值无法被平均,但代价是无法计算边界。如果某些变量很大,此选项还具有空间优势。
配置:模型对象上的其他方法调用允许用户在执行此推理步骤时设置属性,例如老化和细化。 Burnin 忽略前n次迭代的值,允许模型在开始采样之前远离低概率起点。细化通过仅考虑每n次迭代的值来减少 MCMC 过程引起的自相关。
推断概率设置模型中部分或所有参数的值后,即可计算生成这些值的概率。这可以针对模型中的每个变量以及整个模型进行计算。
执行模型运行模型,就好像它是为用户未修复的任何参数生成新值的常规代码一样。何时使用此行为的一个示例是线性回归模型。在这种情况下,首先使用训练数据推断模型系数。一旦它们被推断出来,它们将被固定并成为新的输入数据集。然后执行该模型以生成该新输入数据的相应预测。这种执行形式还可用于从经过训练的模型生成代表性合成数据。
构建和训练模型
//Load inputs
int nStates = 25 ;
int [] actions = loadActions (....);
int nActions = maxActions (....);
//Construct the model
HMM model = new HMM ( actions , nActions , nStates );
//Set the retention policies
model . setDefaultRetentionPolicy ( RetentionPolicy . MAP );
model . st . setRetentionPolicy ( RetentionPolicy . NONE );
//Pick a random number generator. The ones introduced in Java 17 are faster and better quality.
model . setRNGType ( RandomType . L64X1024MixRandom );
//Instruct the model to use the ForkJoin framework for parallel execution.
model . setExecutionTarget ( ExecutionTarget . forkJoin );
//Run 2000 inference steps to infer model values
model . inferValues ( 2000 );
//Gather the results.
double [] initialState = model . initialState . getMAP ();
double [][] bias = model . bias . getMAP ();
double [][] transitions = model . m . getMAP ();
构建模型并推断概率
//Load inputs
int nStates = 25 ;
int [] actions = loadActions (....);
int nActions = maxActions (....);
//Load model parameters
double [][] bias = model . bias . getMAP ();
double [][] transitions = model . m . getMAP ();
//Construct the model
HMM model = new HMM ( actions , nActions , nStates );
//Set and fix trained values
model . bias . setValue ( bias );
Model . m . setValue ( transitions );
//Run 2000 inference steps to infer probabilities
model . inferProbabilities ( 2000 );
//Recover the probabilities of the model parameter actions.
double actionsProbability = model . actions . getProbability ();
//Recover the probability of the model as a whole
double modelProbability = model . getProbability ()
如需有关 Sandwood 的帮助,请在讨论页面上发起或加入讨论。
该项目欢迎社区的贡献。在提交拉取请求之前,请查看我们的贡献指南。
请参阅安全指南,了解我们负责任的安全漏洞披露流程。
版权所有 (c) 2019-2024 Oracle 和/或其附属公司。
根据通用许可许可证 v1.0 发布,如 https://oss.oracle.com/licenses/upl/ 所示。