1package openai
  2
  3import "github.com/openai/openai-go/shared/constant"
  4
  5// Helper to accumulate chunks from a stream
  6type ChatCompletionAccumulator struct {
  7	// The up-to-date accumulation of model's responses
  8	ChatCompletion
  9	choiceChatCompletionStates []chatCompletionResponseState
 10	justFinished               chatCompletionResponseState
 11}
 12
 13type FinishedChatCompletionToolCall struct {
 14	ChatCompletionMessageToolCallFunction
 15	Index int
 16	ID    string
 17}
 18
 19type chatCompletionResponseState struct {
 20	state chatCompletionResponseStateEnum
 21	index int
 22}
 23
 24type chatCompletionResponseStateEnum int
 25
 26const (
 27	emptyResponseState chatCompletionResponseStateEnum = iota
 28	contentResponseState
 29	refusalResponseState
 30	toolResponseState
 31	finishedResponseState
 32)
 33
 34// AddChunk incorporates a chunk into the accumulation. Chunks must be added in order.
 35// Returns false if the chunk could not be successfully accumulated.
 36//
 37// The ChatCompletion field JSON does not get accumulated.
 38func (acc *ChatCompletionAccumulator) AddChunk(chunk ChatCompletionChunk) bool {
 39	acc.justFinished = chatCompletionResponseState{}
 40	if !acc.accumulateDelta(chunk) {
 41		return false
 42	}
 43
 44	// only chunks with choices can cause finished events
 45	if len(chunk.Choices) == 0 {
 46		return true
 47	}
 48
 49	chunkIndex := int(chunk.Choices[0].Index)
 50	acc.choiceChatCompletionStates = expandToFit(acc.choiceChatCompletionStates, chunkIndex)
 51	acc.justFinished = acc.choiceChatCompletionStates[chunkIndex].update(chunk)
 52	return true
 53}
 54
 55// JustFinishedContent retrieves the chat completion content when it is known to have just been completed.
 56// The content is "just completed" when the last added chunk no longer contains a content
 57// delta. If the content is just completed, the content is returned and the boolean is true. Otherwise,
 58// an empty string is returned and the boolean will be false.
 59func (acc *ChatCompletionAccumulator) JustFinishedContent() (content string, ok bool) {
 60	if acc.justFinished.state == contentResponseState {
 61		return acc.Choices[0].Message.Content, true
 62	}
 63	return "", false
 64}
 65
 66// JustFinishedRefusal retrieves the chat completion refusal when it is known to have just been completed.
 67// The refusal is "just completed" when the last added chunk no longer contains a refusal
 68// delta. If the refusal is just completed, the refusal is returned and the boolean is true. Otherwise,
 69// an empty string is returned and the boolean will be false.
 70func (acc *ChatCompletionAccumulator) JustFinishedRefusal() (refusal string, ok bool) {
 71	if acc.justFinished.state == refusalResponseState {
 72		return acc.Choices[0].Message.Refusal, true
 73	}
 74	return "", false
 75}
 76
 77// JustFinishedToolCall retrieves a tool call when it is known to have just been completed.
 78// A tool call is "just completed" when the last added chunk no longer contains a tool call
 79// delta or contains a delta for a different tool call. If the tool call is just completed,
 80// a FinishedChatCompletionToolCall is returned and the boolean is true. Otherwise, an empty
 81// tool call is returned and the boolean will be false.
 82//
 83// You cannot rely on this with a stream that has ParallelToolCalls enabled.
 84func (acc *ChatCompletionAccumulator) JustFinishedToolCall() (toolcall FinishedChatCompletionToolCall, ok bool) {
 85	if acc.justFinished.state == toolResponseState {
 86		f := acc.Choices[0].Message.ToolCalls[acc.justFinished.index].Function
 87		id := acc.Choices[0].Message.ToolCalls[acc.justFinished.index].ID
 88		return FinishedChatCompletionToolCall{
 89			ID:    id,
 90			Index: acc.justFinished.index,
 91			ChatCompletionMessageToolCallFunction: ChatCompletionMessageToolCallFunction{
 92				Name:      f.Name,
 93				Arguments: f.Arguments,
 94			},
 95		}, true
 96	}
 97	return FinishedChatCompletionToolCall{}, false
 98}
 99
100// Concatenates a ChatCompletionChunk onto a ChatCompletion. Returns false and
101// does nothing if a mismatch is detected.
102//
103// Ignores the JSON field
104func (cc *ChatCompletion) accumulateDelta(chunk ChatCompletionChunk) bool {
105	if len(cc.ID) == 0 {
106		cc.ID = chunk.ID
107	} else if cc.ID != chunk.ID {
108		return false
109	}
110
111	for _, delta := range chunk.Choices {
112		cc.Choices = expandToFit(cc.Choices, int(delta.Index))
113		choice := &cc.Choices[delta.Index]
114
115		choice.Index = delta.Index
116		choice.FinishReason = delta.FinishReason
117
118		if delta.Delta.Role != "" {
119			choice.Message.Role = constant.Assistant(delta.Delta.Role)
120		}
121
122		choice.Message.Content += delta.Delta.Content
123		choice.Message.Refusal += delta.Delta.Refusal
124
125		for j := range delta.Delta.ToolCalls {
126			deltaTool := &delta.Delta.ToolCalls[j]
127
128			choice.Message.ToolCalls = expandToFit(choice.Message.ToolCalls, int(deltaTool.Index))
129			tool := &choice.Message.ToolCalls[deltaTool.Index]
130
131			if deltaTool.ID != "" {
132				tool.ID = deltaTool.ID
133			}
134			if deltaTool.Type != "" {
135				tool.Type = constant.Function(deltaTool.Type)
136			}
137			tool.Function.Name += deltaTool.Function.Name
138			tool.Function.Arguments += deltaTool.Function.Arguments
139		}
140
141		choice.Logprobs.Content = append(choice.Logprobs.Content, delta.Logprobs.Content...)
142		choice.Logprobs.Refusal = append(choice.Logprobs.Refusal, delta.Logprobs.Refusal...)
143	}
144
145	cc.Usage.CompletionTokens += chunk.Usage.CompletionTokens
146	cc.Usage.PromptTokens += chunk.Usage.PromptTokens
147	cc.Usage.TotalTokens += chunk.Usage.TotalTokens
148
149	cc.Model = chunk.Model
150	cc.Created = chunk.Created
151	cc.SystemFingerprint = chunk.SystemFingerprint
152	cc.ServiceTier = ChatCompletionServiceTier(chunk.ServiceTier)
153	if chunk.Object == chunk.Object.Default() {
154		cc.Object = cc.Object.Default()
155	}
156
157	return true
158}
159
160// Updates the internal response state and returns the previous state if
161// the state changed. This ensures that JustFinished events only fire once.
162func (prev *chatCompletionResponseState) update(chunk ChatCompletionChunk) (justFinished chatCompletionResponseState) {
163	delta := chunk.Choices[0].Delta
164	new := chatCompletionResponseState{}
165	switch {
166	case delta.JSON.Content.Valid():
167		new.state = contentResponseState
168	case delta.JSON.Refusal.Valid():
169		new.state = refusalResponseState
170	case delta.JSON.ToolCalls.Valid():
171		new.state = toolResponseState
172		new.index = int(delta.ToolCalls[0].Index)
173	default:
174		new.state = finishedResponseState
175	}
176
177	if *prev != new {
178		justFinished = *prev
179	}
180	*prev = new
181
182	return
183}
184
185func expandToFit[T any](slice []T, index int) []T {
186	if index < len(slice) {
187		return slice
188	}
189	if index < cap(slice) {
190		return slice[:index+1]
191	}
192	newSlice := make([]T, index+1)
193	copy(newSlice, slice)
194	return newSlice
195}