"""统一中间格式与转换器接口 定义项目中所有 API 格式共用的中间表示和转换器协议: - UnifiedRequest / UnifiedResponse: 统一的请求/响应数据结构 - InboundTransformer / OutboundTransformer: 入站/出站转换器接口 - StreamProcessor: 流式事件处理器接口 - ClientFormatter: 客户端响应格式化接口 """ from __future__ import annotations import json import logging from dataclasses import dataclass, field from typing import Any, Iterator, Protocol from flask import Response, jsonify import settings from utils.http import forward_request, gen_id, sse_response from utils.request_logger import ( append_client_event, append_upstream_event, attach_client_response, attach_error, attach_upstream_request, attach_upstream_response, finalize_turn, set_stream_summary, ) from utils.usage_tracker import usage_tracker logger = logging.getLogger(__name__) JsonDict = dict[str, Any] # ═══════════════════════════════════════════════════════════ # 统一数据模型 # ═══════════════════════════════════════════════════════════ @dataclass class UnifiedUsage: """标准化的令牌用量统计。""" input_tokens: int = 0 output_tokens: int = 0 total_tokens: int = 0 def to_cc_dict(self) -> JsonDict: return { 'prompt_tokens': self.input_tokens, 'completion_tokens': self.output_tokens, 'total_tokens': self.total_tokens, } def to_responses_dict(self) -> JsonDict: return { 'input_tokens': self.input_tokens, 'output_tokens': self.output_tokens, 'total_tokens': self.total_tokens, } @classmethod def from_cc_dict(cls, d: JsonDict) -> UnifiedUsage: return cls( input_tokens=d.get('prompt_tokens', 0), output_tokens=d.get('completion_tokens', 0), total_tokens=d.get('total_tokens', 0), ) @classmethod def from_responses_dict(cls, d: JsonDict) -> UnifiedUsage: return cls( input_tokens=d.get('input_tokens', 0), output_tokens=d.get('output_tokens', 0), total_tokens=d.get('total_tokens', 0), ) # ═══════════════════════════════════════════════════════════ # 转换器接口 # ═══════════════════════════════════════════════════════════ class OutboundTransformer(Protocol): """出站转换器:将 CC 中间格式转换为上游后端格式。 所有后端(OpenAI Chat / Responses / Anthropic / Gemini)各实现一套, 内部复用各自现有的适配器函数。 """ def build_request(self, payload: JsonDict) -> JsonDict: """将 CC 格式请求体转换为上游格式请求体。""" ... def build_url(self, ctx: Any) -> str: """根据路由上下文构建上游请求 URL。""" ... def build_headers(self, ctx: Any) -> JsonDict: """根据路由上下文构建上游请求头。""" ... def parse_response(self, raw: JsonDict) -> JsonDict: """将上游非流式响应转换回 CC 格式。""" ... def create_stream_processor(self) -> StreamProcessor: """创建该后端对应的流式事件处理器。""" ... class StreamProcessor(Protocol): """流式事件处理器接口。 每个后端的 SSE 格式不同,StreamProcessor 封装了具体的迭代与转换逻辑, 让通用流式处理器不必关心后端差异。 """ def iter_events(self, response: Any) -> Iterator: """从上游 HTTP 响应中迭代原始事件。""" ... def process_event(self, event: Any) -> list: """将单个上游事件转换为输出项列表。 返回值通常是 list[JsonDict](CC chunk), 但 Anthropic→Responses 路径返回 list[str](SSE 字符串)。 """ ... def extract_usage(self, event: Any) -> JsonDict | None: """从上游事件中提取用量信息(如果有的话)。""" ... def finalize(self) -> list: """流结束时产出的收尾项。""" ... class ClientFormatter(Protocol): """客户端响应格式化器。 根据客户端期望的 API 格式(CC 或 Responses),将通用的处理结果 格式化为最终返回给客户端的形态。 """ def format_response(self, cc_response: JsonDict, model: str) -> JsonDict: """格式化非流式响应。""" ... def wrap_stream_item(self, item: Any) -> str: """将单个流式输出项包装为 SSE 字符串。""" ... def format_error(self, message: str) -> str: """构造流式错误消息。""" ... def format_done(self) -> str | None: """构造流结束标记(CC 返回 [DONE],Responses 返回 None)。""" ... def start_events(self) -> list[str]: """流开始前的初始事件(Responses 返回 response.created)。""" ... @property def usage_input_key(self) -> str: """usage 中输入令牌的字段名。""" ... @property def usage_output_key(self) -> str: """usage 中输出令牌的字段名。""" ... # ═══════════════════════════════════════════════════════════ # 通用请求/响应处理器 # ═══════════════════════════════════════════════════════════ def _dbg(message: str) -> None: if settings.get_debug_mode() in ('simple', 'verbose'): logger.info('[通用调试] %s', message) def extract_responses_usage(event_data: JsonDict) -> JsonDict | None: """从原生 Responses 事件中提取 usage(公共辅助)。""" if not isinstance(event_data, dict): return None usage = event_data.get('usage') if isinstance(usage, dict): return usage response_obj = event_data.get('response') if isinstance(response_obj, dict): nested_usage = response_obj.get('usage') if isinstance(nested_usage, dict): return nested_usage return None def handle_non_stream( ctx: Any, outbound: OutboundTransformer, client_fmt: ClientFormatter, payload: JsonDict, turn: JsonDict | None, ) -> Response: """通用非流式处理器。 替代 chat.py 和 responses.py 中的 8 个 _handle_xxx_non_stream 函数。 """ from routes.common import apply_body_modifications, apply_header_modifications, log_usage upstream_payload = outbound.build_request(payload) url = outbound.build_url(ctx) headers = outbound.build_headers(ctx) upstream_payload = apply_body_modifications(upstream_payload, ctx.body_modifications) headers = apply_header_modifications(headers, ctx.header_modifications) upstream_payload['stream'] = False attach_upstream_request(turn, upstream_payload, headers) resp, err = forward_request(url, headers, upstream_payload) if err: attach_error(turn, {'stage': 'forward_request', 'message': 'upstream request failed'}) finalize_turn(turn) return err raw = resp.json() attach_upstream_response(turn, raw) _dbg('上游原始响应=' + json.dumps(raw, ensure_ascii=False, default=str)[:1000]) cc_response = outbound.parse_response(raw) result = client_fmt.format_response(cc_response, ctx.client_model) _dbg('格式化后响应=' + json.dumps(result, ensure_ascii=False, default=str)[:1000]) usage_data = result.get('usage', {}) log_usage('通用', usage_data, input_key=client_fmt.usage_input_key, output_key=client_fmt.usage_output_key) usage_tracker.record( ctx.client_model, usage_data, input_key=client_fmt.usage_input_key, output_key=client_fmt.usage_output_key, ) attach_client_response(turn, result) finalize_turn(turn, usage=usage_data) return jsonify(result) def handle_stream( ctx: Any, outbound: OutboundTransformer, client_fmt: ClientFormatter, payload: JsonDict, turn: JsonDict | None, ) -> Response: """通用流式处理器。 替代 chat.py 和 responses.py 中的 8 个 _handle_xxx_stream 函数。 """ from routes.common import apply_body_modifications, apply_header_modifications upstream_payload = outbound.build_request(payload) url = outbound.build_url(ctx) headers = outbound.build_headers(ctx) upstream_payload = apply_body_modifications(upstream_payload, ctx.body_modifications) headers = apply_header_modifications(headers, ctx.header_modifications) upstream_payload['stream'] = True processor = outbound.create_stream_processor() def generate(): for start_evt in client_fmt.start_events(): yield start_evt attach_upstream_request(turn, upstream_payload, headers) resp, err = forward_request(url, headers, upstream_payload, stream=True) if err: attach_error(turn, {'stage': 'forward_request', 'message': str(err)}) set_stream_summary(turn, {'status': 'error'}) finalize_turn(turn) yield client_fmt.format_error(str(err)) return event_count = 0 client_items: list[str] = [] last_usage: JsonDict | None = None for event in processor.iter_events(resp): append_upstream_event(turn, {'type': 'upstream_event', 'data': event}) extracted = processor.extract_usage(event) if extracted is not None: last_usage = extracted if event_count < 10: _dbg( f'上游事件#{event_count}=' + json.dumps(event, ensure_ascii=False, default=str)[:500] ) for chunk in processor.process_event(event): if isinstance(chunk, dict): chunk['model'] = ctx.client_model wrapped = client_fmt.wrap_stream_item(chunk) client_items.append(wrapped) append_client_event(turn, {'type': 'stream_item', 'data': chunk}) if event_count < 10: _dbg( f'返回片段#{event_count}=' + json.dumps(chunk, ensure_ascii=False, default=str)[:500] ) yield wrapped event_count += 1 for chunk in processor.finalize(): if isinstance(chunk, dict): chunk['model'] = ctx.client_model wrapped = client_fmt.wrap_stream_item(chunk) client_items.append(wrapped) append_client_event(turn, {'type': 'stream_item', 'data': chunk}) yield wrapped done = client_fmt.format_done() if done: append_client_event(turn, {'type': 'done'}) yield done _dbg(f'流式响应结束,共 {event_count} 个事件') usage_tracker.record( ctx.client_model, last_usage, input_key=client_fmt.usage_input_key, output_key=client_fmt.usage_output_key, ) set_stream_summary(turn, { 'event_count': event_count, 'client_item_count': len(client_items), 'usage': last_usage, }) attach_client_response(turn, { 'type': 'stream.summary', 'model': ctx.client_model, 'event_count': len(client_items), 'usage': last_usage, }) finalize_turn(turn, usage=last_usage) return sse_response(generate())