llm-asr-tts/3.04backend_service.py
LIRUI ad1fee9c4b
Some checks failed
Lint / quick-checks (push) Has been cancelled
Lint / flake8-py3 (push) Has been cancelled
Close inactive issues / close-issues (push) Failing after 2s
first commit
2025-03-17 00:41:41 +08:00

155 lines
5.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS
import time
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
from funasr import AutoModel
import edge_tts
import asyncio
import langid
import tempfile
app = Flask(__name__)
CORS(app)
# 配置参数
AUDIO_RATE = 16000
OUTPUT_DIR = "./output"
os.makedirs(OUTPUT_DIR, exist_ok=True)
# 初始化模型
print("Loading models...")
model_dir = "D:/AI/download/SenseVoiceSmall"
model_senceVoice = AutoModel(model=model_dir, trust_remote_code=True)
# 加载Qwen2.5大语言模型
model_name = "D:/AI/download/Qwen2.5-1.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("All models loaded!")
# 语言映射
language_speaker = {
"ja": "ja-JP-NanamiNeural",
"fr": "fr-FR-DeniseNeural",
"es": "ca-ES-JoanaNeural",
"de": "de-DE-KatjaNeural",
"zh": "zh-CN-XiaoyiNeural",
"en": "en-US-AnaNeural",
}
# 全局变量
def process_audio(input_path):
# 语音识别
res = model_senceVoice.generate(
input=input_path,
cache={},
language="auto",
use_itn=False,
)
prompt = res[0]['text'].split(">")[-1]
print("ASR OUT:", prompt)
# 大模型处理
messages = [
{"role": "system", "content": "你叫千问是一个18岁的女大学生性格活泼开朗说话俏皮"},
{"role": "user", "content": prompt},
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(**model_inputs, max_new_tokens=512)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids
in zip(model_inputs.input_ids, generated_ids)]
output_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print("Answer:", output_text)
# 语种识别
lang, _ = langid.classify(output_text)
speaker = language_speaker.get(lang, "zh-CN-XiaoyiNeural")
# 语音合成
output_file = os.path.join(OUTPUT_DIR, f"response_{int(time.time())}.mp3")
asyncio.run(edge_tts.Communicate(output_text, speaker).save(output_file))
return output_file
# 新增ASR专用接口
@app.route('/asr', methods=['POST'])
def handle_asr():
if 'audio' not in request.files:
return jsonify({"error": "No audio file provided"}), 400
try:
audio_file = request.files['audio']
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
audio_file.save(tmp.name)
# 语音识别
res = model_senceVoice.generate(
input=tmp.name,
cache={},
language="auto",
use_itn=False,
)
asr_text = res[0]['text'].split(">")[-1]
return jsonify({
"asr_text": asr_text
})
except Exception as e:
return jsonify({"error": str(e)}), 500
# 大模型处理接口
@app.route('/generate', methods=['POST'])
def handle_generate():
try:
data = request.get_json()
asr_text = data.get('asr_text', '')
if not asr_text:
return jsonify({"error": "No ASR text provided"}), 400
# 大模型处理
messages = [
{"role": "system", "content": "你叫千问是一个18岁的女大学生性格活泼开朗说话俏皮"},
{"role": "user", "content": asr_text},
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(**model_inputs, max_new_tokens=512)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids
in zip(model_inputs.input_ids, generated_ids)]
answer_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
lang, _ = langid.classify(answer_text)
speaker = language_speaker.get(lang, "zh-CN-XiaoyiNeural")
output_file = os.path.join(OUTPUT_DIR, f"response_{int(time.time())}.mp3")
asyncio.run(edge_tts.Communicate(answer_text, speaker).save(output_file))
return jsonify({
"answer_text": answer_text,
"audio_url": f"/audio/{os.path.basename(output_file)}"
})
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route('/audio/<filename>')
def get_audio(filename):
return send_from_directory(OUTPUT_DIR, filename)
if __name__ == '__main__':
app.run(port=5000, threaded=True)