Da unser NAS-Bench-2010 auf NATS-Bench ausgedehnt wurde, ist dieses Repo veraltet und nicht aufrechterhalten. Bitte verwenden Sie NATS-Bench, das 5x mehr Architekturinformationen und schnellere API als NAS-Bench-2010 hat.
Wir schlagen einen Algorithmus-Agnostic NAS-Benchmark (NAS-Bench-201) mit einem festen Suchraum vor, der einen einheitlichen Benchmark für fast jeden aktuellen NAS-Algorithmen bietet. Das Design unseres Suchraums wird von dem in den beliebtesten zellbasierten Suchalgorithmen inspiriert, bei denen eine Zelle als gerichteter acyclischer Graphen dargestellt wird. Jede Kante hier ist einer Operation zugeordnet, die aus einem vordefinierten Betriebssatz ausgewählt wurde. Damit es für alle NAS-Algorithmen anwendbar ist, umfasst der in NAS-Bench-201 definierte Suchraum 4 Knoten und 5 zugeordnete Betriebsoptionen, die insgesamt 15.625 Nervenzellenkandidaten erzeugen.
In dieser Markdown -Datei bieten wir:
Für die folgenden zwei Dinge verwenden Sie bitte Autodl-Projects:
Hinweis: Bitte verwenden Sie PyTorch >= 1.2.0
und Python >= 3.6.0
.
Sie können einfach pip install nas-bench-201
eingeben, um unsere API zu installieren. Bitte beachten Sie die Quellcodes des nas-bench-201
-Moduls in diesem Repo.
Wenn Sie Fragen oder Probleme haben, posten Sie diese bitte hier oder senden Sie mir eine E -Mail.
[veraltet] Die alte Benchmark-Datei von NAS-Bench-2010 kann von Google Drive oder Baidu-Wangpan (Code: 6U5d) heruntergeladen werden.
[Empfohlen] Die neueste Benchmark-Datei von NAS-Bench-20101 ( NAS-Bench-201-v1_1-096897.pth
) kann von Google Drive heruntergeladen werden. Die Dateien für das Modellgewicht sind zu groß (431G) und ich brauche einige Zeit, um es hochzuladen. Bitte sei geduldig, danke für dein Verständnis.
Sie können es auf überall bringen, wo Sie möchten, und ihren Weg zur Initialisierung zu unserer API senden.
NAS-Bench-201-v1_0-e61699.pth
(2.2G), wobei e61699
die letzten sechs Ziffern für diese Datei ist. Es enthält alle Informationen mit Ausnahme der ausgebildeten Gewichte jedes Versuchs.NAS-Bench-201-v1_1-096897.pth
(4.7G), wobei 096897
die letzten sechs Ziffern für diese Datei ist. Es enthält Informationen von weiteren Versuchen im Vergleich zu NAS-Bench-201-v1_0-e61699.pth
, insbesondere alle Modelle, die von 12 Epochen für alle Datensätze trainiert wurden, sind durchschnittlich. Wir empfehlen NAS-Bench-201-v1_1-096897.pth
zu verwenden
Die in NAS-Bench-201 verwendeten Schulungs- und Bewertungsdaten können von Google Drive oder Baidu-Wangpan (Code: 4FG7) heruntergeladen werden. Es wird empfohlen, diese Daten in $TORCH_HOME
( ~/.torch/
standardmäßig) einzulegen. Wenn Sie NAS-Bench-201-oder ähnliche NAS-Datensätze oder Trainingsmodelle selbst generieren möchten, benötigen Sie diese Daten.
In unseren Testcodes finden Sie mehr Nutzung .
from nas_201_api import NASBench201API as API
api = API('$path_to_meta_nas_bench_file')
# Create an API without the verbose log
api = API('NAS-Bench-201-v1_1-096897.pth', verbose=False)
# The default path for benchmark file is '{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-Bench-201-v1_1-096897.pth')
api = API(None)
len(api)
und jeder Architektur api[i]
: num = len(api)
for i, arch_str in enumerate(api):
print ('{:5d}/{:5d} : {:}'.format(i, len(api), arch_str))
# show all information for a specific architecture
api.show(1)
api.show(2)
# show the mean loss and accuracy of an architecture
info = api.query_meta_info_by_index(1) # This is an instance of `ArchResults`
res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric names as keys
cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency
# get the detailed information
results = api.query_by_index(1, 'cifar100') # a dict of all trials for 1st net on cifar100, where the key is the seed
print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1]))
for seed, result in results.items():
print ('Latency : {:}'.format(result.get_latency()))
print ('Train Info : {:}'.format(result.get_train()))
print ('Valid Info : {:}'.format(result.get_eval('x-valid')))
print ('Test Info : {:}'.format(result.get_eval('x-test')))
# for the metric after a specific epoch
print ('Train Info [10-th epoch] : {:}'.format(result.get_train(10)))
index = api.query_index_by_arch('|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|')
api.show(index)
Diese Zeichenfolge |nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|
bedeutet:
node-0: the input tensor
node-1: conv-3x3( node-0 )
node-2: conv-3x3( node-0 ) + avg-pool-3x3( node-1 )
node-3: skip-connect( node-0 ) + conv-3x3( node-1 ) + skip-connect( node-2 )
config = api.get_net_config(123, 'cifar10') # obtain the network configuration for the 123-th architecture on the CIFAR-10 dataset
from models import get_cell_based_tiny_net # this module is in AutoDL-Projects/lib/models
network = get_cell_based_tiny_net(config) # create the network from configurration
print(network) # show the structure of this architecture
Wenn Sie die geschulten Gewichte dieses erstellten Netzwerks laden möchten, müssen Sie api.get_net_param(123, ...)
verwenden, um die Gewichte zu erhalten und sie dann in das Netzwerk zu laden.
api.get_more_info(...)
kann die Verlust / Genauigkeit / Zeit für Trainings- / Validierungs- / Testsätze zurückgeben, was sehr hilfreich ist. Weitere Informationen finden Sie in den Kommentaren in der Funktion get_more_info.
Weitere Verwendungen finden Sie lib/nas_201_api/api.py
. In den Kommentaren für die entsprechenden Funktionen geben wir einige Nutzungsinformationen an. Wenn das, was Sie wollen, nicht bereitgestellt werden, können Sie gerne ein Problem zur Diskussion eröffnen, und ich beantworte gerne Fragen zu NAS-Bench-2010.
In nas_201_api
definieren wir drei Klassen: NASBench201API
, ArchResults
, ResultsCount
.
ResultsCount
führt alle Informationen einer bestimmten Versuch bei. Man kann Ergebniscount instanziieren und die Informationen über die folgenden Codes abrufen ( 000157-FULL.pth
speichert alle Informationen aller Versuche von 157-Th-Architektur):
from nas_201_api import ResultsCount
xdata = torch.load('000157-FULL.pth')
odata = xdata['full']['all_results'][('cifar10-valid', 777)]
result = ResultsCount.create_from_state_dict( odata )
print(result) # print it
print(result.get_train()) # print the final training loss/accuracy/[optional:time-cost-of-a-training-epoch]
print(result.get_train(11)) # print the training info of the 11-th epoch
print(result.get_eval('x-valid')) # print the final evaluation info on the validation set
print(result.get_eval('x-valid', 11)) # print the info on the validation set of the 11-th epoch
print(result.get_latency()) # print the evaluation latency [in batch]
result.get_net_param() # the trained parameters of this trial
arch_config = result.get_config(CellStructure.str2structure) # create the network with params
net_config = dict2config(arch_config, None)
network = get_cell_based_tiny_net(net_config)
network.load_state_dict(result.get_net_param())
ArchResults
behält alle Informationen aller Versuche einer Architektur bei. Bitte beachten Sie die folgenden Verwendungen:
from nas_201_api import ArchResults
xdata = torch.load('000157-FULL.pth')
archRes = ArchResults.create_from_state_dict(xdata['less']) # load trials trained with 12 epochs
archRes = ArchResults.create_from_state_dict(xdata['full']) # load trials trained with 200 epochs
print(archRes.arch_idx_str()) # print the index of this architecture
print(archRes.get_dataset_names()) # print the supported training data
print(archRes.get_compute_costs('cifar10-valid')) # print all computational info when training on cifar10-valid
print(archRes.get_metrics('cifar10-valid', 'x-valid', None, False)) # print the average loss/accuracy/time on all trials
print(archRes.get_metrics('cifar10-valid', 'x-valid', None, True)) # print loss/accuracy/time of a randomly selected trial
NASBench201API
ist die optimale API. Bitte beachten Sie die folgenden Verwendungen:
from nas_201_api import NASBench201API as API
api = API('NAS-Bench-201-v1_1-096897.pth') # This will load all the information of NAS-Bench-201 except the trained weights
api = API('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-Bench-201-v1_1-096897.pth')) # The same as the above line while I usually save NAS-Bench-201-v1_1-096897.pth in ~/.torch/.
api.show(-1) # show info of all architectures
api.reload('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-BENCH-201-4-v1.0-archive'), 3) # This code will reload the information 3-th architecture with the trained weights
weights = api.get_net_param(3, 'cifar10', None) # Obtaining the weights of all trials for the 3-th architecture on cifar10. It will returns a dict, where the key is the seed and the value is the trained weights.
Um die Schulungs- und Bewertungsinformationen zu erhalten (siehe die Kommentare hier):
api.get_more_info(112, 'cifar10', None, hp='200', is_random=True)
# Query info of last training epoch for 112-th architecture
# using 200-epoch-hyper-parameter and randomly select a trial.
api.get_more_info(112, 'ImageNet16-120', None, hp='200', is_random=True)
Wenn Sie feststellen, dass NAS-Bench-2010 Ihre Recherche hilft, sollten Sie sich angeben:
@inproceedings{dong2020nasbench201,
title = {NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search},
author = {Dong, Xuanyi and Yang, Yi},
booktitle = {International Conference on Learning Representations (ICLR)},
url = {https://openreview.net/forum?id=HJxyZkBKDr},
year = {2020}
}