mirror of
https://github.com/Soul-AILab/SoulX-Podcast.git
synced 2026-05-06 21:51:04 +08:00
Unify model behavior and update dialect data example
- Synchronize HF inference output with vLLM for consistency, - Patch the dialect inference reference JSON file.
This commit is contained in:
@@ -8,4 +8,5 @@ input_file=example/podcast_script/script_mandarin.json
|
||||
python cli/podcast.py \
|
||||
--json_path ${input_file} \
|
||||
--model_path ${model_dir} \
|
||||
--output_path outputs/mandarin.wav
|
||||
--output_path outputs/mandarin.wav \
|
||||
--seed 7
|
||||
@@ -3,12 +3,12 @@
|
||||
"S1": {
|
||||
"prompt_audio": "example/audios/female_mandarin.wav",
|
||||
"prompt_text": "喜欢攀岩、徒步、滑雪的语言爱好者,以及过两天要带着全部家当去景德镇做陶瓷的白日梦想家。",
|
||||
"dialect_prompt": "<|Henan|>哎哟俺哩个乖乖!叻老师这生活咋恁得劲咧!又是攀岩又是烧瓷,赶明儿怕不是要踩着雪板往窑洞里钻?"
|
||||
"dialect_prompt": "<|Henan|>俺这不是怕恁路上不得劲儿嘛!那景德镇瓷泥可娇贵着哩,得先拿咱河南人这实诚劲儿给它揉透喽。"
|
||||
},
|
||||
"S2": {
|
||||
"prompt_audio": "example/audios/male_mandarin.wav",
|
||||
"prompt_text": "听到叻四老师关于自己身份的介绍,大家其实可以想象一下 叻四老师的生活有如此的丰富的状态。",
|
||||
"dialect_prompt": "<|Henan|>哎哟俺哩个乖乖!叻老师这生活咋恁得劲咧!又是攀岩又是烧瓷,赶明儿怕不是要踩着雪板往窑洞里钻?"
|
||||
"prompt_text": "呃,还有一个就是要跟大家纠正一点,就是我们在看电影的时候,尤其是游戏玩家,看电影的时候,在看到那个到西北那边的这个陕北民谣,嗯,这个可能在想,哎,是不是他是受到了黑神话的启发?",
|
||||
"dialect_prompt": "<|Henan|>恁这想法真闹挺!陕北民谣比黑神话早几百年都有了,咱可不兴这弄颠倒啊,中不?恁这想法真闹挺!那陕北民谣在黄土高坡响了几百年,咋能说是跟黑神话学的咧?咱得把这事儿捋直喽,中不中!"
|
||||
}
|
||||
},
|
||||
"text": [
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
"S1": {
|
||||
"prompt_audio": "example/audios/female_mandarin.wav",
|
||||
"prompt_text": "喜欢攀岩、徒步、滑雪的语言爱好者,以及过两天要带着全部家当去景德镇做陶瓷的白日梦想家。",
|
||||
"dialect_prompt": "<|Sichuan|>哪个晓得耍耍哒哒的,倒把生科院的门门儿敲开啰!现在天天泡实验室,跟那些瓶瓶罐罐打交道,硬是闻到福尔马林的味道都比屋头泡菜坛子亲切喽!"
|
||||
"dialect_prompt": "<|Sichuan|>要得要得!前头几个耍洋盘,我后脚就背起铺盖卷去景德镇耍泥巴,巴适得喊老天爷!"
|
||||
},
|
||||
"S2": {
|
||||
"prompt_audio": "example/audios/male_mandarin.wav",
|
||||
|
||||
@@ -3,12 +3,12 @@
|
||||
"S1": {
|
||||
"prompt_audio": "example/audios/female_mandarin.wav",
|
||||
"prompt_text": "喜欢攀岩、徒步、滑雪的语言爱好者,以及过两天要带着全部家当去景德镇做陶瓷的白日梦想家。",
|
||||
"dialect_prompt": "<|Yue|>真係冇讲错啊!佢哋制作认真到连背景音乐嘅呼吸声都执到最啱拍子,出亲街访仲要自带三个后备电池,呢种专业精神边度搵啊?我哋成日话学嘢就要学佢哋咁龟毛先得嘛!"
|
||||
"dialect_prompt": "<|Yue|>真係冇讲错啊!攀山滑雪嘅语言专家几巴闭,都唔及我听日拖成副身家去景德镇玩泥巴,呢铺真系发哂白日梦咯!"
|
||||
},
|
||||
"S2": {
|
||||
"prompt_audio": "example/audios/male_mandarin.wav",
|
||||
"prompt_text": "呃,还有一个就是要跟大家纠正一点,就是我们在看电影的时候,尤其是游戏玩家,看电影的时候,在看到那个到西北那边的这个陕北民谣,嗯,这个可能在想,哎,是不是他是受到了黑神话的启发?",
|
||||
"dialect_prompt": "<|Yue|>真係唔话得!徐生连标点符号都要拗到最啱先收货,成组人晚晚通顶改稿改到天光,呢种摞命搏嘅精神先至係最难得嘎!"
|
||||
"dialect_prompt": "<|Yue|>咪搞错啊!陕北民谣响度唱咗几十年,黑神话边有咁大面啊?你估佢哋抄游戏咩!"
|
||||
}
|
||||
},
|
||||
"text": [
|
||||
|
||||
@@ -8,7 +8,7 @@ from dataclasses import fields, asdict
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteriaList
|
||||
from transformers import EosTokenCriteria
|
||||
from transformers import EosTokenCriteria, RepetitionPenaltyLogitsProcessor
|
||||
try:
|
||||
from vllm import LLM
|
||||
from vllm import SamplingParams as VllmSamplingParams
|
||||
@@ -48,7 +48,10 @@ class HFLLMEngine:
|
||||
win_size=sampling_param.win_size, tau_r=sampling_param.tau_r)
|
||||
else:
|
||||
sample_hf_engine_handler = None
|
||||
|
||||
rep_pen_processor = RepetitionPenaltyLogitsProcessor(
|
||||
penalty=sampling_param.repetition_penalty,
|
||||
prompt_ignore_length=len(prompt)
|
||||
) # exclude the input prompt, consistent with vLLM implementation;
|
||||
with torch.no_grad():
|
||||
input_len = len(prompt)
|
||||
generated_ids = self.model.generate(
|
||||
@@ -58,11 +61,11 @@ class HFLLMEngine:
|
||||
top_p=sampling_param.top_p,
|
||||
max_new_tokens=sampling_param.max_tokens,
|
||||
temperature=sampling_param.temperature,
|
||||
repetition_penalty=sampling_param.repetition_penalty,
|
||||
stopping_criteria=stopping_criteria,
|
||||
past_key_values=past_key_values,
|
||||
custom_generate=sample_hf_engine_handler,
|
||||
use_cache=True,
|
||||
logits_processor=[rep_pen_processor]
|
||||
)
|
||||
generated_ids = generated_ids[:, input_len:].cpu().numpy().tolist()[0]
|
||||
output = {
|
||||
|
||||
@@ -126,6 +126,7 @@ class SoulXPodcast(torch.nn.Module):
|
||||
prompt_inputs[-self.config.history_context:]
|
||||
))
|
||||
valid_turn_size = self.config.prompt_context + len(history_inputs) - prompt_text_bound
|
||||
past_key_values = DynamicCache(config=cache_config)
|
||||
valid_turn_size += 1
|
||||
|
||||
inputs.extend(text_tokens_for_llm[i])
|
||||
|
||||
Reference in New Issue
Block a user