1package openai
2
3import (
4 "encoding/json"
5 "testing"
6
7 "charm.land/fantasy"
8 "github.com/stretchr/testify/require"
9)
10
11func TestPrepareParams_Store(t *testing.T) {
12 t.Parallel()
13
14 lm := testResponsesLM()
15 prompt := fantasy.Prompt{testTextMessage(fantasy.MessageRoleUser, "hello")}
16
17 tests := []struct {
18 name string
19 opts *ResponsesProviderOptions
20 wantStore bool
21 }{
22 {
23 name: "store true",
24 opts: &ResponsesProviderOptions{Store: new(true)},
25 wantStore: true,
26 },
27 {
28 name: "store false",
29 opts: &ResponsesProviderOptions{Store: new(false)},
30 wantStore: false,
31 },
32 {
33 name: "store default",
34 opts: &ResponsesProviderOptions{},
35 wantStore: false,
36 },
37 {
38 name: "no provider options",
39 opts: nil,
40 wantStore: false,
41 },
42 }
43
44 for _, tt := range tests {
45 t.Run(tt.name, func(t *testing.T) {
46 t.Parallel()
47
48 params, warnings, err := lm.prepareParams(testCall(prompt, tt.opts))
49 require.NoError(t, err)
50 require.Empty(t, warnings)
51 require.True(t, params.Store.Valid())
52 require.Equal(t, tt.wantStore, params.Store.Value)
53 })
54 }
55}
56
57func TestPrepareParams_PreviousResponseID(t *testing.T) {
58 t.Parallel()
59
60 lm := testResponsesLM()
61 prompt := fantasy.Prompt{testTextMessage(fantasy.MessageRoleUser, "hello")}
62
63 t.Run("forwarded", func(t *testing.T) {
64 t.Parallel()
65
66 params, warnings, err := lm.prepareParams(testCall(prompt, &ResponsesProviderOptions{
67 PreviousResponseID: new("resp_abc123"),
68 Store: new(true),
69 }))
70 require.NoError(t, err)
71 require.Empty(t, warnings)
72 require.True(t, params.PreviousResponseID.Valid())
73 require.Equal(t, "resp_abc123", params.PreviousResponseID.Value)
74 })
75
76 t.Run("not set", func(t *testing.T) {
77 t.Parallel()
78
79 params, warnings, err := lm.prepareParams(testCall(prompt, &ResponsesProviderOptions{}))
80 require.NoError(t, err)
81 require.Empty(t, warnings)
82 require.False(t, params.PreviousResponseID.Valid())
83 })
84
85 t.Run("empty string ignored", func(t *testing.T) {
86 t.Parallel()
87
88 params, warnings, err := lm.prepareParams(testCall(prompt, &ResponsesProviderOptions{
89 PreviousResponseID: new(""),
90 }))
91 require.NoError(t, err)
92 require.Empty(t, warnings)
93 require.False(t, params.PreviousResponseID.Valid())
94 })
95}
96
97func TestPrepareParams_PreviousResponseID_Validation(t *testing.T) {
98 t.Parallel()
99
100 lm := testResponsesLM()
101 opts := &ResponsesProviderOptions{
102 PreviousResponseID: new("resp_abc123"),
103 Store: new(true),
104 }
105
106 t.Run("rejects with assistant messages", func(t *testing.T) {
107 t.Parallel()
108
109 _, _, err := lm.prepareParams(testCall(fantasy.Prompt{
110 testTextMessage(fantasy.MessageRoleUser, "hello"),
111 testTextMessage(fantasy.MessageRoleAssistant, "hi there"),
112 }, opts))
113 require.EqualError(t, err, previousResponseIDHistoryError)
114 })
115
116 t.Run("allows user-only prompt", func(t *testing.T) {
117 t.Parallel()
118
119 _, warnings, err := lm.prepareParams(testCall(fantasy.Prompt{
120 testTextMessage(fantasy.MessageRoleUser, "hello"),
121 testTextMessage(fantasy.MessageRoleUser, "follow up"),
122 }, opts))
123 require.NoError(t, err)
124 require.Empty(t, warnings)
125 })
126
127 t.Run("allows system + user prompt", func(t *testing.T) {
128 t.Parallel()
129
130 _, warnings, err := lm.prepareParams(testCall(fantasy.Prompt{
131 testTextMessage(fantasy.MessageRoleSystem, "be concise"),
132 testTextMessage(fantasy.MessageRoleUser, "hello"),
133 }, opts))
134 require.NoError(t, err)
135 require.Empty(t, warnings)
136 })
137
138 t.Run("rejects tool messages", func(t *testing.T) {
139 t.Parallel()
140
141 _, _, err := lm.prepareParams(testCall(fantasy.Prompt{
142 testToolResultMessage("done"),
143 testTextMessage(fantasy.MessageRoleUser, "hello"),
144 }, opts))
145 require.EqualError(t, err, previousResponseIDHistoryError)
146 })
147
148 t.Run("rejects without store", func(t *testing.T) {
149 t.Parallel()
150
151 _, _, err := lm.prepareParams(testCall(fantasy.Prompt{
152 testTextMessage(fantasy.MessageRoleUser, "hello"),
153 }, &ResponsesProviderOptions{
154 PreviousResponseID: new("resp_abc123"),
155 }))
156 require.EqualError(t, err, previousResponseIDStoreError)
157 })
158
159 t.Run("rejects with store false", func(t *testing.T) {
160 t.Parallel()
161
162 _, _, err := lm.prepareParams(testCall(fantasy.Prompt{
163 testTextMessage(fantasy.MessageRoleUser, "hello"),
164 }, &ResponsesProviderOptions{
165 PreviousResponseID: new("resp_abc123"),
166 Store: new(false),
167 }))
168 require.EqualError(t, err, previousResponseIDStoreError)
169 })
170}
171
172func TestValidatePreviousResponseIDPrompt(t *testing.T) {
173 t.Parallel()
174
175 tests := []struct {
176 name string
177 prompt fantasy.Prompt
178 wantErr bool
179 }{
180 {
181 name: "empty prompt",
182 prompt: nil,
183 },
184 {
185 name: "user-only messages",
186 prompt: fantasy.Prompt{
187 testTextMessage(fantasy.MessageRoleUser, "hello"),
188 testTextMessage(fantasy.MessageRoleUser, "follow up"),
189 },
190 },
191 {
192 name: "system + user messages",
193 prompt: fantasy.Prompt{
194 testTextMessage(fantasy.MessageRoleSystem, "be concise"),
195 testTextMessage(fantasy.MessageRoleUser, "hello"),
196 },
197 },
198 {
199 name: "contains assistant message",
200 prompt: fantasy.Prompt{
201 testTextMessage(fantasy.MessageRoleAssistant, "hi there"),
202 },
203 wantErr: true,
204 },
205 {
206 name: "assistant in the middle",
207 prompt: fantasy.Prompt{
208 testTextMessage(fantasy.MessageRoleUser, "hello"),
209 testTextMessage(fantasy.MessageRoleAssistant, "hi there"),
210 testTextMessage(fantasy.MessageRoleUser, "follow up"),
211 },
212 wantErr: true,
213 },
214 {
215 name: "contains tool message",
216 prompt: fantasy.Prompt{
217 testToolResultMessage("done"),
218 testTextMessage(fantasy.MessageRoleUser, "follow up"),
219 },
220 wantErr: true,
221 },
222 }
223
224 for _, tt := range tests {
225 t.Run(tt.name, func(t *testing.T) {
226 t.Parallel()
227
228 err := validatePreviousResponseIDPrompt(tt.prompt)
229 if tt.wantErr {
230 require.EqualError(t, err, previousResponseIDHistoryError)
231 return
232 }
233
234 require.NoError(t, err)
235 })
236 }
237}
238
239func TestResponsesProviderMetadata_Helper(t *testing.T) {
240 t.Parallel()
241
242 t.Run("non-empty id", func(t *testing.T) {
243 t.Parallel()
244
245 metadata := responsesProviderMetadata("resp_123")
246 require.Len(t, metadata, 1)
247
248 providerMetadata, ok := metadata[Name].(*ResponsesProviderMetadata)
249 require.True(t, ok)
250 require.Equal(t, "resp_123", providerMetadata.ResponseID)
251 })
252
253 t.Run("empty id", func(t *testing.T) {
254 t.Parallel()
255
256 metadata := responsesProviderMetadata("")
257 require.Empty(t, metadata)
258 })
259}
260
261func TestResponsesProviderMetadata_JSON(t *testing.T) {
262 t.Parallel()
263
264 encoded, err := json.Marshal(ResponsesProviderMetadata{ResponseID: "resp_123"})
265 require.NoError(t, err)
266 require.Contains(t, string(encoded), `"response_id":"resp_123"`)
267
268 decoded, err := fantasy.UnmarshalProviderMetadata(map[string]json.RawMessage{
269 Name: encoded,
270 })
271 require.NoError(t, err)
272
273 providerMetadata, ok := decoded[Name].(*ResponsesProviderMetadata)
274 require.True(t, ok)
275 require.Equal(t, "resp_123", providerMetadata.ResponseID)
276}
277
278func testCall(prompt fantasy.Prompt, opts *ResponsesProviderOptions) fantasy.Call {
279 call := fantasy.Call{
280 Prompt: prompt,
281 }
282 if opts != nil {
283 call.ProviderOptions = fantasy.ProviderOptions{
284 Name: opts,
285 }
286 }
287 return call
288}
289
290func testResponsesLM() responsesLanguageModel {
291 return responsesLanguageModel{
292 provider: Name,
293 modelID: "gpt-4o",
294 }
295}
296
297func testTextMessage(role fantasy.MessageRole, text string) fantasy.Message {
298 return fantasy.Message{
299 Role: role,
300 Content: []fantasy.MessagePart{
301 fantasy.TextPart{Text: text},
302 },
303 }
304}
305
306func testToolResultMessage(text string) fantasy.Message {
307 return fantasy.Message{
308 Role: fantasy.MessageRoleTool,
309 Content: []fantasy.MessagePart{
310 fantasy.ToolResultPart{
311 ToolCallID: "call_123",
312 Output: fantasy.ToolResultOutputContentText{
313 Text: text,
314 },
315 },
316 },
317 }
318}