155 lines
5.0 KiB
Python
155 lines
5.0 KiB
Python
from flask import Flask, request, jsonify, send_from_directory
|
||
from flask_cors import CORS
|
||
import time
|
||
import os
|
||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||
from funasr import AutoModel
|
||
import edge_tts
|
||
import asyncio
|
||
import langid
|
||
import tempfile
|
||
|
||
app = Flask(__name__)
|
||
CORS(app)
|
||
|
||
# 配置参数
|
||
AUDIO_RATE = 16000
|
||
OUTPUT_DIR = "./output"
|
||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||
|
||
# 初始化模型
|
||
print("Loading models...")
|
||
model_dir = "D:/AI/download/SenseVoiceSmall"
|
||
model_senceVoice = AutoModel(model=model_dir, trust_remote_code=True)
|
||
|
||
# 加载Qwen2.5大语言模型
|
||
model_name = "D:/AI/download/Qwen2.5-1.5B-Instruct"
|
||
model = AutoModelForCausalLM.from_pretrained(
|
||
model_name,
|
||
torch_dtype="auto",
|
||
device_map="auto"
|
||
)
|
||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||
print("All models loaded!")
|
||
|
||
# 语言映射
|
||
language_speaker = {
|
||
"ja": "ja-JP-NanamiNeural",
|
||
"fr": "fr-FR-DeniseNeural",
|
||
"es": "ca-ES-JoanaNeural",
|
||
"de": "de-DE-KatjaNeural",
|
||
"zh": "zh-CN-XiaoyiNeural",
|
||
"en": "en-US-AnaNeural",
|
||
}
|
||
|
||
# 全局变量
|
||
def process_audio(input_path):
|
||
# 语音识别
|
||
res = model_senceVoice.generate(
|
||
input=input_path,
|
||
cache={},
|
||
language="auto",
|
||
use_itn=False,
|
||
)
|
||
prompt = res[0]['text'].split(">")[-1]
|
||
print("ASR OUT:", prompt)
|
||
|
||
# 大模型处理
|
||
messages = [
|
||
{"role": "system", "content": "你叫千问,是一个18岁的女大学生,性格活泼开朗,说话俏皮"},
|
||
{"role": "user", "content": prompt},
|
||
]
|
||
text = tokenizer.apply_chat_template(
|
||
messages,
|
||
tokenize=False,
|
||
add_generation_prompt=True,
|
||
)
|
||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||
generated_ids = model.generate(**model_inputs, max_new_tokens=512)
|
||
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids
|
||
in zip(model_inputs.input_ids, generated_ids)]
|
||
output_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||
print("Answer:", output_text)
|
||
|
||
# 语种识别
|
||
lang, _ = langid.classify(output_text)
|
||
speaker = language_speaker.get(lang, "zh-CN-XiaoyiNeural")
|
||
|
||
# 语音合成
|
||
output_file = os.path.join(OUTPUT_DIR, f"response_{int(time.time())}.mp3")
|
||
asyncio.run(edge_tts.Communicate(output_text, speaker).save(output_file))
|
||
|
||
return output_file
|
||
|
||
# 新增ASR专用接口
|
||
@app.route('/asr', methods=['POST'])
|
||
def handle_asr():
|
||
if 'audio' not in request.files:
|
||
return jsonify({"error": "No audio file provided"}), 400
|
||
|
||
try:
|
||
audio_file = request.files['audio']
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
||
audio_file.save(tmp.name)
|
||
|
||
# 语音识别
|
||
res = model_senceVoice.generate(
|
||
input=tmp.name,
|
||
cache={},
|
||
language="auto",
|
||
use_itn=False,
|
||
)
|
||
asr_text = res[0]['text'].split(">")[-1]
|
||
|
||
return jsonify({
|
||
"asr_text": asr_text
|
||
})
|
||
|
||
except Exception as e:
|
||
return jsonify({"error": str(e)}), 500
|
||
|
||
# 大模型处理接口
|
||
@app.route('/generate', methods=['POST'])
|
||
def handle_generate():
|
||
try:
|
||
data = request.get_json()
|
||
asr_text = data.get('asr_text', '')
|
||
|
||
if not asr_text:
|
||
return jsonify({"error": "No ASR text provided"}), 400
|
||
|
||
# 大模型处理
|
||
messages = [
|
||
{"role": "system", "content": "你叫千问,是一个18岁的女大学生,性格活泼开朗,说话俏皮"},
|
||
{"role": "user", "content": asr_text},
|
||
]
|
||
text = tokenizer.apply_chat_template(
|
||
messages,
|
||
tokenize=False,
|
||
add_generation_prompt=True,
|
||
)
|
||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||
generated_ids = model.generate(**model_inputs, max_new_tokens=512)
|
||
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids
|
||
in zip(model_inputs.input_ids, generated_ids)]
|
||
answer_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||
|
||
lang, _ = langid.classify(answer_text)
|
||
speaker = language_speaker.get(lang, "zh-CN-XiaoyiNeural")
|
||
output_file = os.path.join(OUTPUT_DIR, f"response_{int(time.time())}.mp3")
|
||
asyncio.run(edge_tts.Communicate(answer_text, speaker).save(output_file))
|
||
|
||
return jsonify({
|
||
"answer_text": answer_text,
|
||
"audio_url": f"/audio/{os.path.basename(output_file)}"
|
||
})
|
||
|
||
except Exception as e:
|
||
return jsonify({"error": str(e)}), 500
|
||
|
||
@app.route('/audio/<filename>')
|
||
def get_audio(filename):
|
||
return send_from_directory(OUTPUT_DIR, filename)
|
||
|
||
if __name__ == '__main__':
|
||
app.run(port=5000, threaded=True) |