Dado que nuestro NAS Bench-201-201 se ha extendido al Bench NATS, este repositorio está en desuso y no se mantiene. Utilice Nats-Bench, que tiene 5 veces más información de arquitectura y API más rápida que NAS-Bench-201.
Proponemos un punto de referencia NAS agnóstico de algoritmo (NAS-Bench-201) con un espacio de búsqueda fijo, que proporciona un punto de referencia unificado para casi cualquier algoritmos NAS actualizados. El diseño de nuestro espacio de búsqueda se inspira en el utilizado en los algoritmos de búsqueda basados en células más populares, donde una celda se representa como un gráfico acíclico dirigido. Cada borde aquí está asociado con una operación seleccionada de un conjunto de operaciones predefinidas. Para que sea aplicable para todos los algoritmos NAS, el espacio de búsqueda definido en NAS-Bench-201-201N incluye 4 nodos y 5 opciones de operación asociadas, que genera 15,625 candidatos de células neurales en total.
En este archivo de Markdown, proporcionamos:
Para las siguientes dos cosas, use Autodl-Projects:
Nota: Utilice PyTorch >= 1.2.0
y Python >= 3.6.0
.
Simplemente puede escribir pip install nas-bench-201
para instalar nuestra API. Consulte los códigos de origen del módulo nas-bench-201
en este repositorio.
Si tiene alguna pregunta o problema, publíquelo aquí o envíeme un correo electrónico.
[Descargado] El antiguo archivo de referencia de NAS Bench-201-2010 se puede descargar de Google Drive o Baidu-Wangpan (Código: 6U5D).
[Recomendado] El último archivo de referencia de NAS-Bench-201 ( NAS-Bench-201-v1_1-096897.pth
) se puede descargar desde Google Drive. Los archivos para el peso del modelo son demasiado grandes (431 g) y necesito algo de tiempo para cargarlo. Por favor, sea paciente, gracias por su comprensión.
Puede moverlo a cualquier lugar que desee y enviar su camino a nuestra API para la inicialización.
NAS-Bench-201-v1_0-e61699.pth
(2.2G), donde e61699
es los últimos seis dígitos para este archivo. Contiene toda la información, excepto los pesos capacitados de cada prueba.NAS-Bench-201-v1_1-096897.pth
(4.7G), donde 096897
son los últimos seis dígitos para este archivo. Contiene información de más ensayos en comparación con NAS-Bench-201-v1_0-e61699.pth
, especialmente todos los modelos entrenados por 12 épocas en todos los conjuntos de datos son avalables. Recomendamos usar NAS-Bench-201-v1_1-096897.pth
Los datos de capacitación y evaluación utilizados en NAS-Bench-201-2010 se pueden descargar desde Google Drive o Baidu-Wangpan (Código: 4FG7). Se recomienda colocar estos datos en $TORCH_HOME
( ~/.torch/
Por defecto). Si desea generar nas-bench-201 o conjuntos de datos NAS o modelos de capacitación similares, necesita estos datos.
Se puede encontrar más uso en nuestros códigos de prueba .
from nas_201_api import NASBench201API as API
api = API('$path_to_meta_nas_bench_file')
# Create an API without the verbose log
api = API('NAS-Bench-201-v1_1-096897.pth', verbose=False)
# The default path for benchmark file is '{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-Bench-201-v1_1-096897.pth')
api = API(None)
len(api)
y cada api[i]
: num = len(api)
for i, arch_str in enumerate(api):
print ('{:5d}/{:5d} : {:}'.format(i, len(api), arch_str))
# show all information for a specific architecture
api.show(1)
api.show(2)
# show the mean loss and accuracy of an architecture
info = api.query_meta_info_by_index(1) # This is an instance of `ArchResults`
res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric names as keys
cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency
# get the detailed information
results = api.query_by_index(1, 'cifar100') # a dict of all trials for 1st net on cifar100, where the key is the seed
print ('There are {:} trials for this architecture [{:}] on cifar100'.format(len(results), api[1]))
for seed, result in results.items():
print ('Latency : {:}'.format(result.get_latency()))
print ('Train Info : {:}'.format(result.get_train()))
print ('Valid Info : {:}'.format(result.get_eval('x-valid')))
print ('Test Info : {:}'.format(result.get_eval('x-test')))
# for the metric after a specific epoch
print ('Train Info [10-th epoch] : {:}'.format(result.get_train(10)))
index = api.query_index_by_arch('|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|')
api.show(index)
Esta cadena |nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|
medio:
node-0: the input tensor
node-1: conv-3x3( node-0 )
node-2: conv-3x3( node-0 ) + avg-pool-3x3( node-1 )
node-3: skip-connect( node-0 ) + conv-3x3( node-1 ) + skip-connect( node-2 )
config = api.get_net_config(123, 'cifar10') # obtain the network configuration for the 123-th architecture on the CIFAR-10 dataset
from models import get_cell_based_tiny_net # this module is in AutoDL-Projects/lib/models
network = get_cell_based_tiny_net(config) # create the network from configurration
print(network) # show the structure of this architecture
Si desea cargar los pesos capacitados de esta red creada, debe usar api.get_net_param(123, ...)
para obtener los pesos y luego cargarlo en la red.
api.get_more_info(...)
puede devolver la pérdida / precisión / tiempo en los conjuntos de entrenamiento / validación / prueba, lo cual es muy útil. Para obtener más detalles, mire los comentarios en la función get_more_info.
Para otros usos, consulte lib/nas_201_api/api.py
. Proporcionamos información de uso en los comentarios para las funciones correspondientes. Si no se proporciona lo que desea, no dude en abrir un problema para la discusión, y me complace responder cualquier pregunta sobre NAS-Bench-201.
En nas_201_api
, definimos tres clases: NASBench201API
, ArchResults
, ResultsCount
.
ResultsCount
mantiene toda la información de una prueba específica. Uno puede instanciar resultados de resultados y obtener la información a través de los siguientes códigos ( 000157-FULL.pth
guarda toda la información de todas las pruebas de 157 ° arquitectura):
from nas_201_api import ResultsCount
xdata = torch.load('000157-FULL.pth')
odata = xdata['full']['all_results'][('cifar10-valid', 777)]
result = ResultsCount.create_from_state_dict( odata )
print(result) # print it
print(result.get_train()) # print the final training loss/accuracy/[optional:time-cost-of-a-training-epoch]
print(result.get_train(11)) # print the training info of the 11-th epoch
print(result.get_eval('x-valid')) # print the final evaluation info on the validation set
print(result.get_eval('x-valid', 11)) # print the info on the validation set of the 11-th epoch
print(result.get_latency()) # print the evaluation latency [in batch]
result.get_net_param() # the trained parameters of this trial
arch_config = result.get_config(CellStructure.str2structure) # create the network with params
net_config = dict2config(arch_config, None)
network = get_cell_based_tiny_net(net_config)
network.load_state_dict(result.get_net_param())
ArchResults
mantiene toda la información de todas las pruebas de una arquitectura. Consulte los siguientes usos:
from nas_201_api import ArchResults
xdata = torch.load('000157-FULL.pth')
archRes = ArchResults.create_from_state_dict(xdata['less']) # load trials trained with 12 epochs
archRes = ArchResults.create_from_state_dict(xdata['full']) # load trials trained with 200 epochs
print(archRes.arch_idx_str()) # print the index of this architecture
print(archRes.get_dataset_names()) # print the supported training data
print(archRes.get_compute_costs('cifar10-valid')) # print all computational info when training on cifar10-valid
print(archRes.get_metrics('cifar10-valid', 'x-valid', None, False)) # print the average loss/accuracy/time on all trials
print(archRes.get_metrics('cifar10-valid', 'x-valid', None, True)) # print loss/accuracy/time of a randomly selected trial
NASBench201API
es la API de nivel más alto. Consulte los siguientes usos:
from nas_201_api import NASBench201API as API
api = API('NAS-Bench-201-v1_1-096897.pth') # This will load all the information of NAS-Bench-201 except the trained weights
api = API('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-Bench-201-v1_1-096897.pth')) # The same as the above line while I usually save NAS-Bench-201-v1_1-096897.pth in ~/.torch/.
api.show(-1) # show info of all architectures
api.reload('{:}/{:}'.format(os.environ['TORCH_HOME'], 'NAS-BENCH-201-4-v1.0-archive'), 3) # This code will reload the information 3-th architecture with the trained weights
weights = api.get_net_param(3, 'cifar10', None) # Obtaining the weights of all trials for the 3-th architecture on cifar10. It will returns a dict, where the key is the seed and the value is the trained weights.
Para obtener la información de capacitación y evaluación (consulte los comentarios aquí):
api.get_more_info(112, 'cifar10', None, hp='200', is_random=True)
# Query info of last training epoch for 112-th architecture
# using 200-epoch-hyper-parameter and randomly select a trial.
api.get_more_info(112, 'ImageNet16-120', None, hp='200', is_random=True)
Si encuentra que NAS-Bench-2010 ayuda a su investigación, considere citarla:
@inproceedings{dong2020nasbench201,
title = {NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search},
author = {Dong, Xuanyi and Yang, Yi},
booktitle = {International Conference on Learning Representations (ICLR)},
url = {https://openreview.net/forum?id=HJxyZkBKDr},
year = {2020}
}