74 lines
2.3 KiB
Python
74 lines
2.3 KiB
Python
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
from cosyvoice.cli.cosyvoice import CosyVoice
|
|
from cosyvoice.utils.file_utils import load_wav
|
|
import torchaudio
|
|
|
|
import pygame
|
|
import time
|
|
|
|
def play_audio(file_path):
|
|
try:
|
|
pygame.mixer.init()
|
|
pygame.mixer.music.load(file_path)
|
|
pygame.mixer.music.play()
|
|
while pygame.mixer.music.get_busy():
|
|
time.sleep(1) # 等待音频播放结束
|
|
print("播放完成!")
|
|
except Exception as e:
|
|
print(f"播放失败: {e}")
|
|
finally:
|
|
pygame.mixer.quit()
|
|
|
|
model_name = r"D:\AI\download\Qwen2.5-0.5B-Instruct"
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_name,
|
|
torch_dtype="auto",
|
|
device_map="auto"
|
|
)
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
cosyvoice = CosyVoice(r'../pretrained_models/CosyVoice-300M', load_jit=True, load_onnx=False, fp16=True)
|
|
# sft usage
|
|
print(cosyvoice.list_avaliable_spks())
|
|
|
|
prompt = "你好,你叫什么名字?"
|
|
messages = [
|
|
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
|
|
{"role": "user", "content": prompt},
|
|
]
|
|
text = tokenizer.apply_chat_template(
|
|
messages,
|
|
tokenize=False,
|
|
add_generation_prompt=True,
|
|
)
|
|
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
|
|
|
generated_ids = model.generate(
|
|
**model_inputs,
|
|
max_new_tokens=512,
|
|
)
|
|
generated_ids = [
|
|
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
|
]
|
|
|
|
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
|
|
print("Input:", prompt)
|
|
print("Answer:", response)
|
|
|
|
# ['中文女', '中文男', '日语男', '粤语女', '英文女', '英文男', '韩语女']
|
|
for i, j in enumerate(cosyvoice.inference_sft(f'{prompt}', '中文男', stream=False)):
|
|
torchaudio.save('prompt_sft_{}.wav'.format(i), j['tts_speech'], 22050)
|
|
# play_audio('prompt_sft_{}.wav'.format(i))
|
|
|
|
# change stream=True for chunk stream inference
|
|
for i, j in enumerate(cosyvoice.inference_sft(f'{response}', '中文女', stream=False)):
|
|
torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], 22050)
|
|
# play_audio('sft_{}.wav'.format(i))
|
|
|
|
play_audio('prompt_sft_0.wav')
|
|
for i, j in enumerate(f'{response}'):
|
|
play_audio('sft_{}.wav'.format(i))
|
|
|