1package openai
2
3import "github.com/openai/openai-go/shared/constant"
4
5// Helper to accumulate chunks from a stream
6type ChatCompletionAccumulator struct {
7 // The up-to-date accumulation of model's responses
8 ChatCompletion
9 choiceChatCompletionStates []chatCompletionResponseState
10 justFinished chatCompletionResponseState
11}
12
13type FinishedChatCompletionToolCall struct {
14 ChatCompletionMessageToolCallFunction
15 Index int
16 ID string
17}
18
19type chatCompletionResponseState struct {
20 state chatCompletionResponseStateEnum
21 index int
22}
23
24type chatCompletionResponseStateEnum int
25
26const (
27 emptyResponseState chatCompletionResponseStateEnum = iota
28 contentResponseState
29 refusalResponseState
30 toolResponseState
31 finishedResponseState
32)
33
34// AddChunk incorporates a chunk into the accumulation. Chunks must be added in order.
35// Returns false if the chunk could not be successfully accumulated.
36//
37// The ChatCompletion field JSON does not get accumulated.
38func (acc *ChatCompletionAccumulator) AddChunk(chunk ChatCompletionChunk) bool {
39 acc.justFinished = chatCompletionResponseState{}
40 if !acc.accumulateDelta(chunk) {
41 return false
42 }
43
44 // only chunks with choices can cause finished events
45 if len(chunk.Choices) == 0 {
46 return true
47 }
48
49 chunkIndex := int(chunk.Choices[0].Index)
50 acc.choiceChatCompletionStates = expandToFit(acc.choiceChatCompletionStates, chunkIndex)
51 acc.justFinished = acc.choiceChatCompletionStates[chunkIndex].update(chunk)
52 return true
53}
54
55// JustFinishedContent retrieves the chat completion content when it is known to have just been completed.
56// The content is "just completed" when the last added chunk no longer contains a content
57// delta. If the content is just completed, the content is returned and the boolean is true. Otherwise,
58// an empty string is returned and the boolean will be false.
59func (acc *ChatCompletionAccumulator) JustFinishedContent() (content string, ok bool) {
60 if acc.justFinished.state == contentResponseState {
61 return acc.Choices[0].Message.Content, true
62 }
63 return "", false
64}
65
66// JustFinishedRefusal retrieves the chat completion refusal when it is known to have just been completed.
67// The refusal is "just completed" when the last added chunk no longer contains a refusal
68// delta. If the refusal is just completed, the refusal is returned and the boolean is true. Otherwise,
69// an empty string is returned and the boolean will be false.
70func (acc *ChatCompletionAccumulator) JustFinishedRefusal() (refusal string, ok bool) {
71 if acc.justFinished.state == refusalResponseState {
72 return acc.Choices[0].Message.Refusal, true
73 }
74 return "", false
75}
76
77// JustFinishedToolCall retrieves a tool call when it is known to have just been completed.
78// A tool call is "just completed" when the last added chunk no longer contains a tool call
79// delta or contains a delta for a different tool call. If the tool call is just completed,
80// a FinishedChatCompletionToolCall is returned and the boolean is true. Otherwise, an empty
81// tool call is returned and the boolean will be false.
82//
83// You cannot rely on this with a stream that has ParallelToolCalls enabled.
84func (acc *ChatCompletionAccumulator) JustFinishedToolCall() (toolcall FinishedChatCompletionToolCall, ok bool) {
85 if acc.justFinished.state == toolResponseState {
86 f := acc.Choices[0].Message.ToolCalls[acc.justFinished.index].Function
87 id := acc.Choices[0].Message.ToolCalls[acc.justFinished.index].ID
88 return FinishedChatCompletionToolCall{
89 ID: id,
90 Index: acc.justFinished.index,
91 ChatCompletionMessageToolCallFunction: ChatCompletionMessageToolCallFunction{
92 Name: f.Name,
93 Arguments: f.Arguments,
94 },
95 }, true
96 }
97 return FinishedChatCompletionToolCall{}, false
98}
99
100// Concatenates a ChatCompletionChunk onto a ChatCompletion. Returns false and
101// does nothing if a mismatch is detected.
102//
103// Ignores the JSON field
104func (cc *ChatCompletion) accumulateDelta(chunk ChatCompletionChunk) bool {
105 if len(cc.ID) == 0 {
106 cc.ID = chunk.ID
107 } else if cc.ID != chunk.ID {
108 return false
109 }
110
111 for _, delta := range chunk.Choices {
112 cc.Choices = expandToFit(cc.Choices, int(delta.Index))
113 choice := &cc.Choices[delta.Index]
114
115 choice.Index = delta.Index
116 choice.FinishReason = delta.FinishReason
117
118 if delta.Delta.Role != "" {
119 choice.Message.Role = constant.Assistant(delta.Delta.Role)
120 }
121
122 choice.Message.Content += delta.Delta.Content
123 choice.Message.Refusal += delta.Delta.Refusal
124
125 for j := range delta.Delta.ToolCalls {
126 deltaTool := &delta.Delta.ToolCalls[j]
127
128 choice.Message.ToolCalls = expandToFit(choice.Message.ToolCalls, int(deltaTool.Index))
129 tool := &choice.Message.ToolCalls[deltaTool.Index]
130
131 if deltaTool.ID != "" {
132 tool.ID = deltaTool.ID
133 }
134 if deltaTool.Type != "" {
135 tool.Type = constant.Function(deltaTool.Type)
136 }
137 tool.Function.Name += deltaTool.Function.Name
138 tool.Function.Arguments += deltaTool.Function.Arguments
139 }
140
141 choice.Logprobs.Content = append(choice.Logprobs.Content, delta.Logprobs.Content...)
142 choice.Logprobs.Refusal = append(choice.Logprobs.Refusal, delta.Logprobs.Refusal...)
143 }
144
145 cc.Usage.CompletionTokens += chunk.Usage.CompletionTokens
146 cc.Usage.PromptTokens += chunk.Usage.PromptTokens
147 cc.Usage.TotalTokens += chunk.Usage.TotalTokens
148
149 cc.Model = chunk.Model
150 cc.Created = chunk.Created
151 cc.SystemFingerprint = chunk.SystemFingerprint
152 cc.ServiceTier = ChatCompletionServiceTier(chunk.ServiceTier)
153 if chunk.Object == chunk.Object.Default() {
154 cc.Object = cc.Object.Default()
155 }
156
157 return true
158}
159
160// Updates the internal response state and returns the previous state if
161// the state changed. This ensures that JustFinished events only fire once.
162func (prev *chatCompletionResponseState) update(chunk ChatCompletionChunk) (justFinished chatCompletionResponseState) {
163 delta := chunk.Choices[0].Delta
164 new := chatCompletionResponseState{}
165 switch {
166 case delta.JSON.Content.Valid():
167 new.state = contentResponseState
168 case delta.JSON.Refusal.Valid():
169 new.state = refusalResponseState
170 case delta.JSON.ToolCalls.Valid():
171 new.state = toolResponseState
172 new.index = int(delta.ToolCalls[0].Index)
173 default:
174 new.state = finishedResponseState
175 }
176
177 if *prev != new {
178 justFinished = *prev
179 }
180 *prev = new
181
182 return
183}
184
185func expandToFit[T any](slice []T, index int) []T {
186 if index < len(slice) {
187 return slice
188 }
189 if index < cap(slice) {
190 return slice[:index+1]
191 }
192 newSlice := make([]T, index+1)
193 copy(newSlice, slice)
194 return newSlice
195}