Unify model behavior and update dialect data example

- Synchronize HF inference output with vLLM for consistency,
- Patch the dialect inference reference JSON file.
This commit is contained in:
linhaopeng
2025-10-28 16:46:57 +08:00
parent 82732cf526
commit 8967f71498
6 changed files with 15 additions and 10 deletions

View File

@@ -8,4 +8,5 @@ input_file=example/podcast_script/script_mandarin.json
python cli/podcast.py \
--json_path ${input_file} \
--model_path ${model_dir} \
--output_path outputs/mandarin.wav
--output_path outputs/mandarin.wav \
--seed 7

View File

@@ -3,12 +3,12 @@
"S1": {
"prompt_audio": "example/audios/female_mandarin.wav",
"prompt_text": "喜欢攀岩、徒步、滑雪的语言爱好者,以及过两天要带着全部家当去景德镇做陶瓷的白日梦想家。",
"dialect_prompt": "<|Henan|>哎哟俺哩个乖乖!叻老师这生活咋恁得劲咧!又是攀岩又是烧瓷,赶明儿怕不是要踩着雪板往窑洞里钻?"
"dialect_prompt": "<|Henan|>俺这不是怕恁路上不得劲儿嘛!那景德镇瓷泥可娇贵着哩,得先拿咱河南人这实诚劲儿给它揉透喽。"
},
"S2": {
"prompt_audio": "example/audios/male_mandarin.wav",
"prompt_text": "听到叻四老师关于自己身份的介绍,大家其实可以想象一下 叻四老师的生活有如此的丰富的状态。",
"dialect_prompt": "<|Henan|>哎哟俺哩个乖乖!叻老师这生活咋恁得劲咧!又是攀岩又是烧瓷,赶明儿怕不是要踩着雪板往窑洞里钻?"
"prompt_text": "呃,还有一个就是要跟大家纠正一点,就是我们在看电影的时候,尤其是游戏玩家,看电影的时候,在看到那个到西北那边的这个陕北民谣,嗯,这个可能在想,哎,是不是他是受到了黑神话的启发?",
"dialect_prompt": "<|Henan|>恁这想法真闹挺!陕北民谣比黑神话早几百年都有了,咱可不兴这弄颠倒啊,中不?恁这想法真闹挺!那陕北民谣在黄土高坡响了几百年,咋能说是跟黑神话学的咧?咱得把这事儿捋直喽,中不中!"
}
},
"text": [

View File

@@ -3,7 +3,7 @@
"S1": {
"prompt_audio": "example/audios/female_mandarin.wav",
"prompt_text": "喜欢攀岩、徒步、滑雪的语言爱好者,以及过两天要带着全部家当去景德镇做陶瓷的白日梦想家。",
"dialect_prompt": "<|Sichuan|>哪个晓得耍耍哒哒的,倒把生科院的门门儿敲开啰!现在天天泡实验室,跟那些瓶瓶罐罐打交道,硬是闻到福尔马林的味道都比屋头泡菜坛子亲切喽"
"dialect_prompt": "<|Sichuan|>要得要得!前头几个耍洋盘,我后脚就背起铺盖卷去景德镇耍泥巴,巴适得喊老天爷"
},
"S2": {
"prompt_audio": "example/audios/male_mandarin.wav",

View File

@@ -3,12 +3,12 @@
"S1": {
"prompt_audio": "example/audios/female_mandarin.wav",
"prompt_text": "喜欢攀岩、徒步、滑雪的语言爱好者,以及过两天要带着全部家当去景德镇做陶瓷的白日梦想家。",
"dialect_prompt": "<|Yue|>真係冇讲错啊!佢哋制作认真到连背景音乐嘅呼吸声都执到最啱拍子,出亲街访仲要自带三个后备电池,呢种专业精神边度搵啊?我哋成日话学嘢就要学佢哋咁龟毛先得嘛"
"dialect_prompt": "<|Yue|>真係冇讲错啊!攀山滑雪嘅语言专家几巴闭,都唔及我听日拖成副身家去景德镇玩泥巴,呢铺真系发哂白日梦咯"
},
"S2": {
"prompt_audio": "example/audios/male_mandarin.wav",
"prompt_text": "呃,还有一个就是要跟大家纠正一点,就是我们在看电影的时候,尤其是游戏玩家,看电影的时候,在看到那个到西北那边的这个陕北民谣,嗯,这个可能在想,哎,是不是他是受到了黑神话的启发?",
"dialect_prompt": "<|Yue|>真係唔话得!徐生连标点符号都要拗到最啱先收货,成组人晚晚通顶改稿改到天光,呢种摞命搏嘅精神先至係最难得嘎"
"dialect_prompt": "<|Yue|>咪搞错啊!陕北民谣响度唱咗几十年,黑神话边有咁大面啊?你估佢哋抄游戏咩"
}
},
"text": [

View File

@@ -8,7 +8,7 @@ from dataclasses import fields, asdict
import torch
import torch.multiprocessing as mp
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteriaList
from transformers import EosTokenCriteria
from transformers import EosTokenCriteria, RepetitionPenaltyLogitsProcessor
try:
from vllm import LLM
from vllm import SamplingParams as VllmSamplingParams
@@ -48,7 +48,10 @@ class HFLLMEngine:
win_size=sampling_param.win_size, tau_r=sampling_param.tau_r)
else:
sample_hf_engine_handler = None
rep_pen_processor = RepetitionPenaltyLogitsProcessor(
penalty=sampling_param.repetition_penalty,
prompt_ignore_length=len(prompt)
) # exclude the input prompt, consistent with vLLM implementation;
with torch.no_grad():
input_len = len(prompt)
generated_ids = self.model.generate(
@@ -58,11 +61,11 @@ class HFLLMEngine:
top_p=sampling_param.top_p,
max_new_tokens=sampling_param.max_tokens,
temperature=sampling_param.temperature,
repetition_penalty=sampling_param.repetition_penalty,
stopping_criteria=stopping_criteria,
past_key_values=past_key_values,
custom_generate=sample_hf_engine_handler,
use_cache=True,
logits_processor=[rep_pen_processor]
)
generated_ids = generated_ids[:, input_len:].cpu().numpy().tolist()[0]
output = {

View File

@@ -126,6 +126,7 @@ class SoulXPodcast(torch.nn.Module):
prompt_inputs[-self.config.history_context:]
))
valid_turn_size = self.config.prompt_context + len(history_inputs) - prompt_text_bound
past_key_values = DynamicCache(config=cache_config)
valid_turn_size += 1
inputs.extend(text_tokens_for_llm[i])