这是一个用于试验不同主动学习算法的 python 模块。运行主动学习实验有几个关键组成部分:
主要实验脚本是run_experiment.py
其中包含许多用于不同运行选项的标志。
可以通过运行utils/create_data.py
将支持的数据集下载到指定目录。
支持的主动学习方法位于sampling_methods
中。
下面我将更详细地介绍每个组件。
免责声明:这不是 Google 官方产品。
依赖项位于requirements.txt
中。 请确保在运行实验之前安装这些软件包。 如果需要支持 GPU 的张tensorflow
,请按照此处的说明进行操作。
强烈建议您将所有依赖项安装到单独的virtualenv
中,以便于包管理。
默认情况下,数据集保存到/tmp/data
。您可以通过--save_dir
标志指定另一个目录。
重新下载所有数据集将非常耗时,因此请耐心等待。您可以通过--datasets
标志传入逗号分隔的数据集字符串来指定要下载的数据子集。
run_experiment.py
有几个关键标志:
dataset
:数据集的名称,必须与create_data.py
中使用的保存名称匹配。还必须存在于 data_dir 中。
sampling_method
:使用的主动学习方法。必须在sampling_methods/constants.py
中指定。
warmstart_size
:用作种子数据的初始批次均匀采样示例。浮点表示总训练数据的百分比,整数表示原始大小。
batch_size
:每批中请求的数据点数量。浮点表示总训练数据的百分比,整数表示原始大小。
score_method
:用于评估采样方法性能的模型。必须位于utils/utils.py
的get_model
方法中。
data_dir
:保存数据集的目录。
save_dir
:保存结果的目录。
这只是所有标志的子集。还有一些选项用于预处理、引入标签噪声、数据集二次采样以及使用不同的模型进行选择而不是评分/评估。
所有命名的主动学习方法都位于sampling_methods/constants.py
中。
您还可以按照由破折号分隔的[sampling_method]-[mixture_weight]
模式来指定主动学习方法的混合;即mixture_of_samplers-margin-0.33-informative_diverse-0.33-uniform-0.34
。
一些支持的采样方法包括:
均匀:通过均匀抽样来选择样本。
裕度:基于不确定性的抽样方法。
信息丰富且多样化:基于边际和聚类的抽样方法。
k 中心贪婪:贪婪地形成一批点以最小化距标记点的最大距离的代表性策略。
图密度:在池的密集区域中选择点的代表性策略。
Exp3 bandit:元主动学习方法,尝试使用流行的多臂老虎机算法来学习最佳采样方法。
实现继承自SamplingMethod
的基本采样器或调用继承自WrapperSamplingMethod
的基本采样器的元采样器。
任何采样器必须实现的唯一方法是select_batch_
,它可以具有任意命名参数。唯一的限制是相同输入的名称必须在所有采样器中保持一致(即已选定示例的索引在采样器中都具有相同的名称)。添加尚未在其他采样方法中使用的新命名参数需要将其输入到run_experiment.py
中的select_batch
调用中。
实现采样器后,请务必将其添加到constants.py
以便可以从run_experiment.py
调用它。
所有可用的模型都在utils/utils.py
的get_model
方法中。
支持的方法:
线性 SVM:带有网格搜索包装器的 scikit 方法,用于正则化参数。
内核 SVM:带有网格搜索包装器的 scikit 方法,用于正则化参数。
Logistc 回归:带有网格搜索包装器的 scikit 方法,用于正则化参数。
小型 CNN:使用 rmsprop 优化的 4 层 CNN,在 Keras 中实现,具有张量流后端。
内核最小二乘分类:块梯度下降求解器,可以使用多个内核,因此通常比 scikit 内核 SVM 更快。
新模型必须遵循 scikit learn api 并实现以下方法
fit(X, y[, sample_weight])
:将模型拟合到输入特征和目标。
predict(X)
:预测输入特征的值。
score(X, y)
:返回给定测试特征和测试目标的目标指标。
decision_function(X)
(可选):返回类概率、到决策边界的距离或可被边际采样器用作不确定性度量的其他度量。
有关示例,请参阅small_cnn.py
。
实现新模型后,请务必将其添加到utils/utils.py
的get_model
方法中。
目前,模型必须一次性添加,并且由于需要用户输入是否以及如何调整模型的超参数,因此并非所有 scikit-learn 分类器都受支持。然而,添加带有超参数搜索作为支持模型的 scikit-learn 模型非常容易。
utils/chart_data.py
脚本处理指定数据集和源目录的数据和图表。