sparse_lookup
功能描述
mxRec模型训练框架,稀疏特征表查询接口。
当前仅支持一表一查和一表多查。若存在一表多查的情况下,查询次数最大值为128。
暂不支持tf.SparseTensor数据类型,若是tf.SparseTensor需转成tf.Tensor。示例代码如下:
# 示例代码 sparse_ids = tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) dense_ids = tf.sparse.to_dense(sparse_ids, default_value=0) embedding = sparse_lookup(sparse_hashtable, dense_ids)
函数原型
def sparse_lookup(hashtable, ids, send_count, is_train=True, name=None, modify_graph=False, batch=None, access_and_evict_config=None, is_grad=True, serving_default_value, **kwargs)
参数说明
参数名 |
类型 |
可选/必选 |
说明 |
---|---|---|---|
hashtable |
SparseEmbedding |
必选 |
待查询的稀疏表。 |
ids |
FeatureSpec/tf.Tensor |
必选 |
查询的关键字(key),对应参数类型在不同功能模式下存在区别,具体参见如下。
|
send_count |
int |
可选参数;开启静态shape时为必选参数。 |
作为All2All通信技术,取值范围:[1, 2147483647]。 开启动态shape时无需传该参数,或传“None”即可。默认值为“None”。 |
is_train |
bool |
必选 |
是否为训练模式。默认值为“True”。 取值范围:
|
name |
str |
可选 |
为该次查询操作创建对应的名称,字符串长度为[1,255]。默认值为“None”。 |
modify_graph |
bool |
可选 |
自动改图功能开关,该功能将在创建Session实例前对模型原图进行修改优化,默认值为“False”。 取值范围:
|
batch |
dict |
可选 |
数据集的迭代器。 当同时使用FeatureSpec类型、动态Shape功能时,“batch”参数必须传入。默认值为“None”。 |
access_and_evict_config |
dict |
可选 |
自动改图模式下开启特征准入与淘汰时使用。该dict由两个key-value对组成,“key”分别为“access_threshold”和“eviction_threshold”,“value”为对应的阈值。默认值为“None”。 |
is_grad |
bool |
可选 |
此次查询是否需要梯度更新,默认值为“True”。 取值范围:
|
serving_default_value |
tf.Tensor |
可选 |
训练时未准入特征/预测时的新特征的默认emb值。如果不指定,默认为“0”。 |
**kwargs参数说明
参数名 |
类型 |
可选/必选 |
说明 |
---|---|---|---|
feature_spec_name_ids_dict |
dict |
可选 |
字典结构,key为FeatureSpec名称,value为公开接口sparse_lookp()的参数ids,无默认值。 |
multi_lookup |
bool |
可选 |
是否存在一表多查的情况,无默认值。 取值范围:
|
lookup_ids |
FeatureSpec/tf.Tensor |
可选 |
查询的关键字(key),对应参数类型在不同功能模式下存在区别,具体参见如下。无默认值。
|
- **kwargs参数中的“feature_spec_name_ids_dict”、“multi_lookup”和“lookup_ids”作为内部使用参数,不建议用户通过kwargs传递这三个参数。
- 如果通过kwargs传递其他未说明参数,则mxRec内部不会使用到该参数。
返回值说明
- 成功:返回查询到的Tensor类结果。
- 失败:抛出异常。
使用示例
from mx_rec.core.embedding import sparse_lookup from mx_rec.core.asc.feature_spec import FeatureSpec feature_spec = FeatureSpec("sparse_feature", table_name="sparse_embeddings_table", batch_size=1) embedding = sparse_lookup(sparse_hashtable, feature_spec, send_count=6000, is_train=True, name="sparse_embeddings")
参考资源
接口调用流程及示例,参见模型迁移与训练。