requestconfig.go

  1// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
  2
  3package requestconfig
  4
  5import (
  6	"bytes"
  7	"context"
  8	"encoding/json"
  9	"fmt"
 10	"io"
 11	"math"
 12	"math/rand"
 13	"mime"
 14	"net/http"
 15	"net/url"
 16	"runtime"
 17	"strconv"
 18	"strings"
 19	"time"
 20
 21	"github.com/openai/openai-go/internal"
 22	"github.com/openai/openai-go/internal/apierror"
 23	"github.com/openai/openai-go/internal/apiform"
 24	"github.com/openai/openai-go/internal/apiquery"
 25	"github.com/tidwall/gjson"
 26)
 27
 28func getDefaultHeaders() map[string]string {
 29	return map[string]string{
 30		"User-Agent": fmt.Sprintf("OpenAI/Go %s", internal.PackageVersion),
 31	}
 32}
 33
 34func getNormalizedOS() string {
 35	switch runtime.GOOS {
 36	case "ios":
 37		return "iOS"
 38	case "android":
 39		return "Android"
 40	case "darwin":
 41		return "MacOS"
 42	case "window":
 43		return "Windows"
 44	case "freebsd":
 45		return "FreeBSD"
 46	case "openbsd":
 47		return "OpenBSD"
 48	case "linux":
 49		return "Linux"
 50	default:
 51		return fmt.Sprintf("Other:%s", runtime.GOOS)
 52	}
 53}
 54
 55func getNormalizedArchitecture() string {
 56	switch runtime.GOARCH {
 57	case "386":
 58		return "x32"
 59	case "amd64":
 60		return "x64"
 61	case "arm":
 62		return "arm"
 63	case "arm64":
 64		return "arm64"
 65	default:
 66		return fmt.Sprintf("other:%s", runtime.GOARCH)
 67	}
 68}
 69
 70func getPlatformProperties() map[string]string {
 71	return map[string]string{
 72		"X-Stainless-Lang":            "go",
 73		"X-Stainless-Package-Version": internal.PackageVersion,
 74		"X-Stainless-OS":              getNormalizedOS(),
 75		"X-Stainless-Arch":            getNormalizedArchitecture(),
 76		"X-Stainless-Runtime":         "go",
 77		"X-Stainless-Runtime-Version": runtime.Version(),
 78	}
 79}
 80
 81type RequestOption interface {
 82	Apply(*RequestConfig) error
 83}
 84
 85type RequestOptionFunc func(*RequestConfig) error
 86type PreRequestOptionFunc func(*RequestConfig) error
 87
 88func (s RequestOptionFunc) Apply(r *RequestConfig) error    { return s(r) }
 89func (s PreRequestOptionFunc) Apply(r *RequestConfig) error { return s(r) }
 90
 91func NewRequestConfig(ctx context.Context, method string, u string, body any, dst any, opts ...RequestOption) (*RequestConfig, error) {
 92	var reader io.Reader
 93
 94	contentType := "application/json"
 95	hasSerializationFunc := false
 96
 97	if body, ok := body.(json.Marshaler); ok {
 98		content, err := body.MarshalJSON()
 99		if err != nil {
100			return nil, err
101		}
102		reader = bytes.NewBuffer(content)
103		hasSerializationFunc = true
104	}
105	if body, ok := body.(apiform.Marshaler); ok {
106		var (
107			content []byte
108			err     error
109		)
110		content, contentType, err = body.MarshalMultipart()
111		if err != nil {
112			return nil, err
113		}
114		reader = bytes.NewBuffer(content)
115		hasSerializationFunc = true
116	}
117	if body, ok := body.(apiquery.Queryer); ok {
118		hasSerializationFunc = true
119		q, err := body.URLQuery()
120		if err != nil {
121			return nil, err
122		}
123		params := q.Encode()
124		if params != "" {
125			u = u + "?" + params
126		}
127	}
128	if body, ok := body.([]byte); ok {
129		reader = bytes.NewBuffer(body)
130		hasSerializationFunc = true
131	}
132	if body, ok := body.(io.Reader); ok {
133		reader = body
134		hasSerializationFunc = true
135	}
136
137	// Fallback to json serialization if none of the serialization functions that we expect
138	// to see is present.
139	if body != nil && !hasSerializationFunc {
140		content, err := json.Marshal(body)
141		if err != nil {
142			return nil, err
143		}
144		reader = bytes.NewBuffer(content)
145	}
146
147	req, err := http.NewRequestWithContext(ctx, method, u, nil)
148	if err != nil {
149		return nil, err
150	}
151	if reader != nil {
152		req.Header.Set("Content-Type", contentType)
153	}
154
155	req.Header.Set("Accept", "application/json")
156	req.Header.Set("X-Stainless-Retry-Count", "0")
157	req.Header.Set("X-Stainless-Timeout", "0")
158	for k, v := range getDefaultHeaders() {
159		req.Header.Add(k, v)
160	}
161
162	for k, v := range getPlatformProperties() {
163		req.Header.Add(k, v)
164	}
165	cfg := RequestConfig{
166		MaxRetries: 2,
167		Context:    ctx,
168		Request:    req,
169		HTTPClient: http.DefaultClient,
170		Body:       reader,
171	}
172	cfg.ResponseBodyInto = dst
173	err = cfg.Apply(opts...)
174	if err != nil {
175		return nil, err
176	}
177
178	// This must run after `cfg.Apply(...)` above in case the request timeout gets modified. We also only
179	// apply our own logic for it if it's still "0" from above. If it's not, then it was deleted or modified
180	// by the user and we should respect that.
181	if req.Header.Get("X-Stainless-Timeout") == "0" {
182		if cfg.RequestTimeout == time.Duration(0) {
183			req.Header.Del("X-Stainless-Timeout")
184		} else {
185			req.Header.Set("X-Stainless-Timeout", strconv.Itoa(int(cfg.RequestTimeout.Seconds())))
186		}
187	}
188
189	return &cfg, nil
190}
191
192// This interface is primarily used to describe an [*http.Client], but also
193// supports custom HTTP implementations.
194type HTTPDoer interface {
195	Do(req *http.Request) (*http.Response, error)
196}
197
198// RequestConfig represents all the state related to one request.
199//
200// Editing the variables inside RequestConfig directly is unstable api. Prefer
201// composing the RequestOption instead if possible.
202type RequestConfig struct {
203	MaxRetries     int
204	RequestTimeout time.Duration
205	Context        context.Context
206	Request        *http.Request
207	BaseURL        *url.URL
208	// DefaultBaseURL will be used if BaseURL is not explicitly overridden using
209	// WithBaseURL.
210	DefaultBaseURL *url.URL
211	CustomHTTPDoer HTTPDoer
212	HTTPClient     *http.Client
213	Middlewares    []middleware
214	APIKey         string
215	Organization   string
216	Project        string
217	WebhookSecret  string
218	// If ResponseBodyInto not nil, then we will attempt to deserialize into
219	// ResponseBodyInto. If Destination is a []byte, then it will return the body as
220	// is.
221	ResponseBodyInto any
222	// ResponseInto copies the \*http.Response of the corresponding request into the
223	// given address
224	ResponseInto **http.Response
225	Body         io.Reader
226}
227
228// middleware is exactly the same type as the Middleware type found in the [option] package,
229// but it is redeclared here for circular dependency issues.
230type middleware = func(*http.Request, middlewareNext) (*http.Response, error)
231
232// middlewareNext is exactly the same type as the MiddlewareNext type found in the [option] package,
233// but it is redeclared here for circular dependency issues.
234type middlewareNext = func(*http.Request) (*http.Response, error)
235
236func applyMiddleware(middleware middleware, next middlewareNext) middlewareNext {
237	return func(req *http.Request) (res *http.Response, err error) {
238		return middleware(req, next)
239	}
240}
241
242func shouldRetry(req *http.Request, res *http.Response) bool {
243	// If there is no way to recover the Body, then we shouldn't retry.
244	if req.Body != nil && req.GetBody == nil {
245		return false
246	}
247
248	// If there is no response, that indicates that there is a connection error
249	// so we retry the request.
250	if res == nil {
251		return true
252	}
253
254	// If the header explicitly wants a retry behavior, respect that over the
255	// http status code.
256	if res.Header.Get("x-should-retry") == "true" {
257		return true
258	}
259	if res.Header.Get("x-should-retry") == "false" {
260		return false
261	}
262
263	return res.StatusCode == http.StatusRequestTimeout ||
264		res.StatusCode == http.StatusConflict ||
265		res.StatusCode == http.StatusTooManyRequests ||
266		res.StatusCode >= http.StatusInternalServerError
267}
268
269func parseRetryAfterHeader(resp *http.Response) (time.Duration, bool) {
270	if resp == nil {
271		return 0, false
272	}
273
274	type retryData struct {
275		header string
276		units  time.Duration
277
278		// custom is used when the regular algorithm failed and is optional.
279		// the returned duration is used verbatim (units is not applied).
280		custom func(string) (time.Duration, bool)
281	}
282
283	nop := func(string) (time.Duration, bool) { return 0, false }
284
285	// the headers are listed in order of preference
286	retries := []retryData{
287		{
288			header: "Retry-After-Ms",
289			units:  time.Millisecond,
290			custom: nop,
291		},
292		{
293			header: "Retry-After",
294			units:  time.Second,
295
296			// retry-after values are expressed in either number of
297			// seconds or an HTTP-date indicating when to try again
298			custom: func(ra string) (time.Duration, bool) {
299				t, err := time.Parse(time.RFC1123, ra)
300				if err != nil {
301					return 0, false
302				}
303				return time.Until(t), true
304			},
305		},
306	}
307
308	for _, retry := range retries {
309		v := resp.Header.Get(retry.header)
310		if v == "" {
311			continue
312		}
313		if retryAfter, err := strconv.ParseFloat(v, 64); err == nil {
314			return time.Duration(retryAfter * float64(retry.units)), true
315		}
316		if d, ok := retry.custom(v); ok {
317			return d, true
318		}
319	}
320
321	return 0, false
322}
323
324// isBeforeContextDeadline reports whether the non-zero Time t is
325// before ctx's deadline. If ctx does not have a deadline, it
326// always reports true (the deadline is considered infinite).
327func isBeforeContextDeadline(t time.Time, ctx context.Context) bool {
328	d, ok := ctx.Deadline()
329	if !ok {
330		return true
331	}
332	return t.Before(d)
333}
334
335// bodyWithTimeout is an io.ReadCloser which can observe a context's cancel func
336// to handle timeouts etc. It wraps an existing io.ReadCloser.
337type bodyWithTimeout struct {
338	stop func() // stops the time.Timer waiting to cancel the request
339	rc   io.ReadCloser
340}
341
342func (b *bodyWithTimeout) Read(p []byte) (n int, err error) {
343	n, err = b.rc.Read(p)
344	if err == nil {
345		return n, nil
346	}
347	if err == io.EOF {
348		return n, err
349	}
350	return n, err
351}
352
353func (b *bodyWithTimeout) Close() error {
354	err := b.rc.Close()
355	b.stop()
356	return err
357}
358
359func retryDelay(res *http.Response, retryCount int) time.Duration {
360	// If the API asks us to wait a certain amount of time (and it's a reasonable amount),
361	// just do what it says.
362
363	if retryAfterDelay, ok := parseRetryAfterHeader(res); ok && 0 <= retryAfterDelay && retryAfterDelay < time.Minute {
364		return retryAfterDelay
365	}
366
367	maxDelay := 8 * time.Second
368	delay := time.Duration(0.5 * float64(time.Second) * math.Pow(2, float64(retryCount)))
369	if delay > maxDelay {
370		delay = maxDelay
371	}
372
373	jitter := rand.Int63n(int64(delay / 4))
374	delay -= time.Duration(jitter)
375	return delay
376}
377
378func (cfg *RequestConfig) Execute() (err error) {
379	if cfg.BaseURL == nil {
380		if cfg.DefaultBaseURL != nil {
381			cfg.BaseURL = cfg.DefaultBaseURL
382		} else {
383			return fmt.Errorf("requestconfig: base url is not set")
384		}
385	}
386
387	cfg.Request.URL, err = cfg.BaseURL.Parse(strings.TrimLeft(cfg.Request.URL.String(), "/"))
388	if err != nil {
389		return err
390	}
391
392	if cfg.Body != nil && cfg.Request.Body == nil {
393		switch body := cfg.Body.(type) {
394		case *bytes.Buffer:
395			b := body.Bytes()
396			cfg.Request.ContentLength = int64(body.Len())
397			cfg.Request.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(bytes.NewReader(b)), nil }
398			cfg.Request.Body, _ = cfg.Request.GetBody()
399		case *bytes.Reader:
400			cfg.Request.ContentLength = int64(body.Len())
401			cfg.Request.GetBody = func() (io.ReadCloser, error) {
402				_, err := body.Seek(0, 0)
403				return io.NopCloser(body), err
404			}
405			cfg.Request.Body, _ = cfg.Request.GetBody()
406		default:
407			if rc, ok := body.(io.ReadCloser); ok {
408				cfg.Request.Body = rc
409			} else {
410				cfg.Request.Body = io.NopCloser(body)
411			}
412		}
413	}
414
415	handler := cfg.HTTPClient.Do
416	if cfg.CustomHTTPDoer != nil {
417		handler = cfg.CustomHTTPDoer.Do
418	}
419	for i := len(cfg.Middlewares) - 1; i >= 0; i -= 1 {
420		handler = applyMiddleware(cfg.Middlewares[i], handler)
421	}
422
423	// Don't send the current retry count in the headers if the caller modified the header defaults.
424	shouldSendRetryCount := cfg.Request.Header.Get("X-Stainless-Retry-Count") == "0"
425
426	var res *http.Response
427	var cancel context.CancelFunc
428	for retryCount := 0; retryCount <= cfg.MaxRetries; retryCount += 1 {
429		ctx := cfg.Request.Context()
430		if cfg.RequestTimeout != time.Duration(0) && isBeforeContextDeadline(time.Now().Add(cfg.RequestTimeout), ctx) {
431			ctx, cancel = context.WithTimeout(ctx, cfg.RequestTimeout)
432			defer func() {
433				// The cancel function is nil if it was handed off to be handled in a different scope.
434				if cancel != nil {
435					cancel()
436				}
437			}()
438		}
439
440		req := cfg.Request.Clone(ctx)
441		if shouldSendRetryCount {
442			req.Header.Set("X-Stainless-Retry-Count", strconv.Itoa(retryCount))
443		}
444
445		res, err = handler(req)
446		if ctx != nil && ctx.Err() != nil {
447			return ctx.Err()
448		}
449		if !shouldRetry(cfg.Request, res) || retryCount >= cfg.MaxRetries {
450			break
451		}
452
453		// Prepare next request and wait for the retry delay
454		if cfg.Request.GetBody != nil {
455			cfg.Request.Body, err = cfg.Request.GetBody()
456			if err != nil {
457				return err
458			}
459		}
460
461		// Can't actually refresh the body, so we don't attempt to retry here
462		if cfg.Request.GetBody == nil && cfg.Request.Body != nil {
463			break
464		}
465
466		time.Sleep(retryDelay(res, retryCount))
467	}
468
469	// Save *http.Response if it is requested to, even if there was an error making the request. This is
470	// useful in cases where you might want to debug by inspecting the response. Note that if err != nil,
471	// the response should be generally be empty, but there are edge cases.
472	if cfg.ResponseInto != nil {
473		*cfg.ResponseInto = res
474	}
475	if responseBodyInto, ok := cfg.ResponseBodyInto.(**http.Response); ok {
476		*responseBodyInto = res
477	}
478
479	// If there was a connection error in the final request or any other transport error,
480	// return that early without trying to coerce into an APIError.
481	if err != nil {
482		return err
483	}
484
485	if res.StatusCode >= 400 {
486		contents, err := io.ReadAll(res.Body)
487		res.Body.Close()
488		if err != nil {
489			return err
490		}
491
492		// If there is an APIError, re-populate the response body so that debugging
493		// utilities can conveniently dump the response without issue.
494		res.Body = io.NopCloser(bytes.NewBuffer(contents))
495
496		// Load the contents into the error format if it is provided.
497		aerr := apierror.Error{Request: cfg.Request, Response: res, StatusCode: res.StatusCode}
498		unwrapped := gjson.GetBytes(contents, "error").Raw
499		err = aerr.UnmarshalJSON([]byte(unwrapped))
500		if err != nil {
501			return err
502		}
503		return &aerr
504	}
505
506	_, intoCustomResponseBody := cfg.ResponseBodyInto.(**http.Response)
507	if cfg.ResponseBodyInto == nil || intoCustomResponseBody {
508		// We aren't reading the response body in this scope, but whoever is will need the
509		// cancel func from the context to observe request timeouts.
510		// Put the cancel function in the response body so it can be handled elsewhere.
511		if cancel != nil {
512			res.Body = &bodyWithTimeout{rc: res.Body, stop: cancel}
513			cancel = nil
514		}
515		return nil
516	}
517
518	contents, err := io.ReadAll(res.Body)
519	res.Body.Close()
520	if err != nil {
521		return fmt.Errorf("error reading response body: %w", err)
522	}
523
524	// If we are not json, return plaintext
525	contentType := res.Header.Get("content-type")
526	mediaType, _, _ := mime.ParseMediaType(contentType)
527	isJSON := strings.Contains(mediaType, "application/json") || strings.HasSuffix(mediaType, "+json")
528	if !isJSON {
529		switch dst := cfg.ResponseBodyInto.(type) {
530		case *string:
531			*dst = string(contents)
532		case **string:
533			tmp := string(contents)
534			*dst = &tmp
535		case *[]byte:
536			*dst = contents
537		default:
538			return fmt.Errorf("expected destination type of 'string' or '[]byte' for responses with content-type '%s' that is not 'application/json'", contentType)
539		}
540		return nil
541	}
542
543	switch dst := cfg.ResponseBodyInto.(type) {
544	// If the response happens to be a byte array, deserialize the body as-is.
545	case *[]byte:
546		*dst = contents
547	default:
548		err = json.NewDecoder(bytes.NewReader(contents)).Decode(cfg.ResponseBodyInto)
549		if err != nil {
550			return fmt.Errorf("error parsing response json: %w", err)
551		}
552	}
553
554	return nil
555}
556
557func ExecuteNewRequest(ctx context.Context, method string, u string, body any, dst any, opts ...RequestOption) error {
558	cfg, err := NewRequestConfig(ctx, method, u, body, dst, opts...)
559	if err != nil {
560		return err
561	}
562	return cfg.Execute()
563}
564
565func (cfg *RequestConfig) Clone(ctx context.Context) *RequestConfig {
566	if cfg == nil {
567		return nil
568	}
569	req := cfg.Request.Clone(ctx)
570	var err error
571	if req.Body != nil {
572		req.Body, err = req.GetBody()
573	}
574	if err != nil {
575		return nil
576	}
577	new := &RequestConfig{
578		MaxRetries:     cfg.MaxRetries,
579		RequestTimeout: cfg.RequestTimeout,
580		Context:        ctx,
581		Request:        req,
582		BaseURL:        cfg.BaseURL,
583		HTTPClient:     cfg.HTTPClient,
584		Middlewares:    cfg.Middlewares,
585		APIKey:         cfg.APIKey,
586		Organization:   cfg.Organization,
587		Project:        cfg.Project,
588		WebhookSecret:  cfg.WebhookSecret,
589	}
590
591	return new
592}
593
594func (cfg *RequestConfig) Apply(opts ...RequestOption) error {
595	for _, opt := range opts {
596		err := opt.Apply(cfg)
597		if err != nil {
598			return err
599		}
600	}
601	return nil
602}
603
604// PreRequestOptions is used to collect all the options which need to be known before
605// a call to [RequestConfig.ExecuteNewRequest], such as path parameters
606// or global defaults.
607// PreRequestOptions will return a [RequestConfig] with the options applied.
608//
609// Only request option functions of type [PreRequestOptionFunc] are applied.
610func PreRequestOptions(opts ...RequestOption) (RequestConfig, error) {
611	cfg := RequestConfig{}
612	for _, opt := range opts {
613		if opt, ok := opt.(PreRequestOptionFunc); ok {
614			err := opt.Apply(&cfg)
615			if err != nil {
616				return cfg, err
617			}
618		}
619	}
620	return cfg, nil
621}
622
623// WithDefaultBaseURL returns a RequestOption that sets the client's default Base URL.
624// This is always overridden by setting a base URL with WithBaseURL.
625// WithBaseURL should be used instead of WithDefaultBaseURL except in internal code.
626func WithDefaultBaseURL(baseURL string) RequestOption {
627	u, err := url.Parse(baseURL)
628	return RequestOptionFunc(func(r *RequestConfig) error {
629		if err != nil {
630			return err
631		}
632		r.DefaultBaseURL = u
633		return nil
634	})
635}