diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 6488f7a0546b6015b569d19b67c6910f4bdf1778..af69167b56317e4d5204eb51a4d80089ea316c53 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -420,6 +420,12 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string msgHistory = append(msgHistory, agentMessage, *toolResults) continue } + if agentMessage.FinishReason() == "" { + // Kujtim: could not track down where this is happening but this means its cancelled + agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "") + _ = a.messages.Update(context.Background(), agentMessage) + return a.err(ErrRequestCancelled) + } return AgentEvent{ Type: AgentEventTypeResponse, Message: agentMessage, diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 152b242312ba5e348ca3f7964b36a85d2d77c56b..97353d6ad5662bbe133583fdba68ac40ea1e7a44 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -319,6 +319,10 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t go func() { for { attempts++ + // Kujtim: fixes an issue with anthropig models on openrouter + if len(params.Tools) == 0 { + params.Tools = nil + } openaiStream := o.client.Chat.Completions.NewStreaming( ctx, params, @@ -331,8 +335,14 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t var currentToolCallID string var currentToolCall openai.ChatCompletionMessageToolCall var msgToolCalls []openai.ChatCompletionMessageToolCall + currentToolIndex := 0 for openaiStream.Next() { chunk := openaiStream.Current() + // Kujtim: this is an issue with openrouter qwen, its sending -1 for the tool index + if len(chunk.Choices) > 0 && len(chunk.Choices[0].Delta.ToolCalls) > 0 && chunk.Choices[0].Delta.ToolCalls[0].Index == -1 { + chunk.Choices[0].Delta.ToolCalls[0].Index = int64(currentToolIndex) + currentToolIndex++ + } acc.AddChunk(chunk) // This fixes multiple tool calls for some providers for _, choice := range chunk.Choices { @@ -348,6 +358,14 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t if currentToolCallID == "" { if toolCall.ID != "" { currentToolCallID = toolCall.ID + eventChan <- ProviderEvent{ + Type: EventToolUseStart, + ToolCall: &message.ToolCall{ + ID: toolCall.ID, + Name: toolCall.Function.Name, + Finished: false, + }, + } currentToolCall = openai.ChatCompletionMessageToolCall{ ID: toolCall.ID, Type: "function", @@ -359,13 +377,21 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t } } else { // Delta tool use - if toolCall.ID == "" { + if toolCall.ID == "" || toolCall.ID == currentToolCallID { currentToolCall.Function.Arguments += toolCall.Function.Arguments } else { // Detect new tool use if toolCall.ID != currentToolCallID { msgToolCalls = append(msgToolCalls, currentToolCall) currentToolCallID = toolCall.ID + eventChan <- ProviderEvent{ + Type: EventToolUseStart, + ToolCall: &message.ToolCall{ + ID: toolCall.ID, + Name: toolCall.Function.Name, + Finished: false, + }, + } currentToolCall = openai.ChatCompletionMessageToolCall{ ID: toolCall.ID, Type: "function", @@ -378,7 +404,8 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t } } } - if choice.FinishReason == "tool_calls" { + // Kujtim: some models send finish stop even for tool calls + if choice.FinishReason == "tool_calls" || (choice.FinishReason == "stop" && currentToolCallID != "") { msgToolCalls = append(msgToolCalls, currentToolCall) if len(acc.Choices) > 0 { acc.Choices[0].Message.ToolCalls = msgToolCalls @@ -461,31 +488,41 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t } func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) { - var apiErr *openai.Error - if !errors.As(err, &apiErr) { - return false, 0, err - } - if attempts > maxRetries { return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries) } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return false, 0, err + } + var apiErr *openai.Error + retryMs := 0 + retryAfterValues := []string{} + if errors.As(err, &apiErr) { + // Check for token expiration (401 Unauthorized) + if apiErr.StatusCode == 401 { + o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey) + if err != nil { + return false, 0, fmt.Errorf("failed to resolve API key: %w", err) + } + o.client = createOpenAIClient(o.providerOptions) + return true, 0, nil + } - // Check for token expiration (401 Unauthorized) - if apiErr.StatusCode == 401 { - o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey) - if err != nil { - return false, 0, fmt.Errorf("failed to resolve API key: %w", err) + if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 { + return false, 0, err } - o.client = createOpenAIClient(o.providerOptions) - return true, 0, nil - } - if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 { - return false, 0, err + retryAfterValues = apiErr.Response.Header.Values("Retry-After") } - retryMs := 0 - retryAfterValues := apiErr.Response.Header.Values("Retry-After") + if apiErr != nil { + slog.Warn("OpenAI API error", "status_code", apiErr.StatusCode, "message", apiErr.Message, "type", apiErr.Type) + if len(retryAfterValues) > 0 { + slog.Warn("Retry-After header", "values", retryAfterValues) + } + } else { + slog.Warn("OpenAI API error", "error", err.Error()) + } backoffMs := 2000 * (1 << (attempts - 1)) jitterMs := int(float64(backoffMs) * 0.2)