requestconfig.go

  1// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
  2
  3package requestconfig
  4
  5import (
  6	"bytes"
  7	"context"
  8	"encoding/json"
  9	"fmt"
 10	"io"
 11	"math"
 12	"math/rand"
 13	"mime"
 14	"net/http"
 15	"net/url"
 16	"runtime"
 17	"strconv"
 18	"strings"
 19	"time"
 20
 21	"github.com/openai/openai-go/internal"
 22	"github.com/openai/openai-go/internal/apierror"
 23	"github.com/openai/openai-go/internal/apiform"
 24	"github.com/openai/openai-go/internal/apiquery"
 25	"github.com/tidwall/gjson"
 26)
 27
 28func getDefaultHeaders() map[string]string {
 29	return map[string]string{
 30		"User-Agent": fmt.Sprintf("OpenAI/Go %s", internal.PackageVersion),
 31	}
 32}
 33
 34func getNormalizedOS() string {
 35	switch runtime.GOOS {
 36	case "ios":
 37		return "iOS"
 38	case "android":
 39		return "Android"
 40	case "darwin":
 41		return "MacOS"
 42	case "window":
 43		return "Windows"
 44	case "freebsd":
 45		return "FreeBSD"
 46	case "openbsd":
 47		return "OpenBSD"
 48	case "linux":
 49		return "Linux"
 50	default:
 51		return fmt.Sprintf("Other:%s", runtime.GOOS)
 52	}
 53}
 54
 55func getNormalizedArchitecture() string {
 56	switch runtime.GOARCH {
 57	case "386":
 58		return "x32"
 59	case "amd64":
 60		return "x64"
 61	case "arm":
 62		return "arm"
 63	case "arm64":
 64		return "arm64"
 65	default:
 66		return fmt.Sprintf("other:%s", runtime.GOARCH)
 67	}
 68}
 69
 70func getPlatformProperties() map[string]string {
 71	return map[string]string{
 72		"X-Stainless-Lang":            "go",
 73		"X-Stainless-Package-Version": internal.PackageVersion,
 74		"X-Stainless-OS":              getNormalizedOS(),
 75		"X-Stainless-Arch":            getNormalizedArchitecture(),
 76		"X-Stainless-Runtime":         "go",
 77		"X-Stainless-Runtime-Version": runtime.Version(),
 78	}
 79}
 80
 81type RequestOption interface {
 82	Apply(*RequestConfig) error
 83}
 84
 85type RequestOptionFunc func(*RequestConfig) error
 86type PreRequestOptionFunc func(*RequestConfig) error
 87
 88func (s RequestOptionFunc) Apply(r *RequestConfig) error    { return s(r) }
 89func (s PreRequestOptionFunc) Apply(r *RequestConfig) error { return s(r) }
 90
 91func NewRequestConfig(ctx context.Context, method string, u string, body interface{}, dst interface{}, opts ...RequestOption) (*RequestConfig, error) {
 92	var reader io.Reader
 93
 94	contentType := "application/json"
 95	hasSerializationFunc := false
 96
 97	if body, ok := body.(json.Marshaler); ok {
 98		content, err := body.MarshalJSON()
 99		if err != nil {
100			return nil, err
101		}
102		reader = bytes.NewBuffer(content)
103		hasSerializationFunc = true
104	}
105	if body, ok := body.(apiform.Marshaler); ok {
106		var (
107			content []byte
108			err     error
109		)
110		content, contentType, err = body.MarshalMultipart()
111		if err != nil {
112			return nil, err
113		}
114		reader = bytes.NewBuffer(content)
115		hasSerializationFunc = true
116	}
117	if body, ok := body.(apiquery.Queryer); ok {
118		hasSerializationFunc = true
119		params := body.URLQuery().Encode()
120		if params != "" {
121			u = u + "?" + params
122		}
123	}
124	if body, ok := body.([]byte); ok {
125		reader = bytes.NewBuffer(body)
126		hasSerializationFunc = true
127	}
128	if body, ok := body.(io.Reader); ok {
129		reader = body
130		hasSerializationFunc = true
131	}
132
133	// Fallback to json serialization if none of the serialization functions that we expect
134	// to see is present.
135	if body != nil && !hasSerializationFunc {
136		content, err := json.Marshal(body)
137		if err != nil {
138			return nil, err
139		}
140		reader = bytes.NewBuffer(content)
141	}
142
143	req, err := http.NewRequestWithContext(ctx, method, u, nil)
144	if err != nil {
145		return nil, err
146	}
147	if reader != nil {
148		req.Header.Set("Content-Type", contentType)
149	}
150
151	req.Header.Set("Accept", "application/json")
152	req.Header.Set("X-Stainless-Retry-Count", "0")
153	req.Header.Set("X-Stainless-Timeout", "0")
154	for k, v := range getDefaultHeaders() {
155		req.Header.Add(k, v)
156	}
157
158	for k, v := range getPlatformProperties() {
159		req.Header.Add(k, v)
160	}
161	cfg := RequestConfig{
162		MaxRetries: 2,
163		Context:    ctx,
164		Request:    req,
165		HTTPClient: http.DefaultClient,
166		Body:       reader,
167	}
168	cfg.ResponseBodyInto = dst
169	err = cfg.Apply(opts...)
170	if err != nil {
171		return nil, err
172	}
173
174	// This must run after `cfg.Apply(...)` above in case the request timeout gets modified. We also only
175	// apply our own logic for it if it's still "0" from above. If it's not, then it was deleted or modified
176	// by the user and we should respect that.
177	if req.Header.Get("X-Stainless-Timeout") == "0" {
178		if cfg.RequestTimeout == time.Duration(0) {
179			req.Header.Del("X-Stainless-Timeout")
180		} else {
181			req.Header.Set("X-Stainless-Timeout", strconv.Itoa(int(cfg.RequestTimeout.Seconds())))
182		}
183	}
184
185	return &cfg, nil
186}
187
188// RequestConfig represents all the state related to one request.
189//
190// Editing the variables inside RequestConfig directly is unstable api. Prefer
191// composing the RequestOption instead if possible.
192type RequestConfig struct {
193	MaxRetries     int
194	RequestTimeout time.Duration
195	Context        context.Context
196	Request        *http.Request
197	BaseURL        *url.URL
198	HTTPClient     *http.Client
199	Middlewares    []middleware
200	APIKey         string
201	Organization   string
202	Project        string
203	// If ResponseBodyInto not nil, then we will attempt to deserialize into
204	// ResponseBodyInto. If Destination is a []byte, then it will return the body as
205	// is.
206	ResponseBodyInto interface{}
207	// ResponseInto copies the \*http.Response of the corresponding request into the
208	// given address
209	ResponseInto **http.Response
210	Body         io.Reader
211}
212
213// middleware is exactly the same type as the Middleware type found in the [option] package,
214// but it is redeclared here for circular dependency issues.
215type middleware = func(*http.Request, middlewareNext) (*http.Response, error)
216
217// middlewareNext is exactly the same type as the MiddlewareNext type found in the [option] package,
218// but it is redeclared here for circular dependency issues.
219type middlewareNext = func(*http.Request) (*http.Response, error)
220
221func applyMiddleware(middleware middleware, next middlewareNext) middlewareNext {
222	return func(req *http.Request) (res *http.Response, err error) {
223		return middleware(req, next)
224	}
225}
226
227func shouldRetry(req *http.Request, res *http.Response) bool {
228	// If there is no way to recover the Body, then we shouldn't retry.
229	if req.Body != nil && req.GetBody == nil {
230		return false
231	}
232
233	// If there is no response, that indicates that there is a connection error
234	// so we retry the request.
235	if res == nil {
236		return true
237	}
238
239	// If the header explictly wants a retry behavior, respect that over the
240	// http status code.
241	if res.Header.Get("x-should-retry") == "true" {
242		return true
243	}
244	if res.Header.Get("x-should-retry") == "false" {
245		return false
246	}
247
248	return res.StatusCode == http.StatusRequestTimeout ||
249		res.StatusCode == http.StatusConflict ||
250		res.StatusCode == http.StatusTooManyRequests ||
251		res.StatusCode >= http.StatusInternalServerError
252}
253
254func parseRetryAfterHeader(resp *http.Response) (time.Duration, bool) {
255	if resp == nil {
256		return 0, false
257	}
258
259	type retryData struct {
260		header string
261		units  time.Duration
262
263		// custom is used when the regular algorithm failed and is optional.
264		// the returned duration is used verbatim (units is not applied).
265		custom func(string) (time.Duration, bool)
266	}
267
268	nop := func(string) (time.Duration, bool) { return 0, false }
269
270	// the headers are listed in order of preference
271	retries := []retryData{
272		{
273			header: "Retry-After-Ms",
274			units:  time.Millisecond,
275			custom: nop,
276		},
277		{
278			header: "Retry-After",
279			units:  time.Second,
280
281			// retry-after values are expressed in either number of
282			// seconds or an HTTP-date indicating when to try again
283			custom: func(ra string) (time.Duration, bool) {
284				t, err := time.Parse(time.RFC1123, ra)
285				if err != nil {
286					return 0, false
287				}
288				return time.Until(t), true
289			},
290		},
291	}
292
293	for _, retry := range retries {
294		v := resp.Header.Get(retry.header)
295		if v == "" {
296			continue
297		}
298		if retryAfter, err := strconv.ParseFloat(v, 64); err == nil {
299			return time.Duration(retryAfter * float64(retry.units)), true
300		}
301		if d, ok := retry.custom(v); ok {
302			return d, true
303		}
304	}
305
306	return 0, false
307}
308
309// isBeforeContextDeadline reports whether the non-zero Time t is
310// before ctx's deadline. If ctx does not have a deadline, it
311// always reports true (the deadline is considered infinite).
312func isBeforeContextDeadline(t time.Time, ctx context.Context) bool {
313	d, ok := ctx.Deadline()
314	if !ok {
315		return true
316	}
317	return t.Before(d)
318}
319
320// bodyWithTimeout is an io.ReadCloser which can observe a context's cancel func
321// to handle timeouts etc. It wraps an existing io.ReadCloser.
322type bodyWithTimeout struct {
323	stop func() // stops the time.Timer waiting to cancel the request
324	rc   io.ReadCloser
325}
326
327func (b *bodyWithTimeout) Read(p []byte) (n int, err error) {
328	n, err = b.rc.Read(p)
329	if err == nil {
330		return n, nil
331	}
332	if err == io.EOF {
333		return n, err
334	}
335	return n, err
336}
337
338func (b *bodyWithTimeout) Close() error {
339	err := b.rc.Close()
340	b.stop()
341	return err
342}
343
344func retryDelay(res *http.Response, retryCount int) time.Duration {
345	// If the API asks us to wait a certain amount of time (and it's a reasonable amount),
346	// just do what it says.
347
348	if retryAfterDelay, ok := parseRetryAfterHeader(res); ok && 0 <= retryAfterDelay && retryAfterDelay < time.Minute {
349		return retryAfterDelay
350	}
351
352	maxDelay := 8 * time.Second
353	delay := time.Duration(0.5 * float64(time.Second) * math.Pow(2, float64(retryCount)))
354	if delay > maxDelay {
355		delay = maxDelay
356	}
357
358	jitter := rand.Int63n(int64(delay / 4))
359	delay -= time.Duration(jitter)
360	return delay
361}
362
363func (cfg *RequestConfig) Execute() (err error) {
364	if cfg.BaseURL == nil {
365		return fmt.Errorf("requestconfig: base url is not set")
366	}
367
368	cfg.Request.URL, err = cfg.BaseURL.Parse(strings.TrimLeft(cfg.Request.URL.String(), "/"))
369	if err != nil {
370		return err
371	}
372
373	if cfg.Body != nil && cfg.Request.Body == nil {
374		switch body := cfg.Body.(type) {
375		case *bytes.Buffer:
376			b := body.Bytes()
377			cfg.Request.ContentLength = int64(body.Len())
378			cfg.Request.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(bytes.NewReader(b)), nil }
379			cfg.Request.Body, _ = cfg.Request.GetBody()
380		case *bytes.Reader:
381			cfg.Request.ContentLength = int64(body.Len())
382			cfg.Request.GetBody = func() (io.ReadCloser, error) {
383				_, err := body.Seek(0, 0)
384				return io.NopCloser(body), err
385			}
386			cfg.Request.Body, _ = cfg.Request.GetBody()
387		default:
388			if rc, ok := body.(io.ReadCloser); ok {
389				cfg.Request.Body = rc
390			} else {
391				cfg.Request.Body = io.NopCloser(body)
392			}
393		}
394	}
395
396	handler := cfg.HTTPClient.Do
397	for i := len(cfg.Middlewares) - 1; i >= 0; i -= 1 {
398		handler = applyMiddleware(cfg.Middlewares[i], handler)
399	}
400
401	// Don't send the current retry count in the headers if the caller modified the header defaults.
402	shouldSendRetryCount := cfg.Request.Header.Get("X-Stainless-Retry-Count") == "0"
403
404	var res *http.Response
405	var cancel context.CancelFunc
406	for retryCount := 0; retryCount <= cfg.MaxRetries; retryCount += 1 {
407		ctx := cfg.Request.Context()
408		if cfg.RequestTimeout != time.Duration(0) && isBeforeContextDeadline(time.Now().Add(cfg.RequestTimeout), ctx) {
409			ctx, cancel = context.WithTimeout(ctx, cfg.RequestTimeout)
410			defer func() {
411				// The cancel function is nil if it was handed off to be handled in a different scope.
412				if cancel != nil {
413					cancel()
414				}
415			}()
416		}
417
418		req := cfg.Request.Clone(ctx)
419		if shouldSendRetryCount {
420			req.Header.Set("X-Stainless-Retry-Count", strconv.Itoa(retryCount))
421		}
422
423		res, err = handler(req)
424		if ctx != nil && ctx.Err() != nil {
425			return ctx.Err()
426		}
427		if !shouldRetry(cfg.Request, res) || retryCount >= cfg.MaxRetries {
428			break
429		}
430
431		// Prepare next request and wait for the retry delay
432		if cfg.Request.GetBody != nil {
433			cfg.Request.Body, err = cfg.Request.GetBody()
434			if err != nil {
435				return err
436			}
437		}
438
439		// Can't actually refresh the body, so we don't attempt to retry here
440		if cfg.Request.GetBody == nil && cfg.Request.Body != nil {
441			break
442		}
443
444		time.Sleep(retryDelay(res, retryCount))
445	}
446
447	// Save *http.Response if it is requested to, even if there was an error making the request. This is
448	// useful in cases where you might want to debug by inspecting the response. Note that if err != nil,
449	// the response should be generally be empty, but there are edge cases.
450	if cfg.ResponseInto != nil {
451		*cfg.ResponseInto = res
452	}
453	if responseBodyInto, ok := cfg.ResponseBodyInto.(**http.Response); ok {
454		*responseBodyInto = res
455	}
456
457	// If there was a connection error in the final request or any other transport error,
458	// return that early without trying to coerce into an APIError.
459	if err != nil {
460		return err
461	}
462
463	if res.StatusCode >= 400 {
464		contents, err := io.ReadAll(res.Body)
465		res.Body.Close()
466		if err != nil {
467			return err
468		}
469
470		// If there is an APIError, re-populate the response body so that debugging
471		// utilities can conveniently dump the response without issue.
472		res.Body = io.NopCloser(bytes.NewBuffer(contents))
473
474		// Load the contents into the error format if it is provided.
475		aerr := apierror.Error{Request: cfg.Request, Response: res, StatusCode: res.StatusCode}
476		unwrapped := gjson.GetBytes(contents, "error").Raw
477		err = aerr.UnmarshalJSON([]byte(unwrapped))
478		if err != nil {
479			return err
480		}
481		return &aerr
482	}
483
484	_, intoCustomResponseBody := cfg.ResponseBodyInto.(**http.Response)
485	if cfg.ResponseBodyInto == nil || intoCustomResponseBody {
486		// We aren't reading the response body in this scope, but whoever is will need the
487		// cancel func from the context to observe request timeouts.
488		// Put the cancel function in the response body so it can be handled elsewhere.
489		if cancel != nil {
490			res.Body = &bodyWithTimeout{rc: res.Body, stop: cancel}
491			cancel = nil
492		}
493		return nil
494	}
495
496	contents, err := io.ReadAll(res.Body)
497	if err != nil {
498		return fmt.Errorf("error reading response body: %w", err)
499	}
500
501	// If we are not json, return plaintext
502	contentType := res.Header.Get("content-type")
503	mediaType, _, _ := mime.ParseMediaType(contentType)
504	isJSON := strings.Contains(mediaType, "application/json") || strings.HasSuffix(mediaType, "+json")
505	if !isJSON {
506		switch dst := cfg.ResponseBodyInto.(type) {
507		case *string:
508			*dst = string(contents)
509		case **string:
510			tmp := string(contents)
511			*dst = &tmp
512		case *[]byte:
513			*dst = contents
514		default:
515			return fmt.Errorf("expected destination type of 'string' or '[]byte' for responses with content-type '%s' that is not 'application/json'", contentType)
516		}
517		return nil
518	}
519
520	// If the response happens to be a byte array, deserialize the body as-is.
521	switch dst := cfg.ResponseBodyInto.(type) {
522	case *[]byte:
523		*dst = contents
524	}
525
526	err = json.NewDecoder(bytes.NewReader(contents)).Decode(cfg.ResponseBodyInto)
527	if err != nil {
528		return fmt.Errorf("error parsing response json: %w", err)
529	}
530
531	return nil
532}
533
534func ExecuteNewRequest(ctx context.Context, method string, u string, body interface{}, dst interface{}, opts ...RequestOption) error {
535	cfg, err := NewRequestConfig(ctx, method, u, body, dst, opts...)
536	if err != nil {
537		return err
538	}
539	return cfg.Execute()
540}
541
542func (cfg *RequestConfig) Clone(ctx context.Context) *RequestConfig {
543	if cfg == nil {
544		return nil
545	}
546	req := cfg.Request.Clone(ctx)
547	var err error
548	if req.Body != nil {
549		req.Body, err = req.GetBody()
550	}
551	if err != nil {
552		return nil
553	}
554	new := &RequestConfig{
555		MaxRetries:     cfg.MaxRetries,
556		RequestTimeout: cfg.RequestTimeout,
557		Context:        ctx,
558		Request:        req,
559		BaseURL:        cfg.BaseURL,
560		HTTPClient:     cfg.HTTPClient,
561		Middlewares:    cfg.Middlewares,
562		APIKey:         cfg.APIKey,
563		Organization:   cfg.Organization,
564		Project:        cfg.Project,
565	}
566
567	return new
568}
569
570func (cfg *RequestConfig) Apply(opts ...RequestOption) error {
571	for _, opt := range opts {
572		err := opt.Apply(cfg)
573		if err != nil {
574			return err
575		}
576	}
577	return nil
578}
579
580func PreRequestOptions(opts ...RequestOption) (RequestConfig, error) {
581	cfg := RequestConfig{}
582	for _, opt := range opts {
583		if _, ok := opt.(PreRequestOptionFunc); !ok {
584			continue
585		}
586
587		err := opt.Apply(&cfg)
588		if err != nil {
589			return cfg, err
590		}
591	}
592	return cfg, nil
593}