火山语音模型TTS API封装

语音识别智能语音交互大模型

字节的项目组是真的懒呀,连个大模型语音TTS的sdk都没有提供。

在下不才,提供了一个基于python的TTS模块API封装库。

项目主要是使用websockets和aiohttp异步网络请求驱动,大家可以根据需要来引入此项目进行使用。

github

也可以直接复制下面的python类直接导入项目中使用。 我使用的是python 3.13版本,版本只要大于3.7都可以正常使用。

依赖要求

requirements.txt

aiohappyeyeballs==2.6.1
aiohttp==3.11.18
aiosignal==1.3.2
attrs==25.3.0
frozenlist==1.6.0
idna==3.10
multidict==6.4.3
propcache==0.3.1
websockets==15.0.1
yarl==1.20.0

websocket 方式进行调用

tts_websocket.py

import asyncio
import copy
import gzip
import json
import os
import ssl
import uuid
from pathlib import Path

import websockets


MESSAGE_TYPES = {
    11: "audio-only server response",
    12: "frontend server response",
    15: "error message from server"}
MESSAGE_TYPE_SPECIFIC_FLAGS = {
    0: "no sequence number",
    1: "sequence number > 0",
    2: "last message from server (seq < 0)",
    3: "sequence number < 0"
}
MESSAGE_SERIALIZATION_METHODS = {
    0: "no serialization",
    1: "JSON",
    15: "custom type"
}
MESSAGE_COMPRESSIONS = {
    0: "no compression",
    1: "gzip",
    15: "custom compression method"
}

request_json = {
    "app": {
        "appid": "",
        "token": "",
        "cluster": "volcano_tts"
    },
    "user": {
        "uid": "qq_bot"
    },
    "audio": {
        "voice_type": "zh_female_meilinvyou_emo_v2_mars_bigtts",
        "encoding": "mp3",
        "speed_ratio": 1.0,
        "volume_ratio": 1.0,
        "pitch_ratio": 1.0,
    },
    "request": {
        "reqid": "",
        "text": "",
        "text_type": "plain",
        "operation": ""
    }
}


def print_text(*args):
    """
    接收任意多个参数,打印字符串或格式化输出字典

    :param args: 可变参数列表,每个参数可以是 str 或 dict
    """
    for arg in args:
        if isinstance(arg, dict):
            # 如果是字典,转为 json 格式并缩进 2 层
            print(json.dumps(arg, ensure_ascii=False, indent=2))
        elif isinstance(arg, str):
            # 如果是字符串,直接打印
            try:
                # 尝试将字符串解析为 JSON 对象(dict 或 list)
                json_obj = json.loads(arg)
                # 如果是字典或列表,格式化输出
                print(json.dumps(json_obj, ensure_ascii=False, indent=2))
            except json.JSONDecodeError:
                # 如果不是 JSON 字符串,直接打印原始内容
                print(arg)
        elif isinstance(arg, bytes):
            # 如果是字节串,尝试将其解码为字符串并打印
            try:
                print(arg.decode('utf-8'))
            except UnicodeDecodeError:
                print(arg)
        else:
            # 其他类型转为字符串打印
            print(str(arg))


def generate_params(
    self,
    text="",
    operation="submit",
    encoding="mp3",
    speed_ratio=1.0,
    volume_ratio=1.0,
    pitch_ratio=1.0,
    override=None
):
    """
    生成请求参数字典,仅当参数有值时覆盖默认值。
    """
    submit_request_json = copy.deepcopy(request_json)

    # 更新 app 字段(必填项)
    submit_request_json["app"].update({
        "appid": self.appid,
        "token": self.token,
        "cluster": self.cluster
    })

    # 更新 request 字段(非空时赋值)
    submit_request_json["request"].update({
        "reqid": str(uuid.uuid4()),
        "operation": operation if operation else submit_request_json["request"]["operation"],
        "text": text if text else submit_request_json["request"]["text"]
    })

    # 定义 audio 字段映射:(传入参数 -> 默认字段名)
    audio_fields = {
        "encoding": encoding,
        "voice_type": self.voice_type,
        "speed_ratio": speed_ratio,
        "volume_ratio": volume_ratio,
        "pitch_ratio": pitch_ratio
    }

    # 自动更新非空值
    for key, value in audio_fields.items():
        if value not in (None, ""):
            submit_request_json["audio"][key] = value

    # 处理 override
    if override and isinstance(override, dict):
        for k, v in override.items():
            if v not in (None, ""):
                submit_request_json["audio"][k] = v

    return submit_request_json



def generate_dir(file_name, file_path, encoding):
    """
    生成文件保存的目录路径。

    参数:
    - file_name: 文件名。
    - file_path: 文件路径。
    - encoding: 文件编码格式。

    返回:
    - 文件生成的完整路径。
    """
    save_dir = Path(__file__).parent / file_path  # 获取目标目录
    save_dir.mkdir(parents=True, exist_ok=True)  # 如果目录不存在,则递归创建
    file_gen_path = (
        Path(__file__).parent / file_path / (file_name + "." + encoding)).resolve()
    return file_gen_path


class WebSocketTTSClient:
    def __init__(self,
                 appid="",
                 token="",
                 cluster="volcano_tts",
                 voice_type="zh_female_meilinvyou_emo_v2_mars_bigtts",
                 host="openspeech.bytedance.com",
                 encoding="mp3"
                 ):
        self.appid = appid
        self.token = token
        self.cluster = cluster
        self.voice_type = voice_type
        self.host = host
        self.encoding = encoding
        self.api_url = f"wss://{host}/api/v1/tts/ws_binary"
        # 版本号version: b0001 0x1 (4 bits)
        # 头部长度header size: b0001 0x1 (4 bits)
        # 头部长度message type: b0001 0x1 (Full client request) (4bits)
        # 消息类型特定标志message type specific flags: b0000 0x0 (none) (4bits)
        # 序列化方式message serialization method: b0001 0x1 (JSON) (4 bits)
        # 压缩方式message compression: b0001 0x1 (gzip) (4bits)
        # 保留字段reserved data: 0x00 (1 byte)
        self.default_header = bytearray(b'\x11\x10\x11\x00')
        self.reconnect_attempts = 0
        self.websocket = None

    def generate_websocket_params(
        self,
        operation="submit",
        text="",
    ):
        submit_request_json = generate_params(self, operation=operation, text=text)
        payload_bytes = str.encode(json.dumps(submit_request_json))
        payload_bytes = gzip.compress(payload_bytes)  # if no compression, comment this line
        full_client_request = bytearray(self.default_header)
        full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big'))  # payload size(4 bytes)
        full_client_request.extend(payload_bytes)  # payload
        print("\n------------------------ test '{}' -------------------------".format(operation))
        print_text("request json: ", submit_request_json)
        print_text("request bytes: ", full_client_request)
        return full_client_request


    def parse_response(self, res, file):
        print("--------------------------- response ---------------------------")
        # print(f"response raw bytes: {res}")
        protocol_version = res[0] >> 4
        header_size = res[0] & 0x0f
        message_type = res[1] >> 4
        message_type_specific_flags = res[1] & 0x0f
        serialization_method = res[2] >> 4
        message_compression = res[2] & 0x0f
        reserved = res[3]
        header_extensions = res[4:header_size * 4]
        payload = res[header_size * 4:]
        print(f"            Protocol version: {protocol_version:#x} - version {protocol_version}")
        print(f"                 Header size: {header_size:#x} - {header_size * 4} bytes ")
        print(f"                Message type: {message_type:#x} - {MESSAGE_TYPES[message_type]}")
        print(
            f" Message type specific flags: {message_type_specific_flags:#x} - {MESSAGE_TYPE_SPECIFIC_FLAGS[message_type_specific_flags]}")
        print(
            f"Message serialization method: {serialization_method:#x} - {MESSAGE_SERIALIZATION_METHODS[serialization_method]}")
        print(
            f"         Message compression: {message_compression:#x} - {MESSAGE_COMPRESSIONS[message_compression]}")
        print(f"                    Reserved: {reserved:#04x}")
        if header_size != 1:
            print(f"           Header extensions: {header_extensions}")
        if message_type == 0xb:  # audio-only server response
            if message_type_specific_flags == 0:  # no sequence number as ACK
                print("                Payload size: 0")
                return False
            else:
                sequence_number = int.from_bytes(payload[:4], "big", signed=True)
                payload_size = int.from_bytes(payload[4:8], "big", signed=False)
                payload = payload[8:]
                print(f"             Sequence number: {sequence_number}")
                print(f"                Payload size: {payload_size} bytes")
            file.write(payload)
            if sequence_number < 0:
                return True
            else:
                return False
        elif message_type == 0xf:
            code = int.from_bytes(payload[:4], "big", signed=False)
            msg_size = int.from_bytes(payload[4:8], "big", signed=False)
            error_msg = payload[8:]
            if message_compression == 1:
                error_msg = gzip.decompress(error_msg)
            error_msg = str(error_msg, "utf-8")
            print(f"          Error message code: {code}")
            print(f"          Error message size: {msg_size} bytes")
            print(f"               Error message: {error_msg}")
            return True
        elif message_type == 0xc:
            msg_size = int.from_bytes(payload[:4], "big", signed=False)
            payload = payload[4:]
            if message_compression == 1:
                payload = gzip.decompress(payload)
            print(f"            Frontend message: {payload}")
            return True
        else:
            print("undefined message type!")
            return True

    async def connect(self):
        ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
        ssl_context.check_hostname = False  # 测试时可关闭主机名验证
        ssl_context.verify_mode = ssl.CERT_NONE  # 测试时跳过证书验证(生产环境需配置)
        header = {"Authorization": f"Bearer;{self.token}"}
        try:
            self.websocket = await websockets.connect(
                self.api_url,
                ssl=ssl_context,
                additional_headers=header,
                ping_interval=None
            )
            self.reconnect_attempts = 0
            print("连接成功")
        except Exception as e:
            print(f"连接失败: {e}")
            await self.handle_reconnect()

    async def query(self, text, file_name, file_path):
        full_client_request = self.generate_websocket_params(operation="submit", text=text)
        if not self.websocket:
            await self.connect()
        # 构建完整的文件路径,并解析为绝对路径
        file_gen_path = generate_dir(file_name, file_path, self.encoding)
        print(file_gen_path)
        with open(file_gen_path, "wb") as file_to_save:
            await self.websocket.send(full_client_request)
            print("发送消息成功")
            while True:
                try:
                    res = await asyncio.wait_for(self.websocket.recv(), timeout=10)
                    condition = self.parse_response(res, file_to_save)
                    if condition:
                        file_to_save.flush()
                        os.fsync(file_to_save.fileno())
                        break
                except Exception as e:
                    print(f"连接失败: {e}")
                    break

        # 编译器并不是实时监控文件目录的,所以并不会在项目列表中实时刷新,请打开文件管理器手动刷新,以判断有没有生成文件
        print("文件已保存至:", file_gen_path)

    async def handle_reconnect(self):
        """指数退避重连机制"""
        self.reconnect_attempts += 1
        delay = min(2 ** self.reconnect_attempts, 30)  # 最大间隔30秒
        print(f"{delay}秒后尝试重连...")
        await asyncio.sleep(delay)
        await self.connect()

    async def listen(self):
        """持续监听消息"""
        while True:
            try:
                async for message in self.websocket:
                    print(f"收到消息: {message}")
            except websockets.ConnectionClosed:
                print("连接丢失,启动重连...")
                await self.handle_reconnect()

    pass


async def main():
    client = WebSocketTTSClient(
        appid="1352532492",
        token="wp5PV1P3TlJFASORpbAQSw103rpshbIj",
        cluster="volcano_tts",
        voice_type="zh_female_meilinvyou_emo_v2_mars_bigtts",
        host="openspeech.bytedance.com",
        encoding="mp3"
    )
    await client.connect()
    await client.query("你好,我是豆包", "test2", "./data")

    # 永久阻塞
    while True:
        await client.listen()
        

# 模拟调用
if __name__ == '__main__':
    try:
        asyncio.run(main())
    except KeyboardInterrupt:
        print("客户端已手动终止")

http方式进行调用

tts_http.py

import asyncio
import base64
import copy
import json
import uuid
from pathlib import Path

import aiohttp


MESSAGE_TYPES = {
    11: "audio-only server response",
    12: "frontend server response",
    15: "error message from server"}
MESSAGE_TYPE_SPECIFIC_FLAGS = {
    0: "no sequence number",
    1: "sequence number > 0",
    2: "last message from server (seq < 0)",
    3: "sequence number < 0"
}
MESSAGE_SERIALIZATION_METHODS = {
    0: "no serialization",
    1: "JSON",
    15: "custom type"
}
MESSAGE_COMPRESSIONS = {
    0: "no compression",
    1: "gzip",
    15: "custom compression method"
}

request_json = {
    "app": {
        "appid": "",
        "token": "",
        "cluster": "volcano_tts"
    },
    "user": {
        "uid": "qq_bot"
    },
    "audio": {
        "voice_type": "zh_female_meilinvyou_emo_v2_mars_bigtts",
        "encoding": "mp3",
        "speed_ratio": 1.0,
        "volume_ratio": 1.0,
        "pitch_ratio": 1.0,
    },
    "request": {
        "reqid": "",
        "text": "",
        "text_type": "plain",
        "operation": ""
    }
}


def print_text(*args):
    """
    接收任意多个参数,打印字符串或格式化输出字典

    :param args: 可变参数列表,每个参数可以是 str 或 dict
    """
    for arg in args:
        if isinstance(arg, dict):
            # 如果是字典,转为 json 格式并缩进 2 层
            print(json.dumps(arg, ensure_ascii=False, indent=2))
        elif isinstance(arg, str):
            # 如果是字符串,直接打印
            try:
                # 尝试将字符串解析为 JSON 对象(dict 或 list)
                json_obj = json.loads(arg)
                # 如果是字典或列表,格式化输出
                print(json.dumps(json_obj, ensure_ascii=False, indent=2))
            except json.JSONDecodeError:
                # 如果不是 JSON 字符串,直接打印原始内容
                print(arg)
        elif isinstance(arg, bytes):
            # 如果是字节串,尝试将其解码为字符串并打印
            try:
                print(arg.decode('utf-8'))
            except UnicodeDecodeError:
                print(arg)
        else:
            # 其他类型转为字符串打印
            print(str(arg))



def generate_params(
    self,
    text="",
    operation="submit",
    encoding="mp3",
    speed_ratio=1.0,
    volume_ratio=1.0,
    pitch_ratio=1.0,
    override=None
):
    """
    生成请求参数字典,仅当参数有值时覆盖默认值。
    """
    submit_request_json = copy.deepcopy(request_json)

    # 更新 app 字段(必填项)
    submit_request_json["app"].update({
        "appid": self.appid,
        "token": self.token,
        "cluster": self.cluster
    })

    # 更新 request 字段(非空时赋值)
    submit_request_json["request"].update({
        "reqid": str(uuid.uuid4()),
        "operation": operation if operation else submit_request_json["request"]["operation"],
        "text": text if text else submit_request_json["request"]["text"]
    })

    # 定义 audio 字段映射:(传入参数 -> 默认字段名)
    audio_fields = {
        "encoding": encoding,
        "voice_type": self.voice_type,
        "speed_ratio": speed_ratio,
        "volume_ratio": volume_ratio,
        "pitch_ratio": pitch_ratio
    }

    # 自动更新非空值
    for key, value in audio_fields.items():
        if value not in (None, ""):
            submit_request_json["audio"][key] = value

    # 处理 override
    if override and isinstance(override, dict):
        for k, v in override.items():
            if v not in (None, ""):
                submit_request_json["audio"][k] = v

    return submit_request_json



def generate_dir(file_name, file_path, encoding):
    """
    生成文件保存的目录路径。

    参数:
    - file_name: 文件名。
    - file_path: 文件路径。
    - encoding: 文件编码格式。

    返回:
    - 文件生成的完整路径。
    """
    save_dir = Path(__file__).parent / file_path  # 获取目标目录
    save_dir.mkdir(parents=True, exist_ok=True)  # 如果目录不存在,则递归创建
    file_gen_path = (
        Path(__file__).parent / file_path / (file_name + "." + encoding)).resolve()
    return file_gen_path


class HTTPClient:
    def __init__(
        self,
        appid="",
        token="",
        cluster="volcano_tts",
        voice_type="zh_female_meilinvyou_emo_v2_mars_bigtts",
        host="openspeech.bytedance.com",
        encoding="mp3"
    ):
        self.appid = appid
        self.token = token
        self.cluster = cluster
        self.voice_type = voice_type
        self.host = host
        self.api_url = f"https://{host}/api/v1/tts"
        self.default_header = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer;{token}"
        }
        self.encoding = encoding

    """
    将 base64 字符串保存为 MP3 文件

    :param file_name:
    :param file_path:
    :param base64_str: 包含或不包含 data URI 前缀的 base64 字符串
    :param output_file: 输出的 MP3 文件路径(例如:"output.mp3")
    """
    def save_base64_to_mp3(self, base64_str: str, file_path: str, file_name: str):

        file_gen_path = generate_dir(file_name, file_path, self.encoding)
        # 去除可能存在的 data URI 前缀
        if base64_str.startswith("data:"):
            base64_str = base64_str.split(",", 1)[1]
        # 解码 base64 数据
        audio_data = base64.b64decode(base64_str)
        # 写入文件
        with open(file_gen_path, "wb") as f:
            f.write(audio_data)
        print(f"音频文件已保存至: {file_gen_path}")

    async def query(self, text, file_name, file_path=""):
        if file_name is None and file_path is None:
            print("file_name and file_path cannot be None at the same time.")
            return False
        request_body = generate_params(self, operation="query", text=text)
        async with aiohttp.ClientSession() as session:
            print_text("request:", request_body)
            async with session.post(
                self.api_url,
                headers=self.default_header,
                json=request_body
            ) as response:
                if response.status == 200:
                    data: bytes = await response.read()
                    data_dict: dict = json.loads(data)
                    print_text("response:", json)
                    self.save_base64_to_mp3(data_dict["data"], file_path, file_name)
                    return True
                else:
                    print(f"请求失败: {await response.text()}")
                    return False


async def main():
    client = HTTPClient(
        appid="1352532492",
        token="wp5PV1P3TlJFASORpbAQSw103rpshbIj",
        cluster="volcano_tts",
        voice_type="zh_female_meilinvyou_emo_v2_mars_bigtts",
        host="openspeech.bytedance.com",
        encoding="mp3"
    )
    await client.query("你好,我是豆包", "test_http", "")


if __name__ == '__main__':
    asyncio.run(main())

如果两种方法都用到,也可以将最上方的函数抽离出来。

tts_config.py


import copy
import json
import uuid
from pathlib import Path

MESSAGE_TYPES = {
    11: "audio-only server response",
    12: "frontend server response",
    15: "error message from server"}
MESSAGE_TYPE_SPECIFIC_FLAGS = {
    0: "no sequence number",
    1: "sequence number > 0",
    2: "last message from server (seq < 0)",
    3: "sequence number < 0"
}
MESSAGE_SERIALIZATION_METHODS = {
    0: "no serialization",
    1: "JSON",
    15: "custom type"
}
MESSAGE_COMPRESSIONS = {
    0: "no compression",
    1: "gzip",
    15: "custom compression method"
}

request_json = {
    "app": {
        "appid": "",
        "token": "",
        "cluster": "volcano_tts"
    },
    "user": {
        "uid": "qq_bot"
    },
    "audio": {
        "voice_type": "zh_female_meilinvyou_emo_v2_mars_bigtts",
        "encoding": "mp3",
        "speed_ratio": 1.0,
        "volume_ratio": 1.0,
        "pitch_ratio": 1.0,
    },
    "request": {
        "reqid": "",
        "text": "",
        "text_type": "plain",
        "operation": ""
    }
}


def print_text(*args):
    """
    接收任意多个参数,打印字符串或格式化输出字典

    :param args: 可变参数列表,每个参数可以是 str 或 dict
    """
    for arg in args:
        if isinstance(arg, dict):
            # 如果是字典,转为 json 格式并缩进 2 层
            print(json.dumps(arg, ensure_ascii=False, indent=2))
        elif isinstance(arg, str):
            # 如果是字符串,直接打印
            try:
                # 尝试将字符串解析为 JSON 对象(dict 或 list)
                json_obj = json.loads(arg)
                # 如果是字典或列表,格式化输出
                print(json.dumps(json_obj, ensure_ascii=False, indent=2))
            except json.JSONDecodeError:
                # 如果不是 JSON 字符串,直接打印原始内容
                print(arg)
        elif isinstance(arg, bytes):
            # 如果是字节串,尝试将其解码为字符串并打印
            try:
                print(arg.decode('utf-8'))
            except UnicodeDecodeError:
                print(arg)
        else:
            # 其他类型转为字符串打印
            print(str(arg))


def generate_params(
    self,
    text="",
    operation="submit",
    encoding="mp3",
    speed_ratio=1.0,
    volume_ratio=1.0,
    pitch_ratio=1.0,
    override=None
):
    """
        根据输入的参数生成请求参数字典。

        参数:
        - text: 要合成语音的文本。
        - operation: 请求的操作类型,默认为"submit"。
        - encoding: 音频编码格式,默认为"mp3"。
        - speed_ratio: 语速比率,默认为1.0。
        - volume_ratio: 音量比率,默认为1.0。
        - pitch_ratio: 语调比率,默认为1.0。
        - override: 是否覆盖请求参数,如果提供,则会用此参数更新请求参数。

        返回:
        - 生成的请求参数字典。
    """
    submit_request_json = copy.deepcopy(request_json)
    submit_request_json["app"]["appid"] = self.appid
    submit_request_json["app"]["token"] = self.token
    submit_request_json["app"]["cluster"] = self.cluster
    submit_request_json["audio"]["voice_type"] = self.voice_type
    submit_request_json["request"]["reqid"] = str(uuid.uuid4())
    submit_request_json["request"]["operation"] = operation
    submit_request_json["request"]["text"] = text
    submit_request_json["audio"]["encoding"] = encoding
    submit_request_json["audio"]["speed_ratio"] = speed_ratio
    submit_request_json["audio"]["volume_ratio"] = volume_ratio
    submit_request_json["audio"]["pitch_ratio"] = pitch_ratio
    if override:
        submit_request_json = self.deep_update(submit_request_json, override)
    return submit_request_json


def generate_dir(file_name, file_path, encoding):
    """
        生成文件保存的目录路径。

        参数:
        - file_name: 文件名。
        - file_path: 文件路径。
        - encoding: 文件编码格式。

        返回:
        - 文件生成的完整路径。
    """
    save_dir = Path(__file__).parent / file_path  # 获取目标目录
    save_dir.mkdir(parents=True, exist_ok=True)  # 如果目录不存在,则递归创建
    file_gen_path = (
        Path(__file__).parent / file_path / (file_name + "." + encoding)).resolve()
    return file_gen_path

当然,还是希望官方尽快出一个健壮的SDK库发布出来,此项目作为抛砖引玉。

0
0
0
0
关于作者

文章

0

获赞

0

收藏

0

相关资源
字节跳动大数据容器化构建与落地实践
随着字节跳动旗下业务的快速发展,数据急剧膨胀,原有的大数据架构在面临日趋复杂的业务需求时逐渐显现疲态。而伴随着大数据架构向云原生演进的行业趋势,字节跳动也对大数据体系进行了云原生改造。本次分享将详细介绍字节跳动大数据容器化的演进与实践。
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论