anthropic_adapter.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. package workflow
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "encoding/json"
  7. "fmt"
  8. "io"
  9. "net/http"
  10. "strings"
  11. "time"
  12. )
  13. // schemaToInstruction converts a JSON Schema object to a strict system-level instruction.
  14. // The strong imperative framing ("enforced by runtime") reduces the model's tendency to
  15. // wrap output in markdown fences or add explanatory text.
  16. func schemaToInstruction(schema map[string]interface{}) string {
  17. const header = "OUTPUT FORMAT REQUIREMENT (enforced by runtime, do not deviate):\n" +
  18. "You must output ONLY a raw JSON object or array. No other text is allowed.\n" +
  19. "- Start your response with { or [, end with } or ]\n" +
  20. "- Do NOT wrap in markdown code fences (no ```json, no ```)\n" +
  21. "- Do NOT include any preamble, explanation, or commentary before or after the JSON\n" +
  22. "- Use double quotes for all keys and string values\n" +
  23. "- Do not include trailing commas\n" +
  24. "- The output must be directly parseable by JSON.parse() with zero preprocessing"
  25. if schema == nil {
  26. return header
  27. }
  28. schemaJSON, err := json.MarshalIndent(schema, "", " ")
  29. if err != nil {
  30. return header
  31. }
  32. return header + "\n\nJSON Schema to conform to:\n" + string(schemaJSON)
  33. }
  34. // AnthropicAdapter calls the Anthropic Messages API directly (no OpenAI proxy needed).
  35. //
  36. // Activate automatically in cmd/ when LLM_URL contains "anthropic.com".
  37. //
  38. // Config (.env):
  39. //
  40. // LLM_URL=https://api.anthropic.com
  41. // LLM_KEY=sk-ant-api03-...
  42. // LLM_MODEL=claude-3-5-sonnet-20241022
  43. type AnthropicAdapter struct {
  44. baseURL string
  45. apiKey string
  46. model string
  47. client *http.Client
  48. }
  49. // AnthropicConfig holds configuration for AnthropicAdapter.
  50. type AnthropicConfig struct {
  51. APIKey string // Anthropic API key (sk-ant-...)
  52. Model string // e.g. "claude-3-5-sonnet-20241022"
  53. BaseURL string // default: "https://api.anthropic.com"
  54. Timeout time.Duration // default: 5 minutes
  55. }
  56. // NewAnthropicAdapter creates a new AnthropicAdapter.
  57. func NewAnthropicAdapter(cfg AnthropicConfig) *AnthropicAdapter {
  58. baseURL := strings.TrimSuffix(cfg.BaseURL, "/")
  59. if baseURL == "" {
  60. baseURL = "https://api.anthropic.com"
  61. }
  62. model := cfg.Model
  63. if model == "" {
  64. model = "claude-3-5-sonnet-20241022"
  65. }
  66. timeout := cfg.Timeout
  67. if timeout == 0 {
  68. timeout = 5 * time.Minute
  69. }
  70. return &AnthropicAdapter{
  71. baseURL: baseURL,
  72. apiKey: cfg.APIKey,
  73. model: model,
  74. client: &http.Client{Timeout: timeout},
  75. }
  76. }
  77. // anthropicMsg is a message in the Anthropic API format.
  78. type anthropicMsg struct {
  79. Role string `json:"role"`
  80. Content string `json:"content"`
  81. }
  82. // anthropicReq is the POST body for /v1/messages.
  83. type anthropicReq struct {
  84. Model string `json:"model"`
  85. MaxTokens int `json:"max_tokens"`
  86. Messages []anthropicMsg `json:"messages"`
  87. System string `json:"system,omitempty"`
  88. Stream bool `json:"stream,omitempty"`
  89. }
  90. // anthropicResp is the response from /v1/messages (non-streaming).
  91. type anthropicResp struct {
  92. ID string `json:"id"`
  93. Model string `json:"model"`
  94. Content []struct {
  95. Type string `json:"type"`
  96. Text string `json:"text"`
  97. } `json:"content"`
  98. StopReason string `json:"stop_reason"`
  99. Usage struct {
  100. InputTokens int `json:"input_tokens"`
  101. OutputTokens int `json:"output_tokens"`
  102. } `json:"usage"`
  103. Error *struct {
  104. Type string `json:"type"`
  105. Message string `json:"message"`
  106. } `json:"error,omitempty"`
  107. }
  108. // anthropicCallParams holds parsed parameters for a single Call invocation.
  109. type anthropicCallParams struct {
  110. model string
  111. maxTokens int
  112. msgs []anthropicMsg
  113. system string
  114. shouldParseJSON bool
  115. jsonSchema map[string]interface{}
  116. }
  117. // Call implements LLMAdapter. Supports both streaming (SSE) and non-streaming modes.
  118. func (a *AnthropicAdapter) Call(ctx context.Context, params map[string]interface{}, stream chan<- string) (map[string]interface{}, error) {
  119. p, err := a.parseCallParams(params)
  120. if err != nil {
  121. return nil, err
  122. }
  123. // Check if streaming is requested
  124. isStreaming := false
  125. if streamVal, ok := params["stream"].(bool); ok {
  126. isStreaming = streamVal
  127. }
  128. var result map[string]interface{}
  129. if isStreaming && stream != nil {
  130. result, err = a.callStreaming(ctx, p, stream)
  131. } else {
  132. result, err = a.callNonStreaming(ctx, p)
  133. }
  134. if err != nil {
  135. return nil, err
  136. }
  137. // Parse structured JSON output if requested
  138. if p.shouldParseJSON {
  139. if err := a.parseStructuredJSON(result); err != nil {
  140. return nil, err
  141. }
  142. }
  143. return result, nil
  144. }
  145. // parseCallParams extracts and validates parameters from the generic params map.
  146. func (a *AnthropicAdapter) parseCallParams(params map[string]interface{}) (*anthropicCallParams, error) {
  147. p := &anthropicCallParams{
  148. model: a.model,
  149. maxTokens: 4096,
  150. }
  151. if m, ok := params["model"].(string); ok && m != "" {
  152. p.model = m
  153. }
  154. switch v := params["max_tokens"].(type) {
  155. case int:
  156. p.maxTokens = v
  157. case float64:
  158. p.maxTokens = int(v)
  159. }
  160. // Detect structured JSON request from output_config
  161. if outputConfig, ok := params["output_config"].(map[string]interface{}); ok {
  162. if format, ok := outputConfig["format"].(map[string]interface{}); ok {
  163. if ftype, ok := format["type"].(string); ok && ftype == "json_schema" {
  164. p.shouldParseJSON = true
  165. if s, ok := format["schema"].(map[string]interface{}); ok {
  166. p.jsonSchema = s
  167. }
  168. }
  169. }
  170. }
  171. // Parse messages + system
  172. if raw, ok := params["messages"].([]interface{}); ok {
  173. for _, item := range raw {
  174. m, ok := item.(map[string]interface{})
  175. if !ok {
  176. continue
  177. }
  178. role, _ := m["role"].(string)
  179. content, _ := m["content"].(string)
  180. if role == "system" {
  181. p.system = content
  182. } else {
  183. p.msgs = append(p.msgs, anthropicMsg{Role: role, Content: content})
  184. }
  185. }
  186. } else if prompt, ok := params["prompt"].(string); ok {
  187. p.msgs = []anthropicMsg{{Role: "user", Content: prompt}}
  188. } else {
  189. return nil, fmt.Errorf("AnthropicAdapter: params must include 'messages' or 'prompt'")
  190. }
  191. if len(p.msgs) == 0 {
  192. return nil, fmt.Errorf("AnthropicAdapter: no user/assistant messages")
  193. }
  194. // Inject schema instruction into system prompt
  195. if p.shouldParseJSON {
  196. instruction := schemaToInstruction(p.jsonSchema)
  197. if p.system == "" {
  198. p.system = instruction
  199. } else {
  200. p.system = p.system + "\n\n" + instruction
  201. }
  202. }
  203. return p, nil
  204. }
  205. // buildHTTPRequest creates the HTTP request for the Anthropic API.
  206. func (a *AnthropicAdapter) buildHTTPRequest(ctx context.Context, p *anthropicCallParams, streaming bool) (*http.Request, error) {
  207. reqBody := anthropicReq{
  208. Model: p.model,
  209. MaxTokens: p.maxTokens,
  210. Messages: p.msgs,
  211. System: p.system,
  212. Stream: streaming,
  213. }
  214. bodyBytes, err := json.Marshal(reqBody)
  215. if err != nil {
  216. return nil, fmt.Errorf("AnthropicAdapter: marshal: %w", err)
  217. }
  218. httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost,
  219. a.baseURL+"/v1/messages", bytes.NewReader(bodyBytes))
  220. if err != nil {
  221. return nil, fmt.Errorf("AnthropicAdapter: new request: %w", err)
  222. }
  223. httpReq.Header.Set("Content-Type", "application/json")
  224. httpReq.Header.Set("x-api-key", a.apiKey)
  225. httpReq.Header.Set("anthropic-version", "2023-06-01")
  226. if streaming {
  227. httpReq.Header.Set("Accept", "text/event-stream")
  228. }
  229. return httpReq, nil
  230. }
  231. // callNonStreaming performs a standard (non-streaming) API call.
  232. func (a *AnthropicAdapter) callNonStreaming(ctx context.Context, p *anthropicCallParams) (map[string]interface{}, error) {
  233. httpReq, err := a.buildHTTPRequest(ctx, p, false)
  234. if err != nil {
  235. return nil, err
  236. }
  237. resp, err := a.client.Do(httpReq)
  238. if err != nil {
  239. return nil, fmt.Errorf("AnthropicAdapter: HTTP: %w", err)
  240. }
  241. defer resp.Body.Close()
  242. respBytes, err := io.ReadAll(resp.Body)
  243. if err != nil {
  244. return nil, fmt.Errorf("AnthropicAdapter: read body: %w", err)
  245. }
  246. if resp.StatusCode != http.StatusOK {
  247. return nil, parseAnthropicError(resp.StatusCode, respBytes, p.model)
  248. }
  249. var apiResp anthropicResp
  250. if err := json.Unmarshal(respBytes, &apiResp); err != nil {
  251. return nil, fmt.Errorf("AnthropicAdapter: parse response: %w", err)
  252. }
  253. if apiResp.Error != nil {
  254. return nil, &LLMError{
  255. Type: apiResp.Error.Type,
  256. Message: fmt.Sprintf("AnthropicAdapter: API error: %s", apiResp.Error.Message),
  257. Model: apiResp.Model,
  258. }
  259. }
  260. // Extract content text
  261. var text strings.Builder
  262. for _, block := range apiResp.Content {
  263. if block.Type == "text" {
  264. text.WriteString(block.Text)
  265. }
  266. }
  267. return map[string]interface{}{
  268. "content": text.String(),
  269. "model": apiResp.Model,
  270. "finish_reason": apiResp.StopReason,
  271. "response_id": apiResp.ID,
  272. "usage": map[string]interface{}{
  273. "prompt_tokens": apiResp.Usage.InputTokens,
  274. "completion_tokens": apiResp.Usage.OutputTokens,
  275. "total_tokens": apiResp.Usage.InputTokens + apiResp.Usage.OutputTokens,
  276. },
  277. }, nil
  278. }
  279. // callStreaming performs a streaming API call using Anthropic's SSE protocol.
  280. //
  281. // Anthropic SSE event types:
  282. // - message_start: contains message metadata (id, model, usage.input_tokens)
  283. // - content_block_delta: contains incremental text (delta.text)
  284. // - message_delta: contains stop_reason and usage.output_tokens
  285. // - message_stop: signals end of stream
  286. // - error: API error during streaming
  287. //
  288. // Channel contract (same as OpenAI adapter):
  289. // - Does not close the channel (engine is responsible)
  290. // - Uses select for non-blocking send
  291. // - All sends complete before Call returns (no goroutines)
  292. func (a *AnthropicAdapter) callStreaming(ctx context.Context, p *anthropicCallParams, stream chan<- string) (map[string]interface{}, error) {
  293. httpReq, err := a.buildHTTPRequest(ctx, p, true)
  294. if err != nil {
  295. return nil, err
  296. }
  297. resp, err := a.client.Do(httpReq)
  298. if err != nil {
  299. return nil, fmt.Errorf("AnthropicAdapter: HTTP: %w", err)
  300. }
  301. defer resp.Body.Close()
  302. if resp.StatusCode != http.StatusOK {
  303. bodyBytes, _ := io.ReadAll(resp.Body)
  304. return nil, parseAnthropicError(resp.StatusCode, bodyBytes, p.model)
  305. }
  306. // Process SSE stream
  307. var fullContent strings.Builder
  308. var model, responseID, stopReason string
  309. var inputTokens, outputTokens int
  310. scanner := bufio.NewScanner(resp.Body)
  311. scanner.Buffer(make([]byte, 64*1024), 1024*1024)
  312. var eventType string
  313. for scanner.Scan() {
  314. line := scanner.Text()
  315. // Track event type from "event:" lines
  316. if strings.HasPrefix(line, "event: ") {
  317. eventType = strings.TrimPrefix(line, "event: ")
  318. continue
  319. }
  320. // Skip empty lines, comments, non-data lines
  321. if line == "" || strings.HasPrefix(line, ":") || !strings.HasPrefix(line, "data: ") {
  322. continue
  323. }
  324. data := strings.TrimPrefix(line, "data: ")
  325. switch eventType {
  326. case "message_start":
  327. // {"type":"message_start","message":{"id":"...","model":"...","usage":{"input_tokens":N}}}
  328. var event struct {
  329. Message struct {
  330. ID string `json:"id"`
  331. Model string `json:"model"`
  332. Usage struct {
  333. InputTokens int `json:"input_tokens"`
  334. } `json:"usage"`
  335. } `json:"message"`
  336. }
  337. if json.Unmarshal([]byte(data), &event) == nil {
  338. responseID = event.Message.ID
  339. model = event.Message.Model
  340. inputTokens = event.Message.Usage.InputTokens
  341. }
  342. case "content_block_delta":
  343. // {"type":"content_block_delta","delta":{"type":"text_delta","text":"..."}}
  344. var event struct {
  345. Delta struct {
  346. Text string `json:"text"`
  347. } `json:"delta"`
  348. }
  349. if json.Unmarshal([]byte(data), &event) == nil && event.Delta.Text != "" {
  350. fullContent.WriteString(event.Delta.Text)
  351. select {
  352. case stream <- event.Delta.Text:
  353. case <-ctx.Done():
  354. return nil, ctx.Err()
  355. }
  356. }
  357. case "message_delta":
  358. // {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":N}}
  359. var event struct {
  360. Delta struct {
  361. StopReason string `json:"stop_reason"`
  362. } `json:"delta"`
  363. Usage struct {
  364. OutputTokens int `json:"output_tokens"`
  365. } `json:"usage"`
  366. }
  367. if json.Unmarshal([]byte(data), &event) == nil {
  368. stopReason = event.Delta.StopReason
  369. outputTokens = event.Usage.OutputTokens
  370. }
  371. case "message_stop":
  372. // End of stream — break out
  373. goto done
  374. case "error":
  375. // {"type":"error","error":{"type":"...","message":"..."}}
  376. var event struct {
  377. Error struct {
  378. Type string `json:"type"`
  379. Message string `json:"message"`
  380. } `json:"error"`
  381. }
  382. if json.Unmarshal([]byte(data), &event) == nil {
  383. return nil, &LLMError{
  384. Type: event.Error.Type,
  385. Message: fmt.Sprintf("AnthropicAdapter: stream error: %s", event.Error.Message),
  386. Model: model,
  387. }
  388. }
  389. return nil, fmt.Errorf("AnthropicAdapter: stream error (unparseable): %s", data)
  390. }
  391. eventType = "" // Reset for next event
  392. }
  393. done:
  394. if err := scanner.Err(); err != nil {
  395. return nil, fmt.Errorf("AnthropicAdapter: error reading stream: %w", err)
  396. }
  397. return map[string]interface{}{
  398. "content": fullContent.String(),
  399. "model": model,
  400. "finish_reason": stopReason,
  401. "response_id": responseID,
  402. "usage": map[string]interface{}{
  403. "prompt_tokens": inputTokens,
  404. "completion_tokens": outputTokens,
  405. "total_tokens": inputTokens + outputTokens,
  406. },
  407. }, nil
  408. }
  409. // parseStructuredJSON attempts to parse the "content" field of the result as JSON.
  410. // Called when output_config requested json_schema format.
  411. func (a *AnthropicAdapter) parseStructuredJSON(result map[string]interface{}) error {
  412. text, ok := result["content"].(string)
  413. if !ok || text == "" {
  414. return nil
  415. }
  416. // Strip markdown code fences if present (e.g. ```json ... ``` or ``` ... ```)
  417. raw := strings.TrimSpace(text)
  418. if strings.HasPrefix(raw, "```") {
  419. if idx := strings.Index(raw, "\n"); idx != -1 {
  420. raw = raw[idx+1:]
  421. }
  422. if idx := strings.LastIndex(raw, "```"); idx != -1 {
  423. raw = strings.TrimSpace(raw[:idx])
  424. }
  425. }
  426. var parsed interface{}
  427. if err := json.Unmarshal([]byte(raw), &parsed); err != nil {
  428. model, _ := result["model"].(string)
  429. return &LLMError{
  430. Type: "json_parse_error",
  431. Code: "JSON_PARSE_ERROR",
  432. Message: fmt.Sprintf("AnthropicAdapter: failed to parse structured output as JSON: %v", err),
  433. Retryable: false,
  434. Model: model,
  435. }
  436. }
  437. result["content"] = parsed
  438. return nil
  439. }
  440. // parseAnthropicError constructs a structured LLMError from an HTTP error response.
  441. func parseAnthropicError(statusCode int, body []byte, model string) *LLMError {
  442. llmErr := &LLMError{
  443. StatusCode: statusCode,
  444. Model: model,
  445. Message: string(body),
  446. }
  447. // Try to parse body as JSON for richer error info
  448. var errResp struct {
  449. Error struct {
  450. Type string `json:"type"`
  451. Message string `json:"message"`
  452. } `json:"error"`
  453. }
  454. if json.Unmarshal(body, &errResp) == nil && errResp.Error.Type != "" {
  455. llmErr.Type = errResp.Error.Type
  456. llmErr.Message = fmt.Sprintf("AnthropicAdapter: API error %s: %s", errResp.Error.Type, errResp.Error.Message)
  457. }
  458. // Classify retryability and error type based on status code
  459. switch {
  460. case statusCode == 429:
  461. llmErr.Type = "rate_limit_error"
  462. llmErr.Code = "RATE_LIMITED"
  463. llmErr.Retryable = true
  464. case statusCode == 529, statusCode == 503:
  465. llmErr.Type = "overloaded_error"
  466. llmErr.Code = "OVERLOADED"
  467. llmErr.Retryable = true
  468. case statusCode >= 500:
  469. llmErr.Type = "api_error"
  470. llmErr.Code = "SERVER_ERROR"
  471. llmErr.Retryable = true
  472. case statusCode == 401:
  473. llmErr.Type = "authentication_error"
  474. llmErr.Code = "AUTH_ERROR"
  475. llmErr.Retryable = false
  476. default:
  477. llmErr.Type = "invalid_request_error"
  478. llmErr.Code = "INVALID_REQUEST"
  479. llmErr.Retryable = false
  480. }
  481. return llmErr
  482. }