Dies ist ein Python-Modul zum Experimentieren mit verschiedenen aktiven Lernalgorithmen. Es gibt einige Schlüsselkomponenten für die Durchführung aktiver Lernexperimente:
Das Hauptexperimentskript ist run_experiment.py
mit vielen Flags für verschiedene Ausführungsoptionen.
Unterstützte Datensätze können durch Ausführen utils/create_data.py
in ein angegebenes Verzeichnis heruntergeladen werden.
Unterstützte aktive Lernmethoden finden Sie in sampling_methods
.
Im Folgenden werde ich näher auf die einzelnen Komponenten eingehen.
HAFTUNGSAUSSCHLUSS: Dies ist kein offizielles Google-Produkt.
Die Abhängigkeiten befinden sich in requirements.txt
. Bitte stellen Sie sicher, dass diese Pakete installiert sind, bevor Sie Experimente durchführen. Wenn GPU-fähiger tensorflow
gewünscht wird, folgen Sie bitte den Anweisungen hier.
Es wird dringend empfohlen, alle Abhängigkeiten in einer separaten virtualenv
zu installieren, um die Paketverwaltung zu vereinfachen.
Standardmäßig werden die Datensätze unter /tmp/data
gespeichert. Über das Flag --save_dir
können Sie ein anderes Verzeichnis angeben.
Das erneute Herunterladen aller Datensätze wird sehr zeitaufwändig sein. Bitte haben Sie etwas Geduld. Sie können eine Teilmenge der herunterzuladenden Daten angeben, indem Sie über das Flag --datasets
eine durch Kommas getrennte Zeichenfolge von Datensätzen übergeben.
Es gibt einige Schlüsselflags für run_experiment.py
:
dataset
: Name des Datensatzes, muss mit dem in create_data.py
verwendeten Speichernamen übereinstimmen. Muss auch im Datenverzeichnis vorhanden sein.
sampling_method
: Zu verwendende aktive Lernmethode. Muss in sampling_methods/constants.py
angegeben werden.
warmstart_size
: erster Stapel einheitlich abgetasteter Beispiele zur Verwendung als Startdaten. Float gibt den Prozentsatz der gesamten Trainingsdaten an und Integer gibt die Rohgröße an.
batch_size
: Anzahl der Datenpunkte, die in jedem Batch angefordert werden sollen. Float gibt den Prozentsatz der gesamten Trainingsdaten an und Integer gibt die Rohgröße an.
score_method
: Modell zur Bewertung der Leistung der Stichprobenmethode. Muss in der get_model
-Methode von utils/utils.py
enthalten sein.
data_dir
: Verzeichnis mit gespeicherten Datensätzen.
save_dir
: Verzeichnis zum Speichern der Ergebnisse.
Dies ist nur eine Teilmenge aller Flags. Es gibt auch Optionen für die Vorverarbeitung, die Einführung von Beschriftungsrauschen, die Unterabtastung von Datensätzen und die Verwendung eines anderen Modells zur Auswahl als zur Bewertung/Bewertung.
Alle genannten aktiven Lernmethoden sind in sampling_methods/constants.py
enthalten.
Sie können auch eine Mischung aktiver Lernmethoden angeben, indem Sie dem durch Bindestriche getrennten Muster [sampling_method]-[mixture_weight]
folgen; dh mixture_of_samplers-margin-0.33-informative_diverse-0.33-uniform-0.34
.
Zu den unterstützten Stichprobenmethoden gehören:
Einheitlich: Proben werden durch einheitliche Probenahme ausgewählt.
Marge: auf Unsicherheit basierende Stichprobenmethode.
Informativ und vielfältig: Margen- und Cluster-basierte Stichprobenmethode.
k-center greedy: Repräsentative Strategie, die gierig eine Reihe von Punkten bildet, um den maximalen Abstand von einem markierten Punkt zu minimieren.
Diagrammdichte: repräsentative Strategie, die Punkte in dichten Poolbereichen auswählt.
Exp3 Bandit: Metaaktive Lernmethode, die versucht, mithilfe eines beliebten mehrarmigen Bandit-Algorithmus die optimale Stichprobenmethode zu erlernen.
Implementieren Sie entweder einen Basis-Sampler, der von SamplingMethod
erbt, oder einen Meta-Sampler, der Basis-Sampler aufruft, der von WrapperSamplingMethod
erbt.
Die einzige Methode, die von jedem Sampler implementiert werden muss, ist select_batch_
, die beliebige benannte Argumente haben kann. Die einzige Einschränkung besteht darin, dass der Name für die gleiche Eingabe über alle Sampler hinweg konsistent sein muss (dh die Indizes für bereits ausgewählte Beispiele haben alle über alle Sampler hinweg denselben Namen). Das Hinzufügen eines neuen benannten Arguments, das in anderen Stichprobenmethoden nicht verwendet wurde, erfordert die Eingabe dieses in den select_batch
-Aufruf in run_experiment.py
.
Nachdem Sie Ihren Sampler implementiert haben, müssen Sie ihn unbedingt zu constants.py
hinzufügen, damit er von run_experiment.py
aufgerufen werden kann.
Alle verfügbaren Modelle befinden sich in der Methode get_model
von utils/utils.py
.
Unterstützte Methoden:
Lineare SVM: Scikit-Methode mit Grid-Such-Wrapper für Regularisierungsparameter.
Kernel-SVM: Scikit-Methode mit Grid-Such-Wrapper für Regularisierungsparameter.
Logistische Regression: Scikit-Methode mit Grid-Such-Wrapper für Regularisierungsparameter.
Kleines CNN: 4-schichtiges CNN, optimiert mit rmsprop, implementiert in Keras mit Tensorflow-Backend.
Kernel-Klassifizierung der kleinsten Quadrate: Blockgradienten-Löser, der mehrere Kerne verwenden kann und daher oft schneller ist als Scikit-Kernel-SVM.
Neue Modelle müssen der Scikit-Lern-API folgen und die folgenden Methoden implementieren
fit(X, y[, sample_weight])
: Passen Sie das Modell an die Eingabemerkmale und das Ziel an.
predict(X)
: Sagen Sie den Wert der Eingabemerkmale voraus.
score(X, y)
: gibt die Zielmetrik bei gegebenen Testfunktionen und Testzielen zurück.
decision_function(X)
(optional): Klassenwahrscheinlichkeiten, Abstand zu Entscheidungsgrenzen oder andere Metriken zurückgeben, die vom Margin-Sampler als Maß für die Unsicherheit verwendet werden können.
Ein Beispiel finden Sie unter small_cnn.py
.
Nachdem Sie Ihr neues Modell implementiert haben, müssen Sie es unbedingt zur get_model
-Methode von utils/utils.py
hinzufügen.
Derzeit müssen Modelle einmalig hinzugefügt werden und nicht alle Scikit-Learn-Klassifikatoren werden unterstützt, da Benutzereingaben dazu erforderlich sind, ob und wie die Hyperparameter des Modells optimiert werden sollen. Es ist jedoch sehr einfach, ein Scikit-Learn-Modell mit umschlossener Hyperparametersuche als unterstütztes Modell hinzuzufügen.
Das Skript utils/chart_data.py
übernimmt die Datenverarbeitung und Diagrammerstellung für einen angegebenen Datensatz und ein angegebenes Quellverzeichnis.