llm-asr-tts/0_Inference_QWen2.5.py
LIRUI ad1fee9c4b
Some checks failed
Lint / quick-checks (push) Has been cancelled
Lint / flake8-py3 (push) Has been cancelled
Close inactive issues / close-issues (push) Failing after 2s
first commit
2025-03-17 00:41:41 +08:00

74 lines
2.3 KiB
Python

from transformers import AutoModelForCausalLM, AutoTokenizer
from cosyvoice.cli.cosyvoice import CosyVoice
from cosyvoice.utils.file_utils import load_wav
import torchaudio
import pygame
import time
def play_audio(file_path):
try:
pygame.mixer.init()
pygame.mixer.music.load(file_path)
pygame.mixer.music.play()
while pygame.mixer.music.get_busy():
time.sleep(1) # 等待音频播放结束
print("播放完成!")
except Exception as e:
print(f"播放失败: {e}")
finally:
pygame.mixer.quit()
model_name = r"D:\AI\download\Qwen2.5-0.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
cosyvoice = CosyVoice(r'../pretrained_models/CosyVoice-300M', load_jit=True, load_onnx=False, fp16=True)
# sft usage
print(cosyvoice.list_avaliable_spks())
prompt = "你好,你叫什么名字?"
messages = [
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
{"role": "user", "content": prompt},
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=512,
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print("Input:", prompt)
print("Answer:", response)
# ['中文女', '中文男', '日语男', '粤语女', '英文女', '英文男', '韩语女']
for i, j in enumerate(cosyvoice.inference_sft(f'{prompt}', '中文男', stream=False)):
torchaudio.save('prompt_sft_{}.wav'.format(i), j['tts_speech'], 22050)
# play_audio('prompt_sft_{}.wav'.format(i))
# change stream=True for chunk stream inference
for i, j in enumerate(cosyvoice.inference_sft(f'{response}', '中文女', stream=False)):
torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], 22050)
# play_audio('sft_{}.wav'.format(i))
play_audio('prompt_sft_0.wav')
for i, j in enumerate(f'{response}'):
play_audio('sft_{}.wav'.format(i))