import Pkg
Pkg . add ( " Lux " )
Consejo
Si está utilizando una versión anterior a la v1 de Lux.jl, consulte la sección Actualización a la v1 para obtener instrucciones sobre cómo actualizar.
Paquetes | Versión estable | Descargas mensuales | Descargas totales | Estado de construcción |
---|---|---|---|---|
? Lux.jl | ||||
└ ? LuxLib.jl | ||||
└ ? LuxCore.jl | ||||
└ ? MLDataDevices.jl | ||||
└ ? PesoIniciales.jl | ||||
└ ? LuxTestUtils.jl | ||||
└ ? LuxCUDA.jl |
using Lux, Random, Optimisers, Zygote
# using LuxCUDA, AMDGPU, Metal, oneAPI # Optional packages for GPU support
# Seeding
rng = Random . default_rng ()
Random . seed! (rng, 0 )
# Construct the layer
model = Chain ( Dense ( 128 , 256 , tanh), Chain ( Dense ( 256 , 1 , tanh), Dense ( 1 , 10 )))
# Get the device determined by Lux
dev = gpu_device ()
# Parameter and State Variables
ps, st = Lux . setup (rng, model) |> dev
# Dummy Input
x = rand (rng, Float32, 128 , 2 ) |> dev
# Run the model
y, st = Lux . apply (model, x, ps, st)
# Gradients
# # First construct a TrainState
train_state = Lux . Training . TrainState (model, ps, st, Adam ( 0.0001f0 ))
# # We can compute the gradients using Training.compute_gradients
gs, loss, stats, train_state = Lux . Training . compute_gradients ( AutoZygote (), MSELoss (),
(x, dev ( rand (rng, Float32, 10 , 2 ))), train_state)
# # Optimization
train_state = Training . apply_gradients! (train_state, gs) # or Training.apply_gradients (no `!` at the end)
# Both these steps can be combined into a single call
gs, loss, stats, train_state = Training . single_train_step! ( AutoZygote (), MSELoss (),
(x, dev ( rand (rng, Float32, 10 , 2 ))), train_state)
Busque en el directorio de ejemplos ejemplos de uso independientes. La documentación tiene ejemplos ordenados en categorías adecuadas.
Para preguntas relacionadas con el uso, utilice Github Discussions, que permite indexar preguntas y respuestas. Para informar errores, utilice problemas de github o, mejor aún, envíe una solicitud de extracción.
Si esta biblioteca le resultó útil en el trabajo académico, cite:
@software { pal2023lux ,
author = { Pal, Avik } ,
title = { {Lux: Explicit Parameterization of Deep Neural Networks in Julia} } ,
month = apr,
year = 2023 ,
note = { If you use this software, please cite it as below. } ,
publisher = { Zenodo } ,
version = { v0.5.0 } ,
doi = { 10.5281/zenodo.7808904 } ,
url = { https://doi.org/10.5281/zenodo.7808904 }
}
@thesis { pal2023efficient ,
title = { {On Efficient Training & Inference of Neural Differential Equations} } ,
author = { Pal, Avik } ,
year = { 2023 } ,
school = { Massachusetts Institute of Technology }
}
También considere destacar nuestro repositorio de github.
Esta sección está algo incompleta. ¿Puedes contribuir contribuyendo a terminar esta sección?.
La prueba completa de Lux.jl
lleva mucho tiempo; aquí se explica cómo probar una parte del código.
Para cada @testitem
, hay tags
correspondientes, por ejemplo:
@testitem " SkipConnection " setup = [SharedTestSetup] tags = [ :core_layers ]
Por ejemplo, consideremos las pruebas de SkipConnection
:
@testitem " SkipConnection " setup = [SharedTestSetup] tags = [ :core_layers ] begin
...
end
Podemos probar el grupo al que pertenece SkipConnection
probando core_layers
. Para hacerlo, configure la variable de entorno LUX_TEST_GROUP
o cambie el nombre de la etiqueta para limitar aún más el alcance de la prueba:
export LUX_TEST_GROUP= " core_layers "
O modifique directamente la etiqueta de prueba predeterminada en runtests.jl
:
# const LUX_TEST_GROUP = lowercase(get(ENV, "LUX_TEST_GROUP", "all"))
const LUX_TEST_GROUP = lowercase ( get ( ENV , " LUX_TEST_GROUP " , " core_layers " ))
Pero asegúrese de restaurar el valor predeterminado "todo" antes de enviar el código.
Además, si desea ejecutar una prueba específica basada en el nombre del conjunto de pruebas, puede utilizar TestEnv.jl de la siguiente manera. Comience activando el entorno Lux y luego ejecute lo siguiente:
using TestEnv; TestEnv . activate (); using ReTestItems;
# Assuming you are in the main directory of Lux
ReTestItems . runtests ( " tests/ " ; name = " NAME OF THE TEST " )
Para las pruebas SkipConnection
eso sería:
ReTestItems . runtests ( " tests/ " ; name = " SkipConnection " )