Il s'agit d'un module Python permettant d'expérimenter différents algorithmes d'apprentissage actif. Il existe quelques éléments clés pour mener des expériences d’apprentissage actif :
Le script d'expérimentation principal est run_experiment.py
avec de nombreux indicateurs pour différentes options d'exécution.
Les ensembles de données pris en charge peuvent être téléchargés dans un répertoire spécifié en exécutant utils/create_data.py
.
Les méthodes d'apprentissage actif prises en charge se trouvent dans sampling_methods
.
Ci-dessous, j'entrerai dans chaque composant plus en détail.
AVERTISSEMENT : ce n'est pas un produit Google officiel.
Les dépendances sont dans requirements.txt
. Veuillez vous assurer que ces packages sont installés avant d'exécuter des expériences. Si tensorflow
compatible GPU est souhaité, veuillez suivre les instructions ici.
Il est fortement suggéré d'installer toutes les dépendances dans un virtualenv
distinct pour faciliter la gestion des packages.
Par défaut, les ensembles de données sont enregistrés dans /tmp/data
. Vous pouvez spécifier un autre répertoire via l'indicateur --save_dir
.
Le téléchargement de tous les ensembles de données prendra beaucoup de temps, alors soyez patient. Vous pouvez spécifier un sous-ensemble de données à télécharger en transmettant une chaîne d'ensembles de données séparés par des virgules via l'indicateur --datasets
.
Il existe quelques indicateurs clés pour run_experiment.py
:
dataset
: nom de l'ensemble de données, doit correspondre au nom de sauvegarde utilisé dans create_data.py
. Doit également exister dans le data_dir.
sampling_method
: méthode d’apprentissage actif à utiliser. Doit être spécifié dans sampling_methods/constants.py
.
warmstart_size
: lot initial d'exemples uniformément échantillonnés à utiliser comme données de départ. Le flotteur indique le pourcentage du total des données d'entraînement et le nombre entier indique la taille brute.
batch_size
: nombre de points de données à demander dans chaque lot. Le flotteur indique le pourcentage du total des données d'entraînement et le nombre entier indique la taille brute.
score_method
: modèle à utiliser pour évaluer les performances de la méthode d'échantillonnage. Doit être dans la méthode get_model
de utils/utils.py
.
data_dir
: répertoire avec les ensembles de données enregistrés.
save_dir
: répertoire pour sauvegarder les résultats.
Ceci est juste un sous-ensemble de tous les drapeaux. Il existe également des options de prétraitement, d'introduction du bruit d'étiquetage, de sous-échantillonnage des ensembles de données et d'utilisation d'un modèle différent pour sélectionner que pour noter/évaluer.
Toutes les méthodes d'apprentissage actives nommées se trouvent dans sampling_methods/constants.py
.
Vous pouvez également spécifier un mélange de méthodes d'apprentissage actives en suivant le modèle de [sampling_method]-[mixture_weight]
séparés par des tirets ; c'est-à-dire mixture_of_samplers-margin-0.33-informative_diverse-0.33-uniform-0.34
.
Certaines méthodes d'échantillonnage prises en charge incluent :
Uniforme : les échantillons sont sélectionnés via un échantillonnage uniforme.
Marge : méthode d’échantillonnage basée sur l’incertitude.
Informative et diversifiée : méthode d’échantillonnage basée sur la marge et les grappes.
k-center greedy : stratégie représentative qui forme de manière gourmande un lot de points pour minimiser la distance maximale par rapport à un point étiqueté.
Densité graphique : stratégie représentative qui sélectionne des points dans les régions denses du pool.
Exp3 bandit : méthode d'apprentissage méta-active qui tente d'apprendre la méthode d'échantillonnage optimale à l'aide d'un algorithme de bandit multi-bras populaire.
Implémentez soit un échantillonneur de base qui hérite de SamplingMethod
, soit un méta-échantillonneur qui appelle des échantillonneurs de base qui héritent de WrapperSamplingMethod
.
La seule méthode qui doit être implémentée par n'importe quel échantillonneur est select_batch_
, qui peut avoir des arguments nommés arbitrairement. La seule restriction est que le nom de la même entrée doit être cohérent dans tous les échantillonneurs (c'est-à-dire que les indices des exemples déjà sélectionnés ont tous le même nom dans tous les échantillonneurs). L'ajout d'un nouvel argument nommé qui n'a pas été utilisé dans d'autres méthodes d'échantillonnage nécessitera de l'introduire dans l'appel select_batch
dans run_experiment.py
.
Après avoir implémenté votre échantillonneur, assurez-vous de l'ajouter à constants.py
afin qu'il puisse être appelé depuis run_experiment.py
.
Tous les modèles disponibles se trouvent dans la méthode get_model
de utils/utils.py
.
Méthodes prises en charge :
SVM linéaire : méthode scikit avec wrapper de recherche de grille pour le paramètre de régularisation.
Kernel SVM : méthode scikit avec wrapper de recherche de grille pour le paramètre de régularisation.
Logistc Regression : méthode scikit avec wrapper de recherche de grille pour le paramètre de régularisation.
Petit CNN : CNN à 4 couches optimisé à l'aide de rmsprop implémenté dans Keras avec le backend tensorflow.
Classification des moindres carrés du noyau : un solveur de gradient de bloc qui peut utiliser plusieurs cœurs est donc souvent plus rapide que scikit Kernel SVM.
Les nouveaux modèles doivent suivre l'API scikit learn et implémenter les méthodes suivantes
fit(X, y[, sample_weight])
: ajuste le modèle aux caractéristiques d'entrée et à la cible.
predict(X)
: prédire la valeur des caractéristiques en entrée.
score(X, y)
: renvoie la métrique cible en fonction des fonctionnalités de test et des cibles de test.
decision_function(X)
(facultatif) : probabilités de classe de retour, distance par rapport aux limites de décision ou autre métrique pouvant être utilisée par l'échantillonneur de marge comme mesure de l'incertitude.
Voir small_cnn.py
pour un exemple.
Après avoir implémenté votre nouveau modèle, assurez-vous de l'ajouter à la méthode get_model
de utils/utils.py
.
Actuellement, les modèles doivent être ajoutés de manière ponctuelle et tous les classificateurs scikit-learn ne sont pas pris en charge en raison de la nécessité d'une intervention de l'utilisateur sur l'opportunité et la manière d'ajuster les hyperparamètres du modèle. Cependant, il est très facile d'ajouter un modèle scikit-learn avec une recherche d'hyperparamètres enveloppée comme modèle pris en charge.
Le script utils/chart_data.py
gère le traitement des données et la création de graphiques pour un ensemble de données et un répertoire source spécifiés.