1// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
3package requestconfig
4
5import (
6 "bytes"
7 "context"
8 "encoding/json"
9 "fmt"
10 "io"
11 "math"
12 "math/rand"
13 "mime"
14 "net/http"
15 "net/url"
16 "runtime"
17 "strconv"
18 "strings"
19 "time"
20
21 "github.com/openai/openai-go/internal"
22 "github.com/openai/openai-go/internal/apierror"
23 "github.com/openai/openai-go/internal/apiform"
24 "github.com/openai/openai-go/internal/apiquery"
25 "github.com/tidwall/gjson"
26)
27
28func getDefaultHeaders() map[string]string {
29 return map[string]string{
30 "User-Agent": fmt.Sprintf("OpenAI/Go %s", internal.PackageVersion),
31 }
32}
33
34func getNormalizedOS() string {
35 switch runtime.GOOS {
36 case "ios":
37 return "iOS"
38 case "android":
39 return "Android"
40 case "darwin":
41 return "MacOS"
42 case "window":
43 return "Windows"
44 case "freebsd":
45 return "FreeBSD"
46 case "openbsd":
47 return "OpenBSD"
48 case "linux":
49 return "Linux"
50 default:
51 return fmt.Sprintf("Other:%s", runtime.GOOS)
52 }
53}
54
55func getNormalizedArchitecture() string {
56 switch runtime.GOARCH {
57 case "386":
58 return "x32"
59 case "amd64":
60 return "x64"
61 case "arm":
62 return "arm"
63 case "arm64":
64 return "arm64"
65 default:
66 return fmt.Sprintf("other:%s", runtime.GOARCH)
67 }
68}
69
70func getPlatformProperties() map[string]string {
71 return map[string]string{
72 "X-Stainless-Lang": "go",
73 "X-Stainless-Package-Version": internal.PackageVersion,
74 "X-Stainless-OS": getNormalizedOS(),
75 "X-Stainless-Arch": getNormalizedArchitecture(),
76 "X-Stainless-Runtime": "go",
77 "X-Stainless-Runtime-Version": runtime.Version(),
78 }
79}
80
81type RequestOption interface {
82 Apply(*RequestConfig) error
83}
84
85type RequestOptionFunc func(*RequestConfig) error
86type PreRequestOptionFunc func(*RequestConfig) error
87
88func (s RequestOptionFunc) Apply(r *RequestConfig) error { return s(r) }
89func (s PreRequestOptionFunc) Apply(r *RequestConfig) error { return s(r) }
90
91func NewRequestConfig(ctx context.Context, method string, u string, body interface{}, dst interface{}, opts ...RequestOption) (*RequestConfig, error) {
92 var reader io.Reader
93
94 contentType := "application/json"
95 hasSerializationFunc := false
96
97 if body, ok := body.(json.Marshaler); ok {
98 content, err := body.MarshalJSON()
99 if err != nil {
100 return nil, err
101 }
102 reader = bytes.NewBuffer(content)
103 hasSerializationFunc = true
104 }
105 if body, ok := body.(apiform.Marshaler); ok {
106 var (
107 content []byte
108 err error
109 )
110 content, contentType, err = body.MarshalMultipart()
111 if err != nil {
112 return nil, err
113 }
114 reader = bytes.NewBuffer(content)
115 hasSerializationFunc = true
116 }
117 if body, ok := body.(apiquery.Queryer); ok {
118 hasSerializationFunc = true
119 params := body.URLQuery().Encode()
120 if params != "" {
121 u = u + "?" + params
122 }
123 }
124 if body, ok := body.([]byte); ok {
125 reader = bytes.NewBuffer(body)
126 hasSerializationFunc = true
127 }
128 if body, ok := body.(io.Reader); ok {
129 reader = body
130 hasSerializationFunc = true
131 }
132
133 // Fallback to json serialization if none of the serialization functions that we expect
134 // to see is present.
135 if body != nil && !hasSerializationFunc {
136 content, err := json.Marshal(body)
137 if err != nil {
138 return nil, err
139 }
140 reader = bytes.NewBuffer(content)
141 }
142
143 req, err := http.NewRequestWithContext(ctx, method, u, nil)
144 if err != nil {
145 return nil, err
146 }
147 if reader != nil {
148 req.Header.Set("Content-Type", contentType)
149 }
150
151 req.Header.Set("Accept", "application/json")
152 req.Header.Set("X-Stainless-Retry-Count", "0")
153 req.Header.Set("X-Stainless-Timeout", "0")
154 for k, v := range getDefaultHeaders() {
155 req.Header.Add(k, v)
156 }
157
158 for k, v := range getPlatformProperties() {
159 req.Header.Add(k, v)
160 }
161 cfg := RequestConfig{
162 MaxRetries: 2,
163 Context: ctx,
164 Request: req,
165 HTTPClient: http.DefaultClient,
166 Body: reader,
167 }
168 cfg.ResponseBodyInto = dst
169 err = cfg.Apply(opts...)
170 if err != nil {
171 return nil, err
172 }
173
174 // This must run after `cfg.Apply(...)` above in case the request timeout gets modified. We also only
175 // apply our own logic for it if it's still "0" from above. If it's not, then it was deleted or modified
176 // by the user and we should respect that.
177 if req.Header.Get("X-Stainless-Timeout") == "0" {
178 if cfg.RequestTimeout == time.Duration(0) {
179 req.Header.Del("X-Stainless-Timeout")
180 } else {
181 req.Header.Set("X-Stainless-Timeout", strconv.Itoa(int(cfg.RequestTimeout.Seconds())))
182 }
183 }
184
185 return &cfg, nil
186}
187
188// RequestConfig represents all the state related to one request.
189//
190// Editing the variables inside RequestConfig directly is unstable api. Prefer
191// composing the RequestOption instead if possible.
192type RequestConfig struct {
193 MaxRetries int
194 RequestTimeout time.Duration
195 Context context.Context
196 Request *http.Request
197 BaseURL *url.URL
198 HTTPClient *http.Client
199 Middlewares []middleware
200 APIKey string
201 Organization string
202 Project string
203 // If ResponseBodyInto not nil, then we will attempt to deserialize into
204 // ResponseBodyInto. If Destination is a []byte, then it will return the body as
205 // is.
206 ResponseBodyInto interface{}
207 // ResponseInto copies the \*http.Response of the corresponding request into the
208 // given address
209 ResponseInto **http.Response
210 Body io.Reader
211}
212
213// middleware is exactly the same type as the Middleware type found in the [option] package,
214// but it is redeclared here for circular dependency issues.
215type middleware = func(*http.Request, middlewareNext) (*http.Response, error)
216
217// middlewareNext is exactly the same type as the MiddlewareNext type found in the [option] package,
218// but it is redeclared here for circular dependency issues.
219type middlewareNext = func(*http.Request) (*http.Response, error)
220
221func applyMiddleware(middleware middleware, next middlewareNext) middlewareNext {
222 return func(req *http.Request) (res *http.Response, err error) {
223 return middleware(req, next)
224 }
225}
226
227func shouldRetry(req *http.Request, res *http.Response) bool {
228 // If there is no way to recover the Body, then we shouldn't retry.
229 if req.Body != nil && req.GetBody == nil {
230 return false
231 }
232
233 // If there is no response, that indicates that there is a connection error
234 // so we retry the request.
235 if res == nil {
236 return true
237 }
238
239 // If the header explictly wants a retry behavior, respect that over the
240 // http status code.
241 if res.Header.Get("x-should-retry") == "true" {
242 return true
243 }
244 if res.Header.Get("x-should-retry") == "false" {
245 return false
246 }
247
248 return res.StatusCode == http.StatusRequestTimeout ||
249 res.StatusCode == http.StatusConflict ||
250 res.StatusCode == http.StatusTooManyRequests ||
251 res.StatusCode >= http.StatusInternalServerError
252}
253
254func parseRetryAfterHeader(resp *http.Response) (time.Duration, bool) {
255 if resp == nil {
256 return 0, false
257 }
258
259 type retryData struct {
260 header string
261 units time.Duration
262
263 // custom is used when the regular algorithm failed and is optional.
264 // the returned duration is used verbatim (units is not applied).
265 custom func(string) (time.Duration, bool)
266 }
267
268 nop := func(string) (time.Duration, bool) { return 0, false }
269
270 // the headers are listed in order of preference
271 retries := []retryData{
272 {
273 header: "Retry-After-Ms",
274 units: time.Millisecond,
275 custom: nop,
276 },
277 {
278 header: "Retry-After",
279 units: time.Second,
280
281 // retry-after values are expressed in either number of
282 // seconds or an HTTP-date indicating when to try again
283 custom: func(ra string) (time.Duration, bool) {
284 t, err := time.Parse(time.RFC1123, ra)
285 if err != nil {
286 return 0, false
287 }
288 return time.Until(t), true
289 },
290 },
291 }
292
293 for _, retry := range retries {
294 v := resp.Header.Get(retry.header)
295 if v == "" {
296 continue
297 }
298 if retryAfter, err := strconv.ParseFloat(v, 64); err == nil {
299 return time.Duration(retryAfter * float64(retry.units)), true
300 }
301 if d, ok := retry.custom(v); ok {
302 return d, true
303 }
304 }
305
306 return 0, false
307}
308
309// isBeforeContextDeadline reports whether the non-zero Time t is
310// before ctx's deadline. If ctx does not have a deadline, it
311// always reports true (the deadline is considered infinite).
312func isBeforeContextDeadline(t time.Time, ctx context.Context) bool {
313 d, ok := ctx.Deadline()
314 if !ok {
315 return true
316 }
317 return t.Before(d)
318}
319
320// bodyWithTimeout is an io.ReadCloser which can observe a context's cancel func
321// to handle timeouts etc. It wraps an existing io.ReadCloser.
322type bodyWithTimeout struct {
323 stop func() // stops the time.Timer waiting to cancel the request
324 rc io.ReadCloser
325}
326
327func (b *bodyWithTimeout) Read(p []byte) (n int, err error) {
328 n, err = b.rc.Read(p)
329 if err == nil {
330 return n, nil
331 }
332 if err == io.EOF {
333 return n, err
334 }
335 return n, err
336}
337
338func (b *bodyWithTimeout) Close() error {
339 err := b.rc.Close()
340 b.stop()
341 return err
342}
343
344func retryDelay(res *http.Response, retryCount int) time.Duration {
345 // If the API asks us to wait a certain amount of time (and it's a reasonable amount),
346 // just do what it says.
347
348 if retryAfterDelay, ok := parseRetryAfterHeader(res); ok && 0 <= retryAfterDelay && retryAfterDelay < time.Minute {
349 return retryAfterDelay
350 }
351
352 maxDelay := 8 * time.Second
353 delay := time.Duration(0.5 * float64(time.Second) * math.Pow(2, float64(retryCount)))
354 if delay > maxDelay {
355 delay = maxDelay
356 }
357
358 jitter := rand.Int63n(int64(delay / 4))
359 delay -= time.Duration(jitter)
360 return delay
361}
362
363func (cfg *RequestConfig) Execute() (err error) {
364 if cfg.BaseURL == nil {
365 return fmt.Errorf("requestconfig: base url is not set")
366 }
367
368 cfg.Request.URL, err = cfg.BaseURL.Parse(strings.TrimLeft(cfg.Request.URL.String(), "/"))
369 if err != nil {
370 return err
371 }
372
373 if cfg.Body != nil && cfg.Request.Body == nil {
374 switch body := cfg.Body.(type) {
375 case *bytes.Buffer:
376 b := body.Bytes()
377 cfg.Request.ContentLength = int64(body.Len())
378 cfg.Request.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(bytes.NewReader(b)), nil }
379 cfg.Request.Body, _ = cfg.Request.GetBody()
380 case *bytes.Reader:
381 cfg.Request.ContentLength = int64(body.Len())
382 cfg.Request.GetBody = func() (io.ReadCloser, error) {
383 _, err := body.Seek(0, 0)
384 return io.NopCloser(body), err
385 }
386 cfg.Request.Body, _ = cfg.Request.GetBody()
387 default:
388 if rc, ok := body.(io.ReadCloser); ok {
389 cfg.Request.Body = rc
390 } else {
391 cfg.Request.Body = io.NopCloser(body)
392 }
393 }
394 }
395
396 handler := cfg.HTTPClient.Do
397 for i := len(cfg.Middlewares) - 1; i >= 0; i -= 1 {
398 handler = applyMiddleware(cfg.Middlewares[i], handler)
399 }
400
401 // Don't send the current retry count in the headers if the caller modified the header defaults.
402 shouldSendRetryCount := cfg.Request.Header.Get("X-Stainless-Retry-Count") == "0"
403
404 var res *http.Response
405 var cancel context.CancelFunc
406 for retryCount := 0; retryCount <= cfg.MaxRetries; retryCount += 1 {
407 ctx := cfg.Request.Context()
408 if cfg.RequestTimeout != time.Duration(0) && isBeforeContextDeadline(time.Now().Add(cfg.RequestTimeout), ctx) {
409 ctx, cancel = context.WithTimeout(ctx, cfg.RequestTimeout)
410 defer func() {
411 // The cancel function is nil if it was handed off to be handled in a different scope.
412 if cancel != nil {
413 cancel()
414 }
415 }()
416 }
417
418 req := cfg.Request.Clone(ctx)
419 if shouldSendRetryCount {
420 req.Header.Set("X-Stainless-Retry-Count", strconv.Itoa(retryCount))
421 }
422
423 res, err = handler(req)
424 if ctx != nil && ctx.Err() != nil {
425 return ctx.Err()
426 }
427 if !shouldRetry(cfg.Request, res) || retryCount >= cfg.MaxRetries {
428 break
429 }
430
431 // Prepare next request and wait for the retry delay
432 if cfg.Request.GetBody != nil {
433 cfg.Request.Body, err = cfg.Request.GetBody()
434 if err != nil {
435 return err
436 }
437 }
438
439 // Can't actually refresh the body, so we don't attempt to retry here
440 if cfg.Request.GetBody == nil && cfg.Request.Body != nil {
441 break
442 }
443
444 time.Sleep(retryDelay(res, retryCount))
445 }
446
447 // Save *http.Response if it is requested to, even if there was an error making the request. This is
448 // useful in cases where you might want to debug by inspecting the response. Note that if err != nil,
449 // the response should be generally be empty, but there are edge cases.
450 if cfg.ResponseInto != nil {
451 *cfg.ResponseInto = res
452 }
453 if responseBodyInto, ok := cfg.ResponseBodyInto.(**http.Response); ok {
454 *responseBodyInto = res
455 }
456
457 // If there was a connection error in the final request or any other transport error,
458 // return that early without trying to coerce into an APIError.
459 if err != nil {
460 return err
461 }
462
463 if res.StatusCode >= 400 {
464 contents, err := io.ReadAll(res.Body)
465 res.Body.Close()
466 if err != nil {
467 return err
468 }
469
470 // If there is an APIError, re-populate the response body so that debugging
471 // utilities can conveniently dump the response without issue.
472 res.Body = io.NopCloser(bytes.NewBuffer(contents))
473
474 // Load the contents into the error format if it is provided.
475 aerr := apierror.Error{Request: cfg.Request, Response: res, StatusCode: res.StatusCode}
476 unwrapped := gjson.GetBytes(contents, "error").Raw
477 err = aerr.UnmarshalJSON([]byte(unwrapped))
478 if err != nil {
479 return err
480 }
481 return &aerr
482 }
483
484 _, intoCustomResponseBody := cfg.ResponseBodyInto.(**http.Response)
485 if cfg.ResponseBodyInto == nil || intoCustomResponseBody {
486 // We aren't reading the response body in this scope, but whoever is will need the
487 // cancel func from the context to observe request timeouts.
488 // Put the cancel function in the response body so it can be handled elsewhere.
489 if cancel != nil {
490 res.Body = &bodyWithTimeout{rc: res.Body, stop: cancel}
491 cancel = nil
492 }
493 return nil
494 }
495
496 contents, err := io.ReadAll(res.Body)
497 if err != nil {
498 return fmt.Errorf("error reading response body: %w", err)
499 }
500
501 // If we are not json, return plaintext
502 contentType := res.Header.Get("content-type")
503 mediaType, _, _ := mime.ParseMediaType(contentType)
504 isJSON := strings.Contains(mediaType, "application/json") || strings.HasSuffix(mediaType, "+json")
505 if !isJSON {
506 switch dst := cfg.ResponseBodyInto.(type) {
507 case *string:
508 *dst = string(contents)
509 case **string:
510 tmp := string(contents)
511 *dst = &tmp
512 case *[]byte:
513 *dst = contents
514 default:
515 return fmt.Errorf("expected destination type of 'string' or '[]byte' for responses with content-type '%s' that is not 'application/json'", contentType)
516 }
517 return nil
518 }
519
520 // If the response happens to be a byte array, deserialize the body as-is.
521 switch dst := cfg.ResponseBodyInto.(type) {
522 case *[]byte:
523 *dst = contents
524 }
525
526 err = json.NewDecoder(bytes.NewReader(contents)).Decode(cfg.ResponseBodyInto)
527 if err != nil {
528 return fmt.Errorf("error parsing response json: %w", err)
529 }
530
531 return nil
532}
533
534func ExecuteNewRequest(ctx context.Context, method string, u string, body interface{}, dst interface{}, opts ...RequestOption) error {
535 cfg, err := NewRequestConfig(ctx, method, u, body, dst, opts...)
536 if err != nil {
537 return err
538 }
539 return cfg.Execute()
540}
541
542func (cfg *RequestConfig) Clone(ctx context.Context) *RequestConfig {
543 if cfg == nil {
544 return nil
545 }
546 req := cfg.Request.Clone(ctx)
547 var err error
548 if req.Body != nil {
549 req.Body, err = req.GetBody()
550 }
551 if err != nil {
552 return nil
553 }
554 new := &RequestConfig{
555 MaxRetries: cfg.MaxRetries,
556 RequestTimeout: cfg.RequestTimeout,
557 Context: ctx,
558 Request: req,
559 BaseURL: cfg.BaseURL,
560 HTTPClient: cfg.HTTPClient,
561 Middlewares: cfg.Middlewares,
562 APIKey: cfg.APIKey,
563 Organization: cfg.Organization,
564 Project: cfg.Project,
565 }
566
567 return new
568}
569
570func (cfg *RequestConfig) Apply(opts ...RequestOption) error {
571 for _, opt := range opts {
572 err := opt.Apply(cfg)
573 if err != nil {
574 return err
575 }
576 }
577 return nil
578}
579
580func PreRequestOptions(opts ...RequestOption) (RequestConfig, error) {
581 cfg := RequestConfig{}
582 for _, opt := range opts {
583 if _, ok := opt.(PreRequestOptionFunc); !ok {
584 continue
585 }
586
587 err := opt.Apply(&cfg)
588 if err != nil {
589 return cfg, err
590 }
591 }
592 return cfg, nil
593}