Étant donné que notre NAS-BENCH-2010 a été étendu à Nats-Bench, ce dépôt est obsolète et non entretenu. Veuillez utiliser Nats-Bench, qui a 5 fois plus d'informations d'architecture et API plus rapide que NAS-BENCH-2010.
Nous proposons une référence NAS agnostique algorithme (NAS-BENCH-2010) avec un espace de recherche fixe, qui fournit une référence unifiée pour presque tous les algorithmes NAS à jour. La conception de notre espace de recherche est inspirée par celle utilisée dans les algorithmes de recherche de cellules les plus populaires, où une cellule est représentée comme un graphique acyclique dirigé. Chaque bord ici est associé à une opération sélectionnée à partir d'un ensemble d'opération prédéfini. Pour qu'il soit applicable à tous les algorithmes NAS, l'espace de recherche défini dans NAS-BENCH-2010 comprend 4 nœuds et 5 options de fonctionnement associées, qui génère 15 625 candidats à cellules neuronales au total.
Dans ce fichier Markdown, nous fournissons:
Pour les deux choses suivantes, veuillez utiliser Autodl-Projects:
Remarque: veuillez utiliser PyTorch >= 1.2.0
et Python >= 3.6.0
.
Vous pouvez simplement taper pip install nas-bench-201
pour installer notre API. Veuillez consulter les codes source du module nas-bench-201
dans ce dépôt.
Si vous avez des questions ou des problèmes, veuillez le poster ici ou m'envoyer un e-mail.
[Désacré] L' ancien fichier de référence de NAS-BENCH-2010 peut être téléchargé à partir de Google Drive ou Baidu-Wangpan (code: 6U5D).
[Recommandé] Le dernier fichier de référence de NAS-BENCH-2010 ( NAS-Bench-201-v1_1-096897.pth
) peut être téléchargé depuis Google Drive. Les fichiers de poids du modèle sont trop grands (431g) et j'ai besoin de temps pour le télécharger. Soyez patient, merci pour votre compréhension.
Vous pouvez le déplacer dans n'importe où vous le souhaitez et envoyer son chemin à notre API pour l'initialisation.
NAS-Bench-201-v1_0-e61699.pth
(2.2G), où e61699
est les six derniers chiffres de ce fichier. Il contient toutes les informations à l'exception des poids formés de chaque essai.096897
NAS-Bench-201-v1_1-096897.pth
Il contient des informations sur plus d'essais par rapport à NAS-Bench-201-v1_0-e61699.pth
, en particulier tous les modèles formés par 12 époques sur tous les ensembles de données sont avalimes. Nous vous recommandons d'utiliser NAS-Bench-201-v1_1-096897.pth
Les données de formation et d'évaluation utilisées dans NAS-BENCH-2010 peuvent être téléchargées à partir de Google Drive ou Baidu-Wangpan (code: 4FG7). Il est recommandé de placer ces données dans $TORCH_HOME
( ~/.torch/
par défaut). Si vous souhaitez générer des ensembles de données NAS ou de formation NAS-BENCH-2010 ou similaires, vous avez besoin de ces données.
Plus d'utilisation peut être trouvée dans nos codes de test .
from nas_201_api import NASBench201API as API
api = API('$path_to_meta_nas_bench_file')
# Create an API without the verbose log
api = API('NAS-Bench-201-v1_1-096897.pth', verbose=False)
# The default path for benchmark file is '{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-Bench-201-v1_1-096897.pth')
api = API(None)
len(api)
et chaque api[i]
: num = len(api)
for i, arch_str in enumerate(api):
print ('{:5d}/{:5d} : {:}'.format(i, len(api), arch_str))
# show all information for a specific architecture
api.show(1)
api.show(2)
# show the mean loss and accuracy of an architecture
info = api.query_meta_info_by_index(1) # This is an instance of `ArchResults`
res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric names as keys
cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency
# get the detailed information
results = api.query_by_index(1, 'cifar100') # a dict of all trials for 1st net on cifar100, where the key is the seed
print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1]))
for seed, result in results.items():
print ('Latency : {:}'.format(result.get_latency()))
print ('Train Info : {:}'.format(result.get_train()))
print ('Valid Info : {:}'.format(result.get_eval('x-valid')))
print ('Test Info : {:}'.format(result.get_eval('x-test')))
# for the metric after a specific epoch
print ('Train Info [10-th epoch] : {:}'.format(result.get_train(10)))
index = api.query_index_by_arch('|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|')
api.show(index)
Cette chaîne |nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|
moyens:
node-0: the input tensor
node-1: conv-3x3( node-0 )
node-2: conv-3x3( node-0 ) + avg-pool-3x3( node-1 )
node-3: skip-connect( node-0 ) + conv-3x3( node-1 ) + skip-connect( node-2 )
config = api.get_net_config(123, 'cifar10') # obtain the network configuration for the 123-th architecture on the CIFAR-10 dataset
from models import get_cell_based_tiny_net # this module is in AutoDL-Projects/lib/models
network = get_cell_based_tiny_net(config) # create the network from configurration
print(network) # show the structure of this architecture
Si vous souhaitez charger les poids formés de ce réseau créé, vous devez utiliser api.get_net_param(123, ...)
pour obtenir les poids, puis le charger sur le réseau.
api.get_more_info(...)
peut renvoyer la perte / précision / temps sur les ensembles de formation / validation / tests, ce qui est très utile. Pour plus de détails, veuillez consulter les commentaires dans la fonction get_more_info.
Pour d'autres usages, veuillez consulter lib/nas_201_api/api.py
Nous fournissons des informations d'utilisation dans les commentaires pour les fonctions correspondantes. Si ce que vous voulez n'est pas fourni, n'hésitez pas à ouvrir un problème à la discussion, et je suis heureux de répondre à toutes les questions concernant le NAS-BENCH-20101.
Dans nas_201_api
, nous définissons trois classes: NASBench201API
, ArchResults
, ResultsCount
.
ResultsCount
maintient toutes les informations d'un essai spécifique. On peut instancier des résultats et obtenir les informations via les codes suivants ( 000157-FULL.pth
enregistre toutes les informations de tous les essais de la 157e architecture):
from nas_201_api import ResultsCount
xdata = torch.load('000157-FULL.pth')
odata = xdata['full']['all_results'][('cifar10-valid', 777)]
result = ResultsCount.create_from_state_dict( odata )
print(result) # print it
print(result.get_train()) # print the final training loss/accuracy/[optional:time-cost-of-a-training-epoch]
print(result.get_train(11)) # print the training info of the 11-th epoch
print(result.get_eval('x-valid')) # print the final evaluation info on the validation set
print(result.get_eval('x-valid', 11)) # print the info on the validation set of the 11-th epoch
print(result.get_latency()) # print the evaluation latency [in batch]
result.get_net_param() # the trained parameters of this trial
arch_config = result.get_config(CellStructure.str2structure) # create the network with params
net_config = dict2config(arch_config, None)
network = get_cell_based_tiny_net(net_config)
network.load_state_dict(result.get_net_param())
ArchResults
maintiennent toutes les informations de tous les essais d'une architecture. Veuillez consulter les usages suivants:
from nas_201_api import ArchResults
xdata = torch.load('000157-FULL.pth')
archRes = ArchResults.create_from_state_dict(xdata['less']) # load trials trained with 12 epochs
archRes = ArchResults.create_from_state_dict(xdata['full']) # load trials trained with 200 epochs
print(archRes.arch_idx_str()) # print the index of this architecture
print(archRes.get_dataset_names()) # print the supported training data
print(archRes.get_compute_costs('cifar10-valid')) # print all computational info when training on cifar10-valid
print(archRes.get_metrics('cifar10-valid', 'x-valid', None, False)) # print the average loss/accuracy/time on all trials
print(archRes.get_metrics('cifar10-valid', 'x-valid', None, True)) # print loss/accuracy/time of a randomly selected trial
NASBench201API
est l'API de niveau de toppest. Veuillez consulter les usages suivants:
from nas_201_api import NASBench201API as API
api = API('NAS-Bench-201-v1_1-096897.pth') # This will load all the information of NAS-Bench-201 except the trained weights
api = API('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-Bench-201-v1_1-096897.pth')) # The same as the above line while I usually save NAS-Bench-201-v1_1-096897.pth in ~/.torch/.
api.show(-1) # show info of all architectures
api.reload('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-BENCH-201-4-v1.0-archive'), 3) # This code will reload the information 3-th architecture with the trained weights
weights = api.get_net_param(3, 'cifar10', None) # Obtaining the weights of all trials for the 3-th architecture on cifar10. It will returns a dict, where the key is the seed and the value is the trained weights.
Pour obtenir les informations de formation et d'évaluation (veuillez consulter les commentaires ici):
api.get_more_info(112, 'cifar10', None, hp='200', is_random=True)
# Query info of last training epoch for 112-th architecture
# using 200-epoch-hyper-parameter and randomly select a trial.
api.get_more_info(112, 'ImageNet16-120', None, hp='200', is_random=True)
Si vous constatez que le NAS-BENCH-20101 aide vos recherches, veuillez envisager de le citer:
@inproceedings{dong2020nasbench201,
title = {NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search},
author = {Dong, Xuanyi and Yang, Yi},
booktitle = {International Conference on Learning Representations (ICLR)},
url = {https://openreview.net/forum?id=HJxyZkBKDr},
year = {2020}
}