下载
中文
注册

模型推理代码

“wenet/model.py”中的python类WeNetASR封装了WeNet模型推理的主要代码。
  • transcribe是推理主函数,将输入的录音文件经模型推理转换为文本。
  • preprocess函数用于执行音频数据预处理,返回音频特征作为模型输入。
  • post_process函数用于模型推理结果后处理,将数值型推理结果转换为文本输出。
class WeNetASR:
    def __init__(self, model_path, vocab_path):
        self.vocabulary = load_vocab(vocab_path)
        self.model = InferSession(0, model_path)
        self.max_len = self.model.get_inputs()[0].shape[1]
 
    def transcribe(self, wav_file):
        """执行模型推理,将录音文件转为文本。"""
        feats_pad, feats_lengths = self.preprocess(wav_file)
        output = self.model.infer([feats_pad, feats_lengths])
        txt = self.post_process(output)
        return txt
 
    def preprocess(self, wav_file):
        """数据预处理"""
        waveform, sample_rate = torchaudio.load(wav_file)
        # 音频重采样,采样率16000
        waveform, sample_rate = resample(waveform, sample_rate, resample_rate=16000)
        # 计算fbank特征
        feature = compute_fbank(waveform, sample_rate)
        feats_lengths = np.array([feature.shape[0]]).astype(np.int32)
        # 对输入特征进行padding,使符合模型输入尺寸
        feats_pad = pad_sequence(feature,
                                 batch_first=True,
                                 padding_value=0,
                                 max_len=self.max_len)
        feats_pad = feats_pad.numpy().astype(np.float32)
        return feats_pad, feats_lengths
 
    def post_process(self, output):
        """对模型推理结果进行后处理"""
        encoder_out, encoder_out_lens, ctc_log_probs, \
            beam_log_probs, beam_log_probs_idx = output
        batch_size = beam_log_probs.shape[0]
 
        num_processes = batch_size
        log_probs_idx = beam_log_probs_idx[:, :, 0]
        batch_sents = []
        for idx, seq in enumerate(log_probs_idx):
            batch_sents.append(seq[:encoder_out_lens[idx]].tolist())
        # 根据预置的标签字典将推理结果转换为文本
        txt = map_batch(batch_sents, self.vocabulary, num_processes, True, 0)[0]
        return txt