Hi everybody, this is Julien from Arcee. A few weeks ago, I noticed that AWS added a new built-in algo to Amazon SageMaker. This algo is called Tab Transformer. Of course, that got me interested; SageMaker, Transformers, these are my obsessions. So I had to take a look. In this video, we're going to run through this Tab Transformer algo and try to look under the hood to see what it's really made of. Let's get to work.
Starting from the AWS documentation, we see that the Tab Transformer algo has a really good name because it's a transformer-based model applied to categorical features, so tabular data. Finally, AWS managed to name something right. Now it makes sense. Looking at the details, we can see some code and how to use this as a built-in algo. I'm sure we'll jump to one of the notebooks in a second. There's an example for tabular classification and another one for tabular regression. We need to have CSV data for training and inference. That's definitely interesting. I'm really curious how you can use transformers with tabular data. Of course, we're all familiar with transformers for NLP and computer vision, but tabular data is relatively new. If you're interested in the details, there's a link to the arXiv paper. Let's quickly take a look.
This is an Amazon algo, and funny enough, the authors are listed as Amazon AWS. Confusing naming, but okay. I work for Amazon AWS. Anyway, it's a good read. I went through it. It's not too crazy. You can see categorical features being embedded and numerical features being normalized, etc. If that's your thing, you can go deeper. It gets a little mathy, but they tested it on a whole bunch of tabular datasets. You get some interesting numbers. The algo does reasonably well compared to other classification and regression algos. You can go and read the paper. Now let's run some code.
I started from the regression example, but you can try the classification example as well. The beginning is really SageMaker as we know it, so a little bit of setup. In this example, they used the Abalone dataset, which is a toy dataset, but we just want to see how this works. The abalone is a shellfish, and the goal is to predict the age of the shellfish based on physical measurements like length, diameter, etc. We need to grab the container for this built-in algo, and it says PyTorch, so we know it's based on PyTorch. We set some training parameters, like where the dataset is, etc. The notebook can optionally run automatic model tuning to optimize hyperparameters. By default, it's on, but I turned it off because I just wanted to run one training job and see what was going on. Feel free to tweak it more if you'd like.
We create the estimator, the cornerstone of any SageMaker script. As this is a built-in algo, we use the generic estimator and pass the name of the container. The entry point is a script called `transferlearning.py`, which is unusual for built-in algos, as they typically have their entry point hard-coded. We'll take a look at this script. We disable automatic model tuning and then call `fit` to start training. Usually, we don't read the log, but this time we will. It installs a whole bunch of things, including PyTorch. One thing it does not install is the transformers library. Let's find out what's actually installed. We see PyTorch, and it installs more stuff. PyTorch YGT rings a bell. You can check it out on GitHub. This is an interesting open-source project that lets you build transformer-based models for tabular data. There's one called deep tabular, so the plot thickens. This is probably what's running under the hood.
We see some hyperparameters, like the number of transformer blocks. Let's keep reading. The module DR is interesting because it's a package with dependencies and other scripts. That's probably where we'll find the training script. The training log goes on forever, but we're not super interested in that. For the record, this was reasonably fast. I trained for 2-3 minutes on an M5 instance, which is CPU-based. This was one of my early questions: can you train transformers for tabular data on CPU for a few minutes, and the cost would be comparable to traditional algos? Apparently, you can.
The rest of the notebook goes on to deploy the model, but feel free to run that and evaluate the algorithm. I just want to find out what's going on. The first thing I'll do is grab that script. You just copy the source, and inside, there's the `transferlearning.py` script. It's all based on PyTorch Wide Deep. Let's take a look at the code. There's some data loading, which isn't fascinating. We use script mode, so we can see all the training parameters passed as command line arguments. There's a little bit of column manipulation, preparing categorical and continuous columns, and fitting the data to the model. We use the Tab Transformer object from PyTorch Wide Deep, which is referenced here. The model is saved, and that answers the question of where the model comes from. It's the Tab Transformer implementation in PyTorch Wide Deep, which you can use out of the box on your own machine with the vanilla library.
They wrapped this around a training script that runs on SageMaker, saving you the trouble of writing data preparation code and training at scale. You can run this on GPU, but multi-GPU training is not supported. If you wanted to, you could tweak that script, copy-paste it, and use your own script in the estimator instead of the provided one.
Now let's look at the model. We can list the output location for that job, and sure enough, there's the model. There's also a lot of profiler stuff, which is enabled by default and can be annoying. You can disable it in the estimator. You can extract the model and see the model parameters. We have 32 dimensions corresponding to the preprocessed columns and four blocks. If we load the model as a PyTorch model and print the layer names, we see the four blocks and some MLP at the end. That answers my curiosity about what the model is and where it comes from.
It's consistent and clear. We have the research article, the implementation in PyTorch Wide Deep, and all that is wrapped around a built-in container and training script in SageMaker. I'm still not sure why they call it transfer learning, as this is initial training. However, the paper discusses using transfer learning with this model. If you have a ton of unlabeled data, you could do pre-training using something similar to masked language modeling for NLP, randomly masking some columns and training the model to predict them. Then you can fine-tune with a little bit of labeled data. That's an interesting technique, and it would be cool to see a notebook showing how to do that.
Although this particular algo is not available on the Hub, there is a Tab Transformer implementation in Keras, thanks to my DevRel colleagues. There's a space where you can demo a version trained on the adult dataset to predict if a person earns more than 50K dollars. The model is available on the Hub. Full credit goes to Khalid Salama. There's also a very good example in the Keras documentation on how to work with this.
Transformers are coming to tabular data, and we have a bunch of options to do that. I'm sure we'll see more models on how to do this. That's it for today. I hope you learned a few things, and I'll see you soon. Bye.