openai_adapter_test.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. package workflow
  2. import (
  3. "encoding/json"
  4. "testing"
  5. )
  6. func TestIsAnthropicModel(t *testing.T) {
  7. tests := []struct {
  8. model string
  9. expected bool
  10. }{
  11. {"claude-3-opus-20240229", true},
  12. {"claude-3-sonnet", true},
  13. {"claude-3-haiku", true},
  14. {"anthropic/claude-3-opus", true},
  15. {"Claude-3-Opus", true},
  16. {"gpt-4", false},
  17. {"gpt-4-turbo", false},
  18. {"gemini-pro", false},
  19. {"mistral-large", false},
  20. }
  21. for _, tc := range tests {
  22. t.Run(tc.model, func(t *testing.T) {
  23. if got := isAnthropicModel(tc.model); got != tc.expected {
  24. t.Errorf("isAnthropicModel(%q) = %v, want %v", tc.model, got, tc.expected)
  25. }
  26. })
  27. }
  28. }
  29. func TestBuildRequestModelOverride(t *testing.T) {
  30. t.Run("ConfigModelOverridesParamModel", func(t *testing.T) {
  31. adapter := NewOpenAIAdapter(OpenAIConfig{Model: "claude-3-opus"})
  32. params := map[string]interface{}{
  33. "model": "gpt-4",
  34. "messages": []interface{}{},
  35. }
  36. req, err := adapter.buildRequest(params)
  37. if err != nil {
  38. t.Fatalf("buildRequest failed: %v", err)
  39. }
  40. if req.Model != "claude-3-opus" {
  41. t.Errorf("Expected model 'claude-3-opus', got %q", req.Model)
  42. }
  43. })
  44. t.Run("NoOverrideUsesParamModel", func(t *testing.T) {
  45. adapter := NewOpenAIAdapter(OpenAIConfig{})
  46. params := map[string]interface{}{
  47. "model": "gpt-4",
  48. "messages": []interface{}{},
  49. }
  50. req, err := adapter.buildRequest(params)
  51. if err != nil {
  52. t.Fatalf("buildRequest failed: %v", err)
  53. }
  54. if req.Model != "gpt-4" {
  55. t.Errorf("Expected model 'gpt-4', got %q", req.Model)
  56. }
  57. })
  58. }
  59. func TestBuildRequestCacheControlOverride(t *testing.T) {
  60. boolPtr := func(b bool) *bool { return &b }
  61. t.Run("ConfigCacheControlOverridesParam", func(t *testing.T) {
  62. adapter := NewOpenAIAdapter(OpenAIConfig{CacheControl: boolPtr(true)})
  63. params := map[string]interface{}{
  64. "model": "claude-3-opus-20240229",
  65. "cache_control": false,
  66. "messages": []interface{}{
  67. map[string]interface{}{"role": "system", "content": "You are helpful."},
  68. },
  69. }
  70. req, err := adapter.buildRequest(params)
  71. if err != nil {
  72. t.Fatalf("buildRequest failed: %v", err)
  73. }
  74. if req.Messages[0].CacheControl == nil {
  75. t.Fatal("Expected cache_control on system message when config override is true")
  76. }
  77. })
  78. t.Run("ConfigCacheControlDisablesParam", func(t *testing.T) {
  79. adapter := NewOpenAIAdapter(OpenAIConfig{CacheControl: boolPtr(false)})
  80. params := map[string]interface{}{
  81. "model": "claude-3-opus-20240229",
  82. "cache_control": true,
  83. "messages": []interface{}{
  84. map[string]interface{}{"role": "system", "content": "You are helpful."},
  85. },
  86. }
  87. req, err := adapter.buildRequest(params)
  88. if err != nil {
  89. t.Fatalf("buildRequest failed: %v", err)
  90. }
  91. if req.Messages[0].CacheControl != nil {
  92. t.Error("Expected no cache_control when config override is false")
  93. }
  94. })
  95. t.Run("NilConfigFallsToParam", func(t *testing.T) {
  96. adapter := NewOpenAIAdapter(OpenAIConfig{})
  97. params := map[string]interface{}{
  98. "model": "claude-3-opus-20240229",
  99. "cache_control": true,
  100. "messages": []interface{}{
  101. map[string]interface{}{"role": "system", "content": "You are helpful."},
  102. },
  103. }
  104. req, err := adapter.buildRequest(params)
  105. if err != nil {
  106. t.Fatalf("buildRequest failed: %v", err)
  107. }
  108. if req.Messages[0].CacheControl == nil {
  109. t.Fatal("Expected cache_control from param when config is nil")
  110. }
  111. })
  112. }
  113. func TestBuildRequestAPIKey(t *testing.T) {
  114. t.Run("RequestAPIKeySetInBody", func(t *testing.T) {
  115. adapter := NewOpenAIAdapter(OpenAIConfig{RequestAPIKey: "sk-user-key-123"})
  116. params := map[string]interface{}{
  117. "model": "gpt-4",
  118. "messages": []interface{}{},
  119. }
  120. req, err := adapter.buildRequest(params)
  121. if err != nil {
  122. t.Fatalf("buildRequest failed: %v", err)
  123. }
  124. if req.APIKey != "sk-user-key-123" {
  125. t.Errorf("Expected api_key 'sk-user-key-123', got %q", req.APIKey)
  126. }
  127. // Verify JSON output includes api_key
  128. body, _ := json.Marshal(req)
  129. if !contains(string(body), `"api_key":"sk-user-key-123"`) {
  130. t.Errorf("Expected api_key in JSON body, got: %s", string(body))
  131. }
  132. })
  133. t.Run("NoRequestAPIKeyOmittedFromBody", func(t *testing.T) {
  134. adapter := NewOpenAIAdapter(OpenAIConfig{})
  135. params := map[string]interface{}{
  136. "model": "gpt-4",
  137. "messages": []interface{}{},
  138. }
  139. req, err := adapter.buildRequest(params)
  140. if err != nil {
  141. t.Fatalf("buildRequest failed: %v", err)
  142. }
  143. if req.APIKey != "" {
  144. t.Errorf("Expected empty api_key, got %q", req.APIKey)
  145. }
  146. // Verify JSON output does NOT include api_key
  147. body, _ := json.Marshal(req)
  148. if contains(string(body), `"api_key"`) {
  149. t.Errorf("Expected no api_key in JSON body, got: %s", string(body))
  150. }
  151. })
  152. }
  153. func TestBuildRequestCacheControl(t *testing.T) {
  154. adapter := NewOpenAIAdapter(OpenAIConfig{})
  155. t.Run("AnthropicWithCacheControl", func(t *testing.T) {
  156. params := map[string]interface{}{
  157. "model": "claude-3-opus-20240229",
  158. "cache_control": true,
  159. "messages": []interface{}{
  160. map[string]interface{}{"role": "system", "content": "You are helpful."},
  161. map[string]interface{}{"role": "user", "content": "Hello"},
  162. },
  163. }
  164. req, err := adapter.buildRequest(params)
  165. if err != nil {
  166. t.Fatalf("buildRequest failed: %v", err)
  167. }
  168. // System message should have cache_control
  169. if req.Messages[0].CacheControl == nil {
  170. t.Fatal("Expected cache_control on system message")
  171. }
  172. if req.Messages[0].CacheControl.Type != "ephemeral" {
  173. t.Errorf("Expected cache_control type 'ephemeral', got %q", req.Messages[0].CacheControl.Type)
  174. }
  175. // User message should NOT have cache_control
  176. if req.Messages[1].CacheControl != nil {
  177. t.Error("Expected no cache_control on user message")
  178. }
  179. // Verify JSON output includes cache_control
  180. body, _ := json.Marshal(req)
  181. if !contains(string(body), `"cache_control":{"type":"ephemeral"}`) {
  182. t.Errorf("Expected cache_control in JSON, got: %s", string(body))
  183. }
  184. })
  185. t.Run("AnthropicWithoutCacheControl", func(t *testing.T) {
  186. params := map[string]interface{}{
  187. "model": "claude-3-opus-20240229",
  188. "messages": []interface{}{
  189. map[string]interface{}{"role": "system", "content": "You are helpful."},
  190. },
  191. }
  192. req, err := adapter.buildRequest(params)
  193. if err != nil {
  194. t.Fatalf("buildRequest failed: %v", err)
  195. }
  196. if req.Messages[0].CacheControl != nil {
  197. t.Error("Expected no cache_control when option not set")
  198. }
  199. })
  200. t.Run("NonAnthropicWithCacheControl", func(t *testing.T) {
  201. params := map[string]interface{}{
  202. "model": "gpt-4",
  203. "cache_control": true,
  204. "messages": []interface{}{
  205. map[string]interface{}{"role": "system", "content": "You are helpful."},
  206. },
  207. }
  208. req, err := adapter.buildRequest(params)
  209. if err != nil {
  210. t.Fatalf("buildRequest failed: %v", err)
  211. }
  212. if req.Messages[0].CacheControl != nil {
  213. t.Error("Expected no cache_control for non-Anthropic model")
  214. }
  215. })
  216. }
  217. func contains(s, substr string) bool {
  218. return len(s) >= len(substr) && searchString(s, substr)
  219. }
  220. func searchString(s, substr string) bool {
  221. for i := 0; i <= len(s)-len(substr); i++ {
  222. if s[i:i+len(substr)] == substr {
  223. return true
  224. }
  225. }
  226. return false
  227. }