From fd4a3a941a6bf217a08a2edc8d375fb94f61bc70 Mon Sep 17 00:00:00 2001 From: Sanne Raymaekers Date: Thu, 26 Oct 2023 18:16:21 +0200 Subject: [PATCH] worker: let client register itself with the worker server Sends a status update to the worker server every 5 minutes. Also fixes a bug where the body the worker client sent would be empty if it had to refresh the JWT token. Instead of io.Reader use io.ReadSeeker so the body can be reread to create the second request (after the token refresh). --- internal/worker/client.go | 183 ++++++++++++++++++++++++++++----- internal/worker/client_test.go | 5 +- 2 files changed, 161 insertions(+), 27 deletions(-) diff --git a/internal/worker/client.go b/internal/worker/client.go index d35cc6713..32bc93224 100644 --- a/internal/worker/client.go +++ b/internal/worker/client.go @@ -12,9 +12,12 @@ import ( "net/http" "net/url" "sync" + "time" "github.com/google/uuid" + "github.com/sirupsen/logrus" + "github.com/osbuild/osbuild-composer/internal/common" "github.com/osbuild/osbuild-composer/internal/worker/api" ) @@ -26,8 +29,10 @@ type Client struct { accessToken string clientId string clientSecret string + workerID uuid.UUID - tokenMu sync.RWMutex + tokenMu sync.RWMutex + workerIDMu sync.RWMutex } type ClientConfig struct { @@ -49,7 +54,7 @@ type Job interface { NDynamicArgs() int Update(result interface{}) error Canceled() (bool, error) - UploadArtifact(name string, reader io.Reader) error + UploadArtifact(name string, readSeeker io.ReadSeeker) error } var ErrClientRequestJobTimeout = errors.New("Dequeue timed out, retry") @@ -108,14 +113,20 @@ func NewClient(conf ClientConfig) (*Client, error) { } requester.Transport = transport - return &Client{ + client := &Client{ server: server, requester: requester, offlineToken: conf.OfflineToken, oAuthURL: conf.OAuthURL, clientId: conf.ClientId, clientSecret: conf.ClientSecret, - }, nil + } + err = client.registerWorker() + if err != nil { + return client, err + } + go client.workerHeartbeat() + return client, nil } func NewClientUnix(conf ClientConfig) *Client { @@ -138,11 +149,103 @@ func NewClientUnix(conf ClientConfig) *Client { }, }, } - - return &Client{ + client := &Client{ server: server, requester: requester, } + err = client.registerWorker() + if err != nil { + panic(err) + } + go client.workerHeartbeat() + return client +} + +func (c *Client) registerWorker() error { + c.workerIDMu.Lock() + defer c.workerIDMu.Unlock() + + url, err := c.server.Parse("workers") + if err != nil { + return err + } + + var buf bytes.Buffer + err = json.NewEncoder(&buf).Encode(api.PostWorkersRequest{ + Arch: common.CurrentArch(), + }) + if err != nil { + logrus.Errorf("Unable create worker request: %v", err) + return err + } + + resp, err := c.NewRequest("POST", url.String(), map[string]string{"Content-Type": "application/json"}, bytes.NewReader(buf.Bytes())) + if err != nil { + logrus.Errorf("Unable to register worker: %v", err) + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + return errorFromResponse(resp, "error registering worker") + } + + var wresp api.PostWorkersResponse + err = json.NewDecoder(resp.Body).Decode(&wresp) + if err != nil { + return err + } + + workerID, err := uuid.Parse(wresp.WorkerId) + if err != nil { + return err + } + c.workerID = workerID + return nil +} + +func (c *Client) workerHeartbeat() { + //nolint:staticcheck // avoid SA1015, this is an endless function + for range time.Tick(time.Minute * 1) { + workerID := func() uuid.UUID { + c.workerIDMu.RLock() + defer c.workerIDMu.RUnlock() + return c.workerID + } + + if workerID() == uuid.Nil { + err := c.registerWorker() + if err != nil { + logrus.Errorf("Error registering worker, %v", err) + continue + } + } + + url, err := c.server.Parse(fmt.Sprintf("workers/%s/status", workerID())) + if err != nil { + logrus.Errorf("Error parsing worker status: %v", err) + continue + } + + var buf bytes.Buffer + resp, err := c.NewRequest("POST", url.String(), map[string]string{"Content-Type": "application/json"}, bytes.NewReader(buf.Bytes())) + if err != nil { + logrus.Errorf("Error updating worker status: %v", err) + continue + } + if resp.StatusCode == http.StatusBadRequest { + logrus.Warning("Registering worker again") + err = c.registerWorker() + if err != nil { + logrus.Errorf("Error registering worker, %v", err) + } + continue + } + if resp.StatusCode != http.StatusOK { + logrus.Errorf("Error updating worker status: %d", resp.StatusCode) + continue + } + } } func (c *Client) refreshAccessToken() error { @@ -176,26 +279,41 @@ func (c *Client) refreshAccessToken() error { return nil } -func (c *Client) NewRequest(method, url string, headers map[string]string, body io.Reader) (*http.Response, error) { +func (c *Client) NewRequest(method, url string, headers map[string]string, body io.ReadSeeker) (*http.Response, error) { token := func() string { c.tokenMu.RLock() defer c.tokenMu.RUnlock() return c.accessToken } - req, err := http.NewRequest(method, url, body) + newRequest := func() (*http.Request, error) { + if body != nil { + _, err := body.Seek(0, 0) + if err != nil { + return nil, err + } + } + + req, err := http.NewRequest(method, url, body) + if err != nil { + return nil, err + } + + if c.oAuthURL != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token())) + } + + for k, v := range headers { + req.Header.Add(k, v) + } + + return req, nil + } + + req, err := newRequest() if err != nil { return nil, err } - - if c.oAuthURL != "" { - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token())) - } - - for k, v := range headers { - req.Header.Add(k, v) - } - resp, err := c.requester.Do(req) if err != nil { return nil, err @@ -207,7 +325,10 @@ func (c *Client) NewRequest(method, url string, headers map[string]string, body return nil, err } - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token())) + req, err = newRequest() + if err != nil { + return nil, err + } resp, err = c.requester.Do(req) } return resp, err @@ -220,16 +341,28 @@ func (c *Client) RequestJob(types []string, arch string) (Job, error) { panic(err) } - var buf bytes.Buffer - err = json.NewEncoder(&buf).Encode(api.RequestJobJSONRequestBody{ + reqBody := api.RequestJobJSONRequestBody{ Types: types, Arch: arch, - }) + } + + workerID := func() uuid.UUID { + c.workerIDMu.RLock() + defer c.workerIDMu.RUnlock() + return c.workerID + }() + + if workerID != uuid.Nil { + reqBody.WorkerId = common.ToPtr(workerID.String()) + } + + var buf bytes.Buffer + err = json.NewEncoder(&buf).Encode(reqBody) if err != nil { panic(err) } - response, err := c.NewRequest("POST", url.String(), map[string]string{"Content-Type": "application/json"}, &buf) + response, err := c.NewRequest("POST", url.String(), map[string]string{"Content-Type": "application/json"}, bytes.NewReader(buf.Bytes())) if err != nil { return nil, err } @@ -320,7 +453,7 @@ func (j *job) Update(result interface{}) error { panic(err) } - response, err := j.client.NewRequest("PATCH", j.location, map[string]string{"Content-Type": "application/json"}, &buf) + response, err := j.client.NewRequest("PATCH", j.location, map[string]string{"Content-Type": "application/json"}, bytes.NewReader(buf.Bytes())) if err != nil { return fmt.Errorf("error fetching job info: %v", err) } @@ -353,7 +486,7 @@ func (j *job) Canceled() (bool, error) { return jr.Canceled, nil } -func (j *job) UploadArtifact(name string, reader io.Reader) error { +func (j *job) UploadArtifact(name string, readSeeker io.ReadSeeker) error { if j.artifactLocation == "" { return fmt.Errorf("server does not accept artifacts for this job") } @@ -368,7 +501,7 @@ func (j *job) UploadArtifact(name string, reader io.Reader) error { panic(err) } - response, err := j.client.NewRequest("PUT", loc.String(), map[string]string{"Content-Type": "application/octet-stream"}, reader) + response, err := j.client.NewRequest("PUT", loc.String(), map[string]string{"Content-Type": "application/octet-stream"}, readSeeker) if err != nil { return fmt.Errorf("error uploading artifact: %v", err) } diff --git a/internal/worker/client_test.go b/internal/worker/client_test.go index 93a9b6b8f..78268e69d 100644 --- a/internal/worker/client_test.go +++ b/internal/worker/client_test.go @@ -128,11 +128,12 @@ func TestProxy(t *testing.T) { require.False(t, c) require.NoError(t, err) - // we expect 5 calls to go through the proxy: + // we expect 6 calls to go through the proxy: + // - register worker // - request job (fails, no oauth token) // - oauth call // - request job (succeeds) // - upload artifact // - cancel - require.Equal(t, 5, proxy.calls) + require.Equal(t, 6, proxy.calls) }