158 lines
5.2 KiB
Python
158 lines
5.2 KiB
Python
from flask import Flask, request, jsonify, send_from_directory
|
||
from flask_cors import CORS
|
||
import time
|
||
import os
|
||
import logging
|
||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||
from funasr import AutoModel
|
||
import edge_tts
|
||
import asyncio
|
||
import langid
|
||
import tempfile
|
||
|
||
app = Flask(__name__)
|
||
CORS(app)
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||
handlers=[
|
||
logging.FileHandler('app.log'), # 文件日志
|
||
logging.StreamHandler() # 控制台日志
|
||
]
|
||
)
|
||
|
||
# 配置参数
|
||
AUDIO_RATE = 16000
|
||
OUTPUT_DIR = "./output"
|
||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||
|
||
# 初始化模型
|
||
app.logger.info("Loading models...")
|
||
model_dir = "D:/AI/download/SenseVoiceSmall"
|
||
model_senceVoice = AutoModel(model=model_dir, trust_remote_code=True)
|
||
|
||
# 加载Qwen2.5大语言模型
|
||
model_name = "D:/AI/download/Qwen2.5-1.5B-Instruct"
|
||
model = AutoModelForCausalLM.from_pretrained(
|
||
model_name,
|
||
torch_dtype="auto",
|
||
device_map="auto"
|
||
)
|
||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||
app.logger.info("All models loaded!")
|
||
|
||
# 语言映射
|
||
language_speaker = {
|
||
"ja": "ja-JP-NanamiNeural",
|
||
"fr": "fr-FR-DeniseNeural",
|
||
"es": "ca-ES-JoanaNeural",
|
||
"de": "de-DE-KatjaNeural",
|
||
"zh": "zh-CN-XiaoyiNeural",
|
||
"en": "en-US-AnaNeural",
|
||
}
|
||
|
||
# ---------------------- 接口路由 ----------------------
|
||
@app.route('/asr', methods=['POST'])
|
||
def handle_asr():
|
||
"""处理语音识别请求"""
|
||
if 'audio' not in request.files:
|
||
return jsonify({"error": "No audio file provided"}), 400
|
||
|
||
try:
|
||
audio_file = request.files['audio']
|
||
app.logger.info(f"Received audio file: {audio_file.filename}")
|
||
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
||
audio_file.save(tmp.name)
|
||
|
||
# 语音识别
|
||
res = model_senceVoice.generate(
|
||
input=tmp.name,
|
||
cache={},
|
||
language="auto",
|
||
use_itn=False,
|
||
)
|
||
asr_text = res[0]['text'].split(">")[-1]
|
||
app.logger.info(f"ASR识别结果: {asr_text}")
|
||
|
||
return jsonify({"asr_text": asr_text})
|
||
|
||
except Exception as e:
|
||
app.logger.error(f"ASR处理异常: {str(e)}")
|
||
return jsonify({"error": str(e)}), 500
|
||
|
||
@app.route('/generate_text', methods=['POST'])
|
||
def handle_generate_text():
|
||
"""处理大模型文本生成请求"""
|
||
try:
|
||
data = request.get_json()
|
||
asr_text = data.get('asr_text', '')
|
||
app.logger.info(f"收到ASR文本: {asr_text}")
|
||
|
||
if not asr_text:
|
||
return jsonify({"error": "No ASR text provided"}), 400
|
||
|
||
# 构建对话模板
|
||
messages = [
|
||
{"role": "system", "content": "你叫千问,是一个18岁的女大学生,性格活泼开朗,说话俏皮"},
|
||
{"role": "user", "content": asr_text},
|
||
]
|
||
text = tokenizer.apply_chat_template(
|
||
messages,
|
||
tokenize=False,
|
||
add_generation_prompt=True,
|
||
)
|
||
|
||
# 生成回复
|
||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||
generated_ids = model.generate(**model_inputs, max_new_tokens=512)
|
||
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids
|
||
in zip(model_inputs.input_ids, generated_ids)]
|
||
answer_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||
app.logger.info(f"大模型回复: {answer_text}")
|
||
|
||
return jsonify({"answer_text": answer_text})
|
||
|
||
except Exception as e:
|
||
app.logger.error(f"文本生成异常: {str(e)}")
|
||
return jsonify({"error": str(e)}), 500
|
||
|
||
@app.route('/generate_audio', methods=['POST'])
|
||
def handle_generate_audio():
|
||
"""处理语音合成请求"""
|
||
try:
|
||
data = request.get_json()
|
||
answer_text = data.get('answer_text', '')
|
||
app.logger.info(f"收到待合成文本: {answer_text}")
|
||
|
||
if not answer_text:
|
||
return jsonify({"error": "No answer text provided"}), 400
|
||
|
||
# 语种识别
|
||
lang, _ = langid.classify(answer_text)
|
||
speaker = language_speaker.get(lang, "zh-CN-XiaoyiNeural")
|
||
app.logger.info(f"识别到语言: {lang}, 使用发音人: {speaker}")
|
||
|
||
# 语音合成
|
||
output_file = os.path.join(OUTPUT_DIR, f"response_{int(time.time())}.mp3")
|
||
asyncio.run(edge_tts.Communicate(answer_text, speaker).save(output_file))
|
||
app.logger.info(f"语音合成完成,保存路径: {output_file}")
|
||
|
||
return jsonify({
|
||
"audio_url": f"/audio/{os.path.basename(output_file)}"
|
||
})
|
||
|
||
except Exception as e:
|
||
app.logger.error(f"语音合成异常: {str(e)}")
|
||
return jsonify({"error": str(e)}), 500
|
||
|
||
@app.route('/audio/<filename>')
|
||
def get_audio(filename):
|
||
"""音频文件下载接口"""
|
||
return send_from_directory(OUTPUT_DIR, filename)
|
||
|
||
if __name__ == '__main__':
|
||
app.logger.info("服务启动,端口:5000")
|
||
app.run(port=5000, threaded=True) |