import Pkg
Pkg . add ( " Lux " )
Tip
Jika Anda menggunakan Lux.jl versi pra-v1, silakan lihat bagian Memperbarui ke v1 untuk petunjuk tentang cara memperbarui.
Paket | Versi Stabil | Unduhan Bulanan | Jumlah Unduhan | Membangun Status |
---|---|---|---|---|
? Lux.jl | ||||
└ ? LuxLib.jl | ||||
└ ? LuxCore.jl | ||||
└ ? MLDataDevices.jl | ||||
└ ? WeightInitializers.jl | ||||
└ ? LuxTestUtils.jl | ||||
└ ? LuxCUDA.jl |
using Lux, Random, Optimisers, Zygote
# using LuxCUDA, AMDGPU, Metal, oneAPI # Optional packages for GPU support
# Seeding
rng = Random . default_rng ()
Random . seed! (rng, 0 )
# Construct the layer
model = Chain ( Dense ( 128 , 256 , tanh), Chain ( Dense ( 256 , 1 , tanh), Dense ( 1 , 10 )))
# Get the device determined by Lux
dev = gpu_device ()
# Parameter and State Variables
ps, st = Lux . setup (rng, model) |> dev
# Dummy Input
x = rand (rng, Float32, 128 , 2 ) |> dev
# Run the model
y, st = Lux . apply (model, x, ps, st)
# Gradients
# # First construct a TrainState
train_state = Lux . Training . TrainState (model, ps, st, Adam ( 0.0001f0 ))
# # We can compute the gradients using Training.compute_gradients
gs, loss, stats, train_state = Lux . Training . compute_gradients ( AutoZygote (), MSELoss (),
(x, dev ( rand (rng, Float32, 10 , 2 ))), train_state)
# # Optimization
train_state = Training . apply_gradients! (train_state, gs) # or Training.apply_gradients (no `!` at the end)
# Both these steps can be combined into a single call
gs, loss, stats, train_state = Training . single_train_step! ( AutoZygote (), MSELoss (),
(x, dev ( rand (rng, Float32, 10 , 2 ))), train_state)
Lihat di direktori contoh untuk contoh penggunaan mandiri. Dokumentasi memiliki contoh yang diurutkan ke dalam kategori yang sesuai.
Untuk pertanyaan terkait penggunaan, silakan gunakan Diskusi Github yang memungkinkan pertanyaan dan jawaban diindeks. Untuk melaporkan bug, gunakan masalah github atau lebih baik lagi kirimkan permintaan tarik.
Jika menurut Anda perpustakaan ini berguna dalam pekerjaan akademis, silakan kutip:
@software { pal2023lux ,
author = { Pal, Avik } ,
title = { {Lux: Explicit Parameterization of Deep Neural Networks in Julia} } ,
month = apr,
year = 2023 ,
note = { If you use this software, please cite it as below. } ,
publisher = { Zenodo } ,
version = { v0.5.0 } ,
doi = { 10.5281/zenodo.7808904 } ,
url = { https://doi.org/10.5281/zenodo.7808904 }
}
@thesis { pal2023efficient ,
title = { {On Efficient Training & Inference of Neural Differential Equations} } ,
author = { Pal, Avik } ,
year = { 2023 } ,
school = { Massachusetts Institute of Technology }
}
Pertimbangkan juga untuk membintangi repo github kami.
Bagian ini agak kurang lengkap. Anda dapat berkontribusi dengan berkontribusi menyelesaikan bagian ini?.
Pengujian lengkap Lux.jl
membutuhkan waktu lama, berikut cara menguji sebagian kodenya.
Untuk setiap @testitem
, ada tags
yang sesuai, misalnya:
@testitem " SkipConnection " setup = [SharedTestSetup] tags = [ :core_layers ]
Misalnya, mari pertimbangkan pengujian untuk SkipConnection
:
@testitem " SkipConnection " setup = [SharedTestSetup] tags = [ :core_layers ] begin
...
end
Kita dapat menguji grup tempat SkipConnection
berada dengan menguji core_layers
. Untuk melakukannya, atur variabel lingkungan LUX_TEST_GROUP
, atau ganti nama tag untuk lebih mempersempit cakupan pengujian:
export LUX_TEST_GROUP= " core_layers "
Atau langsung ubah tag pengujian default di runtests.jl
:
# const LUX_TEST_GROUP = lowercase(get(ENV, "LUX_TEST_GROUP", "all"))
const LUX_TEST_GROUP = lowercase ( get ( ENV , " LUX_TEST_GROUP " , " core_layers " ))
Namun pastikan untuk mengembalikan nilai default "semua" sebelum mengirimkan kode.
Selanjutnya jika Anda ingin menjalankan pengujian tertentu berdasarkan nama testset tersebut, Anda dapat menggunakan TestEnv.jl sebagai berikut. Mulailah dengan mengaktifkan lingkungan Lux lalu jalankan perintah berikut:
using TestEnv; TestEnv . activate (); using ReTestItems;
# Assuming you are in the main directory of Lux
ReTestItems . runtests ( " tests/ " ; name = " NAME OF THE TEST " )
Untuk tes SkipConnection
itu adalah:
ReTestItems . runtests ( " tests/ " ; name = " SkipConnection " )