first commit
38
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
@ -0,0 +1,38 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Create a report to help us improve
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**To Reproduce**
|
||||
Steps to reproduce the behavior:
|
||||
1. Go to '...'
|
||||
2. Click on '....'
|
||||
3. Scroll down to '....'
|
||||
4. See error
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Screenshots**
|
||||
If applicable, add screenshots to help explain your problem.
|
||||
|
||||
**Desktop (please complete the following information):**
|
||||
- OS: [e.g. iOS]
|
||||
- Browser [e.g. chrome, safari]
|
||||
- Version [e.g. 22]
|
||||
|
||||
**Smartphone (please complete the following information):**
|
||||
- Device: [e.g. iPhone6]
|
||||
- OS: [e.g. iOS8.1]
|
||||
- Browser [e.g. stock browser, safari]
|
||||
- Version [e.g. 22]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
20
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
@ -0,0 +1,20 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest an idea for this project
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Is your feature request related to a problem? Please describe.**
|
||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
||||
|
||||
**Describe the solution you'd like**
|
||||
A clear and concise description of what you want to happen.
|
||||
|
||||
**Describe alternatives you've considered**
|
||||
A clear and concise description of any alternative solutions or features you've considered.
|
||||
|
||||
**Additional context**
|
||||
Add any other context or screenshots about the feature request here.
|
56
.github/workflows/lint.yml
vendored
Normal file
@ -0,0 +1,56 @@
|
||||
name: Lint
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
|
||||
jobs:
|
||||
quick-checks:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Fetch CosyVoice
|
||||
uses: actions/checkout@v1
|
||||
- name: Checkout PR tip
|
||||
run: |
|
||||
set -eux
|
||||
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
|
||||
# We are on a PR, so actions/checkout leaves us on a merge commit.
|
||||
# Check out the actual tip of the branch.
|
||||
git checkout ${{ github.event.pull_request.head.sha }}
|
||||
fi
|
||||
echo ::set-output name=commit_sha::$(git rev-parse HEAD)
|
||||
id: get_pr_tip
|
||||
- name: Ensure no tabs
|
||||
run: |
|
||||
(! git grep -I -l $'\t' -- . ':(exclude)*.txt' ':(exclude)*.svg' ':(exclude)**Makefile' ':(exclude)**/contrib/**' ':(exclude)third_party' ':(exclude).gitattributes' ':(exclude).gitmodules' || (echo "The above files have tabs; please convert them to spaces"; false))
|
||||
- name: Ensure no trailing whitespace
|
||||
run: |
|
||||
(! git grep -I -n $' $' -- . ':(exclude)*.txt' ':(exclude)third_party' ':(exclude).gitattributes' ':(exclude).gitmodules' || (echo "The above files have trailing whitespace; please remove them"; false))
|
||||
|
||||
flake8-py3:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v1
|
||||
with:
|
||||
python-version: 3.9
|
||||
architecture: x64
|
||||
- name: Fetch CosyVoice
|
||||
uses: actions/checkout@v1
|
||||
- name: Checkout PR tip
|
||||
run: |
|
||||
set -eux
|
||||
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
|
||||
# We are on a PR, so actions/checkout leaves us on a merge commit.
|
||||
# Check out the actual tip of the branch.
|
||||
git checkout ${{ github.event.pull_request.head.sha }}
|
||||
fi
|
||||
echo ::set-output name=commit_sha::$(git rev-parse HEAD)
|
||||
id: get_pr_tip
|
||||
- name: Run flake8
|
||||
run: |
|
||||
set -eux
|
||||
pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0
|
||||
flake8 --version
|
||||
flake8 --max-line-length 150 --ignore B006,B008,B905,C408,E402,E731,E741,W503,W504 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
|
||||
if [ $? != 0 ]; then exit 1; fi
|
22
.github/workflows/stale-issues.yml
vendored
Normal file
@ -0,0 +1,22 @@
|
||||
name: Close inactive issues
|
||||
on:
|
||||
schedule:
|
||||
- cron: "30 1 * * *"
|
||||
|
||||
jobs:
|
||||
close-issues:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/stale@v5
|
||||
with:
|
||||
days-before-issue-stale: 30
|
||||
days-before-issue-close: 14
|
||||
stale-issue-label: "stale"
|
||||
stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
|
||||
close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
52
.gitignore
vendored
Normal file
@ -0,0 +1,52 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
.venv/
|
||||
|
||||
# Visual Studio Code files
|
||||
.vscode
|
||||
.vs
|
||||
|
||||
# PyCharm files
|
||||
.idea
|
||||
|
||||
# Eclipse Project settings
|
||||
*.*project
|
||||
.settings
|
||||
|
||||
# Sublime Text settings
|
||||
*.sublime-workspace
|
||||
*.sublime-project
|
||||
|
||||
# Editor temporaries
|
||||
*.swn
|
||||
*.swo
|
||||
*.swp
|
||||
*.swm
|
||||
*~
|
||||
|
||||
# IPython notebook checkpoints
|
||||
.ipynb_checkpoints
|
||||
|
||||
# macOS dir files
|
||||
.DS_Store
|
||||
|
||||
exp
|
||||
data
|
||||
raw_wav
|
||||
tensorboard
|
||||
**/*build*
|
||||
|
||||
# Clangd files
|
||||
.cache
|
||||
compile_commands.json
|
||||
|
||||
# train/inference files
|
||||
*.wav
|
||||
*.m4a
|
||||
*.aac
|
||||
*.pt
|
||||
pretrained_models/*
|
||||
*_pb2_grpc.py
|
||||
*_pb2.py
|
3
.gitmodules
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
[submodule "third_party/Matcha-TTS"]
|
||||
path = third_party/Matcha-TTS
|
||||
url = https://github.com/shivammehta25/Matcha-TTS.git
|
73
0_Inference_QWen2.5.py
Normal file
@ -0,0 +1,73 @@
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice
|
||||
from cosyvoice.utils.file_utils import load_wav
|
||||
import torchaudio
|
||||
|
||||
import pygame
|
||||
import time
|
||||
|
||||
def play_audio(file_path):
|
||||
try:
|
||||
pygame.mixer.init()
|
||||
pygame.mixer.music.load(file_path)
|
||||
pygame.mixer.music.play()
|
||||
while pygame.mixer.music.get_busy():
|
||||
time.sleep(1) # 等待音频播放结束
|
||||
print("播放完成!")
|
||||
except Exception as e:
|
||||
print(f"播放失败: {e}")
|
||||
finally:
|
||||
pygame.mixer.quit()
|
||||
|
||||
model_name = r"D:\AI\download\Qwen2.5-0.5B-Instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
cosyvoice = CosyVoice(r'../pretrained_models/CosyVoice-300M', load_jit=True, load_onnx=False, fp16=True)
|
||||
# sft usage
|
||||
print(cosyvoice.list_avaliable_spks())
|
||||
|
||||
prompt = "你好,你叫什么名字?"
|
||||
messages = [
|
||||
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||
|
||||
generated_ids = model.generate(
|
||||
**model_inputs,
|
||||
max_new_tokens=512,
|
||||
)
|
||||
generated_ids = [
|
||||
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
||||
]
|
||||
|
||||
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
print("Input:", prompt)
|
||||
print("Answer:", response)
|
||||
|
||||
# ['中文女', '中文男', '日语男', '粤语女', '英文女', '英文男', '韩语女']
|
||||
for i, j in enumerate(cosyvoice.inference_sft(f'{prompt}', '中文男', stream=False)):
|
||||
torchaudio.save('prompt_sft_{}.wav'.format(i), j['tts_speech'], 22050)
|
||||
# play_audio('prompt_sft_{}.wav'.format(i))
|
||||
|
||||
# change stream=True for chunk stream inference
|
||||
for i, j in enumerate(cosyvoice.inference_sft(f'{response}', '中文女', stream=False)):
|
||||
torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], 22050)
|
||||
# play_audio('sft_{}.wav'.format(i))
|
||||
|
||||
play_audio('prompt_sft_0.wav')
|
||||
for i, j in enumerate(f'{response}'):
|
||||
play_audio('sft_{}.wav'.format(i))
|
||||
|
156
10_SenceVoice_QWen2.5_cosyVoice.py
Normal file
@ -0,0 +1,156 @@
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice
|
||||
from cosyvoice.utils.file_utils import load_wav
|
||||
from funasr import AutoModel
|
||||
import torchaudio
|
||||
import pygame
|
||||
import time
|
||||
import sys
|
||||
import sounddevice as sd
|
||||
from scipy.io.wavfile import write
|
||||
import numpy as np
|
||||
|
||||
def record_audio(filename="output.wav", sample_rate=44100):
|
||||
print("按下 Enter 开始录音...")
|
||||
input() # 等待用户按下 Enter 键开始录音
|
||||
print("录音中... 按下 Enter 键结束录音")
|
||||
|
||||
# 开始录音
|
||||
recording = []
|
||||
try:
|
||||
def callback(indata, frames, time, status):
|
||||
recording.append(indata.copy())
|
||||
with sd.InputStream(samplerate=sample_rate, channels=1, callback=callback):
|
||||
input() # 等待用户再次按下 Enter 键结束录音
|
||||
except Exception as e:
|
||||
print(f"录音出现错误: {e}")
|
||||
return
|
||||
|
||||
# 将录音数据合并并保存为 WAV 文件
|
||||
audio_data = np.concatenate(recording, axis=0)
|
||||
write(filename, sample_rate, (audio_data * 32767).astype(np.int16))
|
||||
print(f"录音已保存为 {filename}")
|
||||
|
||||
|
||||
# --- 播放音频 -
|
||||
def play_audio(file_path):
|
||||
try:
|
||||
pygame.mixer.init()
|
||||
pygame.mixer.music.load(file_path)
|
||||
pygame.mixer.music.play()
|
||||
while pygame.mixer.music.get_busy():
|
||||
time.sleep(1) # 等待音频播放结束
|
||||
print("播放完成!")
|
||||
except Exception as e:
|
||||
print(f"播放失败: {e}")
|
||||
finally:
|
||||
pygame.mixer.quit()
|
||||
|
||||
import os
|
||||
import shutil
|
||||
|
||||
def clear_folder(folder_path):
|
||||
# 检查文件夹是否存在
|
||||
if not os.path.exists(folder_path):
|
||||
os.makedirs(folder_path, exist_ok=True)
|
||||
print(f"文件夹 '{folder_path}' 不存在,已创建")
|
||||
return
|
||||
|
||||
# 获取文件夹中的所有文件和子文件夹
|
||||
items = os.listdir(folder_path)
|
||||
|
||||
# 如果文件夹为空,直接返回
|
||||
if not items:
|
||||
print(f"文件夹 '{folder_path}' 已经为空")
|
||||
return
|
||||
|
||||
# 遍历文件和文件夹并删除
|
||||
for item in items:
|
||||
item_path = os.path.join(folder_path, item)
|
||||
|
||||
# 判断是否是文件夹或文件
|
||||
if os.path.isfile(item_path):
|
||||
os.remove(item_path) # 删除文件
|
||||
print(f"删除文件: {item_path}")
|
||||
elif os.path.isdir(item_path):
|
||||
shutil.rmtree(item_path) # 删除文件夹及其内容
|
||||
print(f"删除文件夹: {item_path}")
|
||||
|
||||
print(f"文件夹 '{folder_path}' 已清空")
|
||||
|
||||
# ------------------- 模型初始化 ---------------
|
||||
# --- SenceVoice-语音识别模型
|
||||
model_dir = r"D:\AI\download\SenseVoiceSmall"
|
||||
model_senceVoice = AutoModel( model=model_dir, trust_remote_code=True, )
|
||||
|
||||
# --- QWen2.5大语言模型 ---
|
||||
# model_name = r":\2_PYTHON\Project\GPT\QWen\Qwen2.5-0.5B-Instruct"
|
||||
model_name = r"D:\AI\download\Qwen2.5-1.5B-Instruct"
|
||||
# model_name = r':\2_PYTHON\Project\GPT\QWen\Qwen2.5-7B-Instruct-GPTQ-Int4'
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
# --- CosyVoice - 语音合成模型
|
||||
cosyvoice = CosyVoice(r'D:\AI\download\CosyVoice-300M', load_jit=True, load_onnx=False, fp16=True)
|
||||
# --- CosyVoice - 支持的音色列表
|
||||
print(cosyvoice.list_avaliable_spks())
|
||||
# ------------------ 模型初始化结束 ----------------
|
||||
|
||||
while(1):
|
||||
# 使用函数录音,作为输入
|
||||
record_audio("my_recording.wav")
|
||||
|
||||
# input_file = ( "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" )
|
||||
input_file = ("my_recording.wav")
|
||||
res = model_senceVoice.generate(
|
||||
input=input_file,
|
||||
cache={},
|
||||
language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech"
|
||||
use_itn=False,
|
||||
)
|
||||
|
||||
# -------- 模型推理阶段,将语音识别结果作为大模型Prompt ------
|
||||
prompt = res[0]['text'].split(">")[-1] + ",回答简短一些,保持50字以内!"
|
||||
messages = [
|
||||
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||
|
||||
generated_ids = model.generate(
|
||||
**model_inputs,
|
||||
max_new_tokens=512,
|
||||
)
|
||||
generated_ids = [
|
||||
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
||||
]
|
||||
|
||||
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
print("Input:", prompt)
|
||||
print("Answer:", response)
|
||||
|
||||
# --- 答复输出文件夹 ---
|
||||
folder_path = "./out_answer/"
|
||||
clear_folder(folder_path)
|
||||
|
||||
# ['中文女', '中文男', '日语男', '粤语女', '英文女', '英文男', '韩语女']
|
||||
# change stream=True for chunk stream inference
|
||||
index_out = 0
|
||||
for i, j in enumerate(cosyvoice.inference_sft(f'{response}', '中文女', stream=False)):
|
||||
torchaudio.save('{}/sft_{}.wav'.format(folder_path,i), j['tts_speech'], 22050)
|
||||
index_out += 1
|
||||
# play_audio('sft_{}.wav'.format(i))
|
||||
|
||||
for idx in range(index_out):
|
||||
play_audio('{}/sft_{}.wav'.format(folder_path,idx))
|
||||
|
158
11_SenceVoice_QWen2.5_pytts3.py
Normal file
@ -0,0 +1,158 @@
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice
|
||||
from cosyvoice.utils.file_utils import load_wav
|
||||
from funasr import AutoModel
|
||||
import torchaudio
|
||||
import pygame
|
||||
import time
|
||||
import sys
|
||||
import sounddevice as sd
|
||||
from scipy.io.wavfile import write
|
||||
import numpy as np
|
||||
|
||||
def record_audio(filename="output.wav", sample_rate=44100):
|
||||
print("按下 Enter 开始录音...")
|
||||
input() # 等待用户按下 Enter 键开始录音
|
||||
print("录音中... 按下 Enter 键结束录音")
|
||||
|
||||
# 开始录音
|
||||
recording = []
|
||||
try:
|
||||
def callback(indata, frames, time, status):
|
||||
recording.append(indata.copy())
|
||||
with sd.InputStream(samplerate=sample_rate, channels=1, callback=callback):
|
||||
input() # 等待用户再次按下 Enter 键结束录音
|
||||
except Exception as e:
|
||||
print(f"录音出现错误: {e}")
|
||||
return
|
||||
|
||||
# 将录音数据合并并保存为 WAV 文件
|
||||
audio_data = np.concatenate(recording, axis=0)
|
||||
write(filename, sample_rate, (audio_data * 32767).astype(np.int16))
|
||||
print(f"录音已保存为 {filename}")
|
||||
|
||||
|
||||
# --- 播放音频 -
|
||||
def play_audio(file_path):
|
||||
try:
|
||||
pygame.mixer.init()
|
||||
pygame.mixer.music.load(file_path)
|
||||
pygame.mixer.music.play()
|
||||
while pygame.mixer.music.get_busy():
|
||||
time.sleep(1) # 等待音频播放结束
|
||||
print("播放完成!")
|
||||
except Exception as e:
|
||||
print(f"播放失败: {e}")
|
||||
finally:
|
||||
pygame.mixer.quit()
|
||||
|
||||
import os
|
||||
import shutil
|
||||
|
||||
def clear_folder(folder_path):
|
||||
# 检查文件夹是否存在
|
||||
if not os.path.exists(folder_path):
|
||||
os.makedirs(folder_path, exist_ok=True)
|
||||
print(f"文件夹 '{folder_path}' 不存在,已创建")
|
||||
return
|
||||
|
||||
# 获取文件夹中的所有文件和子文件夹
|
||||
items = os.listdir(folder_path)
|
||||
|
||||
# 如果文件夹为空,直接返回
|
||||
if not items:
|
||||
print(f"文件夹 '{folder_path}' 已经为空")
|
||||
return
|
||||
|
||||
# 遍历文件和文件夹并删除
|
||||
for item in items:
|
||||
item_path = os.path.join(folder_path, item)
|
||||
|
||||
# 判断是否是文件夹或文件
|
||||
if os.path.isfile(item_path):
|
||||
os.remove(item_path) # 删除文件
|
||||
print(f"删除文件: {item_path}")
|
||||
elif os.path.isdir(item_path):
|
||||
shutil.rmtree(item_path) # 删除文件夹及其内容
|
||||
print(f"删除文件夹: {item_path}")
|
||||
|
||||
print(f"文件夹 '{folder_path}' 已清空")
|
||||
|
||||
# ------------------- 模型初始化 ---------------
|
||||
# --- SenceVoice-语音识别模型
|
||||
model_dir = r"E:\2_PYTHON\Project\GPT\QWen\pretrained_models\SenseVoiceSmall"
|
||||
model_senceVoice = AutoModel( model=model_dir, trust_remote_code=True, )
|
||||
|
||||
# --- QWen2.5大语言模型 ---
|
||||
# model_name = r":\2_PYTHON\Project\GPT\QWen\Qwen2.5-0.5B-Instruct"
|
||||
model_name = r"E:\2_PYTHON\Project\GPT\QWen\Qwen2.5-1.5B-Instruct"
|
||||
# model_name = r':\2_PYTHON\Project\GPT\QWen\Qwen2.5-7B-Instruct-GPTQ-Int4'
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
import pyttsx3
|
||||
# 初始化 TTS 引擎
|
||||
engine = pyttsx3.init()
|
||||
# 设置语音属性
|
||||
engine.setProperty('rate', 200) # 语速engine.setProperty('volume', 0.9) # 音量(0.0 到 1.0)
|
||||
# 选择语音
|
||||
voices = engine.getProperty('voices')
|
||||
# print(voices)
|
||||
engine.setProperty('voice', voices[0].id) # 使用第一个语音
|
||||
|
||||
while(1):
|
||||
# 使用函数录音,作为输入
|
||||
record_audio("my_recording.wav")
|
||||
|
||||
# input_file = ( "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" )
|
||||
input_file = ("my_recording.wav")
|
||||
res = model_senceVoice.generate(
|
||||
input=input_file,
|
||||
cache={},
|
||||
language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech"
|
||||
use_itn=False,
|
||||
)
|
||||
|
||||
# -------- 模型推理阶段,将语音识别结果作为大模型Prompt ------
|
||||
prompt = res[0]['text'].split(">")[-1] + ",回答简短一些,保持50字以内!"
|
||||
messages = [
|
||||
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||
|
||||
generated_ids = model.generate(
|
||||
**model_inputs,
|
||||
max_new_tokens=512,
|
||||
)
|
||||
generated_ids = [
|
||||
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
||||
]
|
||||
|
||||
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
print("Input:", prompt)
|
||||
print("Answer:", response)
|
||||
|
||||
# 答复输出文件夹
|
||||
folder_path = "./out_answer/"
|
||||
clear_folder(folder_path)
|
||||
|
||||
# 输入文本
|
||||
text = response
|
||||
# 朗读文本
|
||||
# engine.say(text)
|
||||
# # 等待朗读完成
|
||||
# engine.runAndWait()
|
||||
engine.save_to_file(text, os.path.join(folder_path,"sft_0.wav"))
|
||||
engine.runAndWait()
|
||||
play_audio(f'{folder_path}/sft_0.wav')
|
157
12_SenceVoice_QWen2.5_edgeTTS.py
Normal file
@ -0,0 +1,157 @@
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice
|
||||
from cosyvoice.utils.file_utils import load_wav
|
||||
from funasr import AutoModel
|
||||
import torchaudio
|
||||
import pygame
|
||||
import time
|
||||
import sys
|
||||
import sounddevice as sd
|
||||
from scipy.io.wavfile import write
|
||||
import numpy as np
|
||||
import asyncio
|
||||
import edge_tts
|
||||
import os
|
||||
import shutil
|
||||
|
||||
async def amain(TEXT, VOICE, OUTPUT_FILE) -> None:
|
||||
"""Main function"""
|
||||
communicate = edge_tts.Communicate(TEXT, VOICE)
|
||||
await communicate.save(OUTPUT_FILE)
|
||||
|
||||
def record_audio(filename="output.wav", sample_rate=44100):
|
||||
print("按下 Enter 开始录音...")
|
||||
input() # 等待用户按下 Enter 键开始录音
|
||||
print("录音中... 按下 Enter 键结束录音")
|
||||
|
||||
# 开始录音
|
||||
recording = []
|
||||
try:
|
||||
def callback(indata, frames, time, status):
|
||||
recording.append(indata.copy())
|
||||
with sd.InputStream(samplerate=sample_rate, channels=1, callback=callback):
|
||||
input() # 等待用户再次按下 Enter 键结束录音
|
||||
except Exception as e:
|
||||
print(f"录音出现错误: {e}")
|
||||
return
|
||||
|
||||
# 将录音数据合并并保存为 WAV 文件
|
||||
audio_data = np.concatenate(recording, axis=0)
|
||||
write(filename, sample_rate, (audio_data * 32767).astype(np.int16))
|
||||
print(f"录音已保存为 {filename}")
|
||||
|
||||
|
||||
# --- 播放音频 -
|
||||
def play_audio(file_path):
|
||||
try:
|
||||
pygame.mixer.init()
|
||||
pygame.mixer.music.load(file_path)
|
||||
pygame.mixer.music.play()
|
||||
while pygame.mixer.music.get_busy():
|
||||
time.sleep(1) # 等待音频播放结束
|
||||
print("播放完成!")
|
||||
except Exception as e:
|
||||
print(f"播放失败: {e}")
|
||||
finally:
|
||||
pygame.mixer.quit()
|
||||
|
||||
|
||||
def clear_folder(folder_path):
|
||||
# 检查文件夹是否存在
|
||||
if not os.path.exists(folder_path):
|
||||
os.makedirs(folder_path, exist_ok=True)
|
||||
print(f"文件夹 '{folder_path}' 不存在,已创建")
|
||||
return
|
||||
|
||||
# 获取文件夹中的所有文件和子文件夹
|
||||
items = os.listdir(folder_path)
|
||||
|
||||
# 如果文件夹为空,直接返回
|
||||
if not items:
|
||||
print(f"文件夹 '{folder_path}' 已经为空")
|
||||
return
|
||||
|
||||
# 遍历文件和文件夹并删除
|
||||
for item in items:
|
||||
item_path = os.path.join(folder_path, item)
|
||||
|
||||
# 判断是否是文件夹或文件
|
||||
if os.path.isfile(item_path):
|
||||
os.remove(item_path) # 删除文件
|
||||
print(f"删除文件: {item_path}")
|
||||
elif os.path.isdir(item_path):
|
||||
shutil.rmtree(item_path) # 删除文件夹及其内容
|
||||
print(f"删除文件夹: {item_path}")
|
||||
|
||||
print(f"文件夹 '{folder_path}' 已清空")
|
||||
|
||||
# ------------------- 模型初始化 ---------------
|
||||
# --- SenceVoice-语音识别模型
|
||||
model_dir = r"D:\AI\download\SenseVoiceSmall"
|
||||
model_senceVoice = AutoModel( model=model_dir, trust_remote_code=True, )
|
||||
|
||||
|
||||
# --- QWen2.5大语言模型 ---
|
||||
model_name = r"D:\AI\download\Qwen2.5-0.5B-Instruct"
|
||||
# model_name = r"E:\2_PYTHON\Project\GPT\QWen\Qwen2.5-1.5B-Instruct"
|
||||
# model_name = r':\2_PYTHON\Project\GPT\QWen\Qwen2.5-7B-Instruct-GPTQ-Int4'
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
# --- CosyVoice - 语音合成模型 ------
|
||||
# cosyvoice = CosyVoice(r'../pretrained_models/CosyVoice-300M', load_jit=True, load_onnx=False, fp16=True)
|
||||
# print(cosyvoice.list_avaliable_spks())
|
||||
# --- CosyVoice - 支持的音色列表
|
||||
# ------------------ 模型初始化结束 ----------------
|
||||
|
||||
|
||||
while(1):
|
||||
# 使用函数录音,作为输入
|
||||
record_audio("my_recording.wav")
|
||||
|
||||
input_file = ("my_recording.wav")
|
||||
res = model_senceVoice.generate(
|
||||
input=input_file,
|
||||
cache={},
|
||||
language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech"
|
||||
use_itn=False,
|
||||
)
|
||||
|
||||
# -------- 模型推理阶段,将语音识别结果作为大模型Prompt ------
|
||||
prompt = res[0]['text'].split(">")[-1] + ",回答简短一些,保持50字以内!"
|
||||
messages = [
|
||||
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||
|
||||
generated_ids = model.generate(
|
||||
**model_inputs,
|
||||
max_new_tokens=512,
|
||||
)
|
||||
generated_ids = [
|
||||
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
||||
]
|
||||
|
||||
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
print("Input:", prompt)
|
||||
print("Answer:", response)
|
||||
|
||||
# 答复输出文件夹
|
||||
folder_path = "./out_answer/"
|
||||
clear_folder(folder_path)
|
||||
|
||||
# 输入文本
|
||||
text = response
|
||||
asyncio.run(amain(text, "zh-CN-XiaoyiNeural", os.path.join(folder_path,"sft_0.mp3")))
|
||||
play_audio(f'{folder_path}/sft_0.mp3')
|
355
13_1_SenceVoice_QWen2.5_kokoro_realTime.py
Normal file
@ -0,0 +1,355 @@
|
||||
import cv2
|
||||
import pyaudio
|
||||
import wave
|
||||
import threading
|
||||
import numpy as np
|
||||
import time
|
||||
from queue import Queue
|
||||
import webrtcvad
|
||||
import os
|
||||
import threading
|
||||
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from qwen_vl_utils import process_vision_info
|
||||
import torch
|
||||
from funasr import AutoModel
|
||||
import pygame
|
||||
# import edge_tts
|
||||
from kokoro import KPipeline
|
||||
import asyncio
|
||||
from time import sleep
|
||||
import langid
|
||||
from langdetect import detect
|
||||
from IPython.display import display, Audio
|
||||
import soundfile as sf
|
||||
|
||||
# kokoro 语音合成
|
||||
# 🇪🇸 'e' => Spanish es
|
||||
# 🇫🇷 'f' => French fr-fr
|
||||
# 🇮🇳 'h' => Hindi hi
|
||||
# 🇮🇹 'i' => Italian it
|
||||
# 🇧🇷 'p' => Brazilian Portuguese pt-br
|
||||
# 🇺🇸 'a' => American English, 🇬🇧 'b' => British English
|
||||
# 🇯🇵 'j' => Japanese: pip install misaki[ja]
|
||||
# 🇨🇳 'z' => Mandarin Chinese: pip install misaki[zh]
|
||||
root_voice = r'E:\2_PYTHON\Project\TTS\Kokoro-82M\voices'
|
||||
|
||||
def tts_kokoro(text, outpath, lid='z', voice_glo='zm_yunjian'):
|
||||
global root_voice
|
||||
pipeline = KPipeline(lang_code=lid)
|
||||
voice_tensor = torch.load(os.path.join(root_voice, voice_glo+'.pt'), weights_only=True)
|
||||
generator = pipeline(
|
||||
text, voice=voice_tensor,
|
||||
speed=1, split_pattern=r'\n+'
|
||||
)
|
||||
|
||||
for i, (gs, ps, audio) in enumerate(generator):
|
||||
# display(Audio(data=audio, rate=24000, autoplay=i==0))
|
||||
sf.write(f'{outpath}', audio, 24000) # save each audio file
|
||||
|
||||
# --- 配置huggingFace国内镜像 ---
|
||||
import os
|
||||
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
||||
|
||||
# 参数设置
|
||||
AUDIO_RATE = 16000 # 音频采样率
|
||||
AUDIO_CHANNELS = 1 # 单声道
|
||||
CHUNK = 1024 # 音频块大小
|
||||
VAD_MODE = 3 # VAD 模式 (0-3, 数字越大越敏感)
|
||||
OUTPUT_DIR = "./output" # 输出目录
|
||||
NO_SPEECH_THRESHOLD = 1 # 无效语音阈值,单位:秒
|
||||
folder_path = "./Test_QWen2_VL/"
|
||||
audio_file_count = 0
|
||||
|
||||
# 确保输出目录存在
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
os.makedirs(folder_path, exist_ok=True)
|
||||
|
||||
# 队列用于音频和视频同步缓存
|
||||
audio_queue = Queue()
|
||||
video_queue = Queue()
|
||||
|
||||
# 全局变量
|
||||
last_active_time = time.time()
|
||||
recording_active = True
|
||||
segments_to_save = []
|
||||
saved_intervals = []
|
||||
last_vad_end_time = 0 # 上次保存的 VAD 有效段结束时间
|
||||
|
||||
# 初始化 WebRTC VAD
|
||||
vad = webrtcvad.Vad()
|
||||
vad.set_mode(VAD_MODE)
|
||||
|
||||
# 音频录制线程
|
||||
def audio_recorder():
|
||||
global audio_queue, recording_active, last_active_time, segments_to_save, last_vad_end_time
|
||||
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(format=pyaudio.paInt16,
|
||||
channels=AUDIO_CHANNELS,
|
||||
rate=AUDIO_RATE,
|
||||
input=True,
|
||||
frames_per_buffer=CHUNK)
|
||||
|
||||
audio_buffer = []
|
||||
print("音频录制已开始")
|
||||
|
||||
while recording_active:
|
||||
data = stream.read(CHUNK)
|
||||
audio_buffer.append(data)
|
||||
|
||||
# 每 0.5 秒检测一次 VAD
|
||||
if len(audio_buffer) * CHUNK / AUDIO_RATE >= 0.5:
|
||||
# 拼接音频数据并检测 VAD
|
||||
raw_audio = b''.join(audio_buffer)
|
||||
vad_result = check_vad_activity(raw_audio)
|
||||
|
||||
if vad_result:
|
||||
print("检测到语音活动")
|
||||
last_active_time = time.time()
|
||||
segments_to_save.append((raw_audio, time.time()))
|
||||
else:
|
||||
print("静音中...")
|
||||
|
||||
audio_buffer = [] # 清空缓冲区
|
||||
|
||||
# 检查无效语音时间
|
||||
if time.time() - last_active_time > NO_SPEECH_THRESHOLD:
|
||||
# 检查是否需要保存
|
||||
if segments_to_save and segments_to_save[-1][1] > last_vad_end_time:
|
||||
save_audio_video()
|
||||
last_active_time = time.time()
|
||||
else:
|
||||
pass
|
||||
# print("无新增语音段,跳过保存")
|
||||
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
||||
|
||||
# 视频录制线程
|
||||
def video_recorder():
|
||||
global video_queue, recording_active
|
||||
|
||||
cap = cv2.VideoCapture(0) # 使用默认摄像头
|
||||
print("视频录制已开始")
|
||||
|
||||
while recording_active:
|
||||
ret, frame = cap.read()
|
||||
if ret:
|
||||
video_queue.put((frame, time.time()))
|
||||
|
||||
# 实时显示摄像头画面
|
||||
cv2.imshow("Real Camera", frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'): # 按 Q 键退出
|
||||
break
|
||||
else:
|
||||
print("无法获取摄像头画面")
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
# 检测 VAD 活动
|
||||
def check_vad_activity(audio_data):
|
||||
# 将音频数据分块检测
|
||||
num, rate = 0, 0.4
|
||||
step = int(AUDIO_RATE * 0.02) # 20ms 块大小
|
||||
flag_rate = round(rate * len(audio_data) // step)
|
||||
|
||||
for i in range(0, len(audio_data), step):
|
||||
chunk = audio_data[i:i + step]
|
||||
if len(chunk) == step:
|
||||
if vad.is_speech(chunk, sample_rate=AUDIO_RATE):
|
||||
num += 1
|
||||
|
||||
if num > flag_rate:
|
||||
return True
|
||||
return False
|
||||
|
||||
# 保存音频和视频
|
||||
def save_audio_video():
|
||||
pygame.mixer.init()
|
||||
|
||||
global segments_to_save, video_queue, last_vad_end_time, saved_intervals
|
||||
|
||||
# 全局变量,用于保存音频文件名计数
|
||||
global audio_file_count
|
||||
audio_file_count += 1
|
||||
audio_output_path = f"{OUTPUT_DIR}/audio_{audio_file_count}.wav"
|
||||
# audio_output_path = f"{OUTPUT_DIR}/audio_0.wav"
|
||||
|
||||
if not segments_to_save:
|
||||
return
|
||||
|
||||
# 停止当前播放的音频
|
||||
if pygame.mixer.music.get_busy():
|
||||
pygame.mixer.music.stop()
|
||||
print("检测到新的有效音,已停止当前音频播放")
|
||||
|
||||
# 获取有效段的时间范围
|
||||
start_time = segments_to_save[0][1]
|
||||
end_time = segments_to_save[-1][1]
|
||||
|
||||
# 检查是否与之前的片段重叠
|
||||
if saved_intervals and saved_intervals[-1][1] >= start_time:
|
||||
print("当前片段与之前片段重叠,跳过保存")
|
||||
segments_to_save.clear()
|
||||
return
|
||||
|
||||
# 保存音频
|
||||
audio_frames = [seg[0] for seg in segments_to_save]
|
||||
|
||||
wf = wave.open(audio_output_path, 'wb')
|
||||
wf.setnchannels(AUDIO_CHANNELS)
|
||||
wf.setsampwidth(2) # 16-bit PCM
|
||||
wf.setframerate(AUDIO_RATE)
|
||||
wf.writeframes(b''.join(audio_frames))
|
||||
wf.close()
|
||||
print(f"音频保存至 {audio_output_path}")
|
||||
|
||||
# Inference()
|
||||
# 使用线程执行推理
|
||||
inference_thread = threading.Thread(target=Inference, args=(audio_output_path,))
|
||||
inference_thread.start()
|
||||
|
||||
# 记录保存的区间
|
||||
saved_intervals.append((start_time, end_time))
|
||||
|
||||
# 清空缓冲区
|
||||
segments_to_save.clear()
|
||||
|
||||
# --- 播放音频 -
|
||||
def play_audio(file_path):
|
||||
try:
|
||||
pygame.mixer.init()
|
||||
pygame.mixer.music.load(file_path)
|
||||
pygame.mixer.music.play()
|
||||
while pygame.mixer.music.get_busy():
|
||||
time.sleep(1) # 等待音频播放结束
|
||||
print("播放完成!")
|
||||
except Exception as e:
|
||||
print(f"播放失败: {e}")
|
||||
finally:
|
||||
pygame.mixer.quit()
|
||||
|
||||
# async def amain(TEXT, VOICE, OUTPUT_FILE) -> None:
|
||||
# """Main function"""
|
||||
# communicate = edge_tts.Communicate(TEXT, VOICE)
|
||||
# await communicate.save(OUTPUT_FILE)
|
||||
|
||||
|
||||
# -------- SenceVoice 语音识别 --模型加载-----
|
||||
model_dir = r"E:\2_PYTHON\Project\GPT\QWen\pretrained_models\SenseVoiceSmall"
|
||||
model_senceVoice = AutoModel( model=model_dir, trust_remote_code=True, )
|
||||
|
||||
# --- QWen2.5大语言模型 ---
|
||||
# model_name = r"E:\2_PYTHON\Project\GPT\QWen\Qwen2.5-0.5B-Instruct"
|
||||
model_name = r"E:\2_PYTHON\Project\GPT\QWen\Qwen2.5-1.5B-Instruct"
|
||||
# model_name = r'E:\2_PYTHON\Project\GPT\QWen\Qwen2.5-7B-Instruct-GPTQ-Int4'
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
def Inference(TEMP_AUDIO_FILE=f"{OUTPUT_DIR}/audio_0.wav"):
|
||||
# -------- SenceVoice 推理 ---------
|
||||
input_file = (TEMP_AUDIO_FILE)
|
||||
res = model_senceVoice.generate(
|
||||
input=input_file,
|
||||
cache={},
|
||||
language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech"
|
||||
use_itn=False,
|
||||
)
|
||||
# prompt = res[0]['text'].split(">")[-1]
|
||||
prompt = res[0]['text'].split(">")[-1] + ",回答简短一些,保持50字以内!"
|
||||
print("ASR OUT:", prompt)
|
||||
# ---------SenceVoice --end----------
|
||||
# -------- 模型推理阶段,将语音识别结果作为大模型Prompt ------
|
||||
messages = [
|
||||
{"role": "system", "content": "你叫千问,是一个18岁的女大学生,性格活泼开朗,说话俏皮"},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||
|
||||
generated_ids = model.generate(
|
||||
**model_inputs,
|
||||
max_new_tokens=512,
|
||||
)
|
||||
generated_ids = [
|
||||
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
||||
]
|
||||
|
||||
output_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
print("answer", output_text)
|
||||
|
||||
# 输入文本
|
||||
text = output_text
|
||||
# 语种识别 -- langid
|
||||
language, confidence = langid.classify(text)
|
||||
# 语种识别 -- langdetect
|
||||
# language = detect(text).split("-")[0]
|
||||
|
||||
# 🇪🇸 'e' => Spanish es
|
||||
# 🇫🇷 'f' => French fr-fr
|
||||
# 🇮🇳 'h' => Hindi hi
|
||||
# 🇮🇹 'i' => Italian it
|
||||
# 🇧🇷 'p' => Brazilian Portuguese pt-br
|
||||
# 🇺🇸 'a' => American English, 🇬🇧 'b' => British English
|
||||
# 🇯🇵 'j' => Japanese: pip install misaki[ja]
|
||||
# 🇨🇳 'z' => Mandarin Chinese: pip install misaki[zh]
|
||||
|
||||
language_speaker = {
|
||||
"ja" : "j", # ok
|
||||
"fr" : "f", # ok
|
||||
"es" : "e", # ok
|
||||
"zh" : "z", # ok
|
||||
"en" : "a", # ok
|
||||
}
|
||||
|
||||
language_spk = {
|
||||
"j" : "jf_nezumi", # ok
|
||||
"f" : "ff_siwis", # ok
|
||||
"e" : "em_santa", # ok
|
||||
"z" : "zm_yunyang", # ok
|
||||
"a" : "af_heart", # ok
|
||||
}
|
||||
|
||||
if language not in language_speaker.keys():
|
||||
used_speaker = "z"
|
||||
else:
|
||||
used_speaker = language_speaker[language]
|
||||
print("检测到语种:", language, "使用音色:", language_speaker[language])
|
||||
|
||||
global audio_file_count
|
||||
outpath = os.path.join(folder_path,f"sft_{audio_file_count}.wav")
|
||||
tts_kokoro(text, outpath, lid=used_speaker, voice_glo=language_spk[used_speaker])
|
||||
play_audio(f'{folder_path}/sft_{audio_file_count}.wav')
|
||||
|
||||
# 主函数
|
||||
if __name__ == "__main__":
|
||||
|
||||
try:
|
||||
# 启动音视频录制线程
|
||||
audio_thread = threading.Thread(target=audio_recorder)
|
||||
# video_thread = threading.Thread(target=video_recorder)
|
||||
audio_thread.start()
|
||||
# video_thread.start()
|
||||
|
||||
print("按 Ctrl+C 停止录制")
|
||||
while True:
|
||||
time.sleep(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("录制停止中...")
|
||||
recording_active = False
|
||||
audio_thread.join()
|
||||
# video_thread.join()
|
||||
print("录制已停止")
|
311
13_SenceVoice_QWen2.5_edgeTTS_realTime.py
Normal file
@ -0,0 +1,311 @@
|
||||
import cv2
|
||||
import pyaudio
|
||||
import wave
|
||||
import threading
|
||||
import numpy as np
|
||||
import time
|
||||
from queue import Queue
|
||||
import webrtcvad
|
||||
import os
|
||||
import threading
|
||||
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from qwen_vl_utils import process_vision_info
|
||||
import torch
|
||||
from funasr import AutoModel
|
||||
import pygame
|
||||
import edge_tts
|
||||
import asyncio
|
||||
from time import sleep
|
||||
import langid
|
||||
from langdetect import detect
|
||||
|
||||
# --- 配置huggingFace国内镜像 ---
|
||||
import os
|
||||
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
||||
|
||||
# 参数设置
|
||||
AUDIO_RATE = 16000 # 音频采样率
|
||||
AUDIO_CHANNELS = 1 # 单声道
|
||||
CHUNK = 1024 # 音频块大小
|
||||
VAD_MODE = 3 # VAD 模式 (0-3, 数字越大越敏感)
|
||||
OUTPUT_DIR = "./output" # 输出目录
|
||||
NO_SPEECH_THRESHOLD = 1 # 无效语音阈值,单位:秒
|
||||
folder_path = "./Test_QWen2_VL/"
|
||||
audio_file_count = 0
|
||||
|
||||
# 确保输出目录存在
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
os.makedirs(folder_path, exist_ok=True)
|
||||
|
||||
# 队列用于音频和视频同步缓存
|
||||
audio_queue = Queue()
|
||||
video_queue = Queue()
|
||||
|
||||
# 全局变量
|
||||
last_active_time = time.time()
|
||||
recording_active = True
|
||||
segments_to_save = []
|
||||
saved_intervals = []
|
||||
last_vad_end_time = 0 # 上次保存的 VAD 有效段结束时间
|
||||
|
||||
# 初始化 WebRTC VAD
|
||||
vad = webrtcvad.Vad()
|
||||
vad.set_mode(VAD_MODE)
|
||||
|
||||
# 音频录制线程
|
||||
def audio_recorder():
|
||||
global audio_queue, recording_active, last_active_time, segments_to_save, last_vad_end_time
|
||||
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(format=pyaudio.paInt16,
|
||||
channels=AUDIO_CHANNELS,
|
||||
rate=AUDIO_RATE,
|
||||
input=True,
|
||||
frames_per_buffer=CHUNK)
|
||||
|
||||
audio_buffer = []
|
||||
print("音频录制已开始")
|
||||
|
||||
while recording_active:
|
||||
data = stream.read(CHUNK)
|
||||
audio_buffer.append(data)
|
||||
|
||||
# 每 0.5 秒检测一次 VAD
|
||||
if len(audio_buffer) * CHUNK / AUDIO_RATE >= 0.5:
|
||||
# 拼接音频数据并检测 VAD
|
||||
raw_audio = b''.join(audio_buffer)
|
||||
vad_result = check_vad_activity(raw_audio)
|
||||
|
||||
if vad_result:
|
||||
print("检测到语音活动")
|
||||
last_active_time = time.time()
|
||||
segments_to_save.append((raw_audio, time.time()))
|
||||
else:
|
||||
print("静音中...")
|
||||
|
||||
audio_buffer = [] # 清空缓冲区
|
||||
|
||||
# 检查无效语音时间
|
||||
if time.time() - last_active_time > NO_SPEECH_THRESHOLD:
|
||||
# 检查是否需要保存
|
||||
if segments_to_save and segments_to_save[-1][1] > last_vad_end_time:
|
||||
save_audio_video()
|
||||
last_active_time = time.time()
|
||||
else:
|
||||
pass
|
||||
# print("无新增语音段,跳过保存")
|
||||
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
||||
|
||||
# 视频录制线程
|
||||
def video_recorder():
|
||||
global video_queue, recording_active
|
||||
|
||||
cap = cv2.VideoCapture(0) # 使用默认摄像头
|
||||
print("视频录制已开始")
|
||||
|
||||
while recording_active:
|
||||
ret, frame = cap.read()
|
||||
if ret:
|
||||
video_queue.put((frame, time.time()))
|
||||
|
||||
# 实时显示摄像头画面
|
||||
cv2.imshow("Real Camera", frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'): # 按 Q 键退出
|
||||
break
|
||||
else:
|
||||
print("无法获取摄像头画面")
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
# 检测 VAD 活动
|
||||
def check_vad_activity(audio_data):
|
||||
# 将音频数据分块检测
|
||||
num, rate = 0, 0.4
|
||||
step = int(AUDIO_RATE * 0.02) # 20ms 块大小
|
||||
flag_rate = round(rate * len(audio_data) // step)
|
||||
|
||||
for i in range(0, len(audio_data), step):
|
||||
chunk = audio_data[i:i + step]
|
||||
if len(chunk) == step:
|
||||
if vad.is_speech(chunk, sample_rate=AUDIO_RATE):
|
||||
num += 1
|
||||
|
||||
if num > flag_rate:
|
||||
return True
|
||||
return False
|
||||
|
||||
# 保存音频和视频
|
||||
def save_audio_video():
|
||||
pygame.mixer.init()
|
||||
|
||||
global segments_to_save, video_queue, last_vad_end_time, saved_intervals
|
||||
|
||||
# 全局变量,用于保存音频文件名计数
|
||||
global audio_file_count
|
||||
audio_file_count += 1
|
||||
audio_output_path = f"{OUTPUT_DIR}/audio_{audio_file_count}.wav"
|
||||
# audio_output_path = f"{OUTPUT_DIR}/audio_0.wav"
|
||||
|
||||
if not segments_to_save:
|
||||
return
|
||||
|
||||
# 停止当前播放的音频
|
||||
if pygame.mixer.music.get_busy():
|
||||
pygame.mixer.music.stop()
|
||||
print("检测到新的有效音,已停止当前音频播放")
|
||||
|
||||
# 获取有效段的时间范围
|
||||
start_time = segments_to_save[0][1]
|
||||
end_time = segments_to_save[-1][1]
|
||||
|
||||
# 检查是否与之前的片段重叠
|
||||
if saved_intervals and saved_intervals[-1][1] >= start_time:
|
||||
print("当前片段与之前片段重叠,跳过保存")
|
||||
segments_to_save.clear()
|
||||
return
|
||||
|
||||
# 保存音频
|
||||
audio_frames = [seg[0] for seg in segments_to_save]
|
||||
|
||||
wf = wave.open(audio_output_path, 'wb')
|
||||
wf.setnchannels(AUDIO_CHANNELS)
|
||||
wf.setsampwidth(2) # 16-bit PCM
|
||||
wf.setframerate(AUDIO_RATE)
|
||||
wf.writeframes(b''.join(audio_frames))
|
||||
wf.close()
|
||||
print(f"音频保存至 {audio_output_path}")
|
||||
|
||||
# Inference()
|
||||
# 使用线程执行推理
|
||||
inference_thread = threading.Thread(target=Inference, args=(audio_output_path,))
|
||||
inference_thread.start()
|
||||
|
||||
# 记录保存的区间
|
||||
saved_intervals.append((start_time, end_time))
|
||||
|
||||
# 清空缓冲区
|
||||
segments_to_save.clear()
|
||||
|
||||
# --- 播放音频 -
|
||||
def play_audio(file_path):
|
||||
try:
|
||||
pygame.mixer.init()
|
||||
pygame.mixer.music.load(file_path)
|
||||
pygame.mixer.music.play()
|
||||
while pygame.mixer.music.get_busy():
|
||||
time.sleep(1) # 等待音频播放结束
|
||||
print("播放完成!")
|
||||
except Exception as e:
|
||||
print(f"播放失败: {e}")
|
||||
finally:
|
||||
pygame.mixer.quit()
|
||||
|
||||
async def amain(TEXT, VOICE, OUTPUT_FILE) -> None:
|
||||
"""Main function"""
|
||||
communicate = edge_tts.Communicate(TEXT, VOICE)
|
||||
await communicate.save(OUTPUT_FILE)
|
||||
|
||||
|
||||
# -------- SenceVoice 语音识别 --模型加载-----
|
||||
model_dir = r"D:\AI\download\SenseVoiceSmall"
|
||||
model_senceVoice = AutoModel( model=model_dir, trust_remote_code=True, )
|
||||
|
||||
# --- QWen2.5大语言模型 ---
|
||||
model_name = r"D:\AI\download\Qwen2.5-1.5B-Instruct"
|
||||
# model_name = r'D:\AI\download\Qwen2.5-7B-Instruct-GPTQ-Int4'
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
def Inference(TEMP_AUDIO_FILE=f"{OUTPUT_DIR}/audio_0.wav"):
|
||||
|
||||
# -------- SenceVoice 推理 ---------
|
||||
input_file = (TEMP_AUDIO_FILE)
|
||||
res = model_senceVoice.generate(
|
||||
input=input_file,
|
||||
cache={},
|
||||
language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech"
|
||||
use_itn=False,
|
||||
)
|
||||
# prompt = res[0]['text'].split(">")[-1]
|
||||
prompt = res[0]['text'].split(">")[-1] #识别结果
|
||||
print("ASR OUT:", prompt)
|
||||
# ---------SenceVoice --end----------
|
||||
# -------- 模型推理阶段,将语音识别结果作为大模型Prompt ------
|
||||
messages = [
|
||||
{"role": "system", "content": "你叫千问,是一个18岁的女大学生,性格活泼开朗,说话俏皮"},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||
|
||||
generated_ids = model.generate(
|
||||
**model_inputs,
|
||||
max_new_tokens=512,
|
||||
)
|
||||
generated_ids = [
|
||||
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
||||
]
|
||||
|
||||
output_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
print("answer", output_text)
|
||||
|
||||
# 输入文本
|
||||
text = output_text
|
||||
# 语种识别 -- langid
|
||||
language, confidence = langid.classify(text)
|
||||
# 语种识别 -- langdetect
|
||||
# language = detect(text).split("-")[0]
|
||||
|
||||
language_speaker = {
|
||||
"ja" : "ja-JP-NanamiNeural", # ok
|
||||
"fr" : "fr-FR-DeniseNeural", # ok
|
||||
"es" : "ca-ES-JoanaNeural", # ok
|
||||
"de" : "de-DE-KatjaNeural", # ok
|
||||
"zh" : "zh-CN-XiaoyiNeural", # ok
|
||||
"en" : "en-US-AnaNeural", # ok
|
||||
}
|
||||
|
||||
if language not in language_speaker.keys():
|
||||
used_speaker = "zh-CN-XiaoyiNeural"
|
||||
else:
|
||||
used_speaker = language_speaker[language]
|
||||
print("检测到语种:", language, "使用音色:", language_speaker[language])
|
||||
|
||||
global audio_file_count
|
||||
asyncio.run(amain(text, used_speaker, os.path.join(folder_path,f"sft_{audio_file_count}.mp3")))
|
||||
play_audio(f'{folder_path}/sft_{audio_file_count}.mp3')
|
||||
|
||||
# 主函数
|
||||
if __name__ == "__main__":
|
||||
|
||||
try:
|
||||
# 启动音视频录制线程
|
||||
audio_thread = threading.Thread(target=audio_recorder)
|
||||
# video_thread = threading.Thread(target=video_recorder)
|
||||
audio_thread.start()
|
||||
# video_thread.start()
|
||||
|
||||
print("按 Ctrl+C 停止录制")
|
||||
while True:
|
||||
time.sleep(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("录制停止中...")
|
||||
recording_active = False
|
||||
audio_thread.join()
|
||||
# video_thread.join()
|
||||
print("录制已停止")
|
389
14_SenceVoice_QWen2VL_edgeTTS_realTime.py
Normal file
@ -0,0 +1,389 @@
|
||||
import cv2
|
||||
import pyaudio
|
||||
import wave
|
||||
import threading
|
||||
import numpy as np
|
||||
import time
|
||||
from queue import Queue
|
||||
import webrtcvad
|
||||
import os
|
||||
import threading
|
||||
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
||||
from qwen_vl_utils import process_vision_info
|
||||
import torch
|
||||
from funasr import AutoModel
|
||||
import pygame
|
||||
import edge_tts
|
||||
import asyncio
|
||||
from time import sleep
|
||||
|
||||
import langid
|
||||
from langdetect import detect
|
||||
|
||||
# --- 配置huggingFace国内镜像 ---
|
||||
import os
|
||||
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
||||
|
||||
# 参数设置
|
||||
AUDIO_RATE = 16000 # 音频采样率
|
||||
AUDIO_CHANNELS = 1 # 单声道
|
||||
CHUNK = 1024 # 音频块大小
|
||||
VAD_MODE = 3 # VAD 模式 (0-3, 数字越大越敏感)
|
||||
OUTPUT_DIR = "./output" # 输出目录
|
||||
NO_SPEECH_THRESHOLD = 1 # 无效语音阈值,单位:秒
|
||||
audio_file_count = 0
|
||||
|
||||
# 确保输出目录存在
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
# 队列用于音频和视频同步缓存
|
||||
audio_queue = Queue()
|
||||
video_queue = Queue()
|
||||
|
||||
# 全局变量
|
||||
last_active_time = time.time()
|
||||
recording_active = True
|
||||
segments_to_save = []
|
||||
saved_intervals = []
|
||||
last_vad_end_time = 0 # 上次保存的 VAD 有效段结束时间
|
||||
|
||||
# 初始化 WebRTC VAD
|
||||
vad = webrtcvad.Vad()
|
||||
vad.set_mode(VAD_MODE)
|
||||
|
||||
# 音频录制线程
|
||||
def audio_recorder():
|
||||
global audio_queue, recording_active, last_active_time, segments_to_save, last_vad_end_time
|
||||
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(format=pyaudio.paInt16,
|
||||
channels=AUDIO_CHANNELS,
|
||||
rate=AUDIO_RATE,
|
||||
input=True,
|
||||
frames_per_buffer=CHUNK)
|
||||
|
||||
audio_buffer = []
|
||||
print("音频录制已开始")
|
||||
|
||||
while recording_active:
|
||||
data = stream.read(CHUNK)
|
||||
audio_buffer.append(data)
|
||||
|
||||
# 每 0.5 秒检测一次 VAD
|
||||
if len(audio_buffer) * CHUNK / AUDIO_RATE >= 0.5:
|
||||
# 拼接音频数据并检测 VAD
|
||||
raw_audio = b''.join(audio_buffer)
|
||||
vad_result = check_vad_activity(raw_audio)
|
||||
|
||||
if vad_result:
|
||||
print("检测到语音活动")
|
||||
last_active_time = time.time()
|
||||
segments_to_save.append((raw_audio, time.time()))
|
||||
else:
|
||||
print("静音中...")
|
||||
|
||||
audio_buffer = [] # 清空缓冲区
|
||||
|
||||
# 检查无效语音时间
|
||||
if time.time() - last_active_time > NO_SPEECH_THRESHOLD:
|
||||
# 检查是否需要保存
|
||||
if segments_to_save and segments_to_save[-1][1] > last_vad_end_time:
|
||||
save_audio_video()
|
||||
last_active_time = time.time()
|
||||
else:
|
||||
pass
|
||||
# print("无新增语音段,跳过保存")
|
||||
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
||||
|
||||
# 视频录制线程
|
||||
def video_recorder():
|
||||
global video_queue, recording_active
|
||||
|
||||
cap = cv2.VideoCapture(0) # 使用默认摄像头
|
||||
print("视频录制已开始")
|
||||
|
||||
while recording_active:
|
||||
ret, frame = cap.read()
|
||||
if ret:
|
||||
video_queue.put((frame, time.time()))
|
||||
|
||||
# 实时显示摄像头画面
|
||||
cv2.imshow("Real Camera", frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'): # 按 Q 键退出
|
||||
break
|
||||
else:
|
||||
print("无法获取摄像头画面")
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
# 检测 VAD 活动
|
||||
def check_vad_activity(audio_data):
|
||||
# 将音频数据分块检测,设置有效激活率rate=40%,低于此比例当作静音段
|
||||
num, rate = 0, 0.4
|
||||
step = int(AUDIO_RATE * 0.02) # 20ms 块大小
|
||||
flag_rate = round(rate * len(audio_data) // step)
|
||||
|
||||
for i in range(0, len(audio_data), step):
|
||||
chunk = audio_data[i:i + step]
|
||||
if len(chunk) == step:
|
||||
if vad.is_speech(chunk, sample_rate=AUDIO_RATE):
|
||||
num += 1
|
||||
|
||||
if num > flag_rate:
|
||||
return True
|
||||
return False
|
||||
|
||||
# 保存音频和视频
|
||||
def save_audio_video():
|
||||
pygame.mixer.init()
|
||||
global segments_to_save, video_queue, last_vad_end_time, saved_intervals
|
||||
|
||||
# 全局变量,用于保存音频文件名计数
|
||||
global audio_file_count
|
||||
audio_file_count += 1
|
||||
audio_output_path = f"{OUTPUT_DIR}/audio_{audio_file_count}.wav"
|
||||
video_output_path = f"{OUTPUT_DIR}/video_{audio_file_count}.avi"
|
||||
|
||||
if not segments_to_save:
|
||||
return
|
||||
|
||||
# 用于实时打断:接收到新保存文件需求,停止当前播放的音频
|
||||
if pygame.mixer.music.get_busy():
|
||||
pygame.mixer.music.stop()
|
||||
print("检测到新的有效音,已停止当前音频播放")
|
||||
|
||||
# 获取有效段的时间范围
|
||||
start_time = segments_to_save[0][1]
|
||||
end_time = segments_to_save[-1][1]
|
||||
|
||||
# 检查是否与之前的片段重叠
|
||||
if saved_intervals and saved_intervals[-1][1] >= start_time:
|
||||
print("当前片段与之前片段重叠,跳过保存")
|
||||
segments_to_save.clear()
|
||||
return
|
||||
|
||||
# 保存音频
|
||||
audio_frames = [seg[0] for seg in segments_to_save]
|
||||
wf = wave.open(audio_output_path, 'wb')
|
||||
wf.setnchannels(AUDIO_CHANNELS)
|
||||
wf.setsampwidth(2) # 16-bit PCM
|
||||
wf.setframerate(AUDIO_RATE)
|
||||
wf.writeframes(b''.join(audio_frames))
|
||||
wf.close()
|
||||
print(f"音频保存至 {audio_output_path}")
|
||||
|
||||
# 保存视频
|
||||
video_frames = []
|
||||
while not video_queue.empty():
|
||||
frame, timestamp = video_queue.get()
|
||||
if start_time <= timestamp <= end_time:
|
||||
video_frames.append(frame)
|
||||
|
||||
if video_frames:
|
||||
out = cv2.VideoWriter(video_output_path, cv2.VideoWriter_fourcc(*'XVID'), 20.0, (640, 480))
|
||||
for frame in video_frames:
|
||||
out.write(frame)
|
||||
out.release()
|
||||
print(f"视频保存至 {video_output_path}")
|
||||
|
||||
# --- 直接推理会影响录制主线程,无法实现实时打断逻辑 ---
|
||||
# Inference()
|
||||
|
||||
# --- 使用线程执行推理
|
||||
inference_thread = threading.Thread(target=Inference, args=(video_output_path, audio_output_path))
|
||||
inference_thread.start()
|
||||
else:
|
||||
pass
|
||||
# print("无可保存的视频帧")
|
||||
|
||||
# 记录保存的区间
|
||||
saved_intervals.append((start_time, end_time))
|
||||
|
||||
# 清空缓冲区
|
||||
segments_to_save.clear()
|
||||
|
||||
# --- 播放音频 -
|
||||
def play_audio(file_path):
|
||||
try:
|
||||
pygame.mixer.init()
|
||||
pygame.mixer.music.load(file_path)
|
||||
pygame.mixer.music.play()
|
||||
while pygame.mixer.music.get_busy():
|
||||
time.sleep(1) # 等待音频播放结束
|
||||
print("播放完成!")
|
||||
except Exception as e:
|
||||
print(f"播放失败: {e}")
|
||||
finally:
|
||||
pygame.mixer.quit()
|
||||
|
||||
async def amain(TEXT, VOICE, OUTPUT_FILE) -> None:
|
||||
"""Main function"""
|
||||
communicate = edge_tts.Communicate(TEXT, VOICE)
|
||||
await communicate.save(OUTPUT_FILE)
|
||||
|
||||
# -------------- Load QWen2-VL Model ------------
|
||||
# default: Load the model on the available device(s)
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto"
|
||||
)
|
||||
# ------- 设置分辨率,降低现存占用 -------
|
||||
min_pixels = 256*28*28
|
||||
max_pixels = 512*28*28
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
# --------------------------------------
|
||||
|
||||
# -------- SenceVoice 语音识别 --模型加载-----
|
||||
model_dir = r"E:\2_PYTHON\Project\GPT\QWen\pretrained_models\SenseVoiceSmall"
|
||||
model_senceVoice = AutoModel( model=model_dir, trust_remote_code=True, )
|
||||
folder_path = "./Test_QWen2_VL/"
|
||||
|
||||
def Inference(TEMP_VIDEO_FILE, TEMP_AUDIO_FILE):
|
||||
|
||||
cap = cv2.VideoCapture(TEMP_VIDEO_FILE)
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
# --- 设定视频截取帧时间比例
|
||||
S_index = [0.2, 0.4, 0.6, 0.8]
|
||||
frame_index = [int(total_frames * i) for i in S_index]
|
||||
# 设置视频帧位置
|
||||
for idx in frame_index:
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
print(f"无法读取帧索引 {idx}")
|
||||
else:
|
||||
# 保存帧
|
||||
file_path = os.path.join(folder_path, f"captured_image{idx}.jpg") # 设置保存路径
|
||||
cv2.imwrite(file_path, frame)
|
||||
|
||||
# -------- SenceVoice 推理 --start-------
|
||||
input_file = (TEMP_AUDIO_FILE)
|
||||
res = model_senceVoice.generate(
|
||||
input=input_file,
|
||||
cache={},
|
||||
language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech"
|
||||
use_itn=False,
|
||||
)
|
||||
# prompt = res[0]['text'].split(">")[-1]
|
||||
prompt = res[0]['text'].split(">")[-1] + ",回答简短一些,保持50字以内!"
|
||||
print("ASR OUT:", prompt)
|
||||
# ---------SenceVoice 推理--end----------
|
||||
|
||||
MODE_FLAG = 0
|
||||
# -------- QWen2-VL 模型推理 --------- 多图模式
|
||||
# Messages containing a images list as a video and a text query
|
||||
if not MODE_FLAG:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video",
|
||||
"video": [
|
||||
f'{os.path.join(folder_path, f"captured_image{frame_index[0]}.jpg")}',
|
||||
f'{os.path.join(folder_path, f"captured_image{frame_index[1]}.jpg")}',
|
||||
f'{os.path.join(folder_path, f"captured_image{frame_index[2]}.jpg")}',
|
||||
f'{os.path.join(folder_path, f"captured_image{frame_index[3]}.jpg")}',
|
||||
],
|
||||
"fps": 1.0,
|
||||
},
|
||||
{"type": "text", "text": f"{prompt}"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# -------- QWen2-VL 模型推理 --------- 视频模式
|
||||
else:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video",
|
||||
"video": f"{TEMP_VIDEO_FILE}",
|
||||
"max_pixels": 360 * 420,
|
||||
"fps": 1.0,
|
||||
},
|
||||
{"type": "text", "text": f"{prompt}"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# Preparation for inference
|
||||
text = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to("cuda")
|
||||
|
||||
# Inference
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
print(output_text)
|
||||
|
||||
# 输入文本
|
||||
text = output_text[0]
|
||||
# 语种识别 -- langid
|
||||
language, confidence = langid.classify(text)
|
||||
# 语种识别 -- langdetect -- 没做结果对应关键词映射
|
||||
# language = detect(text)
|
||||
|
||||
language_speaker = {
|
||||
"ja" : "ja-JP-NanamiNeural", # ok
|
||||
"fr" : "fr-FR-DeniseNeural", # ok
|
||||
"es" : "ca-ES-JoanaNeural", # ok
|
||||
"de" : "de-DE-KatjaNeural", # ok
|
||||
"zh" : "zh-CN-XiaoyiNeural", # ok
|
||||
"en" : "en-US-AnaNeural", # ok
|
||||
}
|
||||
|
||||
if language not in language_speaker.keys():
|
||||
used_speaker = "zh-CN-XiaoyiNeural"
|
||||
else:
|
||||
used_speaker = language_speaker[language]
|
||||
print("检测到语种:", language, "使用音色:", language_speaker[language])
|
||||
|
||||
global audio_file_count
|
||||
asyncio.run(amain(text, used_speaker, os.path.join(folder_path,f"sft_{audio_file_count}.mp3")))
|
||||
play_audio(f'{folder_path}/sft_{audio_file_count}.mp3')
|
||||
|
||||
# asyncio.run(amain(text, "zh-CN-YunjianNeural", os.path.join(folder_path,"sft_0.mp3")))
|
||||
# play_audio(f'{folder_path}/sft_0.mp3')
|
||||
|
||||
# asyncio.run(amain(text, "zh-CN-shaanxi-XiaoniNeural", os.path.join(folder_path,"sft_0.mp3")))
|
||||
# play_audio(f'{folder_path}/sft_0.mp3')
|
||||
|
||||
# 主函数
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
# 启动音视频录制线程
|
||||
audio_thread = threading.Thread(target=audio_recorder)
|
||||
video_thread = threading.Thread(target=video_recorder)
|
||||
audio_thread.start()
|
||||
video_thread.start()
|
||||
|
||||
print("按 Ctrl+C 停止录制")
|
||||
while True:
|
||||
time.sleep(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("录制停止中...")
|
||||
recording_active = False
|
||||
audio_thread.join()
|
||||
video_thread.join()
|
||||
print("录制已停止")
|
474
15.0_SenceVoice_kws_CAM++.py
Normal file
@ -0,0 +1,474 @@
|
||||
import cv2
|
||||
import pyaudio
|
||||
import wave
|
||||
import threading
|
||||
import numpy as np
|
||||
import time
|
||||
from queue import Queue
|
||||
import webrtcvad
|
||||
import os
|
||||
import threading
|
||||
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from qwen_vl_utils import process_vision_info
|
||||
import torch
|
||||
from funasr import AutoModel
|
||||
import pygame
|
||||
import edge_tts
|
||||
import asyncio
|
||||
from time import sleep
|
||||
import langid
|
||||
from langdetect import detect
|
||||
import re
|
||||
from pypinyin import pinyin, Style
|
||||
from modelscope.pipelines import pipeline
|
||||
|
||||
# --- 配置huggingFace国内镜像 ---
|
||||
import os
|
||||
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
||||
|
||||
# 参数设置
|
||||
AUDIO_RATE = 16000 # 音频采样率
|
||||
AUDIO_CHANNELS = 1 # 单声道
|
||||
CHUNK = 1024 # 音频块大小
|
||||
VAD_MODE = 3 # VAD 模式 (0-3, 数字越大越敏感)
|
||||
OUTPUT_DIR = "./output" # 输出目录
|
||||
NO_SPEECH_THRESHOLD = 1 # 无效语音阈值,单位:秒
|
||||
folder_path = "./Test_QWen2_VL/"
|
||||
audio_file_count = 0
|
||||
audio_file_count_tmp = 0
|
||||
|
||||
# 确保输出目录存在
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
os.makedirs(folder_path, exist_ok=True)
|
||||
|
||||
# 队列用于音频和视频同步缓存
|
||||
audio_queue = Queue()
|
||||
video_queue = Queue()
|
||||
|
||||
# 全局变量
|
||||
last_active_time = time.time()
|
||||
recording_active = True
|
||||
segments_to_save = []
|
||||
saved_intervals = []
|
||||
last_vad_end_time = 0 # 上次保存的 VAD 有效段结束时间
|
||||
|
||||
|
||||
# --- 唤醒词、声纹变量配置 ---
|
||||
# set_KWS = "ni hao xiao qian"
|
||||
set_KWS = "shuo hua xiao qian"
|
||||
flag_KWS = 0
|
||||
flag_KWS_used = 1
|
||||
|
||||
flag_sv_used = 1
|
||||
flag_sv_enroll = 0
|
||||
thred_sv = 0.35
|
||||
|
||||
# 初始化 WebRTC VAD
|
||||
vad = webrtcvad.Vad()
|
||||
vad.set_mode(VAD_MODE)
|
||||
|
||||
|
||||
def extract_chinese_and_convert_to_pinyin(input_string):
|
||||
"""
|
||||
提取字符串中的汉字,并将其转换为拼音。
|
||||
|
||||
:param input_string: 原始字符串
|
||||
:return: 转换后的拼音字符串
|
||||
"""
|
||||
# 使用正则表达式提取所有汉字
|
||||
chinese_characters = re.findall(r'[\u4e00-\u9fa5]', input_string)
|
||||
# 将汉字列表合并为字符串
|
||||
chinese_text = ''.join(chinese_characters)
|
||||
|
||||
# 转换为拼音
|
||||
pinyin_result = pinyin(chinese_text, style=Style.NORMAL)
|
||||
# 将拼音列表拼接为字符串
|
||||
pinyin_text = ' '.join([item[0] for item in pinyin_result])
|
||||
|
||||
return pinyin_text
|
||||
|
||||
|
||||
# 音频录制线程
|
||||
def audio_recorder():
|
||||
global audio_queue, recording_active, last_active_time, segments_to_save, last_vad_end_time
|
||||
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(format=pyaudio.paInt16,
|
||||
channels=AUDIO_CHANNELS,
|
||||
rate=AUDIO_RATE,
|
||||
input=True,
|
||||
frames_per_buffer=CHUNK)
|
||||
|
||||
audio_buffer = []
|
||||
print("音频录制已开始")
|
||||
|
||||
while recording_active:
|
||||
data = stream.read(CHUNK)
|
||||
audio_buffer.append(data)
|
||||
|
||||
# 每 0.5 秒检测一次 VAD
|
||||
if len(audio_buffer) * CHUNK / AUDIO_RATE >= 0.5:
|
||||
# 拼接音频数据并检测 VAD
|
||||
raw_audio = b''.join(audio_buffer)
|
||||
vad_result = check_vad_activity(raw_audio)
|
||||
|
||||
if vad_result:
|
||||
print("检测到语音活动")
|
||||
last_active_time = time.time()
|
||||
segments_to_save.append((raw_audio, time.time()))
|
||||
else:
|
||||
print("静音中...")
|
||||
|
||||
audio_buffer = [] # 清空缓冲区
|
||||
|
||||
# 检查无效语音时间
|
||||
if time.time() - last_active_time > NO_SPEECH_THRESHOLD:
|
||||
# 检查是否需要保存
|
||||
if segments_to_save and segments_to_save[-1][1] > last_vad_end_time:
|
||||
save_audio_video()
|
||||
last_active_time = time.time()
|
||||
else:
|
||||
pass
|
||||
# print("无新增语音段,跳过保存")
|
||||
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
||||
|
||||
# 视频录制线程
|
||||
def video_recorder():
|
||||
global video_queue, recording_active
|
||||
|
||||
cap = cv2.VideoCapture(0) # 使用默认摄像头
|
||||
print("视频录制已开始")
|
||||
|
||||
while recording_active:
|
||||
ret, frame = cap.read()
|
||||
if ret:
|
||||
video_queue.put((frame, time.time()))
|
||||
|
||||
# 实时显示摄像头画面
|
||||
cv2.imshow("Real Camera", frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'): # 按 Q 键退出
|
||||
break
|
||||
else:
|
||||
print("无法获取摄像头画面")
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
# 检测 VAD 活动
|
||||
def check_vad_activity(audio_data):
|
||||
# 将音频数据分块检测
|
||||
num, rate = 0, 0.4
|
||||
step = int(AUDIO_RATE * 0.02) # 20ms 块大小
|
||||
flag_rate = round(rate * len(audio_data) // step)
|
||||
|
||||
for i in range(0, len(audio_data), step):
|
||||
chunk = audio_data[i:i + step]
|
||||
if len(chunk) == step:
|
||||
if vad.is_speech(chunk, sample_rate=AUDIO_RATE):
|
||||
num += 1
|
||||
|
||||
if num > flag_rate:
|
||||
return True
|
||||
return False
|
||||
|
||||
# 保存音频和视频
|
||||
def save_audio_video():
|
||||
pygame.mixer.init()
|
||||
|
||||
global segments_to_save, video_queue, last_vad_end_time, saved_intervals
|
||||
|
||||
# 全局变量,用于保存音频文件名计数
|
||||
global audio_file_count
|
||||
global flag_sv_enroll
|
||||
global set_SV_enroll
|
||||
|
||||
if flag_sv_enroll:
|
||||
audio_output_path = f"{set_SV_enroll}/enroll_0.wav"
|
||||
else:
|
||||
audio_file_count += 1
|
||||
audio_output_path = f"{OUTPUT_DIR}/audio_{audio_file_count}.wav"
|
||||
# audio_output_path = f"{OUTPUT_DIR}/audio_0.wav"
|
||||
|
||||
if not segments_to_save:
|
||||
return
|
||||
|
||||
# 停止当前播放的音频
|
||||
if pygame.mixer.music.get_busy():
|
||||
pygame.mixer.music.stop()
|
||||
print("检测到新的有效音,已停止当前音频播放")
|
||||
|
||||
# 获取有效段的时间范围
|
||||
start_time = segments_to_save[0][1]
|
||||
end_time = segments_to_save[-1][1]
|
||||
|
||||
# 检查是否与之前的片段重叠
|
||||
if saved_intervals and saved_intervals[-1][1] >= start_time:
|
||||
print("当前片段与之前片段重叠,跳过保存")
|
||||
segments_to_save.clear()
|
||||
return
|
||||
|
||||
# 保存音频
|
||||
audio_frames = [seg[0] for seg in segments_to_save]
|
||||
if flag_sv_enroll:
|
||||
audio_length = 0.5 * len(segments_to_save)
|
||||
if audio_length < 3:
|
||||
print("声纹注册语音需大于3秒,请重新注册")
|
||||
return 1
|
||||
|
||||
wf = wave.open(audio_output_path, 'wb')
|
||||
wf.setnchannels(AUDIO_CHANNELS)
|
||||
wf.setsampwidth(2) # 16-bit PCM
|
||||
wf.setframerate(AUDIO_RATE)
|
||||
wf.writeframes(b''.join(audio_frames))
|
||||
wf.close()
|
||||
print(f"音频保存至 {audio_output_path}")
|
||||
|
||||
# Inference()
|
||||
|
||||
if flag_sv_enroll:
|
||||
text = "声纹注册完成!现在只有你可以命令我啦!"
|
||||
print(text)
|
||||
flag_sv_enroll = 0
|
||||
system_introduction(text)
|
||||
else:
|
||||
# 使用线程执行推理
|
||||
inference_thread = threading.Thread(target=Inference, args=(audio_output_path,))
|
||||
inference_thread.start()
|
||||
|
||||
# 记录保存的区间
|
||||
saved_intervals.append((start_time, end_time))
|
||||
|
||||
# 清空缓冲区
|
||||
segments_to_save.clear()
|
||||
|
||||
# --- 播放音频 -
|
||||
def play_audio(file_path):
|
||||
try:
|
||||
pygame.mixer.init()
|
||||
pygame.mixer.music.load(file_path)
|
||||
pygame.mixer.music.play()
|
||||
while pygame.mixer.music.get_busy():
|
||||
time.sleep(1) # 等待音频播放结束
|
||||
print("播放完成!")
|
||||
except Exception as e:
|
||||
print(f"播放失败: {e}")
|
||||
finally:
|
||||
pygame.mixer.quit()
|
||||
|
||||
async def amain(TEXT, VOICE, OUTPUT_FILE) -> None:
|
||||
"""Main function"""
|
||||
communicate = edge_tts.Communicate(TEXT, VOICE)
|
||||
await communicate.save(OUTPUT_FILE)
|
||||
|
||||
import os
|
||||
|
||||
def is_folder_empty(folder_path):
|
||||
"""
|
||||
检测指定文件夹内是否有文件。
|
||||
|
||||
:param folder_path: 文件夹路径
|
||||
:return: 如果文件夹为空返回 True,否则返回 False
|
||||
"""
|
||||
# 获取文件夹中的所有条目(文件或子文件夹)
|
||||
entries = os.listdir(folder_path)
|
||||
# 检查是否存在文件
|
||||
for entry in entries:
|
||||
# 获取完整路径
|
||||
full_path = os.path.join(folder_path, entry)
|
||||
# 如果是文件,返回 False
|
||||
if os.path.isfile(full_path):
|
||||
return False
|
||||
# 如果没有文件,返回 True
|
||||
return True
|
||||
|
||||
|
||||
# -------- SenceVoice 语音识别 --模型加载-----
|
||||
model_dir = r"E:\2_PYTHON\Project\GPT\QWen\pretrained_models\SenseVoiceSmall"
|
||||
model_senceVoice = AutoModel( model=model_dir, trust_remote_code=True, )
|
||||
|
||||
set_SV_enroll = r'.\SpeakerVerification_DIR\enroll_wav\\'
|
||||
# -------- CAM++声纹识别 -- 模型加载 --------
|
||||
sv_pipeline = pipeline(
|
||||
task='speaker-verification',
|
||||
model='damo/speech_campplus_sv_zh-cn_16k-common',
|
||||
model_revision='v1.0.0'
|
||||
)
|
||||
|
||||
# --------- QWen2.5大语言模型 ---------------
|
||||
# model_name = r"E:\2_PYTHON\Project\GPT\QWen\Qwen2.5-0.5B-Instruct"
|
||||
model_name = r"E:\2_PYTHON\Project\GPT\QWen\Qwen2.5-1.5B-Instruct"
|
||||
# model_name = r'E:\2_PYTHON\Project\GPT\QWen\Qwen2.5-7B-Instruct-GPTQ-Int4'
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
def system_introduction(text):
|
||||
global audio_file_count
|
||||
global folder_path
|
||||
text = text
|
||||
print("LLM output:", text)
|
||||
used_speaker = "zh-CN-XiaoyiNeural"
|
||||
asyncio.run(amain(text, used_speaker, os.path.join(folder_path,f"sft_tmp_{audio_file_count}.mp3")))
|
||||
play_audio(f'{folder_path}/sft_tmp_{audio_file_count}.mp3')
|
||||
|
||||
def Inference(TEMP_AUDIO_FILE=f"{OUTPUT_DIR}/audio_0.wav"):
|
||||
'''
|
||||
1. 使用senceVoice做asr,转换为拼音,检测唤醒词
|
||||
- 首先检测声纹注册文件夹是否有注册文件,如果无,启动声纹注册
|
||||
2. 使用CAM++做声纹识别
|
||||
- 设置固定声纹注册语音目录,每次输入音频均进行声纹对比
|
||||
3. 以上两者均通过,则进行大模型推理
|
||||
'''
|
||||
global audio_file_count
|
||||
|
||||
global set_SV_enroll
|
||||
global flag_sv_enroll
|
||||
global thred_sv
|
||||
global flag_sv_used
|
||||
|
||||
global set_KWS
|
||||
global flag_KWS
|
||||
global flag_KWS_used
|
||||
|
||||
|
||||
# flag_info = f'{flag_sv_used}-{flag_KWS_used}'
|
||||
# dict_flag_info = {
|
||||
# "1-1": "您已开启声纹识别和关键词唤醒,",
|
||||
# "0-1":"您已开启关键词唤醒",
|
||||
# "1-0":"您已开启声纹识别",
|
||||
# "0-0":"",
|
||||
# }
|
||||
# if flag_sv_used or flag_KWS_used:
|
||||
# text = dict_flag_info[flag_info]
|
||||
# system_introduction(text)
|
||||
|
||||
os.makedirs(set_SV_enroll, exist_ok=True)
|
||||
# --- 如果开启声纹识别,且声纹文件夹为空,则开始声纹注册。设定注册语音有效长度需大于3秒
|
||||
if flag_sv_used and is_folder_empty(set_SV_enroll):
|
||||
text = f"无声纹注册文件!请先注册声纹,需大于三秒哦~"
|
||||
print(text)
|
||||
system_introduction(text)
|
||||
flag_sv_enroll = 1
|
||||
|
||||
else:
|
||||
# -------- SenceVoice 推理 ---------
|
||||
input_file = (TEMP_AUDIO_FILE)
|
||||
res = model_senceVoice.generate(
|
||||
input=input_file,
|
||||
cache={},
|
||||
language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech"
|
||||
use_itn=False,
|
||||
)
|
||||
prompt = res[0]['text'].split(">")[-1]
|
||||
prompt_pinyin = extract_chinese_and_convert_to_pinyin(prompt)
|
||||
print(prompt, prompt_pinyin)
|
||||
|
||||
# --- 判断是否启动KWS
|
||||
if not flag_KWS_used:
|
||||
flag_KWS = 1
|
||||
if not flag_KWS:
|
||||
if set_KWS in prompt_pinyin:
|
||||
flag_KWS = 1
|
||||
|
||||
# --- KWS成功,或不设置KWS
|
||||
if flag_KWS:
|
||||
sv_score = sv_pipeline([os.path.join(set_SV_enroll, "enroll_0.wav"), TEMP_AUDIO_FILE], thr=thred_sv)
|
||||
print(sv_score)
|
||||
sv_result = sv_score['text']
|
||||
if sv_result == "yes":
|
||||
prompt = res[0]['text'].split(">")[-1] + ",回答简短一些,保持50字以内!"
|
||||
print("ASR OUT:", prompt)
|
||||
# ---------SenceVoice --end----------
|
||||
# -------- 模型推理阶段,将语音识别结果作为大模型Prompt ------
|
||||
messages = [
|
||||
{"role": "system", "content": "你叫小千,是一个18岁的女大学生,性格活泼开朗,说话俏皮"},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||
|
||||
generated_ids = model.generate(
|
||||
**model_inputs,
|
||||
max_new_tokens=512,
|
||||
)
|
||||
generated_ids = [
|
||||
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
||||
]
|
||||
|
||||
output_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
print("answer", output_text)
|
||||
|
||||
# 输入文本
|
||||
text = output_text
|
||||
# 语种识别 -- langid
|
||||
language, confidence = langid.classify(text)
|
||||
# 语种识别 -- langdetect
|
||||
# language = detect(text).split("-")[0]
|
||||
|
||||
language_speaker = {
|
||||
"ja" : "ja-JP-NanamiNeural", # ok
|
||||
"fr" : "fr-FR-DeniseNeural", # ok
|
||||
"es" : "ca-ES-JoanaNeural", # ok
|
||||
"de" : "de-DE-KatjaNeural", # ok
|
||||
"zh" : "zh-CN-XiaoyiNeural", # ok
|
||||
"en" : "en-US-AnaNeural", # ok
|
||||
}
|
||||
|
||||
if language not in language_speaker.keys():
|
||||
used_speaker = "zh-CN-XiaoyiNeural"
|
||||
else:
|
||||
used_speaker = language_speaker[language]
|
||||
print("检测到语种:", language, "使用音色:", language_speaker[language])
|
||||
|
||||
asyncio.run(amain(text, used_speaker, os.path.join(folder_path,f"sft_{audio_file_count}.mp3")))
|
||||
play_audio(f'{folder_path}/sft_{audio_file_count}.mp3')
|
||||
else:
|
||||
text = "很抱歉,你不是我的主人哦,我无法为您服务"
|
||||
system_introduction(text)
|
||||
else:
|
||||
text = "很抱歉,唤醒词错误,我无法为您服务。请说出正确的唤醒词哦"
|
||||
system_introduction(text)
|
||||
|
||||
# 主函数
|
||||
if __name__ == "__main__":
|
||||
|
||||
try:
|
||||
# 启动音视频录制线程
|
||||
audio_thread = threading.Thread(target=audio_recorder)
|
||||
# video_thread = threading.Thread(target=video_recorder)
|
||||
audio_thread.start()
|
||||
# video_thread.start()
|
||||
|
||||
flag_info = f'{flag_sv_used}-{flag_KWS_used}'
|
||||
dict_flag_info = {
|
||||
"1-1": "您已开启声纹识别和关键词唤醒,",
|
||||
"0-1":"您已开启关键词唤醒",
|
||||
"1-0":"您已开启声纹识别",
|
||||
"0-0":"",
|
||||
}
|
||||
if flag_sv_used or flag_KWS_used:
|
||||
text = dict_flag_info[flag_info]
|
||||
system_introduction(text)
|
||||
|
||||
print("按 Ctrl+C 停止录制")
|
||||
while True:
|
||||
time.sleep(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("录制停止中...")
|
||||
recording_active = False
|
||||
audio_thread.join()
|
||||
# video_thread.join()
|
||||
print("录制已停止")
|
502
15.1_SenceVoice_kws_CAM++.py
Normal file
@ -0,0 +1,502 @@
|
||||
import cv2
|
||||
import pyaudio
|
||||
import wave
|
||||
import threading
|
||||
import numpy as np
|
||||
import time
|
||||
from queue import Queue
|
||||
import webrtcvad
|
||||
import os
|
||||
import threading
|
||||
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from qwen_vl_utils import process_vision_info
|
||||
import torch
|
||||
from funasr import AutoModel
|
||||
import pygame
|
||||
import edge_tts
|
||||
import asyncio
|
||||
from time import sleep
|
||||
import langid
|
||||
from langdetect import detect
|
||||
import re
|
||||
from pypinyin import pinyin, Style
|
||||
from modelscope.pipelines import pipeline
|
||||
|
||||
# --- 配置huggingFace国内镜像 ---
|
||||
import os
|
||||
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
||||
|
||||
# 参数设置
|
||||
AUDIO_RATE = 16000 # 音频采样率
|
||||
AUDIO_CHANNELS = 1 # 单声道
|
||||
CHUNK = 1024 # 音频块大小
|
||||
VAD_MODE = 3 # VAD 模式 (0-3, 数字越大越敏感)
|
||||
OUTPUT_DIR = "./output" # 输出目录
|
||||
NO_SPEECH_THRESHOLD = 1 # 无效语音阈值,单位:秒
|
||||
folder_path = "./Test_QWen2_VL/"
|
||||
audio_file_count = 0
|
||||
audio_file_count_tmp = 0
|
||||
|
||||
# 确保输出目录存在
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
os.makedirs(folder_path, exist_ok=True)
|
||||
|
||||
# 队列用于音频和视频同步缓存
|
||||
audio_queue = Queue()
|
||||
video_queue = Queue()
|
||||
|
||||
# 全局变量
|
||||
last_active_time = time.time()
|
||||
recording_active = True
|
||||
segments_to_save = []
|
||||
saved_intervals = []
|
||||
last_vad_end_time = 0 # 上次保存的 VAD 有效段结束时间
|
||||
|
||||
|
||||
# --- 唤醒词、声纹变量配置 ---
|
||||
# set_KWS = "ni hao xiao qian"
|
||||
# set_KWS = "shuo hua xiao qian"
|
||||
set_KWS = "ni hao qian wen"
|
||||
flag_KWS = 0
|
||||
|
||||
flag_KWS_used = 1
|
||||
flag_sv_used = 1
|
||||
|
||||
flag_sv_enroll = 0
|
||||
thred_sv = 0.35
|
||||
|
||||
# 初始化 WebRTC VAD
|
||||
vad = webrtcvad.Vad()
|
||||
vad.set_mode(VAD_MODE)
|
||||
|
||||
|
||||
def extract_chinese_and_convert_to_pinyin(input_string):
|
||||
"""
|
||||
提取字符串中的汉字,并将其转换为拼音。
|
||||
|
||||
:param input_string: 原始字符串
|
||||
:return: 转换后的拼音字符串
|
||||
"""
|
||||
# 使用正则表达式提取所有汉字
|
||||
chinese_characters = re.findall(r'[\u4e00-\u9fa5]', input_string)
|
||||
# 将汉字列表合并为字符串
|
||||
chinese_text = ''.join(chinese_characters)
|
||||
|
||||
# 转换为拼音
|
||||
pinyin_result = pinyin(chinese_text, style=Style.NORMAL)
|
||||
# 将拼音列表拼接为字符串
|
||||
pinyin_text = ' '.join([item[0] for item in pinyin_result])
|
||||
|
||||
return pinyin_text
|
||||
|
||||
|
||||
# 音频录制线程
|
||||
def audio_recorder():
|
||||
global audio_queue, recording_active, last_active_time, segments_to_save, last_vad_end_time
|
||||
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(format=pyaudio.paInt16,
|
||||
channels=AUDIO_CHANNELS,
|
||||
rate=AUDIO_RATE,
|
||||
input=True,
|
||||
frames_per_buffer=CHUNK)
|
||||
|
||||
audio_buffer = []
|
||||
print("音频录制已开始")
|
||||
|
||||
while recording_active:
|
||||
data = stream.read(CHUNK)
|
||||
audio_buffer.append(data)
|
||||
|
||||
# 每 0.5 秒检测一次 VAD
|
||||
if len(audio_buffer) * CHUNK / AUDIO_RATE >= 0.5:
|
||||
# 拼接音频数据并检测 VAD
|
||||
raw_audio = b''.join(audio_buffer)
|
||||
vad_result = check_vad_activity(raw_audio)
|
||||
|
||||
if vad_result:
|
||||
print("检测到语音活动")
|
||||
last_active_time = time.time()
|
||||
segments_to_save.append((raw_audio, time.time()))
|
||||
else:
|
||||
print("静音中...")
|
||||
|
||||
audio_buffer = [] # 清空缓冲区
|
||||
|
||||
# 检查无效语音时间
|
||||
if time.time() - last_active_time > NO_SPEECH_THRESHOLD:
|
||||
# 检查是否需要保存
|
||||
if segments_to_save and segments_to_save[-1][1] > last_vad_end_time:
|
||||
save_audio_video()
|
||||
last_active_time = time.time()
|
||||
else:
|
||||
pass
|
||||
# print("无新增语音段,跳过保存")
|
||||
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
||||
|
||||
# 视频录制线程
|
||||
def video_recorder():
|
||||
global video_queue, recording_active
|
||||
|
||||
cap = cv2.VideoCapture(0) # 使用默认摄像头
|
||||
print("视频录制已开始")
|
||||
|
||||
while recording_active:
|
||||
ret, frame = cap.read()
|
||||
if ret:
|
||||
video_queue.put((frame, time.time()))
|
||||
|
||||
# 实时显示摄像头画面
|
||||
cv2.imshow("Real Camera", frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'): # 按 Q 键退出
|
||||
break
|
||||
else:
|
||||
print("无法获取摄像头画面")
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
# 检测 VAD 活动
|
||||
def check_vad_activity(audio_data):
|
||||
# 将音频数据分块检测
|
||||
num, rate = 0, 0.5
|
||||
step = int(AUDIO_RATE * 0.02) # 20ms 块大小
|
||||
flag_rate = round(rate * len(audio_data) // step)
|
||||
|
||||
for i in range(0, len(audio_data), step):
|
||||
chunk = audio_data[i:i + step]
|
||||
if len(chunk) == step:
|
||||
if vad.is_speech(chunk, sample_rate=AUDIO_RATE):
|
||||
num += 1
|
||||
|
||||
if num > flag_rate:
|
||||
return True
|
||||
return False
|
||||
|
||||
# 保存音频和视频
|
||||
def save_audio_video():
|
||||
pygame.mixer.init()
|
||||
|
||||
global segments_to_save, video_queue, last_vad_end_time, saved_intervals
|
||||
|
||||
# 全局变量,用于保存音频文件名计数
|
||||
global audio_file_count
|
||||
global flag_sv_enroll
|
||||
global set_SV_enroll
|
||||
|
||||
if flag_sv_enroll:
|
||||
audio_output_path = f"{set_SV_enroll}/enroll_0.wav"
|
||||
else:
|
||||
audio_file_count += 1
|
||||
audio_output_path = f"{OUTPUT_DIR}/audio_{audio_file_count}.wav"
|
||||
# audio_output_path = f"{OUTPUT_DIR}/audio_0.wav"
|
||||
|
||||
if not segments_to_save:
|
||||
return
|
||||
|
||||
# 停止当前播放的音频
|
||||
if pygame.mixer.music.get_busy():
|
||||
pygame.mixer.music.stop()
|
||||
print("检测到新的有效音,已停止当前音频播放")
|
||||
|
||||
# 获取有效段的时间范围
|
||||
start_time = segments_to_save[0][1]
|
||||
end_time = segments_to_save[-1][1]
|
||||
|
||||
# 检查是否与之前的片段重叠
|
||||
if saved_intervals and saved_intervals[-1][1] >= start_time:
|
||||
print("当前片段与之前片段重叠,跳过保存")
|
||||
segments_to_save.clear()
|
||||
return
|
||||
|
||||
# 保存音频
|
||||
audio_frames = [seg[0] for seg in segments_to_save]
|
||||
if flag_sv_enroll:
|
||||
audio_length = 0.5 * len(segments_to_save)
|
||||
if audio_length < 3:
|
||||
print("声纹注册语音需大于3秒,请重新注册")
|
||||
return 1
|
||||
|
||||
wf = wave.open(audio_output_path, 'wb')
|
||||
wf.setnchannels(AUDIO_CHANNELS)
|
||||
wf.setsampwidth(2) # 16-bit PCM
|
||||
wf.setframerate(AUDIO_RATE)
|
||||
wf.writeframes(b''.join(audio_frames))
|
||||
wf.close()
|
||||
print(f"音频保存至 {audio_output_path}")
|
||||
|
||||
# Inference()
|
||||
|
||||
if flag_sv_enroll:
|
||||
text = "声纹注册完成!现在只有你可以命令我啦!"
|
||||
print(text)
|
||||
flag_sv_enroll = 0
|
||||
system_introduction(text)
|
||||
else:
|
||||
# 使用线程执行推理
|
||||
inference_thread = threading.Thread(target=Inference, args=(audio_output_path,))
|
||||
inference_thread.start()
|
||||
|
||||
# 记录保存的区间
|
||||
saved_intervals.append((start_time, end_time))
|
||||
|
||||
# 清空缓冲区
|
||||
segments_to_save.clear()
|
||||
|
||||
# --- 播放音频 -
|
||||
def play_audio(file_path):
|
||||
try:
|
||||
pygame.mixer.init()
|
||||
pygame.mixer.music.load(file_path)
|
||||
pygame.mixer.music.play()
|
||||
while pygame.mixer.music.get_busy():
|
||||
time.sleep(1) # 等待音频播放结束
|
||||
print("播放完成!")
|
||||
except Exception as e:
|
||||
print(f"播放失败: {e}")
|
||||
finally:
|
||||
pygame.mixer.quit()
|
||||
|
||||
async def amain(TEXT, VOICE, OUTPUT_FILE) -> None:
|
||||
"""Main function"""
|
||||
communicate = edge_tts.Communicate(TEXT, VOICE)
|
||||
await communicate.save(OUTPUT_FILE)
|
||||
|
||||
import os
|
||||
|
||||
def is_folder_empty(folder_path):
|
||||
"""
|
||||
检测指定文件夹内是否有文件。
|
||||
|
||||
:param folder_path: 文件夹路径
|
||||
:return: 如果文件夹为空返回 True,否则返回 False
|
||||
"""
|
||||
# 获取文件夹中的所有条目(文件或子文件夹)
|
||||
entries = os.listdir(folder_path)
|
||||
# 检查是否存在文件
|
||||
for entry in entries:
|
||||
# 获取完整路径
|
||||
full_path = os.path.join(folder_path, entry)
|
||||
# 如果是文件,返回 False
|
||||
if os.path.isfile(full_path):
|
||||
return False
|
||||
# 如果没有文件,返回 True
|
||||
return True
|
||||
|
||||
|
||||
# -------- SenceVoice 语音识别 --模型加载-----
|
||||
model_dir = r"D:\AI\download\SenseVoiceSmall"
|
||||
model_senceVoice = AutoModel( model=model_dir, trust_remote_code=True, )
|
||||
|
||||
# -------- CAM++声纹识别 -- 模型加载 --------
|
||||
set_SV_enroll = r'.\SpeakerVerification_DIR\enroll_wav\\'
|
||||
sv_pipeline = pipeline(
|
||||
task='speaker-verification',
|
||||
model='damo/speech_campplus_sv_zh-cn_16k-common',
|
||||
model_revision='v1.0.0'
|
||||
)
|
||||
|
||||
# --------- QWen2.5大语言模型 ---------------
|
||||
# model_name = r"E:\2_PYTHON\Project\GPT\QWen\Qwen2.5-0.5B-Instruct"
|
||||
model_name = r"D:\AI\download\Qwen2.5-0.5B-Instruct"
|
||||
# model_name = r'E:\2_PYTHON\Project\GPT\QWen\Qwen2.5-7B-Instruct-GPTQ-Int4'
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
# ---------- 模型加载结束 -----------------------
|
||||
|
||||
class ChatMemory:
|
||||
def __init__(self, max_length=2048):
|
||||
self.history = []
|
||||
self.max_length = max_length # 最大输入长度
|
||||
|
||||
def add_to_history(self, user_input, model_response):
|
||||
"""
|
||||
添加用户输入和模型响应到历史记录。
|
||||
"""
|
||||
self.history.append(f"User: {user_input}")
|
||||
self.history.append(f"system: {model_response}")
|
||||
|
||||
def get_context(self):
|
||||
"""
|
||||
获取拼接后的对话上下文。
|
||||
"""
|
||||
context = "\n".join(self.history)
|
||||
# 截断上下文,使其不超过 max_length
|
||||
if len(context) > self.max_length:
|
||||
context = context[-self.max_length :]
|
||||
return context
|
||||
|
||||
# -------- memory 初始化 --------
|
||||
memory = ChatMemory(max_length=512)
|
||||
|
||||
def system_introduction(text):
|
||||
global audio_file_count
|
||||
global folder_path
|
||||
text = text
|
||||
print("LLM output:", text)
|
||||
used_speaker = "zh-CN-XiaoyiNeural"
|
||||
asyncio.run(amain(text, used_speaker, os.path.join(folder_path,f"sft_tmp_{audio_file_count}.mp3")))
|
||||
play_audio(f'{folder_path}/sft_tmp_{audio_file_count}.mp3')
|
||||
|
||||
def Inference(TEMP_AUDIO_FILE=f"{OUTPUT_DIR}/audio_0.wav"):
|
||||
'''
|
||||
1. 使用senceVoice做asr,转换为拼音,检测唤醒词
|
||||
- 首先检测声纹注册文件夹是否有注册文件,如果无,启动声纹注册
|
||||
2. 使用CAM++做声纹识别
|
||||
- 设置固定声纹注册语音目录,每次输入音频均进行声纹对比
|
||||
3. 以上两者均通过,则进行大模型推理
|
||||
'''
|
||||
global audio_file_count
|
||||
|
||||
global set_SV_enroll
|
||||
global flag_sv_enroll
|
||||
global thred_sv
|
||||
global flag_sv_used
|
||||
|
||||
global set_KWS
|
||||
global flag_KWS
|
||||
global flag_KWS_used
|
||||
|
||||
os.makedirs(set_SV_enroll, exist_ok=True)
|
||||
# --- 如果开启声纹识别,且声纹文件夹为空,则开始声纹注册。设定注册语音有效长度需大于3秒
|
||||
if flag_sv_used and is_folder_empty(set_SV_enroll):
|
||||
text = f"无声纹注册文件!请先注册声纹,需大于三秒哦~"
|
||||
print(text)
|
||||
system_introduction(text)
|
||||
flag_sv_enroll = 1
|
||||
|
||||
else:
|
||||
# -------- SenceVoice 推理 ---------
|
||||
input_file = (TEMP_AUDIO_FILE)
|
||||
res = model_senceVoice.generate(
|
||||
input=input_file,
|
||||
cache={},
|
||||
language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech"
|
||||
use_itn=False,
|
||||
)
|
||||
prompt = res[0]['text'].split(">")[-1]
|
||||
prompt_pinyin = extract_chinese_and_convert_to_pinyin(prompt)
|
||||
print(prompt, prompt_pinyin)
|
||||
|
||||
# --- 判断是否启动KWS
|
||||
if not flag_KWS_used:
|
||||
flag_KWS = 1
|
||||
if not flag_KWS:
|
||||
if set_KWS in prompt_pinyin:
|
||||
flag_KWS = 1
|
||||
|
||||
# --- KWS成功,或不设置KWS
|
||||
if flag_KWS:
|
||||
sv_score = sv_pipeline([os.path.join(set_SV_enroll, "enroll_0.wav"), TEMP_AUDIO_FILE], thr=thred_sv)
|
||||
print(sv_score)
|
||||
sv_result = sv_score['text']
|
||||
if sv_result == "yes":
|
||||
|
||||
# --- 读取历史对话 ---
|
||||
context = memory.get_context()
|
||||
|
||||
# prompt_tmp = res[0]['text'].split(">")[-1] + ",回答简短一些,保持50字以内!"
|
||||
prompt_tmp = res[0]['text'].split(">")[-1]
|
||||
prompt = f"{context}\nUser:{prompt_tmp}\n"
|
||||
|
||||
print("History:", context)
|
||||
print("ASR OUT:", prompt)
|
||||
# ---------SenceVoice --end----------
|
||||
# -------- 模型推理阶段,将语音识别结果作为大模型Prompt ------
|
||||
messages = [
|
||||
{"role": "system", "content": "你叫小千,是一个18岁的女大学生,性格活泼开朗,说话俏皮简洁。"},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||
|
||||
generated_ids = model.generate(
|
||||
**model_inputs,
|
||||
max_new_tokens=512,
|
||||
)
|
||||
generated_ids = [
|
||||
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
||||
]
|
||||
|
||||
output_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
print("answer", output_text)
|
||||
|
||||
# -------- 更新记忆库 -----
|
||||
memory.add_to_history(prompt_tmp, output_text)
|
||||
|
||||
# 输入文本
|
||||
text = output_text
|
||||
# 语种识别 -- langid
|
||||
language, confidence = langid.classify(text)
|
||||
# 语种识别 -- langdetect
|
||||
# language = detect(text).split("-")[0]
|
||||
|
||||
language_speaker = {
|
||||
"ja" : "ja-JP-NanamiNeural", # ok
|
||||
"fr" : "fr-FR-DeniseNeural", # ok
|
||||
"es" : "ca-ES-JoanaNeural", # ok
|
||||
"de" : "de-DE-KatjaNeural", # ok
|
||||
"zh" : "zh-CN-XiaoyiNeural", # ok
|
||||
"en" : "en-US-AnaNeural", # ok
|
||||
}
|
||||
|
||||
if language not in language_speaker.keys():
|
||||
used_speaker = "zh-CN-XiaoyiNeural"
|
||||
else:
|
||||
used_speaker = language_speaker[language]
|
||||
print("检测到语种:", language, "使用音色:", language_speaker[language])
|
||||
|
||||
asyncio.run(amain(text, used_speaker, os.path.join(folder_path,f"sft_{audio_file_count}.mp3")))
|
||||
play_audio(f'{folder_path}/sft_{audio_file_count}.mp3')
|
||||
else:
|
||||
text = "很抱歉,声纹验证失败,我无法为您服务"
|
||||
print(text)
|
||||
# system_introduction(text)
|
||||
else:
|
||||
text = "很抱歉,唤醒词错误,请说出正确的唤醒词哦"
|
||||
system_introduction(text)
|
||||
|
||||
# 主函数
|
||||
if __name__ == "__main__":
|
||||
|
||||
try:
|
||||
# 启动音视频录制线程
|
||||
audio_thread = threading.Thread(target=audio_recorder)
|
||||
# video_thread = threading.Thread(target=video_recorder)
|
||||
audio_thread.start()
|
||||
# video_thread.start()
|
||||
|
||||
flag_info = f'{flag_sv_used}-{flag_KWS_used}'
|
||||
dict_flag_info = {
|
||||
"1-1": "您已开启声纹识别和关键词唤醒,",
|
||||
"0-1":"您已开启关键词唤醒",
|
||||
"1-0":"您已开启声纹识别",
|
||||
"0-0":"",
|
||||
}
|
||||
if flag_sv_used or flag_KWS_used:
|
||||
text = dict_flag_info[flag_info]
|
||||
system_introduction(text)
|
||||
|
||||
print("按 Ctrl+C 停止录制")
|
||||
while True:
|
||||
time.sleep(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("录制停止中...")
|
||||
recording_active = False
|
||||
audio_thread.join()
|
||||
# video_thread.join()
|
||||
print("录制已停止")
|
35
1_Inference_CosyVoice.py
Normal file
@ -0,0 +1,35 @@
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice
|
||||
from cosyvoice.utils.file_utils import load_wav
|
||||
import torchaudio
|
||||
cosyvoice = CosyVoice(r'D:\AI\download\CosyVoice-300M', load_jit=True, load_onnx=False, fp16=True)
|
||||
print(cosyvoice.list_avaliable_spks())
|
||||
prompt_speech_16k = load_wav('vocal_3.mp3_10.wav_0006151680_0006360320.wav', 16000)
|
||||
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜', '可以动动你的小手点个关注,感谢各位好哥哥,如果之后有新消息,我还会在更新呢。', prompt_speech_16k, stream=False)):
|
||||
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], 22050)
|
||||
|
||||
# change stream=True for chunk stream inference
|
||||
# out = cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)
|
||||
# torchaudio.save('sft_0.wav', out['tts_speech'], 22050)
|
||||
|
||||
# for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
|
||||
# torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], 22050)
|
||||
|
||||
# cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-25Hz') # or change to pretrained_models/CosyVoice-300M for 50Hz inference
|
||||
# # zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
|
||||
# prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
|
||||
# for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
|
||||
# torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], 22050)
|
||||
# # cross_lingual usage
|
||||
# prompt_speech_16k = load_wav('cross_lingual_prompt.wav', 16000)
|
||||
# for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k, stream=False)):
|
||||
# torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], 22050)
|
||||
# # vc usage
|
||||
# prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
|
||||
# source_speech_16k = load_wav('cross_lingual_prompt.wav', 16000)
|
||||
# for i, j in enumerate(cosyvoice.inference_vc(source_speech_16k, prompt_speech_16k, stream=False)):
|
||||
# torchaudio.save('vc_{}.wav'.format(i), j['tts_speech'], 22050)
|
||||
|
||||
# cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct')
|
||||
# # instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
|
||||
# for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.', stream=False)):
|
||||
# torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], 22050)
|
46
2_record_test.py
Normal file
@ -0,0 +1,46 @@
|
||||
import pygame
|
||||
import time
|
||||
import sys
|
||||
import sounddevice as sd
|
||||
from scipy.io.wavfile import write
|
||||
import numpy as np
|
||||
|
||||
def record_audio(filename="output.wav", sample_rate=44100):
|
||||
print("按下 Enter 开始录音...")
|
||||
input() # 等待用户按下 Enter 键开始录音
|
||||
print("录音中... 按下 Enter 键结束录音")
|
||||
|
||||
# 开始录音
|
||||
recording = []
|
||||
try:
|
||||
def callback(indata, frames, time, status):
|
||||
recording.append(indata.copy())
|
||||
with sd.InputStream(samplerate=sample_rate, channels=1, callback=callback):
|
||||
input() # 等待用户再次按下 Enter 键结束录音
|
||||
except Exception as e:
|
||||
print(f"录音出现错误: {e}")
|
||||
return
|
||||
|
||||
# 将录音数据合并并保存为 WAV 文件
|
||||
audio_data = np.concatenate(recording, axis=0)
|
||||
write(filename, sample_rate, (audio_data * 32767).astype(np.int16))
|
||||
print(f"录音已保存为 {filename}")
|
||||
|
||||
|
||||
# --- 播放音频 -
|
||||
def play_audio(file_path):
|
||||
try:
|
||||
pygame.mixer.init()
|
||||
pygame.mixer.music.load(file_path)
|
||||
pygame.mixer.music.play()
|
||||
while pygame.mixer.music.get_busy():
|
||||
time.sleep(1) # 等待音频播放结束
|
||||
print("播放完成!")
|
||||
except Exception as e:
|
||||
print(f"播放失败: {e}")
|
||||
finally:
|
||||
pygame.mixer.quit()
|
||||
|
||||
# 使用函数录音,作为输入
|
||||
record_audio("my_recording.wav")
|
||||
play_audio('my_recording.wav')
|
155
3.04backend_service.py
Normal file
@ -0,0 +1,155 @@
|
||||
from flask import Flask, request, jsonify, send_from_directory
|
||||
from flask_cors import CORS
|
||||
import time
|
||||
import os
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from funasr import AutoModel
|
||||
import edge_tts
|
||||
import asyncio
|
||||
import langid
|
||||
import tempfile
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app)
|
||||
|
||||
# 配置参数
|
||||
AUDIO_RATE = 16000
|
||||
OUTPUT_DIR = "./output"
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
# 初始化模型
|
||||
print("Loading models...")
|
||||
model_dir = "D:/AI/download/SenseVoiceSmall"
|
||||
model_senceVoice = AutoModel(model=model_dir, trust_remote_code=True)
|
||||
|
||||
# 加载Qwen2.5大语言模型
|
||||
model_name = "D:/AI/download/Qwen2.5-1.5B-Instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
print("All models loaded!")
|
||||
|
||||
# 语言映射
|
||||
language_speaker = {
|
||||
"ja": "ja-JP-NanamiNeural",
|
||||
"fr": "fr-FR-DeniseNeural",
|
||||
"es": "ca-ES-JoanaNeural",
|
||||
"de": "de-DE-KatjaNeural",
|
||||
"zh": "zh-CN-XiaoyiNeural",
|
||||
"en": "en-US-AnaNeural",
|
||||
}
|
||||
|
||||
# 全局变量
|
||||
def process_audio(input_path):
|
||||
# 语音识别
|
||||
res = model_senceVoice.generate(
|
||||
input=input_path,
|
||||
cache={},
|
||||
language="auto",
|
||||
use_itn=False,
|
||||
)
|
||||
prompt = res[0]['text'].split(">")[-1]
|
||||
print("ASR OUT:", prompt)
|
||||
|
||||
# 大模型处理
|
||||
messages = [
|
||||
{"role": "system", "content": "你叫千问,是一个18岁的女大学生,性格活泼开朗,说话俏皮"},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||
generated_ids = model.generate(**model_inputs, max_new_tokens=512)
|
||||
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids
|
||||
in zip(model_inputs.input_ids, generated_ids)]
|
||||
output_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
print("Answer:", output_text)
|
||||
|
||||
# 语种识别
|
||||
lang, _ = langid.classify(output_text)
|
||||
speaker = language_speaker.get(lang, "zh-CN-XiaoyiNeural")
|
||||
|
||||
# 语音合成
|
||||
output_file = os.path.join(OUTPUT_DIR, f"response_{int(time.time())}.mp3")
|
||||
asyncio.run(edge_tts.Communicate(output_text, speaker).save(output_file))
|
||||
|
||||
return output_file
|
||||
|
||||
# 新增ASR专用接口
|
||||
@app.route('/asr', methods=['POST'])
|
||||
def handle_asr():
|
||||
if 'audio' not in request.files:
|
||||
return jsonify({"error": "No audio file provided"}), 400
|
||||
|
||||
try:
|
||||
audio_file = request.files['audio']
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
||||
audio_file.save(tmp.name)
|
||||
|
||||
# 语音识别
|
||||
res = model_senceVoice.generate(
|
||||
input=tmp.name,
|
||||
cache={},
|
||||
language="auto",
|
||||
use_itn=False,
|
||||
)
|
||||
asr_text = res[0]['text'].split(">")[-1]
|
||||
|
||||
return jsonify({
|
||||
"asr_text": asr_text
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
# 大模型处理接口
|
||||
@app.route('/generate', methods=['POST'])
|
||||
def handle_generate():
|
||||
try:
|
||||
data = request.get_json()
|
||||
asr_text = data.get('asr_text', '')
|
||||
|
||||
if not asr_text:
|
||||
return jsonify({"error": "No ASR text provided"}), 400
|
||||
|
||||
# 大模型处理
|
||||
messages = [
|
||||
{"role": "system", "content": "你叫千问,是一个18岁的女大学生,性格活泼开朗,说话俏皮"},
|
||||
{"role": "user", "content": asr_text},
|
||||
]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||
generated_ids = model.generate(**model_inputs, max_new_tokens=512)
|
||||
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids
|
||||
in zip(model_inputs.input_ids, generated_ids)]
|
||||
answer_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
lang, _ = langid.classify(answer_text)
|
||||
speaker = language_speaker.get(lang, "zh-CN-XiaoyiNeural")
|
||||
output_file = os.path.join(OUTPUT_DIR, f"response_{int(time.time())}.mp3")
|
||||
asyncio.run(edge_tts.Communicate(answer_text, speaker).save(output_file))
|
||||
|
||||
return jsonify({
|
||||
"answer_text": answer_text,
|
||||
"audio_url": f"/audio/{os.path.basename(output_file)}"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
@app.route('/audio/<filename>')
|
||||
def get_audio(filename):
|
||||
return send_from_directory(OUTPUT_DIR, filename)
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(port=5000, threaded=True)
|
158
3.07backend_service.py
Normal file
@ -0,0 +1,158 @@
|
||||
from flask import Flask, request, jsonify, send_from_directory
|
||||
from flask_cors import CORS
|
||||
import time
|
||||
import os
|
||||
import logging
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from funasr import AutoModel
|
||||
import edge_tts
|
||||
import asyncio
|
||||
import langid
|
||||
import tempfile
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app)
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('app.log'), # 文件日志
|
||||
logging.StreamHandler() # 控制台日志
|
||||
]
|
||||
)
|
||||
|
||||
# 配置参数
|
||||
AUDIO_RATE = 16000
|
||||
OUTPUT_DIR = "./output"
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
# 初始化模型
|
||||
app.logger.info("Loading models...")
|
||||
model_dir = "D:/AI/download/SenseVoiceSmall"
|
||||
model_senceVoice = AutoModel(model=model_dir, trust_remote_code=True)
|
||||
|
||||
# 加载Qwen2.5大语言模型
|
||||
model_name = "D:/AI/download/Qwen2.5-1.5B-Instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype="auto",
|
||||
device_map="auto"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
app.logger.info("All models loaded!")
|
||||
|
||||
# 语言映射
|
||||
language_speaker = {
|
||||
"ja": "ja-JP-NanamiNeural",
|
||||
"fr": "fr-FR-DeniseNeural",
|
||||
"es": "ca-ES-JoanaNeural",
|
||||
"de": "de-DE-KatjaNeural",
|
||||
"zh": "zh-CN-XiaoyiNeural",
|
||||
"en": "en-US-AnaNeural",
|
||||
}
|
||||
|
||||
# ---------------------- 接口路由 ----------------------
|
||||
@app.route('/asr', methods=['POST'])
|
||||
def handle_asr():
|
||||
"""处理语音识别请求"""
|
||||
if 'audio' not in request.files:
|
||||
return jsonify({"error": "No audio file provided"}), 400
|
||||
|
||||
try:
|
||||
audio_file = request.files['audio']
|
||||
app.logger.info(f"Received audio file: {audio_file.filename}")
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
||||
audio_file.save(tmp.name)
|
||||
|
||||
# 语音识别
|
||||
res = model_senceVoice.generate(
|
||||
input=tmp.name,
|
||||
cache={},
|
||||
language="auto",
|
||||
use_itn=False,
|
||||
)
|
||||
asr_text = res[0]['text'].split(">")[-1]
|
||||
app.logger.info(f"ASR识别结果: {asr_text}")
|
||||
|
||||
return jsonify({"asr_text": asr_text})
|
||||
|
||||
except Exception as e:
|
||||
app.logger.error(f"ASR处理异常: {str(e)}")
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
@app.route('/generate_text', methods=['POST'])
|
||||
def handle_generate_text():
|
||||
"""处理大模型文本生成请求"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
asr_text = data.get('asr_text', '')
|
||||
app.logger.info(f"收到ASR文本: {asr_text}")
|
||||
|
||||
if not asr_text:
|
||||
return jsonify({"error": "No ASR text provided"}), 400
|
||||
|
||||
# 构建对话模板
|
||||
messages = [
|
||||
{"role": "system", "content": "你叫千问,是一个18岁的女大学生,性格活泼开朗,说话俏皮"},
|
||||
{"role": "user", "content": asr_text},
|
||||
]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
# 生成回复
|
||||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||
generated_ids = model.generate(**model_inputs, max_new_tokens=512)
|
||||
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids
|
||||
in zip(model_inputs.input_ids, generated_ids)]
|
||||
answer_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
app.logger.info(f"大模型回复: {answer_text}")
|
||||
|
||||
return jsonify({"answer_text": answer_text})
|
||||
|
||||
except Exception as e:
|
||||
app.logger.error(f"文本生成异常: {str(e)}")
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
@app.route('/generate_audio', methods=['POST'])
|
||||
def handle_generate_audio():
|
||||
"""处理语音合成请求"""
|
||||
try:
|
||||
data = request.get_json()
|
||||
answer_text = data.get('answer_text', '')
|
||||
app.logger.info(f"收到待合成文本: {answer_text}")
|
||||
|
||||
if not answer_text:
|
||||
return jsonify({"error": "No answer text provided"}), 400
|
||||
|
||||
# 语种识别
|
||||
lang, _ = langid.classify(answer_text)
|
||||
speaker = language_speaker.get(lang, "zh-CN-XiaoyiNeural")
|
||||
app.logger.info(f"识别到语言: {lang}, 使用发音人: {speaker}")
|
||||
|
||||
# 语音合成
|
||||
output_file = os.path.join(OUTPUT_DIR, f"response_{int(time.time())}.mp3")
|
||||
asyncio.run(edge_tts.Communicate(answer_text, speaker).save(output_file))
|
||||
app.logger.info(f"语音合成完成,保存路径: {output_file}")
|
||||
|
||||
return jsonify({
|
||||
"audio_url": f"/audio/{os.path.basename(output_file)}"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
app.logger.error(f"语音合成异常: {str(e)}")
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
@app.route('/audio/<filename>')
|
||||
def get_audio(filename):
|
||||
"""音频文件下载接口"""
|
||||
return send_from_directory(OUTPUT_DIR, filename)
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.logger.info("服务启动,端口:5000")
|
||||
app.run(port=5000, threaded=True)
|
73
3_Inference_edgeTTS.py
Normal file
@ -0,0 +1,73 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Basic example of edge_tts usage.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import edge_tts
|
||||
|
||||
# TEXT = "今日特别报道,伊朗、舍沙特阿拉伯在北京中方见证下,决定恢复外交关系,重启大使馆并互相派驻外交大使"
|
||||
# VOICE = "zh-CN-YunxiNeural"
|
||||
# OUTPUT_FILE = "test_male.mp3"
|
||||
|
||||
# TEXT = "はそれぞれ"
|
||||
# VOICE = "ja-JP-NanamiNeural"
|
||||
# OUTPUT_FILE = "test_male.mp3"
|
||||
|
||||
# TEXT = "Estoy bien, gracias. ¿Y tú?"
|
||||
# VOICE = "ca-ES-JoanaNeural"
|
||||
# OUTPUT_FILE = "test_male.mp3"
|
||||
|
||||
# TEXT = "Ça va bien, merci. Et toi ?"
|
||||
# VOICE = "fr-FR-DeniseNeural"
|
||||
# OUTPUT_FILE = "test_male.mp3"
|
||||
|
||||
TEXT = "hello, whats your name"
|
||||
VOICE = "en-US-AnaNeural"
|
||||
OUTPUT_FILE = "test_male.mp3"
|
||||
|
||||
async def amain() -> None:
|
||||
"""Main function"""
|
||||
communicate = edge_tts.Communicate(TEXT, VOICE)
|
||||
await communicate.save(OUTPUT_FILE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(amain())
|
||||
|
||||
'''
|
||||
Name: zh-CN-XiaoyiNeural
|
||||
Gender: Female
|
||||
|
||||
Name: zh-CN-YunjianNeural
|
||||
Gender: Male
|
||||
|
||||
Name: zh-CN-YunxiNeural
|
||||
Gender: Male
|
||||
|
||||
Name: zh-CN-YunxiaNeural
|
||||
Gender: Male
|
||||
|
||||
Name: zh-CN-YunyangNeural
|
||||
Gender: Male
|
||||
|
||||
Name: zh-CN-liaoning-XiaobeiNeural
|
||||
Gender: Female
|
||||
|
||||
Name: zh-CN-shaanxi-XiaoniNeural
|
||||
Gender: Female
|
||||
|
||||
Name: zh-HK-HiuGaaiNeural
|
||||
Gender: Female
|
||||
|
||||
Name: zh-HK-HiuMaanNeural
|
||||
Gender: Female
|
||||
|
||||
Name: zh-HK-WanLungNeural
|
||||
Gender: Male
|
||||
|
||||
Name: zh-TW-HsiaoChenNeural
|
||||
Gender: Female
|
||||
'''
|
49
4_Inference_QWen2Audio.py
Normal file
@ -0,0 +1,49 @@
|
||||
from io import BytesIO
|
||||
from urllib.request import urlopen
|
||||
import librosa
|
||||
from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor
|
||||
|
||||
model_name = r".\QWen\Qwen2-Audio-7B-Instruct"
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
model = Qwen2AudioForConditionalGeneration.from_pretrained(model_name, device_map="cuda")
|
||||
|
||||
# conversation = [
|
||||
# {"role": "user", "content": [
|
||||
# {"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/guess_age_gender.wav"},
|
||||
# ]},
|
||||
# {"role": "assistant", "content": "Yes, the speaker is female and in her twenties."},
|
||||
# {"role": "user", "content": [
|
||||
# {"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/translate_to_chinese.wav"},
|
||||
# ]},
|
||||
# ]
|
||||
|
||||
# 定义对话
|
||||
conversation = [
|
||||
{"role": "user", "content": [
|
||||
{"type": "audio", "audio_path": r".\QWen\Qwen2-Audio-7B-Instruct\guess_age_gender.wav"},
|
||||
]},
|
||||
{"role": "assistant", "content": "Yes, the speaker is female and in her twenties."},
|
||||
{"role": "user", "content": [
|
||||
{"type": "audio", "audio_path": r".\QWen\Qwen2-Audio-7B-Instruct\translate_to_chinese.wav"},
|
||||
]},
|
||||
]
|
||||
|
||||
text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
|
||||
audios = []
|
||||
for message in conversation:
|
||||
if isinstance(message["content"], list):
|
||||
for ele in message["content"]:
|
||||
if ele["type"] == "audio":
|
||||
audios.append(librosa.load(
|
||||
BytesIO(urlopen(ele['audio_url']).read()),
|
||||
sr=processor.feature_extractor.sampling_rate)[0]
|
||||
)
|
||||
|
||||
inputs = processor(text=text, audios=audios, return_tensors="pt", padding=True)
|
||||
inputs.input_ids = inputs.input_ids.to("cuda")
|
||||
|
||||
generate_ids = model.generate(**inputs, max_length=256)
|
||||
generate_ids = generate_ids[:, inputs.input_ids.size(1):]
|
||||
|
||||
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
print("Answer:", response)
|
21
5_pyttsx3_test.py
Normal file
@ -0,0 +1,21 @@
|
||||
import pyttsx3
|
||||
# 初始化 TTS 引擎
|
||||
engine = pyttsx3.init()
|
||||
|
||||
voices = engine.getProperty('voices')
|
||||
# 打印出所有的音色信息
|
||||
for voice in voices:
|
||||
print(f'id = {voice.id}----name = {voice.name}')
|
||||
|
||||
# 设置语音属性
|
||||
engine.setProperty('rate', 150) # 语速engine.setProperty('volume', 0.9) # 音量(0.0 到 1.0)
|
||||
# 选择语音
|
||||
voices = engine.getProperty('voices')
|
||||
print(voices)
|
||||
engine.setProperty('voice', voices[0].id) # 使用第一个语音
|
||||
# 输入文本
|
||||
text = "你好,今天天气很好,适合爬山"
|
||||
# 朗读文本
|
||||
engine.say(text)
|
||||
# 等待朗读完成
|
||||
engine.runAndWait()
|
28
6_Inference_funasr.py
Normal file
@ -0,0 +1,28 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- encoding: utf-8 -*-
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
import sys
|
||||
from funasr import AutoModel
|
||||
|
||||
model_dir = r".\QWen\pretrained_models\SenseVoiceSmall"
|
||||
input_file = (
|
||||
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
|
||||
)
|
||||
|
||||
model = AutoModel(
|
||||
model=model_dir,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
res = model.generate(
|
||||
input=input_file,
|
||||
cache={},
|
||||
language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech"
|
||||
use_itn=False,
|
||||
)
|
||||
|
||||
print(res)
|
||||
# import pdb; pdb.set_trace()
|
||||
print(res[0]['text'].split(">")[-1])
|
110
7.1_test_record_AV.py
Normal file
@ -0,0 +1,110 @@
|
||||
import cv2
|
||||
import pyaudio
|
||||
import wave
|
||||
import threading
|
||||
import ffmpeg
|
||||
import os
|
||||
|
||||
# 配置音频参数
|
||||
AUDIO_FORMAT = pyaudio.paInt16
|
||||
CHANNELS = 1
|
||||
RATE = 44100
|
||||
CHUNK = 1024
|
||||
|
||||
# 配置视频参数
|
||||
FRAME_WIDTH = 640
|
||||
FRAME_HEIGHT = 480
|
||||
FRAME_RATE = 20.0
|
||||
|
||||
# 文件保存路径
|
||||
TEMP_AUDIO_FILE = "temp_audio.wav"
|
||||
TEMP_VIDEO_FILE = "temp_video.avi"
|
||||
OUTPUT_FILE = "output.mp4"
|
||||
|
||||
# 音频录制线程
|
||||
def record_audio(stop_event):
|
||||
audio = pyaudio.PyAudio()
|
||||
stream = audio.open(format=AUDIO_FORMAT, channels=CHANNELS, rate=RATE, input=True, frames_per_buffer=CHUNK)
|
||||
frames = []
|
||||
print("开始录音...")
|
||||
|
||||
while not stop_event.is_set():
|
||||
data = stream.read(CHUNK)
|
||||
frames.append(data)
|
||||
|
||||
print("录音结束。")
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
audio.terminate()
|
||||
|
||||
# 保存音频
|
||||
with wave.open(TEMP_AUDIO_FILE, 'wb') as wf:
|
||||
wf.setnchannels(CHANNELS)
|
||||
wf.setsampwidth(audio.get_sample_size(AUDIO_FORMAT))
|
||||
wf.setframerate(RATE)
|
||||
wf.writeframes(b''.join(frames))
|
||||
|
||||
# 视频录制线程
|
||||
def record_video(stop_event):
|
||||
cap = cv2.VideoCapture(0)
|
||||
cap.set(cv2.CAP_PROP_FRAME_WIDTH, FRAME_WIDTH)
|
||||
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, FRAME_HEIGHT)
|
||||
cap.set(cv2.CAP_PROP_FPS, FRAME_RATE)
|
||||
|
||||
fourcc = cv2.VideoWriter_fourcc(*'XVID')
|
||||
out = cv2.VideoWriter(TEMP_VIDEO_FILE, fourcc, FRAME_RATE, (FRAME_WIDTH, FRAME_HEIGHT))
|
||||
print("开始录像...")
|
||||
|
||||
while not stop_event.is_set():
|
||||
ret, frame = cap.read()
|
||||
if ret:
|
||||
out.write(frame)
|
||||
cv2.imshow('Recording Video', frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'): # 按 Q 退出摄像头窗口
|
||||
stop_event.set()
|
||||
else:
|
||||
break
|
||||
|
||||
print("录像结束。")
|
||||
cap.release()
|
||||
out.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
# 合并音视频
|
||||
def merge_audio_video(audio_file, video_file, output_file):
|
||||
print("正在合并音频和视频...")
|
||||
ffmpeg.input(video_file).output(audio_file, output_file, vcodec='copy', acodec='aac', strict='experimental').run(overwrite_output=True)
|
||||
print(f"合并完成,文件保存为: {output_file}")
|
||||
|
||||
# 主函数
|
||||
def main():
|
||||
stop_event = threading.Event()
|
||||
|
||||
# 启动音频和视频录制线程
|
||||
audio_thread = threading.Thread(target=record_audio, args=(stop_event,))
|
||||
video_thread = threading.Thread(target=record_video, args=(stop_event,))
|
||||
|
||||
print("按 Enter 键开始录制...")
|
||||
input() # 等待用户按下 Enter 键
|
||||
print("录制中... 再次按 Enter 键停止录制。")
|
||||
|
||||
audio_thread.start()
|
||||
video_thread.start()
|
||||
|
||||
input() # 等待用户再次按下 Enter 键
|
||||
stop_event.set()
|
||||
|
||||
audio_thread.join()
|
||||
video_thread.join()
|
||||
|
||||
# # 合并音频和视频
|
||||
# merge_audio_video(TEMP_AUDIO_FILE, TEMP_VIDEO_FILE, OUTPUT_FILE)
|
||||
|
||||
# # 清理临时文件
|
||||
# os.remove(TEMP_AUDIO_FILE)
|
||||
# os.remove(TEMP_VIDEO_FILE)
|
||||
|
||||
print("录制完成!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
189
7.2_test_record_QWen2_VL_AV.py
Normal file
@ -0,0 +1,189 @@
|
||||
import cv2
|
||||
import pyaudio
|
||||
import wave
|
||||
import threading
|
||||
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
||||
from qwen_vl_utils import process_vision_info
|
||||
import torch
|
||||
from funasr import AutoModel
|
||||
|
||||
# 配置音频参数
|
||||
AUDIO_FORMAT = pyaudio.paInt16
|
||||
CHANNELS = 1
|
||||
RATE = 44100
|
||||
CHUNK = 1024
|
||||
|
||||
# 配置视频参数
|
||||
FRAME_WIDTH = 640
|
||||
FRAME_HEIGHT = 480
|
||||
FRAME_RATE = 20.0
|
||||
|
||||
# 文件保存路径
|
||||
TEMP_AUDIO_FILE = "temp_audio.wav"
|
||||
TEMP_VIDEO_FILE = "temp_video.avi"
|
||||
# OUTPUT_FILE = "output.mp4"
|
||||
|
||||
# 音频录制线程
|
||||
def record_audio(stop_event):
|
||||
audio = pyaudio.PyAudio()
|
||||
stream = audio.open(format=AUDIO_FORMAT, channels=CHANNELS, rate=RATE, input=True, frames_per_buffer=CHUNK)
|
||||
frames = []
|
||||
print("开始录音...")
|
||||
|
||||
while not stop_event.is_set():
|
||||
data = stream.read(CHUNK)
|
||||
frames.append(data)
|
||||
|
||||
print("录音结束。")
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
audio.terminate()
|
||||
|
||||
# 保存音频
|
||||
with wave.open(TEMP_AUDIO_FILE, 'wb') as wf:
|
||||
wf.setnchannels(CHANNELS)
|
||||
wf.setsampwidth(audio.get_sample_size(AUDIO_FORMAT))
|
||||
wf.setframerate(RATE)
|
||||
wf.writeframes(b''.join(frames))
|
||||
|
||||
# 视频录制线程
|
||||
def record_video(stop_event):
|
||||
cap = cv2.VideoCapture(0)
|
||||
cap.set(cv2.CAP_PROP_FRAME_WIDTH, FRAME_WIDTH)
|
||||
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, FRAME_HEIGHT)
|
||||
cap.set(cv2.CAP_PROP_FPS, FRAME_RATE)
|
||||
|
||||
fourcc = cv2.VideoWriter_fourcc(*'XVID')
|
||||
out = cv2.VideoWriter(TEMP_VIDEO_FILE, fourcc, FRAME_RATE, (FRAME_WIDTH, FRAME_HEIGHT))
|
||||
print("开始录像...")
|
||||
|
||||
while not stop_event.is_set():
|
||||
ret, frame = cap.read()
|
||||
if ret:
|
||||
out.write(frame)
|
||||
cv2.imshow('Recording Video', frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'): # 按 Q 退出摄像头窗口
|
||||
stop_event.set()
|
||||
else:
|
||||
break
|
||||
|
||||
print("录像结束。")
|
||||
cap.release()
|
||||
out.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
# 合并音视频
|
||||
def merge_audio_video(audio_file, video_file, output_file):
|
||||
print("正在合并音频和视频...")
|
||||
ffmpeg.input(video_file).output(audio_file, output_file, vcodec='copy', acodec='aac', strict='experimental').run(overwrite_output=True)
|
||||
print(f"合并完成,文件保存为: {output_file}")
|
||||
|
||||
# 主函数
|
||||
def main():
|
||||
stop_event = threading.Event()
|
||||
|
||||
# 启动音频和视频录制线程
|
||||
audio_thread = threading.Thread(target=record_audio, args=(stop_event,))
|
||||
video_thread = threading.Thread(target=record_video, args=(stop_event,))
|
||||
|
||||
print("按 Enter 键开始录制...")
|
||||
input() # 等待用户按下 Enter 键
|
||||
print("录制中... 再次按 Enter 键停止录制。")
|
||||
|
||||
audio_thread.start()
|
||||
video_thread.start()
|
||||
|
||||
input() # 等待用户再次按下 Enter 键
|
||||
stop_event.set()
|
||||
|
||||
audio_thread.join()
|
||||
video_thread.join()
|
||||
|
||||
# # 合并音频和视频
|
||||
# merge_audio_video(TEMP_AUDIO_FILE, TEMP_VIDEO_FILE, OUTPUT_FILE)
|
||||
|
||||
# # 清理临时文件
|
||||
# os.remove(TEMP_AUDIO_FILE)
|
||||
# os.remove(TEMP_VIDEO_FILE)
|
||||
|
||||
print("录制完成!")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
file_path = "captured_image.jpg" # 设置保存路径
|
||||
cap = cv2.VideoCapture(TEMP_VIDEO_FILE)
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
frame_index = int(total_frames // 2)
|
||||
# 设置视频帧位置
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
print(f"无法读取帧索引 {frame_index}")
|
||||
else:
|
||||
# 显示帧
|
||||
cv2.imwrite(file_path, frame)
|
||||
# cv2.imshow(f"Frame {frame_index}", frame)
|
||||
|
||||
# -------------- Load QWen2-VL Model ------------
|
||||
# default: Load the model on the available device(s)
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto"
|
||||
)
|
||||
# ------- 设置分辨率,降低现存占用 -------
|
||||
min_pixels = 256*28*28
|
||||
max_pixels = 1280*28*28
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
# --------------------------------------
|
||||
|
||||
# -------- SenceVoice 语音识别 -------
|
||||
model_dir = r"E:\2_PYTHON\Project\GPT\QWen\pretrained_models\SenseVoiceSmall"
|
||||
model_senceVoice = AutoModel( model=model_dir, trust_remote_code=True, )
|
||||
input_file = (TEMP_AUDIO_FILE)
|
||||
res = model_senceVoice.generate(
|
||||
input=input_file,
|
||||
cache={},
|
||||
language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech"
|
||||
use_itn=False,
|
||||
)
|
||||
prompt = res[0]['text'].split(">")[-1]
|
||||
# ---------SenceVoice --end----------
|
||||
|
||||
# -------- QWen2-VL 模型推理 ---------
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": f"{file_path}",
|
||||
},
|
||||
{"type": "text", "text": f"{prompt}"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# Preparation for inference
|
||||
text = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to("cuda")
|
||||
|
||||
# Inference: Generation of the output
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
print(output_text)
|
235
7.3_test_record_QWen2_VL_AV_TTS.py
Normal file
@ -0,0 +1,235 @@
|
||||
import os
|
||||
import cv2
|
||||
import pyaudio
|
||||
import wave
|
||||
import threading
|
||||
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
||||
from qwen_vl_utils import process_vision_info
|
||||
import torch
|
||||
from funasr import AutoModel
|
||||
import pygame
|
||||
import time
|
||||
import edge_tts
|
||||
import asyncio
|
||||
from time import sleep
|
||||
|
||||
# --- 播放音频 -
|
||||
def play_audio(file_path):
|
||||
try:
|
||||
pygame.mixer.init()
|
||||
pygame.mixer.music.load(file_path)
|
||||
pygame.mixer.music.play()
|
||||
while pygame.mixer.music.get_busy():
|
||||
time.sleep(1) # 等待音频播放结束
|
||||
print("播放完成!")
|
||||
except Exception as e:
|
||||
print(f"播放失败: {e}")
|
||||
finally:
|
||||
pygame.mixer.quit()
|
||||
|
||||
async def amain(TEXT, VOICE, OUTPUT_FILE) -> None:
|
||||
"""Main function"""
|
||||
communicate = edge_tts.Communicate(TEXT, VOICE)
|
||||
await communicate.save(OUTPUT_FILE)
|
||||
|
||||
# 配置音频参数
|
||||
AUDIO_FORMAT = pyaudio.paInt16
|
||||
CHANNELS = 1
|
||||
RATE = 44100
|
||||
CHUNK = 1024
|
||||
|
||||
# 配置视频参数
|
||||
FRAME_WIDTH = 640
|
||||
FRAME_HEIGHT = 480
|
||||
FRAME_RATE = 20.0
|
||||
|
||||
# 文件保存路径
|
||||
TEMP_AUDIO_FILE = "temp_audio.wav"
|
||||
TEMP_VIDEO_FILE = "temp_video.avi"
|
||||
# OUTPUT_FILE = "output.mp4"
|
||||
|
||||
# 音频录制线程
|
||||
def record_audio(stop_event):
|
||||
# time.sleep(5)
|
||||
audio = pyaudio.PyAudio()
|
||||
stream = audio.open(format=AUDIO_FORMAT, channels=CHANNELS, rate=RATE, input=True, frames_per_buffer=CHUNK)
|
||||
frames = []
|
||||
print("开始录音...")
|
||||
|
||||
while not stop_event.is_set():
|
||||
data = stream.read(CHUNK)
|
||||
frames.append(data)
|
||||
|
||||
print("录音结束。")
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
audio.terminate()
|
||||
|
||||
# 保存音频
|
||||
with wave.open(TEMP_AUDIO_FILE, 'wb') as wf:
|
||||
wf.setnchannels(CHANNELS)
|
||||
wf.setsampwidth(audio.get_sample_size(AUDIO_FORMAT))
|
||||
wf.setframerate(RATE)
|
||||
wf.writeframes(b''.join(frames))
|
||||
|
||||
# 视频录制线程
|
||||
def record_video(stop_event):
|
||||
# time.sleep(5)
|
||||
cap = cv2.VideoCapture(0)
|
||||
cap.set(cv2.CAP_PROP_FRAME_WIDTH, FRAME_WIDTH)
|
||||
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, FRAME_HEIGHT)
|
||||
cap.set(cv2.CAP_PROP_FPS, FRAME_RATE)
|
||||
|
||||
fourcc = cv2.VideoWriter_fourcc(*'XVID')
|
||||
out = cv2.VideoWriter(TEMP_VIDEO_FILE, fourcc, FRAME_RATE, (FRAME_WIDTH, FRAME_HEIGHT))
|
||||
print("开始录像...")
|
||||
|
||||
while not stop_event.is_set():
|
||||
ret, frame = cap.read()
|
||||
if ret:
|
||||
out.write(frame)
|
||||
cv2.imshow('Recording Video', frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'): # 按 Q 退出摄像头窗口
|
||||
stop_event.set()
|
||||
else:
|
||||
break
|
||||
|
||||
print("录像结束。")
|
||||
cap.release()
|
||||
out.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
# 合并音视频
|
||||
def merge_audio_video(audio_file, video_file, output_file):
|
||||
print("正在合并音频和视频...")
|
||||
ffmpeg.input(video_file).output(audio_file, output_file, vcodec='copy', acodec='aac', strict='experimental').run(overwrite_output=True)
|
||||
print(f"合并完成,文件保存为: {output_file}")
|
||||
|
||||
# 主函数
|
||||
def main():
|
||||
stop_event = threading.Event()
|
||||
|
||||
# 启动音频和视频录制线程
|
||||
audio_thread = threading.Thread(target=record_audio, args=(stop_event,))
|
||||
video_thread = threading.Thread(target=record_video, args=(stop_event,))
|
||||
|
||||
print("按 Enter 键开始录制...")
|
||||
input() # 等待用户按下 Enter 键
|
||||
print("录制中... 再次按 Enter 键停止录制。")
|
||||
|
||||
audio_thread.start()
|
||||
video_thread.start()
|
||||
|
||||
input() # 等待用户再次按下 Enter 键
|
||||
stop_event.set()
|
||||
|
||||
audio_thread.join()
|
||||
video_thread.join()
|
||||
|
||||
# # 合并音频和视频
|
||||
# merge_audio_video(TEMP_AUDIO_FILE, TEMP_VIDEO_FILE, OUTPUT_FILE)
|
||||
|
||||
# # 清理临时文件
|
||||
# os.remove(TEMP_AUDIO_FILE)
|
||||
# os.remove(TEMP_VIDEO_FILE)
|
||||
|
||||
print("录制完成!")
|
||||
|
||||
|
||||
# -------------- Load QWen2-VL Model ------------
|
||||
# default: Load the model on the available device(s)
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto"
|
||||
)
|
||||
# ------- 设置分辨率,降低现存占用 -------
|
||||
min_pixels = 256*28*28
|
||||
max_pixels = 512*28*28
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
# --------------------------------------
|
||||
|
||||
# -------- SenceVoice 语音识别 --模型加载-----
|
||||
model_dir = r"E:\2_PYTHON\Project\GPT\QWen\pretrained_models\SenseVoiceSmall"
|
||||
model_senceVoice = AutoModel( model=model_dir, trust_remote_code=True, )
|
||||
|
||||
if __name__ == "__main__":
|
||||
while 1:
|
||||
main()
|
||||
folder_path = "./Test_QWen2_VL/"
|
||||
os.makedirs(folder_path, exist_ok=True)
|
||||
|
||||
file_path = os.path.join(folder_path, "captured_image.jpg") # 设置保存路径
|
||||
cap = cv2.VideoCapture(TEMP_VIDEO_FILE)
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
frame_index = int(total_frames // 2)
|
||||
# 设置视频帧位置
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
print(f"无法读取帧索引 {frame_index}")
|
||||
else:
|
||||
# 显示帧
|
||||
cv2.imwrite(file_path, frame)
|
||||
# cv2.imshow(f"Frame {frame_index}", frame)
|
||||
|
||||
# -------- SenceVoice 推理 ---------
|
||||
input_file = (TEMP_AUDIO_FILE)
|
||||
res = model_senceVoice.generate(
|
||||
input=input_file,
|
||||
cache={},
|
||||
language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech"
|
||||
use_itn=False,
|
||||
)
|
||||
prompt = res[0]['text'].split(">")[-1]
|
||||
# ---------SenceVoice --end----------
|
||||
|
||||
# -------- QWen2-VL 模型推理 ---------
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": f"{file_path}",
|
||||
},
|
||||
{"type": "text", "text": f"{prompt}"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# Preparation for inference
|
||||
text = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to("cuda")
|
||||
|
||||
# Inference: Generation of the output
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
print(output_text)
|
||||
|
||||
# 输入文本
|
||||
text = output_text[0]
|
||||
# asyncio.run(amain(text, "zh-CN-YunxiaNeural", os.path.join(folder_path,"sft_0.mp3")))
|
||||
# play_audio(f'{folder_path}/sft_0.mp3')
|
||||
|
||||
asyncio.run(amain(text, "zh-CN-XiaoyiNeural", os.path.join(folder_path,"sft_0.mp3")))
|
||||
play_audio(f'{folder_path}/sft_0.mp3')
|
||||
|
||||
# asyncio.run(amain(text, "zh-CN-YunjianNeural", os.path.join(folder_path,"sft_0.mp3")))
|
||||
# play_audio(f'{folder_path}/sft_0.mp3')
|
||||
|
||||
# asyncio.run(amain(text, "zh-CN-shaanxi-XiaoniNeural", os.path.join(folder_path,"sft_0.mp3")))
|
||||
# play_audio(f'{folder_path}/sft_0.mp3')
|
316
7.5_realTime_AV.py
Normal file
@ -0,0 +1,316 @@
|
||||
import cv2
|
||||
import pyaudio
|
||||
import wave
|
||||
import threading
|
||||
import numpy as np
|
||||
import time
|
||||
from queue import Queue
|
||||
import webrtcvad
|
||||
import os
|
||||
import threading
|
||||
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
||||
from qwen_vl_utils import process_vision_info
|
||||
import torch
|
||||
from funasr import AutoModel
|
||||
import pygame
|
||||
import edge_tts
|
||||
import asyncio
|
||||
from time import sleep
|
||||
|
||||
# 参数设置
|
||||
AUDIO_RATE = 16000 # 音频采样率
|
||||
AUDIO_CHANNELS = 1 # 单声道
|
||||
CHUNK = 1024 # 音频块大小
|
||||
VAD_MODE = 3 # VAD 模式 (0-3, 数字越大越敏感)
|
||||
OUTPUT_DIR = "./output" # 输出目录
|
||||
NO_SPEECH_THRESHOLD = 1 # 无效语音阈值,单位:秒
|
||||
|
||||
# 确保输出目录存在
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
# 队列用于音频和视频同步缓存
|
||||
audio_queue = Queue()
|
||||
video_queue = Queue()
|
||||
|
||||
# 全局变量
|
||||
last_active_time = time.time()
|
||||
recording_active = True
|
||||
segments_to_save = []
|
||||
saved_intervals = []
|
||||
last_vad_end_time = 0 # 上次保存的 VAD 有效段结束时间
|
||||
|
||||
# 初始化 WebRTC VAD
|
||||
vad = webrtcvad.Vad()
|
||||
vad.set_mode(VAD_MODE)
|
||||
|
||||
# 音频录制线程
|
||||
def audio_recorder():
|
||||
global audio_queue, recording_active, last_active_time, segments_to_save, last_vad_end_time
|
||||
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(format=pyaudio.paInt16,
|
||||
channels=AUDIO_CHANNELS,
|
||||
rate=AUDIO_RATE,
|
||||
input=True,
|
||||
frames_per_buffer=CHUNK)
|
||||
|
||||
audio_buffer = []
|
||||
print("音频录制已开始")
|
||||
|
||||
while recording_active:
|
||||
data = stream.read(CHUNK)
|
||||
audio_buffer.append(data)
|
||||
|
||||
# 每 0.5 秒检测一次 VAD
|
||||
if len(audio_buffer) * CHUNK / AUDIO_RATE >= 0.5:
|
||||
# 拼接音频数据并检测 VAD
|
||||
raw_audio = b''.join(audio_buffer)
|
||||
vad_result = check_vad_activity(raw_audio)
|
||||
|
||||
if vad_result:
|
||||
print("检测到语音活动")
|
||||
last_active_time = time.time()
|
||||
segments_to_save.append((raw_audio, time.time()))
|
||||
else:
|
||||
print("静音中...")
|
||||
|
||||
audio_buffer = [] # 清空缓冲区
|
||||
|
||||
# 检查无效语音时间
|
||||
if time.time() - last_active_time > NO_SPEECH_THRESHOLD:
|
||||
# 检查是否需要保存
|
||||
if segments_to_save and segments_to_save[-1][1] > last_vad_end_time:
|
||||
save_audio_video()
|
||||
last_active_time = time.time()
|
||||
else:
|
||||
pass
|
||||
# print("无新增语音段,跳过保存")
|
||||
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
||||
|
||||
# 视频录制线程
|
||||
def video_recorder():
|
||||
global video_queue, recording_active
|
||||
|
||||
cap = cv2.VideoCapture(0) # 使用默认摄像头
|
||||
print("视频录制已开始")
|
||||
|
||||
while recording_active:
|
||||
ret, frame = cap.read()
|
||||
if ret:
|
||||
video_queue.put((frame, time.time()))
|
||||
|
||||
# 实时显示摄像头画面
|
||||
cv2.imshow("Real Camera", frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'): # 按 Q 键退出
|
||||
break
|
||||
else:
|
||||
print("无法获取摄像头画面")
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
# 检测 VAD 活动
|
||||
def check_vad_activity(audio_data):
|
||||
# 将音频数据分块检测
|
||||
num, rate = 0, 0.4
|
||||
step = int(AUDIO_RATE * 0.02) # 20ms 块大小
|
||||
flag_rate = round(rate * len(audio_data) // step)
|
||||
|
||||
for i in range(0, len(audio_data), step):
|
||||
chunk = audio_data[i:i + step]
|
||||
if len(chunk) == step:
|
||||
if vad.is_speech(chunk, sample_rate=AUDIO_RATE):
|
||||
num += 1
|
||||
|
||||
if num > flag_rate:
|
||||
return True
|
||||
return False
|
||||
|
||||
# 保存音频和视频
|
||||
def save_audio_video():
|
||||
global segments_to_save, video_queue, last_vad_end_time, saved_intervals
|
||||
|
||||
if not segments_to_save:
|
||||
return
|
||||
|
||||
# 获取有效段的时间范围
|
||||
start_time = segments_to_save[0][1]
|
||||
end_time = segments_to_save[-1][1]
|
||||
|
||||
# 检查是否与之前的片段重叠
|
||||
if saved_intervals and saved_intervals[-1][1] >= start_time:
|
||||
print("当前片段与之前片段重叠,跳过保存")
|
||||
segments_to_save.clear()
|
||||
return
|
||||
|
||||
# 保存音频
|
||||
audio_frames = [seg[0] for seg in segments_to_save]
|
||||
audio_output_path = f"{OUTPUT_DIR}/audio_0.wav"
|
||||
wf = wave.open(audio_output_path, 'wb')
|
||||
wf.setnchannels(AUDIO_CHANNELS)
|
||||
wf.setsampwidth(2) # 16-bit PCM
|
||||
wf.setframerate(AUDIO_RATE)
|
||||
wf.writeframes(b''.join(audio_frames))
|
||||
wf.close()
|
||||
print(f"音频保存至 {audio_output_path}")
|
||||
|
||||
# 保存视频
|
||||
video_frames = []
|
||||
while not video_queue.empty():
|
||||
frame, timestamp = video_queue.get()
|
||||
if start_time <= timestamp <= end_time:
|
||||
video_frames.append(frame)
|
||||
|
||||
if video_frames:
|
||||
video_output_path = f"{OUTPUT_DIR}/video_0.avi"
|
||||
out = cv2.VideoWriter(video_output_path, cv2.VideoWriter_fourcc(*'XVID'), 20.0, (640, 480))
|
||||
for frame in video_frames:
|
||||
out.write(frame)
|
||||
out.release()
|
||||
print(f"视频保存至 {video_output_path}")
|
||||
Inference()
|
||||
else:
|
||||
pass
|
||||
# print("无可保存的视频帧")
|
||||
|
||||
# 记录保存的区间
|
||||
saved_intervals.append((start_time, end_time))
|
||||
|
||||
# 清空缓冲区
|
||||
segments_to_save.clear()
|
||||
|
||||
# --- 播放音频 -
|
||||
def play_audio(file_path):
|
||||
try:
|
||||
pygame.mixer.init()
|
||||
pygame.mixer.music.load(file_path)
|
||||
pygame.mixer.music.play()
|
||||
while pygame.mixer.music.get_busy():
|
||||
time.sleep(1) # 等待音频播放结束
|
||||
print("播放完成!")
|
||||
except Exception as e:
|
||||
print(f"播放失败: {e}")
|
||||
finally:
|
||||
pygame.mixer.quit()
|
||||
|
||||
async def amain(TEXT, VOICE, OUTPUT_FILE) -> None:
|
||||
"""Main function"""
|
||||
communicate = edge_tts.Communicate(TEXT, VOICE)
|
||||
await communicate.save(OUTPUT_FILE)
|
||||
|
||||
# -------------- Load QWen2-VL Model ------------
|
||||
# default: Load the model on the available device(s)
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto"
|
||||
)
|
||||
# ------- 设置分辨率,降低现存占用 -------
|
||||
min_pixels = 256*28*28
|
||||
max_pixels = 512*28*28
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
# --------------------------------------
|
||||
|
||||
# -------- SenceVoice 语音识别 --模型加载-----
|
||||
model_dir = r"E:\2_PYTHON\Project\GPT\QWen\pretrained_models\SenseVoiceSmall"
|
||||
model_senceVoice = AutoModel( model=model_dir, trust_remote_code=True, )
|
||||
folder_path = "./Test_QWen2_VL/"
|
||||
|
||||
def Inference(TEMP_VIDEO_FILE=f"{OUTPUT_DIR}/video_0.avi", TEMP_AUDIO_FILE=f"{OUTPUT_DIR}/audio_0.wav"):
|
||||
file_path = os.path.join(folder_path, "captured_image.jpg") # 设置保存路径
|
||||
cap = cv2.VideoCapture(TEMP_VIDEO_FILE)
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
frame_index = int(total_frames // 2)
|
||||
# 设置视频帧位置
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
print(f"无法读取帧索引 {frame_index}")
|
||||
else:
|
||||
# 显示帧
|
||||
cv2.imwrite(file_path, frame)
|
||||
# cv2.imshow(f"Frame {frame_index}", frame)
|
||||
|
||||
# -------- SenceVoice 推理 ---------
|
||||
input_file = (TEMP_AUDIO_FILE)
|
||||
res = model_senceVoice.generate(
|
||||
input=input_file,
|
||||
cache={},
|
||||
language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech"
|
||||
use_itn=False,
|
||||
)
|
||||
prompt = res[0]['text'].split(">")[-1]
|
||||
# ---------SenceVoice --end----------
|
||||
|
||||
# -------- QWen2-VL 模型推理 ---------
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": f"{file_path}",
|
||||
},
|
||||
{"type": "text", "text": f"{prompt}"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# Preparation for inference
|
||||
text = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to("cuda")
|
||||
|
||||
# Inference: Generation of the output
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
print(output_text)
|
||||
|
||||
# 输入文本
|
||||
text = output_text[0]
|
||||
# asyncio.run(amain(text, "zh-CN-YunxiaNeural", os.path.join(folder_path,"sft_0.mp3")))
|
||||
# play_audio(f'{folder_path}/sft_0.mp3')
|
||||
|
||||
asyncio.run(amain(text, "zh-CN-XiaoyiNeural", os.path.join(folder_path,"sft_0.mp3")))
|
||||
play_audio(f'{folder_path}/sft_0.mp3')
|
||||
|
||||
# asyncio.run(amain(text, "zh-CN-YunjianNeural", os.path.join(folder_path,"sft_0.mp3")))
|
||||
# play_audio(f'{folder_path}/sft_0.mp3')
|
||||
|
||||
# asyncio.run(amain(text, "zh-CN-shaanxi-XiaoniNeural", os.path.join(folder_path,"sft_0.mp3")))
|
||||
# play_audio(f'{folder_path}/sft_0.mp3')
|
||||
|
||||
# 主函数
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
# 启动音视频录制线程
|
||||
audio_thread = threading.Thread(target=audio_recorder)
|
||||
video_thread = threading.Thread(target=video_recorder)
|
||||
audio_thread.start()
|
||||
video_thread.start()
|
||||
|
||||
print("按 Ctrl+C 停止录制")
|
||||
while True:
|
||||
time.sleep(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("录制停止中...")
|
||||
recording_active = False
|
||||
audio_thread.join()
|
||||
video_thread.join()
|
||||
print("录制已停止")
|
316
7.6_realTime_debug.py
Normal file
@ -0,0 +1,316 @@
|
||||
import cv2
|
||||
import pyaudio
|
||||
import wave
|
||||
import threading
|
||||
import numpy as np
|
||||
import time
|
||||
from queue import Queue
|
||||
import webrtcvad
|
||||
import os
|
||||
import threading
|
||||
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
||||
from qwen_vl_utils import process_vision_info
|
||||
import torch
|
||||
from funasr import AutoModel
|
||||
import pygame
|
||||
import edge_tts
|
||||
import asyncio
|
||||
from time import sleep
|
||||
|
||||
# 参数设置
|
||||
AUDIO_RATE = 16000 # 音频采样率
|
||||
AUDIO_CHANNELS = 1 # 单声道
|
||||
CHUNK = 1024 # 音频块大小
|
||||
VAD_MODE = 3 # VAD 模式 (0-3, 数字越大越敏感)
|
||||
OUTPUT_DIR = "./output" # 输出目录
|
||||
NO_SPEECH_THRESHOLD = 1 # 无效语音阈值,单位:秒
|
||||
|
||||
# 确保输出目录存在
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
# 队列用于音频和视频同步缓存
|
||||
audio_queue = Queue()
|
||||
video_queue = Queue()
|
||||
|
||||
# 全局变量
|
||||
last_active_time = time.time()
|
||||
recording_active = True
|
||||
segments_to_save = []
|
||||
saved_intervals = []
|
||||
last_vad_end_time = 0 # 上次保存的 VAD 有效段结束时间
|
||||
|
||||
# 初始化 WebRTC VAD
|
||||
vad = webrtcvad.Vad()
|
||||
vad.set_mode(VAD_MODE)
|
||||
|
||||
# 音频录制线程
|
||||
def audio_recorder():
|
||||
global audio_queue, recording_active, last_active_time, segments_to_save, last_vad_end_time
|
||||
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(format=pyaudio.paInt16,
|
||||
channels=AUDIO_CHANNELS,
|
||||
rate=AUDIO_RATE,
|
||||
input=True,
|
||||
frames_per_buffer=CHUNK)
|
||||
|
||||
audio_buffer = []
|
||||
print("音频录制已开始")
|
||||
|
||||
while recording_active:
|
||||
data = stream.read(CHUNK)
|
||||
audio_buffer.append(data)
|
||||
|
||||
# 每 0.5 秒检测一次 VAD
|
||||
if len(audio_buffer) * CHUNK / AUDIO_RATE >= 0.5:
|
||||
# 拼接音频数据并检测 VAD
|
||||
raw_audio = b''.join(audio_buffer)
|
||||
vad_result = check_vad_activity(raw_audio)
|
||||
|
||||
if vad_result:
|
||||
print("检测到语音活动")
|
||||
last_active_time = time.time()
|
||||
segments_to_save.append((raw_audio, time.time()))
|
||||
else:
|
||||
print("静音中...")
|
||||
|
||||
audio_buffer = [] # 清空缓冲区
|
||||
|
||||
# 检查无效语音时间
|
||||
if time.time() - last_active_time > NO_SPEECH_THRESHOLD:
|
||||
# 检查是否需要保存
|
||||
if segments_to_save and segments_to_save[-1][1] > last_vad_end_time:
|
||||
save_audio_video()
|
||||
last_active_time = time.time()
|
||||
else:
|
||||
pass
|
||||
# print("无新增语音段,跳过保存")
|
||||
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
||||
|
||||
# 视频录制线程
|
||||
def video_recorder():
|
||||
global video_queue, recording_active
|
||||
|
||||
cap = cv2.VideoCapture(0) # 使用默认摄像头
|
||||
print("视频录制已开始")
|
||||
|
||||
while recording_active:
|
||||
ret, frame = cap.read()
|
||||
if ret:
|
||||
video_queue.put((frame, time.time()))
|
||||
|
||||
# 实时显示摄像头画面
|
||||
cv2.imshow("Real Camera", frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'): # 按 Q 键退出
|
||||
break
|
||||
else:
|
||||
print("无法获取摄像头画面")
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
# 检测 VAD 活动
|
||||
def check_vad_activity(audio_data):
|
||||
# 将音频数据分块检测
|
||||
num = 0
|
||||
step = int(AUDIO_RATE * 0.02) # 20ms 块大小
|
||||
flag_rate = round(0.8 * len(audio_data) // step)
|
||||
|
||||
for i in range(0, len(audio_data), step):
|
||||
chunk = audio_data[i:i + step]
|
||||
if len(chunk) == step:
|
||||
if vad.is_speech(chunk, sample_rate=AUDIO_RATE):
|
||||
num += 1
|
||||
|
||||
if num > flag_rate:
|
||||
return True
|
||||
return False
|
||||
|
||||
# 保存音频和视频
|
||||
def save_audio_video():
|
||||
global segments_to_save, video_queue, last_vad_end_time, saved_intervals
|
||||
|
||||
if not segments_to_save:
|
||||
return
|
||||
|
||||
# 获取有效段的时间范围
|
||||
start_time = segments_to_save[0][1]
|
||||
end_time = segments_to_save[-1][1]
|
||||
|
||||
# 检查是否与之前的片段重叠
|
||||
if saved_intervals and saved_intervals[-1][1] >= start_time:
|
||||
print("当前片段与之前片段重叠,跳过保存")
|
||||
segments_to_save.clear()
|
||||
return
|
||||
|
||||
# 保存音频
|
||||
audio_frames = [seg[0] for seg in segments_to_save]
|
||||
audio_output_path = f"{OUTPUT_DIR}/audio_0.wav"
|
||||
wf = wave.open(audio_output_path, 'wb')
|
||||
wf.setnchannels(AUDIO_CHANNELS)
|
||||
wf.setsampwidth(2) # 16-bit PCM
|
||||
wf.setframerate(AUDIO_RATE)
|
||||
wf.writeframes(b''.join(audio_frames))
|
||||
wf.close()
|
||||
print(f"音频保存至 {audio_output_path}")
|
||||
|
||||
# 保存视频
|
||||
video_frames = []
|
||||
while not video_queue.empty():
|
||||
frame, timestamp = video_queue.get()
|
||||
if start_time <= timestamp <= end_time:
|
||||
video_frames.append(frame)
|
||||
|
||||
if video_frames:
|
||||
video_output_path = f"{OUTPUT_DIR}/video_0.avi"
|
||||
out = cv2.VideoWriter(video_output_path, cv2.VideoWriter_fourcc(*'XVID'), 20.0, (640, 480))
|
||||
for frame in video_frames:
|
||||
out.write(frame)
|
||||
out.release()
|
||||
print(f"视频保存至 {video_output_path}")
|
||||
# Inference()
|
||||
else:
|
||||
pass
|
||||
# print("无可保存的视频帧")
|
||||
|
||||
# 记录保存的区间
|
||||
saved_intervals.append((start_time, end_time))
|
||||
|
||||
# 清空缓冲区
|
||||
segments_to_save.clear()
|
||||
|
||||
# --- 播放音频 -
|
||||
def play_audio(file_path):
|
||||
try:
|
||||
pygame.mixer.init()
|
||||
pygame.mixer.music.load(file_path)
|
||||
pygame.mixer.music.play()
|
||||
while pygame.mixer.music.get_busy():
|
||||
time.sleep(1) # 等待音频播放结束
|
||||
print("播放完成!")
|
||||
except Exception as e:
|
||||
print(f"播放失败: {e}")
|
||||
finally:
|
||||
pygame.mixer.quit()
|
||||
|
||||
async def amain(TEXT, VOICE, OUTPUT_FILE) -> None:
|
||||
"""Main function"""
|
||||
communicate = edge_tts.Communicate(TEXT, VOICE)
|
||||
await communicate.save(OUTPUT_FILE)
|
||||
|
||||
# # -------------- Load QWen2-VL Model ------------
|
||||
# # default: Load the model on the available device(s)
|
||||
# model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
# "Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto"
|
||||
# )
|
||||
# # ------- 设置分辨率,降低现存占用 -------
|
||||
# min_pixels = 256*28*28
|
||||
# max_pixels = 512*28*28
|
||||
# processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
# # --------------------------------------
|
||||
|
||||
# # -------- SenceVoice 语音识别 --模型加载-----
|
||||
# model_dir = r"E:\2_PYTHON\Project\GPT\QWen\pretrained_models\SenseVoiceSmall"
|
||||
# model_senceVoice = AutoModel( model=model_dir, trust_remote_code=True, )
|
||||
# folder_path = "./Test_QWen2_VL/"
|
||||
|
||||
# def Inference(TEMP_VIDEO_FILE=f"{OUTPUT_DIR}/video_0.avi", TEMP_AUDIO_FILE=f"{OUTPUT_DIR}/audio_0.wav"):
|
||||
# file_path = os.path.join(folder_path, "captured_image.jpg") # 设置保存路径
|
||||
# cap = cv2.VideoCapture(TEMP_VIDEO_FILE)
|
||||
# total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
# frame_index = int(total_frames // 2)
|
||||
# # 设置视频帧位置
|
||||
# cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
|
||||
# ret, frame = cap.read()
|
||||
# if not ret:
|
||||
# print(f"无法读取帧索引 {frame_index}")
|
||||
# else:
|
||||
# # 显示帧
|
||||
# cv2.imwrite(file_path, frame)
|
||||
# # cv2.imshow(f"Frame {frame_index}", frame)
|
||||
|
||||
# # -------- SenceVoice 推理 ---------
|
||||
# input_file = (TEMP_AUDIO_FILE)
|
||||
# res = model_senceVoice.generate(
|
||||
# input=input_file,
|
||||
# cache={},
|
||||
# language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech"
|
||||
# use_itn=False,
|
||||
# )
|
||||
# prompt = res[0]['text'].split(">")[-1]
|
||||
# # ---------SenceVoice --end----------
|
||||
|
||||
# # -------- QWen2-VL 模型推理 ---------
|
||||
# messages = [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": [
|
||||
# {
|
||||
# "type": "image",
|
||||
# "image": f"{file_path}",
|
||||
# },
|
||||
# {"type": "text", "text": f"{prompt}"},
|
||||
# ],
|
||||
# }
|
||||
# ]
|
||||
|
||||
# # Preparation for inference
|
||||
# text = processor.apply_chat_template(
|
||||
# messages, tokenize=False, add_generation_prompt=True
|
||||
# )
|
||||
# image_inputs, video_inputs = process_vision_info(messages)
|
||||
# inputs = processor(
|
||||
# text=[text],
|
||||
# images=image_inputs,
|
||||
# videos=video_inputs,
|
||||
# padding=True,
|
||||
# return_tensors="pt",
|
||||
# )
|
||||
# inputs = inputs.to("cuda")
|
||||
|
||||
# # Inference: Generation of the output
|
||||
# generated_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
# generated_ids_trimmed = [
|
||||
# out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
# ]
|
||||
# output_text = processor.batch_decode(
|
||||
# generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
# )
|
||||
# print(output_text)
|
||||
|
||||
# # 输入文本
|
||||
# text = output_text[0]
|
||||
# # asyncio.run(amain(text, "zh-CN-YunxiaNeural", os.path.join(folder_path,"sft_0.mp3")))
|
||||
# # play_audio(f'{folder_path}/sft_0.mp3')
|
||||
|
||||
# asyncio.run(amain(text, "zh-CN-XiaoyiNeural", os.path.join(folder_path,"sft_0.mp3")))
|
||||
# play_audio(f'{folder_path}/sft_0.mp3')
|
||||
|
||||
# # asyncio.run(amain(text, "zh-CN-YunjianNeural", os.path.join(folder_path,"sft_0.mp3")))
|
||||
# # play_audio(f'{folder_path}/sft_0.mp3')
|
||||
|
||||
# # asyncio.run(amain(text, "zh-CN-shaanxi-XiaoniNeural", os.path.join(folder_path,"sft_0.mp3")))
|
||||
# # play_audio(f'{folder_path}/sft_0.mp3')
|
||||
|
||||
# 主函数
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
# 启动音视频录制线程
|
||||
audio_thread = threading.Thread(target=audio_recorder)
|
||||
video_thread = threading.Thread(target=video_recorder)
|
||||
audio_thread.start()
|
||||
video_thread.start()
|
||||
|
||||
print("按 Ctrl+C 停止录制")
|
||||
while True:
|
||||
time.sleep(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("录制停止中...")
|
||||
recording_active = False
|
||||
audio_thread.join()
|
||||
video_thread.join()
|
||||
print("录制已停止")
|
319
7.7_realTime_AV_multiImage.py
Normal file
@ -0,0 +1,319 @@
|
||||
import cv2
|
||||
import pyaudio
|
||||
import wave
|
||||
import threading
|
||||
import numpy as np
|
||||
import time
|
||||
from queue import Queue
|
||||
import webrtcvad
|
||||
import os
|
||||
import threading
|
||||
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
||||
from qwen_vl_utils import process_vision_info
|
||||
import torch
|
||||
from funasr import AutoModel
|
||||
import pygame
|
||||
import edge_tts
|
||||
import asyncio
|
||||
from time import sleep
|
||||
|
||||
# 参数设置
|
||||
AUDIO_RATE = 16000 # 音频采样率
|
||||
AUDIO_CHANNELS = 1 # 单声道
|
||||
CHUNK = 1024 # 音频块大小
|
||||
VAD_MODE = 3 # VAD 模式 (0-3, 数字越大越敏感)
|
||||
OUTPUT_DIR = "./output" # 输出目录
|
||||
NO_SPEECH_THRESHOLD = 1 # 无效语音阈值,单位:秒
|
||||
|
||||
# 确保输出目录存在
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
# 队列用于音频和视频同步缓存
|
||||
audio_queue = Queue()
|
||||
video_queue = Queue()
|
||||
|
||||
# 全局变量
|
||||
last_active_time = time.time()
|
||||
recording_active = True
|
||||
segments_to_save = []
|
||||
saved_intervals = []
|
||||
last_vad_end_time = 0 # 上次保存的 VAD 有效段结束时间
|
||||
|
||||
# 初始化 WebRTC VAD
|
||||
vad = webrtcvad.Vad()
|
||||
vad.set_mode(VAD_MODE)
|
||||
|
||||
# 音频录制线程
|
||||
def audio_recorder():
|
||||
global audio_queue, recording_active, last_active_time, segments_to_save, last_vad_end_time
|
||||
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(format=pyaudio.paInt16,
|
||||
channels=AUDIO_CHANNELS,
|
||||
rate=AUDIO_RATE,
|
||||
input=True,
|
||||
frames_per_buffer=CHUNK)
|
||||
|
||||
audio_buffer = []
|
||||
print("音频录制已开始")
|
||||
|
||||
while recording_active:
|
||||
data = stream.read(CHUNK)
|
||||
audio_buffer.append(data)
|
||||
|
||||
# 每 0.5 秒检测一次 VAD
|
||||
if len(audio_buffer) * CHUNK / AUDIO_RATE >= 0.5:
|
||||
# 拼接音频数据并检测 VAD
|
||||
raw_audio = b''.join(audio_buffer)
|
||||
vad_result = check_vad_activity(raw_audio)
|
||||
|
||||
if vad_result:
|
||||
print("检测到语音活动")
|
||||
last_active_time = time.time()
|
||||
segments_to_save.append((raw_audio, time.time()))
|
||||
else:
|
||||
print("静音中...")
|
||||
|
||||
audio_buffer = [] # 清空缓冲区
|
||||
|
||||
# 检查无效语音时间
|
||||
if time.time() - last_active_time > NO_SPEECH_THRESHOLD:
|
||||
# 检查是否需要保存
|
||||
if segments_to_save and segments_to_save[-1][1] > last_vad_end_time:
|
||||
save_audio_video()
|
||||
last_active_time = time.time()
|
||||
else:
|
||||
pass
|
||||
# print("无新增语音段,跳过保存")
|
||||
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
||||
|
||||
# 视频录制线程
|
||||
def video_recorder():
|
||||
global video_queue, recording_active
|
||||
|
||||
cap = cv2.VideoCapture(0) # 使用默认摄像头
|
||||
print("视频录制已开始")
|
||||
|
||||
while recording_active:
|
||||
ret, frame = cap.read()
|
||||
if ret:
|
||||
video_queue.put((frame, time.time()))
|
||||
|
||||
# 实时显示摄像头画面
|
||||
cv2.imshow("Real Camera", frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'): # 按 Q 键退出
|
||||
break
|
||||
else:
|
||||
print("无法获取摄像头画面")
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
# 检测 VAD 活动
|
||||
def check_vad_activity(audio_data):
|
||||
# 将音频数据分块检测
|
||||
num, rate = 0, 0.4
|
||||
step = int(AUDIO_RATE * 0.02) # 20ms 块大小
|
||||
flag_rate = round(rate * len(audio_data) // step)
|
||||
|
||||
for i in range(0, len(audio_data), step):
|
||||
chunk = audio_data[i:i + step]
|
||||
if len(chunk) == step:
|
||||
if vad.is_speech(chunk, sample_rate=AUDIO_RATE):
|
||||
num += 1
|
||||
|
||||
if num > flag_rate:
|
||||
return True
|
||||
return False
|
||||
|
||||
# 保存音频和视频
|
||||
def save_audio_video():
|
||||
global segments_to_save, video_queue, last_vad_end_time, saved_intervals
|
||||
|
||||
if not segments_to_save:
|
||||
return
|
||||
|
||||
# 获取有效段的时间范围
|
||||
start_time = segments_to_save[0][1]
|
||||
end_time = segments_to_save[-1][1]
|
||||
|
||||
# 检查是否与之前的片段重叠
|
||||
if saved_intervals and saved_intervals[-1][1] >= start_time:
|
||||
print("当前片段与之前片段重叠,跳过保存")
|
||||
segments_to_save.clear()
|
||||
return
|
||||
|
||||
# 保存音频
|
||||
audio_frames = [seg[0] for seg in segments_to_save]
|
||||
audio_output_path = f"{OUTPUT_DIR}/audio_0.wav"
|
||||
wf = wave.open(audio_output_path, 'wb')
|
||||
wf.setnchannels(AUDIO_CHANNELS)
|
||||
wf.setsampwidth(2) # 16-bit PCM
|
||||
wf.setframerate(AUDIO_RATE)
|
||||
wf.writeframes(b''.join(audio_frames))
|
||||
wf.close()
|
||||
print(f"音频保存至 {audio_output_path}")
|
||||
|
||||
# 保存视频
|
||||
video_frames = []
|
||||
while not video_queue.empty():
|
||||
frame, timestamp = video_queue.get()
|
||||
if start_time <= timestamp <= end_time:
|
||||
video_frames.append(frame)
|
||||
|
||||
if video_frames:
|
||||
video_output_path = f"{OUTPUT_DIR}/video_0.avi"
|
||||
out = cv2.VideoWriter(video_output_path, cv2.VideoWriter_fourcc(*'XVID'), 20.0, (640, 480))
|
||||
for frame in video_frames:
|
||||
out.write(frame)
|
||||
out.release()
|
||||
print(f"视频保存至 {video_output_path}")
|
||||
Inference()
|
||||
else:
|
||||
pass
|
||||
# print("无可保存的视频帧")
|
||||
|
||||
# 记录保存的区间
|
||||
saved_intervals.append((start_time, end_time))
|
||||
|
||||
# 清空缓冲区
|
||||
segments_to_save.clear()
|
||||
|
||||
# --- 播放音频 -
|
||||
def play_audio(file_path):
|
||||
try:
|
||||
pygame.mixer.init()
|
||||
pygame.mixer.music.load(file_path)
|
||||
pygame.mixer.music.play()
|
||||
while pygame.mixer.music.get_busy():
|
||||
time.sleep(1) # 等待音频播放结束
|
||||
print("播放完成!")
|
||||
except Exception as e:
|
||||
print(f"播放失败: {e}")
|
||||
finally:
|
||||
pygame.mixer.quit()
|
||||
|
||||
async def amain(TEXT, VOICE, OUTPUT_FILE) -> None:
|
||||
"""Main function"""
|
||||
communicate = edge_tts.Communicate(TEXT, VOICE)
|
||||
await communicate.save(OUTPUT_FILE)
|
||||
|
||||
# -------------- Load QWen2-VL Model ------------
|
||||
# default: Load the model on the available device(s)
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto"
|
||||
)
|
||||
# ------- 设置分辨率,降低现存占用 -------
|
||||
min_pixels = 256*28*28
|
||||
max_pixels = 512*28*28
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
# --------------------------------------
|
||||
|
||||
# -------- SenceVoice 语音识别 --模型加载-----
|
||||
model_dir = r"E:\2_PYTHON\Project\GPT\QWen\pretrained_models\SenseVoiceSmall"
|
||||
model_senceVoice = AutoModel( model=model_dir, trust_remote_code=True, )
|
||||
folder_path = "./Test_QWen2_VL/"
|
||||
|
||||
def Inference(TEMP_VIDEO_FILE=f"{OUTPUT_DIR}/video_0.avi", TEMP_AUDIO_FILE=f"{OUTPUT_DIR}/audio_0.wav"):
|
||||
|
||||
cap = cv2.VideoCapture(TEMP_VIDEO_FILE)
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
S_index = [0.2, 0.4, 0.6, 0.8]
|
||||
frame_index = [int(total_frames * i) for i in S_index]
|
||||
# 设置视频帧位置
|
||||
for idx in frame_index:
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
print(f"无法读取帧索引 {idx}")
|
||||
else:
|
||||
# 保存帧
|
||||
file_path = os.path.join(folder_path, f"captured_image{idx}.jpg") # 设置保存路径
|
||||
cv2.imwrite(file_path, frame)
|
||||
|
||||
# -------- SenceVoice 推理 ---------
|
||||
input_file = (TEMP_AUDIO_FILE)
|
||||
res = model_senceVoice.generate(
|
||||
input=input_file,
|
||||
cache={},
|
||||
language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech"
|
||||
use_itn=False,
|
||||
)
|
||||
prompt = res[0]['text'].split(">")[-1]
|
||||
print("ASR OUT:", prompt)
|
||||
# ---------SenceVoice --end----------
|
||||
|
||||
# -------- QWen2-VL 模型推理 ---------
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": f'{os.path.join(folder_path, f"captured_image{frame_index[0]}.jpg")}'},
|
||||
{"type": "image", "image": f'{os.path.join(folder_path, f"captured_image{frame_index[1]}.jpg")}'},
|
||||
# {"type": "image", "image": f'{os.path.join(folder_path, f"captured_image{frame_index[2]}.jpg")}'},
|
||||
# {"type": "image", "image": f'{os.path.join(folder_path, f"captured_image{frame_index[3]}.jpg")}'},
|
||||
{"type": "text", "text": f"{prompt}"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# Preparation for inference
|
||||
text = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to("cuda")
|
||||
|
||||
# Inference
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
print(output_text)
|
||||
|
||||
# 输入文本
|
||||
text = output_text[0]
|
||||
# asyncio.run(amain(text, "zh-CN-YunxiaNeural", os.path.join(folder_path,"sft_0.mp3")))
|
||||
# play_audio(f'{folder_path}/sft_0.mp3')
|
||||
|
||||
asyncio.run(amain(text, "zh-CN-XiaoyiNeural", os.path.join(folder_path,"sft_0.mp3")))
|
||||
play_audio(f'{folder_path}/sft_0.mp3')
|
||||
|
||||
# asyncio.run(amain(text, "zh-CN-YunjianNeural", os.path.join(folder_path,"sft_0.mp3")))
|
||||
# play_audio(f'{folder_path}/sft_0.mp3')
|
||||
|
||||
# asyncio.run(amain(text, "zh-CN-shaanxi-XiaoniNeural", os.path.join(folder_path,"sft_0.mp3")))
|
||||
# play_audio(f'{folder_path}/sft_0.mp3')
|
||||
|
||||
# 主函数
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
# 启动音视频录制线程
|
||||
audio_thread = threading.Thread(target=audio_recorder)
|
||||
video_thread = threading.Thread(target=video_recorder)
|
||||
audio_thread.start()
|
||||
video_thread.start()
|
||||
|
||||
print("按 Ctrl+C 停止录制")
|
||||
while True:
|
||||
time.sleep(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("录制停止中...")
|
||||
recording_active = False
|
||||
audio_thread.join()
|
||||
video_thread.join()
|
||||
print("录制已停止")
|
342
7.8_realTime_AV_video.py
Normal file
@ -0,0 +1,342 @@
|
||||
import cv2
|
||||
import pyaudio
|
||||
import wave
|
||||
import threading
|
||||
import numpy as np
|
||||
import time
|
||||
from queue import Queue
|
||||
import webrtcvad
|
||||
import os
|
||||
import threading
|
||||
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
||||
from qwen_vl_utils import process_vision_info
|
||||
import torch
|
||||
from funasr import AutoModel
|
||||
import pygame
|
||||
import edge_tts
|
||||
import asyncio
|
||||
from time import sleep
|
||||
|
||||
# 参数设置
|
||||
AUDIO_RATE = 16000 # 音频采样率
|
||||
AUDIO_CHANNELS = 1 # 单声道
|
||||
CHUNK = 1024 # 音频块大小
|
||||
VAD_MODE = 3 # VAD 模式 (0-3, 数字越大越敏感)
|
||||
OUTPUT_DIR = "./output" # 输出目录
|
||||
NO_SPEECH_THRESHOLD = 1 # 无效语音阈值,单位:秒
|
||||
|
||||
# 确保输出目录存在
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
# 队列用于音频和视频同步缓存
|
||||
audio_queue = Queue()
|
||||
video_queue = Queue()
|
||||
|
||||
# 全局变量
|
||||
last_active_time = time.time()
|
||||
recording_active = True
|
||||
segments_to_save = []
|
||||
saved_intervals = []
|
||||
last_vad_end_time = 0 # 上次保存的 VAD 有效段结束时间
|
||||
|
||||
# 初始化 WebRTC VAD
|
||||
vad = webrtcvad.Vad()
|
||||
vad.set_mode(VAD_MODE)
|
||||
|
||||
# 音频录制线程
|
||||
def audio_recorder():
|
||||
global audio_queue, recording_active, last_active_time, segments_to_save, last_vad_end_time
|
||||
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(format=pyaudio.paInt16,
|
||||
channels=AUDIO_CHANNELS,
|
||||
rate=AUDIO_RATE,
|
||||
input=True,
|
||||
frames_per_buffer=CHUNK)
|
||||
|
||||
audio_buffer = []
|
||||
print("音频录制已开始")
|
||||
|
||||
while recording_active:
|
||||
data = stream.read(CHUNK)
|
||||
audio_buffer.append(data)
|
||||
|
||||
# 每 0.5 秒检测一次 VAD
|
||||
if len(audio_buffer) * CHUNK / AUDIO_RATE >= 0.5:
|
||||
# 拼接音频数据并检测 VAD
|
||||
raw_audio = b''.join(audio_buffer)
|
||||
vad_result = check_vad_activity(raw_audio)
|
||||
|
||||
if vad_result:
|
||||
print("检测到语音活动")
|
||||
last_active_time = time.time()
|
||||
segments_to_save.append((raw_audio, time.time()))
|
||||
else:
|
||||
print("静音中...")
|
||||
|
||||
audio_buffer = [] # 清空缓冲区
|
||||
|
||||
# 检查无效语音时间
|
||||
if time.time() - last_active_time > NO_SPEECH_THRESHOLD:
|
||||
# 检查是否需要保存
|
||||
if segments_to_save and segments_to_save[-1][1] > last_vad_end_time:
|
||||
save_audio_video()
|
||||
last_active_time = time.time()
|
||||
else:
|
||||
pass
|
||||
# print("无新增语音段,跳过保存")
|
||||
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
||||
|
||||
# 视频录制线程
|
||||
def video_recorder():
|
||||
global video_queue, recording_active
|
||||
|
||||
cap = cv2.VideoCapture(0) # 使用默认摄像头
|
||||
print("视频录制已开始")
|
||||
|
||||
while recording_active:
|
||||
ret, frame = cap.read()
|
||||
if ret:
|
||||
video_queue.put((frame, time.time()))
|
||||
|
||||
# 实时显示摄像头画面
|
||||
cv2.imshow("Real Camera", frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'): # 按 Q 键退出
|
||||
break
|
||||
else:
|
||||
print("无法获取摄像头画面")
|
||||
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
# 检测 VAD 活动
|
||||
def check_vad_activity(audio_data):
|
||||
# 将音频数据分块检测
|
||||
num, rate = 0, 0.4
|
||||
step = int(AUDIO_RATE * 0.02) # 20ms 块大小
|
||||
flag_rate = round(rate * len(audio_data) // step)
|
||||
|
||||
for i in range(0, len(audio_data), step):
|
||||
chunk = audio_data[i:i + step]
|
||||
if len(chunk) == step:
|
||||
if vad.is_speech(chunk, sample_rate=AUDIO_RATE):
|
||||
num += 1
|
||||
|
||||
if num > flag_rate:
|
||||
return True
|
||||
return False
|
||||
|
||||
# 保存音频和视频
|
||||
def save_audio_video():
|
||||
global segments_to_save, video_queue, last_vad_end_time, saved_intervals
|
||||
|
||||
if not segments_to_save:
|
||||
return
|
||||
|
||||
# 获取有效段的时间范围
|
||||
start_time = segments_to_save[0][1]
|
||||
end_time = segments_to_save[-1][1]
|
||||
|
||||
# 检查是否与之前的片段重叠
|
||||
if saved_intervals and saved_intervals[-1][1] >= start_time:
|
||||
print("当前片段与之前片段重叠,跳过保存")
|
||||
segments_to_save.clear()
|
||||
return
|
||||
|
||||
# 保存音频
|
||||
audio_frames = [seg[0] for seg in segments_to_save]
|
||||
audio_output_path = f"{OUTPUT_DIR}/audio_0.wav"
|
||||
wf = wave.open(audio_output_path, 'wb')
|
||||
wf.setnchannels(AUDIO_CHANNELS)
|
||||
wf.setsampwidth(2) # 16-bit PCM
|
||||
wf.setframerate(AUDIO_RATE)
|
||||
wf.writeframes(b''.join(audio_frames))
|
||||
wf.close()
|
||||
print(f"音频保存至 {audio_output_path}")
|
||||
|
||||
# 保存视频
|
||||
video_frames = []
|
||||
while not video_queue.empty():
|
||||
frame, timestamp = video_queue.get()
|
||||
if start_time <= timestamp <= end_time:
|
||||
video_frames.append(frame)
|
||||
|
||||
if video_frames:
|
||||
video_output_path = f"{OUTPUT_DIR}/video_0.avi"
|
||||
out = cv2.VideoWriter(video_output_path, cv2.VideoWriter_fourcc(*'XVID'), 20.0, (640, 480))
|
||||
for frame in video_frames:
|
||||
out.write(frame)
|
||||
out.release()
|
||||
print(f"视频保存至 {video_output_path}")
|
||||
Inference()
|
||||
else:
|
||||
pass
|
||||
# print("无可保存的视频帧")
|
||||
|
||||
# 记录保存的区间
|
||||
saved_intervals.append((start_time, end_time))
|
||||
|
||||
# 清空缓冲区
|
||||
segments_to_save.clear()
|
||||
|
||||
# --- 播放音频 -
|
||||
def play_audio(file_path):
|
||||
try:
|
||||
pygame.mixer.init()
|
||||
pygame.mixer.music.load(file_path)
|
||||
pygame.mixer.music.play()
|
||||
while pygame.mixer.music.get_busy():
|
||||
time.sleep(1) # 等待音频播放结束
|
||||
print("播放完成!")
|
||||
except Exception as e:
|
||||
print(f"播放失败: {e}")
|
||||
finally:
|
||||
pygame.mixer.quit()
|
||||
|
||||
async def amain(TEXT, VOICE, OUTPUT_FILE) -> None:
|
||||
"""Main function"""
|
||||
communicate = edge_tts.Communicate(TEXT, VOICE)
|
||||
await communicate.save(OUTPUT_FILE)
|
||||
|
||||
# -------------- Load QWen2-VL Model ------------
|
||||
# default: Load the model on the available device(s)
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto"
|
||||
)
|
||||
# ------- 设置分辨率,降低现存占用 -------
|
||||
min_pixels = 256*28*28
|
||||
max_pixels = 512*28*28
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
# --------------------------------------
|
||||
|
||||
# -------- SenceVoice 语音识别 --模型加载-----
|
||||
model_dir = r"D:\AI\download\SenseVoiceSmall"
|
||||
model_senceVoice = AutoModel( model=model_dir, trust_remote_code=True, )
|
||||
folder_path = "./Test_QWen2_VL/"
|
||||
|
||||
def Inference(TEMP_VIDEO_FILE=f"{OUTPUT_DIR}/video_0.avi", TEMP_AUDIO_FILE=f"{OUTPUT_DIR}/audio_0.wav"):
|
||||
|
||||
cap = cv2.VideoCapture(TEMP_VIDEO_FILE)
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
S_index = [0.2, 0.4, 0.6, 0.8]
|
||||
frame_index = [int(total_frames * i) for i in S_index]
|
||||
# 设置视频帧位置
|
||||
for idx in frame_index:
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
print(f"无法读取帧索引 {idx}")
|
||||
else:
|
||||
# 保存帧
|
||||
file_path = os.path.join(folder_path, f"captured_image{idx}.jpg") # 设置保存路径
|
||||
cv2.imwrite(file_path, frame)
|
||||
|
||||
# -------- SenceVoice 推理 ---------
|
||||
input_file = (TEMP_AUDIO_FILE)
|
||||
res = model_senceVoice.generate(
|
||||
input=input_file,
|
||||
cache={},
|
||||
language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech"
|
||||
use_itn=False,
|
||||
)
|
||||
prompt = res[0]['text'].split(">")[-1]
|
||||
print("ASR OUT:", prompt)
|
||||
# ---------SenceVoice --end----------
|
||||
|
||||
# -------- QWen2-VL 模型推理 ---------
|
||||
# Messages containing a images list as a video and a text query
|
||||
# messages = [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": [
|
||||
# {
|
||||
# "type": "video",
|
||||
# "video": [
|
||||
# f'{os.path.join(folder_path, f"captured_image{frame_index[0]}.jpg")}',
|
||||
# f'{os.path.join(folder_path, f"captured_image{frame_index[1]}.jpg")}',
|
||||
# f'{os.path.join(folder_path, f"captured_image{frame_index[2]}.jpg")}',
|
||||
# f'{os.path.join(folder_path, f"captured_image{frame_index[3]}.jpg")}',
|
||||
# ],
|
||||
# "fps": 1.0,
|
||||
# },
|
||||
# {"type": "text", "text": f"{prompt},同时描述这段视频"},
|
||||
# ],
|
||||
# }
|
||||
# ]
|
||||
|
||||
# Messages containing a video and a text query
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "video",
|
||||
"video": f"{OUTPUT_DIR}/video_0.avi",
|
||||
"max_pixels": 360 * 420,
|
||||
"fps": 1.0,
|
||||
},
|
||||
{"type": "text", "text": f"{prompt}"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# Preparation for inference
|
||||
text = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to("cuda")
|
||||
|
||||
# Inference
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
print(output_text)
|
||||
|
||||
# 输入文本
|
||||
text = output_text[0]
|
||||
# asyncio.run(amain(text, "zh-CN-YunxiaNeural", os.path.join(folder_path,"sft_0.mp3")))
|
||||
# play_audio(f'{folder_path}/sft_0.mp3')
|
||||
|
||||
asyncio.run(amain(text, "zh-CN-XiaoyiNeural", os.path.join(folder_path,"sft_0.mp3")))
|
||||
play_audio(f'{folder_path}/sft_0.mp3')
|
||||
|
||||
# asyncio.run(amain(text, "zh-CN-YunjianNeural", os.path.join(folder_path,"sft_0.mp3")))
|
||||
# play_audio(f'{folder_path}/sft_0.mp3')
|
||||
|
||||
# asyncio.run(amain(text, "zh-CN-shaanxi-XiaoniNeural", os.path.join(folder_path,"sft_0.mp3")))
|
||||
# play_audio(f'{folder_path}/sft_0.mp3')
|
||||
|
||||
# 主函数
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
# 启动音视频录制线程
|
||||
audio_thread = threading.Thread(target=audio_recorder)
|
||||
video_thread = threading.Thread(target=video_recorder)
|
||||
audio_thread.start()
|
||||
video_thread.start()
|
||||
|
||||
print("按 Ctrl+C 停止录制")
|
||||
while True:
|
||||
time.sleep(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("录制停止中...")
|
||||
recording_active = False
|
||||
audio_thread.join()
|
||||
video_thread.join()
|
||||
print("录制已停止")
|
28
7_0_FunASR.py
Normal file
@ -0,0 +1,28 @@
|
||||
|
||||
import torchaudio
|
||||
from funasr import AutoModel
|
||||
from IPython.display import Audio
|
||||
|
||||
speaker1_wav = r'E:\2_PYTHON\Project\GPT\QWen\CosyVoice\output\audio_0.wav'
|
||||
waveform, sample_rate = torchaudio.load(speaker1_wav)
|
||||
Audio(waveform, rate=sample_rate, autoplay=True)
|
||||
|
||||
# VAD检测
|
||||
from funasr import AutoModel
|
||||
model = AutoModel(model="fsmn-vad")
|
||||
res = model.generate(input=speaker1_wav)
|
||||
print(res)
|
||||
|
||||
|
||||
# # 多说话人语音识别
|
||||
# funasr_model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
# vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
# punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||
# spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
|
||||
# )
|
||||
# res = funasr_model.generate(input=f"multi_speaker.wav",
|
||||
# batch_size_s=300)
|
||||
# print(res[0]['text'])
|
||||
# res_srt = generate_srt(res[0]['sentence_info'])
|
||||
# print(res_srt)
|
||||
|
61
7_Inference_QWen2-VL.py
Normal file
@ -0,0 +1,61 @@
|
||||
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
||||
from qwen_vl_utils import process_vision_info
|
||||
import torch
|
||||
|
||||
# default: Load the model on the available device(s)
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto"
|
||||
)
|
||||
|
||||
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
|
||||
# model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
# "Qwen/Qwen2-VL-2B-Instruct",
|
||||
# torch_dtype=torch.bfloat16,
|
||||
# attn_implementation="flash_attention_2",
|
||||
# device_map="auto",
|
||||
# )
|
||||
|
||||
# default processer
|
||||
# processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
||||
|
||||
# The default range for the number of visual tokens per image in the model is 4-16384. You can set min_pixels and max_pixels according to your needs, such as a token count range of 256-1280, to balance speed and memory usage.
|
||||
min_pixels = 256*28*28
|
||||
max_pixels = 1280*28*28
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
|
||||
},
|
||||
{"type": "text", "text": "Describe this image."},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# Preparation for inference
|
||||
text = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to("cuda")
|
||||
|
||||
# Inference: Generation of the output
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
print(output_text)
|
50
8_Inference_QWen2-VL_offline_AV.py
Normal file
@ -0,0 +1,50 @@
|
||||
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
|
||||
from qwen_vl_utils import process_vision_info
|
||||
import torch
|
||||
|
||||
# default: Load the model on the available device(s)
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
"Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto"
|
||||
)
|
||||
|
||||
# The default range for the number of visual tokens per image in the model is 4-16384. You can set min_pixels and max_pixels according to your needs, such as a token count range of 256-1280, to balance speed and memory usage.
|
||||
min_pixels = 256*28*28
|
||||
max_pixels = 1280*28*28
|
||||
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
|
||||
},
|
||||
{"type": "text", "text": "Describe this image."},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# Preparation for inference
|
||||
text = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to("cuda")
|
||||
|
||||
# Inference: Generation of the output
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=128)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
print(output_text)
|
27
9.1_test_cam++.py
Normal file
@ -0,0 +1,27 @@
|
||||
# --- 声纹识别模型 CAM++测试 3D-Speaker数据训练 ---
|
||||
from modelscope.pipelines import pipeline
|
||||
|
||||
sv_pipeline = pipeline(
|
||||
task='speaker-verification',
|
||||
model='damo/speech_campplus_sv_zh-cn_16k-common',
|
||||
model_revision='v1.0.0'
|
||||
)
|
||||
|
||||
speaker1_a_wav = r'E:\2_PYTHON\Project\GPT\QWen\CosyVoice\SpeakerVerification_DIR\enroll_wav\enroll_0.wav'
|
||||
speaker1_b_wav = r'E:\2_PYTHON\Project\GPT\QWen\CosyVoice\SpeakerVerification_DIR\enroll_wav\enroll_1.wav'
|
||||
speaker2_a_wav = 'https://modelscope.cn/api/v1/models/damo/speech_campplus_sv_zh-cn_16k-common/repo?Revision=master&FilePath=examples/speaker2_a_cn_16k.wav'
|
||||
# 相同说话人语音
|
||||
result = sv_pipeline([speaker1_a_wav, speaker1_b_wav])
|
||||
print(result)
|
||||
# 不同说话人语音
|
||||
result = sv_pipeline([speaker1_a_wav, speaker2_a_wav])
|
||||
print(result)
|
||||
# 可以自定义得分阈值来进行识别,阈值越高,判定为同一人的条件越严格
|
||||
result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], thr=0.31)
|
||||
print(result)
|
||||
# 可以传入output_emb参数,输出结果中就会包含提取到的说话人embedding
|
||||
result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], output_emb=True)
|
||||
print(result['embs'], result['outputs'])
|
||||
# 可以传入save_dir参数,提取到的说话人embedding会存储在save_dir目录中
|
||||
save_path = r"E:\2_PYTHON\Project\GPT\QWen\CosyVoice\SpeakerVerification_DIR\enroll_ivec"
|
||||
result = sv_pipeline([speaker1_a_wav, speaker2_a_wav], save_dir=save_path)
|
76
CODE_OF_CONDUCT.md
Normal file
@ -0,0 +1,76 @@
|
||||
# Contributor Covenant Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
In the interest of fostering an open and welcoming environment, we as
|
||||
contributors and maintainers pledge to making participation in our project and
|
||||
our community a harassment-free experience for everyone, regardless of age, body
|
||||
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
||||
level of experience, education, socio-economic status, nationality, personal
|
||||
appearance, race, religion, or sexual identity and orientation.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to creating a positive environment
|
||||
include:
|
||||
|
||||
* Using welcoming and inclusive language
|
||||
* Being respectful of differing viewpoints and experiences
|
||||
* Gracefully accepting constructive criticism
|
||||
* Focusing on what is best for the community
|
||||
* Showing empathy towards other community members
|
||||
|
||||
Examples of unacceptable behavior by participants include:
|
||||
|
||||
* The use of sexualized language or imagery and unwelcome sexual attention or
|
||||
advances
|
||||
* Trolling, insulting/derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or electronic
|
||||
address, without explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Our Responsibilities
|
||||
|
||||
Project maintainers are responsible for clarifying the standards of acceptable
|
||||
behavior and are expected to take appropriate and fair corrective action in
|
||||
response to any instances of unacceptable behavior.
|
||||
|
||||
Project maintainers have the right and responsibility to remove, edit, or
|
||||
reject comments, commits, code, wiki edits, issues, and other contributions
|
||||
that are not aligned to this Code of Conduct, or to ban temporarily or
|
||||
permanently any contributor for other behaviors that they deem inappropriate,
|
||||
threatening, offensive, or harmful.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies both within project spaces and in public spaces
|
||||
when an individual is representing the project or its community. Examples of
|
||||
representing a project or community include using an official project e-mail
|
||||
address, posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline event. Representation of a project may be
|
||||
further defined and clarified by project maintainers.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported by contacting the project team at mikelei@mobvoi.com. All
|
||||
complaints will be reviewed and investigated and will result in a response that
|
||||
is deemed necessary and appropriate to the circumstances. The project team is
|
||||
obligated to maintain confidentiality with regard to the reporter of an incident.
|
||||
Further details of specific enforcement policies may be posted separately.
|
||||
|
||||
Project maintainers who do not follow or enforce the Code of Conduct in good
|
||||
faith may face temporary or permanent repercussions as determined by other
|
||||
members of the project's leadership.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
||||
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
||||
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
|
||||
For answers to common questions about this code of conduct, see
|
||||
https://www.contributor-covenant.org/faq
|
8
CosyVoice1.py
Normal file
@ -0,0 +1,8 @@
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice
|
||||
from cosyvoice.utils.file_utils import load_wav
|
||||
import torchaudio
|
||||
cosyvoice = CosyVoice(r'D:\AI\download\CosyVoice-300M', load_jit=True, load_onnx=False, fp16=True)
|
||||
print(cosyvoice.list_avaliable_spks())
|
||||
prompt_speech_16k = load_wav('example_audio.wav', 16000)
|
||||
for i, j in enumerate(cosyvoice.inference_zero_shot('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '对,这就是我,万人敬仰的太乙真人', prompt_speech_16k, stream=False)):
|
||||
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], 22050)
|
31
CosyVoice2.py
Normal file
@ -0,0 +1,31 @@
|
||||
import sys
|
||||
sys.path.append('third_party/Matcha-TTS')
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
|
||||
from cosyvoice.utils.file_utils import load_wav
|
||||
import torchaudio
|
||||
|
||||
cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=False, load_trt=False, fp16=False)
|
||||
|
||||
# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
|
||||
# zero_shot usage
|
||||
prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
|
||||
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
|
||||
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
# fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L248
|
||||
for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。', prompt_speech_16k, stream=False)):
|
||||
torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
# instruct usage
|
||||
for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话', prompt_speech_16k, stream=False)):
|
||||
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
# bistream usage, you can use generator as input, this is useful when using text llm model as input
|
||||
# NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length
|
||||
def text_generator():
|
||||
yield '收到好友从远方寄来的生日礼物,'
|
||||
yield '那份意外的惊喜与深深的祝福'
|
||||
yield '让我心中充满了甜蜜的快乐,'
|
||||
yield '笑容如花儿般绽放。'
|
||||
for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
|
||||
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
16
FAQ.md
Normal file
@ -0,0 +1,16 @@
|
||||
## ModuleNotFoundError: No module named 'matcha'
|
||||
|
||||
Matcha-TTS is a third_party module. Please check `third_party` directory. If there is no `Matcha-TTS`, execute `git submodule update --init --recursive`.
|
||||
|
||||
run `export PYTHONPATH=third_party/Matcha-TTS` if you want to use `from cosyvoice.cli.cosyvoice import CosyVoice` in python script.
|
||||
|
||||
## cannot find resource.zip or cannot unzip resource.zip
|
||||
|
||||
Please make sure you have git-lfs installed. Execute
|
||||
|
||||
```sh
|
||||
git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd
|
||||
cd pretrained_models/CosyVoice-ttsfrd/
|
||||
unzip resource.zip -d .
|
||||
pip install ttsfrd-0.3.6-cp38-cp38-linux_x86_64.whl
|
||||
```
|
201
LICENSE
Normal file
@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
119
README.md
Normal file
@ -0,0 +1,119 @@
|
||||
# 环境配置详细教程 [B站](https://www.bilibili.com/video/BV1HucueQEJo/)
|
||||
|
||||
0. anaconda\ffmpeg安装
|
||||
```
|
||||
网上很多教程,自行搜索
|
||||
```
|
||||
|
||||
```
|
||||
SenseVoiceSmall模型下载:
|
||||
自动下载:设置215行 model_dir = "iic/SenseVoiceSmall"
|
||||
手动下载:https://www.modelscope.cn/models/iic/SenseVoiceSmall/files
|
||||
|
||||
QWen模型下载:
|
||||
自动下载:设置220行 model_name = "Qwen/Qwen2.5-1.5B-Instruct",开启科学上网,可从huggingface自动下载
|
||||
手动下载:https://www.modelscope.cn/models/ 搜索QWen,结果中下载显存可支持模型
|
||||
```
|
||||
|
||||
1. 创建虚拟环境
|
||||
```
|
||||
conda create -n chatAudio python=3.10
|
||||
conda activate chatAudio
|
||||
```
|
||||
2. 安装pytorch+cuda版本,本地测试2.0以上版本均可,这里安装torch=2.3.1+cuda11.8
|
||||
```
|
||||
pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu118
|
||||
|
||||
其它适合自己电脑的torch+cuda版本可在torch官网查找
|
||||
https://pytorch.org/get-started/previous-versions/
|
||||
```
|
||||
|
||||
3. 简易版本安装,不使用cosyvoice时依赖项较少
|
||||
```
|
||||
pip install edge-tts==6.1.17 funasr==1.1.12 ffmpeg==1.4 opencv-python==4.10.0.84 transformers==4.45.2 webrtcvad==2.0.10 qwen-vl-utils==0.0.8 pygame==2.6.1 langid==1.1.6 langdetect==1.0.9 accelerate==0.33.0 PyAudio==0.2.14
|
||||
|
||||
可执行验证:
|
||||
python 13_SenceVoice_QWen2.5_edgeTTS_realTime.py
|
||||
```
|
||||
|
||||
至此,不调用cosyvoice作为合成的交互可成功调用了。
|
||||
|
||||
4. cosyvoice依赖库
|
||||
```
|
||||
大家反馈较多pynini、wetext安装方法:
|
||||
conda install -c conda-forge pynini=2.1.6
|
||||
pip install WeTextProcessing --no-deps
|
||||
```
|
||||
|
||||
5. cosyvoice其它依赖项安装(如遇到权限问题导致安装失败,以管理员形式打开终端)
|
||||
```
|
||||
pip install HyperPyYAML==1.2.2 modelscope==1.15.0 onnxruntime==1.19.2 openai-whisper==20231117 importlib_resources==6.4.5 sounddevice==0.5.1 matcha-tts==0.0.7.0
|
||||
|
||||
可执行验证:
|
||||
python 10_SenceVoice_QWen2.5_cosyVoice.py
|
||||
```
|
||||
|
||||
# :sparkles: 241130-updata
|
||||
|
||||
## 新增声纹识别功能
|
||||
|
||||
设置固定声纹注册语音存储目录,如目录为空则自动进入声纹注册模式。默认注册语音时长大于3秒,可自定义,一般而言时长越长,声纹效果越稳定。
|
||||
声纹模型采用阿里开源的CAM++,其采用3D-Speaker中文数据训练,符合中文对话需求
|
||||
|
||||
## 新增自由定义唤醒词功能
|
||||
|
||||
使用SenceVoice的语音识别能力实现,将语音识别的汉字转为拼音进行匹配。将唤醒词/指令词设置为中文对应拼音,可自由定制。15.0_SenceVoice_kws_CAM++.py中默认为'ni hao xiao qian',15.1_SenceVoice_kws_CAM++.py中默认为'zhan qi lai'[暗影君王实在太cool辣]
|
||||
|
||||
## 新增对话历史内容记忆功能
|
||||
|
||||
通过建立user、system历史队列实现。开启新一轮对话时,首先获取历史记忆,而后拼接新的输入指令。可自由定义最大历史长度,默认为512。
|
||||
|
||||
对应脚本:
|
||||
|
||||
无历史记忆:15.0_SenceVoice_kws_CAM++.py
|
||||
|
||||
有历史记忆:15.1_SenceVoice_kws_CAM++.py
|
||||
|
||||
[演示demo,B站] (https://www.bilibili.com/video/BV1Q6zpYpEgv)
|
||||
|
||||
Have fun! 😊
|
||||
|
||||
# :sparkles: 241123-updata
|
||||
|
||||
## 更新单模态自由打断语音交互
|
||||
|
||||
使用webrtcvad进行实时vad检测,设置一个检测时间段=0.5s,有效语音激活率=40%,每个检测chunk=20ms。也就是说500ms/20ms=25个检测段,如果25*0.4=10个片段激活,则该0.5秒为有效音,加入缓存。
|
||||
|
||||
可改进点:使用模型VAD,去除噪声干扰
|
||||
|
||||
13_SenceVoice_QWen2.5_edgeTTS_realTime.py
|
||||
|
||||
## 音视频多模态语音交互
|
||||
|
||||
基于以上逻辑,替换QWen2.5-1.5B模型为QWen2-VL-2B,可实现音视频多模态交互。模型具有两种输入格式,图片/视频
|
||||
|
||||
14_SenceVoice_QWen2VL_edgeTTS_realTime.py
|
||||
|
||||
[演示demo,B站] (https://www.bilibili.com/video/BV1uQBCYrEYL)
|
||||
|
||||
# :sparkles: 241027-语音交互大模型/SenceVoice-QWen2.5-TTS
|
||||
|
||||
## 框架
|
||||
|
||||
SenceVoice-QWen2.5-CosyVoice搭建
|
||||
|
||||
此工程主代码来于[CosyVoice] (https://github.com/FunAudioLLM/CosyVoice)
|
||||
|
||||
在CosyVoice基础上添加[SenceVoice] (https://github.com/modelscope/FunASR) 作为语音识别模型
|
||||
|
||||
添加[QWwn2.5] (https://github.com/QwenLM/Qwen2.5) 作为大语言模型进行对话理解
|
||||
|
||||
## 3种语音合成方法
|
||||
|
||||
CoosyVoice推理速度慢,严重影响对话实时性,额外添加pyttsx3和edgeTTS
|
||||
|
||||
EdgeTTS实验过程出现链接错误问题,升级版本至6.1.17解决,无需科学上网
|
||||
|
||||
All dependencies are listed in requirements.txt, the interactive inference scripts are 10/11/12_SenceVoice_QWen2.5_xxx.py.
|
||||
|
||||
Have fun! 😊
|
193
README_CosyVoice.md
Normal file
@ -0,0 +1,193 @@
|
||||
# CosyVoice
|
||||
## 👉🏻 [CosyVoice Demos](https://fun-audio-llm.github.io/) 👈🏻
|
||||
[[CosyVoice Paper](https://fun-audio-llm.github.io/pdf/CosyVoice_v1.pdf)][[CosyVoice Studio](https://www.modelscope.cn/studios/iic/CosyVoice-300M)][[CosyVoice Code](https://github.com/FunAudioLLM/CosyVoice)]
|
||||
|
||||
For `SenseVoice`, visit [SenseVoice repo](https://github.com/FunAudioLLM/SenseVoice) and [SenseVoice space](https://www.modelscope.cn/studios/iic/SenseVoice).
|
||||
|
||||
## Roadmap
|
||||
|
||||
- [x] 2024/07
|
||||
|
||||
- [x] Flow matching training support
|
||||
- [x] WeTextProcessing support when ttsfrd is not avaliable
|
||||
- [x] Fastapi server and client
|
||||
|
||||
- [x] 2024/08
|
||||
|
||||
- [x] Repetition Aware Sampling(RAS) inference for llm stability
|
||||
- [x] Streaming inference mode support, including kv cache and sdpa for rtf optimization
|
||||
|
||||
- [x] 2024/09
|
||||
|
||||
- [x] 25hz cosyvoice base model
|
||||
- [x] 25hz cosyvoice voice conversion model
|
||||
|
||||
- [ ] TBD
|
||||
|
||||
- [ ] 25hz llama based llm model which supports lora finetune
|
||||
- [ ] Support more instruction mode
|
||||
- [ ] Music generation
|
||||
- [ ] CosyVoice-500M trained with more multi-lingual data
|
||||
- [ ] More...
|
||||
|
||||
## Install
|
||||
|
||||
**Clone and install**
|
||||
|
||||
- Clone the repo
|
||||
``` sh
|
||||
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
|
||||
# If you failed to clone submodule due to network failures, please run following command until success
|
||||
cd CosyVoice
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
|
||||
- Install Conda: please see https://docs.conda.io/en/latest/miniconda.html
|
||||
- Create Conda env:
|
||||
|
||||
``` sh
|
||||
conda create -n cosyvoice python=3.8
|
||||
conda activate cosyvoice
|
||||
# pynini is required by WeTextProcessing, use conda to install it as it can be executed on all platform.
|
||||
conda install -y -c conda-forge pynini==2.1.5
|
||||
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
||||
|
||||
# If you encounter sox compatibility issues
|
||||
# ubuntu
|
||||
sudo apt-get install sox libsox-dev
|
||||
# centos
|
||||
sudo yum install sox sox-devel
|
||||
```
|
||||
|
||||
**Model download**
|
||||
|
||||
We strongly recommend that you download our pretrained `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource.
|
||||
|
||||
If you are expert in this field, and you are only interested in training your own CosyVoice model from scratch, you can skip this step.
|
||||
|
||||
``` python
|
||||
# SDK模型下载
|
||||
from modelscope import snapshot_download
|
||||
snapshot_download('iic/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
|
||||
snapshot_download('iic/CosyVoice-300M-25Hz', local_dir='pretrained_models/CosyVoice-300M-25Hz')
|
||||
snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
|
||||
snapshot_download('iic/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
|
||||
snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
|
||||
```
|
||||
|
||||
``` sh
|
||||
# git模型下载,请确保已安装git lfs
|
||||
mkdir -p pretrained_models
|
||||
git clone https://www.modelscope.cn/iic/CosyVoice-300M.git pretrained_models/CosyVoice-300M
|
||||
git clone https://www.modelscope.cn/iic/CosyVoice-300M-25Hz.git pretrained_models/CosyVoice-300M-25Hz
|
||||
git clone https://www.modelscope.cn/iic/CosyVoice-300M-SFT.git pretrained_models/CosyVoice-300M-SFT
|
||||
git clone https://www.modelscope.cn/iic/CosyVoice-300M-Instruct.git pretrained_models/CosyVoice-300M-Instruct
|
||||
git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd
|
||||
```
|
||||
|
||||
Optionaly, you can unzip `ttsfrd` resouce and install `ttsfrd` package for better text normalization performance.
|
||||
|
||||
Notice that this step is not necessary. If you do not install `ttsfrd` package, we will use WeTextProcessing by default.
|
||||
|
||||
``` sh
|
||||
cd pretrained_models/CosyVoice-ttsfrd/
|
||||
unzip resource.zip -d .
|
||||
pip install ttsfrd-0.3.6-cp38-cp38-linux_x86_64.whl
|
||||
```
|
||||
|
||||
**Basic Usage**
|
||||
|
||||
For zero_shot/cross_lingual inference, please use `CosyVoice-300M` model.
|
||||
For sft inference, please use `CosyVoice-300M-SFT` model.
|
||||
For instruct inference, please use `CosyVoice-300M-Instruct` model.
|
||||
First, add `third_party/Matcha-TTS` to your `PYTHONPATH`.
|
||||
|
||||
``` sh
|
||||
export PYTHONPATH=third_party/Matcha-TTS
|
||||
```
|
||||
|
||||
``` python
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice
|
||||
from cosyvoice.utils.file_utils import load_wav
|
||||
import torchaudio
|
||||
|
||||
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT', load_jit=True, load_onnx=False, fp16=True)
|
||||
# sft usage
|
||||
print(cosyvoice.list_avaliable_spks())
|
||||
# change stream=True for chunk stream inference
|
||||
for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
|
||||
torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], 22050)
|
||||
|
||||
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-25Hz') # or change to pretrained_models/CosyVoice-300M for 50Hz inference
|
||||
# zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
|
||||
prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
|
||||
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
|
||||
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], 22050)
|
||||
# cross_lingual usage
|
||||
prompt_speech_16k = load_wav('cross_lingual_prompt.wav', 16000)
|
||||
for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k, stream=False)):
|
||||
torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], 22050)
|
||||
# vc usage
|
||||
prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
|
||||
source_speech_16k = load_wav('cross_lingual_prompt.wav', 16000)
|
||||
for i, j in enumerate(cosyvoice.inference_vc(source_speech_16k, prompt_speech_16k, stream=False)):
|
||||
torchaudio.save('vc_{}.wav'.format(i), j['tts_speech'], 22050)
|
||||
|
||||
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct')
|
||||
# instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
|
||||
for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.', stream=False)):
|
||||
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], 22050)
|
||||
```
|
||||
|
||||
**Start web demo**
|
||||
|
||||
You can use our web demo page to get familiar with CosyVoice quickly.
|
||||
We support sft/zero_shot/cross_lingual/instruct inference in web demo.
|
||||
|
||||
Please see the demo website for details.
|
||||
|
||||
``` python
|
||||
# change iic/CosyVoice-300M-SFT for sft inference, or iic/CosyVoice-300M-Instruct for instruct inference
|
||||
python3 webui.py --port 50000 --model_dir pretrained_models/CosyVoice-300M
|
||||
```
|
||||
|
||||
**Advanced Usage**
|
||||
|
||||
For advanced user, we have provided train and inference scripts in `examples/libritts/cosyvoice/run.sh`.
|
||||
You can get familiar with CosyVoice following this recipie.
|
||||
|
||||
**Build for deployment**
|
||||
|
||||
Optionally, if you want to use grpc for service deployment,
|
||||
you can run following steps. Otherwise, you can just ignore this step.
|
||||
|
||||
``` sh
|
||||
cd runtime/python
|
||||
docker build -t cosyvoice:v1.0 .
|
||||
# change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference
|
||||
# for grpc usage
|
||||
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
|
||||
cd grpc && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
|
||||
# for fastapi usage
|
||||
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && python3 server.py --port 50000 --model_dir iic/CosyVoice-300M && sleep infinity"
|
||||
cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
|
||||
```
|
||||
|
||||
## Discussion & Communication
|
||||
|
||||
You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues).
|
||||
|
||||
You can also scan the QR code to join our official Dingding chat group.
|
||||
|
||||
<img src="./asset/dingding.png" width="250px">
|
||||
|
||||
## Acknowledge
|
||||
|
||||
1. We borrowed a lot of code from [FunASR](https://github.com/modelscope/FunASR).
|
||||
2. We borrowed a lot of code from [FunCodec](https://github.com/modelscope/FunCodec).
|
||||
3. We borrowed a lot of code from [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS).
|
||||
4. We borrowed a lot of code from [AcademiCodec](https://github.com/yangdongchao/AcademiCodec).
|
||||
5. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet).
|
||||
|
||||
## Disclaimer
|
||||
The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.
|
BIN
Test_QWen2_VL/captured_image0.jpg
Normal file
After Width: | Height: | Size: 23 KiB |
BIN
Test_QWen2_VL/captured_image11.jpg
Normal file
After Width: | Height: | Size: 43 KiB |
BIN
Test_QWen2_VL/captured_image12.jpg
Normal file
After Width: | Height: | Size: 30 KiB |
BIN
Test_QWen2_VL/captured_image18.jpg
Normal file
After Width: | Height: | Size: 43 KiB |
BIN
Test_QWen2_VL/captured_image2.jpg
Normal file
After Width: | Height: | Size: 39 KiB |
BIN
Test_QWen2_VL/captured_image24.jpg
Normal file
After Width: | Height: | Size: 39 KiB |
BIN
Test_QWen2_VL/captured_image27.jpg
Normal file
After Width: | Height: | Size: 42 KiB |
BIN
Test_QWen2_VL/captured_image3.jpg
Normal file
After Width: | Height: | Size: 40 KiB |
BIN
Test_QWen2_VL/captured_image36.jpg
Normal file
After Width: | Height: | Size: 31 KiB |
BIN
Test_QWen2_VL/captured_image48.jpg
Normal file
After Width: | Height: | Size: 39 KiB |
BIN
Test_QWen2_VL/captured_image5.jpg
Normal file
After Width: | Height: | Size: 41 KiB |
BIN
Test_QWen2_VL/captured_image6.jpg
Normal file
After Width: | Height: | Size: 42 KiB |
BIN
Test_QWen2_VL/captured_image8.jpg
Normal file
After Width: | Height: | Size: 43 KiB |
BIN
Test_QWen2_VL/captured_image9.jpg
Normal file
After Width: | Height: | Size: 43 KiB |
BIN
Test_QWen2_VL/sft_0.mp3
Normal file
BIN
Test_QWen2_VL/sft_1.mp3
Normal file
BIN
Test_QWen2_VL/sft_10.mp3
Normal file
BIN
Test_QWen2_VL/sft_11.mp3
Normal file
BIN
Test_QWen2_VL/sft_12.mp3
Normal file
BIN
Test_QWen2_VL/sft_13.mp3
Normal file
BIN
Test_QWen2_VL/sft_14.mp3
Normal file
BIN
Test_QWen2_VL/sft_15.mp3
Normal file
BIN
Test_QWen2_VL/sft_16.mp3
Normal file
BIN
Test_QWen2_VL/sft_17.mp3
Normal file
BIN
Test_QWen2_VL/sft_18.mp3
Normal file
BIN
Test_QWen2_VL/sft_19.mp3
Normal file
BIN
Test_QWen2_VL/sft_2.mp3
Normal file
BIN
Test_QWen2_VL/sft_20.mp3
Normal file
BIN
Test_QWen2_VL/sft_21.mp3
Normal file
BIN
Test_QWen2_VL/sft_22.mp3
Normal file
BIN
Test_QWen2_VL/sft_3.mp3
Normal file
BIN
Test_QWen2_VL/sft_4.mp3
Normal file
BIN
Test_QWen2_VL/sft_5.mp3
Normal file
BIN
Test_QWen2_VL/sft_6.mp3
Normal file
BIN
Test_QWen2_VL/sft_7.mp3
Normal file
BIN
Test_QWen2_VL/sft_8.mp3
Normal file
BIN
Test_QWen2_VL/sft_9.mp3
Normal file
BIN
Test_QWen2_VL/sft_tmp_0.mp3
Normal file
BIN
Test_QWen2_VL/sft_tmp_1.mp3
Normal file
BIN
Test_QWen2_VL/sft_tmp_2.mp3
Normal file
BIN
Test_QWen2_VL/sft_tmp_3.mp3
Normal file
BIN
Test_QWen2_VL/sft_tmp_4.mp3
Normal file
BIN
Test_QWen2_VL/sft_tmp_5.mp3
Normal file
BIN
Test_QWen2_VL/sft_tmp_6.mp3
Normal file
BIN
Test_QWen2_VL/sft_tmp_7.mp3
Normal file
BIN
asset/dingding.png
Normal file
After Width: | Height: | Size: 94 KiB |
0
cosyvoice/__init__.py
Normal file
92
cosyvoice/bin/average_model.py
Normal file
@ -0,0 +1,92 @@
|
||||
# Copyright (c) 2020 Mobvoi Inc (Di Wu)
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import glob
|
||||
|
||||
import yaml
|
||||
import torch
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='average model')
|
||||
parser.add_argument('--dst_model', required=True, help='averaged model')
|
||||
parser.add_argument('--src_path',
|
||||
required=True,
|
||||
help='src model path for average')
|
||||
parser.add_argument('--val_best',
|
||||
action="store_true",
|
||||
help='averaged model')
|
||||
parser.add_argument('--num',
|
||||
default=5,
|
||||
type=int,
|
||||
help='nums for averaged model')
|
||||
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
val_scores = []
|
||||
if args.val_best:
|
||||
yamls = glob.glob('{}/*.yaml'.format(args.src_path))
|
||||
yamls = [
|
||||
f for f in yamls
|
||||
if not (os.path.basename(f).startswith('train')
|
||||
or os.path.basename(f).startswith('init'))
|
||||
]
|
||||
for y in yamls:
|
||||
with open(y, 'r') as f:
|
||||
dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
|
||||
loss = float(dic_yaml['loss_dict']['loss'])
|
||||
epoch = int(dic_yaml['epoch'])
|
||||
step = int(dic_yaml['step'])
|
||||
tag = dic_yaml['tag']
|
||||
val_scores += [[epoch, step, loss, tag]]
|
||||
sorted_val_scores = sorted(val_scores,
|
||||
key=lambda x: x[2],
|
||||
reverse=False)
|
||||
print("best val (epoch, step, loss, tag) = " +
|
||||
str(sorted_val_scores[:args.num]))
|
||||
path_list = [
|
||||
args.src_path + '/epoch_{}_whole.pt'.format(score[0])
|
||||
for score in sorted_val_scores[:args.num]
|
||||
]
|
||||
print(path_list)
|
||||
avg = {}
|
||||
num = args.num
|
||||
assert num == len(path_list)
|
||||
for path in path_list:
|
||||
print('Processing {}'.format(path))
|
||||
states = torch.load(path, map_location=torch.device('cpu'))
|
||||
for k in states.keys():
|
||||
if k not in avg.keys():
|
||||
avg[k] = states[k].clone()
|
||||
else:
|
||||
avg[k] += states[k]
|
||||
# average
|
||||
for k in avg.keys():
|
||||
if avg[k] is not None:
|
||||
# pytorch 1.6 use true_divide instead of /=
|
||||
avg[k] = torch.true_divide(avg[k], num)
|
||||
print('Saving to {}'.format(args.dst_model))
|
||||
torch.save(avg, args.dst_model)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
74
cosyvoice/bin/export_jit.py
Normal file
@ -0,0 +1,74 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append('{}/../..'.format(ROOT_DIR))
|
||||
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='export your model for deployment')
|
||||
parser.add_argument('--model_dir',
|
||||
type=str,
|
||||
default='pretrained_models/CosyVoice-300M',
|
||||
help='local path')
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)s %(message)s')
|
||||
|
||||
torch._C._jit_set_fusion_strategy([('STATIC', 1)])
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._jit_set_profiling_executor(False)
|
||||
|
||||
cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
|
||||
|
||||
# 1. export llm text_encoder
|
||||
llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
|
||||
script = torch.jit.script(llm_text_encoder)
|
||||
script = torch.jit.freeze(script)
|
||||
script = torch.jit.optimize_for_inference(script)
|
||||
script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
|
||||
|
||||
# 2. export llm llm
|
||||
llm_llm = cosyvoice.model.llm.llm.half()
|
||||
script = torch.jit.script(llm_llm)
|
||||
script = torch.jit.freeze(script, preserved_attrs=['forward_chunk'])
|
||||
script = torch.jit.optimize_for_inference(script)
|
||||
script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
|
||||
|
||||
# 3. export flow encoder
|
||||
flow_encoder = cosyvoice.model.flow.encoder
|
||||
script = torch.jit.script(flow_encoder)
|
||||
script = torch.jit.freeze(script)
|
||||
script = torch.jit.optimize_for_inference(script)
|
||||
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
112
cosyvoice/bin/export_onnx.py
Normal file
@ -0,0 +1,112 @@
|
||||
# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||
import os
|
||||
import sys
|
||||
import onnxruntime
|
||||
import random
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append('{}/../..'.format(ROOT_DIR))
|
||||
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice
|
||||
|
||||
|
||||
def get_dummy_input(batch_size, seq_len, out_channels, device):
|
||||
x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
||||
mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
|
||||
mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
||||
t = torch.rand((batch_size), dtype=torch.float32, device=device)
|
||||
spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
|
||||
cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
||||
return x, mask, mu, t, spks, cond
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='export your model for deployment')
|
||||
parser.add_argument('--model_dir',
|
||||
type=str,
|
||||
default='pretrained_models/CosyVoice-300M',
|
||||
help='local path')
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)s %(message)s')
|
||||
|
||||
cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
|
||||
|
||||
# 1. export flow decoder estimator
|
||||
estimator = cosyvoice.model.flow.decoder.estimator
|
||||
|
||||
device = cosyvoice.model.device
|
||||
batch_size, seq_len = 1, 256
|
||||
out_channels = cosyvoice.model.flow.decoder.estimator.out_channels
|
||||
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
|
||||
torch.onnx.export(
|
||||
estimator,
|
||||
(x, mask, mu, t, spks, cond),
|
||||
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
||||
export_params=True,
|
||||
opset_version=18,
|
||||
do_constant_folding=True,
|
||||
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
|
||||
output_names=['estimator_out'],
|
||||
dynamic_axes={
|
||||
'x': {0: 'batch_size', 2: 'seq_len'},
|
||||
'mask': {0: 'batch_size', 2: 'seq_len'},
|
||||
'mu': {0: 'batch_size', 2: 'seq_len'},
|
||||
'cond': {0: 'batch_size', 2: 'seq_len'},
|
||||
't': {0: 'batch_size'},
|
||||
'spks': {0: 'batch_size'},
|
||||
'estimator_out': {0: 'batch_size', 2: 'seq_len'},
|
||||
}
|
||||
)
|
||||
|
||||
# 2. test computation consistency
|
||||
option = onnxruntime.SessionOptions()
|
||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
option.intra_op_num_threads = 1
|
||||
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
||||
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
||||
sess_options=option, providers=providers)
|
||||
|
||||
for _ in tqdm(range(10)):
|
||||
x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
|
||||
output_pytorch = estimator(x, mask, mu, t, spks, cond)
|
||||
ort_inputs = {
|
||||
'x': x.cpu().numpy(),
|
||||
'mask': mask.cpu().numpy(),
|
||||
'mu': mu.cpu().numpy(),
|
||||
't': t.cpu().numpy(),
|
||||
'spks': spks.cpu().numpy(),
|
||||
'cond': cond.cpu().numpy()
|
||||
}
|
||||
output_onnx = estimator_onnx.run(None, ort_inputs)[0]
|
||||
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
115
cosyvoice/bin/inference.py
Normal file
@ -0,0 +1,115 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||
import os
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import torchaudio
|
||||
from hyperpyyaml import load_hyperpyyaml
|
||||
from tqdm import tqdm
|
||||
from cosyvoice.cli.model import CosyVoiceModel
|
||||
from cosyvoice.dataset.dataset import Dataset
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='inference with your model')
|
||||
parser.add_argument('--config', required=True, help='config file')
|
||||
parser.add_argument('--prompt_data', required=True, help='prompt data file')
|
||||
parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
|
||||
parser.add_argument('--tts_text', required=True, help='tts input file')
|
||||
parser.add_argument('--llm_model', required=True, help='llm model file')
|
||||
parser.add_argument('--flow_model', required=True, help='flow model file')
|
||||
parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
|
||||
parser.add_argument('--gpu',
|
||||
type=int,
|
||||
default=-1,
|
||||
help='gpu id for this rank, -1 for cpu')
|
||||
parser.add_argument('--mode',
|
||||
default='sft',
|
||||
choices=['sft', 'zero_shot'],
|
||||
help='inference mode')
|
||||
parser.add_argument('--result_dir', required=True, help='asr result file')
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)s %(message)s')
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
||||
|
||||
# Init cosyvoice models from configs
|
||||
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
|
||||
device = torch.device('cuda' if use_cuda else 'cpu')
|
||||
with open(args.config, 'r') as f:
|
||||
configs = load_hyperpyyaml(f)
|
||||
|
||||
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
|
||||
model.load(args.llm_model, args.flow_model, args.hifigan_model)
|
||||
|
||||
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
|
||||
tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
|
||||
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
|
||||
|
||||
del configs
|
||||
os.makedirs(args.result_dir, exist_ok=True)
|
||||
fn = os.path.join(args.result_dir, 'wav.scp')
|
||||
f = open(fn, 'w')
|
||||
with torch.no_grad():
|
||||
for _, batch in tqdm(enumerate(test_data_loader)):
|
||||
utts = batch["utts"]
|
||||
assert len(utts) == 1, "inference mode only support batchsize 1"
|
||||
text_token = batch["text_token"].to(device)
|
||||
text_token_len = batch["text_token_len"].to(device)
|
||||
tts_index = batch["tts_index"]
|
||||
tts_text_token = batch["tts_text_token"].to(device)
|
||||
tts_text_token_len = batch["tts_text_token_len"].to(device)
|
||||
speech_token = batch["speech_token"].to(device)
|
||||
speech_token_len = batch["speech_token_len"].to(device)
|
||||
speech_feat = batch["speech_feat"].to(device)
|
||||
speech_feat_len = batch["speech_feat_len"].to(device)
|
||||
utt_embedding = batch["utt_embedding"].to(device)
|
||||
spk_embedding = batch["spk_embedding"].to(device)
|
||||
if args.mode == 'sft':
|
||||
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
||||
'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
|
||||
else:
|
||||
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
||||
'prompt_text': text_token, 'prompt_text_len': text_token_len,
|
||||
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
||||
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
||||
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
||||
'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
|
||||
tts_speeches = []
|
||||
for model_output in model.tts(**model_input):
|
||||
tts_speeches.append(model_output['tts_speech'])
|
||||
tts_speeches = torch.concat(tts_speeches, dim=1)
|
||||
tts_key = '{}_{}'.format(utts[0], tts_index[0])
|
||||
tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
|
||||
torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
|
||||
f.write('{} {}\n'.format(tts_key, tts_fn))
|
||||
f.flush()
|
||||
f.close()
|
||||
logging.info('Result wav.scp saved in {}'.format(fn))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
158
cosyvoice/bin/train.py
Normal file
@ -0,0 +1,158 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
import argparse
|
||||
import datetime
|
||||
import logging
|
||||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||
from copy import deepcopy
|
||||
import os
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import deepspeed
|
||||
|
||||
from hyperpyyaml import load_hyperpyyaml
|
||||
|
||||
from torch.distributed.elastic.multiprocessing.errors import record
|
||||
|
||||
from cosyvoice.utils.executor import Executor
|
||||
from cosyvoice.utils.train_utils import (
|
||||
init_distributed,
|
||||
init_dataset_and_dataloader,
|
||||
init_optimizer_and_scheduler,
|
||||
init_summarywriter, save_model,
|
||||
wrap_cuda_model, check_modify_and_save_config)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='training your network')
|
||||
parser.add_argument('--train_engine',
|
||||
default='torch_ddp',
|
||||
choices=['torch_ddp', 'deepspeed'],
|
||||
help='Engine for paralleled training')
|
||||
parser.add_argument('--model', required=True, help='model which will be trained')
|
||||
parser.add_argument('--config', required=True, help='config file')
|
||||
parser.add_argument('--train_data', required=True, help='train data file')
|
||||
parser.add_argument('--cv_data', required=True, help='cv data file')
|
||||
parser.add_argument('--checkpoint', help='checkpoint model')
|
||||
parser.add_argument('--model_dir', required=True, help='save model dir')
|
||||
parser.add_argument('--tensorboard_dir',
|
||||
default='tensorboard',
|
||||
help='tensorboard log dir')
|
||||
parser.add_argument('--ddp.dist_backend',
|
||||
dest='dist_backend',
|
||||
default='nccl',
|
||||
choices=['nccl', 'gloo'],
|
||||
help='distributed backend')
|
||||
parser.add_argument('--num_workers',
|
||||
default=0,
|
||||
type=int,
|
||||
help='num of subprocess workers for reading')
|
||||
parser.add_argument('--prefetch',
|
||||
default=100,
|
||||
type=int,
|
||||
help='prefetch number')
|
||||
parser.add_argument('--pin_memory',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Use pinned memory buffers used for reading')
|
||||
parser.add_argument('--use_amp',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Use automatic mixed precision training')
|
||||
parser.add_argument('--deepspeed.save_states',
|
||||
dest='save_states',
|
||||
default='model_only',
|
||||
choices=['model_only', 'model+optimizer'],
|
||||
help='save model/optimizer states')
|
||||
parser.add_argument('--timeout',
|
||||
default=60,
|
||||
type=int,
|
||||
help='timeout (in seconds) of cosyvoice_join.')
|
||||
parser = deepspeed.add_config_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
@record
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)s %(message)s')
|
||||
# gan train has some special initialization logic
|
||||
gan = True if args.model == 'hifigan' else False
|
||||
|
||||
override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
|
||||
if gan is True:
|
||||
override_dict.pop('hift')
|
||||
with open(args.config, 'r') as f:
|
||||
configs = load_hyperpyyaml(f, overrides=override_dict)
|
||||
if gan is True:
|
||||
configs['train_conf'] = configs['train_conf_gan']
|
||||
configs['train_conf'].update(vars(args))
|
||||
|
||||
# Init env for ddp
|
||||
init_distributed(args)
|
||||
|
||||
# Get dataset & dataloader
|
||||
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
|
||||
init_dataset_and_dataloader(args, configs, gan)
|
||||
|
||||
# Do some sanity checks and save config to arsg.model_dir
|
||||
configs = check_modify_and_save_config(args, configs)
|
||||
|
||||
# Tensorboard summary
|
||||
writer = init_summarywriter(args)
|
||||
|
||||
# load checkpoint
|
||||
model = configs[args.model]
|
||||
if args.checkpoint is not None:
|
||||
if os.path.exists(args.checkpoint):
|
||||
model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'), strict=False)
|
||||
else:
|
||||
logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
|
||||
|
||||
# Dispatch model from cpu to gpu
|
||||
model = wrap_cuda_model(args, model)
|
||||
|
||||
# Get optimizer & scheduler
|
||||
model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
|
||||
|
||||
# Save init checkpoints
|
||||
info_dict = deepcopy(configs['train_conf'])
|
||||
save_model(model, 'init', info_dict)
|
||||
|
||||
# Get executor
|
||||
executor = Executor(gan=gan)
|
||||
|
||||
# Init scaler, used for pytorch amp mixed precision training
|
||||
scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
|
||||
|
||||
# Start training loop
|
||||
for epoch in range(info_dict['max_epoch']):
|
||||
executor.epoch = epoch
|
||||
train_dataset.set_epoch(epoch)
|
||||
dist.barrier()
|
||||
group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
|
||||
if gan is True:
|
||||
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
|
||||
writer, info_dict, scaler, group_join)
|
||||
else:
|
||||
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
|
||||
dist.destroy_process_group(group_join)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
0
cosyvoice/cli/__init__.py
Normal file
113
cosyvoice/cli/cosyvoice.py
Normal file
@ -0,0 +1,113 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
from hyperpyyaml import load_hyperpyyaml
|
||||
from modelscope import snapshot_download
|
||||
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
||||
from cosyvoice.cli.model import CosyVoiceModel
|
||||
from cosyvoice.utils.file_utils import logging
|
||||
|
||||
|
||||
class CosyVoice:
|
||||
|
||||
def __init__(self, model_dir, load_jit=True, load_onnx=False, fp16=True):
|
||||
instruct = True if '-Instruct' in model_dir else False
|
||||
self.model_dir = model_dir
|
||||
if not os.path.exists(model_dir):
|
||||
model_dir = snapshot_download(model_dir)
|
||||
with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
|
||||
configs = load_hyperpyyaml(f)
|
||||
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
||||
configs['feat_extractor'],
|
||||
'{}/campplus.onnx'.format(model_dir),
|
||||
'{}/speech_tokenizer_v1.onnx'.format(model_dir),
|
||||
'{}/spk2info.pt'.format(model_dir),
|
||||
instruct,
|
||||
configs['allowed_special'])
|
||||
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
|
||||
self.model.load('{}/llm.pt'.format(model_dir),
|
||||
'{}/flow.pt'.format(model_dir),
|
||||
'{}/hift.pt'.format(model_dir))
|
||||
if load_jit:
|
||||
self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
|
||||
'{}/llm.llm.fp16.zip'.format(model_dir),
|
||||
'{}/flow.encoder.fp32.zip'.format(model_dir))
|
||||
if load_onnx:
|
||||
self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
|
||||
del configs
|
||||
|
||||
def list_avaliable_spks(self):
|
||||
spks = list(self.frontend.spk2info.keys())
|
||||
return spks
|
||||
|
||||
def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0):
|
||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|
||||
model_input = self.frontend.frontend_sft(i, spk_id)
|
||||
start_time = time.time()
|
||||
logging.info('synthesis text {}'.format(i))
|
||||
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||
speech_len = model_output['tts_speech'].shape[1] / 22050
|
||||
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
|
||||
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0):
|
||||
prompt_text = self.frontend.text_normalize(prompt_text, split=False)
|
||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|
||||
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
|
||||
start_time = time.time()
|
||||
logging.info('synthesis text {}'.format(i))
|
||||
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||
speech_len = model_output['tts_speech'].shape[1] / 22050
|
||||
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
|
||||
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0):
|
||||
if self.frontend.instruct is True:
|
||||
raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
|
||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|
||||
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
|
||||
start_time = time.time()
|
||||
logging.info('synthesis text {}'.format(i))
|
||||
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||
speech_len = model_output['tts_speech'].shape[1] / 22050
|
||||
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
|
||||
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0):
|
||||
if self.frontend.instruct is False:
|
||||
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
|
||||
instruct_text = self.frontend.text_normalize(instruct_text, split=False)
|
||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|
||||
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
|
||||
start_time = time.time()
|
||||
logging.info('synthesis text {}'.format(i))
|
||||
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||
speech_len = model_output['tts_speech'].shape[1] / 22050
|
||||
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
|
||||
def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
|
||||
model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k)
|
||||
start_time = time.time()
|
||||
for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
|
||||
speech_len = model_output['tts_speech'].shape[1] / 22050
|
||||
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||
yield model_output
|
||||
start_time = time.time()
|
188
cosyvoice/cli/frontend.py
Normal file
@ -0,0 +1,188 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import partial
|
||||
import onnxruntime
|
||||
import torch
|
||||
import numpy as np
|
||||
import whisper
|
||||
from typing import Callable
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
import torchaudio
|
||||
import os
|
||||
import re
|
||||
import inflect
|
||||
try:
|
||||
import ttsfrd
|
||||
use_ttsfrd = True
|
||||
except ImportError:
|
||||
print("failed to import ttsfrd, use WeTextProcessing instead")
|
||||
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
||||
from tn.english.normalizer import Normalizer as EnNormalizer
|
||||
use_ttsfrd = False
|
||||
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph
|
||||
|
||||
|
||||
class CosyVoiceFrontEnd:
|
||||
|
||||
def __init__(self,
|
||||
get_tokenizer: Callable,
|
||||
feat_extractor: Callable,
|
||||
campplus_model: str,
|
||||
speech_tokenizer_model: str,
|
||||
spk2info: str = '',
|
||||
instruct: bool = False,
|
||||
allowed_special: str = 'all'):
|
||||
self.tokenizer = get_tokenizer()
|
||||
self.feat_extractor = feat_extractor
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
option = onnxruntime.SessionOptions()
|
||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
option.intra_op_num_threads = 1
|
||||
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
||||
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
|
||||
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
|
||||
"CPUExecutionProvider"])
|
||||
if os.path.exists(spk2info):
|
||||
self.spk2info = torch.load(spk2info, map_location=self.device)
|
||||
else:
|
||||
self.spk2info = {}
|
||||
self.instruct = instruct
|
||||
self.allowed_special = allowed_special
|
||||
self.inflect_parser = inflect.engine()
|
||||
self.use_ttsfrd = use_ttsfrd
|
||||
if self.use_ttsfrd:
|
||||
self.frd = ttsfrd.TtsFrontendEngine()
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
|
||||
'failed to initialize ttsfrd resource'
|
||||
self.frd.set_lang_type('pinyin')
|
||||
self.frd.enable_pinyin_mix(True)
|
||||
self.frd.set_breakmodel_index(1)
|
||||
else:
|
||||
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False)
|
||||
self.en_tn_model = EnNormalizer()
|
||||
|
||||
def _extract_text_token(self, text):
|
||||
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
|
||||
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
|
||||
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
|
||||
return text_token, text_token_len
|
||||
|
||||
def _extract_speech_token(self, speech):
|
||||
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
|
||||
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
|
||||
speech_token = self.speech_tokenizer_session.run(None,
|
||||
{self.speech_tokenizer_session.get_inputs()[0].name:
|
||||
feat.detach().cpu().numpy(),
|
||||
self.speech_tokenizer_session.get_inputs()[1].name:
|
||||
np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
|
||||
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
|
||||
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
|
||||
return speech_token, speech_token_len
|
||||
|
||||
def _extract_spk_embedding(self, speech):
|
||||
feat = kaldi.fbank(speech,
|
||||
num_mel_bins=80,
|
||||
dither=0,
|
||||
sample_frequency=16000)
|
||||
feat = feat - feat.mean(dim=0, keepdim=True)
|
||||
embedding = self.campplus_session.run(None,
|
||||
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
||||
embedding = torch.tensor([embedding]).to(self.device)
|
||||
return embedding
|
||||
|
||||
def _extract_speech_feat(self, speech):
|
||||
speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
|
||||
speech_feat = speech_feat.unsqueeze(dim=0)
|
||||
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
|
||||
return speech_feat, speech_feat_len
|
||||
|
||||
def text_normalize(self, text, split=True):
|
||||
text = text.strip()
|
||||
if contains_chinese(text):
|
||||
if self.use_ttsfrd:
|
||||
text = self.frd.get_frd_extra_info(text, 'input')
|
||||
else:
|
||||
text = self.zh_tn_model.normalize(text)
|
||||
text = text.replace("\n", "")
|
||||
text = replace_blank(text)
|
||||
text = replace_corner_mark(text)
|
||||
text = text.replace(".", "。")
|
||||
text = text.replace(" - ", ",")
|
||||
text = remove_bracket(text)
|
||||
text = re.sub(r'[,,、]+$', '。', text)
|
||||
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
|
||||
token_min_n=60, merge_len=20, comma_split=False))
|
||||
else:
|
||||
if self.use_ttsfrd:
|
||||
text = self.frd.get_frd_extra_info(text, 'input')
|
||||
else:
|
||||
text = self.en_tn_model.normalize(text)
|
||||
text = spell_out_number(text, self.inflect_parser)
|
||||
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
|
||||
token_min_n=60, merge_len=20, comma_split=False))
|
||||
if split is False:
|
||||
return text
|
||||
return texts
|
||||
|
||||
def frontend_sft(self, tts_text, spk_id):
|
||||
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
||||
embedding = self.spk2info[spk_id]['embedding']
|
||||
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
|
||||
return model_input
|
||||
|
||||
def frontend_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
|
||||
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
||||
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
|
||||
prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
|
||||
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
|
||||
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
||||
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
||||
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
|
||||
'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
|
||||
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
||||
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
||||
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
||||
'llm_embedding': embedding, 'flow_embedding': embedding}
|
||||
return model_input
|
||||
|
||||
def frontend_cross_lingual(self, tts_text, prompt_speech_16k):
|
||||
model_input = self.frontend_zero_shot(tts_text, '', prompt_speech_16k)
|
||||
# in cross lingual mode, we remove prompt in llm
|
||||
del model_input['prompt_text']
|
||||
del model_input['prompt_text_len']
|
||||
del model_input['llm_prompt_speech_token']
|
||||
del model_input['llm_prompt_speech_token_len']
|
||||
return model_input
|
||||
|
||||
def frontend_instruct(self, tts_text, spk_id, instruct_text):
|
||||
model_input = self.frontend_sft(tts_text, spk_id)
|
||||
# in instruct mode, we remove spk_embedding in llm due to information leakage
|
||||
del model_input['llm_embedding']
|
||||
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text + '<endofprompt>')
|
||||
model_input['prompt_text'] = instruct_text_token
|
||||
model_input['prompt_text_len'] = instruct_text_token_len
|
||||
return model_input
|
||||
|
||||
def frontend_vc(self, source_speech_16k, prompt_speech_16k):
|
||||
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
|
||||
prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
|
||||
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
|
||||
embedding = self._extract_spk_embedding(prompt_speech_16k)
|
||||
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
|
||||
model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
|
||||
'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
|
||||
'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
|
||||
'flow_embedding': embedding}
|
||||
return model_input
|
256
cosyvoice/cli/model.py
Normal file
@ -0,0 +1,256 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
import numpy as np
|
||||
import threading
|
||||
import time
|
||||
from torch.nn import functional as F
|
||||
from contextlib import nullcontext
|
||||
import uuid
|
||||
from cosyvoice.utils.common import fade_in_out
|
||||
|
||||
|
||||
class CosyVoiceModel:
|
||||
|
||||
def __init__(self,
|
||||
llm: torch.nn.Module,
|
||||
flow: torch.nn.Module,
|
||||
hift: torch.nn.Module,
|
||||
fp16: bool):
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.llm = llm
|
||||
self.flow = flow
|
||||
self.hift = hift
|
||||
self.fp16 = fp16
|
||||
self.token_min_hop_len = 2 * self.flow.input_frame_rate
|
||||
self.token_max_hop_len = 4 * self.flow.input_frame_rate
|
||||
self.token_overlap_len = 20
|
||||
# mel fade in out
|
||||
self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
|
||||
self.mel_window = np.hamming(2 * self.mel_overlap_len)
|
||||
# hift cache
|
||||
self.mel_cache_len = 20
|
||||
self.source_cache_len = int(self.mel_cache_len * 256)
|
||||
# speech fade in out
|
||||
self.speech_window = np.hamming(2 * self.source_cache_len)
|
||||
# rtf and decoding related
|
||||
self.stream_scale_factor = 1
|
||||
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
|
||||
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
||||
self.lock = threading.Lock()
|
||||
# dict used to store session related variable
|
||||
self.tts_speech_token_dict = {}
|
||||
self.llm_end_dict = {}
|
||||
self.mel_overlap_dict = {}
|
||||
self.flow_cache_dict = {}
|
||||
self.hift_cache_dict = {}
|
||||
|
||||
def load(self, llm_model, flow_model, hift_model):
|
||||
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device), strict=False)
|
||||
self.llm.to(self.device).eval()
|
||||
if self.fp16 is True:
|
||||
self.llm.half()
|
||||
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device), strict=False)
|
||||
self.flow.to(self.device).eval()
|
||||
# in case hift_model is a hifigan model
|
||||
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device).items()}
|
||||
self.hift.load_state_dict(hift_state_dict, strict=False)
|
||||
self.hift.to(self.device).eval()
|
||||
|
||||
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
|
||||
assert self.fp16 is True, "we only provide fp16 jit model, set fp16=True if you want to use jit model"
|
||||
llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
|
||||
self.llm.text_encoder = llm_text_encoder
|
||||
llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
|
||||
self.llm.llm = llm_llm
|
||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||
self.flow.encoder = flow_encoder
|
||||
|
||||
def load_onnx(self, flow_decoder_estimator_model):
|
||||
import onnxruntime
|
||||
option = onnxruntime.SessionOptions()
|
||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
option.intra_op_num_threads = 1
|
||||
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
||||
del self.flow.decoder.estimator
|
||||
self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
|
||||
|
||||
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
||||
if self.fp16 is True:
|
||||
llm_embedding = llm_embedding.half()
|
||||
with self.llm_context:
|
||||
for i in self.llm.inference(text=text.to(self.device),
|
||||
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_text=prompt_text.to(self.device),
|
||||
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=llm_embedding.to(self.device)):
|
||||
self.tts_speech_token_dict[uuid].append(i)
|
||||
self.llm_end_dict[uuid] = True
|
||||
|
||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
|
||||
tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
|
||||
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_token=prompt_token.to(self.device),
|
||||
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_feat=prompt_feat.to(self.device),
|
||||
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=embedding.to(self.device),
|
||||
flow_cache=self.flow_cache_dict[uuid])
|
||||
self.flow_cache_dict[uuid] = flow_cache
|
||||
|
||||
# mel overlap fade in out
|
||||
if self.mel_overlap_dict[uuid].shape[2] != 0:
|
||||
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
|
||||
# append hift cache
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
||||
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
||||
else:
|
||||
hift_cache_source = torch.zeros(1, 1, 0)
|
||||
# keep overlap mel and hift cache
|
||||
if finalize is False:
|
||||
self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
|
||||
tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
|
||||
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
||||
'source': tts_source[:, :, -self.source_cache_len:],
|
||||
'speech': tts_speech[:, -self.source_cache_len:]}
|
||||
tts_speech = tts_speech[:, :-self.source_cache_len]
|
||||
else:
|
||||
if speed != 1.0:
|
||||
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
||||
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
||||
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||
return tts_speech
|
||||
|
||||
def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
|
||||
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
||||
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
||||
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
||||
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
|
||||
# this_uuid is used to track variables related to this inference thread
|
||||
this_uuid = str(uuid.uuid1())
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
||||
self.hift_cache_dict[this_uuid] = None
|
||||
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
|
||||
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
|
||||
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
||||
p.start()
|
||||
if stream is True:
|
||||
token_hop_len = self.token_min_hop_len
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
|
||||
.unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
finalize=False)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
||||
# increase token_hop_len for better speech quality
|
||||
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
||||
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
||||
break
|
||||
p.join()
|
||||
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
finalize=True)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
else:
|
||||
# deal with all tokens
|
||||
p.join()
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
finalize=True,
|
||||
speed=speed)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict.pop(this_uuid)
|
||||
self.llm_end_dict.pop(this_uuid)
|
||||
self.mel_overlap_dict.pop(this_uuid)
|
||||
self.hift_cache_dict.pop(this_uuid)
|
||||
|
||||
def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
|
||||
# this_uuid is used to track variables related to this inference thread
|
||||
this_uuid = str(uuid.uuid1())
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
|
||||
self.hift_cache_dict[this_uuid] = None
|
||||
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
|
||||
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
|
||||
if stream is True:
|
||||
token_hop_len = self.token_min_hop_len
|
||||
while True:
|
||||
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
|
||||
.unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
finalize=False)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
||||
# increase token_hop_len for better speech quality
|
||||
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
||||
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
||||
break
|
||||
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
finalize=True)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
else:
|
||||
# deal with all tokens
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
finalize=True,
|
||||
speed=speed)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict.pop(this_uuid)
|
||||
self.llm_end_dict.pop(this_uuid)
|
||||
self.mel_overlap_dict.pop(this_uuid)
|
||||
self.hift_cache_dict.pop(this_uuid)
|
0
cosyvoice/dataset/__init__.py
Normal file
164
cosyvoice/dataset/dataset.py
Normal file
@ -0,0 +1,164 @@
|
||||
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
||||
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import json
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import IterableDataset
|
||||
from cosyvoice.utils.file_utils import read_lists, read_json_lists
|
||||
|
||||
|
||||
class Processor(IterableDataset):
|
||||
|
||||
def __init__(self, source, f, *args, **kw):
|
||||
assert callable(f)
|
||||
self.source = source
|
||||
self.f = f
|
||||
self.args = args
|
||||
self.kw = kw
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.source.set_epoch(epoch)
|
||||
|
||||
def __iter__(self):
|
||||
""" Return an iterator over the source dataset processed by the
|
||||
given processor.
|
||||
"""
|
||||
assert self.source is not None
|
||||
assert callable(self.f)
|
||||
return self.f(iter(self.source), *self.args, **self.kw)
|
||||
|
||||
def apply(self, f):
|
||||
assert callable(f)
|
||||
return Processor(self, f, *self.args, **self.kw)
|
||||
|
||||
|
||||
class DistributedSampler:
|
||||
|
||||
def __init__(self, shuffle=True, partition=True):
|
||||
self.epoch = -1
|
||||
self.update()
|
||||
self.shuffle = shuffle
|
||||
self.partition = partition
|
||||
|
||||
def update(self):
|
||||
assert dist.is_available()
|
||||
if dist.is_initialized():
|
||||
self.rank = dist.get_rank()
|
||||
self.world_size = dist.get_world_size()
|
||||
else:
|
||||
self.rank = 0
|
||||
self.world_size = 1
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is None:
|
||||
self.worker_id = 0
|
||||
self.num_workers = 1
|
||||
else:
|
||||
self.worker_id = worker_info.id
|
||||
self.num_workers = worker_info.num_workers
|
||||
return dict(rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
worker_id=self.worker_id,
|
||||
num_workers=self.num_workers)
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
def sample(self, data):
|
||||
""" Sample data according to rank/world_size/num_workers
|
||||
|
||||
Args:
|
||||
data(List): input data list
|
||||
|
||||
Returns:
|
||||
List: data list after sample
|
||||
"""
|
||||
data = list(range(len(data)))
|
||||
# force datalist even
|
||||
if self.partition:
|
||||
if self.shuffle:
|
||||
random.Random(self.epoch).shuffle(data)
|
||||
if len(data) < self.world_size:
|
||||
data = data * math.ceil(self.world_size / len(data))
|
||||
data = data[:self.world_size]
|
||||
data = data[self.rank::self.world_size]
|
||||
if len(data) < self.num_workers:
|
||||
data = data * math.ceil(self.num_workers / len(data))
|
||||
data = data[:self.num_workers]
|
||||
data = data[self.worker_id::self.num_workers]
|
||||
return data
|
||||
|
||||
|
||||
class DataList(IterableDataset):
|
||||
|
||||
def __init__(self, lists, shuffle=True, partition=True):
|
||||
self.lists = lists
|
||||
self.sampler = DistributedSampler(shuffle, partition)
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.sampler.set_epoch(epoch)
|
||||
|
||||
def __iter__(self):
|
||||
sampler_info = self.sampler.update()
|
||||
indexes = self.sampler.sample(self.lists)
|
||||
for index in indexes:
|
||||
data = dict(src=self.lists[index])
|
||||
data.update(sampler_info)
|
||||
yield data
|
||||
|
||||
|
||||
def Dataset(data_list_file,
|
||||
data_pipeline,
|
||||
mode='train',
|
||||
gan=False,
|
||||
shuffle=True,
|
||||
partition=True,
|
||||
tts_file='',
|
||||
prompt_utt2data=''):
|
||||
""" Construct dataset from arguments
|
||||
|
||||
We have two shuffle stage in the Dataset. The first is global
|
||||
shuffle at shards tar/raw file level. The second is global shuffle
|
||||
at training samples level.
|
||||
|
||||
Args:
|
||||
data_type(str): raw/shard
|
||||
tokenizer (BaseTokenizer): tokenizer to tokenize
|
||||
partition(bool): whether to do data partition in terms of rank
|
||||
"""
|
||||
assert mode in ['train', 'inference']
|
||||
lists = read_lists(data_list_file)
|
||||
if mode == 'inference':
|
||||
with open(tts_file) as f:
|
||||
tts_data = json.load(f)
|
||||
utt2lists = read_json_lists(prompt_utt2data)
|
||||
# filter unnecessary file in inference mode
|
||||
lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
|
||||
dataset = DataList(lists,
|
||||
shuffle=shuffle,
|
||||
partition=partition)
|
||||
if mode == 'inference':
|
||||
# map partial arg to parquet_opener func in inference mode
|
||||
data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
|
||||
if gan is True:
|
||||
# map partial arg to padding func in gan mode
|
||||
data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
|
||||
for func in data_pipeline:
|
||||
dataset = Processor(dataset, func, mode=mode)
|
||||
return dataset
|