展会信息港展会大全

Google的神经网络表格处理模型TabNet介绍
来源:互联网   发布日期:2020-10-14 08:30:33   浏览:8238次  

导读:Google Research的TabNet于2019年发布,在预印稿中被宣称优于表格数据的现有方法。它是如何工作的,又如何可以尝试呢? 表格数据可能构成当今大多数业务数据。考虑诸如零售交易,点击流数据,工厂中的温度和压力传感器,银行使用的KYC (Know Your Customer)...

Google Research的TabNet于2019年发布,在预印稿中被宣称优于表格数据的现有方法。它是如何工作的,又如何可以尝试呢?

Google的神经网络表格处理模型TabNet介绍

表格数据可能构成当今大多数业务数据。考虑诸如零售交易,点击流数据,工厂中的温度和压力传感器,银行使用的KYC (Know Your Customer) 信息或制药公司使用的模型生物的基因表达数据之类的事情。

论文称为TabNet: Attentive Interpretable Tabular Learning(https://arxiv.org/pdf/1908.07442.pdf),很好地总结了作者正在尝试做的事情。“Net”部分告诉我们这是一种神经网络,“Attentive ”部分表示它正在使用一种注意力机制,旨在实现可解释性,并用于表格数据的机器学习。

它是如何工作的?

TabNet使用一种软功能选择将重点仅放在对当前示例很重要的功能上。这是通过顺序的多步骤决策机制完成的。即,以多个步骤自上而下地处理输入信息。正如论文所指出的那样,“自上而下关注的思想是从处理视觉和语言数据或强化学习中得到的启发,可以在高维输入中搜索一小部分相关信息。”

尽管它们与BERT等流行的NLP模型中使用的transformer 有些不同,但执行这种顺序关注的构件却称为transformer 块。这些transformer 使用自注意力机制,试图模拟句子中不同单词之间的依赖关系。这里使用的transformer类型试图使用“软”特性选择,一步一步地消除与示例无关的那些特性,这是通过使用sparsemax函数完成的。

这篇论文的第一个图,如下重现,描绘了信息是如何聚集起来形成预测的。

Google的神经网络表格处理模型TabNet介绍

TabNet的一个好特性是它不需要特性预处理。另一个原因是,它具有内置的可解释性,即为每个示例选择最相关的特性。这意味着您不必应用外部解释模块,如shap或LIME。

在阅读本文时,要理解这个架构中发生了什么并不容易,但幸运的是,已经发表的代码稍微澄清了一些问题,并表明它并不像您可能认为的那样复杂。

我怎么使用它?

现在TabNet有了更好的实现,如下所述:一个是PyTorch的接口,它有一个类似scikit学习的接口,还有一个是FastAI的接口。

根据作者readme描述要点如下:

为每个数据集创建新的train.csv,val.csv和test.csv文件,我不如读取整个数据集并在内存中进行拆分(当然,只要可行),所以我写了一个在我的代码中为Pandas提供了新的输入功能。

修改data_helper.py文件可能需要一些工作,至少在最初不确定您要做什么以及应该如何定义功能列时(至少我是这样)。还有许多参数需要更改,但它们位于主训练循环文件中,而不是数据帮助器文件中。有鉴于此,我还尝试在我的代码中概括和简化此过程。

我添加了一些快速的代码来进行超参数优化,但到目前为止仅用于分类。

还值得一提的是,作者提供的示例代码仅显示了如何进行分类,而不是回归,因此用户也必须编写额外的代码。我添加了具有简单均方误差损失的回归功能。

使用命令行运行测试

python train_tabnet.py \

--csv-path data/adult.csv \

--target-name "

--categorical-features workclass,education,marital.status,\

occupation,relationship,race,sex,native.country\

--feature_dim 16 \

--output_dim 16 \

--batch-size 4096 \

--virtual-batch-size 128 \

--batch-momentum 0.98 \

--gamma 1.5 \

--n_steps 5 \

--decay-every 2500 \

--lambda-sparsity 0.0001 \

--max-steps 7700

强制性参数包括--csv-path(指向CSV文件的位置),-target-name(具有预测目标的列的名称)和-category-featues(逗号分隔列表) 应该视为分类的功能)。其余输入参数是需要针对每个特定问题进行优化的超参数。但是,上面显示的值直接取自TabNet论文,因此作者已经针对成人普查数据集对其进行了优化。

默认情况下,训练过程会将信息写入执行脚本的位置的tflog子文件夹。您可以将tensorboard指向此文件夹以查看训练和验证统计信息:

tensorboard --logdir tflog

如果您没有GPU ...

…您可以尝试这款Colaboratory笔记(https://colab.research.google.com/drive/1AWnaS6uQVDw0sdWjfh-E77QlLtD0cpDa)。请注意,如果您想查看Tensorboard日志,最好的选择是创建一个Google Storage存储桶,并让脚本在其中写入日志。这可以通过使用tb-log-location参数来完成。例如。如果您的存储桶名称是camembert-skyscrape,则可以在脚本的调用中添加--tb-log-location gs:// camembert-skyscraper。(不过请注意,您必须正确设置存储桶的权限。这可能有点麻烦。)

然后可以将tensorboard从自己的本地计算机指向该存储桶:

tensorboard --logdir gs://camembert-skyscraper

超参数优化

在存储库(opt_tabnet.py)中也有一个用于完成超参数优化的快捷脚本。同样,在协作笔记本中显示了一个示例。该脚本仅适用于到目前为止的分类,值得注意的是,某些训练参数虽然实际上并不需要,但仍进行了硬编码(例如,用于尽早停止的参数[您可以继续执行多少步,而 验证准确性没有提高]。)

优化脚本中变化的参数为N_steps,feature_dim,batch-momentum,gamma,lambda-sparsity。(正如下面的优化技巧所建议的那样,output_dim设置为等于feature_dim。)

论文中具有以下有关超参数优化的提示:

大多数数据集对N_steps∈[3,10]产生最佳结果。通常,更大的数据集和更复杂的任务需要更大的N_steps。N_steps的非常高的值可能会过度拟合并导致不良的泛化。

调整Nd [feature_dim]和Na [output_dim]的值是获得性能与复杂性之间折衷的最有效方法。Nd = Na是大多数数据集的合理选择。Nd和Na的非常高的值可能会过度拟合,导致泛化效果差。

γ的最佳选择对整体性能具有重要作用。通常,较大的N_steps值有利于较大的γ。

批量较大对性能有利-如果内存限制允许,建议最大训练数据集总大小的1-10%。虚拟批次大小通常比批次大小小得多。

最初,较高的学习率很重要,应逐渐降低直至收敛。

结果

我已经通过此命令行界面尝试了TabNet的多个数据集,作者提供了他们在那里找到的最佳参数设置。使用这些设置重复运行后,我注意到最佳验证误差(和测试误差)往往在86%左右,类似于不进行超参数调整的CatBoost。作者报告论文中测试集的性能为85.7%。当我使用hyperopt进行超参数优化时,尽管使用了不同的参数设置,但我毫不奇怪地达到了约86%的相似性能。

对于其他数据集,例如Poker Hand 数据集,TabNet被认为远远击败了其他方法。我还没有花很多时间,但是当然每个人都应邀请他们自己对各种数据集进行超参数优化的TabNet!

TabNet是一个有趣的体系结构,似乎有望用于表格数据分析。它直接对原始数据进行操作,并使用顺序注意机制对每个示例执行显式特征选择。此属性还使其具有某种内置的可解释性。

我试图通过围绕它编写一些包装器代码来使TabNet稍微容易一些。下一步是将其与各种数据集中的其他方法进行比较。

tabnet的各种实现

google官方:https://github.com/google-research/google-research/tree/master/tabnet

pytorch:https://github.com/dreamquark-ai/tabnet

本文作者的一些改进:https://github.com/hussius/tabnet_fork

作者:Mikael Huss

deephub翻译组


赞助本站

相关内容
AiLab云推荐
推荐内容
展开

热门栏目HotCates

Copyright © 2010-2024 AiLab Team. 人工智能实验室 版权所有    关于我们 | 联系我们 | 广告服务 | 公司动态 | 免责声明 | 隐私条款 | 工作机会 | 展会港