copilot.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"net/http"
 10	"os"
 11	"time"
 12
 13	"github.com/openai/openai-go"
 14	"github.com/openai/openai-go/option"
 15	"github.com/openai/openai-go/shared"
 16	"github.com/opencode-ai/opencode/internal/config"
 17	"github.com/opencode-ai/opencode/internal/llm/models"
 18	toolsPkg "github.com/opencode-ai/opencode/internal/llm/tools"
 19	"github.com/opencode-ai/opencode/internal/logging"
 20	"github.com/opencode-ai/opencode/internal/message"
 21)
 22
 23type copilotOptions struct {
 24	reasoningEffort string
 25	extraHeaders    map[string]string
 26	bearerToken     string
 27}
 28
 29type CopilotOption func(*copilotOptions)
 30
 31type copilotClient struct {
 32	providerOptions providerClientOptions
 33	options         copilotOptions
 34	client          openai.Client
 35	httpClient      *http.Client
 36}
 37
 38type CopilotClient ProviderClient
 39
 40// CopilotTokenResponse represents the response from GitHub's token exchange endpoint
 41type CopilotTokenResponse struct {
 42	Token     string `json:"token"`
 43	ExpiresAt int64  `json:"expires_at"`
 44}
 45
 46func (c *copilotClient) isAnthropicModel() bool {
 47	for _, modelId := range models.CopilotAnthropicModels {
 48		if c.providerOptions.model.ID == modelId {
 49			return true
 50		}
 51	}
 52	return false
 53}
 54
 55// loadGitHubToken loads the GitHub OAuth token from the standard GitHub CLI/Copilot locations
 56
 57// exchangeGitHubToken exchanges a GitHub token for a Copilot bearer token
 58func (c *copilotClient) exchangeGitHubToken(githubToken string) (string, error) {
 59	req, err := http.NewRequest("GET", "https://api.github.com/copilot_internal/v2/token", nil)
 60	if err != nil {
 61		return "", fmt.Errorf("failed to create token exchange request: %w", err)
 62	}
 63
 64	req.Header.Set("Authorization", "Token "+githubToken)
 65	req.Header.Set("User-Agent", "OpenCode/1.0")
 66
 67	resp, err := c.httpClient.Do(req)
 68	if err != nil {
 69		return "", fmt.Errorf("failed to exchange GitHub token: %w", err)
 70	}
 71	defer resp.Body.Close()
 72
 73	if resp.StatusCode != http.StatusOK {
 74		body, _ := io.ReadAll(resp.Body)
 75		return "", fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body))
 76	}
 77
 78	var tokenResp CopilotTokenResponse
 79	if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
 80		return "", fmt.Errorf("failed to decode token response: %w", err)
 81	}
 82
 83	return tokenResp.Token, nil
 84}
 85
 86func newCopilotClient(opts providerClientOptions) CopilotClient {
 87	copilotOpts := copilotOptions{
 88		reasoningEffort: "medium",
 89	}
 90	// Apply copilot-specific options
 91	for _, o := range opts.copilotOptions {
 92		o(&copilotOpts)
 93	}
 94
 95	// Create HTTP client for token exchange
 96	httpClient := &http.Client{
 97		Timeout: 30 * time.Second,
 98	}
 99
100	var bearerToken string
101
102	// If bearer token is already provided, use it
103	if copilotOpts.bearerToken != "" {
104		bearerToken = copilotOpts.bearerToken
105	} else {
106		// Try to get GitHub token from multiple sources
107		var githubToken string
108
109		// 1. Environment variable
110		githubToken = os.Getenv("GITHUB_TOKEN")
111
112		// 2. API key from options
113		if githubToken == "" {
114			githubToken = opts.apiKey
115		}
116
117		// 3. Standard GitHub CLI/Copilot locations
118		if githubToken == "" {
119			var err error
120			githubToken, err = config.LoadGitHubToken()
121			if err != nil {
122				logging.Debug("Failed to load GitHub token from standard locations", "error", err)
123			}
124		}
125
126		if githubToken == "" {
127			logging.Error("GitHub token is required for Copilot provider. Set GITHUB_TOKEN environment variable, configure it in opencode.json, or ensure GitHub CLI/Copilot is properly authenticated.")
128			return &copilotClient{
129				providerOptions: opts,
130				options:         copilotOpts,
131				httpClient:      httpClient,
132			}
133		}
134
135		// Create a temporary client for token exchange
136		tempClient := &copilotClient{
137			providerOptions: opts,
138			options:         copilotOpts,
139			httpClient:      httpClient,
140		}
141
142		// Exchange GitHub token for bearer token
143		var err error
144		bearerToken, err = tempClient.exchangeGitHubToken(githubToken)
145		if err != nil {
146			logging.Error("Failed to exchange GitHub token for Copilot bearer token", "error", err)
147			return &copilotClient{
148				providerOptions: opts,
149				options:         copilotOpts,
150				httpClient:      httpClient,
151			}
152		}
153	}
154
155	copilotOpts.bearerToken = bearerToken
156
157	// GitHub Copilot API base URL
158	baseURL := "https://api.githubcopilot.com"
159
160	openaiClientOptions := []option.RequestOption{
161		option.WithBaseURL(baseURL),
162		option.WithAPIKey(bearerToken), // Use bearer token as API key
163	}
164
165	// Add GitHub Copilot specific headers
166	openaiClientOptions = append(openaiClientOptions,
167		option.WithHeader("Editor-Version", "OpenCode/1.0"),
168		option.WithHeader("Editor-Plugin-Version", "OpenCode/1.0"),
169		option.WithHeader("Copilot-Integration-Id", "vscode-chat"),
170	)
171
172	// Add any extra headers
173	if copilotOpts.extraHeaders != nil {
174		for key, value := range copilotOpts.extraHeaders {
175			openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
176		}
177	}
178
179	client := openai.NewClient(openaiClientOptions...)
180	// logging.Debug("Copilot client created", "opts", opts, "copilotOpts", copilotOpts, "model", opts.model)
181	return &copilotClient{
182		providerOptions: opts,
183		options:         copilotOpts,
184		client:          client,
185		httpClient:      httpClient,
186	}
187}
188
189func (c *copilotClient) convertMessages(messages []message.Message) (copilotMessages []openai.ChatCompletionMessageParamUnion) {
190	// Add system message first
191	copilotMessages = append(copilotMessages, openai.SystemMessage(c.providerOptions.systemMessage))
192
193	for _, msg := range messages {
194		switch msg.Role {
195		case message.User:
196			var content []openai.ChatCompletionContentPartUnionParam
197			textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
198			content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
199
200			for _, binaryContent := range msg.BinaryContent() {
201				imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(models.ProviderCopilot)}
202				imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
203				content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
204			}
205
206			copilotMessages = append(copilotMessages, openai.UserMessage(content))
207
208		case message.Assistant:
209			assistantMsg := openai.ChatCompletionAssistantMessageParam{
210				Role: "assistant",
211			}
212
213			if msg.Content().String() != "" {
214				assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
215					OfString: openai.String(msg.Content().String()),
216				}
217			}
218
219			if len(msg.ToolCalls()) > 0 {
220				assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
221				for i, call := range msg.ToolCalls() {
222					assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
223						ID:   call.ID,
224						Type: "function",
225						Function: openai.ChatCompletionMessageToolCallFunctionParam{
226							Name:      call.Name,
227							Arguments: call.Input,
228						},
229					}
230				}
231			}
232
233			copilotMessages = append(copilotMessages, openai.ChatCompletionMessageParamUnion{
234				OfAssistant: &assistantMsg,
235			})
236
237		case message.Tool:
238			for _, result := range msg.ToolResults() {
239				copilotMessages = append(copilotMessages,
240					openai.ToolMessage(result.Content, result.ToolCallID),
241				)
242			}
243		}
244	}
245
246	return
247}
248
249func (c *copilotClient) convertTools(tools []toolsPkg.BaseTool) []openai.ChatCompletionToolParam {
250	copilotTools := make([]openai.ChatCompletionToolParam, len(tools))
251
252	for i, tool := range tools {
253		info := tool.Info()
254		copilotTools[i] = openai.ChatCompletionToolParam{
255			Function: openai.FunctionDefinitionParam{
256				Name:        info.Name,
257				Description: openai.String(info.Description),
258				Parameters: openai.FunctionParameters{
259					"type":       "object",
260					"properties": info.Parameters,
261					"required":   info.Required,
262				},
263			},
264		}
265	}
266
267	return copilotTools
268}
269
270func (c *copilotClient) finishReason(reason string) message.FinishReason {
271	switch reason {
272	case "stop":
273		return message.FinishReasonEndTurn
274	case "length":
275		return message.FinishReasonMaxTokens
276	case "tool_calls":
277		return message.FinishReasonToolUse
278	default:
279		return message.FinishReasonUnknown
280	}
281}
282
283func (c *copilotClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
284	params := openai.ChatCompletionNewParams{
285		Model:    openai.ChatModel(c.providerOptions.model.APIModel),
286		Messages: messages,
287		Tools:    tools,
288	}
289
290	if c.providerOptions.model.CanReason == true {
291		params.MaxCompletionTokens = openai.Int(c.providerOptions.maxTokens)
292		switch c.options.reasoningEffort {
293		case "low":
294			params.ReasoningEffort = shared.ReasoningEffortLow
295		case "medium":
296			params.ReasoningEffort = shared.ReasoningEffortMedium
297		case "high":
298			params.ReasoningEffort = shared.ReasoningEffortHigh
299		default:
300			params.ReasoningEffort = shared.ReasoningEffortMedium
301		}
302	} else {
303		params.MaxTokens = openai.Int(c.providerOptions.maxTokens)
304	}
305
306	return params
307}
308
309func (c *copilotClient) send(ctx context.Context, messages []message.Message, tools []toolsPkg.BaseTool) (response *ProviderResponse, err error) {
310	params := c.preparedParams(c.convertMessages(messages), c.convertTools(tools))
311	cfg := config.Get()
312	var sessionId string
313	requestSeqId := (len(messages) + 1) / 2
314	if cfg.Debug {
315		// jsonData, _ := json.Marshal(params)
316		// logging.Debug("Prepared messages", "messages", string(jsonData))
317		if sid, ok := ctx.Value(toolsPkg.SessionIDContextKey).(string); ok {
318			sessionId = sid
319		}
320		jsonData, _ := json.Marshal(params)
321		if sessionId != "" {
322			filepath := logging.WriteRequestMessageJson(sessionId, requestSeqId, params)
323			logging.Debug("Prepared messages", "filepath", filepath)
324		} else {
325			logging.Debug("Prepared messages", "messages", string(jsonData))
326		}
327	}
328
329	attempts := 0
330	for {
331		attempts++
332		copilotResponse, err := c.client.Chat.Completions.New(
333			ctx,
334			params,
335		)
336
337		// If there is an error we are going to see if we can retry the call
338		if err != nil {
339			retry, after, retryErr := c.shouldRetry(attempts, err)
340			if retryErr != nil {
341				return nil, retryErr
342			}
343			if retry {
344				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
345				select {
346				case <-ctx.Done():
347					return nil, ctx.Err()
348				case <-time.After(time.Duration(after) * time.Millisecond):
349					continue
350				}
351			}
352			return nil, retryErr
353		}
354
355		content := ""
356		if copilotResponse.Choices[0].Message.Content != "" {
357			content = copilotResponse.Choices[0].Message.Content
358		}
359
360		toolCalls := c.toolCalls(*copilotResponse)
361		finishReason := c.finishReason(string(copilotResponse.Choices[0].FinishReason))
362
363		if len(toolCalls) > 0 {
364			finishReason = message.FinishReasonToolUse
365		}
366
367		return &ProviderResponse{
368			Content:      content,
369			ToolCalls:    toolCalls,
370			Usage:        c.usage(*copilotResponse),
371			FinishReason: finishReason,
372		}, nil
373	}
374}
375
376func (c *copilotClient) stream(ctx context.Context, messages []message.Message, tools []toolsPkg.BaseTool) <-chan ProviderEvent {
377	params := c.preparedParams(c.convertMessages(messages), c.convertTools(tools))
378	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
379		IncludeUsage: openai.Bool(true),
380	}
381
382	cfg := config.Get()
383	var sessionId string
384	requestSeqId := (len(messages) + 1) / 2
385	if cfg.Debug {
386		if sid, ok := ctx.Value(toolsPkg.SessionIDContextKey).(string); ok {
387			sessionId = sid
388		}
389		jsonData, _ := json.Marshal(params)
390		if sessionId != "" {
391			filepath := logging.WriteRequestMessageJson(sessionId, requestSeqId, params)
392			logging.Debug("Prepared messages", "filepath", filepath)
393		} else {
394			logging.Debug("Prepared messages", "messages", string(jsonData))
395		}
396
397	}
398
399	attempts := 0
400	eventChan := make(chan ProviderEvent)
401
402	go func() {
403		for {
404			attempts++
405			copilotStream := c.client.Chat.Completions.NewStreaming(
406				ctx,
407				params,
408			)
409
410			acc := openai.ChatCompletionAccumulator{}
411			currentContent := ""
412			toolCalls := make([]message.ToolCall, 0)
413
414			var currentToolCallId string
415			var currentToolCall openai.ChatCompletionMessageToolCall
416			var msgToolCalls []openai.ChatCompletionMessageToolCall
417			for copilotStream.Next() {
418				chunk := copilotStream.Current()
419				acc.AddChunk(chunk)
420
421				if cfg.Debug {
422					logging.AppendToStreamSessionLogJson(sessionId, requestSeqId, chunk)
423				}
424
425				for _, choice := range chunk.Choices {
426					if choice.Delta.Content != "" {
427						eventChan <- ProviderEvent{
428							Type:    EventContentDelta,
429							Content: choice.Delta.Content,
430						}
431						currentContent += choice.Delta.Content
432					}
433				}
434
435				if c.isAnthropicModel() {
436					// Monkeypatch adapter for Sonnet-4 multi-tool use
437					for _, choice := range chunk.Choices {
438						if choice.Delta.ToolCalls != nil && len(choice.Delta.ToolCalls) > 0 {
439							toolCall := choice.Delta.ToolCalls[0]
440							// Detect tool use start
441							if currentToolCallId == "" {
442								if toolCall.ID != "" {
443									currentToolCallId = toolCall.ID
444									currentToolCall = openai.ChatCompletionMessageToolCall{
445										ID:   toolCall.ID,
446										Type: "function",
447										Function: openai.ChatCompletionMessageToolCallFunction{
448											Name:      toolCall.Function.Name,
449											Arguments: toolCall.Function.Arguments,
450										},
451									}
452								}
453							} else {
454								// Delta tool use
455								if toolCall.ID == "" {
456									currentToolCall.Function.Arguments += toolCall.Function.Arguments
457								} else {
458									// Detect new tool use
459									if toolCall.ID != currentToolCallId {
460										msgToolCalls = append(msgToolCalls, currentToolCall)
461										currentToolCallId = toolCall.ID
462										currentToolCall = openai.ChatCompletionMessageToolCall{
463											ID:   toolCall.ID,
464											Type: "function",
465											Function: openai.ChatCompletionMessageToolCallFunction{
466												Name:      toolCall.Function.Name,
467												Arguments: toolCall.Function.Arguments,
468											},
469										}
470									}
471								}
472							}
473						}
474						if choice.FinishReason == "tool_calls" {
475							msgToolCalls = append(msgToolCalls, currentToolCall)
476							acc.ChatCompletion.Choices[0].Message.ToolCalls = msgToolCalls
477						}
478					}
479				}
480			}
481
482			err := copilotStream.Err()
483			if err == nil || errors.Is(err, io.EOF) {
484				if cfg.Debug {
485					respFilepath := logging.WriteChatResponseJson(sessionId, requestSeqId, acc.ChatCompletion)
486					logging.Debug("Chat completion response", "filepath", respFilepath)
487				}
488				// Stream completed successfully
489				finishReason := c.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason))
490				if len(acc.ChatCompletion.Choices[0].Message.ToolCalls) > 0 {
491					toolCalls = append(toolCalls, c.toolCalls(acc.ChatCompletion)...)
492				}
493				if len(toolCalls) > 0 {
494					finishReason = message.FinishReasonToolUse
495				}
496
497				eventChan <- ProviderEvent{
498					Type: EventComplete,
499					Response: &ProviderResponse{
500						Content:      currentContent,
501						ToolCalls:    toolCalls,
502						Usage:        c.usage(acc.ChatCompletion),
503						FinishReason: finishReason,
504					},
505				}
506				close(eventChan)
507				return
508			}
509
510			// If there is an error we are going to see if we can retry the call
511			retry, after, retryErr := c.shouldRetry(attempts, err)
512			if retryErr != nil {
513				eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
514				close(eventChan)
515				return
516			}
517			// shouldRetry is not catching the max retries...
518			// TODO: Figure out why
519			if attempts > maxRetries {
520				logging.Warn("Maximum retry attempts reached for rate limit", "attempts", attempts, "max_retries", maxRetries)
521				retry = false
522			}
523			if retry {
524				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d (paused for %d ms)", attempts, maxRetries, after), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
525				select {
526				case <-ctx.Done():
527					// context cancelled
528					if ctx.Err() == nil {
529						eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
530					}
531					close(eventChan)
532					return
533				case <-time.After(time.Duration(after) * time.Millisecond):
534					continue
535				}
536			}
537			eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
538			close(eventChan)
539			return
540		}
541	}()
542
543	return eventChan
544}
545
546func (c *copilotClient) shouldRetry(attempts int, err error) (bool, int64, error) {
547	var apierr *openai.Error
548	if !errors.As(err, &apierr) {
549		return false, 0, err
550	}
551
552	// Check for token expiration (401 Unauthorized)
553	if apierr.StatusCode == 401 {
554		// Try to refresh the bearer token
555		var githubToken string
556
557		// 1. Environment variable
558		githubToken = os.Getenv("GITHUB_TOKEN")
559
560		// 2. API key from options
561		if githubToken == "" {
562			githubToken = c.providerOptions.apiKey
563		}
564
565		// 3. Standard GitHub CLI/Copilot locations
566		if githubToken == "" {
567			var err error
568			githubToken, err = config.LoadGitHubToken()
569			if err != nil {
570				logging.Debug("Failed to load GitHub token from standard locations during retry", "error", err)
571			}
572		}
573
574		if githubToken != "" {
575			newBearerToken, tokenErr := c.exchangeGitHubToken(githubToken)
576			if tokenErr == nil {
577				c.options.bearerToken = newBearerToken
578				// Update the client with the new token
579				// Note: This is a simplified approach. In a production system,
580				// you might want to recreate the entire client with the new token
581				logging.Info("Refreshed Copilot bearer token")
582				return true, 1000, nil // Retry immediately with new token
583			}
584			logging.Error("Failed to refresh Copilot bearer token", "error", tokenErr)
585		}
586		return false, 0, fmt.Errorf("authentication failed: %w", err)
587	}
588	logging.Debug("Copilot API Error", "status", apierr.StatusCode, "headers", apierr.Response.Header, "body", apierr.RawJSON())
589
590	if apierr.StatusCode != 429 && apierr.StatusCode != 500 {
591		return false, 0, err
592	}
593
594	if apierr.StatusCode == 500 {
595		logging.Warn("Copilot API returned 500 error, retrying", "error", err)
596	}
597
598	if attempts > maxRetries {
599		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
600	}
601
602	retryMs := 0
603	retryAfterValues := apierr.Response.Header.Values("Retry-After")
604
605	backoffMs := 2000 * (1 << (attempts - 1))
606	jitterMs := int(float64(backoffMs) * 0.2)
607	retryMs = backoffMs + jitterMs
608	if len(retryAfterValues) > 0 {
609		if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
610			retryMs = retryMs * 1000
611		}
612	}
613	return true, int64(retryMs), nil
614}
615
616func (c *copilotClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
617	var toolCalls []message.ToolCall
618
619	if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
620		for _, call := range completion.Choices[0].Message.ToolCalls {
621			toolCall := message.ToolCall{
622				ID:       call.ID,
623				Name:     call.Function.Name,
624				Input:    call.Function.Arguments,
625				Type:     "function",
626				Finished: true,
627			}
628			toolCalls = append(toolCalls, toolCall)
629		}
630	}
631
632	return toolCalls
633}
634
635func (c *copilotClient) usage(completion openai.ChatCompletion) TokenUsage {
636	cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
637	inputTokens := completion.Usage.PromptTokens - cachedTokens
638
639	return TokenUsage{
640		InputTokens:         inputTokens,
641		OutputTokens:        completion.Usage.CompletionTokens,
642		CacheCreationTokens: 0, // GitHub Copilot doesn't provide this directly
643		CacheReadTokens:     cachedTokens,
644	}
645}
646
647func WithCopilotReasoningEffort(effort string) CopilotOption {
648	return func(options *copilotOptions) {
649		defaultReasoningEffort := "medium"
650		switch effort {
651		case "low", "medium", "high":
652			defaultReasoningEffort = effort
653		default:
654			logging.Warn("Invalid reasoning effort, using default: medium")
655		}
656		options.reasoningEffort = defaultReasoningEffort
657	}
658}
659
660func WithCopilotExtraHeaders(headers map[string]string) CopilotOption {
661	return func(options *copilotOptions) {
662		options.extraHeaders = headers
663	}
664}
665
666func WithCopilotBearerToken(bearerToken string) CopilotOption {
667	return func(options *copilotOptions) {
668		options.bearerToken = bearerToken
669	}
670}
671