Implementasi Toolformer, Model Bahasa yang Dapat Menggunakan Alat, oleh MetaAI atas sponsor yang murah hati untuk bekerja dan membuka sumber penelitian kecerdasan buatan yang canggih
Enrico yang membuat bola bergulir dengan penerapan awal berbagai alat!
Terima kasih kepada ChatGPT karena telah melakukan semua ekspresi reguler di repositori ini untuk mengurai fungsi dan parameter panggilan API. Saya buruk dalam ekspresi reguler, jadi ini merupakan bantuan yang sangat besar dari AI (tanpa hambatan, ini sempurna).
$ pip install toolformer-pytorch
Contoh penggunaan dengan memberikan model bahasa kesadaran akan tanggal dan waktu saat ini.
import torch
from toolformer_pytorch import Toolformer , PaLM
# simple calendar api call - function that returns a string
def Calendar ():
import datetime
from calendar import day_name , month_name
now = datetime . datetime . now ()
return f'Today is { day_name [ now . weekday ()] } , { month_name [ now . month ] } { now . day } , { now . year } .'
# prompt for teaching it to use the Calendar function from above
prompt = f"""
Your task is to add calls to a Calendar API to a piece of text.
The API calls should help you get information required to complete the text.
You can call the API by writing "[Calendar()]"
Here are some examples of API calls:
Input: Today is the first Friday of the year.
Output: Today is the first [Calendar()] Friday of the year.
Input: The president of the United States is Joe Biden.
Output: The president of the United States is [Calendar()] Joe Biden.
Input: [input]
data = [
"The store is never open on the weekend, so today it is closed." ,
"The number of days from now until Christmas is 30" ,
"The current day of the week is Wednesday."
# model - here using PaLM, but any nn.Module that returns logits in the shape (batch, seq, num_tokens) is fine
model = PaLM (
dim = 512 ,
depth = 2 ,
heads = 8 ,
dim_head = 64
). cuda ()
# toolformer
toolformer = Toolformer (
model = model ,
model_seq_len = 256 ,
teach_tool_prompt = prompt ,
tool_id = 'Calendar' ,
tool = Calendar ,
finetune = True
# invoking this will
# (1) prompt the model with your inputs (data), inserted into [input] tag
# (2) with the sampled outputs, filter out the ones that made proper API calls
# (3) execute the API calls with the `tool` given
# (4) filter with the specialized filter function (which can be used independently as shown in the next section)
# (5) fine-tune on the filtered results
filtered_stats = toolformer ( data )
# then, once you see the 'finetune complete' message
response = toolformer . sample_model_with_api_calls ( "How many days until the next new years?" )
# hopefully you see it invoke the calendar and utilize the response of the api call...
Kebaruan utama dari makalah ini adalah menentukan skor kesesuaian untuk keluaran dari transformator yang diinstruksikan untuk memasukkan panggilan API. Skor tersebut digunakan untuk memfilter keluaran sampel guna menyempurnakan transformator guna membuat panggilan API yang mengurangi kebingungan teks berikutnya.
import torch
from toolformer_pytorch import (
Toolformer ,
PaLM ,
# model
palm = PaLM (
dim = 512 ,
num_tokens = 20000 ,
depth = 2 ,
heads = 8 ,
dim_head = 64
). cuda ()
# mock some tokens
mock_start_pos = 512
mock_api_call_length = 10
mock_api_start_id = 19998
mock_api_stop_id = 19999
tokens = torch . randint ( 0 , 20000 , ( 10 , 1024 )). cuda ()
tokens_with_api_response = torch . randint ( 0 , 20000 , ( 10 , 1024 )). cuda ()
tokens_without_api_response = torch . randint ( 0 , 20000 , ( 10 , 1024 )). cuda ()
tokens_with_api_response [:, mock_start_pos ] = mock_api_start_id
tokens_with_api_response [:, mock_start_pos + mock_api_call_length ] = mock_api_stop_id
tokens_without_api_response [:, mock_start_pos ] = mock_api_start_id
tokens_without_api_response [:, mock_start_pos + mock_api_call_length ] = mock_api_stop_id
# filter
filtered_results = filter_tokens_with_api_response (
model = palm ,
tokens = tokens ,
tokens_with_api_response = tokens_with_api_response ,
tokens_without_api_response = tokens_without_api_response ,
filter_threshold = 1. ,
api_start_token_id = mock_api_start_id ,
api_end_token_id = mock_api_stop_id
Untuk memanggil alat pada string yang dihasilkan oleh model bahasa, gunakan invoke_tools
from toolformer_pytorch import invoke_tools
def inc ( i ):
return i + 1
def dec ( i ):
return i - 1
function_registry = dict (
inc = inc ,
dec = dec
text = 'make the following api calls: [inc(1)] and [dec(2)] and [ignored(3)]'
invoke_tools ( function_registry , text )
# make the following api calls: [inc(1) → 2] and [dec(2) → 1] and [ignored(3)]
