?网站| ?文档|安装指南|教程|示例|推特|领英|中等的
Optuna是一个自动超参数优化软件框架,专为机器学习而设计。它具有命令式、按运行定义风格的用户 API。得益于我们的按运行定义API,使用 Optuna 编写的代码具有高度模块化性,Optuna 用户可以动态构建超参数的搜索空间。
Terminator
文章,该文章在 Optuna 4.0 中进行了扩展。JournalStorage
文章,该文章已在 Optuna 4.0 中稳定下来。pip install -U optuna
安装它。在这里查找最新信息并查看我们的文章。 Optuna 具有以下现代功能:
我们使用术语“研究”和“试验”如下:
请参考下面的示例代码。研究的目标是通过多次试验(例如n_trials=100
)找出最佳的超参数值集(例如regressor
和svr_c
)。 Optuna 是一个专为自动化和加速优化研究而设计的框架。
import ...
# Define an objective function to be minimized.
def objective ( trial ):
# Invoke suggest methods of a Trial object to generate hyperparameters.
regressor_name = trial . suggest_categorical ( 'regressor' , [ 'SVR' , 'RandomForest' ])
if regressor_name == 'SVR' :
svr_c = trial . suggest_float ( 'svr_c' , 1e-10 , 1e10 , log = True )
regressor_obj = sklearn . svm . SVR ( C = svr_c )
else :
rf_max_depth = trial . suggest_int ( 'rf_max_depth' , 2 , 32 )
regressor_obj = sklearn . ensemble . RandomForestRegressor ( max_depth = rf_max_depth )
X , y = sklearn . datasets . fetch_california_housing ( return_X_y = True )
X_train , X_val , y_train , y_val = sklearn . model_selection . train_test_split ( X , y , random_state = 0 )
regressor_obj . fit ( X_train , y_train )
y_pred = regressor_obj . predict ( X_val )
error = sklearn . metrics . mean_squared_error ( y_val , y_pred )
return error # An objective value linked with the Trial object.
study = optuna . create_study () # Create a new study.
study . optimize ( objective , n_trials = 100 ) # Invoke optimization of the objective function.
笔记
更多示例可以在 optuna/optuna-examples 中找到。
这些示例涵盖了不同的问题设置,例如多目标优化、约束优化、剪枝和分布式优化。
Optuna 可在 Python 包索引和 Anaconda Cloud 上获取。
# PyPI
$ pip install optuna
# Anaconda Cloud
$ conda install -c conda-forge optuna
重要的
Optuna 支持 Python 3.8 或更高版本。
此外,我们还在 DockerHub 上提供 Optuna docker 镜像。
Optuna 具有与各种第三方库的集成功能。集成可以在 optuna/optuna-integration 中找到,文档可以在此处找到。
Optuna Dashboard 是 Optuna 的实时 Web 仪表板。您可以通过图表查看优化历史记录、超参数重要性等。您无需创建Python脚本来调用Optuna的可视化功能。欢迎提出功能请求和错误报告!
optuna-dashboard
可以通过 pip 安装:
$ pip install optuna-dashboard
提示
请使用下面的示例代码查看 Optuna Dashboard 的便利性。
将以下代码保存为optimize_toy.py
。
import optuna
def objective ( trial ):
x1 = trial . suggest_float ( "x1" , - 100 , 100 )
x2 = trial . suggest_float ( "x2" , - 100 , 100 )
return x1 ** 2 + 0.01 * x2 ** 2
study = optuna . create_study ( storage = "sqlite:///db.sqlite3" ) # Create a new study with database.
study . optimize ( objective , n_trials = 100 )
然后尝试以下命令:
# Run the study specified above
$ python optimize_toy.py
# Launch the dashboard based on the storage `sqlite:///db.sqlite3`
$ optuna-dashboard sqlite:///db.sqlite3
...
Listening on http://localhost:8080/
Hit Ctrl-C to quit.
OptunaHub 是 Optuna 的功能共享平台。您可以使用注册的功能并发布您的包。
optunahub
可以通过 pip 安装:
$ pip install optunahub
# Install AutoSampler dependencies (CPU only is sufficient for PyTorch)
$ pip install cmaes scipy torch --extra-index-url https://download.pytorch.org/whl/cpu
您可以使用optunahub.load_module
加载已注册的模块。
import optuna
import optunahub
def objective ( trial : optuna . Trial ) -> float :
x = trial . suggest_float ( "x" , - 5 , 5 )
y = trial . suggest_float ( "y" , - 5 , 5 )
return x ** 2 + y ** 2
module = optunahub . load_module ( package = "samplers/auto_sampler" )
study = optuna . create_study ( sampler = module . AutoSampler ())
study . optimize ( objective , n_trials = 10 )
print ( study . best_trial . value , study . best_trial . params )
有关更多详细信息,请参阅 optunahub 文档。
您可以通过 optunahub-registry 发布您的包。请参阅 OptunaHub 教程。
我们非常欢迎对 Optuna 做出任何贡献!
如果您是 Optuna 的新手,请查看优先问题。它们相对简单、定义明确,并且通常是您熟悉贡献工作流程和其他开发人员的良好起点。
如果您已经为 Optuna 做出过贡献,我们会推荐其他欢迎贡献的问题。
有关如何为项目做出贡献的一般准则,请查看 CONTRIBUTING.md。
如果您在某个研究项目中使用 Optuna,请引用我们的 KDD 论文“Optuna:下一代超参数优化框架”:
@inproceedings { akiba2019optuna ,
title = { {O}ptuna: A Next-Generation Hyperparameter Optimization Framework } ,
author = { Akiba, Takuya and Sano, Shotaro and Yanase, Toshihiko and Ohta, Takeru and Koyama, Masanori } ,
booktitle = { The 25th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining } ,
pages = { 2623--2631 } ,
year = { 2019 }
}
麻省理工学院许可证(请参阅许可证)。
Optuna 使用 SciPy 和 fdlibm 项目的代码(请参阅 LICENSE_THIRD_PARTY)。