Karena NAS-Bench-201 telah diperluas ke Nats-Bench, repo ini sudah usang dan tidak dipertahankan. Harap gunakan Nats-Bench, yang memiliki informasi arsitektur 5x lebih banyak dan API lebih cepat daripada NAS-Bench-201.
Kami mengusulkan benchmark NAS algoritma-agnostik (NAS-Bench-201) dengan ruang pencarian tetap, yang menyediakan tolok ukur terpadu untuk hampir semua algoritma NAS terkini. Desain ruang pencarian kami terinspirasi oleh yang digunakan dalam algoritma pencarian berbasis sel yang paling populer, di mana sel diwakili sebagai grafik asiklik terarah. Setiap tepi di sini terkait dengan operasi yang dipilih dari set operasi yang telah ditentukan. Agar dapat diterapkan untuk semua algoritma NAS, ruang pencarian yang didefinisikan dalam NAS-Bench-201 mencakup 4 node dan 5 opsi operasi terkait, yang menghasilkan 15.625 kandidat sel saraf secara total.
Dalam file penurunan harga ini, kami menyediakan:
Untuk dua hal berikut, silakan gunakan Autodl-Projects:
Catatan: Harap gunakan PyTorch >= 1.2.0
dan Python >= 3.6.0
.
Anda dapat mengetik pip install nas-bench-201
untuk menginstal API kami. Silakan lihat kode sumber modul nas-bench-201
dalam repo ini.
Jika Anda memiliki pertanyaan atau masalah, silakan posting di sini atau email saya.
[DESTRECATED] File tolok ukur lama NAS-Bench-201 dapat diunduh dari Google Drive atau Baidu-wangpan (kode: 6U5D).
[Direkomendasikan] File benchmark terbaru NAS-Bench-201 ( NAS-Bench-201-v1_1-096897.pth
) dapat diunduh dari Google Drive. File untuk berat model terlalu besar (431g) dan saya perlu waktu untuk mengunggahnya. Harap bersabar, terima kasih atas pengertian Anda.
Anda dapat memindahkannya ke mana pun yang Anda inginkan dan mengirim jalurnya ke API kami untuk inisialisasi.
NAS-Bench-201-v1_0-e61699.pth
(2.2G), di mana e61699
adalah enam digit terakhir untuk file ini. Ini berisi semua informasi kecuali untuk bobot terlatih dari setiap percobaan.NAS-Bench-201-v1_1-096897.pth
(4.7g), di mana 096897
adalah enam digit terakhir untuk file ini. Ini berisi informasi lebih banyak uji coba dibandingkan dengan NAS-Bench-201-v1_0-e61699.pth
, terutama semua model yang dilatih oleh 12 zaman pada semua dataset tersedia. Kami merekomendasikan untuk menggunakan NAS-Bench-201-v1_1-096897.pth
Data pelatihan dan evaluasi yang digunakan dalam NAS-Bench-201 dapat diunduh dari Google Drive atau Baidu-Wangpan (kode: 4FG7). Disarankan untuk memasukkan data ini ke $TORCH_HOME
( ~/.torch/
secara default). Jika Anda ingin menghasilkan NAS-Bench-201 atau kumpulan data NAS serupa atau model pelatihan sendiri, Anda memerlukan data ini.
Lebih banyak penggunaan dapat ditemukan dalam kode uji kami .
from nas_201_api import NASBench201API as API
api = API('$path_to_meta_nas_bench_file')
# Create an API without the verbose log
api = API('NAS-Bench-201-v1_1-096897.pth', verbose=False)
# The default path for benchmark file is '{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-Bench-201-v1_1-096897.pth')
api = API(None)
len(api)
dan setiap api[i]
: num = len(api)
for i, arch_str in enumerate(api):
print ('{:5d}/{:5d} : {:}'.format(i, len(api), arch_str))
# show all information for a specific architecture
api.show(1)
api.show(2)
# show the mean loss and accuracy of an architecture
info = api.query_meta_info_by_index(1) # This is an instance of `ArchResults`
res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric names as keys
cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency
# get the detailed information
results = api.query_by_index(1, 'cifar100') # a dict of all trials for 1st net on cifar100, where the key is the seed
print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1]))
for seed, result in results.items():
print ('Latency : {:}'.format(result.get_latency()))
print ('Train Info : {:}'.format(result.get_train()))
print ('Valid Info : {:}'.format(result.get_eval('x-valid')))
print ('Test Info : {:}'.format(result.get_eval('x-test')))
# for the metric after a specific epoch
print ('Train Info [10-th epoch] : {:}'.format(result.get_train(10)))
index = api.query_index_by_arch('|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|')
api.show(index)
String ini |nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|
cara:
node-0: the input tensor
node-1: conv-3x3( node-0 )
node-2: conv-3x3( node-0 ) + avg-pool-3x3( node-1 )
node-3: skip-connect( node-0 ) + conv-3x3( node-1 ) + skip-connect( node-2 )
config = api.get_net_config(123, 'cifar10') # obtain the network configuration for the 123-th architecture on the CIFAR-10 dataset
from models import get_cell_based_tiny_net # this module is in AutoDL-Projects/lib/models
network = get_cell_based_tiny_net(config) # create the network from configurration
print(network) # show the structure of this architecture
Jika Anda ingin memuat bobot terlatih dari jaringan yang dibuat ini, Anda perlu menggunakan api.get_net_param(123, ...)
untuk mendapatkan bobot dan kemudian memuatnya ke jaringan.
api.get_more_info(...)
dapat mengembalikan kehilangan / akurasi / waktu pada pelatihan / validasi / set tes, yang sangat membantu. Untuk detail lebih lanjut, silakan lihat komentar di fungsi get_more_info.
Untuk penggunaan lain, silakan lihat lib/nas_201_api/api.py
. Kami memberikan beberapa informasi penggunaan dalam komentar untuk fungsi yang sesuai. Jika apa yang Anda inginkan tidak disediakan, jangan ragu untuk membuka masalah untuk diskusi, dan saya senang menjawab pertanyaan apa pun tentang NAS-Bench-201.
Dalam nas_201_api
, kami mendefinisikan tiga kelas: NASBench201API
, ArchResults
, ResultsCount
.
ResultsCount
menyimpan semua informasi dari uji coba tertentu. Seseorang dapat membuat instantiate hasil dan mendapatkan info melalui kode berikut ( 000157-FULL.pth
menyimpan semua informasi dari semua uji coba arsitektur ke-157):
from nas_201_api import ResultsCount
xdata = torch.load('000157-FULL.pth')
odata = xdata['full']['all_results'][('cifar10-valid', 777)]
result = ResultsCount.create_from_state_dict( odata )
print(result) # print it
print(result.get_train()) # print the final training loss/accuracy/[optional:time-cost-of-a-training-epoch]
print(result.get_train(11)) # print the training info of the 11-th epoch
print(result.get_eval('x-valid')) # print the final evaluation info on the validation set
print(result.get_eval('x-valid', 11)) # print the info on the validation set of the 11-th epoch
print(result.get_latency()) # print the evaluation latency [in batch]
result.get_net_param() # the trained parameters of this trial
arch_config = result.get_config(CellStructure.str2structure) # create the network with params
net_config = dict2config(arch_config, None)
network = get_cell_based_tiny_net(net_config)
network.load_state_dict(result.get_net_param())
ArchResults
menyimpan semua informasi dari semua uji coba arsitektur. Silakan lihat penggunaan berikut:
from nas_201_api import ArchResults
xdata = torch.load('000157-FULL.pth')
archRes = ArchResults.create_from_state_dict(xdata['less']) # load trials trained with 12 epochs
archRes = ArchResults.create_from_state_dict(xdata['full']) # load trials trained with 200 epochs
print(archRes.arch_idx_str()) # print the index of this architecture
print(archRes.get_dataset_names()) # print the supported training data
print(archRes.get_compute_costs('cifar10-valid')) # print all computational info when training on cifar10-valid
print(archRes.get_metrics('cifar10-valid', 'x-valid', None, False)) # print the average loss/accuracy/time on all trials
print(archRes.get_metrics('cifar10-valid', 'x-valid', None, True)) # print loss/accuracy/time of a randomly selected trial
NASBench201API
adalah API tingkat terbanyak. Silakan lihat penggunaan berikut:
from nas_201_api import NASBench201API as API
api = API('NAS-Bench-201-v1_1-096897.pth') # This will load all the information of NAS-Bench-201 except the trained weights
api = API('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-Bench-201-v1_1-096897.pth')) # The same as the above line while I usually save NAS-Bench-201-v1_1-096897.pth in ~/.torch/.
api.show(-1) # show info of all architectures
api.reload('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-BENCH-201-4-v1.0-archive'), 3) # This code will reload the information 3-th architecture with the trained weights
weights = api.get_net_param(3, 'cifar10', None) # Obtaining the weights of all trials for the 3-th architecture on cifar10. It will returns a dict, where the key is the seed and the value is the trained weights.
Untuk mendapatkan informasi pelatihan dan evaluasi (silakan lihat komentar di sini):
api.get_more_info(112, 'cifar10', None, hp='200', is_random=True)
# Query info of last training epoch for 112-th architecture
# using 200-epoch-hyper-parameter and randomly select a trial.
api.get_more_info(112, 'ImageNet16-120', None, hp='200', is_random=True)
Jika Anda menemukan bahwa NAS-Bench-201 membantu meneliti Anda, silakan pertimbangkan mengutipnya:
@inproceedings{dong2020nasbench201,
title = {NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search},
author = {Dong, Xuanyi and Yang, Yi},
booktitle = {International Conference on Learning Representations (ICLR)},
url = {https://openreview.net/forum?id=HJxyZkBKDr},
year = {2020}
}