The REaLTabFormer (Realistic Relational and Tabular Data using Transformers) offers a unified framework for synthesizing tabular data of different types. A sequence-to-sequence (Seq2Seq) model is used for generating synthetic relational datasets. The REaLTabFormer model for a non-relational tabular data uses GPT-2, and can be used out-of-the-box to model any tabular data with independent observations.
REaLTabFormer: Generating Realistic Relational and Tabular Data using Transformers
Paper on ArXiv
REaLTabFormer is available on PyPi and can be easily installed with pip (Python version >= 3.7):
pip install realtabformer
We show examples of using the REaLTabFormer for modeling and generating synthetic data from a trained model.
Note
The model implements an optimal stopping criterion based on the synthetic data distribution when training a non-relational tabular model. The model will stop training when the synthetic data distribution is close to the real data distribution.
Make sure to set the epochs
parameter to a large number to allow the model to fit the data better.The model will stop training when the optimal stopping criterion is met.
# pip install realtabformerimport pandas as pdfrom realtabformer import REaLTabFormerdf = pd.read_csv("foo.csv")# NOTE: Remove any unique identifiers in the# data that you don't want to be modeled.# Non-relational or parent table.rtf_model = REaLTabFormer(model_type="tabular",gradient_accumulation_steps=4,logging_steps=100)# Fit the model on the dataset.# Additional parameters can be# passed to the `.fit` method.rtf_model.fit(df)# Save the model to the current directory.# A new directory `rtf_model/` will be created.# In it, a directory with the model's# experiment id `idXXXX` will also be created# where the artefacts of the model will be stored.rtf_model.save("rtf_model/")# Generate synthetic data with the same# number of observations as the real dataset.samples = rtf_model.sample(n_samples=len(df))# Load the saved model. The directory to the# experiment must be provided.rtf_model2 = REaLTabFormer.load_from_dir(path="rtf_model/idXXXX")
# pip install realtabformerimport osimport pandas as pdfrom pathlib import Pathfrom realtabformer import REaLTabFormerparent_df = pd.read_csv("foo.csv")child_df = pd.read_csv("bar.csv")join_on = "unique_id"# Make sure that the key columns in both the# parent and the child table have the same name.assert ((join_on in parent_df.columns) and(join_on in child_df.columns))# Non-relational or parent table. Don't include the# unique_id field.parent_model = REaLTabFormer(model_type="tabular")parent_model.fit(parent_df.drop(join_on, axis=1))pdir = Path("rtf_parent/")parent_model.save(pdir)# # Get the most recently saved parent model,# # or a specify some other saved model.# parent_model_path = pdir / "idXXX"parent_model_path = sorted([p for p in pdir.glob("id*") if p.is_dir()],key=os.path.getmtime)[-1]child_model = REaLTabFormer(model_type="relational",parent_realtabformer_path=parent_model_path,output_max_length=None,train_size=0.8)child_model.fit(df=child_df,in_df=parent_df,join_on=join_on)# Generate parent samples.parent_samples = parent_model.sample(len(parend_df))# Create the unique ids based on the index.parent_samples.index.name = join_onparent_samples = parent_samples.reset_index()# Generate the relational observations.child_samples = child_model.sample(input_unique_ids=parent_samples[join_on],input_df=parent_samples.drop(join_on, axis=1),gen_batch=64)
The REaLTabFormer framework provides an interface to easily build observation validators for filtering invalid synthetic samples. We show an example of using the GeoValidator
below. The chart on the left shows the distribution of the generated latitude and longitude without validation. The chart on the right shows the synthetic samples with observations that have been validated using the GeoValidator
with the California boundary. Still, even when we did not optimally train the model for generating this, the invalid samples (falling outside of the boundary) are scarce from the generated data with no validator.
# !pip install geopandas &> /dev/null# !pip install realtabformer &> /dev/null# !git clone https://github.com/joncutrer/geopandas-tutorial.git &> /dev/nullimport geopandasimport seaborn as snsimport matplotlib.pyplot as pltfrom realtabformer import REaLTabFormerfrom realtabformer import rtf_validators as rtf_valfrom shapely.geometry import Polygon, LineString, Point, MultiPolygonfrom sklearn.datasets import fetch_california_housingdef plot_sf(data, samples, title=None):xlims = (-126, -113.5)ylims = (31, 43)bins = (50, 50)dd = samples.copy()pp = dd.loc[dd["Longitude"].between(data["Longitude"].min(), data["Longitude"].max()) &dd["Latitude"].between(data["Latitude"].min(), data["Latitude"].max()) ]g = sns.JointGrid(data=pp, x="Longitude", y="Latitude", marginal_ticks=True)g.plot_joint(sns.histplot,bins=bins, )states[states['NAME'] == 'California'].boundary.plot(ax=g.ax_joint)g.ax_joint.set_xlim(*xlims)g.ax_joint.set_ylim(*ylims)g.plot_marginals(sns.histplot, element="step", color="#03012d")if title:g.ax_joint.set_title(title)plt.tight_layout()# Get geographic filesstates = geopandas.read_file('geopandas-tutorial/data/usa-states-census-2014.shp')states = states.to_crs("EPSG:4326") # GPS Projection# Get the California housing datasetdata = fetch_california_housing(as_frame=True).frame# We create a model with small epochs for the demo, default is 200.rtf_model = REaLTabFormer(model_type="tabular",batch_size=64,epochs=10,gradient_accumulation_steps=4,logging_steps=100)# Fit the specified model. We also reduce the num_bootstrap, default is 500.rtf_model.fit(data, num_bootstrap=10)# Save the trained modelrtf_model.save("rtf_model/")# Sample raw data without validatorsamples_raw = rtf_model.sample(n_samples=10240, gen_batch=512)# Sample data with the geographic validatorobs_validator = rtf_val.ObservationValidator()obs_validator.add_validator("geo_validator",rtf_val.GeoValidator(MultiPolygon(states[states['NAME'] == 'California'].geometry[0])), ("Longitude", "Latitude") )samples_validated = rtf_model.sample(n_samples=10240, gen_batch=512,validator=obs_validator, )# Visualize the samplesplot_sf(data, samples_raw, title="Raw samples")plot_sf(data, samples_validated, title="Validated samples")
Please cite our work if you use the REaLTabFormer in your projects or research.
@article{solatorio2023realtabformer, title={REaLTabFormer: Generating Realistic Relational and Tabular Data using Transformers}, author={Solatorio, Aivin V. and Dupriez, Olivier}, journal={arXiv preprint arXiv:2302.02041}, year={2023}}
We thank the World Bank-UNHCR Joint Data Center on Forced Displacement (JDC) for funding the project "Enhancing Responsible Microdata Access to Improve Policy and Response in Forced Displacement Situations" (KP-P174174-GINP-TF0B5124). A part of the fund went into supporting the development of the REaLTabFormer framework which was used to generate the synthetic population for research on disclosure risk and the mosaic effect.
We also send ? to the HuggingFace ? for all the open-sourced software they release. And to all open-sourced projects, thank you!