模型推理代码
“wenet/model.py”中的python类WeNetASR封装了WeNet模型推理的主要代码。
- transcribe是推理主函数,将输入的录音文件经模型推理转换为文本。
- preprocess函数用于执行音频数据预处理,返回音频特征作为模型输入。
- post_process函数用于模型推理结果后处理,将数值型推理结果转换为文本输出。
class WeNetASR: def __init__(self, model_path, vocab_path): self.vocabulary = load_vocab(vocab_path) self.model = InferSession(0, model_path) self.max_len = self.model.get_inputs()[0].shape[1] def transcribe(self, wav_file): """执行模型推理,将录音文件转为文本。""" feats_pad, feats_lengths = self.preprocess(wav_file) output = self.model.infer([feats_pad, feats_lengths]) txt = self.post_process(output) return txt def preprocess(self, wav_file): """数据预处理""" waveform, sample_rate = torchaudio.load(wav_file) # 音频重采样,采样率16000 waveform, sample_rate = resample(waveform, sample_rate, resample_rate=16000) # 计算fbank特征 feature = compute_fbank(waveform, sample_rate) feats_lengths = np.array([feature.shape[0]]).astype(np.int32) # 对输入特征进行padding,使符合模型输入尺寸 feats_pad = pad_sequence(feature, batch_first=True, padding_value=0, max_len=self.max_len) feats_pad = feats_pad.numpy().astype(np.float32) return feats_pad, feats_lengths def post_process(self, output): """对模型推理结果进行后处理""" encoder_out, encoder_out_lens, ctc_log_probs, \ beam_log_probs, beam_log_probs_idx = output batch_size = beam_log_probs.shape[0] num_processes = batch_size log_probs_idx = beam_log_probs_idx[:, :, 0] batch_sents = [] for idx, seq in enumerate(log_probs_idx): batch_sents.append(seq[:encoder_out_lens[idx]].tolist()) # 根据预置的标签字典将推理结果转换为文本 txt = map_batch(batch_sents, self.vocabulary, num_processes, True, 0)[0] return txt
父主题: 代码实现