具有图像神经网络的 Python 库
基于PyTorch的分割。
该库的主要特点是:
访问阅读文档项目页面或阅读以下 README 以了解有关分割模型 Pytorch(简称 SMP)库的更多信息
分割模型只是一个 PyTorch torch.nn.Module
,创建起来很简单:
import segmentation_models_pytorch as smp
model = smp . Unet (
encoder_name = "resnet34" , # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
encoder_weights = "imagenet" , # use `imagenet` pre-trained weights for encoder initialization
in_channels = 1 , # model input channels (1 for gray-scale images, 3 for RGB, etc.)
classes = 3 , # model output channels (number of classes in your dataset)
)
所有编码器都有预训练的权重。以与权重预训练期间相同的方式准备数据可能会给您带来更好的结果(更高的指标分数和更快的收敛)。如果您训练整个模型,而不仅仅是解码器,则没有必要。
from segmentation_models_pytorch . encoders import get_preprocessing_fn
preprocess_input = get_preprocessing_fn ( 'resnet18' , pretrained = 'imagenet' )
恭喜!你完成了!现在您可以使用您最喜欢的框架来训练您的模型!
以下是 SMP 中支持的编码器列表。选择适当的编码器系列,然后单击展开表格并选择特定编码器及其预训练权重( encoder_name
和encoder_weights
参数)。
编码器 | 重量 | 帕拉姆斯,M |
---|---|---|
资源网18 | 图像网 / ssl / swsl | 11M |
资源网34 | 图像网 | 21M |
resnet50 | 图像网 / ssl / swsl | 23M |
资源网101 | 图像网 | 42M |
资源网152 | 图像网 | 58M |
编码器 | 重量 | 帕拉姆斯,M |
---|---|---|
resnext50_32x4d | 图像网 / ssl / swsl | 22M |
resnext101_32x4d | SSL/SWSL | 42M |
resnext101_32x8d | imagenet/instagram/ssl/swsl | 86M |
resnext101_32x16d | Instagram / ssl / swsl | 191M |
resnext101_32x32d | 466M | |
resnext101_32x48d | 826M |
编码器 | 重量 | 帕拉姆斯,M |
---|---|---|
蒂姆·雷斯内斯特14d | 图像网 | 8M |
蒂姆-resnest26d | 图像网 | 15M |
蒂姆-雷斯内斯特50d | 图像网 | 25M |
蒂姆-resnest101e | 图像网 | 46M |
蒂姆-resnest200e | 图像网 | 68M |
蒂姆-resnest269e | 图像网 | 108M |
蒂姆-resnest50d_4s2x40d | 图像网 | 28M |
蒂姆-resnest50d_1s4x24d | 图像网 | 23M |
编码器 | 重量 | 帕拉姆斯,M |
---|---|---|
timm-res2net50_26w_4s | 图像网 | 23M |
timm-res2net101_26w_4s | 图像网 | 43M |
timm-res2net50_26w_6s | 图像网 | 35M |
timm-res2net50_26w_8s | 图像网 | 46M |
timm-res2net50_48w_2s | 图像网 | 23M |
timm-res2net50_14w_8s | 图像网 | 23M |
timm-res2next50 | 图像网 | 22M |
编码器 | 重量 | 帕拉姆斯,M |
---|---|---|
蒂姆-regnetx_002 | 图像网 | 2M |
蒂姆-regnetx_004 | 图像网 | 4M |
蒂姆-regnetx_006 | 图像网 | 5M |
蒂姆-regnetx_008 | 图像网 | 6M |
蒂姆-regnetx_016 | 图像网 | 8M |
蒂姆-regnetx_032 | 图像网 | 14M |
蒂姆-regnetx_040 | 图像网 | 20M |
蒂姆-regnetx_064 | 图像网 | 24M |
蒂姆-regnetx_080 | 图像网 | 37M |
蒂姆-regnetx_120 | 图像网 | 43M |
蒂姆-regnetx_160 | 图像网 | 52M |
蒂姆-regnetx_320 | 图像网 | 105M |
蒂姆-regnety_002 | 图像网 | 2M |
蒂姆-regnety_004 | 图像网 | 3M |
蒂姆-regnety_006 | 图像网 | 5M |
蒂姆-regnety_008 | 图像网 | 5M |
蒂姆-regnety_016 | 图像网 | 10M |
蒂姆-regnety_032 | 图像网 | 17M |
蒂姆-regnety_040 | 图像网 | 19M |
蒂姆-regnety_064 | 图像网 | 29M |
蒂姆-regnety_080 | 图像网 | 37M |
蒂姆-regnety_120 | 图像网 | 49M |
蒂姆-regnety_160 | 图像网 | 80M |
蒂姆-regnety_320 | 图像网 | 141M |
编码器 | 重量 | 帕拉姆斯,M |
---|---|---|
蒂姆-gernet_s | 图像网 | 6M |
蒂姆-gernet_m | 图像网 | 18M |
蒂姆-gernet_l | 图像网 | 28M |
编码器 | 重量 | 帕拉姆斯,M |
---|---|---|
塞内特154 | 图像网 | 113M |
se_resnet50 | 图像网 | 26M |
se_resnet101 | 图像网 | 47M |
se_resnet152 | 图像网 | 64M |
se_resnext50_32x4d | 图像网 | 25M |
se_resnext101_32x4d | 图像网 | 46M |
编码器 | 重量 | 帕拉姆斯,M |
---|---|---|
蒂姆-skresnet18 | 图像网 | 11M |
蒂姆-skresnet34 | 图像网 | 21M |
蒂姆-skresnext50_32x4d | 图像网 | 25M |
编码器 | 重量 | 帕拉姆斯,M |
---|---|---|
密集网121 | 图像网 | 6M |
密集网169 | 图像网 | 12M |
密集网201 | 图像网 | 18M |
密集网161 | 图像网 | 26M |
编码器 | 重量 | 帕拉姆斯,M |
---|---|---|
盗梦空间resnetv2 | imagenet / imagenet+背景 | 54M |
盗梦空间v4 | imagenet / imagenet+背景 | 41M |
异象 | 图像网 | 22M |
编码器 | 重量 | 帕拉姆斯,M |
---|---|---|
高效网-b0 | 图像网 | 4M |
高效网-b1 | 图像网 | 6M |
高效网-b2 | 图像网 | 7M |
高效网-b3 | 图像网 | 10M |
高效网络-b4 | 图像网 | 17M |
高效网-b5 | 图像网 | 28M |
高效网-b6 | 图像网 | 40M |
高效网-b7 | 图像网 | 63M |
timm-efficientnet-b0 | imagenet / advprop / 嘈杂的学生 | 4M |
timm-efficientnet-b1 | imagenet / advprop / 嘈杂的学生 | 6M |
timm-efficientnet-b2 | imagenet / advprop / 嘈杂的学生 | 7M |
timm-efficientnet-b3 | imagenet / advprop / 嘈杂的学生 | 10M |
timm-efficientnet-b4 | imagenet / advprop / 嘈杂的学生 | 17M |
timm-efficientnet-b5 | imagenet / advprop / 嘈杂的学生 | 28M |
timm-efficientnet-b6 | imagenet / advprop / 嘈杂的学生 | 40M |
timm-efficientnet-b7 | imagenet / advprop / 嘈杂的学生 | 63M |
timm-efficientnet-b8 | 图像网 / advprop | 84M |
timm-efficientnet-l2 | 吵闹的学生 | 474M |
timm-efficientnet-lite0 | 图像网 | 4M |
timm-efficientnet-lite1 | 图像网 | 5M |
timm-efficientnet-lite2 | 图像网 | 6M |
timm-efficientnet-lite3 | 图像网 | 8M |
timm-efficientnet-lite4 | 图像网 | 13M |
编码器 | 重量 | 帕拉姆斯,M |
---|---|---|
移动网络_v2 | 图像网 | 2M |
timm-mobilenetv3_large_075 | 图像网 | 1.78M |
timm-mobilenetv3_large_100 | 图像网 | 2.97M |
timm-mobilenetv3_large_minimal_100 | 图像网 | 1.41M |
timm-mobilenetv3_small_075 | 图像网 | 0.57M |
timm-mobilenetv3_small_100 | 图像网 | 0.93M |
timm-mobilenetv3_small_minimal_100 | 图像网 | 0.43M |
编码器 | 重量 | 帕拉姆斯,M |
---|---|---|
dpn68 | 图像网 | 11M |
dpn68b | 图像网+5k | 11M |
DPN92 | 图像网+5k | 34M |
dpn98 | 图像网 | 58M |
DPN107 | 图像网+5k | 84M |
DPN131 | 图像网 | 76M |
编码器 | 重量 | 帕拉姆斯,M |
---|---|---|
VG11 | 图像网 | 9M |
vgg11_bn | 图像网 | 9M |
VG13 | 图像网 | 9M |
vgg13_bn | 图像网 | 9M |
VG16 | 图像网 | 14M |
vgg16_bn | 图像网 | 14M |
VG19 | 图像网 | 20M |
vgg19_bn | 图像网 | 20M |
来自 SegFormer 的骨干在 Imagenet 上进行了预训练!可以与包中的其他解码器一起使用,您可以将 Mix Vision Transformer 与 Unet、FPN 等结合使用!
限制:
编码器 | 重量 | 帕拉姆斯,M |
---|---|---|
mit_b0 | 图像网 | 3M |
mit_b1 | 图像网 | 13M |
mit_b2 | 图像网 | 24M |
mit_b3 | 图像网 | 44M |
mit_b4 | 图像网 | 60M |
mit_b5 | 图像网 | 81M |
Apple 的“亚一毫秒”Backbone 在 Imagenet 上进行了预训练!可与所有解码器一起使用。
注意:在官方 github 存储库中,s0 变体具有额外的 num_conv_branches,导致参数比 s1 更多。
编码器 | 重量 | 帕拉姆斯,M |
---|---|---|
mobileone_s0 | 图像网 | 4.6M |
mobileone_s1 | 图像网 | 4.0M |
mobileone_s2 | 图像网 | 6.5M |
mobileone_s3 | 图像网 | 8.8M |
mobileone_s4 | 图像网 | 13.6M |
* ssl
、 swsl
- ImageNet 上的半监督和弱监督学习(repo)。
文档
Pytorch 图像模型(又名 timm)有很多预训练模型和接口,允许使用这些模型作为 smp 中的编码器,但是,并非所有模型都受支持
features_only
功能支持的编码器总数:549
model.encoder
- 预训练的主干以提取不同空间分辨率的特征model.decoder
- 取决于模型架构( Unet
/ Linknet
/ PSPNet
/ FPN
)model.segmentation_head
- 生成所需数量的掩模通道的最后一个块(还包括可选的上采样和激活)model.classification_head
- 在编码器顶部创建分类头的可选块model.forward(x)
- 按顺序将x
通过模型的编码器、解码器和分段头(以及分类头,如果指定) 输入通道参数允许您创建模型,该模型处理具有任意数量通道的张量。如果您使用 imagenet 中的预训练权重 - 第一个卷积的权重将被重用。对于 1 通道情况,它将是第一个卷积层的权重之和,否则通道将填充new_weight[:, i] = pretrained_weight[:, i % 3]
之类的权重,然后使用new_weight * 3 / new_in_channels
进行缩放。
model = smp . FPN ( 'resnet34' , in_channels = 1 )
mask = model ( torch . ones ([ 1 , 1 , 64 , 64 ]))
所有模型都支持aux_params
参数,默认设置为None
。如果aux_params = None
则不会创建分类辅助输出,否则模型不仅会产生mask
,还会产生形状为NC
label
输出。分类头由GlobalPooling->Dropout(可选)->Linear->Activation(可选)层组成,可以通过aux_params
配置如下:
aux_params = dict (
pooling = 'avg' , # one of 'avg', 'max'
dropout = 0.5 , # dropout ratio, default is None
activation = 'sigmoid' , # activation function, default is None
classes = 4 , # define number of output labels
)
model = smp . Unet ( 'resnet34' , classes = 4 , aux_params = aux_params )
mask , label = model ( x )
深度参数指定编码器中的下采样操作次数,因此如果指定较小的depth
,可以使模型更轻。
model = smp . Unet ( 'resnet34' , encoder_depth = 4 )
PyPI 版本:
$ pip install segmentation-models-pytorch
来源最新版本:
$ pip install git+https://github.com/qubvel/segmentation_models.pytorch
Segmentation Models
包广泛应用于图像分割竞赛中。在这里您可以找到竞赛、获奖者姓名及其解决方案的链接。
make install_dev # create .venv, install SMP in dev mode
make fixup # Ruff for formatting and lint checks
make table # generate a table with encoders and print to stdout
@misc{Iakubovskii:2019,
Author = {Pavel Iakubovskii},
Title = {Segmentation Models Pytorch},
Year = {2019},
Publisher = {GitHub},
Journal = {GitHub repository},
Howpublished = {url{https://github.com/qubvel/segmentation_models.pytorch}}
}
该项目根据 MIT 许可证分发