error.go

 1package openai
 2
 3import (
 4	"cmp"
 5	"errors"
 6	"io"
 7	"net/http"
 8	"regexp"
 9	"strconv"
10	"strings"
11
12	"charm.land/fantasy"
13	"github.com/charmbracelet/openai-go"
14)
15
16var openaiContextPattern = regexp.MustCompile(`maximum context length is (\d+) tokens.*?(?:resulted in|requested) (\d+) tokens`)
17
18func toProviderErr(err error) error {
19	var apiErr *openai.Error
20	if errors.As(err, &apiErr) {
21		message := toProviderErrMessage(apiErr)
22		providerErr := &fantasy.ProviderError{
23			Title:           cmp.Or(fantasy.ErrorTitleForStatusCode(apiErr.StatusCode), "provider request failed"),
24			Message:         message,
25			Cause:           apiErr,
26			URL:             apiErr.Request.URL.String(),
27			StatusCode:      apiErr.StatusCode,
28			RequestBody:     apiErr.DumpRequest(true),
29			ResponseHeaders: toHeaderMap(apiErr.Response.Header),
30			ResponseBody:    apiErr.DumpResponse(true),
31		}
32
33		parseContextTooLargeError(message, providerErr)
34
35		return providerErr
36	}
37	// Wrap in a `ProviderError` so `.IsRetriable()` works.
38	if errors.Is(err, io.ErrUnexpectedEOF) {
39		return &fantasy.ProviderError{
40			Title:   "stream transport error",
41			Message: err.Error(),
42			Cause:   err,
43		}
44	}
45	return err
46}
47
48func parseContextTooLargeError(message string, providerErr *fantasy.ProviderError) {
49	matches := openaiContextPattern.FindStringSubmatch(message)
50	if matches == nil {
51		return
52	}
53	providerErr.ContextTooLargeErr = true
54	providerErr.ContextMaxTokens, _ = strconv.Atoi(matches[1])
55	providerErr.ContextUsedTokens, _ = strconv.Atoi(matches[2])
56}
57
58func toProviderErrMessage(apiErr *openai.Error) string {
59	if apiErr.Message != "" {
60		return apiErr.Message
61	}
62
63	// For some OpenAI-compatible providers, the SDK is not always able to parse
64	// the error message correctly.
65	// Fallback to returning the raw response body in such cases.
66	data, _ := io.ReadAll(apiErr.Response.Body)
67	return string(data)
68}
69
70func toHeaderMap(in http.Header) (out map[string]string) {
71	out = make(map[string]string, len(in))
72	for k, v := range in {
73		if l := len(v); l > 0 {
74			out[k] = v[l-1]
75			in[strings.ToLower(k)] = v
76		}
77	}
78	return out
79}