1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "time"
10
11 "github.com/charmbracelet/crush/internal/config"
12 "github.com/charmbracelet/crush/internal/fur/provider"
13 "github.com/charmbracelet/crush/internal/llm/tools"
14 "github.com/charmbracelet/crush/internal/logging"
15 "github.com/charmbracelet/crush/internal/message"
16 "github.com/openai/openai-go"
17 "github.com/openai/openai-go/option"
18 "github.com/openai/openai-go/shared"
19)
20
21type LlamaClient ProviderClient
22
23func newLlamaClient(opts providerClientOptions) LlamaClient {
24 return &llamaClient{
25 providerOptions: opts,
26 client: createLlamaClient(opts),
27 }
28}
29
30type llamaClient struct {
31 providerOptions providerClientOptions
32 client openai.Client
33}
34
35func createLlamaClient(opts providerClientOptions) openai.Client {
36 openaiClientOptions := []option.RequestOption{}
37 if opts.apiKey != "" {
38 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
39 }
40
41 baseURL := "https://api.llama.com/compat/v1/"
42 if opts.baseURL != "" {
43 baseURL = opts.baseURL
44 }
45 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(baseURL))
46
47 if opts.extraHeaders != nil {
48 for key, value := range opts.extraHeaders {
49 openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
50 }
51 }
52 return openai.NewClient(openaiClientOptions...)
53}
54
55func (l *llamaClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
56 openaiMessages := l.convertMessages(messages)
57 openaiTools := l.convertTools(tools)
58 params := l.preparedParams(openaiMessages, openaiTools)
59 cfg := config.Get()
60 if cfg.Options.Debug {
61 jsonData, _ := json.Marshal(params)
62 logging.Debug("Prepared messages", "messages", string(jsonData))
63 }
64 attempts := 0
65 for {
66 attempts++
67 openaiResponse, err := l.client.Chat.Completions.New(ctx, params)
68 // If there is an error we are going to see if we can retry the call
69 if err != nil {
70 retry, after, retryErr := l.shouldRetry(attempts, err)
71 if retryErr != nil {
72 return nil, retryErr
73 }
74 if retry {
75 logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
76 select {
77 case <-ctx.Done():
78 return nil, ctx.Err()
79 case <-time.After(time.Duration(after) * time.Millisecond):
80 continue
81 }
82 }
83 return nil, retryErr
84 }
85
86 content := ""
87 if openaiResponse.Choices[0].Message.Content != "" {
88 content = openaiResponse.Choices[0].Message.Content
89 }
90 toolCalls := l.toolCalls(*openaiResponse)
91 finishReason := l.finishReason(string(openaiResponse.Choices[0].FinishReason))
92 if len(toolCalls) > 0 {
93 finishReason = message.FinishReasonToolUse
94 }
95 return &ProviderResponse{
96 Content: content,
97 ToolCalls: toolCalls,
98 Usage: l.usage(*openaiResponse),
99 FinishReason: finishReason,
100 }, nil
101 }
102}
103
104func (l *llamaClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
105 openaiMessages := l.convertMessages(messages)
106 openaiTools := l.convertTools(tools)
107 params := l.preparedParams(openaiMessages, openaiTools)
108 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
109 IncludeUsage: openai.Bool(true),
110 }
111 cfg := config.Get()
112 if cfg.Options.Debug {
113 jsonData, _ := json.Marshal(params)
114 logging.Debug("Prepared messages", "messages", string(jsonData))
115 }
116 eventChan := make(chan ProviderEvent)
117 go func() {
118 attempts := 0
119 for {
120 attempts++
121 openaiStream := l.client.Chat.Completions.NewStreaming(ctx, params)
122
123 acc := openai.ChatCompletionAccumulator{}
124 currentContent := ""
125 toolCalls := make([]message.ToolCall, 0)
126
127 for openaiStream.Next() {
128 chunk := openaiStream.Current()
129 acc.AddChunk(chunk)
130 for _, choice := range chunk.Choices {
131 if choice.Delta.Content != "" {
132 eventChan <- ProviderEvent{
133 Type: EventContentDelta,
134 Content: choice.Delta.Content,
135 }
136 currentContent += choice.Delta.Content
137 }
138 }
139 }
140
141 err := openaiStream.Err()
142 if err == nil || errors.Is(err, io.EOF) {
143 if cfg.Options.Debug {
144 jsonData, _ := json.Marshal(acc.ChatCompletion)
145 logging.Debug("Response", "messages", string(jsonData))
146 }
147 resultFinishReason := acc.ChatCompletion.Choices[0].FinishReason
148 if resultFinishReason == "" {
149 // If the finish reason is empty, we assume it was a successful completion
150 resultFinishReason = "stop"
151 }
152 // Stream completed successfully
153 finishReason := l.finishReason(resultFinishReason)
154 if len(acc.Choices[0].Message.ToolCalls) > 0 {
155 toolCalls = append(toolCalls, l.toolCalls(acc.ChatCompletion)...)
156 }
157 if len(toolCalls) > 0 {
158 finishReason = message.FinishReasonToolUse
159 }
160
161 eventChan <- ProviderEvent{
162 Type: EventComplete,
163 Response: &ProviderResponse{
164 Content: currentContent,
165 ToolCalls: toolCalls,
166 Usage: l.usage(acc.ChatCompletion),
167 FinishReason: finishReason,
168 },
169 }
170 close(eventChan)
171 return
172 }
173
174 // If there is an error we are going to see if we can retry the call
175 retry, after, retryErr := l.shouldRetry(attempts, err)
176 if retryErr != nil {
177 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
178 close(eventChan)
179 return
180 }
181 if retry {
182 logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
183 select {
184 case <-ctx.Done():
185 // context cancelled
186 if ctx.Err() != nil {
187 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
188 }
189 close(eventChan)
190 return
191 case <-time.After(time.Duration(after) * time.Millisecond):
192 continue
193 }
194 }
195 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
196 close(eventChan)
197 return
198 }
199 }()
200 return eventChan
201}
202
203func (l *llamaClient) shouldRetry(attempts int, err error) (bool, int64, error) {
204 var apiErr *openai.Error
205 if !errors.As(err, &apiErr) {
206 return false, 0, err
207 }
208
209 if attempts > maxRetries {
210 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
211 }
212
213 // Check for token expiration (401 Unauthorized)
214 if apiErr.StatusCode == 401 {
215 var err error
216 l.providerOptions.apiKey, err = config.ResolveAPIKey(l.providerOptions.config.APIKey)
217 if err != nil {
218 return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
219 }
220 l.client = createLlamaClient(l.providerOptions)
221 return true, 0, nil
222 }
223
224 if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 {
225 return false, 0, err
226 }
227
228 retryMs := 0
229 retryAfterValues := apiErr.Response.Header.Values("Retry-After")
230
231 backoffMs := 2000 * (1 << (attempts - 1))
232 jitterMs := int(float64(backoffMs) * 0.2)
233 retryMs = backoffMs + jitterMs
234 if len(retryAfterValues) > 0 {
235 if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
236 retryMs = retryMs * 1000
237 }
238 }
239 return true, int64(retryMs), nil
240}
241
242func (l *llamaClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
243 // Add system message first
244 openaiMessages = append(openaiMessages, openai.SystemMessage(l.providerOptions.systemMessage))
245 for _, msg := range messages {
246 switch msg.Role {
247 case message.User:
248 var content []openai.ChatCompletionContentPartUnionParam
249 textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
250 content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
251 for _, binaryContent := range msg.BinaryContent() {
252 imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(provider.InferenceProviderLlama)}
253 imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
254 content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
255 }
256 openaiMessages = append(openaiMessages, openai.UserMessage(content))
257 case message.Assistant:
258 assistantMsg := openai.ChatCompletionAssistantMessageParam{
259 Role: "assistant",
260 }
261 if msg.Content().String() != "" {
262 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
263 OfString: openai.String(msg.Content().String()),
264 }
265 }
266 if len(msg.ToolCalls()) > 0 {
267 assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
268 for i, call := range msg.ToolCalls() {
269 assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
270 ID: call.ID,
271 Type: "function",
272 Function: openai.ChatCompletionMessageToolCallFunctionParam{
273 Name: call.Name,
274 Arguments: call.Input,
275 },
276 }
277 }
278 }
279 openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{OfAssistant: &assistantMsg})
280 case message.Tool:
281 for _, result := range msg.ToolResults() {
282 openaiMessages = append(openaiMessages, openai.ToolMessage(result.Content, result.ToolCallID))
283 }
284 }
285 }
286 return
287}
288
289func (l *llamaClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
290 openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
291 for i, tool := range tools {
292 info := tool.Info()
293 openaiTools[i] = openai.ChatCompletionToolParam{
294 Function: openai.FunctionDefinitionParam{
295 Name: info.Name,
296 Description: openai.String(info.Description),
297 Parameters: openai.FunctionParameters{
298 "type": "object",
299 "properties": info.Parameters,
300 "required": info.Required,
301 },
302 },
303 }
304 }
305 return openaiTools
306}
307
308func (l *llamaClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
309 model := l.providerOptions.model(l.providerOptions.modelType)
310 cfg := config.Get()
311 modelConfig := cfg.Models.Large
312 if l.providerOptions.modelType == config.SmallModel {
313 modelConfig = cfg.Models.Small
314 }
315 reasoningEffort := model.ReasoningEffort
316 if modelConfig.ReasoningEffort != "" {
317 reasoningEffort = modelConfig.ReasoningEffort
318 }
319 params := openai.ChatCompletionNewParams{
320 Model: openai.ChatModel(model.ID),
321 Messages: messages,
322 Tools: tools,
323 }
324 maxTokens := model.DefaultMaxTokens
325 if modelConfig.MaxTokens > 0 {
326 maxTokens = modelConfig.MaxTokens
327 }
328 if l.providerOptions.maxTokens > 0 {
329 maxTokens = l.providerOptions.maxTokens
330 }
331 if model.CanReason {
332 params.MaxCompletionTokens = openai.Int(maxTokens)
333 switch reasoningEffort {
334 case "low":
335 params.ReasoningEffort = shared.ReasoningEffortLow
336 case "medium":
337 params.ReasoningEffort = shared.ReasoningEffortMedium
338 case "high":
339 params.ReasoningEffort = shared.ReasoningEffortHigh
340 default:
341 params.ReasoningEffort = shared.ReasoningEffort(reasoningEffort)
342 }
343 } else {
344 params.MaxTokens = openai.Int(maxTokens)
345 }
346 return params
347}
348
349func (l *llamaClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
350 var toolCalls []message.ToolCall
351 if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
352 for _, call := range completion.Choices[0].Message.ToolCalls {
353 toolCall := message.ToolCall{
354 ID: call.ID,
355 Name: call.Function.Name,
356 Input: call.Function.Arguments,
357 Type: "function",
358 Finished: true,
359 }
360 toolCalls = append(toolCalls, toolCall)
361 }
362 }
363 return toolCalls
364}
365
366func (l *llamaClient) finishReason(reason string) message.FinishReason {
367 switch reason {
368 case "stop":
369 return message.FinishReasonEndTurn
370 case "length":
371 return message.FinishReasonMaxTokens
372 case "tool_calls":
373 return message.FinishReasonToolUse
374 default:
375 return message.FinishReasonUnknown
376 }
377}
378
379func (l *llamaClient) usage(completion openai.ChatCompletion) TokenUsage {
380 cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
381 inputTokens := completion.Usage.PromptTokens - cachedTokens
382 return TokenUsage{
383 InputTokens: inputTokens,
384 OutputTokens: completion.Usage.CompletionTokens,
385 CacheCreationTokens: 0, // OpenAI doesn't provide this directly
386 CacheReadTokens: cachedTokens,
387 }
388}
389
390func (l *llamaClient) Model() config.Model {
391 return l.providerOptions.model(l.providerOptions.modelType)
392}