下载
中文
注册

def __init__

函数功能

类初始化函数。

函数原型

def __init__(self, input_size=(None, None, None), in_channels=4, caption_channels=4096, enable_flash_attn=True, enable_sequence_parallelism=False, use_cache=True, cache_interval=2, cache_start=3, num_cache_layer=13, cache_start_steps=5)

参数说明

参数名

输入/输出

说明

input_size

输入

STDiT的输入latent_size大小,表示输入数据的尺寸,三元组 (T, H, W)分别表示时间维度、高度和宽度。

in_channels

输入

每个像素的RGBA通道数,默认值为4。

caption_channels

输入

TextEncoder模型文本编码维度,默认值为4096。

enable_flash_attn

输入

控制是否使用Flash Attention技术来加速注意力计算,默认值为True。

enable_sequence_parallelism

输入

控制是否在模型中使用序列并行化技术来加速计算,设置为True时需要使用多卡并行推理。默认值为False。

use_cache

输入

是否启用缓存机制,默认值为True。

cache_interval

输入

缓存数据的间隔步数。不建议用户修改默认值,如需修改,需要保证不要超过迭代的最大步数。默认值为2。

cache_start

输入

开始缓存的Block层数。不建议用户修改默认值,如需修改,需要保证不要超过模型最大Block层数。默认值为3。

num_cache_layer

输入

缓存的Block层数。不建议用户修改默认值,如需修改,需要保证不要超过模型最大Block层数。默认值为13。

cache_start_steps

输入

开始缓存的迭代步数。不建议用户修改默认值,如需修改,需要保证不要超过迭代的最大步数。默认值为5。

返回值说明