Transfer Learning Suite
1.0.0
此儲存庫用作遷移學習套件。目標是能夠使用任何內建 Keras 影像分類模型輕鬆執行遷移學習!歡迎提出任何改進此存儲庫的建議或您希望看到的任何新功能!
您也可以查看我的語義分割套件。
所有 Keras 內建模型均可用:
模型 | 尺寸 | Top-1 準確度 | 前 5 名準確度 | 參數 | 深度 |
---|---|---|---|---|---|
VGG16 | 528MB | 0.715 | 0.901 | 138,357,544 | 23 |
VGG19 | 549MB | 0.727 | 0.910 | 143,667,240 | 26 |
殘網50 | 99MB | 0.759 | 0.929 | 25,636,712 | 168 |
Xception | 88MB | 0.790 | 0.945 | 22,910,480 | 126 |
盜夢空間V3 | 92MB | 0.788 | 0.944 | 23,851,784 | 159 |
InceptionResNetV2 | 215MB | 0.804 | 0.953 | 55,873,736 | 第572章 |
行動網路 | 17MB | 0.665 | 0.871 | 4,253,864 | 88 |
密集網121 | 33MB | 0.745 | 0.918 | 8,062,504 | 121 |
密集網169 | 57MB | 0.759 | 0.928 | 14,307,880 | 169 |
密集網201 | 80MB | 0.770 | 0.933 | 20,242,984 | 201 |
NAS網路移動 | 21MB | 不適用 | 不適用 | 5,326,716 | 不適用 |
NAS網路大型 | 342MB | 不適用 | 不適用 | 88,949,818 | 不適用 |
main.py:訓練和預測模式
utils.py:輔助實用函數
檢查點:訓練期間每個時期的檢查點文件
預測:預測結果
該項目具有以下相依性:
numpy sudo pip install numpy
OpenCV Python sudo apt-get install python-opencv
TensorFlow sudo pip install --upgrade tensorflow-gpu
Keras sudo pip install keras
您唯一需要做的就是按照以下結構設置資料夾:
├── "dataset_name"
| ├── train
| | ├── class_1_images
| | ├── class_2_images
| | ├── class_X_images
| | ├── .....
| ├── val
| | ├── class_1_images
| | ├── class_2_images
| | ├── class_X_images
| | ├── .....
| ├── test
| | ├── class_1_images
| | ├── class_2_images
| | ├── class_X_images
| | ├── .....
然後你可以簡單地運行main.py
!查看可選的命令列參數:
usage: main.py [-h] [--num_epochs NUM_EPOCHS] [--mode MODE] [--image IMAGE]
[--continue_training CONTINUE_TRAINING] [--dataset DATASET]
[--resize_height RESIZE_HEIGHT] [--resize_width RESIZE_WIDTH]
[--batch_size BATCH_SIZE] [--dropout DROPOUT] [--h_flip H_FLIP]
[--v_flip V_FLIP] [--rotation ROTATION] [--zoom ZOOM]
[--shear SHEAR] [--model MODEL]
optional arguments:
-h, --help show this help message and exit
--num_epochs NUM_EPOCHS
Number of epochs to train for
--mode MODE Select "train", or "predict" mode. Note that for
prediction mode you have to specify an image to run
the model on.
--image IMAGE The image you want to predict on. Only valid in
"predict" mode.
--continue_training CONTINUE_TRAINING
Whether to continue training from a checkpoint
--dataset DATASET Dataset you are using.
--resize_height RESIZE_HEIGHT
Height of cropped input image to network
--resize_width RESIZE_WIDTH
Width of cropped input image to network
--batch_size BATCH_SIZE
Number of images in each batch
--dropout DROPOUT Dropout ratio
--h_flip H_FLIP Whether to randomly flip the image horizontally for
data augmentation
--v_flip V_FLIP Whether to randomly flip the image vertically for data
augmentation
--rotation ROTATION Whether to randomly rotate the image for data
augmentation
--zoom ZOOM Whether to randomly zoom in for data augmentation
--shear SHEAR Whether to randomly shear in for data augmentation
--model MODEL Your pre-trained classification model of choice