metnet3 pytorch
0.0.12
MetNet 3、Google Deepmind の SOTA ニューラル気象モデルの Pytorch での実装
モデルのアーキテクチャは非常に目立たないものです。これは基本的に、特定の優れたパフォーマンスのビジョン トランスフォーマーを備えた U-net です。この論文で最も興味深い点は、最終的にセクション 4.3.2 の損失スケーリングになるかもしれません。
$ pip install metnet3-pytorch
import torch
from metnet3_pytorch import MetNet3
metnet3 = MetNet3 (
dim = 512 ,
num_lead_times = 722 ,
lead_time_embed_dim = 32 ,
input_spatial_size = 624 ,
attn_dim_head = 8 ,
hrrr_channels = 617 ,
input_2496_channels = 2 + 14 + 1 + 2 + 20 ,
input_4996_channels = 16 + 1 ,
precipitation_target_bins = dict (
mrms_rate = 512 ,
mrms_accumulation = 512 ,
),
surface_target_bins = dict (
omo_temperature = 256 ,
omo_dew_point = 256 ,
omo_wind_speed = 256 ,
omo_wind_component_x = 256 ,
omo_wind_component_y = 256 ,
omo_wind_direction = 180
),
hrrr_loss_weight = 10 ,
hrrr_norm_strategy = 'sync_batchnorm' , # this would use a sync batchnorm to normalize the input hrrr and target, without having to precalculate the mean and variance of the hrrr dataset per channel
hrrr_norm_statistics = None # you can also also set `hrrr_norm_strategy = "precalculated"` and pass in the mean and variance as shape `(2, 617)` through this keyword argument
)
# inputs
lead_times = torch . randint ( 0 , 722 , ( 2 ,))
hrrr_input_2496 = torch . randn (( 2 , 617 , 624 , 624 ))
hrrr_stale_state = torch . randn (( 2 , 1 , 624 , 624 ))
input_2496 = torch . randn (( 2 , 39 , 624 , 624 ))
input_4996 = torch . randn (( 2 , 17 , 624 , 624 ))
# targets
precipitation_targets = dict (
mrms_rate = torch . randint ( 0 , 512 , ( 2 , 512 , 512 )),
mrms_accumulation = torch . randint ( 0 , 512 , ( 2 , 512 , 512 )),
)
surface_targets = dict (
omo_temperature = torch . randint ( 0 , 256 , ( 2 , 128 , 128 )),
omo_dew_point = torch . randint ( 0 , 256 , ( 2 , 128 , 128 )),
omo_wind_speed = torch . randint ( 0 , 256 , ( 2 , 128 , 128 )),
omo_wind_component_x = torch . randint ( 0 , 256 , ( 2 , 128 , 128 )),
omo_wind_component_y = torch . randint ( 0 , 256 , ( 2 , 128 , 128 )),
omo_wind_direction = torch . randint ( 0 , 180 , ( 2 , 128 , 128 ))
)
hrrr_target = torch . randn ( 2 , 617 , 128 , 128 )
total_loss , loss_breakdown = metnet3 (
lead_times = lead_times ,
hrrr_input_2496 = hrrr_input_2496 ,
hrrr_stale_state = hrrr_stale_state ,
input_2496 = input_2496 ,
input_4996 = input_4996 ,
precipitation_targets = precipitation_targets ,
surface_targets = surface_targets ,
hrrr_target = hrrr_target
)
total_loss . backward ()
# after much training from above, you can predict as follows
metnet3 . eval ()
surface_preds , hrrr_pred , precipitation_preds = metnet3 (
lead_times = lead_times ,
hrrr_input_2496 = hrrr_input_2496 ,
hrrr_stale_state = hrrr_stale_state ,
input_2496 = input_2496 ,
input_4996 = input_4996 ,
)
# Dict[str, Tensor], Tensor, Dict[str, Tensor]
すべてのクロスエントロピーとMSE損失を計算する
トレーニング中にターゲットの移動平均と分散を追跡することにより、HRRR のすべてのチャネルにわたる正規化を自動処理します (ハックとして同期バッチノルムを使用)
研究者が独自の HRRR の正規化変数を渡すことができる
すべての入力を仕様に合わせて構築し、hrrr 入力が正規化されていることを確認し、hrrr 予測を非正規化するオプションを提供します
hrrr ノルムを処理するさまざまな方法を使用して、モデルを簡単に保存およびロードできることを確認します。
トポロジカルな埋め込みを解明し、神経天気の研究者に相談する
@article { Andrychowicz2023DeepLF ,
title = { Deep Learning for Day Forecasts from Sparse Observations } ,
author = { Marcin Andrychowicz and Lasse Espeholt and Di Li and Samier Merchant and Alexander Merose and Fred Zyda and Shreya Agrawal and Nal Kalchbrenner } ,
journal = { ArXiv } ,
year = { 2023 } ,
volume = { abs/2306.06079 } ,
url = { https://api.semanticscholar.org/CorpusID:259129311 }
}
@inproceedings { ElNouby2021XCiTCI ,
title = { XCiT: Cross-Covariance Image Transformers } ,
author = { Alaaeldin El-Nouby and Hugo Touvron and Mathilde Caron and Piotr Bojanowski and Matthijs Douze and Armand Joulin and Ivan Laptev and Natalia Neverova and Gabriel Synnaeve and Jakob Verbeek and Herv{'e} J{'e}gou } ,
booktitle = { Neural Information Processing Systems } ,
year = { 2021 } ,
url = { https://api.semanticscholar.org/CorpusID:235458262 }
}