diff --git a/internal/worker/client.go b/internal/worker/client.go index 0b753d2d7..833d45cd4 100644 --- a/internal/worker/client.go +++ b/internal/worker/client.go @@ -12,27 +12,20 @@ import ( "net/http" "net/url" "sync" - "time" "github.com/google/uuid" "github.com/osbuild/osbuild-composer/internal/worker/api" ) -type bearerToken struct { - AccessToken string `json:"access_token"` - ValidForSeconds int `json:"expires_in"` -} - type Client struct { - server *url.URL - requester *http.Client - offlineToken string - oAuthURL string - lastTokenRefresh *time.Time - bearerToken *bearerToken + server *url.URL + requester *http.Client + offlineToken string + oAuthURL string + accessToken string - tokenMu *sync.Mutex + tokenMu sync.RWMutex } type ClientConfig struct { @@ -66,6 +59,10 @@ type job struct { dynamicArgs []json.RawMessage } +type tokenResponse struct { + AccessToken string `json:"access_token"` +} + func NewClient(conf ClientConfig) (*Client, error) { server, err := url.Parse(conf.BaseURL) if err != nil { @@ -91,7 +88,6 @@ func NewClient(conf ClientConfig) (*Client, error) { requester: requester, offlineToken: conf.OfflineToken, oAuthURL: conf.OAuthURL, - tokenMu: &sync.Mutex{}, }, nil } @@ -123,7 +119,10 @@ func NewClientUnix(conf ClientConfig) *Client { } // Note: Only call this function with Client.tokenMu locked! -func (c *Client) refreshBearerToken() error { +func (c *Client) refreshAccessToken() error { + c.tokenMu.Lock() + defer c.tokenMu.Unlock() + if c.offlineToken == "" || c.oAuthURL == "" { return fmt.Errorf("No offline token or oauth url available") } @@ -133,24 +132,28 @@ func (c *Client) refreshBearerToken() error { data.Set("client_id", "rhsm-api") data.Set("refresh_token", c.offlineToken) - t := time.Now() resp, err := http.PostForm(c.oAuthURL, data) if err != nil { return err } - var bt bearerToken - err = json.NewDecoder(resp.Body).Decode(&bt) + var tr tokenResponse + err = json.NewDecoder(resp.Body).Decode(&tr) if err != nil { return err } - c.bearerToken = &bt - c.lastTokenRefresh = &t + c.accessToken = tr.AccessToken return nil } -func (c *Client) NewRequest(method, url string, body io.Reader) (*http.Request, error) { +func (c *Client) NewRequest(method, url string, headers map[string]string, body io.Reader) (*http.Response, error) { + token := func() string { + c.tokenMu.RLock() + defer c.tokenMu.RUnlock() + return c.accessToken + } + req, err := http.NewRequest(method, url, body) if err != nil { return nil, err @@ -158,23 +161,28 @@ func (c *Client) NewRequest(method, url string, body io.Reader) (*http.Request, // If we're using OAUTH, add the Bearer token if c.offlineToken != "" { - // make sure we have a valid token - var d time.Duration - c.tokenMu.Lock() - defer c.tokenMu.Unlock() - if c.lastTokenRefresh != nil { - d = time.Since(*c.lastTokenRefresh) - } - if c.bearerToken == nil || d.Seconds() >= (float64(c.bearerToken.ValidForSeconds)*0.8) { - err = c.refreshBearerToken() - if err != nil { - return nil, err - } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token())) + } + + for k, v := range headers { + req.Header.Add(k, v) + } + + resp, err := c.requester.Do(req) + if err != nil { + return nil, err + } + + if resp.StatusCode == http.StatusUnauthorized { + err = c.refreshAccessToken() + if err != nil { + return nil, err } - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", c.bearerToken.AccessToken)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token())) + resp, err = c.requester.Do(req) } - return req, nil + return resp, err } func (c *Client) RequestJob(types []string, arch string) (Job, error) { @@ -193,16 +201,10 @@ func (c *Client) RequestJob(types []string, arch string) (Job, error) { panic(err) } - req, err := c.NewRequest("POST", url.String(), &buf) + response, err := c.NewRequest("POST", url.String(), map[string]string{"Content-Type": "application/json"}, &buf) if err != nil { return nil, err } - req.Header.Add("Content-Type", "application/json") - - response, err := c.requester.Do(req) - if err != nil { - return nil, fmt.Errorf("error requesting job: %v", err) - } defer response.Body.Close() if response.StatusCode == http.StatusNoContent { @@ -290,14 +292,7 @@ func (j *job) Update(result interface{}) error { panic(err) } - req, err := j.client.NewRequest("PATCH", j.location, &buf) - if err != nil { - panic(err) - } - - req.Header.Add("Content-Type", "application/json") - - response, err := j.client.requester.Do(req) + response, err := j.client.NewRequest("PATCH", j.location, map[string]string{"Content-Type": "application/json"}, &buf) if err != nil { return fmt.Errorf("error fetching job info: %v", err) } @@ -311,12 +306,7 @@ func (j *job) Update(result interface{}) error { } func (j *job) Canceled() (bool, error) { - req, err := j.client.NewRequest("GET", j.location, nil) - if err != nil { - return false, err - } - - response, err := j.client.requester.Do(req) + response, err := j.client.NewRequest("GET", j.location, map[string]string{}, nil) if err != nil { return false, fmt.Errorf("error fetching job info: %v", err) } @@ -350,14 +340,7 @@ func (j *job) UploadArtifact(name string, reader io.Reader) error { panic(err) } - req, err := j.client.NewRequest("PUT", loc.String(), reader) - if err != nil { - return fmt.Errorf("cannot create request: %v", err) - } - - req.Header.Add("Content-Type", "application/octet-stream") - - response, err := j.client.requester.Do(req) + response, err := j.client.NewRequest("PUT", loc.String(), map[string]string{"Content-Type": "application/octet-stream"}, reader) if err != nil { return fmt.Errorf("error uploading artifact: %v", err) } diff --git a/internal/worker/server_test.go b/internal/worker/server_test.go index 57068a2cd..2fd0658ec 100644 --- a/internal/worker/server_test.go +++ b/internal/worker/server_test.go @@ -342,7 +342,13 @@ func TestOAuth(t *testing.T) { defer workSrv.Close() /* Check that the worker supplies the access token */ + calls := 0 proxySrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if calls == 0 { + require.Equal(t, "Bearer", r.Header.Get("Authorization")) + w.WriteHeader(http.StatusUnauthorized) + return + } require.Equal(t, "Bearer accessToken!", r.Header.Get("Authorization")) handler.ServeHTTP(w, r) })) @@ -351,6 +357,7 @@ func TestOAuth(t *testing.T) { offlineToken := "someOfflineToken" /* Start oauth srv supplying the bearer token */ oauthSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls += 1 require.Equal(t, "POST", r.Method) err = r.ParseForm() require.NoError(t, err) @@ -360,11 +367,9 @@ func TestOAuth(t *testing.T) { require.Equal(t, offlineToken, r.FormValue("refresh_token")) bt := struct { - AccessToken string `json:"access_token"` - ValidForSeconds int `json:"expires_in"` + AccessToken string `json:"access_token"` }{ "accessToken!", - 900, } w.WriteHeader(http.StatusOK)