create_table
功能描述
创建稀疏表。
函数原型
def create_table(key_dtype, dim, name, emb_initializer, device_vocabulary_size=1, host_vocabulary_size=0, optimizer_list=None, mode=MxRecMode.ASC, value_dtype=tf.float32, shard_num=1, fusion_optimizer_var=True, hashtable_threshold=0, is_save=True, init_param=1.0)
参数说明
参数名 |
类型 |
可选/必选 |
说明 |
---|---|---|---|
key_dtype |
TensorFlow的dtype类型 |
必选 |
稀疏特征键(key)数据类型。 |
dim |
int |
必选 |
嵌入层(embedding)维度。 |
name |
str |
必选 |
稀疏表表名。 |
emb_initializer |
func |
必选 |
嵌入层初始值生成器。 |
device_vocabulary_size |
int |
可选 |
Device侧嵌入层数量,默认值为“1”。 |
host_vocabulary_size |
int |
可选 |
Host侧DDR存储的嵌入层数量,默认值为“0”。 |
optimizer_list |
list |
可选 |
优化器列表,默认值为“None”。 |
mode |
MxRecMode枚举类 |
可选 |
mxRec框架方案,默认值为“MxRecMode.ASC”。 |
value_dtype |
TensorFlow的dtype类型 |
可选 |
稀疏特征值(value)数据类型,默认值为“tf.float32”。 |
shard_num |
int |
可选 |
嵌入层分区数,默认值为“1”。 |
fusion_optimizer_var |
bool |
可选 |
是否使用融合优化参数,默认值为“True”。 |
hashtable_threshold |
int |
可选 |
哈希表阈值,高于阈值时使用哈希表,低于阈值时使用线性表,默认值为“0”。 |
is_save |
bool |
可选 |
是否保存embedding数据,默认值为“True”。 |
init_param |
float |
可选 |
Embedding初始化参数系数,默认值为“1.0”。 |
- 当“host_vocabulary_size”不为0时,“optimizer_list”为必选。
- 当“host_vocabulary_size”为“0”时,不开启Host侧DDR功能,不为“0”时开启。所有embedding表必须保持同时使用Host侧DDR功能或同时不使用Host侧DDR功能,即所有表“host_vocabulary_size”参数同时为“0”或同时不为“0”,否则进行参数校验时会报错,报错信息参考如下。
ValueError: The host-side DDR function of all tables must be used or not used at the same time. However, host voc size of each table is [].
返回值说明
- 成功:返回稀疏表实例。
- 失败:抛出异常。
使用示例
import tensorflow as tf from mx_rec.core.embedding import create_table sparse_hashtable = create_table(key_dtype=tf.int32, dim=tf.TensorShape([128]), name="sparse_embeddings_table", emb_initializer=tf.truncated_normal_initializer(), device_vocabulary_size=24_000_000 * 8, host_vocabulary_size=0)
参考资源
接口调用流程及示例,参见模型训练。
父主题: 模型接口