def __init__
函数功能
类初始化函数。
函数原型
def __init__(self, input_size=(None, None, None), in_channels=4, caption_channels=4096, enable_flash_attn=True, enable_sequence_parallelism=False, use_cache=True, cache_interval=2, cache_start=3, num_cache_layer=13, cache_start_steps=5)
参数说明
参数名 |
输入/输出 |
说明 |
---|---|---|
input_size |
输入 |
STDiT的输入latent_size大小,表示输入数据的尺寸,三元组 (T, H, W)分别表示时间维度、高度和宽度。 |
in_channels |
输入 |
每个像素的RGBA通道数,默认值为4。 |
caption_channels |
输入 |
TextEncoder模型文本编码维度,默认值为4096。 |
enable_flash_attn |
输入 |
控制是否使用Flash Attention技术来加速注意力计算,默认值为True。 |
enable_sequence_parallelism |
输入 |
控制是否在模型中使用序列并行化技术来加速计算,设置为True时需要使用多卡并行推理。默认值为False。 |
use_cache |
输入 |
是否启用缓存机制,默认值为True。 |
cache_interval |
输入 |
缓存数据的间隔步数。不建议用户修改默认值,如需修改,需要保证不要超过迭代的最大步数。默认值为2。 |
cache_start |
输入 |
开始缓存的Block层数。不建议用户修改默认值,如需修改,需要保证不要超过模型最大Block层数。默认值为3。 |
num_cache_layer |
输入 |
缓存的Block层数。不建议用户修改默认值,如需修改,需要保证不要超过模型最大Block层数。默认值为13。 |
cache_start_steps |
输入 |
开始缓存的迭代步数。不建议用户修改默认值,如需修改,需要保证不要超过迭代的最大步数。默认值为5。 |
返回值说明
无