La eficacia irrazonable de las características profundas como métrica de percepción
Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang. En CVPR, 2018.
Ejecute pip install lpips
. El siguiente código Python es todo lo que necesita.
import lpips
loss_fn_alex = lpips . LPIPS ( net = 'alex' ) # best forward scores
loss_fn_vgg = lpips . LPIPS ( net = 'vgg' ) # closer to "traditional" perceptual loss, when used for optimization
import torch
img0 = torch . zeros ( 1 , 3 , 64 , 64 ) # image should be RGB, IMPORTANT: normalized to [-1,1]
img1 = torch . zeros ( 1 , 3 , 64 , 64 )
d = loss_fn_alex ( img0 , img1 )
A continuación encontrará información más detallada sobre las variantes. Este repositorio contiene nuestra métrica de percepción (LPIPS) y nuestro conjunto de datos (BAPPS) . También se puede utilizar como "pérdida de percepción". Esto usa PyTorch; Una alternativa a Tensorflow está aquí.
Tabla de contenido
pip install -r requirements.txt
git clone https://github.com/richzhang/PerceptualSimilarity
cd PerceptualSimilarity
Evalúe la distancia entre parches de imágenes. Más alto significa más/más diferente. Más bajo significa más similar.
Scripts de ejemplo para tomar la distancia entre 2 imágenes específicas, todos los pares de imágenes correspondientes en 2 directorios o todos los pares de imágenes dentro de un directorio:
python lpips_2imgs.py -p0 imgs/ex_ref.png -p1 imgs/ex_p0.png --use_gpu
python lpips_2dirs.py -d0 imgs/ex_dir0 -d1 imgs/ex_dir1 -o imgs/example_dists.txt --use_gpu
python lpips_1dir_allpairs.py -d imgs/ex_dir_pair -o imgs/example_dists_pair.txt --use_gpu
El archivo test_network.py muestra un ejemplo de uso. Este fragmento es todo lo que realmente necesitas.
import lpips
loss_fn = lpips . LPIPS ( net = 'alex' )
d = loss_fn . forward ( im0 , im1 )
Las variables im0, im1
son un tensor/variable de PyTorch con forma Nx3xHxW
( N
parches de tamaño HxW
, imágenes RGB escaladas en [-1,+1]
). Esto devuelve d
, un tensor/variable de longitud N
Ejecute python test_network.py
para tomar la distancia entre la imagen de referencia de ejemplo ex_ref.png
y las imágenes distorsionadas ex_p0.png
y ex_p1.png
. Antes de ejecutarlo, ¿cuál crees que debería estar más cerca?
Algunas opciones por defecto en model.initialize
:
net='alex'
. Network alex
es más rápido, tiene el mejor rendimiento (como métrica directa) y es el predeterminado. Para el backpropping, la pérdida net='vgg'
está más cerca de la tradicional "pérdida de percepción".lpips=True
. Esto agrega una calibración lineal además de las características intermedias en la red. Establezca esto en lpips=False
para ponderar por igual todas las funciones. El archivo lpips_loss.py
muestra cómo optimizar de forma iterativa utilizando la métrica. Ejecute python lpips_loss.py
para una demostración. El código también se puede utilizar para implementar la pérdida VGG básica, sin nuestros pesos aprendidos.
Más alto significa más/más diferente. Más bajo significa más similar.
Descubrimos que las activaciones de red profundas funcionan sorprendentemente bien como métrica de similitud perceptiva. Esto fue cierto en todas las arquitecturas de red (SqueezeNet [2,8 MB], AlexNet [9,1 MB] y VGG [58,9 MB] proporcionaron puntuaciones similares) y señales de supervisión (las no supervisadas, las autosupervisadas y las supervisadas tienen un buen rendimiento). Mejoramos ligeramente las puntuaciones al "calibrar" linealmente las redes, agregando una capa lineal encima de las redes de clasificación disponibles en el mercado. Proporcionamos 3 variantes, utilizando capas lineales encima de las redes SqueezeNet, AlexNet (predeterminada) y VGG.
Si utiliza LPIPS en su publicación, especifique qué versión está utilizando. La versión actual es 0.1. Puede configurar version='0.0'
para la versión inicial.
Ejecute bash ./scripts/download_dataset.sh
para descargar y descomprimir el conjunto de datos en el directorio ./dataset
. Se necesitan [6,6 GB] en total. Alternativamente, ejecute bash ./scripts/download_dataset_valonly.sh
para descargar solo el conjunto de validación [1,3 GB].
El script test_dataset_model.py
evalúa un modelo de percepción en un subconjunto del conjunto de datos.
Banderas de conjunto de datos
--dataset_mode
: 2afc
o jnd
, qué tipo de juicio perceptual evaluar--datasets
: enumera los conjuntos de datos a evaluar--dataset_mode 2afc
: las opciones son [ train/traditional
, train/cnn
, val/traditional
, val/cnn
, val/superres
, val/deblur
, val/color
, val/frameinterp
]--dataset_mode jnd
: las opciones son [ val/traditional
, val/cnn
]Banderas del modelo de similitud perceptiva
--model
: modelo de similitud perceptiva a utilizarlpips
para nuestro modelo de similitud aprendido LPIPS (red lineal además de activaciones internas de red previamente entrenada)baseline
para una red de clasificación (sin calibrar con todas las capas promediadas)l2
para distancia euclidianassim
para métrica de imagen de similitud estructurada--net
: [ squeeze
, alex
, vgg
] para los modelos net-lin
y net
; ignorado para los modelos l2
y ssim
--colorspace
: las opciones son [ Lab
, RGB
], utilizadas para los modelos l2
y ssim
; ignorado para los modelos net-lin
y net
Banderas varias
--batch_size
: tamaño del lote de evaluación (el valor predeterminado será 1)--use_gpu
: activa esta bandera para el uso de GPU Un ejemplo de uso es el siguiente: python ./test_dataset_model.py --dataset_mode 2afc --datasets val/traditional val/cnn --model lpips --net alex --use_gpu --batch_size 50
. Esto evaluaría nuestro modelo en los conjuntos de datos de validación "tradicional" y "cnn".
El conjunto de datos contiene dos tipos de juicios perceptivos: dos opciones forzadas alternativas (2AFC) y diferencias apenas perceptibles (JND) .
(1) A los evaluadores de 2AFC se les asignó un triplete de parches (1 de referencia + 2 distorsionados). Se les pidió que seleccionaran cuál de los distorsionados estaba "más cerca" de la referencia.
Los conjuntos de entrenamiento contienen 2 juicios/triplete.
train/traditional
[56,6k trillizos]train/cnn
[38,1k trillizos]train/mix
[56,6k trillizos]Los conjuntos de validación contienen 5 juicios/triplete.
val/traditional
[4,7k trillizos]val/cnn
[4,7k trillizos]val/superres
[10,9k tripletes]val/deblur
[9,4k trillizos]val/color
[4,7k tripletes]val/frameinterp
[tripletes de 1,9k]Cada subdirectorio 2AFC contiene las siguientes carpetas:
ref
: parches de referencia originalesp0,p1
: dos parches distorsionadosjudge
: juicios humanos - 0 si todos prefirieron p0, 1 si todos los humanos prefirieron p1(2) A los evaluadores de JND se les presentaron dos parches, uno de referencia y otro distorsionado, por un tiempo limitado. Se les preguntó si los parches eran iguales (idénticos) o diferentes.
Cada conjunto contiene 3 evaluaciones humanas/ejemplo.
val/traditional
[4,8k pares]val/cnn
[4,8k pares]Cada subdirectorio JND contiene las siguientes carpetas:
p0,p1
: dos parchessame
: juicios humanos: 0 si todos los humanos pensaran que los parches eran diferentes, 1 si todos los humanos pensaran que los parches eran iguales Consulte el script train_test_metric.sh
para ver un ejemplo de entrenamiento y prueba de la métrica. El script entrenará un modelo en el conjunto de entrenamiento completo durante 10 épocas y luego probará la métrica aprendida en todos los conjuntos de validación. Los números deben coincidir aproximadamente con la fila de Alex-lin en la Tabla 5 del artículo. El código admite el entrenamiento de una capa lineal sobre una representación existente. La capacitación agregará un subdirectorio en el directorio checkpoints
.
También puedes entrenar versiones "scratch" y "tune" ejecutando train_test_metric_scratch.sh
y train_test_metric_tune.sh
, respectivamente.
Si encuentra que este repositorio es útil para su investigación, utilice lo siguiente.
@inproceedings{zhang2018perceptual,
title={The Unreasonable Effectiveness of Deep Features as a Perceptual Metric},
author={Zhang, Richard and Isola, Phillip and Efros, Alexei A and Shechtman, Eli and Wang, Oliver},
booktitle={CVPR},
year={2018}
}
Este repositorio toma prestado parcialmente del repositorio pytorch-CycleGAN-and-pix2pix. El código de precisión promedio (AP) se toma prestado del repositorio py-faster-rcnn. Angjoo Kanazawa, Connelly Barnes, Gaurav Mittal, wilhelmhb, Filippo Mameli, SuperShinyEyes, Minyoung Huh ayudaron a mejorar el código base.