api2cursor/routes/common.py
2026-03-22 08:24:19 +08:00

345 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""路由层公共辅助
收敛多个数据面路由都会用到的上下文解析、上游目标构造、日志输出和
SSE 消息拼装逻辑,避免 `chat.py` 和 `responses.py` 各自维护重复实现。
"""
from __future__ import annotations
from dataclasses import dataclass
import json
import logging
from typing import Any
import settings
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class RouteContext:
"""数据面路由使用的标准请求上下文。
路由层会先根据客户端模型名解析出统一上下文,后续处理函数只需要关心
上游模型、后端类型、目标地址、鉴权信息、流式标记和自定义指令,
而不必重复访问配置层。
"""
client_model: str
upstream_model: str
backend: str
target_url: str
api_key: str
is_stream: bool
custom_instructions: str
instructions_position: str
body_modifications: dict
header_modifications: dict
def build_route_context(client_model: str, is_stream: bool) -> RouteContext:
"""解析模型映射,得到当前请求的统一路由上下文。"""
mapping = settings.resolve_model(client_model)
return RouteContext(
client_model=client_model,
upstream_model=mapping['upstream_model'],
backend=mapping['backend'],
target_url=mapping['target_url'],
api_key=mapping['api_key'],
is_stream=is_stream,
custom_instructions=mapping.get('custom_instructions', ''),
instructions_position=mapping.get('instructions_position', 'prepend'),
body_modifications=mapping.get('body_modifications', {}),
header_modifications=mapping.get('header_modifications', {}),
)
def log_route_context(route_name: str, ctx: RouteContext, *, extra: str = '') -> None:
"""统一输出路由级日志,避免不同入口的日志格式逐渐漂移。"""
parts = [
f'[{route_name}]',
f'模型={ctx.client_model}',
f'上游模型={ctx.upstream_model}',
f'后端={ctx.backend}',
f'流式={ctx.is_stream}',
]
if extra:
parts.append(extra)
logger.info(' '.join(parts))
def log_usage(
route_name: str,
usage: dict[str, Any],
*,
input_key: str,
output_key: str,
) -> None:
"""统一输出令牌统计日志。
不同协议对 usage 字段命名不一致,这里只接收字段名,不在调用方重复拼接日志文案。
"""
logger.info(
'[%s] 请求完成 输入令牌=%s 输出令牌=%s',
route_name,
usage.get(input_key, 0),
usage.get(output_key, 0),
)
def sse_data_message(data: Any) -> str:
"""构造仅包含 data 的 SSE 消息。"""
payload = data if isinstance(data, str) else json.dumps(data, ensure_ascii=False)
return f'data: {payload}\n\n'
def sse_event_message(event_type: str, data: Any) -> str:
"""构造带 event 名称的 SSE 消息。"""
payload = data if isinstance(data, str) else json.dumps(data, ensure_ascii=False)
return f'event: {event_type}\ndata: {payload}\n\n'
def responses_error_event(message: str) -> str:
"""构造 Responses 流式接口使用的错误事件。"""
return sse_event_message('error', {'error': message})
# ─── 自定义指令注入 ──────────────────────────────
def _merge_text(custom: str, existing: str, position: str) -> str:
"""根据 position 决定自定义指令与原有内容的拼接顺序。"""
if not existing:
return custom
if position == 'append':
return existing + '\n\n' + custom
return custom + '\n\n' + existing
def inject_instructions_cc(payload: dict[str, Any], instructions: str, position: str = 'prepend') -> dict[str, Any]:
"""向 Chat Completions 请求注入自定义指令。
position='prepend' 时放在 system 消息开头,'append' 时放在末尾。
"""
if not instructions:
return payload
messages = payload.get('messages', [])
if messages and messages[0].get('role') == 'system':
first = messages[0]
original = first.get('content') or ''
first['content'] = _merge_text(instructions, original, position)
else:
messages.insert(0, {'role': 'system', 'content': instructions})
payload['messages'] = messages
logger.info('已注入自定义指令到 CC system 消息 (%d 字符, %s)', len(instructions), position)
return payload
def inject_instructions_responses(payload: dict[str, Any], instructions: str, position: str = 'prepend') -> dict[str, Any]:
"""向 Responses 请求注入自定义指令(写入 instructions 字段)。
position='prepend' 时放在 instructions 开头,'append' 时放在末尾。
"""
if not instructions:
return payload
existing = payload.get('instructions') or ''
payload['instructions'] = _merge_text(instructions, existing, position)
logger.info('已注入自定义指令到 Responses instructions (%d 字符, %s)', len(instructions), position)
return payload
def inject_instructions_anthropic(payload: dict[str, Any], instructions: str, position: str = 'prepend') -> dict[str, Any]:
"""向 Anthropic Messages 请求注入自定义指令(写入 system 字段)。
position='prepend' 时放在 system 开头,'append' 时放在末尾。
"""
if not instructions:
return payload
existing = payload.get('system') or ''
if isinstance(existing, list):
existing = '\n'.join(
block.get('text', '') for block in existing
if isinstance(block, dict) and block.get('type') == 'text'
)
payload['system'] = _merge_text(instructions, existing, position)
logger.info('已注入自定义指令到 Anthropic system (%d 字符, %s)', len(instructions), position)
return payload
# ─── Body / Header 修改 ──────────────────────────
def apply_body_modifications(payload: dict[str, Any], modifications: dict[str, Any]) -> dict[str, Any]:
"""对转发请求体应用字段级修改。
规则与 CursorProxy 一致:值为 null 的字段会被删除,其余字段设置/覆盖。
"""
if not modifications:
return payload
for key, value in modifications.items():
if value is None:
payload.pop(key, None)
else:
payload[key] = value
logger.info('已应用 body_modifications: %s', list(modifications.keys()))
return payload
def apply_header_modifications(headers: dict[str, str], modifications: dict[str, Any]) -> dict[str, str]:
"""对转发请求头应用字段级修改。
规则同 body值为 null 删除,其余设置/覆盖。
"""
if not modifications:
return headers
for key, value in modifications.items():
if value is None:
headers.pop(key, None)
else:
headers[key] = str(value)
logger.info('已应用 header_modifications: %s', list(modifications.keys()))
return headers
# ═══════════════════════════════════════════════════════════
# 后端注册表 + ClientFormatter 实现
# ═══════════════════════════════════════════════════════════
def get_outbound(backend: str):
"""根据后端类型获取对应的 OutboundTransformer 实例。"""
from adapters.cc_anthropic_adapter import AnthropicOutbound
from adapters.cc_gemini_adapter import GeminiOutbound
from adapters.openai_compat_fixer import OpenAIChatOutbound
from adapters.responses_cc_adapter import ResponsesOutbound
registry = {
'openai': OpenAIChatOutbound,
'anthropic': AnthropicOutbound,
'gemini': GeminiOutbound,
'responses': ResponsesOutbound,
}
cls = registry.get(backend, OpenAIChatOutbound)
return cls()
class CCClientFormatter:
"""Chat Completions 客户端格式化器。
将通用处理结果格式化为 OpenAI Chat Completions 格式,
供 /v1/chat/completions 端点使用。
"""
def format_response(self, cc_response: dict[str, Any], model: str) -> dict[str, Any]:
cc_response['model'] = model
return cc_response
def wrap_stream_item(self, item: Any) -> str:
payload = item if isinstance(item, str) else json.dumps(item, ensure_ascii=False)
return f'data: {payload}\n\n'
def format_error(self, message: str) -> str:
return sse_data_message({'error': {'message': message, 'type': 'upstream_error'}})
def format_done(self) -> str | None:
return sse_data_message('[DONE]')
def start_events(self) -> list[str]:
return []
@property
def usage_input_key(self) -> str:
return 'prompt_tokens'
@property
def usage_output_key(self) -> str:
return 'completion_tokens'
class ResponsesClientFormatter:
"""Responses API 客户端格式化器。
将通用处理结果格式化为 OpenAI Responses 格式,
供 /v1/responses 端点使用。
流式场景使用 ResponsesStreamConverter 做 CC chunk → Responses SSE 转换。
"""
def __init__(self, model: str = ''):
from adapters.responses_cc_adapter import ResponsesStreamConverter, cc_to_responses
self._model = model
self._converter = ResponsesStreamConverter(model=model)
self._cc_to_responses = cc_to_responses
def format_response(self, cc_response: dict[str, Any], model: str) -> dict[str, Any]:
return self._cc_to_responses(cc_response, model)
def wrap_stream_item(self, item: Any) -> str:
if isinstance(item, str):
return item
events = self._converter.process_cc_chunk(item)
return ''.join(events)
def format_error(self, message: str) -> str:
return responses_error_event(message)
def format_done(self) -> str | None:
events = self._converter.finalize()
return ''.join(events) if events else None
def start_events(self) -> list[str]:
return self._converter.start_events()
@property
def usage_input_key(self) -> str:
return 'input_tokens'
@property
def usage_output_key(self) -> str:
return 'output_tokens'
class ResponsesPassthroughFormatter:
"""Responses 透传格式化器。
当后端本身就是 Responses 格式时使用,做轻量模型名改写。
"""
def __init__(self, model: str = ''):
self._model = model
def format_response(self, response_data: dict[str, Any], model: str) -> dict[str, Any]:
response_data['model'] = model
return response_data
def wrap_stream_item(self, item: Any) -> str:
if isinstance(item, str):
return item
event_type = item.pop('_sse_event_type', None)
if event_type:
return f'event: {event_type}\ndata: {json.dumps(item, ensure_ascii=False)}\n\n'
return f'data: {json.dumps(item, ensure_ascii=False)}\n\n'
def format_error(self, message: str) -> str:
return responses_error_event(message)
def format_done(self) -> str | None:
return None
def start_events(self) -> list[str]:
return []
@property
def usage_input_key(self) -> str:
return 'input_tokens'
@property
def usage_output_key(self) -> str:
return 'output_tokens'