llm-asr-tts/3.07backend_service.py
LIRUI ad1fee9c4b
Some checks failed
Lint / quick-checks (push) Has been cancelled
Lint / flake8-py3 (push) Has been cancelled
Close inactive issues / close-issues (push) Failing after 2s
first commit
2025-03-17 00:41:41 +08:00

158 lines
5.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS
import time
import os
import logging
from transformers import AutoTokenizer, AutoModelForCausalLM
from funasr import AutoModel
import edge_tts
import asyncio
import langid
import tempfile
app = Flask(__name__)
CORS(app)
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('app.log'), # 文件日志
logging.StreamHandler() # 控制台日志
]
)
# 配置参数
AUDIO_RATE = 16000
OUTPUT_DIR = "./output"
os.makedirs(OUTPUT_DIR, exist_ok=True)
# 初始化模型
app.logger.info("Loading models...")
model_dir = "D:/AI/download/SenseVoiceSmall"
model_senceVoice = AutoModel(model=model_dir, trust_remote_code=True)
# 加载Qwen2.5大语言模型
model_name = "D:/AI/download/Qwen2.5-1.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
app.logger.info("All models loaded!")
# 语言映射
language_speaker = {
"ja": "ja-JP-NanamiNeural",
"fr": "fr-FR-DeniseNeural",
"es": "ca-ES-JoanaNeural",
"de": "de-DE-KatjaNeural",
"zh": "zh-CN-XiaoyiNeural",
"en": "en-US-AnaNeural",
}
# ---------------------- 接口路由 ----------------------
@app.route('/asr', methods=['POST'])
def handle_asr():
"""处理语音识别请求"""
if 'audio' not in request.files:
return jsonify({"error": "No audio file provided"}), 400
try:
audio_file = request.files['audio']
app.logger.info(f"Received audio file: {audio_file.filename}")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
audio_file.save(tmp.name)
# 语音识别
res = model_senceVoice.generate(
input=tmp.name,
cache={},
language="auto",
use_itn=False,
)
asr_text = res[0]['text'].split(">")[-1]
app.logger.info(f"ASR识别结果: {asr_text}")
return jsonify({"asr_text": asr_text})
except Exception as e:
app.logger.error(f"ASR处理异常: {str(e)}")
return jsonify({"error": str(e)}), 500
@app.route('/generate_text', methods=['POST'])
def handle_generate_text():
"""处理大模型文本生成请求"""
try:
data = request.get_json()
asr_text = data.get('asr_text', '')
app.logger.info(f"收到ASR文本: {asr_text}")
if not asr_text:
return jsonify({"error": "No ASR text provided"}), 400
# 构建对话模板
messages = [
{"role": "system", "content": "你叫千问是一个18岁的女大学生性格活泼开朗说话俏皮"},
{"role": "user", "content": asr_text},
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
# 生成回复
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(**model_inputs, max_new_tokens=512)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids
in zip(model_inputs.input_ids, generated_ids)]
answer_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
app.logger.info(f"大模型回复: {answer_text}")
return jsonify({"answer_text": answer_text})
except Exception as e:
app.logger.error(f"文本生成异常: {str(e)}")
return jsonify({"error": str(e)}), 500
@app.route('/generate_audio', methods=['POST'])
def handle_generate_audio():
"""处理语音合成请求"""
try:
data = request.get_json()
answer_text = data.get('answer_text', '')
app.logger.info(f"收到待合成文本: {answer_text}")
if not answer_text:
return jsonify({"error": "No answer text provided"}), 400
# 语种识别
lang, _ = langid.classify(answer_text)
speaker = language_speaker.get(lang, "zh-CN-XiaoyiNeural")
app.logger.info(f"识别到语言: {lang}, 使用发音人: {speaker}")
# 语音合成
output_file = os.path.join(OUTPUT_DIR, f"response_{int(time.time())}.mp3")
asyncio.run(edge_tts.Communicate(answer_text, speaker).save(output_file))
app.logger.info(f"语音合成完成,保存路径: {output_file}")
return jsonify({
"audio_url": f"/audio/{os.path.basename(output_file)}"
})
except Exception as e:
app.logger.error(f"语音合成异常: {str(e)}")
return jsonify({"error": str(e)}), 500
@app.route('/audio/<filename>')
def get_audio(filename):
"""音频文件下载接口"""
return send_from_directory(OUTPUT_DIR, filename)
if __name__ == '__main__':
app.logger.info("服务启动端口5000")
app.run(port=5000, threaded=True)